ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/wavelets.py CHANGED
@@ -1,194 +1,187 @@
1
1
  import typing
2
2
 
3
+ import ezmsg.core as ez
3
4
  import numpy as np
4
5
  import numpy.typing as npt
5
6
  import pywt
6
- import ezmsg.core as ez
7
+ from ezmsg.baseproc import (
8
+ BaseStatefulTransformer,
9
+ BaseTransformerUnit,
10
+ processor_state,
11
+ )
7
12
  from ezmsg.util.messages.axisarray import AxisArray
8
13
  from ezmsg.util.messages.util import replace
9
- from ezmsg.util.generator import consumer
10
14
 
11
- from .base import GenAxisArray
12
- from .filterbank import filterbank, FilterbankMode, MinPhaseMode
15
+ from .filterbank import FilterbankMode, MinPhaseMode, filterbank
13
16
 
14
17
 
15
- @consumer
16
- def cwt(
17
- frequencies: list | tuple | npt.NDArray | None,
18
- wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
19
- min_phase: MinPhaseMode = MinPhaseMode.NONE,
20
- axis: str = "time",
21
- scales: list | tuple | npt.NDArray | None = None,
22
- ) -> typing.Generator[AxisArray, AxisArray, None]:
18
+ class CWTSettings(ez.Settings):
19
+ """
20
+ Settings for :obj:`CWT`
21
+ See :obj:`cwt` for argument details.
23
22
  """
24
- Perform a continuous wavelet transform.
25
- The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
26
23
 
27
- Args:
28
- frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
29
- Note: frequencies will be sorted from smallest to largest.
30
- wavelet: Wavelet object or name of wavelet to use.
31
- min_phase: See filterbank MinPhaseMode for details.
32
- axis: The target axis for operation. Note that this will be moved to the -1th dimension
33
- because fft and matrix multiplication is much faster on the last axis.
34
- This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
35
- scales: The scales to use. If None, the scales will be calculated from the frequencies.
36
- Note: Scales will be sorted from largest to smallest.
37
- Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
38
- `pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
24
+ frequencies: list | tuple | npt.NDArray | None
25
+ wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
26
+ min_phase: MinPhaseMode = MinPhaseMode.NONE
27
+ axis: str = "time"
28
+ scales: list | tuple | npt.NDArray | None = None
39
29
 
40
- Returns:
41
- A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
42
- and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
43
- """
44
- precision = 10
45
- msg_out: AxisArray | None = None
46
-
47
- # Check parameters
48
- if frequencies is None and scales is None:
49
- raise ValueError("Either frequencies or scales must be provided.")
50
- if frequencies is not None and scales is not None:
51
- raise ValueError("Only one of frequencies or scales can be provided.")
52
- if scales is not None:
53
- scales = np.sort(scales)[::-1]
54
- assert np.all(scales > 0), "scales must be positive."
55
- assert scales.ndim == 1, "scales must be a 1D list, tuple, or array."
56
-
57
- if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
58
- wavelet = pywt.DiscreteContinuousWavelet(wavelet)
59
-
60
- if frequencies is not None:
61
- frequencies = np.sort(frequencies)
62
- assert np.all(frequencies > 0), "frequencies must be positive."
63
- assert frequencies.ndim == 1, "frequencies must be a 1D list, tuple, or array."
64
-
65
- # State variables
30
+
31
+ @processor_state
32
+ class CWTState:
66
33
  neg_rt_scales: npt.NDArray | None = None
67
- int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
68
- int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
34
+ int_psi_scales: list[npt.NDArray] | None = None
69
35
  template: AxisArray | None = None
70
36
  fbgen: typing.Generator[AxisArray, AxisArray, None] | None = None
71
37
  last_conv_samp: npt.NDArray | None = None
72
38
 
73
- # Reset if input changed
74
- check_input = {
75
- "kind": None, # Need to recalc kernels at same complexity as input
76
- "gain": None, # Need to recalc freqs
77
- "shape": None, # Need to recalc template and buffer
78
- "key": None, # Buffer obsolete
79
- }
80
-
81
- while True:
82
- msg_in: AxisArray = yield msg_out
83
- ax_idx = msg_in.get_axis_idx(axis)
84
- in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1 :]
85
-
86
- b_reset = msg_in.data.dtype.kind != check_input["kind"]
87
- b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
88
- b_reset = b_reset or in_shape != check_input["shape"]
89
- b_reset = b_reset or msg_in.key != check_input["key"]
90
- b_reset = b_reset and msg_in.data.size > 0
91
- if b_reset:
92
- check_input["kind"] = msg_in.data.dtype.kind
93
- check_input["gain"] = msg_in.axes[axis].gain
94
- check_input["shape"] = in_shape
95
- check_input["key"] = msg_in.key
96
-
97
- if frequencies is not None:
98
- scales = pywt.frequency2scale(
99
- wavelet, frequencies * msg_in.axes[axis].gain, precision=precision
100
- )
101
- neg_rt_scales = -np.sqrt(scales)[:, None]
102
-
103
- # convert int_psi, wave_xvec to the same precision as the data
104
- dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
105
- dt_cplx = np.result_type(dt_data, np.complex64)
106
- dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
107
- int_psi = np.asarray(int_psi, dtype=dt_psi)
108
- # TODO: Currently int_psi cannot be made non-complex once it is complex.
109
-
110
- # Calculate waves for each scale
111
- wave_xvec = np.asarray(wave_xvec, dtype=msg_in.data.real.dtype)
112
- wave_range = wave_xvec[-1] - wave_xvec[0]
113
- step = wave_xvec[1] - wave_xvec[0]
114
- int_psi_scales = []
115
- for scale in scales:
116
- reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
117
- if reix[-1] >= int_psi.size:
118
- reix = np.extract(reix < int_psi.size, reix)
119
- int_psi_scales.append(int_psi[reix][::-1])
120
-
121
- # CONV is probably best because we often get huge kernels.
122
- fbgen = filterbank(
123
- int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis
124
- )
125
39
 
126
- freqs = (
127
- pywt.scale2frequency(wavelet, scales, precision)
128
- / msg_in.axes[axis].gain
40
+ class CWTTransformer(BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]):
41
+ def _hash_message(self, message: AxisArray) -> int:
42
+ ax_idx = message.get_axis_idx(self.settings.axis)
43
+ in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
44
+ return hash(
45
+ (
46
+ message.data.dtype.kind,
47
+ message.axes[self.settings.axis].gain,
48
+ in_shape,
49
+ message.key,
129
50
  )
130
- # Create output template
131
- dummy_shape = in_shape + (len(scales), 0)
132
- template = AxisArray(
133
- np.zeros(
134
- dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data
135
- ),
136
- dims=msg_in.dims[:ax_idx] + msg_in.dims[ax_idx + 1 :] + ["freq", axis],
137
- axes={
138
- **msg_in.axes,
139
- "freq": AxisArray.CoordinateAxis(
140
- unit="Hz", data=freqs, dims=["freq"]
141
- ),
142
- },
143
- key=msg_in.key,
144
- )
145
- last_conv_samp = np.zeros(
146
- dummy_shape[:-1] + (1,), dtype=template.data.dtype
51
+ )
52
+
53
+ def _reset_state(self, message: AxisArray) -> None:
54
+ precision = 10
55
+
56
+ # Process wavelet
57
+ wavelet = (
58
+ self.settings.wavelet
59
+ if isinstance(self.settings.wavelet, (pywt.ContinuousWavelet, pywt.Wavelet))
60
+ else pywt.DiscreteContinuousWavelet(self.settings.wavelet)
61
+ )
62
+ # Process wavelet integration
63
+ int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
64
+ int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
65
+
66
+ # Calculate scales and frequencies
67
+ if self.settings.frequencies is not None:
68
+ frequencies = np.sort(np.array(self.settings.frequencies))
69
+ scales = pywt.frequency2scale(
70
+ wavelet,
71
+ frequencies * message.axes[self.settings.axis].gain,
72
+ precision=precision,
147
73
  )
74
+ else:
75
+ scales = np.sort(self.settings.scales)[::-1]
76
+
77
+ self._state.neg_rt_scales = -np.sqrt(scales)[:, None]
78
+
79
+ # Convert to appropriate dtype
80
+ dt_data = message.data.dtype
81
+ dt_cplx = np.result_type(dt_data, np.complex64)
82
+ dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
83
+ int_psi = np.asarray(int_psi, dtype=dt_psi)
84
+ # Note: Currently int_psi cannot be made non-complex once it is complex.
85
+
86
+ # Calculate waves for each scale
87
+ wave_xvec = np.asarray(wave_xvec, dtype=message.data.real.dtype)
88
+ wave_range = wave_xvec[-1] - wave_xvec[0]
89
+ step = wave_xvec[1] - wave_xvec[0]
90
+ self._state.int_psi_scales = []
91
+ for scale in scales:
92
+ reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
93
+ if reix[-1] >= int_psi.size:
94
+ reix = np.extract(reix < int_psi.size, reix)
95
+ self._state.int_psi_scales.append(int_psi[reix][::-1])
96
+
97
+ # Setup filterbank generator
98
+ self._state.fbgen = filterbank(
99
+ self._state.int_psi_scales,
100
+ mode=FilterbankMode.CONV,
101
+ min_phase=self.settings.min_phase,
102
+ axis=self.settings.axis,
103
+ )
104
+
105
+ # Create output template
106
+ ax_idx = message.get_axis_idx(self.settings.axis)
107
+ in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
108
+ freqs = pywt.scale2frequency(wavelet, scales, precision) / message.axes[self.settings.axis].gain
109
+ dummy_shape = in_shape + (len(scales), 0)
110
+ self._state.template = AxisArray(
111
+ np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
112
+ dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["freq", self.settings.axis],
113
+ axes={
114
+ **message.axes,
115
+ "freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
116
+ },
117
+ key=message.key,
118
+ )
119
+ self._state.last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype)
148
120
 
149
- conv_msg = fbgen.send(msg_in)
121
+ def _process(self, message: AxisArray) -> AxisArray:
122
+ conv_msg = self._state.fbgen.send(message)
150
123
 
151
124
  # Prepend with last_conv_samp before doing diff
152
- dat = np.concatenate((last_conv_samp, conv_msg.data), axis=-1)
153
- coef = neg_rt_scales * np.diff(dat, axis=-1)
154
- # Store last_conv_samp for next iteration.
155
- last_conv_samp = conv_msg.data[..., -1:]
125
+ dat = np.concatenate((self._state.last_conv_samp, conv_msg.data), axis=-1)
126
+ coef = self._state.neg_rt_scales * np.diff(dat, axis=-1)
127
+ # Store last_conv_samp for next iteration
128
+ self._state.last_conv_samp = conv_msg.data[..., -1:]
156
129
 
157
- if template.data.dtype.kind != "c":
130
+ if self._state.template.data.dtype.kind != "c":
158
131
  coef = coef.real
159
132
 
160
133
  # pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
161
134
  # that luxury when streaming.
162
135
  # d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
163
136
  # coef = coef[..., math.floor(d):-math.ceil(d)]
164
- msg_out = replace(
165
- template, data=coef, axes={**template.axes, axis: msg_in.axes[axis]}
137
+ return replace(
138
+ self._state.template,
139
+ data=coef,
140
+ axes={
141
+ **self._state.template.axes,
142
+ self.settings.axis: message.axes[self.settings.axis],
143
+ },
166
144
  )
167
145
 
168
146
 
169
- class CWTSettings(ez.Settings):
170
- """
171
- Settings for :obj:`CWT`
172
- See :obj:`cwt` for argument details.
173
- """
174
-
175
- scales: list | tuple | npt.NDArray
176
- wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
177
- min_phase: MinPhaseMode = MinPhaseMode.NONE
178
- axis: str = "time"
147
+ class CWT(BaseTransformerUnit[CWTSettings, AxisArray, AxisArray, CWTTransformer]):
148
+ SETTINGS = CWTSettings
179
149
 
180
150
 
181
- class CWT(GenAxisArray):
182
- """
183
- :obj:`Unit` for :obj:`common_rereference`.
151
+ def cwt(
152
+ frequencies: list | tuple | npt.NDArray | None,
153
+ wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
154
+ min_phase: MinPhaseMode = MinPhaseMode.NONE,
155
+ axis: str = "time",
156
+ scales: list | tuple | npt.NDArray | None = None,
157
+ ) -> CWTTransformer:
184
158
  """
159
+ Perform a continuous wavelet transform.
160
+ The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
185
161
 
186
- SETTINGS = CWTSettings
162
+ Args:
163
+ frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
164
+ Note: frequencies will be sorted from smallest to largest.
165
+ wavelet: Wavelet object or name of wavelet to use.
166
+ min_phase: See filterbank MinPhaseMode for details.
167
+ axis: The target axis for operation. Note that this will be moved to the -1th dimension
168
+ because fft and matrix multiplication is much faster on the last axis.
169
+ This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
170
+ scales: The scales to use. If None, the scales will be calculated from the frequencies.
171
+ Note: Scales will be sorted from largest to smallest.
172
+ Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
173
+ `pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
187
174
 
188
- def construct_generator(self):
189
- self.STATE.gen = cwt(
190
- scales=self.SETTINGS.scales,
191
- wavelet=self.SETTINGS.wavelet,
192
- min_phase=self.SETTINGS.min_phase,
193
- axis=self.SETTINGS.axis,
175
+ Returns:
176
+ A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
177
+ and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
178
+ """
179
+ return CWTTransformer(
180
+ CWTSettings(
181
+ frequencies=frequencies,
182
+ wavelet=wavelet,
183
+ min_phase=min_phase,
184
+ axis=axis,
185
+ scales=scales,
194
186
  )
187
+ )