ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +133 -101
  6. ezmsg/sigproc/bandpower.py +64 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -84
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.1.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/wavelets.py CHANGED
@@ -6,190 +6,191 @@ import pywt
6
6
  import ezmsg.core as ez
7
7
  from ezmsg.util.messages.axisarray import AxisArray
8
8
  from ezmsg.util.messages.util import replace
9
- from ezmsg.util.generator import consumer
10
9
 
11
- from .base import GenAxisArray
10
+ from .base import (
11
+ BaseStatefulTransformer,
12
+ BaseTransformerUnit,
13
+ processor_state,
14
+ )
12
15
  from .filterbank import filterbank, FilterbankMode, MinPhaseMode
13
16
 
14
17
 
15
- @consumer
16
- def cwt(
17
- frequencies: list | tuple | npt.NDArray | None,
18
- wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
19
- min_phase: MinPhaseMode = MinPhaseMode.NONE,
20
- axis: str = "time",
21
- scales: list | tuple | npt.NDArray | None = None,
22
- ) -> typing.Generator[AxisArray, AxisArray, None]:
18
+ class CWTSettings(ez.Settings):
19
+ """
20
+ Settings for :obj:`CWT`
21
+ See :obj:`cwt` for argument details.
23
22
  """
24
- Perform a continuous wavelet transform.
25
- The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
26
23
 
27
- Args:
28
- frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
29
- Note: frequencies will be sorted from smallest to largest.
30
- wavelet: Wavelet object or name of wavelet to use.
31
- min_phase: See filterbank MinPhaseMode for details.
32
- axis: The target axis for operation. Note that this will be moved to the -1th dimension
33
- because fft and matrix multiplication is much faster on the last axis.
34
- This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
35
- scales: The scales to use. If None, the scales will be calculated from the frequencies.
36
- Note: Scales will be sorted from largest to smallest.
37
- Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
38
- `pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
24
+ frequencies: list | tuple | npt.NDArray | None
25
+ wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
26
+ min_phase: MinPhaseMode = MinPhaseMode.NONE
27
+ axis: str = "time"
28
+ scales: list | tuple | npt.NDArray | None = None
39
29
 
40
- Returns:
41
- A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
42
- and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
43
- """
44
- precision = 10
45
- msg_out: AxisArray | None = None
46
-
47
- # Check parameters
48
- if frequencies is None and scales is None:
49
- raise ValueError("Either frequencies or scales must be provided.")
50
- if frequencies is not None and scales is not None:
51
- raise ValueError("Only one of frequencies or scales can be provided.")
52
- if scales is not None:
53
- scales = np.sort(scales)[::-1]
54
- assert np.all(scales > 0), "scales must be positive."
55
- assert scales.ndim == 1, "scales must be a 1D list, tuple, or array."
56
-
57
- if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
58
- wavelet = pywt.DiscreteContinuousWavelet(wavelet)
59
-
60
- if frequencies is not None:
61
- frequencies = np.sort(frequencies)
62
- assert np.all(frequencies > 0), "frequencies must be positive."
63
- assert frequencies.ndim == 1, "frequencies must be a 1D list, tuple, or array."
64
-
65
- # State variables
30
+
31
+ @processor_state
32
+ class CWTState:
66
33
  neg_rt_scales: npt.NDArray | None = None
67
- int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
68
- int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
34
+ int_psi_scales: list[npt.NDArray] | None = None
69
35
  template: AxisArray | None = None
70
36
  fbgen: typing.Generator[AxisArray, AxisArray, None] | None = None
71
37
  last_conv_samp: npt.NDArray | None = None
72
38
 
73
- # Reset if input changed
74
- check_input = {
75
- "kind": None, # Need to recalc kernels at same complexity as input
76
- "gain": None, # Need to recalc freqs
77
- "shape": None, # Need to recalc template and buffer
78
- "key": None, # Buffer obsolete
79
- }
80
-
81
- while True:
82
- msg_in: AxisArray = yield msg_out
83
- ax_idx = msg_in.get_axis_idx(axis)
84
- in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1 :]
85
-
86
- b_reset = msg_in.data.dtype.kind != check_input["kind"]
87
- b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
88
- b_reset = b_reset or in_shape != check_input["shape"]
89
- b_reset = b_reset or msg_in.key != check_input["key"]
90
- b_reset = b_reset and msg_in.data.size > 0
91
- if b_reset:
92
- check_input["kind"] = msg_in.data.dtype.kind
93
- check_input["gain"] = msg_in.axes[axis].gain
94
- check_input["shape"] = in_shape
95
- check_input["key"] = msg_in.key
96
-
97
- if frequencies is not None:
98
- scales = pywt.frequency2scale(
99
- wavelet, frequencies * msg_in.axes[axis].gain, precision=precision
100
- )
101
- neg_rt_scales = -np.sqrt(scales)[:, None]
102
-
103
- # convert int_psi, wave_xvec to the same precision as the data
104
- dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
105
- dt_cplx = np.result_type(dt_data, np.complex64)
106
- dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
107
- int_psi = np.asarray(int_psi, dtype=dt_psi)
108
- # TODO: Currently int_psi cannot be made non-complex once it is complex.
109
-
110
- # Calculate waves for each scale
111
- wave_xvec = np.asarray(wave_xvec, dtype=msg_in.data.real.dtype)
112
- wave_range = wave_xvec[-1] - wave_xvec[0]
113
- step = wave_xvec[1] - wave_xvec[0]
114
- int_psi_scales = []
115
- for scale in scales:
116
- reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
117
- if reix[-1] >= int_psi.size:
118
- reix = np.extract(reix < int_psi.size, reix)
119
- int_psi_scales.append(int_psi[reix][::-1])
120
-
121
- # CONV is probably best because we often get huge kernels.
122
- fbgen = filterbank(
123
- int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis
124
- )
125
39
 
126
- freqs = (
127
- pywt.scale2frequency(wavelet, scales, precision)
128
- / msg_in.axes[axis].gain
129
- )
130
- # Create output template
131
- dummy_shape = in_shape + (len(scales), 0)
132
- template = AxisArray(
133
- np.zeros(
134
- dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data
135
- ),
136
- dims=msg_in.dims[:ax_idx] + msg_in.dims[ax_idx + 1 :] + ["freq", axis],
137
- axes={
138
- **msg_in.axes,
139
- "freq": AxisArray.CoordinateAxis(
140
- unit="Hz", data=freqs, dims=["freq"]
141
- ),
142
- },
143
- key=msg_in.key,
40
+ class CWTTransformer(
41
+ BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]
42
+ ):
43
+ def _hash_message(self, message: AxisArray) -> int:
44
+ ax_idx = message.get_axis_idx(self.settings.axis)
45
+ in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
46
+ return hash(
47
+ (
48
+ message.data.dtype.kind,
49
+ message.axes[self.settings.axis].gain,
50
+ in_shape,
51
+ message.key,
144
52
  )
145
- last_conv_samp = np.zeros(
146
- dummy_shape[:-1] + (1,), dtype=template.data.dtype
53
+ )
54
+
55
+ def _reset_state(self, message: AxisArray) -> None:
56
+ precision = 10
57
+
58
+ # Process wavelet
59
+ wavelet = (
60
+ self.settings.wavelet
61
+ if isinstance(self.settings.wavelet, (pywt.ContinuousWavelet, pywt.Wavelet))
62
+ else pywt.DiscreteContinuousWavelet(self.settings.wavelet)
63
+ )
64
+ # Process wavelet integration
65
+ int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
66
+ int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
67
+
68
+ # Calculate scales and frequencies
69
+ if self.settings.frequencies is not None:
70
+ frequencies = np.sort(np.array(self.settings.frequencies))
71
+ scales = pywt.frequency2scale(
72
+ wavelet,
73
+ frequencies * message.axes[self.settings.axis].gain,
74
+ precision=precision,
147
75
  )
76
+ else:
77
+ scales = np.sort(self.settings.scales)[::-1]
78
+
79
+ self._state.neg_rt_scales = -np.sqrt(scales)[:, None]
80
+
81
+ # Convert to appropriate dtype
82
+ dt_data = message.data.dtype
83
+ dt_cplx = np.result_type(dt_data, np.complex64)
84
+ dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
85
+ int_psi = np.asarray(int_psi, dtype=dt_psi)
86
+ # Note: Currently int_psi cannot be made non-complex once it is complex.
87
+
88
+ # Calculate waves for each scale
89
+ wave_xvec = np.asarray(wave_xvec, dtype=message.data.real.dtype)
90
+ wave_range = wave_xvec[-1] - wave_xvec[0]
91
+ step = wave_xvec[1] - wave_xvec[0]
92
+ self._state.int_psi_scales = []
93
+ for scale in scales:
94
+ reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
95
+ if reix[-1] >= int_psi.size:
96
+ reix = np.extract(reix < int_psi.size, reix)
97
+ self._state.int_psi_scales.append(int_psi[reix][::-1])
98
+
99
+ # Setup filterbank generator
100
+ self._state.fbgen = filterbank(
101
+ self._state.int_psi_scales,
102
+ mode=FilterbankMode.CONV,
103
+ min_phase=self.settings.min_phase,
104
+ axis=self.settings.axis,
105
+ )
148
106
 
149
- conv_msg = fbgen.send(msg_in)
107
+ # Create output template
108
+ ax_idx = message.get_axis_idx(self.settings.axis)
109
+ in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
110
+ freqs = (
111
+ pywt.scale2frequency(wavelet, scales, precision)
112
+ / message.axes[self.settings.axis].gain
113
+ )
114
+ dummy_shape = in_shape + (len(scales), 0)
115
+ self._state.template = AxisArray(
116
+ np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
117
+ dims=message.dims[:ax_idx]
118
+ + message.dims[ax_idx + 1 :]
119
+ + ["freq", self.settings.axis],
120
+ axes={
121
+ **message.axes,
122
+ "freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
123
+ },
124
+ key=message.key,
125
+ )
126
+ self._state.last_conv_samp = np.zeros(
127
+ dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype
128
+ )
129
+
130
+ def _process(self, message: AxisArray) -> AxisArray:
131
+ conv_msg = self._state.fbgen.send(message)
150
132
 
151
133
  # Prepend with last_conv_samp before doing diff
152
- dat = np.concatenate((last_conv_samp, conv_msg.data), axis=-1)
153
- coef = neg_rt_scales * np.diff(dat, axis=-1)
154
- # Store last_conv_samp for next iteration.
155
- last_conv_samp = conv_msg.data[..., -1:]
134
+ dat = np.concatenate((self._state.last_conv_samp, conv_msg.data), axis=-1)
135
+ coef = self._state.neg_rt_scales * np.diff(dat, axis=-1)
136
+ # Store last_conv_samp for next iteration
137
+ self._state.last_conv_samp = conv_msg.data[..., -1:]
156
138
 
157
- if template.data.dtype.kind != "c":
139
+ if self._state.template.data.dtype.kind != "c":
158
140
  coef = coef.real
159
141
 
160
142
  # pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
161
143
  # that luxury when streaming.
162
144
  # d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
163
145
  # coef = coef[..., math.floor(d):-math.ceil(d)]
164
- msg_out = replace(
165
- template, data=coef, axes={**template.axes, axis: msg_in.axes[axis]}
146
+ return replace(
147
+ self._state.template,
148
+ data=coef,
149
+ axes={
150
+ **self._state.template.axes,
151
+ self.settings.axis: message.axes[self.settings.axis],
152
+ },
166
153
  )
167
154
 
168
155
 
169
- class CWTSettings(ez.Settings):
170
- """
171
- Settings for :obj:`CWT`
172
- See :obj:`cwt` for argument details.
173
- """
174
- frequencies: list | tuple | npt.NDArray | None
175
- wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet
176
- min_phase: MinPhaseMode = MinPhaseMode.NONE
177
- axis: str = "time"
178
- scales: list | tuple | npt.NDArray | None = None
156
+ class CWT(BaseTransformerUnit[CWTSettings, AxisArray, AxisArray, CWTTransformer]):
157
+ SETTINGS = CWTSettings
179
158
 
180
159
 
181
- class CWT(GenAxisArray):
182
- """
183
- :obj:`Unit` for :obj:`common_rereference`.
160
+ def cwt(
161
+ frequencies: list | tuple | npt.NDArray | None,
162
+ wavelet: str | pywt.ContinuousWavelet | pywt.Wavelet,
163
+ min_phase: MinPhaseMode = MinPhaseMode.NONE,
164
+ axis: str = "time",
165
+ scales: list | tuple | npt.NDArray | None = None,
166
+ ) -> CWTTransformer:
184
167
  """
168
+ Perform a continuous wavelet transform.
169
+ The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
185
170
 
186
- SETTINGS = CWTSettings
171
+ Args:
172
+ frequencies: The wavelet frequencies to use in Hz. If `None` provided then the scales will be used.
173
+ Note: frequencies will be sorted from smallest to largest.
174
+ wavelet: Wavelet object or name of wavelet to use.
175
+ min_phase: See filterbank MinPhaseMode for details.
176
+ axis: The target axis for operation. Note that this will be moved to the -1th dimension
177
+ because fft and matrix multiplication is much faster on the last axis.
178
+ This axis must be in the msg.axes and it must be of type AxisArray.LinearAxis.
179
+ scales: The scales to use. If None, the scales will be calculated from the frequencies.
180
+ Note: Scales will be sorted from largest to smallest.
181
+ Note: Use of scales is deprecated in favor of frequencies. Convert scales to frequencies using
182
+ `pywt.scale2frequency(wavelet, scales, precision=10) * fs` where fs is the sampling frequency.
187
183
 
188
- def construct_generator(self):
189
- self.STATE.gen = cwt(
190
- frequencies=self.SETTINGS.frequencies,
191
- wavelet=self.SETTINGS.wavelet,
192
- min_phase=self.SETTINGS.min_phase,
193
- axis=self.SETTINGS.axis,
194
- scales=self.SETTINGS.scales,
184
+ Returns:
185
+ A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
186
+ and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
187
+ """
188
+ return CWTTransformer(
189
+ CWTSettings(
190
+ frequencies=frequencies,
191
+ wavelet=wavelet,
192
+ min_phase=min_phase,
193
+ axis=axis,
194
+ scales=scales,
195
195
  )
196
+ )