ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import math
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
import numpy as np
|
|
7
|
+
import numpy.typing as npt
|
|
8
|
+
import scipy.fft as sp_fft
|
|
9
|
+
import scipy.signal as sps
|
|
10
|
+
from ezmsg.baseproc import (
|
|
11
|
+
BaseStatefulTransformer,
|
|
12
|
+
BaseTransformerUnit,
|
|
13
|
+
processor_state,
|
|
14
|
+
)
|
|
15
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
16
|
+
from ezmsg.util.messages.util import replace
|
|
17
|
+
from scipy.special import lambertw
|
|
18
|
+
|
|
19
|
+
from .spectrum import OptionsEnum
|
|
20
|
+
from .window import WindowTransformer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FilterbankMode(OptionsEnum):
|
|
24
|
+
"""The mode of operation for the filterbank."""
|
|
25
|
+
|
|
26
|
+
CONV = "Direct Convolution"
|
|
27
|
+
FFT = "FFT Convolution"
|
|
28
|
+
AUTO = "Automatic"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MinPhaseMode(OptionsEnum):
|
|
32
|
+
"""The mode of operation for the filterbank."""
|
|
33
|
+
|
|
34
|
+
NONE = "No kernel modification"
|
|
35
|
+
HILBERT = (
|
|
36
|
+
"Hilbert Method; designed to be used with equiripple filters (e.g., from remez) with unity or zero gain regions"
|
|
37
|
+
)
|
|
38
|
+
HOMOMORPHIC = (
|
|
39
|
+
"Works best with filters with an odd number of taps, and the resulting minimum phase filter "
|
|
40
|
+
"will have a magnitude response that approximates the square root of the original filter’s "
|
|
41
|
+
"magnitude response using half the number of taps"
|
|
42
|
+
)
|
|
43
|
+
# HOMOMORPHICFULL = "Like HOMOMORPHIC, but uses the full number of taps and same magnitude"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class FilterbankSettings(ez.Settings):
|
|
47
|
+
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...]
|
|
48
|
+
|
|
49
|
+
mode: FilterbankMode = FilterbankMode.CONV
|
|
50
|
+
"""
|
|
51
|
+
"conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
|
|
52
|
+
fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
|
|
53
|
+
incur a delay equal to the window length, which is larger than the largest kernel.
|
|
54
|
+
conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
|
|
55
|
+
and thus can provide shorter latency updates.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
59
|
+
"""
|
|
60
|
+
If not None, convert the kernels to minimum-phase equivalents. Valid options are
|
|
61
|
+
'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
|
|
62
|
+
See `scipy.signal.minimum_phase` for details.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
axis: str = "time"
|
|
66
|
+
"""The name of the axis to operate on. This should usually be "time"."""
|
|
67
|
+
|
|
68
|
+
new_axis: str = "kernel"
|
|
69
|
+
"""The name of the new axis corresponding to the kernel index."""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@processor_state
|
|
73
|
+
class FilterbankState:
|
|
74
|
+
tail: npt.NDArray | None = None
|
|
75
|
+
template: AxisArray | None = None
|
|
76
|
+
dest_arr: npt.NDArray | None = None
|
|
77
|
+
prep_kerns: npt.NDArray | list[npt.NDArray] | None = None
|
|
78
|
+
windower: WindowTransformer | None = None
|
|
79
|
+
fft: typing.Callable | None = None
|
|
80
|
+
ifft: typing.Callable | None = None
|
|
81
|
+
nfft: int | None = None
|
|
82
|
+
infft: int | None = None
|
|
83
|
+
overlap: int | None = None
|
|
84
|
+
mode: FilterbankMode | None = None
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class FilterbankTransformer(BaseStatefulTransformer[FilterbankSettings, AxisArray, AxisArray, FilterbankState]):
|
|
88
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
89
|
+
axis = self.settings.axis or message.dims[0]
|
|
90
|
+
gain = message.axes[axis].gain if axis in message.axes else 1.0
|
|
91
|
+
targ_ax_ix = message.get_axis_idx(axis)
|
|
92
|
+
in_shape = message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
93
|
+
|
|
94
|
+
return hash(
|
|
95
|
+
(
|
|
96
|
+
message.key,
|
|
97
|
+
gain if self.settings.mode in [FilterbankMode.FFT, FilterbankMode.AUTO] else None,
|
|
98
|
+
message.data.dtype.kind,
|
|
99
|
+
in_shape,
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
104
|
+
axis = self.settings.axis or message.dims[0]
|
|
105
|
+
gain = message.axes[axis].gain if axis in message.axes else 1.0
|
|
106
|
+
targ_ax_ix = message.get_axis_idx(axis)
|
|
107
|
+
in_shape = message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
108
|
+
|
|
109
|
+
kernels = self.settings.kernels
|
|
110
|
+
if self.settings.min_phase != MinPhaseMode.NONE:
|
|
111
|
+
method, half = {
|
|
112
|
+
MinPhaseMode.HILBERT: ("hilbert", False),
|
|
113
|
+
MinPhaseMode.HOMOMORPHIC: ("homomorphic", False),
|
|
114
|
+
# MinPhaseMode.HOMOMORPHICFULL: ("homomorphic", True),
|
|
115
|
+
}[self.settings.min_phase]
|
|
116
|
+
kernels = [sps.minimum_phase(k, method=method) for k in kernels]
|
|
117
|
+
|
|
118
|
+
# Determine if this will be operating with complex data.
|
|
119
|
+
b_complex = message.data.dtype.kind == "c" or any([_.dtype.kind == "c" for _ in kernels])
|
|
120
|
+
|
|
121
|
+
# Calculate window_dur, window_shift, nfft
|
|
122
|
+
max_kernel_len = max([_.size for _ in kernels])
|
|
123
|
+
# From sps._calc_oa_lens, where s2=max_kernel_len,:
|
|
124
|
+
# fallback_nfft = n_input + max_kernel_len - 1, but n_input is unbound.
|
|
125
|
+
self._state.overlap = max_kernel_len - 1
|
|
126
|
+
|
|
127
|
+
# Prepare previous iteration's overlap tail to add to input -- all zeros.
|
|
128
|
+
tail_shape = in_shape + (len(kernels), self._state.overlap)
|
|
129
|
+
self._state.tail = np.zeros(tail_shape, dtype="complex" if b_complex else "float")
|
|
130
|
+
|
|
131
|
+
# Prepare output template -- kernels axis immediately before the target axis
|
|
132
|
+
dummy_shape = in_shape + (len(kernels), 0)
|
|
133
|
+
self._state.template = AxisArray(
|
|
134
|
+
data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
|
|
135
|
+
dims=message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [self.settings.new_axis, axis],
|
|
136
|
+
axes=message.axes.copy(),
|
|
137
|
+
key=message.key,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Determine optimal mode. Assumes 100 msec chunks.
|
|
141
|
+
self._state.mode = self.settings.mode
|
|
142
|
+
if self._state.mode == FilterbankMode.AUTO:
|
|
143
|
+
# concatenate kernels into 1 mega kernel then check what's faster.
|
|
144
|
+
# Will typically return fft when combined kernel length is > 1500.
|
|
145
|
+
concat_kernel = np.concatenate(kernels)
|
|
146
|
+
n_dummy = max(2 * len(concat_kernel), int(0.1 / gain))
|
|
147
|
+
dummy_arr = np.zeros(n_dummy)
|
|
148
|
+
self._state.mode = (
|
|
149
|
+
FilterbankMode.CONV
|
|
150
|
+
if sps.choose_conv_method(dummy_arr, concat_kernel, mode="full") == "direct"
|
|
151
|
+
else FilterbankMode.FFT
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
if self._state.mode == FilterbankMode.CONV:
|
|
155
|
+
# Preallocate memory for convolution result and overlap-add
|
|
156
|
+
dest_shape = in_shape + (
|
|
157
|
+
len(kernels),
|
|
158
|
+
self._state.overlap + message.data.shape[targ_ax_ix],
|
|
159
|
+
)
|
|
160
|
+
self._state.dest_arr = np.zeros(dest_shape, dtype="complex" if b_complex else "float")
|
|
161
|
+
self._state.prep_kerns = kernels
|
|
162
|
+
else: # FFT mode
|
|
163
|
+
# Calculate optimal nfft and windowing size.
|
|
164
|
+
opt_size = -self._state.overlap * lambertw(-1 / (2 * math.e * self._state.overlap), k=-1).real
|
|
165
|
+
self._state.nfft = sp_fft.next_fast_len(math.ceil(opt_size))
|
|
166
|
+
win_len = self._state.nfft - self._state.overlap
|
|
167
|
+
# infft same as nfft. Keeping as separate variable because I might need it again.
|
|
168
|
+
self._state.infft = win_len + self._state.overlap
|
|
169
|
+
|
|
170
|
+
# Create windowing node.
|
|
171
|
+
# Note: We could do windowing manually to avoid the overhead of the message structure,
|
|
172
|
+
# but windowing is difficult to do correctly, so we lean on the heavily-tested `windowing` generator.
|
|
173
|
+
win_dur = win_len * gain
|
|
174
|
+
self._state.windower = WindowTransformer(
|
|
175
|
+
axis=axis,
|
|
176
|
+
newaxis="win",
|
|
177
|
+
window_dur=win_dur,
|
|
178
|
+
window_shift=win_dur,
|
|
179
|
+
zero_pad_until="none",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Windowing output has an extra "win" dimension, so we need our tail to match.
|
|
183
|
+
self._state.tail = np.expand_dims(self._state.tail, -2)
|
|
184
|
+
|
|
185
|
+
# Prepare fft functions
|
|
186
|
+
# Note: We could instead use `spectrum` but this adds overhead in creating the message structure
|
|
187
|
+
# for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
|
|
188
|
+
# more fft backends.
|
|
189
|
+
if b_complex:
|
|
190
|
+
self._state.fft = functools.partial(sp_fft.fft, n=self._state.nfft, norm="backward")
|
|
191
|
+
self._state.ifft = functools.partial(sp_fft.ifft, n=self._state.infft, norm="backward")
|
|
192
|
+
else:
|
|
193
|
+
self._state.fft = functools.partial(sp_fft.rfft, n=self._state.nfft, norm="backward")
|
|
194
|
+
self._state.ifft = functools.partial(sp_fft.irfft, n=self._state.infft, norm="backward")
|
|
195
|
+
|
|
196
|
+
# Calculate fft of kernels
|
|
197
|
+
self._state.prep_kerns = np.array([self._state.fft(_) for _ in kernels])
|
|
198
|
+
self._state.prep_kerns = np.expand_dims(self._state.prep_kerns, -2)
|
|
199
|
+
# TODO: If fft_kernels have significant stretches of zeros, convert to sparse array.
|
|
200
|
+
|
|
201
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
202
|
+
axis = self.settings.axis or message.dims[0]
|
|
203
|
+
targ_ax_ix = message.get_axis_idx(axis)
|
|
204
|
+
|
|
205
|
+
# Make sure target axis is in -1th position.
|
|
206
|
+
if targ_ax_ix != (message.data.ndim - 1):
|
|
207
|
+
in_dat = np.moveaxis(message.data, targ_ax_ix, -1)
|
|
208
|
+
if self._state.mode == FilterbankMode.FFT:
|
|
209
|
+
# Fix message.dims because we will pass it to windower
|
|
210
|
+
move_dims = message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [axis]
|
|
211
|
+
message = replace(message, data=in_dat, dims=move_dims)
|
|
212
|
+
else:
|
|
213
|
+
in_dat = message.data
|
|
214
|
+
|
|
215
|
+
if self._state.mode == FilterbankMode.CONV:
|
|
216
|
+
n_dest = in_dat.shape[-1] + self._state.overlap
|
|
217
|
+
if self._state.dest_arr.shape[-1] < n_dest:
|
|
218
|
+
pad = np.zeros(self._state.dest_arr.shape[:-1] + (n_dest - self._state.dest_arr.shape[-1],))
|
|
219
|
+
self._state.dest_arr = np.concatenate([self._state.dest_arr, pad], axis=-1)
|
|
220
|
+
self._state.dest_arr.fill(0)
|
|
221
|
+
|
|
222
|
+
# Note: I tried several alternatives to this loop; all were slower than this.
|
|
223
|
+
# numba.jit; stride_tricks + np.einsum; threading. Latter might be better with Python 3.13.
|
|
224
|
+
for k_ix, k in enumerate(self._state.prep_kerns):
|
|
225
|
+
n_out = in_dat.shape[-1] + k.shape[-1] - 1
|
|
226
|
+
self._state.dest_arr[..., k_ix, :n_out] = np.apply_along_axis(np.convolve, -1, in_dat, k, mode="full")
|
|
227
|
+
self._state.dest_arr[..., : self._state.overlap] += self._state.tail
|
|
228
|
+
new_tail = self._state.dest_arr[..., in_dat.shape[-1] : n_dest]
|
|
229
|
+
if new_tail.size > 0:
|
|
230
|
+
# COPY overlap for next iteration
|
|
231
|
+
self._state.tail = new_tail.copy()
|
|
232
|
+
res = self._state.dest_arr[..., : in_dat.shape[-1]].copy()
|
|
233
|
+
else: # FFT mode
|
|
234
|
+
# Slice into non-overlapping windows
|
|
235
|
+
win_msg = self._state.windower.send(message)
|
|
236
|
+
# Calculate spectrum of each window
|
|
237
|
+
spec_dat = self._state.fft(win_msg.data, axis=-1)
|
|
238
|
+
# Insert axis for filters
|
|
239
|
+
spec_dat = np.expand_dims(spec_dat, -3)
|
|
240
|
+
|
|
241
|
+
# Do the FFT convolution
|
|
242
|
+
# TODO: handle fft_kernels being sparse. Maybe need np.dot.
|
|
243
|
+
conv_spec = spec_dat * self._state.prep_kerns
|
|
244
|
+
overlapped = self._state.ifft(conv_spec, axis=-1)
|
|
245
|
+
|
|
246
|
+
# Do the overlap-add on the `axis` axis
|
|
247
|
+
# Previous iteration's tail:
|
|
248
|
+
overlapped[..., :1, : self._state.overlap] += self._state.tail
|
|
249
|
+
# window-to-window:
|
|
250
|
+
overlapped[..., 1:, : self._state.overlap] += overlapped[..., :-1, -self._state.overlap :]
|
|
251
|
+
# Save tail:
|
|
252
|
+
new_tail = overlapped[..., -1:, -self._state.overlap :]
|
|
253
|
+
if new_tail.size > 0:
|
|
254
|
+
# All of the above code works if input is size-zero, but we don't want to save a zero-size tail.
|
|
255
|
+
self._state.tail = new_tail
|
|
256
|
+
# Concat over win axis, without overlap.
|
|
257
|
+
res = overlapped[..., : -self._state.overlap].reshape(overlapped.shape[:-2] + (-1,))
|
|
258
|
+
|
|
259
|
+
return replace(
|
|
260
|
+
self._state.template,
|
|
261
|
+
data=res,
|
|
262
|
+
axes={**self._state.template.axes, axis: message.axes[axis]},
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class Filterbank(BaseTransformerUnit[FilterbankSettings, AxisArray, AxisArray, FilterbankTransformer]):
|
|
267
|
+
SETTINGS = FilterbankSettings
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def filterbank(
|
|
271
|
+
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...],
|
|
272
|
+
mode: FilterbankMode = FilterbankMode.CONV,
|
|
273
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
274
|
+
axis: str = "time",
|
|
275
|
+
new_axis: str = "kernel",
|
|
276
|
+
) -> FilterbankTransformer:
|
|
277
|
+
"""
|
|
278
|
+
Perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
|
|
279
|
+
This is intended to be used during online processing, therefore both direct and fft convolutions
|
|
280
|
+
use the overlap-add method.
|
|
281
|
+
|
|
282
|
+
Returns: :obj:`FilterbankTransformer`.
|
|
283
|
+
"""
|
|
284
|
+
return FilterbankTransformer(
|
|
285
|
+
settings=FilterbankSettings(
|
|
286
|
+
kernels=kernels,
|
|
287
|
+
mode=mode,
|
|
288
|
+
min_phase=min_phase,
|
|
289
|
+
axis=axis,
|
|
290
|
+
new_axis=new_axis,
|
|
291
|
+
)
|
|
292
|
+
)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
from ezmsg.baseproc import (
|
|
7
|
+
BaseStatefulTransformer,
|
|
8
|
+
processor_state,
|
|
9
|
+
)
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
from ezmsg.util.messages.util import replace
|
|
12
|
+
|
|
13
|
+
from .filterbank import (
|
|
14
|
+
FilterbankMode,
|
|
15
|
+
FilterbankSettings,
|
|
16
|
+
FilterbankTransformer,
|
|
17
|
+
MinPhaseMode,
|
|
18
|
+
)
|
|
19
|
+
from .kaiser import KaiserFilterSettings, kaiser_design_fun
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FilterbankDesignSettings(ez.Settings):
|
|
23
|
+
filters: typing.Iterable[KaiserFilterSettings]
|
|
24
|
+
|
|
25
|
+
mode: FilterbankMode = FilterbankMode.CONV
|
|
26
|
+
"""
|
|
27
|
+
"conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
|
|
28
|
+
fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
|
|
29
|
+
incur a delay equal to the window length, which is larger than the largest kernel.
|
|
30
|
+
conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
|
|
31
|
+
and thus can provide shorter latency updates.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
35
|
+
"""
|
|
36
|
+
If not None, convert the kernels to minimum-phase equivalents. Valid options are
|
|
37
|
+
'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
|
|
38
|
+
See `scipy.signal.minimum_phase` for details.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
axis: str = "time"
|
|
42
|
+
"""The name of the axis to operate on. This should usually be "time"."""
|
|
43
|
+
|
|
44
|
+
new_axis: str = "kernel"
|
|
45
|
+
"""The name of the new axis corresponding to the kernel index."""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@processor_state
|
|
49
|
+
class FilterbankDesignState:
|
|
50
|
+
filterbank: FilterbankTransformer | None = None
|
|
51
|
+
needs_redesign: bool = False
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class FilterbankDesignTransformer(
|
|
55
|
+
BaseStatefulTransformer[FilterbankDesignSettings, AxisArray, AxisArray, FilterbankDesignState],
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Transformer that designs and applies a filterbank based on Kaiser windowed FIR filters.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def get_message_type(cls, dir: str) -> type[AxisArray]:
|
|
63
|
+
if dir in ("in", "out"):
|
|
64
|
+
return AxisArray
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
|
|
67
|
+
|
|
68
|
+
def update_settings(self, new_settings: typing.Optional[FilterbankDesignSettings] = None, **kwargs) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Update settings and mark that filter coefficients need to be recalculated.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
new_settings: Complete new settings object to replace current settings
|
|
74
|
+
**kwargs: Individual settings to update
|
|
75
|
+
"""
|
|
76
|
+
# Update settings
|
|
77
|
+
if new_settings is not None:
|
|
78
|
+
self.settings = new_settings
|
|
79
|
+
else:
|
|
80
|
+
self.settings = replace(self.settings, **kwargs)
|
|
81
|
+
|
|
82
|
+
# Set flag to trigger recalculation on next message
|
|
83
|
+
if self.state.filterbank is not None:
|
|
84
|
+
self.state.needs_redesign = True
|
|
85
|
+
|
|
86
|
+
def _calculate_kernels(self, fs: float) -> list[npt.NDArray]:
|
|
87
|
+
kernels = []
|
|
88
|
+
for filter in self.settings.filters:
|
|
89
|
+
output = kaiser_design_fun(
|
|
90
|
+
fs,
|
|
91
|
+
cutoff=filter.cutoff,
|
|
92
|
+
ripple=filter.ripple,
|
|
93
|
+
width=filter.width,
|
|
94
|
+
pass_zero=filter.pass_zero,
|
|
95
|
+
wn_hz=filter.wn_hz,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
kernels.append(np.array([1.0]) if output is None else output[0])
|
|
99
|
+
return kernels
|
|
100
|
+
|
|
101
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
102
|
+
if self.state.filterbank is not None and self.state.needs_redesign:
|
|
103
|
+
self._reset_state(message)
|
|
104
|
+
self.state.needs_redesign = False
|
|
105
|
+
return super().__call__(message)
|
|
106
|
+
|
|
107
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
108
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
109
|
+
gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
|
|
110
|
+
axis_idx = message.get_axis_idx(axis)
|
|
111
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
112
|
+
return hash((message.key, samp_shape, gain))
|
|
113
|
+
|
|
114
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
115
|
+
axis_obj = message.axes[self.settings.axis]
|
|
116
|
+
assert isinstance(axis_obj, AxisArray.LinearAxis)
|
|
117
|
+
fs = 1 / axis_obj.gain
|
|
118
|
+
kernels = self._calculate_kernels(fs)
|
|
119
|
+
new_settings = FilterbankSettings(
|
|
120
|
+
kernels=kernels,
|
|
121
|
+
mode=self.settings.mode,
|
|
122
|
+
min_phase=self.settings.min_phase,
|
|
123
|
+
axis=self.settings.axis,
|
|
124
|
+
new_axis=self.settings.new_axis,
|
|
125
|
+
)
|
|
126
|
+
self.state.filterbank = FilterbankTransformer(settings=new_settings)
|
|
127
|
+
|
|
128
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
129
|
+
return self.state.filterbank(message)
|