ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +119 -104
  6. ezmsg/sigproc/bandpower.py +58 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -84
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -10,11 +10,14 @@ import numpy.typing as npt
10
10
  import ezmsg.core as ez
11
11
  from ezmsg.util.messages.axisarray import AxisArray
12
12
  from ezmsg.util.messages.util import replace
13
- from ezmsg.util.generator import consumer
14
13
 
15
- from .base import GenAxisArray
14
+ from .base import (
15
+ BaseStatefulTransformer,
16
+ BaseTransformerUnit,
17
+ processor_state,
18
+ )
16
19
  from .spectrum import OptionsEnum
17
- from .window import windowing
20
+ from .window import WindowTransformer
18
21
 
19
22
 
20
23
  class FilterbankMode(OptionsEnum):
@@ -34,248 +37,293 @@ class MinPhaseMode(OptionsEnum):
34
37
  # HOMOMORPHICFULL = "Like HOMOMORPHIC, but uses the full number of taps and same magnitude"
35
38
 
36
39
 
37
- @consumer
38
- def filterbank(
39
- kernels: list[npt.NDArray] | tuple[npt.NDArray, ...],
40
- mode: FilterbankMode = FilterbankMode.CONV,
41
- min_phase: MinPhaseMode = MinPhaseMode.NONE,
42
- axis: str = "time",
43
- new_axis: str = "kernel",
44
- ) -> typing.Generator[AxisArray, AxisArray, None]:
40
+ class FilterbankSettings(ez.Settings):
41
+ kernels: list[npt.NDArray] | tuple[npt.NDArray, ...]
42
+
43
+ mode: FilterbankMode = FilterbankMode.CONV
44
+ """
45
+ "conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
46
+ fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
47
+ incur a delay equal to the window length, which is larger than the largest kernel.
48
+ conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
49
+ and thus can provide shorter latency updates.
45
50
  """
46
- Perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
47
- This is intended to be used during online processing, therefore both direct and fft convolutions
48
- use the overlap-add method.
49
- Args:
50
- kernels:
51
- mode: "conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
52
- fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
53
- incur a delay equal to the window length, which is larger than the largest kernel.
54
- conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
55
- and thus can provide shorter latency updates.
56
- min_phase: If not None, convert the kernels to minimum-phase equivalents. Valid options are
57
- 'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
58
- See `scipy.signal.minimum_phase` for details.
59
- axis: The name of the axis to operate on. This should usually be "time".
60
- new_axis: The name of the new axis corresponding to the kernel index.
61
-
62
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
63
- with the data payload containing the absolute value of the input :obj:`AxisArray` data.
64
51
 
52
+ min_phase: MinPhaseMode = MinPhaseMode.NONE
53
+ """
54
+ If not None, convert the kernels to minimum-phase equivalents. Valid options are
55
+ 'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
56
+ See `scipy.signal.minimum_phase` for details.
65
57
  """
66
- msg_out: AxisArray | None = None
67
58
 
68
- # State variables
59
+ axis: str = "time"
60
+ """The name of the axis to operate on. This should usually be "time"."""
61
+
62
+ new_axis: str = "kernel"
63
+ """The name of the new axis corresponding to the kernel index."""
64
+
65
+
66
+ @processor_state
67
+ class FilterbankState:
68
+ tail: npt.NDArray | None = None
69
69
  template: AxisArray | None = None
70
+ dest_arr: npt.NDArray | None = None
71
+ prep_kerns: npt.NDArray | list[npt.NDArray] | None = None
72
+ windower: WindowTransformer | None = None
73
+ fft: typing.Callable | None = None
74
+ ifft: typing.Callable | None = None
75
+ nfft: int | None = None
76
+ infft: int | None = None
77
+ overlap: int | None = None
78
+ mode: FilterbankMode | None = None
79
+
80
+
81
+ class FilterbankTransformer(
82
+ BaseStatefulTransformer[FilterbankSettings, AxisArray, AxisArray, FilterbankState]
83
+ ):
84
+ def _hash_message(self, message: AxisArray) -> int:
85
+ axis = self.settings.axis or message.dims[0]
86
+ gain = message.axes[axis].gain if axis in message.axes else 1.0
87
+ targ_ax_ix = message.get_axis_idx(axis)
88
+ in_shape = (
89
+ message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
90
+ )
91
+
92
+ return hash(
93
+ (
94
+ message.key,
95
+ gain
96
+ if self.settings.mode in [FilterbankMode.FFT, FilterbankMode.AUTO]
97
+ else None,
98
+ message.data.dtype.kind,
99
+ in_shape,
100
+ )
101
+ )
70
102
 
71
- # Reset if these change
72
- check_input = {
73
- "key": None,
74
- "template": None,
75
- "gain": None,
76
- "kind": None,
77
- "shape": None,
78
- }
79
-
80
- while True:
81
- msg_in: AxisArray = yield msg_out
82
-
83
- axis = axis or msg_in.dims[0]
84
- gain = msg_in.axes[axis].gain if axis in msg_in.axes else 1.0
85
- targ_ax_ix = msg_in.get_axis_idx(axis)
86
- in_shape = msg_in.data.shape[:targ_ax_ix] + msg_in.data.shape[targ_ax_ix + 1 :]
87
-
88
- b_reset = msg_in.key != check_input["key"]
89
- b_reset = b_reset or (
90
- gain != check_input["gain"]
91
- and mode in [FilterbankMode.FFT, FilterbankMode.AUTO]
103
+ def _reset_state(self, message: AxisArray) -> None:
104
+ axis = self.settings.axis or message.dims[0]
105
+ gain = message.axes[axis].gain if axis in message.axes else 1.0
106
+ targ_ax_ix = message.get_axis_idx(axis)
107
+ in_shape = (
108
+ message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
92
109
  )
93
- b_reset = b_reset or msg_in.data.dtype.kind != check_input["kind"]
94
- b_reset = b_reset or in_shape != check_input["shape"]
95
- if b_reset:
96
- check_input["key"] = msg_in.key
97
- check_input["gain"] = gain
98
- check_input["kind"] = msg_in.data.dtype.kind
99
- check_input["shape"] = in_shape
100
-
101
- if min_phase != MinPhaseMode.NONE:
102
- method, half = {
103
- MinPhaseMode.HILBERT: ("hilbert", False),
104
- MinPhaseMode.HOMOMORPHIC: ("homomorphic", False),
105
- # MinPhaseMode.HOMOMORPHICFULL: ("homomorphic", True),
106
- }[min_phase]
107
- kernels = [
108
- sps.minimum_phase(
109
- k, method=method
110
- ) # , half=half) -- half requires later scipy >= 1.14
111
- for k in kernels
112
- ]
113
-
114
- # Determine if this will be operating with complex data.
115
- b_complex = msg_in.data.dtype.kind == "c" or any(
116
- [_.dtype.kind == "c" for _ in kernels]
110
+
111
+ kernels = self.settings.kernels
112
+ if self.settings.min_phase != MinPhaseMode.NONE:
113
+ method, half = {
114
+ MinPhaseMode.HILBERT: ("hilbert", False),
115
+ MinPhaseMode.HOMOMORPHIC: ("homomorphic", False),
116
+ # MinPhaseMode.HOMOMORPHICFULL: ("homomorphic", True),
117
+ }[self.settings.min_phase]
118
+ kernels = [sps.minimum_phase(k, method=method) for k in kernels]
119
+
120
+ # Determine if this will be operating with complex data.
121
+ b_complex = message.data.dtype.kind == "c" or any(
122
+ [_.dtype.kind == "c" for _ in kernels]
123
+ )
124
+
125
+ # Calculate window_dur, window_shift, nfft
126
+ max_kernel_len = max([_.size for _ in kernels])
127
+ # From sps._calc_oa_lens, where s2=max_kernel_len,:
128
+ # fallback_nfft = n_input + max_kernel_len - 1, but n_input is unbound.
129
+ self._state.overlap = max_kernel_len - 1
130
+
131
+ # Prepare previous iteration's overlap tail to add to input -- all zeros.
132
+ tail_shape = in_shape + (len(kernels), self._state.overlap)
133
+ self._state.tail = np.zeros(
134
+ tail_shape, dtype="complex" if b_complex else "float"
135
+ )
136
+
137
+ # Prepare output template -- kernels axis immediately before the target axis
138
+ dummy_shape = in_shape + (len(kernels), 0)
139
+ self._state.template = AxisArray(
140
+ data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
141
+ dims=message.dims[:targ_ax_ix]
142
+ + message.dims[targ_ax_ix + 1 :]
143
+ + [self.settings.new_axis, axis],
144
+ axes=message.axes.copy(),
145
+ key=message.key,
146
+ )
147
+
148
+ # Determine optimal mode. Assumes 100 msec chunks.
149
+ self._state.mode = self.settings.mode
150
+ if self._state.mode == FilterbankMode.AUTO:
151
+ # concatenate kernels into 1 mega kernel then check what's faster.
152
+ # Will typically return fft when combined kernel length is > 1500.
153
+ concat_kernel = np.concatenate(kernels)
154
+ n_dummy = max(2 * len(concat_kernel), int(0.1 / gain))
155
+ dummy_arr = np.zeros(n_dummy)
156
+ self._state.mode = (
157
+ FilterbankMode.CONV
158
+ if sps.choose_conv_method(dummy_arr, concat_kernel, mode="full")
159
+ == "direct"
160
+ else FilterbankMode.FFT
117
161
  )
118
162
 
119
- # Calculate window_dur, window_shift, nfft
120
- max_kernel_len = max([_.size for _ in kernels])
121
- # From sps._calc_oa_lens, where s2=max_kernel_len,:
122
- # fallback_nfft = n_input + max_kernel_len - 1, but n_input is unbound.
123
- overlap = max_kernel_len - 1
124
-
125
- # Prepare previous iteration's overlap tail to add to input -- all zeros.
126
- tail_shape = in_shape + (len(kernels), overlap)
127
- tail = np.zeros(tail_shape, dtype="complex" if b_complex else "float")
128
-
129
- # Prepare output template -- kernels axis immediately before the target axis
130
- dummy_shape = in_shape + (len(kernels), 0)
131
- template = AxisArray(
132
- data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
133
- dims=msg_in.dims[:targ_ax_ix]
134
- + msg_in.dims[targ_ax_ix + 1 :]
135
- + [new_axis, axis],
136
- axes=msg_in.axes.copy(), # We do not have info for kernel/filter axis :(.
137
- key=msg_in.key,
163
+ if self._state.mode == FilterbankMode.CONV:
164
+ # Preallocate memory for convolution result and overlap-add
165
+ dest_shape = in_shape + (
166
+ len(kernels),
167
+ self._state.overlap + message.data.shape[targ_ax_ix],
138
168
  )
169
+ self._state.dest_arr = np.zeros(
170
+ dest_shape, dtype="complex" if b_complex else "float"
171
+ )
172
+ self._state.prep_kerns = kernels
173
+ else: # FFT mode
174
+ # Calculate optimal nfft and windowing size.
175
+ opt_size = (
176
+ -self._state.overlap
177
+ * lambertw(-1 / (2 * math.e * self._state.overlap), k=-1).real
178
+ )
179
+ self._state.nfft = sp_fft.next_fast_len(math.ceil(opt_size))
180
+ win_len = self._state.nfft - self._state.overlap
181
+ # infft same as nfft. Keeping as separate variable because I might need it again.
182
+ self._state.infft = win_len + self._state.overlap
183
+
184
+ # Create windowing node.
185
+ # Note: We could do windowing manually to avoid the overhead of the message structure,
186
+ # but windowing is difficult to do correctly, so we lean on the heavily-tested `windowing` generator.
187
+ win_dur = win_len * gain
188
+ self._state.windower = WindowTransformer(
189
+ axis=axis,
190
+ newaxis="win",
191
+ window_dur=win_dur,
192
+ window_shift=win_dur,
193
+ zero_pad_until="none",
194
+ )
195
+
196
+ # Windowing output has an extra "win" dimension, so we need our tail to match.
197
+ self._state.tail = np.expand_dims(self._state.tail, -2)
139
198
 
140
- # Determine optimal mode. Assumes 100 msec chunks.
141
- if mode == FilterbankMode.AUTO:
142
- # concatenate kernels into 1 mega kernel then check what's faster.
143
- # Will typically return fft when combined kernel length is > 1500.
144
- concat_kernel = np.concatenate(kernels)
145
- n_dummy = max(2 * len(concat_kernel), int(0.1 / gain))
146
- dummy_arr = np.zeros(n_dummy)
147
- mode = sps.choose_conv_method(dummy_arr, concat_kernel, mode="full")
148
- mode = FilterbankMode.CONV if mode == "direct" else FilterbankMode.FFT
149
-
150
- if mode == FilterbankMode.CONV:
151
- # Preallocate memory for convolution result and overlap-add
152
- dest_shape = in_shape + (
153
- len(kernels),
154
- overlap + msg_in.data.shape[targ_ax_ix],
199
+ # Prepare fft functions
200
+ # Note: We could instead use `spectrum` but this adds overhead in creating the message structure
201
+ # for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
202
+ # more fft backends.
203
+ if b_complex:
204
+ self._state.fft = functools.partial(
205
+ sp_fft.fft, n=self._state.nfft, norm="backward"
155
206
  )
156
- dest_arr = np.zeros(
157
- dest_shape, dtype="complex" if b_complex else "float"
207
+ self._state.ifft = functools.partial(
208
+ sp_fft.ifft, n=self._state.infft, norm="backward"
158
209
  )
159
-
160
- elif mode == FilterbankMode.FFT:
161
- # Calculate optimal nfft and windowing size.
162
- opt_size = -overlap * lambertw(-1 / (2 * math.e * overlap), k=-1).real
163
- nfft = sp_fft.next_fast_len(math.ceil(opt_size))
164
- win_len = nfft - overlap
165
- # infft same as nfft. Keeping as separate variable because I might need it again.
166
- infft = win_len + overlap
167
-
168
- # Create windowing node.
169
- # Note: We could do windowing manually to avoid the overhead of the message structure,
170
- # but windowing is difficult to do correctly, so we lean on the heavily-tested `windowing` generator.
171
- win_dur = win_len * gain
172
- wingen = windowing(
173
- axis=axis,
174
- newaxis="win", # Big data chunks might yield more than 1 window.
175
- window_dur=win_dur,
176
- window_shift=win_dur, # Tumbling (not sliding) windows expected!
177
- zero_pad_until="none",
210
+ else:
211
+ self._state.fft = functools.partial(
212
+ sp_fft.rfft, n=self._state.nfft, norm="backward"
213
+ )
214
+ self._state.ifft = functools.partial(
215
+ sp_fft.irfft, n=self._state.infft, norm="backward"
178
216
  )
179
217
 
180
- # Windowing output has an extra "win" dimension, so we need our tail to match.
181
- tail = np.expand_dims(tail, -2)
182
-
183
- # Prepare fft functions
184
- # Note: We could instead use `spectrum` but this adds overhead in creating the message structure
185
- # for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
186
- # more fft backends.
187
- if b_complex:
188
- fft = functools.partial(sp_fft.fft, n=nfft, norm="backward")
189
- ifft = functools.partial(sp_fft.ifft, n=infft, norm="backward")
190
- else:
191
- fft = functools.partial(sp_fft.rfft, n=nfft, norm="backward")
192
- ifft = functools.partial(sp_fft.irfft, n=infft, norm="backward")
193
-
194
- # Calculate fft of kernels
195
- prep_kerns = np.array([fft(_) for _ in kernels])
196
- prep_kerns = np.expand_dims(prep_kerns, -2)
197
- # TODO: If fft_kernels have significant stretches of zeros, convert to sparse array.
218
+ # Calculate fft of kernels
219
+ self._state.prep_kerns = np.array([self._state.fft(_) for _ in kernels])
220
+ self._state.prep_kerns = np.expand_dims(self._state.prep_kerns, -2)
221
+ # TODO: If fft_kernels have significant stretches of zeros, convert to sparse array.
222
+
223
+ def _process(self, message: AxisArray) -> AxisArray:
224
+ axis = self.settings.axis or message.dims[0]
225
+ targ_ax_ix = message.get_axis_idx(axis)
198
226
 
199
227
  # Make sure target axis is in -1th position.
200
- if targ_ax_ix != (msg_in.data.ndim - 1):
201
- in_dat = np.moveaxis(msg_in.data, targ_ax_ix, -1)
202
- if mode == FilterbankMode.FFT:
203
- # Fix msg_in .dims because we will pass it to wingen
228
+ if targ_ax_ix != (message.data.ndim - 1):
229
+ in_dat = np.moveaxis(message.data, targ_ax_ix, -1)
230
+ if self._state.mode == FilterbankMode.FFT:
231
+ # Fix message.dims because we will pass it to windower
204
232
  move_dims = (
205
- msg_in.dims[:targ_ax_ix] + msg_in.dims[targ_ax_ix + 1 :] + [axis]
233
+ message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [axis]
206
234
  )
207
- msg_in = replace(msg_in, data=in_dat, dims=move_dims)
235
+ message = replace(message, data=in_dat, dims=move_dims)
208
236
  else:
209
- in_dat = msg_in.data
210
-
211
- if mode == FilterbankMode.CONV:
212
- n_dest = in_dat.shape[-1] + overlap
213
- if dest_arr.shape[-1] < n_dest:
214
- pad = np.zeros(dest_arr.shape[:-1] + (n_dest - dest_arr.shape[-1],))
215
- dest_arr = np.concatenate(dest_arr, pad, axis=-1)
216
- dest_arr.fill(0)
237
+ in_dat = message.data
238
+
239
+ if self._state.mode == FilterbankMode.CONV:
240
+ n_dest = in_dat.shape[-1] + self._state.overlap
241
+ if self._state.dest_arr.shape[-1] < n_dest:
242
+ pad = np.zeros(
243
+ self._state.dest_arr.shape[:-1]
244
+ + (n_dest - self._state.dest_arr.shape[-1],)
245
+ )
246
+ self._state.dest_arr = np.concatenate(
247
+ [self._state.dest_arr, pad], axis=-1
248
+ )
249
+ self._state.dest_arr.fill(0)
250
+
217
251
  # Note: I tried several alternatives to this loop; all were slower than this.
218
252
  # numba.jit; stride_tricks + np.einsum; threading. Latter might be better with Python 3.13.
219
- for k_ix, k in enumerate(kernels):
253
+ for k_ix, k in enumerate(self._state.prep_kerns):
220
254
  n_out = in_dat.shape[-1] + k.shape[-1] - 1
221
- dest_arr[..., k_ix, :n_out] = np.apply_along_axis(
255
+ self._state.dest_arr[..., k_ix, :n_out] = np.apply_along_axis(
222
256
  np.convolve, -1, in_dat, k, mode="full"
223
257
  )
224
- dest_arr[..., :overlap] += tail # Add previous overlap
225
- new_tail = dest_arr[..., in_dat.shape[-1] : n_dest]
258
+ self._state.dest_arr[..., : self._state.overlap] += self._state.tail
259
+ new_tail = self._state.dest_arr[..., in_dat.shape[-1] : n_dest]
226
260
  if new_tail.size > 0:
227
261
  # COPY overlap for next iteration
228
- tail = new_tail.copy()
229
- res = dest_arr[..., : in_dat.shape[-1]].copy()
230
- elif mode == FilterbankMode.FFT:
262
+ self._state.tail = new_tail.copy()
263
+ res = self._state.dest_arr[..., : in_dat.shape[-1]].copy()
264
+ else: # FFT mode
231
265
  # Slice into non-overlapping windows
232
- win_msg = wingen.send(msg_in)
233
- # Calculate spectra of each window
234
- spec_dat = fft(win_msg.data, axis=-1)
266
+ win_msg = self._state.windower.send(message)
267
+ # Calculate spectrum of each window
268
+ spec_dat = self._state.fft(win_msg.data, axis=-1)
235
269
  # Insert axis for filters
236
270
  spec_dat = np.expand_dims(spec_dat, -3)
237
271
 
238
272
  # Do the FFT convolution
239
273
  # TODO: handle fft_kernels being sparse. Maybe need np.dot.
240
- conv_spec = spec_dat * prep_kerns
241
- overlapped = ifft(conv_spec, axis=-1)
274
+ conv_spec = spec_dat * self._state.prep_kerns
275
+ overlapped = self._state.ifft(conv_spec, axis=-1)
242
276
 
243
277
  # Do the overlap-add on the `axis` axis
244
278
  # Previous iteration's tail:
245
- overlapped[..., :1, :overlap] += tail
279
+ overlapped[..., :1, : self._state.overlap] += self._state.tail
246
280
  # window-to-window:
247
- overlapped[..., 1:, :overlap] += overlapped[..., :-1, -overlap:]
281
+ overlapped[..., 1:, : self._state.overlap] += overlapped[
282
+ ..., :-1, -self._state.overlap :
283
+ ]
248
284
  # Save tail:
249
- new_tail = overlapped[..., -1:, -overlap:]
285
+ new_tail = overlapped[..., -1:, -self._state.overlap :]
250
286
  if new_tail.size > 0:
251
287
  # All of the above code works if input is size-zero, but we don't want to save a zero-size tail.
252
- tail = new_tail # Save the tail for the next iteration.
288
+ self._state.tail = new_tail
253
289
  # Concat over win axis, without overlap.
254
- res = overlapped[..., :-overlap].reshape(overlapped.shape[:-2] + (-1,))
290
+ res = overlapped[..., : -self._state.overlap].reshape(
291
+ overlapped.shape[:-2] + (-1,)
292
+ )
255
293
 
256
- msg_out = replace(
257
- template, data=res, axes={**template.axes, axis: msg_in.axes[axis]}
294
+ return replace(
295
+ self._state.template,
296
+ data=res,
297
+ axes={**self._state.template.axes, axis: message.axes[axis]},
258
298
  )
259
299
 
260
300
 
261
- class FilterbankSettings(ez.Settings):
262
- kernels: list[npt.NDArray] | tuple[npt.NDArray, ...]
263
- mode: FilterbankMode = FilterbankMode.CONV
264
- min_phase: MinPhaseMode = MinPhaseMode.NONE
265
- axis: str = "time"
266
-
267
-
268
- class Filterbank(GenAxisArray):
269
- """Unit for :obj:`spectrum`"""
270
-
301
+ class Filterbank(
302
+ BaseTransformerUnit[FilterbankSettings, AxisArray, AxisArray, FilterbankTransformer]
303
+ ):
271
304
  SETTINGS = FilterbankSettings
272
305
 
273
- INPUT_SETTINGS = ez.InputStream(FilterbankSettings)
274
306
 
275
- def construct_generator(self):
276
- self.STATE.gen = filterbank(
277
- kernels=self.SETTINGS.kernels,
278
- mode=self.SETTINGS.mode,
279
- min_phase=self.SETTINGS.min_phase,
280
- axis=self.SETTINGS.axis,
307
+ def filterbank(
308
+ kernels: list[npt.NDArray] | tuple[npt.NDArray, ...],
309
+ mode: FilterbankMode = FilterbankMode.CONV,
310
+ min_phase: MinPhaseMode = MinPhaseMode.NONE,
311
+ axis: str = "time",
312
+ new_axis: str = "kernel",
313
+ ) -> FilterbankTransformer:
314
+ """
315
+ Perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
316
+ This is intended to be used during online processing, therefore both direct and fft convolutions
317
+ use the overlap-add method.
318
+
319
+ Returns: :obj:`FilterbankTransformer`.
320
+ """
321
+ return FilterbankTransformer(
322
+ settings=FilterbankSettings(
323
+ kernels=kernels,
324
+ mode=mode,
325
+ min_phase=min_phase,
326
+ axis=axis,
327
+ new_axis=new_axis,
281
328
  )
329
+ )
ezmsg/sigproc/math/abs.py CHANGED
@@ -1,34 +1,29 @@
1
- import typing
2
-
3
1
  import numpy as np
4
- import ezmsg.core as ez
5
- from ezmsg.util.generator import consumer
6
2
  from ezmsg.util.messages.axisarray import AxisArray
7
3
  from ezmsg.util.messages.util import replace
8
4
 
9
- from ..base import GenAxisArray
5
+ from ..base import BaseTransformer, BaseTransformerUnit
10
6
 
11
7
 
12
- @consumer
13
- def abs() -> typing.Generator[AxisArray, AxisArray, None]:
14
- """
15
- Take the absolute value of the data. See :obj:`np.abs` for more details.
8
+ class AbsSettings:
9
+ pass
16
10
 
17
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
18
- with the data payload containing the absolute value of the input :obj:`AxisArray` data.
19
- """
20
- msg_out = AxisArray(np.array([]), dims=[""])
21
- while True:
22
- msg_in: AxisArray = yield msg_out
23
- msg_out = replace(msg_in, data=np.abs(msg_in.data))
24
11
 
12
+ class AbsTransformer(BaseTransformer[None, AxisArray, AxisArray]):
13
+ def _process(self, message: AxisArray) -> AxisArray:
14
+ return replace(message, data=np.abs(message.data))
25
15
 
26
- class AbsSettings(ez.Settings):
27
- pass
28
16
 
17
+ class Abs(
18
+ BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]
19
+ ): ... # SETTINGS = None
29
20
 
30
- class Abs(GenAxisArray):
31
- SETTINGS = AbsSettings
32
21
 
33
- def construct_generator(self):
34
- self.STATE.gen = abs()
22
+ def abs() -> AbsTransformer:
23
+ """
24
+ Take the absolute value of the data. See :obj:`np.abs` for more details.
25
+
26
+ Returns: :obj:`AbsTransformer`.
27
+
28
+ """
29
+ return AbsTransformer()
@@ -1,16 +1,32 @@
1
- import typing
2
-
3
1
  import numpy as np
4
2
  import ezmsg.core as ez
5
- from ezmsg.util.generator import consumer
6
3
  from ezmsg.util.messages.axisarray import AxisArray
7
4
  from ezmsg.util.messages.util import replace
8
5
 
9
- from ..base import GenAxisArray
6
+ from ..base import BaseTransformer, BaseTransformerUnit
7
+
8
+
9
+ class ClipSettings(ez.Settings):
10
+ a_min: float
11
+ """Lower clip bound."""
10
12
 
13
+ a_max: float
14
+ """Upper clip bound."""
15
+
16
+
17
+ class ClipTransformer(BaseTransformer[ClipSettings, AxisArray, AxisArray]):
18
+ def _process(self, message: AxisArray) -> AxisArray:
19
+ return replace(
20
+ message,
21
+ data=np.clip(message.data, self.settings.a_min, self.settings.a_max),
22
+ )
23
+
24
+
25
+ class Clip(BaseTransformerUnit[ClipSettings, AxisArray, AxisArray, ClipTransformer]):
26
+ SETTINGS = ClipSettings
11
27
 
12
- @consumer
13
- def clip(a_min: float, a_max: float) -> typing.Generator[AxisArray, AxisArray, None]:
28
+
29
+ def clip(a_min: float, a_max: float) -> ClipTransformer:
14
30
  """
15
31
  Clips the data to be within the specified range. See :obj:`np.clip` for more details.
16
32
 
@@ -18,23 +34,7 @@ def clip(a_min: float, a_max: float) -> typing.Generator[AxisArray, AxisArray, N
18
34
  a_min: Lower clip bound
19
35
  a_max: Upper clip bound
20
36
 
21
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
22
- with the data payload containing the clipped version of the input :obj:`AxisArray` data.
37
+ Returns: :obj:`ClipTransformer`.
23
38
 
24
39
  """
25
- msg_out = AxisArray(np.array([]), dims=[""])
26
- while True:
27
- msg_in: AxisArray = yield msg_out
28
- msg_out = replace(msg_in, data=np.clip(msg_in.data, a_min, a_max))
29
-
30
-
31
- class ClipSettings(ez.Settings):
32
- a_min: float
33
- a_max: float
34
-
35
-
36
- class Clip(GenAxisArray):
37
- SETTINGS = ClipSettings
38
-
39
- def construct_generator(self):
40
- self.STATE.gen = clip(a_min=self.SETTINGS.a_min, a_max=self.SETTINGS.a_max)
40
+ return ClipTransformer(ClipSettings(a_min=a_min, a_max=a_max))