/jax/scipy/signal.py

https://github.com/google/jax · Python · 139 lines · 103 code · 19 blank · 17 comment · 46 complexity · eb32b32d1b6bc5d3160fdceb11734ba9 MD5 · raw file

  1. # Copyright 2020 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # https://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import scipy.signal as osp_signal
  15. import warnings
  16. import numpy as np
  17. from .. import lax
  18. from ..numpy import lax_numpy as jnp
  19. from ..numpy import linalg
  20. from ..numpy.lax_numpy import _promote_dtypes_inexact
  21. from ..numpy._util import _wraps
  22. # Note: we do not re-use the code from jax.numpy.convolve here, because the handling
  23. # of padding differs slightly between the two implementations (particularly for
  24. # mode='same').
  25. def _convolve_nd(in1, in2, mode, *, precision):
  26. if mode not in ["full", "same", "valid"]:
  27. raise ValueError("mode must be one of ['full', 'same', 'valid']")
  28. if in1.ndim != in2.ndim:
  29. raise ValueError("in1 and in2 must have the same number of dimensions")
  30. if in1.size == 0 or in2.size == 0:
  31. raise ValueError(f"zero-size arrays not supported in convolutions, got shapes {in1.shape} and {in2.shape}.")
  32. in1, in2 = _promote_dtypes_inexact(in1, in2)
  33. no_swap = all(s1 >= s2 for s1, s2 in zip(in1.shape, in2.shape))
  34. swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
  35. if not (no_swap or swap):
  36. raise ValueError("One input must be smaller than the other in every dimension.")
  37. shape_o = in2.shape
  38. if swap:
  39. in1, in2 = in2, in1
  40. shape = in2.shape
  41. in2 = in2[tuple(slice(None, None, -1) for s in shape)]
  42. if mode == 'valid':
  43. padding = [(0, 0) for s in shape]
  44. elif mode == 'same':
  45. padding = [(s - 1 - (s_o - 1) // 2, s - s_o + (s_o - 1) // 2)
  46. for (s, s_o) in zip(shape, shape_o)]
  47. elif mode == 'full':
  48. padding = [(s - 1, s - 1) for s in shape]
  49. strides = tuple(1 for s in shape)
  50. result = lax.conv_general_dilated(in1[None, None], in2[None, None], strides,
  51. padding, precision=precision)
  52. return result[0, 0]
  53. @_wraps(osp_signal.convolve)
  54. def convolve(in1, in2, mode='full', method='auto',
  55. precision=None):
  56. if method != 'auto':
  57. warnings.warn("convolve() ignores method argument")
  58. if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating):
  59. raise NotImplementedError("convolve() does not support complex inputs")
  60. if jnp.ndim(in1) != 1 or jnp.ndim(in2) != 1:
  61. raise ValueError("convolve() only supports 1-dimensional inputs.")
  62. return _convolve_nd(in1, in2, mode, precision=precision)
  63. @_wraps(osp_signal.convolve2d)
  64. def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
  65. precision=None):
  66. if boundary != 'fill' or fillvalue != 0:
  67. raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0")
  68. if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating):
  69. raise NotImplementedError("convolve2d() does not support complex inputs")
  70. if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
  71. raise ValueError("convolve2d() only supports 2-dimensional inputs.")
  72. return _convolve_nd(in1, in2, mode, precision=precision)
  73. @_wraps(osp_signal.correlate)
  74. def correlate(in1, in2, mode='full', method='auto',
  75. precision=None):
  76. if method != 'auto':
  77. warnings.warn("correlate() ignores method argument")
  78. if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating):
  79. raise NotImplementedError("correlate() does not support complex inputs")
  80. if jnp.ndim(in1) != 1 or jnp.ndim(in2) != 1:
  81. raise ValueError("correlate() only supports {ndim}-dimensional inputs.")
  82. return _convolve_nd(in1, in2[::-1], mode, precision=precision)
  83. @_wraps(osp_signal.correlate)
  84. def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
  85. precision=None):
  86. if boundary != 'fill' or fillvalue != 0:
  87. raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0")
  88. if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating):
  89. raise NotImplementedError("correlate2d() does not support complex inputs")
  90. if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
  91. raise ValueError("correlate2d() only supports {ndim}-dimensional inputs.")
  92. return _convolve_nd(in1[::-1, ::-1], in2, mode, precision=precision)[::-1, ::-1]
  93. @_wraps(osp_signal.detrend)
  94. def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
  95. if overwrite_data is not None:
  96. raise NotImplementedError("overwrite_data argument not implemented.")
  97. if type not in ['constant', 'linear']:
  98. raise ValueError("Trend type must be 'linear' or 'constant'.")
  99. data, = _promote_dtypes_inexact(jnp.asarray(data))
  100. if type == 'constant':
  101. return data - data.mean(axis, keepdims=True)
  102. else:
  103. N = data.shape[axis]
  104. # bp is static, so we use np operations to avoid pushing to device.
  105. bp = np.sort(np.unique(np.r_[0, bp, N]))
  106. if bp[0] < 0 or bp[-1] > N:
  107. raise ValueError("Breakpoints must be non-negative and less than length of data along given axis.")
  108. data = jnp.moveaxis(data, axis, 0)
  109. shape = data.shape
  110. data = data.reshape(N, -1)
  111. for m in range(len(bp) - 1):
  112. Npts = bp[m + 1] - bp[m]
  113. A = jnp.vstack([
  114. jnp.ones(Npts, dtype=data.dtype),
  115. jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts
  116. ]).T
  117. sl = slice(bp[m], bp[m + 1])
  118. coef, *_ = linalg.lstsq(A, data[sl])
  119. data = data.at[sl].add(-jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
  120. return jnp.moveaxis(data.reshape(shape), 0, axis)