ezmsg-sigproc 2.2.0__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,14 @@
1
1
  # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
3
 
4
- __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
5
12
 
6
13
  TYPE_CHECKING = False
7
14
  if TYPE_CHECKING:
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
9
16
  from typing import Union
10
17
 
11
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
12
20
  else:
13
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
14
23
 
15
24
  version: str
16
25
  __version__: str
17
26
  __version_tuple__: VERSION_TUPLE
18
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
19
30
 
20
- __version__ = version = '2.2.0'
21
- __version_tuple__ = version_tuple = (2, 2, 0)
31
+ __version__ = version = '2.3.0'
32
+ __version_tuple__ = version_tuple = (2, 3, 0)
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,86 @@
1
+ import ezmsg.core as ez
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ from ezmsg.sigproc.base import (
5
+ BaseTransformerUnit,
6
+ BaseStatefulTransformer,
7
+ processor_state,
8
+ )
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
11
+
12
+
13
+ class DenormalizeSettings(ez.Settings):
14
+ low_rate: float = 2.0
15
+ """Low end of probable rate after denormalization (Hz)."""
16
+
17
+ high_rate: float = 40.0
18
+ """High end of probable rate after denormalization (Hz)."""
19
+
20
+ distribution: str = "uniform"
21
+ """Distribution to sample rates from. Options are 'uniform', 'normal', or 'constant'."""
22
+
23
+
24
+ @processor_state
25
+ class DenormalizeRateState:
26
+ gains: npt.NDArray | None = None
27
+ offsets: npt.NDArray | None = None
28
+
29
+
30
+ class DenormalizeTransformer(
31
+ BaseStatefulTransformer[
32
+ DenormalizeSettings, AxisArray, AxisArray, DenormalizeRateState
33
+ ]
34
+ ):
35
+ """
36
+ Scales data from a normalized distribution (mean=0, std=1) to a denormalized
37
+ distribution using random per-channel offsets and gains designed to keep the
38
+ 99.9% CIs between 0 and 2x the offset.
39
+
40
+ This is useful for simulating realistic firing rates from normalized data.
41
+ """
42
+
43
+ def _reset_state(self, message: AxisArray) -> None:
44
+ ax_ix = message.get_axis_idx("ch")
45
+ nch = message.data.shape[ax_ix]
46
+ arr_size = (nch, 1) if ax_ix == 0 else (1, nch)
47
+ if self.settings.distribution == "uniform":
48
+ self.state.offsets = np.random.uniform(2.0, 40.0, size=arr_size)
49
+ elif self.settings.distribution == "normal":
50
+ self.state.offsets = np.random.normal(
51
+ loc=(self.settings.low_rate + self.settings.high_rate) / 2.0,
52
+ scale=(self.settings.high_rate - self.settings.low_rate) / 6.0,
53
+ size=arr_size,
54
+ )
55
+ self.state.offsets = np.clip(
56
+ self.state.offsets,
57
+ a_min=self.settings.low_rate,
58
+ a_max=self.settings.high_rate,
59
+ )
60
+ elif self.settings.distribution == "constant":
61
+ self.state.offsets = np.full(
62
+ shape=arr_size,
63
+ fill_value=(self.settings.low_rate + self.settings.high_rate) / 2.0,
64
+ )
65
+ else:
66
+ raise ValueError(f"Invalid distribution: {self.settings.distribution}")
67
+ # Input has std == 1
68
+ # Desired output has range from 0 to 2*self.state.offsets within 99.9% confidence interval
69
+ # For a standard normal distribution, 99.9% of data is within +/- 3.29 std devs.
70
+ # So, gain = offset / 3.29 to scale the std dev appropriately.
71
+ self.state.gains = self.state.offsets / 3.29
72
+
73
+ def _process(self, message: AxisArray) -> AxisArray:
74
+ denorm = message.data * self.state.gains + self.state.offsets
75
+ return replace(
76
+ message,
77
+ data=np.clip(denorm, a_min=0.0, a_max=None),
78
+ )
79
+
80
+
81
+ class DenormalizeRateUnit(
82
+ BaseTransformerUnit[
83
+ DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer
84
+ ]
85
+ ):
86
+ SETTINGS = DenormalizeSettings
ezmsg/sigproc/fbcca.py ADDED
@@ -0,0 +1,332 @@
1
+ import typing
2
+ import math
3
+ from dataclasses import field
4
+
5
+ import numpy as np
6
+
7
+ import ezmsg.core as ez
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.messages.util import replace
10
+
11
+ from .sampler import SampleTriggerMessage
12
+ from .window import WindowTransformer, WindowSettings
13
+
14
+ from .base import (
15
+ BaseTransformer,
16
+ BaseTransformerUnit,
17
+ CompositeProcessor,
18
+ BaseProcessor,
19
+ BaseStatefulProcessor,
20
+ )
21
+
22
+ from .kaiser import KaiserFilterSettings
23
+ from .filterbankdesign import (
24
+ FilterbankDesignSettings,
25
+ FilterbankDesignTransformer,
26
+ )
27
+
28
+
29
+ class FBCCASettings(ez.Settings):
30
+ """
31
+ Settings for :obj:`FBCCATransformer`
32
+ """
33
+
34
+ time_dim: str
35
+ """
36
+ The time dim in the data array.
37
+ """
38
+
39
+ ch_dim: str
40
+ """
41
+ The channels dim in the data array.
42
+ """
43
+
44
+ filterbank_dim: str | None = None
45
+ """
46
+ The filter bank subband dim in the data array. If unspecified, method falls back to CCA
47
+ None (default): the input has no subbands; just use CCA
48
+ """
49
+
50
+ harmonics: int = 5
51
+ """
52
+ The number of additional harmonics beyond the fundamental to use for the 'design' matrix.
53
+ 5 (default): Evaluate 5 harmonics of the base frequency.
54
+ Many periodic signals are not pure sinusoids, and inclusion of higher harmonics can help evaluate the
55
+ presence of signals with higher frequency harmonic content
56
+ """
57
+
58
+ freqs: typing.List[float] = field(default_factory=list)
59
+ """
60
+ Frequencies (in hz) to evaluate the presence of within the input signal.
61
+ [] (default): an empty list; frequencies will be found within the input SampleMessages.
62
+ AxisArrays have no good place to put this metadata, so specify frequencies here if only AxisArrays
63
+ will be passed as input to the generator. If the input has a `trigger` attr of type :obj:`SampleTriggerMessage`,
64
+ the processor looks for the `freqs` attribute within that trigger for a list of frequencies to evaluate.
65
+ This field is present in the :obj:`SSVEPSampleTriggerMessage` defined in ezmsg.tasks.ssvep from the ezmsg-tasks package.
66
+ NOTE: Avoid frequencies that have line-noise (60 Hz/50 Hz) as a harmonic.
67
+ """
68
+
69
+ softmax_beta: float = 1.0
70
+ """
71
+ Beta parameter for softmax on output --> "probabilities".
72
+ 1.0 (default): Use the shifted softmax transformation to output 0-1 probabilities.
73
+ If 0.0, the maximum singular value of the SVD for each design matrix is output
74
+ """
75
+
76
+ target_freq_dim: str = "target_freq"
77
+ """
78
+ Name for dim to put target frequency outputs on.
79
+ 'target_freq' (default)
80
+ """
81
+
82
+ max_int_time: float = 0.0
83
+ """
84
+ Maximum integration time (in seconds) to use for calculation.
85
+ 0 (default): Use all time provided for the calculation.
86
+ Useful for artificially limiting the amount of data used for the CCA method to evaluate
87
+ the necessary integration time for good decoding performance
88
+ """
89
+
90
+
91
+ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
92
+ """
93
+ A canonical-correlation (CCA) signal decoder for detection of periodic activity in multi-channel timeseries
94
+ recordings. It is particularly useful for detecting the presence of steady-state evoked responses in multi-channel
95
+ EEG data. Please see Lin et. al. 2007 for a description on the use of CCA to detect the presence of SSVEP in EEG
96
+ data.
97
+ This implementation also includes the "Filterbank" extension of the CCA decoding approach which utilizes a
98
+ filterbank to decompose input multi-channel EEG data into several frequency sub-bands; each of which is analyzed
99
+ with CCA, then combined using a weighted sum; allowing CCA to more readily identify harmonic content in EEG data.
100
+ Read more about this approach in Chen et. al. 2015.
101
+
102
+ ## Further reading:
103
+ * [Lin et. al. 2007](https://ieeexplore.ieee.org/document/4015614)
104
+ * [Nakanishi et. al. 2015](https://doi.org/10.1371%2Fjournal.pone.0140703)
105
+ * [Chen et. al. 2015](http://dx.doi.org/10.1088/1741-2560/12/4/046008)
106
+ """
107
+
108
+ def _process(self, message: AxisArray) -> AxisArray:
109
+ """
110
+ Input: AxisArray with at least a time_dim, and ch_dim
111
+ Output: AxisArray with time_dim, ch_dim, (and filterbank_dim if specified)
112
+ collapsed, with a new 'target_freq' dim of length 'freqs'
113
+ """
114
+
115
+ test_freqs: list[float] = self.settings.freqs
116
+ trigger = message.attrs.get("trigger", None)
117
+ if isinstance(trigger, SampleTriggerMessage):
118
+ if len(test_freqs) == 0:
119
+ test_freqs = getattr(trigger, "freqs", [])
120
+
121
+ if len(test_freqs) == 0:
122
+ raise ValueError("no frequencies to test")
123
+
124
+ time_dim_idx = message.get_axis_idx(self.settings.time_dim)
125
+ ch_dim_idx = message.get_axis_idx(self.settings.ch_dim)
126
+
127
+ filterbank_dim_idx = None
128
+ if self.settings.filterbank_dim is not None:
129
+ filterbank_dim_idx = message.get_axis_idx(self.settings.filterbank_dim)
130
+
131
+ # Move (filterbank_dim), time, ch to end of array
132
+ rm_dims = [self.settings.time_dim, self.settings.ch_dim]
133
+ if self.settings.filterbank_dim is not None:
134
+ rm_dims = [self.settings.filterbank_dim] + rm_dims
135
+ new_order = [i for i, dim in enumerate(message.dims) if dim not in rm_dims]
136
+ if filterbank_dim_idx is not None:
137
+ new_order.append(filterbank_dim_idx)
138
+ new_order.extend([time_dim_idx, ch_dim_idx])
139
+ out_dims = [
140
+ message.dims[i] for i in new_order if message.dims[i] not in rm_dims
141
+ ]
142
+ data_arr = message.data.transpose(new_order)
143
+
144
+ # Add a singleton dim for filterbank dim if we don't have one
145
+ if filterbank_dim_idx is None:
146
+ data_arr = data_arr[..., None, :, :]
147
+ filterbank_dim_idx = data_arr.ndim - 3
148
+
149
+ # data_arr is now (..., filterbank, time, ch)
150
+ # Get output shape for remaining dims and reshape data_arr for iterative processing
151
+ out_shape = list(data_arr.shape[:-3])
152
+ data_arr = data_arr.reshape([math.prod(out_shape), *data_arr.shape[-3:]])
153
+
154
+ # Create output dims and axes with added target_freq_dim
155
+ out_shape.append(len(test_freqs))
156
+ out_dims.append(self.settings.target_freq_dim)
157
+ out_axes = {
158
+ axis_name: axis
159
+ for axis_name, axis in message.axes.items()
160
+ if axis_name not in rm_dims
161
+ and not (
162
+ isinstance(axis, AxisArray.CoordinateAxis)
163
+ and any(d in rm_dims for d in axis.dims)
164
+ )
165
+ }
166
+ out_axes[self.settings.target_freq_dim] = AxisArray.CoordinateAxis(
167
+ np.array(test_freqs), [self.settings.target_freq_dim]
168
+ )
169
+
170
+ if message.data.size == 0:
171
+ out_data = message.data.reshape(out_shape)
172
+ output = replace(message, data=out_data, dims=out_dims, axes=out_axes)
173
+ return output
174
+
175
+ # Get time axis
176
+ t_ax_info = message.ax(self.settings.time_dim)
177
+ t = t_ax_info.values
178
+ t -= t[0]
179
+ max_samp = len(t)
180
+ if self.settings.max_int_time > 0:
181
+ max_samp = int(abs(t_ax_info.values - self.settings.max_int_time).argmin())
182
+ t = t[:max_samp]
183
+
184
+ calc_output = np.zeros((*data_arr.shape[:-2], len(test_freqs)))
185
+
186
+ for test_freq_idx, test_freq in enumerate(test_freqs):
187
+ # Create the design matrix of base frequency and requested harmonics
188
+ Y = np.column_stack(
189
+ [
190
+ fn(2.0 * np.pi * k * test_freq * t)
191
+ for k in range(1, self.settings.harmonics + 1)
192
+ for fn in (np.sin, np.cos)
193
+ ]
194
+ )
195
+
196
+ for test_idx, arr in enumerate(
197
+ data_arr
198
+ ): # iterate over first dim; arr is (filterbank x time x ch)
199
+ for band_idx, band in enumerate(
200
+ arr
201
+ ): # iterate over second dim: arr is (time x ch)
202
+ calc_output[test_idx, band_idx, test_freq_idx] = cca_rho_max(
203
+ band[:max_samp, ...], Y
204
+ )
205
+
206
+ # Combine per-subband canonical correlations using a weighted sum
207
+ # https://iopscience.iop.org/article/10.1088/1741-2560/12/4/046008
208
+ freq_weights = (np.arange(1, calc_output.shape[1] + 1) ** -1.25) + 0.25
209
+ calc_output = ((calc_output**2) * freq_weights[None, :, None]).sum(axis=1)
210
+
211
+ if self.settings.softmax_beta != 0:
212
+ calc_output = calc_softmax(
213
+ calc_output, axis=-1, beta=self.settings.softmax_beta
214
+ )
215
+
216
+ output = replace(
217
+ message,
218
+ data=calc_output.reshape(out_shape),
219
+ dims=out_dims,
220
+ axes=out_axes,
221
+ )
222
+
223
+ return output
224
+
225
+
226
+ class FBCCA(BaseTransformerUnit[FBCCASettings, AxisArray, AxisArray, FBCCATransformer]):
227
+ SETTINGS = FBCCASettings
228
+
229
+
230
+ class StreamingFBCCASettings(FBCCASettings):
231
+ """
232
+ Perform rolling/streaming FBCCA on incoming EEG.
233
+ Decomposes the input multi-channel timeseries data into multiple sub-bands using a FilterbankDesign Transformer,
234
+ then accumulates data using Window into short-time observations for analysis using an FBCCA Transformer.
235
+ """
236
+
237
+ window_dur: float = 4.0 # sec
238
+ window_shift: float = 0.5 # sec
239
+ window_dim: str = "fbcca_window"
240
+ filter_bw: float = 7.0 # Hz
241
+ filter_low: float = 7.0 # Hz
242
+ trans_bw: float = 2.0 # Hz
243
+ ripple_db: float = 20.0 # dB
244
+ subbands: int = 12
245
+
246
+
247
+ class StreamingFBCCATransformer(
248
+ CompositeProcessor[StreamingFBCCASettings, AxisArray, AxisArray]
249
+ ):
250
+ @staticmethod
251
+ def _initialize_processors(
252
+ settings: StreamingFBCCASettings,
253
+ ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
254
+ pipeline = {}
255
+
256
+ if settings.filterbank_dim is not None:
257
+ cut_freqs = (
258
+ np.arange(settings.subbands + 1) * settings.filter_bw
259
+ ) + settings.filter_low
260
+ filters = [
261
+ KaiserFilterSettings(
262
+ axis=settings.time_dim,
263
+ cutoff=(c - settings.trans_bw, cut_freqs[-1]),
264
+ ripple=settings.ripple_db,
265
+ width=settings.trans_bw,
266
+ pass_zero=False,
267
+ )
268
+ for c in cut_freqs[:-1]
269
+ ]
270
+
271
+ pipeline["filterbank"] = FilterbankDesignTransformer(
272
+ FilterbankDesignSettings(
273
+ filters=filters, new_axis=settings.filterbank_dim
274
+ )
275
+ )
276
+
277
+ pipeline["window"] = WindowTransformer(
278
+ WindowSettings(
279
+ axis=settings.time_dim,
280
+ newaxis=settings.window_dim,
281
+ window_dur=settings.window_dur,
282
+ window_shift=settings.window_shift,
283
+ zero_pad_until="shift",
284
+ )
285
+ )
286
+
287
+ pipeline["fbcca"] = FBCCATransformer(settings)
288
+
289
+ return pipeline
290
+
291
+
292
+ class StreamingFBCCA(
293
+ BaseTransformerUnit[
294
+ StreamingFBCCASettings, AxisArray, AxisArray, StreamingFBCCATransformer
295
+ ]
296
+ ):
297
+ SETTINGS = StreamingFBCCASettings
298
+
299
+
300
+ def cca_rho_max(X: np.ndarray, Y: np.ndarray) -> float:
301
+ """
302
+ X: (n_time, n_ch)
303
+ Y: (n_time, n_ref) # design matrix for one frequency
304
+ returns: largest canonical correlation in [0,1]
305
+ """
306
+ # Center columns
307
+ Xc = X - X.mean(axis=0, keepdims=True)
308
+ Yc = Y - Y.mean(axis=0, keepdims=True)
309
+
310
+ # Drop any zero-variance columns to avoid rank issues
311
+ Xc = Xc[:, Xc.std(axis=0) > 1e-12]
312
+ Yc = Yc[:, Yc.std(axis=0) > 1e-12]
313
+ if Xc.size == 0 or Yc.size == 0:
314
+ return 0.0
315
+
316
+ # Orthonormal bases
317
+ Qx, _ = np.linalg.qr(Xc, mode="reduced") # (n_time, r_x)
318
+ Qy, _ = np.linalg.qr(Yc, mode="reduced") # (n_time, r_y)
319
+
320
+ # Canonical correlations are the singular values of Qx^T Qy
321
+ with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
322
+ s = np.linalg.svd(Qx.T @ Qy, compute_uv=False)
323
+ return float(s[0]) if s.size else 0.0
324
+
325
+
326
+ def calc_softmax(cv: np.ndarray, axis: int, beta: float = 1.0):
327
+ # Calculate softmax with shifting to avoid overflow
328
+ # (https://doi.org/10.1093/imanum/draa038)
329
+ cv = cv - cv.max(axis=axis, keepdims=True)
330
+ cv = np.exp(beta * cv)
331
+ cv = cv / np.sum(cv, axis=axis, keepdims=True)
332
+ return cv
ezmsg/sigproc/filter.py CHANGED
@@ -263,6 +263,14 @@ class FilterByDesignTransformer(
263
263
  axis = self.state.filter.settings.axis
264
264
  fs = 1 / message.axes[axis].gain
265
265
  coefs = design_fun(fs)
266
+
267
+ # Convert BA to SOS if requested
268
+ if coefs is not None and self.settings.coef_type == "sos":
269
+ if isinstance(coefs, tuple) and len(coefs) == 2:
270
+ # It's BA format, convert to SOS
271
+ b, a = coefs
272
+ coefs = scipy.signal.tf2sos(b, a)
273
+
266
274
  self.state.filter.update_coefficients(
267
275
  coefs, coef_type=self.settings.coef_type
268
276
  )
@@ -282,6 +290,14 @@ class FilterByDesignTransformer(
282
290
  axis = message.dims[0] if self.settings.axis is None else self.settings.axis
283
291
  fs = 1 / message.axes[axis].gain
284
292
  coefs = design_fun(fs)
293
+
294
+ # Convert BA to SOS if requested
295
+ if coefs is not None and self.settings.coef_type == "sos":
296
+ if isinstance(coefs, tuple) and len(coefs) == 2:
297
+ # It's BA format, convert to SOS
298
+ b, a = coefs
299
+ coefs = scipy.signal.tf2sos(b, a)
300
+
285
301
  new_settings = FilterSettings(
286
302
  axis=axis, coef_type=self.settings.coef_type, coefs=coefs
287
303
  )
@@ -0,0 +1,136 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+
7
+ from ezmsg.util.messages.util import replace
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+
10
+ from .base import (
11
+ BaseStatefulTransformer,
12
+ processor_state,
13
+ )
14
+
15
+ from .filterbank import (
16
+ FilterbankTransformer,
17
+ FilterbankSettings,
18
+ FilterbankMode,
19
+ MinPhaseMode,
20
+ )
21
+
22
+ from .kaiser import KaiserFilterSettings, kaiser_design_fun
23
+
24
+
25
+ class FilterbankDesignSettings(ez.Settings):
26
+ filters: typing.Iterable[KaiserFilterSettings]
27
+
28
+ mode: FilterbankMode = FilterbankMode.CONV
29
+ """
30
+ "conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
31
+ fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
32
+ incur a delay equal to the window length, which is larger than the largest kernel.
33
+ conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
34
+ and thus can provide shorter latency updates.
35
+ """
36
+
37
+ min_phase: MinPhaseMode = MinPhaseMode.NONE
38
+ """
39
+ If not None, convert the kernels to minimum-phase equivalents. Valid options are
40
+ 'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
41
+ See `scipy.signal.minimum_phase` for details.
42
+ """
43
+
44
+ axis: str = "time"
45
+ """The name of the axis to operate on. This should usually be "time"."""
46
+
47
+ new_axis: str = "kernel"
48
+ """The name of the new axis corresponding to the kernel index."""
49
+
50
+
51
+ @processor_state
52
+ class FilterbankDesignState:
53
+ filterbank: FilterbankTransformer | None = None
54
+ needs_redesign: bool = False
55
+
56
+
57
+ class FilterbankDesignTransformer(
58
+ BaseStatefulTransformer[
59
+ FilterbankDesignSettings, AxisArray, AxisArray, FilterbankDesignState
60
+ ],
61
+ ):
62
+ """
63
+ Transformer that designs and applies a filterbank based on Kaiser windowed FIR filters.
64
+ """
65
+
66
+ @classmethod
67
+ def get_message_type(cls, dir: str) -> type[AxisArray]:
68
+ if dir in ("in", "out"):
69
+ return AxisArray
70
+ else:
71
+ raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
72
+
73
+ def update_settings(
74
+ self, new_settings: typing.Optional[FilterbankDesignSettings] = None, **kwargs
75
+ ) -> None:
76
+ """
77
+ Update settings and mark that filter coefficients need to be recalculated.
78
+
79
+ Args:
80
+ new_settings: Complete new settings object to replace current settings
81
+ **kwargs: Individual settings to update
82
+ """
83
+ # Update settings
84
+ if new_settings is not None:
85
+ self.settings = new_settings
86
+ else:
87
+ self.settings = replace(self.settings, **kwargs)
88
+
89
+ # Set flag to trigger recalculation on next message
90
+ if self.state.filterbank is not None:
91
+ self.state.needs_redesign = True
92
+
93
+ def _calculate_kernels(self, fs: float) -> list[npt.NDArray]:
94
+ kernels = []
95
+ for filter in self.settings.filters:
96
+ output = kaiser_design_fun(
97
+ fs,
98
+ cutoff=filter.cutoff,
99
+ ripple=filter.ripple,
100
+ width=filter.width,
101
+ pass_zero=filter.pass_zero,
102
+ wn_hz=filter.wn_hz,
103
+ )
104
+
105
+ kernels.append(np.array([1.0]) if output is None else output[0])
106
+ return kernels
107
+
108
+ def __call__(self, message: AxisArray) -> AxisArray:
109
+ if self.state.filterbank is not None and self.state.needs_redesign:
110
+ self._reset_state(message)
111
+ self.state.needs_redesign = False
112
+ return super().__call__(message)
113
+
114
+ def _hash_message(self, message: AxisArray) -> int:
115
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
116
+ gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
117
+ axis_idx = message.get_axis_idx(axis)
118
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
119
+ return hash((message.key, samp_shape, gain))
120
+
121
+ def _reset_state(self, message: AxisArray) -> None:
122
+ axis_obj = message.axes[self.settings.axis]
123
+ assert isinstance(axis_obj, AxisArray.LinearAxis)
124
+ fs = 1 / axis_obj.gain
125
+ kernels = self._calculate_kernels(fs)
126
+ new_settings = FilterbankSettings(
127
+ kernels=kernels,
128
+ mode=self.settings.mode,
129
+ min_phase=self.settings.min_phase,
130
+ axis=self.settings.axis,
131
+ new_axis=self.settings.new_axis,
132
+ )
133
+ self.state.filterbank = FilterbankTransformer(settings=new_settings)
134
+
135
+ def _process(self, message: AxisArray) -> AxisArray:
136
+ return self.state.filterbank(message)