ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
  4. ezmsg/sigproc/affinetransform.py +16 -42
  5. ezmsg/sigproc/aggregate.py +17 -34
  6. ezmsg/sigproc/bandpower.py +12 -20
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +7 -16
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/coordinatespaces.py +142 -0
  13. ezmsg/sigproc/decimate.py +3 -7
  14. ezmsg/sigproc/denormalize.py +6 -11
  15. ezmsg/sigproc/detrend.py +3 -4
  16. ezmsg/sigproc/diff.py +8 -17
  17. ezmsg/sigproc/downsample.py +11 -20
  18. ezmsg/sigproc/ewma.py +11 -28
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +3 -4
  21. ezmsg/sigproc/fbcca.py +34 -59
  22. ezmsg/sigproc/filter.py +19 -45
  23. ezmsg/sigproc/filterbank.py +37 -74
  24. ezmsg/sigproc/filterbankdesign.py +7 -14
  25. ezmsg/sigproc/fir_hilbert.py +13 -30
  26. ezmsg/sigproc/fir_pmc.py +5 -10
  27. ezmsg/sigproc/firfilter.py +12 -14
  28. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  29. ezmsg/sigproc/kaiser.py +11 -15
  30. ezmsg/sigproc/math/abs.py +4 -3
  31. ezmsg/sigproc/math/add.py +121 -0
  32. ezmsg/sigproc/math/clip.py +4 -1
  33. ezmsg/sigproc/math/difference.py +100 -36
  34. ezmsg/sigproc/math/invert.py +3 -3
  35. ezmsg/sigproc/math/log.py +5 -6
  36. ezmsg/sigproc/math/scale.py +2 -0
  37. ezmsg/sigproc/messages.py +1 -2
  38. ezmsg/sigproc/quantize.py +3 -6
  39. ezmsg/sigproc/resample.py +17 -38
  40. ezmsg/sigproc/rollingscaler.py +12 -37
  41. ezmsg/sigproc/sampler.py +19 -37
  42. ezmsg/sigproc/scaler.py +11 -22
  43. ezmsg/sigproc/signalinjector.py +7 -18
  44. ezmsg/sigproc/slicer.py +14 -34
  45. ezmsg/sigproc/spectral.py +3 -3
  46. ezmsg/sigproc/spectrogram.py +12 -19
  47. ezmsg/sigproc/spectrum.py +17 -38
  48. ezmsg/sigproc/transpose.py +12 -24
  49. ezmsg/sigproc/util/asio.py +25 -156
  50. ezmsg/sigproc/util/axisarray_buffer.py +12 -26
  51. ezmsg/sigproc/util/buffer.py +22 -43
  52. ezmsg/sigproc/util/message.py +17 -31
  53. ezmsg/sigproc/util/profile.py +23 -174
  54. ezmsg/sigproc/util/sparse.py +7 -15
  55. ezmsg/sigproc/util/typeresolution.py +17 -83
  56. ezmsg/sigproc/wavelets.py +10 -19
  57. ezmsg/sigproc/window.py +29 -83
  58. ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
  59. ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
  60. ezmsg/sigproc/synth.py +0 -774
  61. ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
  62. ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
  63. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
  64. /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/fbcca.py CHANGED
@@ -1,29 +1,26 @@
1
- import typing
2
1
  import math
2
+ import typing
3
3
  from dataclasses import field
4
4
 
5
- import numpy as np
6
-
7
5
  import ezmsg.core as ez
8
- from ezmsg.util.messages.axisarray import AxisArray
9
- from ezmsg.util.messages.util import replace
10
-
11
- from .sampler import SampleTriggerMessage
12
- from .window import WindowTransformer, WindowSettings
13
-
14
- from .base import (
6
+ import numpy as np
7
+ from ezmsg.baseproc import (
8
+ BaseProcessor,
9
+ BaseStatefulProcessor,
15
10
  BaseTransformer,
16
11
  BaseTransformerUnit,
17
12
  CompositeProcessor,
18
- BaseProcessor,
19
- BaseStatefulProcessor,
20
13
  )
14
+ from ezmsg.util.messages.axisarray import AxisArray
15
+ from ezmsg.util.messages.util import replace
21
16
 
22
- from .kaiser import KaiserFilterSettings
23
17
  from .filterbankdesign import (
24
18
  FilterbankDesignSettings,
25
19
  FilterbankDesignTransformer,
26
20
  )
21
+ from .kaiser import KaiserFilterSettings
22
+ from .sampler import SampleTriggerMessage
23
+ from .window import WindowSettings, WindowTransformer
27
24
 
28
25
 
29
26
  class FBCCASettings(ez.Settings):
@@ -33,7 +30,7 @@ class FBCCASettings(ez.Settings):
33
30
 
34
31
  time_dim: str
35
32
  """
36
- The time dim in the data array.
33
+ The time dim in the data array.
37
34
  """
38
35
 
39
36
  ch_dim: str
@@ -49,40 +46,41 @@ class FBCCASettings(ez.Settings):
49
46
 
50
47
  harmonics: int = 5
51
48
  """
52
- The number of additional harmonics beyond the fundamental to use for the 'design' matrix.
53
- 5 (default): Evaluate 5 harmonics of the base frequency.
54
- Many periodic signals are not pure sinusoids, and inclusion of higher harmonics can help evaluate the
49
+ The number of additional harmonics beyond the fundamental to use for the 'design' matrix.
50
+ 5 (default): Evaluate 5 harmonics of the base frequency.
51
+ Many periodic signals are not pure sinusoids, and inclusion of higher harmonics can help evaluate the
55
52
  presence of signals with higher frequency harmonic content
56
53
  """
57
54
 
58
55
  freqs: typing.List[float] = field(default_factory=list)
59
56
  """
60
- Frequencies (in hz) to evaluate the presence of within the input signal.
61
- [] (default): an empty list; frequencies will be found within the input SampleMessages.
57
+ Frequencies (in hz) to evaluate the presence of within the input signal.
58
+ [] (default): an empty list; frequencies will be found within the input SampleMessages.
62
59
  AxisArrays have no good place to put this metadata, so specify frequencies here if only AxisArrays
63
60
  will be passed as input to the generator. If the input has a `trigger` attr of type :obj:`SampleTriggerMessage`,
64
- the processor looks for the `freqs` attribute within that trigger for a list of frequencies to evaluate.
65
- This field is present in the :obj:`SSVEPSampleTriggerMessage` defined in ezmsg.tasks.ssvep from the ezmsg-tasks package.
61
+ the processor looks for the `freqs` attribute within that trigger for a list of frequencies to evaluate.
62
+ This field is present in the :obj:`SSVEPSampleTriggerMessage` defined in ezmsg.tasks.ssvep from
63
+ the ezmsg-tasks package.
66
64
  NOTE: Avoid frequencies that have line-noise (60 Hz/50 Hz) as a harmonic.
67
65
  """
68
66
 
69
67
  softmax_beta: float = 1.0
70
68
  """
71
- Beta parameter for softmax on output --> "probabilities".
72
- 1.0 (default): Use the shifted softmax transformation to output 0-1 probabilities.
69
+ Beta parameter for softmax on output --> "probabilities".
70
+ 1.0 (default): Use the shifted softmax transformation to output 0-1 probabilities.
73
71
  If 0.0, the maximum singular value of the SVD for each design matrix is output
74
72
  """
75
73
 
76
74
  target_freq_dim: str = "target_freq"
77
75
  """
78
- Name for dim to put target frequency outputs on.
76
+ Name for dim to put target frequency outputs on.
79
77
  'target_freq' (default)
80
78
  """
81
79
 
82
80
  max_int_time: float = 0.0
83
81
  """
84
- Maximum integration time (in seconds) to use for calculation.
85
- 0 (default): Use all time provided for the calculation.
82
+ Maximum integration time (in seconds) to use for calculation.
83
+ 0 (default): Use all time provided for the calculation.
86
84
  Useful for artificially limiting the amount of data used for the CCA method to evaluate
87
85
  the necessary integration time for good decoding performance
88
86
  """
@@ -136,9 +134,7 @@ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
136
134
  if filterbank_dim_idx is not None:
137
135
  new_order.append(filterbank_dim_idx)
138
136
  new_order.extend([time_dim_idx, ch_dim_idx])
139
- out_dims = [
140
- message.dims[i] for i in new_order if message.dims[i] not in rm_dims
141
- ]
137
+ out_dims = [message.dims[i] for i in new_order if message.dims[i] not in rm_dims]
142
138
  data_arr = message.data.transpose(new_order)
143
139
 
144
140
  # Add a singleton dim for filterbank dim if we don't have one
@@ -158,10 +154,7 @@ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
158
154
  axis_name: axis
159
155
  for axis_name, axis in message.axes.items()
160
156
  if axis_name not in rm_dims
161
- and not (
162
- isinstance(axis, AxisArray.CoordinateAxis)
163
- and any(d in rm_dims for d in axis.dims)
164
- )
157
+ and not (isinstance(axis, AxisArray.CoordinateAxis) and any(d in rm_dims for d in axis.dims))
165
158
  }
166
159
  out_axes[self.settings.target_freq_dim] = AxisArray.CoordinateAxis(
167
160
  np.array(test_freqs), [self.settings.target_freq_dim]
@@ -193,15 +186,9 @@ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
193
186
  ]
194
187
  )
195
188
 
196
- for test_idx, arr in enumerate(
197
- data_arr
198
- ): # iterate over first dim; arr is (filterbank x time x ch)
199
- for band_idx, band in enumerate(
200
- arr
201
- ): # iterate over second dim: arr is (time x ch)
202
- calc_output[test_idx, band_idx, test_freq_idx] = cca_rho_max(
203
- band[:max_samp, ...], Y
204
- )
189
+ for test_idx, arr in enumerate(data_arr): # iterate over first dim; arr is (filterbank x time x ch)
190
+ for band_idx, band in enumerate(arr): # iterate over second dim: arr is (time x ch)
191
+ calc_output[test_idx, band_idx, test_freq_idx] = cca_rho_max(band[:max_samp, ...], Y)
205
192
 
206
193
  # Combine per-subband canonical correlations using a weighted sum
207
194
  # https://iopscience.iop.org/article/10.1088/1741-2560/12/4/046008
@@ -209,9 +196,7 @@ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
209
196
  calc_output = ((calc_output**2) * freq_weights[None, :, None]).sum(axis=1)
210
197
 
211
198
  if self.settings.softmax_beta != 0:
212
- calc_output = calc_softmax(
213
- calc_output, axis=-1, beta=self.settings.softmax_beta
214
- )
199
+ calc_output = calc_softmax(calc_output, axis=-1, beta=self.settings.softmax_beta)
215
200
 
216
201
  output = replace(
217
202
  message,
@@ -244,9 +229,7 @@ class StreamingFBCCASettings(FBCCASettings):
244
229
  subbands: int = 12
245
230
 
246
231
 
247
- class StreamingFBCCATransformer(
248
- CompositeProcessor[StreamingFBCCASettings, AxisArray, AxisArray]
249
- ):
232
+ class StreamingFBCCATransformer(CompositeProcessor[StreamingFBCCASettings, AxisArray, AxisArray]):
250
233
  @staticmethod
251
234
  def _initialize_processors(
252
235
  settings: StreamingFBCCASettings,
@@ -254,9 +237,7 @@ class StreamingFBCCATransformer(
254
237
  pipeline = {}
255
238
 
256
239
  if settings.filterbank_dim is not None:
257
- cut_freqs = (
258
- np.arange(settings.subbands + 1) * settings.filter_bw
259
- ) + settings.filter_low
240
+ cut_freqs = (np.arange(settings.subbands + 1) * settings.filter_bw) + settings.filter_low
260
241
  filters = [
261
242
  KaiserFilterSettings(
262
243
  axis=settings.time_dim,
@@ -269,9 +250,7 @@ class StreamingFBCCATransformer(
269
250
  ]
270
251
 
271
252
  pipeline["filterbank"] = FilterbankDesignTransformer(
272
- FilterbankDesignSettings(
273
- filters=filters, new_axis=settings.filterbank_dim
274
- )
253
+ FilterbankDesignSettings(filters=filters, new_axis=settings.filterbank_dim)
275
254
  )
276
255
 
277
256
  pipeline["window"] = WindowTransformer(
@@ -289,11 +268,7 @@ class StreamingFBCCATransformer(
289
268
  return pipeline
290
269
 
291
270
 
292
- class StreamingFBCCA(
293
- BaseTransformerUnit[
294
- StreamingFBCCASettings, AxisArray, AxisArray, StreamingFBCCATransformer
295
- ]
296
- ):
271
+ class StreamingFBCCA(BaseTransformerUnit[StreamingFBCCASettings, AxisArray, AxisArray, StreamingFBCCATransformer]):
297
272
  SETTINGS = StreamingFBCCASettings
298
273
 
299
274
 
ezmsg/sigproc/filter.py CHANGED
@@ -1,21 +1,21 @@
1
- from abc import abstractmethod, ABC
2
- from dataclasses import dataclass, field
3
1
  import typing
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass, field
4
4
 
5
5
  import ezmsg.core as ez
6
- from ezmsg.util.messages.axisarray import AxisArray
7
- from ezmsg.util.messages.util import replace
8
6
  import numpy as np
9
7
  import numpy.typing as npt
10
8
  import scipy.signal
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
11
11
 
12
12
  from ezmsg.sigproc.base import (
13
- processor_state,
13
+ BaseConsumerUnit,
14
14
  BaseStatefulTransformer,
15
15
  BaseTransformerUnit,
16
16
  SettingsType,
17
- BaseConsumerUnit,
18
17
  TransformerType,
18
+ processor_state,
19
19
  )
20
20
 
21
21
 
@@ -68,9 +68,7 @@ class FilterState:
68
68
  zi: npt.NDArray | None = None
69
69
 
70
70
 
71
- class FilterTransformer(
72
- BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]
73
- ):
71
+ class FilterTransformer(BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]):
74
72
  """
75
73
  Filter data using the provided coefficients.
76
74
  """
@@ -108,9 +106,7 @@ class FilterTransformer(
108
106
  zi = scipy.signal.sosfilt_zi(*coefs)
109
107
 
110
108
  zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
111
- n_tile = (
112
- message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :]
113
- )
109
+ n_tile = message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :]
114
110
 
115
111
  if self.settings.coef_type == "sos":
116
112
  zi_expand = (slice(None),) + zi_expand
@@ -144,17 +140,11 @@ class FilterTransformer(
144
140
  reset_needed = False
145
141
 
146
142
  if self.settings.coef_type == "ba":
147
- if isinstance(old_coefs, FilterCoefficients) and isinstance(
148
- coefs, FilterCoefficients
149
- ):
150
- if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(
151
- coefs.a
152
- ):
143
+ if isinstance(old_coefs, FilterCoefficients) and isinstance(coefs, FilterCoefficients):
144
+ if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(coefs.a):
153
145
  reset_needed = True
154
146
  elif isinstance(old_coefs, tuple) and isinstance(coefs, tuple):
155
- if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(
156
- coefs[1]
157
- ):
147
+ if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(coefs[1]):
158
148
  reset_needed = True
159
149
  else:
160
150
  reset_needed = True
@@ -173,36 +163,26 @@ class FilterTransformer(
173
163
  axis = message.dims[0] if self.settings.axis is None else self.settings.axis
174
164
  axis_idx = message.get_axis_idx(axis)
175
165
  _, coefs = _normalize_coefs(self.settings.coefs)
176
- filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[
177
- self.settings.coef_type
178
- ]
179
- dat_out, self.state.zi = filt_func(
180
- *coefs, message.data, axis=axis_idx, zi=self.state.zi
181
- )
166
+ filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[self.settings.coef_type]
167
+ dat_out, self.state.zi = filt_func(*coefs, message.data, axis=axis_idx, zi=self.state.zi)
182
168
  else:
183
169
  dat_out = message.data
184
170
 
185
171
  return replace(message, data=dat_out)
186
172
 
187
173
 
188
- class Filter(
189
- BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]
190
- ):
174
+ class Filter(BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]):
191
175
  SETTINGS = FilterSettings
192
176
 
193
177
 
194
- def filtergen(
195
- axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str
196
- ) -> FilterTransformer:
178
+ def filtergen(axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str) -> FilterTransformer:
197
179
  """
198
180
  Filter data using the provided coefficients.
199
181
 
200
182
  Returns:
201
183
  :obj:`FilterTransformer`.
202
184
  """
203
- return FilterTransformer(
204
- FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type)
205
- )
185
+ return FilterTransformer(FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type))
206
186
 
207
187
 
208
188
  @processor_state
@@ -230,9 +210,7 @@ class FilterByDesignTransformer(
230
210
  """Return a function that takes sampling frequency and returns filter coefficients."""
231
211
  ...
232
212
 
233
- def update_settings(
234
- self, new_settings: typing.Optional[SettingsType] = None, **kwargs
235
- ) -> None:
213
+ def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
236
214
  """
237
215
  Update settings and mark that filter coefficients need to be recalculated.
238
216
 
@@ -271,9 +249,7 @@ class FilterByDesignTransformer(
271
249
  b, a = coefs
272
250
  coefs = scipy.signal.tf2sos(b, a)
273
251
 
274
- self.state.filter.update_coefficients(
275
- coefs, coef_type=self.settings.coef_type
276
- )
252
+ self.state.filter.update_coefficients(coefs, coef_type=self.settings.coef_type)
277
253
  self.state.needs_redesign = False
278
254
 
279
255
  return super().__call__(message)
@@ -298,9 +274,7 @@ class FilterByDesignTransformer(
298
274
  b, a = coefs
299
275
  coefs = scipy.signal.tf2sos(b, a)
300
276
 
301
- new_settings = FilterSettings(
302
- axis=axis, coef_type=self.settings.coef_type, coefs=coefs
303
- )
277
+ new_settings = FilterSettings(axis=axis, coef_type=self.settings.coef_type, coefs=coefs)
304
278
  self.state.filter = FilterTransformer(settings=new_settings)
305
279
 
306
280
  def _process(self, message: AxisArray) -> AxisArray:
@@ -2,20 +2,20 @@ import functools
2
2
  import math
3
3
  import typing
4
4
 
5
+ import ezmsg.core as ez
5
6
  import numpy as np
6
- import scipy.signal as sps
7
- import scipy.fft as sp_fft
8
- from scipy.special import lambertw
9
7
  import numpy.typing as npt
10
- import ezmsg.core as ez
11
- from ezmsg.util.messages.axisarray import AxisArray
12
- from ezmsg.util.messages.util import replace
13
-
14
- from .base import (
8
+ import scipy.fft as sp_fft
9
+ import scipy.signal as sps
10
+ from ezmsg.baseproc import (
15
11
  BaseStatefulTransformer,
16
12
  BaseTransformerUnit,
17
13
  processor_state,
18
14
  )
15
+ from ezmsg.util.messages.axisarray import AxisArray
16
+ from ezmsg.util.messages.util import replace
17
+ from scipy.special import lambertw
18
+
19
19
  from .spectrum import OptionsEnum
20
20
  from .window import WindowTransformer
21
21
 
@@ -32,8 +32,14 @@ class MinPhaseMode(OptionsEnum):
32
32
  """The mode of operation for the filterbank."""
33
33
 
34
34
  NONE = "No kernel modification"
35
- HILBERT = "Hilbert Method; designed to be used with equiripple filters (e.g., from remez) with unity or zero gain regions"
36
- HOMOMORPHIC = "Works best with filters with an odd number of taps, and the resulting minimum phase filter will have a magnitude response that approximates the square root of the original filter’s magnitude response using half the number of taps"
35
+ HILBERT = (
36
+ "Hilbert Method; designed to be used with equiripple filters (e.g., from remez) with unity or zero gain regions"
37
+ )
38
+ HOMOMORPHIC = (
39
+ "Works best with filters with an odd number of taps, and the resulting minimum phase filter "
40
+ "will have a magnitude response that approximates the square root of the original filter’s "
41
+ "magnitude response using half the number of taps"
42
+ )
37
43
  # HOMOMORPHICFULL = "Like HOMOMORPHIC, but uses the full number of taps and same magnitude"
38
44
 
39
45
 
@@ -78,23 +84,17 @@ class FilterbankState:
78
84
  mode: FilterbankMode | None = None
79
85
 
80
86
 
81
- class FilterbankTransformer(
82
- BaseStatefulTransformer[FilterbankSettings, AxisArray, AxisArray, FilterbankState]
83
- ):
87
+ class FilterbankTransformer(BaseStatefulTransformer[FilterbankSettings, AxisArray, AxisArray, FilterbankState]):
84
88
  def _hash_message(self, message: AxisArray) -> int:
85
89
  axis = self.settings.axis or message.dims[0]
86
90
  gain = message.axes[axis].gain if axis in message.axes else 1.0
87
91
  targ_ax_ix = message.get_axis_idx(axis)
88
- in_shape = (
89
- message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
90
- )
92
+ in_shape = message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
91
93
 
92
94
  return hash(
93
95
  (
94
96
  message.key,
95
- gain
96
- if self.settings.mode in [FilterbankMode.FFT, FilterbankMode.AUTO]
97
- else None,
97
+ gain if self.settings.mode in [FilterbankMode.FFT, FilterbankMode.AUTO] else None,
98
98
  message.data.dtype.kind,
99
99
  in_shape,
100
100
  )
@@ -104,9 +104,7 @@ class FilterbankTransformer(
104
104
  axis = self.settings.axis or message.dims[0]
105
105
  gain = message.axes[axis].gain if axis in message.axes else 1.0
106
106
  targ_ax_ix = message.get_axis_idx(axis)
107
- in_shape = (
108
- message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
109
- )
107
+ in_shape = message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
110
108
 
111
109
  kernels = self.settings.kernels
112
110
  if self.settings.min_phase != MinPhaseMode.NONE:
@@ -118,9 +116,7 @@ class FilterbankTransformer(
118
116
  kernels = [sps.minimum_phase(k, method=method) for k in kernels]
119
117
 
120
118
  # Determine if this will be operating with complex data.
121
- b_complex = message.data.dtype.kind == "c" or any(
122
- [_.dtype.kind == "c" for _ in kernels]
123
- )
119
+ b_complex = message.data.dtype.kind == "c" or any([_.dtype.kind == "c" for _ in kernels])
124
120
 
125
121
  # Calculate window_dur, window_shift, nfft
126
122
  max_kernel_len = max([_.size for _ in kernels])
@@ -130,17 +126,13 @@ class FilterbankTransformer(
130
126
 
131
127
  # Prepare previous iteration's overlap tail to add to input -- all zeros.
132
128
  tail_shape = in_shape + (len(kernels), self._state.overlap)
133
- self._state.tail = np.zeros(
134
- tail_shape, dtype="complex" if b_complex else "float"
135
- )
129
+ self._state.tail = np.zeros(tail_shape, dtype="complex" if b_complex else "float")
136
130
 
137
131
  # Prepare output template -- kernels axis immediately before the target axis
138
132
  dummy_shape = in_shape + (len(kernels), 0)
139
133
  self._state.template = AxisArray(
140
134
  data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
141
- dims=message.dims[:targ_ax_ix]
142
- + message.dims[targ_ax_ix + 1 :]
143
- + [self.settings.new_axis, axis],
135
+ dims=message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [self.settings.new_axis, axis],
144
136
  axes=message.axes.copy(),
145
137
  key=message.key,
146
138
  )
@@ -155,8 +147,7 @@ class FilterbankTransformer(
155
147
  dummy_arr = np.zeros(n_dummy)
156
148
  self._state.mode = (
157
149
  FilterbankMode.CONV
158
- if sps.choose_conv_method(dummy_arr, concat_kernel, mode="full")
159
- == "direct"
150
+ if sps.choose_conv_method(dummy_arr, concat_kernel, mode="full") == "direct"
160
151
  else FilterbankMode.FFT
161
152
  )
162
153
 
@@ -166,16 +157,11 @@ class FilterbankTransformer(
166
157
  len(kernels),
167
158
  self._state.overlap + message.data.shape[targ_ax_ix],
168
159
  )
169
- self._state.dest_arr = np.zeros(
170
- dest_shape, dtype="complex" if b_complex else "float"
171
- )
160
+ self._state.dest_arr = np.zeros(dest_shape, dtype="complex" if b_complex else "float")
172
161
  self._state.prep_kerns = kernels
173
162
  else: # FFT mode
174
163
  # Calculate optimal nfft and windowing size.
175
- opt_size = (
176
- -self._state.overlap
177
- * lambertw(-1 / (2 * math.e * self._state.overlap), k=-1).real
178
- )
164
+ opt_size = -self._state.overlap * lambertw(-1 / (2 * math.e * self._state.overlap), k=-1).real
179
165
  self._state.nfft = sp_fft.next_fast_len(math.ceil(opt_size))
180
166
  win_len = self._state.nfft - self._state.overlap
181
167
  # infft same as nfft. Keeping as separate variable because I might need it again.
@@ -201,19 +187,11 @@ class FilterbankTransformer(
201
187
  # for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
202
188
  # more fft backends.
203
189
  if b_complex:
204
- self._state.fft = functools.partial(
205
- sp_fft.fft, n=self._state.nfft, norm="backward"
206
- )
207
- self._state.ifft = functools.partial(
208
- sp_fft.ifft, n=self._state.infft, norm="backward"
209
- )
190
+ self._state.fft = functools.partial(sp_fft.fft, n=self._state.nfft, norm="backward")
191
+ self._state.ifft = functools.partial(sp_fft.ifft, n=self._state.infft, norm="backward")
210
192
  else:
211
- self._state.fft = functools.partial(
212
- sp_fft.rfft, n=self._state.nfft, norm="backward"
213
- )
214
- self._state.ifft = functools.partial(
215
- sp_fft.irfft, n=self._state.infft, norm="backward"
216
- )
193
+ self._state.fft = functools.partial(sp_fft.rfft, n=self._state.nfft, norm="backward")
194
+ self._state.ifft = functools.partial(sp_fft.irfft, n=self._state.infft, norm="backward")
217
195
 
218
196
  # Calculate fft of kernels
219
197
  self._state.prep_kerns = np.array([self._state.fft(_) for _ in kernels])
@@ -229,9 +207,7 @@ class FilterbankTransformer(
229
207
  in_dat = np.moveaxis(message.data, targ_ax_ix, -1)
230
208
  if self._state.mode == FilterbankMode.FFT:
231
209
  # Fix message.dims because we will pass it to windower
232
- move_dims = (
233
- message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [axis]
234
- )
210
+ move_dims = message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [axis]
235
211
  message = replace(message, data=in_dat, dims=move_dims)
236
212
  else:
237
213
  in_dat = message.data
@@ -239,22 +215,15 @@ class FilterbankTransformer(
239
215
  if self._state.mode == FilterbankMode.CONV:
240
216
  n_dest = in_dat.shape[-1] + self._state.overlap
241
217
  if self._state.dest_arr.shape[-1] < n_dest:
242
- pad = np.zeros(
243
- self._state.dest_arr.shape[:-1]
244
- + (n_dest - self._state.dest_arr.shape[-1],)
245
- )
246
- self._state.dest_arr = np.concatenate(
247
- [self._state.dest_arr, pad], axis=-1
248
- )
218
+ pad = np.zeros(self._state.dest_arr.shape[:-1] + (n_dest - self._state.dest_arr.shape[-1],))
219
+ self._state.dest_arr = np.concatenate([self._state.dest_arr, pad], axis=-1)
249
220
  self._state.dest_arr.fill(0)
250
221
 
251
222
  # Note: I tried several alternatives to this loop; all were slower than this.
252
223
  # numba.jit; stride_tricks + np.einsum; threading. Latter might be better with Python 3.13.
253
224
  for k_ix, k in enumerate(self._state.prep_kerns):
254
225
  n_out = in_dat.shape[-1] + k.shape[-1] - 1
255
- self._state.dest_arr[..., k_ix, :n_out] = np.apply_along_axis(
256
- np.convolve, -1, in_dat, k, mode="full"
257
- )
226
+ self._state.dest_arr[..., k_ix, :n_out] = np.apply_along_axis(np.convolve, -1, in_dat, k, mode="full")
258
227
  self._state.dest_arr[..., : self._state.overlap] += self._state.tail
259
228
  new_tail = self._state.dest_arr[..., in_dat.shape[-1] : n_dest]
260
229
  if new_tail.size > 0:
@@ -278,18 +247,14 @@ class FilterbankTransformer(
278
247
  # Previous iteration's tail:
279
248
  overlapped[..., :1, : self._state.overlap] += self._state.tail
280
249
  # window-to-window:
281
- overlapped[..., 1:, : self._state.overlap] += overlapped[
282
- ..., :-1, -self._state.overlap :
283
- ]
250
+ overlapped[..., 1:, : self._state.overlap] += overlapped[..., :-1, -self._state.overlap :]
284
251
  # Save tail:
285
252
  new_tail = overlapped[..., -1:, -self._state.overlap :]
286
253
  if new_tail.size > 0:
287
254
  # All of the above code works if input is size-zero, but we don't want to save a zero-size tail.
288
255
  self._state.tail = new_tail
289
256
  # Concat over win axis, without overlap.
290
- res = overlapped[..., : -self._state.overlap].reshape(
291
- overlapped.shape[:-2] + (-1,)
292
- )
257
+ res = overlapped[..., : -self._state.overlap].reshape(overlapped.shape[:-2] + (-1,))
293
258
 
294
259
  return replace(
295
260
  self._state.template,
@@ -298,9 +263,7 @@ class FilterbankTransformer(
298
263
  )
299
264
 
300
265
 
301
- class Filterbank(
302
- BaseTransformerUnit[FilterbankSettings, AxisArray, AxisArray, FilterbankTransformer]
303
- ):
266
+ class Filterbank(BaseTransformerUnit[FilterbankSettings, AxisArray, AxisArray, FilterbankTransformer]):
304
267
  SETTINGS = FilterbankSettings
305
268
 
306
269
 
@@ -3,22 +3,19 @@ import typing
3
3
  import ezmsg.core as ez
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
-
7
- from ezmsg.util.messages.util import replace
8
- from ezmsg.util.messages.axisarray import AxisArray
9
-
10
- from .base import (
6
+ from ezmsg.baseproc import (
11
7
  BaseStatefulTransformer,
12
8
  processor_state,
13
9
  )
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
+ from ezmsg.util.messages.util import replace
14
12
 
15
13
  from .filterbank import (
16
- FilterbankTransformer,
17
- FilterbankSettings,
18
14
  FilterbankMode,
15
+ FilterbankSettings,
16
+ FilterbankTransformer,
19
17
  MinPhaseMode,
20
18
  )
21
-
22
19
  from .kaiser import KaiserFilterSettings, kaiser_design_fun
23
20
 
24
21
 
@@ -55,9 +52,7 @@ class FilterbankDesignState:
55
52
 
56
53
 
57
54
  class FilterbankDesignTransformer(
58
- BaseStatefulTransformer[
59
- FilterbankDesignSettings, AxisArray, AxisArray, FilterbankDesignState
60
- ],
55
+ BaseStatefulTransformer[FilterbankDesignSettings, AxisArray, AxisArray, FilterbankDesignState],
61
56
  ):
62
57
  """
63
58
  Transformer that designs and applies a filterbank based on Kaiser windowed FIR filters.
@@ -70,9 +65,7 @@ class FilterbankDesignTransformer(
70
65
  else:
71
66
  raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
72
67
 
73
- def update_settings(
74
- self, new_settings: typing.Optional[FilterbankDesignSettings] = None, **kwargs
75
- ) -> None:
68
+ def update_settings(self, new_settings: typing.Optional[FilterbankDesignSettings] = None, **kwargs) -> None:
76
69
  """
77
70
  Update settings and mark that filter coefficients need to be recalculated.
78
71