ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
- ezmsg/sigproc/affinetransform.py +16 -42
- ezmsg/sigproc/aggregate.py +17 -34
- ezmsg/sigproc/bandpower.py +12 -20
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +7 -16
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/coordinatespaces.py +142 -0
- ezmsg/sigproc/decimate.py +3 -7
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +11 -20
- ezmsg/sigproc/ewma.py +11 -28
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +34 -59
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +37 -74
- ezmsg/sigproc/filterbankdesign.py +7 -14
- ezmsg/sigproc/fir_hilbert.py +13 -30
- ezmsg/sigproc/fir_pmc.py +5 -10
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +4 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +4 -1
- ezmsg/sigproc/math/difference.py +100 -36
- ezmsg/sigproc/math/invert.py +3 -3
- ezmsg/sigproc/math/log.py +5 -6
- ezmsg/sigproc/math/scale.py +2 -0
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +3 -6
- ezmsg/sigproc/resample.py +17 -38
- ezmsg/sigproc/rollingscaler.py +12 -37
- ezmsg/sigproc/sampler.py +19 -37
- ezmsg/sigproc/scaler.py +11 -22
- ezmsg/sigproc/signalinjector.py +7 -18
- ezmsg/sigproc/slicer.py +14 -34
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +17 -38
- ezmsg/sigproc/transpose.py +12 -24
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +12 -26
- ezmsg/sigproc/util/buffer.py +22 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +7 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +10 -19
- ezmsg/sigproc/window.py +29 -83
- ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
- ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
- {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/fbcca.py
CHANGED
|
@@ -1,29 +1,26 @@
|
|
|
1
|
-
import typing
|
|
2
1
|
import math
|
|
2
|
+
import typing
|
|
3
3
|
from dataclasses import field
|
|
4
4
|
|
|
5
|
-
import numpy as np
|
|
6
|
-
|
|
7
5
|
import ezmsg.core as ez
|
|
8
|
-
|
|
9
|
-
from ezmsg.
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from .window import WindowTransformer, WindowSettings
|
|
13
|
-
|
|
14
|
-
from .base import (
|
|
6
|
+
import numpy as np
|
|
7
|
+
from ezmsg.baseproc import (
|
|
8
|
+
BaseProcessor,
|
|
9
|
+
BaseStatefulProcessor,
|
|
15
10
|
BaseTransformer,
|
|
16
11
|
BaseTransformerUnit,
|
|
17
12
|
CompositeProcessor,
|
|
18
|
-
BaseProcessor,
|
|
19
|
-
BaseStatefulProcessor,
|
|
20
13
|
)
|
|
14
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
15
|
+
from ezmsg.util.messages.util import replace
|
|
21
16
|
|
|
22
|
-
from .kaiser import KaiserFilterSettings
|
|
23
17
|
from .filterbankdesign import (
|
|
24
18
|
FilterbankDesignSettings,
|
|
25
19
|
FilterbankDesignTransformer,
|
|
26
20
|
)
|
|
21
|
+
from .kaiser import KaiserFilterSettings
|
|
22
|
+
from .sampler import SampleTriggerMessage
|
|
23
|
+
from .window import WindowSettings, WindowTransformer
|
|
27
24
|
|
|
28
25
|
|
|
29
26
|
class FBCCASettings(ez.Settings):
|
|
@@ -33,7 +30,7 @@ class FBCCASettings(ez.Settings):
|
|
|
33
30
|
|
|
34
31
|
time_dim: str
|
|
35
32
|
"""
|
|
36
|
-
The time dim in the data array.
|
|
33
|
+
The time dim in the data array.
|
|
37
34
|
"""
|
|
38
35
|
|
|
39
36
|
ch_dim: str
|
|
@@ -49,40 +46,41 @@ class FBCCASettings(ez.Settings):
|
|
|
49
46
|
|
|
50
47
|
harmonics: int = 5
|
|
51
48
|
"""
|
|
52
|
-
The number of additional harmonics beyond the fundamental to use for the 'design' matrix.
|
|
53
|
-
5 (default): Evaluate 5 harmonics of the base frequency.
|
|
54
|
-
Many periodic signals are not pure sinusoids, and inclusion of higher harmonics can help evaluate the
|
|
49
|
+
The number of additional harmonics beyond the fundamental to use for the 'design' matrix.
|
|
50
|
+
5 (default): Evaluate 5 harmonics of the base frequency.
|
|
51
|
+
Many periodic signals are not pure sinusoids, and inclusion of higher harmonics can help evaluate the
|
|
55
52
|
presence of signals with higher frequency harmonic content
|
|
56
53
|
"""
|
|
57
54
|
|
|
58
55
|
freqs: typing.List[float] = field(default_factory=list)
|
|
59
56
|
"""
|
|
60
|
-
Frequencies (in hz) to evaluate the presence of within the input signal.
|
|
61
|
-
[] (default): an empty list; frequencies will be found within the input SampleMessages.
|
|
57
|
+
Frequencies (in hz) to evaluate the presence of within the input signal.
|
|
58
|
+
[] (default): an empty list; frequencies will be found within the input SampleMessages.
|
|
62
59
|
AxisArrays have no good place to put this metadata, so specify frequencies here if only AxisArrays
|
|
63
60
|
will be passed as input to the generator. If the input has a `trigger` attr of type :obj:`SampleTriggerMessage`,
|
|
64
|
-
the processor looks for the `freqs` attribute within that trigger for a list of frequencies to evaluate.
|
|
65
|
-
This field is present in the :obj:`SSVEPSampleTriggerMessage` defined in ezmsg.tasks.ssvep from
|
|
61
|
+
the processor looks for the `freqs` attribute within that trigger for a list of frequencies to evaluate.
|
|
62
|
+
This field is present in the :obj:`SSVEPSampleTriggerMessage` defined in ezmsg.tasks.ssvep from
|
|
63
|
+
the ezmsg-tasks package.
|
|
66
64
|
NOTE: Avoid frequencies that have line-noise (60 Hz/50 Hz) as a harmonic.
|
|
67
65
|
"""
|
|
68
66
|
|
|
69
67
|
softmax_beta: float = 1.0
|
|
70
68
|
"""
|
|
71
|
-
Beta parameter for softmax on output --> "probabilities".
|
|
72
|
-
1.0 (default): Use the shifted softmax transformation to output 0-1 probabilities.
|
|
69
|
+
Beta parameter for softmax on output --> "probabilities".
|
|
70
|
+
1.0 (default): Use the shifted softmax transformation to output 0-1 probabilities.
|
|
73
71
|
If 0.0, the maximum singular value of the SVD for each design matrix is output
|
|
74
72
|
"""
|
|
75
73
|
|
|
76
74
|
target_freq_dim: str = "target_freq"
|
|
77
75
|
"""
|
|
78
|
-
Name for dim to put target frequency outputs on.
|
|
76
|
+
Name for dim to put target frequency outputs on.
|
|
79
77
|
'target_freq' (default)
|
|
80
78
|
"""
|
|
81
79
|
|
|
82
80
|
max_int_time: float = 0.0
|
|
83
81
|
"""
|
|
84
|
-
Maximum integration time (in seconds) to use for calculation.
|
|
85
|
-
0 (default): Use all time provided for the calculation.
|
|
82
|
+
Maximum integration time (in seconds) to use for calculation.
|
|
83
|
+
0 (default): Use all time provided for the calculation.
|
|
86
84
|
Useful for artificially limiting the amount of data used for the CCA method to evaluate
|
|
87
85
|
the necessary integration time for good decoding performance
|
|
88
86
|
"""
|
|
@@ -136,9 +134,7 @@ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
|
|
|
136
134
|
if filterbank_dim_idx is not None:
|
|
137
135
|
new_order.append(filterbank_dim_idx)
|
|
138
136
|
new_order.extend([time_dim_idx, ch_dim_idx])
|
|
139
|
-
out_dims = [
|
|
140
|
-
message.dims[i] for i in new_order if message.dims[i] not in rm_dims
|
|
141
|
-
]
|
|
137
|
+
out_dims = [message.dims[i] for i in new_order if message.dims[i] not in rm_dims]
|
|
142
138
|
data_arr = message.data.transpose(new_order)
|
|
143
139
|
|
|
144
140
|
# Add a singleton dim for filterbank dim if we don't have one
|
|
@@ -158,10 +154,7 @@ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
|
|
|
158
154
|
axis_name: axis
|
|
159
155
|
for axis_name, axis in message.axes.items()
|
|
160
156
|
if axis_name not in rm_dims
|
|
161
|
-
and not (
|
|
162
|
-
isinstance(axis, AxisArray.CoordinateAxis)
|
|
163
|
-
and any(d in rm_dims for d in axis.dims)
|
|
164
|
-
)
|
|
157
|
+
and not (isinstance(axis, AxisArray.CoordinateAxis) and any(d in rm_dims for d in axis.dims))
|
|
165
158
|
}
|
|
166
159
|
out_axes[self.settings.target_freq_dim] = AxisArray.CoordinateAxis(
|
|
167
160
|
np.array(test_freqs), [self.settings.target_freq_dim]
|
|
@@ -193,15 +186,9 @@ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
|
|
|
193
186
|
]
|
|
194
187
|
)
|
|
195
188
|
|
|
196
|
-
for test_idx, arr in enumerate(
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
for band_idx, band in enumerate(
|
|
200
|
-
arr
|
|
201
|
-
): # iterate over second dim: arr is (time x ch)
|
|
202
|
-
calc_output[test_idx, band_idx, test_freq_idx] = cca_rho_max(
|
|
203
|
-
band[:max_samp, ...], Y
|
|
204
|
-
)
|
|
189
|
+
for test_idx, arr in enumerate(data_arr): # iterate over first dim; arr is (filterbank x time x ch)
|
|
190
|
+
for band_idx, band in enumerate(arr): # iterate over second dim: arr is (time x ch)
|
|
191
|
+
calc_output[test_idx, band_idx, test_freq_idx] = cca_rho_max(band[:max_samp, ...], Y)
|
|
205
192
|
|
|
206
193
|
# Combine per-subband canonical correlations using a weighted sum
|
|
207
194
|
# https://iopscience.iop.org/article/10.1088/1741-2560/12/4/046008
|
|
@@ -209,9 +196,7 @@ class FBCCATransformer(BaseTransformer[FBCCASettings, AxisArray, AxisArray]):
|
|
|
209
196
|
calc_output = ((calc_output**2) * freq_weights[None, :, None]).sum(axis=1)
|
|
210
197
|
|
|
211
198
|
if self.settings.softmax_beta != 0:
|
|
212
|
-
calc_output = calc_softmax(
|
|
213
|
-
calc_output, axis=-1, beta=self.settings.softmax_beta
|
|
214
|
-
)
|
|
199
|
+
calc_output = calc_softmax(calc_output, axis=-1, beta=self.settings.softmax_beta)
|
|
215
200
|
|
|
216
201
|
output = replace(
|
|
217
202
|
message,
|
|
@@ -244,9 +229,7 @@ class StreamingFBCCASettings(FBCCASettings):
|
|
|
244
229
|
subbands: int = 12
|
|
245
230
|
|
|
246
231
|
|
|
247
|
-
class StreamingFBCCATransformer(
|
|
248
|
-
CompositeProcessor[StreamingFBCCASettings, AxisArray, AxisArray]
|
|
249
|
-
):
|
|
232
|
+
class StreamingFBCCATransformer(CompositeProcessor[StreamingFBCCASettings, AxisArray, AxisArray]):
|
|
250
233
|
@staticmethod
|
|
251
234
|
def _initialize_processors(
|
|
252
235
|
settings: StreamingFBCCASettings,
|
|
@@ -254,9 +237,7 @@ class StreamingFBCCATransformer(
|
|
|
254
237
|
pipeline = {}
|
|
255
238
|
|
|
256
239
|
if settings.filterbank_dim is not None:
|
|
257
|
-
cut_freqs = (
|
|
258
|
-
np.arange(settings.subbands + 1) * settings.filter_bw
|
|
259
|
-
) + settings.filter_low
|
|
240
|
+
cut_freqs = (np.arange(settings.subbands + 1) * settings.filter_bw) + settings.filter_low
|
|
260
241
|
filters = [
|
|
261
242
|
KaiserFilterSettings(
|
|
262
243
|
axis=settings.time_dim,
|
|
@@ -269,9 +250,7 @@ class StreamingFBCCATransformer(
|
|
|
269
250
|
]
|
|
270
251
|
|
|
271
252
|
pipeline["filterbank"] = FilterbankDesignTransformer(
|
|
272
|
-
FilterbankDesignSettings(
|
|
273
|
-
filters=filters, new_axis=settings.filterbank_dim
|
|
274
|
-
)
|
|
253
|
+
FilterbankDesignSettings(filters=filters, new_axis=settings.filterbank_dim)
|
|
275
254
|
)
|
|
276
255
|
|
|
277
256
|
pipeline["window"] = WindowTransformer(
|
|
@@ -289,11 +268,7 @@ class StreamingFBCCATransformer(
|
|
|
289
268
|
return pipeline
|
|
290
269
|
|
|
291
270
|
|
|
292
|
-
class StreamingFBCCA(
|
|
293
|
-
BaseTransformerUnit[
|
|
294
|
-
StreamingFBCCASettings, AxisArray, AxisArray, StreamingFBCCATransformer
|
|
295
|
-
]
|
|
296
|
-
):
|
|
271
|
+
class StreamingFBCCA(BaseTransformerUnit[StreamingFBCCASettings, AxisArray, AxisArray, StreamingFBCCATransformer]):
|
|
297
272
|
SETTINGS = StreamingFBCCASettings
|
|
298
273
|
|
|
299
274
|
|
ezmsg/sigproc/filter.py
CHANGED
|
@@ -1,21 +1,21 @@
|
|
|
1
|
-
from abc import abstractmethod, ABC
|
|
2
|
-
from dataclasses import dataclass, field
|
|
3
1
|
import typing
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
4
|
|
|
5
5
|
import ezmsg.core as ez
|
|
6
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
-
from ezmsg.util.messages.util import replace
|
|
8
6
|
import numpy as np
|
|
9
7
|
import numpy.typing as npt
|
|
10
8
|
import scipy.signal
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
11
11
|
|
|
12
12
|
from ezmsg.sigproc.base import (
|
|
13
|
-
|
|
13
|
+
BaseConsumerUnit,
|
|
14
14
|
BaseStatefulTransformer,
|
|
15
15
|
BaseTransformerUnit,
|
|
16
16
|
SettingsType,
|
|
17
|
-
BaseConsumerUnit,
|
|
18
17
|
TransformerType,
|
|
18
|
+
processor_state,
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
|
|
@@ -68,9 +68,7 @@ class FilterState:
|
|
|
68
68
|
zi: npt.NDArray | None = None
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
class FilterTransformer(
|
|
72
|
-
BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]
|
|
73
|
-
):
|
|
71
|
+
class FilterTransformer(BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]):
|
|
74
72
|
"""
|
|
75
73
|
Filter data using the provided coefficients.
|
|
76
74
|
"""
|
|
@@ -108,9 +106,7 @@ class FilterTransformer(
|
|
|
108
106
|
zi = scipy.signal.sosfilt_zi(*coefs)
|
|
109
107
|
|
|
110
108
|
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
|
|
111
|
-
n_tile = (
|
|
112
|
-
message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :]
|
|
113
|
-
)
|
|
109
|
+
n_tile = message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :]
|
|
114
110
|
|
|
115
111
|
if self.settings.coef_type == "sos":
|
|
116
112
|
zi_expand = (slice(None),) + zi_expand
|
|
@@ -144,17 +140,11 @@ class FilterTransformer(
|
|
|
144
140
|
reset_needed = False
|
|
145
141
|
|
|
146
142
|
if self.settings.coef_type == "ba":
|
|
147
|
-
if isinstance(old_coefs, FilterCoefficients) and isinstance(
|
|
148
|
-
coefs
|
|
149
|
-
):
|
|
150
|
-
if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(
|
|
151
|
-
coefs.a
|
|
152
|
-
):
|
|
143
|
+
if isinstance(old_coefs, FilterCoefficients) and isinstance(coefs, FilterCoefficients):
|
|
144
|
+
if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(coefs.a):
|
|
153
145
|
reset_needed = True
|
|
154
146
|
elif isinstance(old_coefs, tuple) and isinstance(coefs, tuple):
|
|
155
|
-
if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(
|
|
156
|
-
coefs[1]
|
|
157
|
-
):
|
|
147
|
+
if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(coefs[1]):
|
|
158
148
|
reset_needed = True
|
|
159
149
|
else:
|
|
160
150
|
reset_needed = True
|
|
@@ -173,36 +163,26 @@ class FilterTransformer(
|
|
|
173
163
|
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
174
164
|
axis_idx = message.get_axis_idx(axis)
|
|
175
165
|
_, coefs = _normalize_coefs(self.settings.coefs)
|
|
176
|
-
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[
|
|
177
|
-
|
|
178
|
-
]
|
|
179
|
-
dat_out, self.state.zi = filt_func(
|
|
180
|
-
*coefs, message.data, axis=axis_idx, zi=self.state.zi
|
|
181
|
-
)
|
|
166
|
+
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[self.settings.coef_type]
|
|
167
|
+
dat_out, self.state.zi = filt_func(*coefs, message.data, axis=axis_idx, zi=self.state.zi)
|
|
182
168
|
else:
|
|
183
169
|
dat_out = message.data
|
|
184
170
|
|
|
185
171
|
return replace(message, data=dat_out)
|
|
186
172
|
|
|
187
173
|
|
|
188
|
-
class Filter(
|
|
189
|
-
BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]
|
|
190
|
-
):
|
|
174
|
+
class Filter(BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]):
|
|
191
175
|
SETTINGS = FilterSettings
|
|
192
176
|
|
|
193
177
|
|
|
194
|
-
def filtergen(
|
|
195
|
-
axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str
|
|
196
|
-
) -> FilterTransformer:
|
|
178
|
+
def filtergen(axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str) -> FilterTransformer:
|
|
197
179
|
"""
|
|
198
180
|
Filter data using the provided coefficients.
|
|
199
181
|
|
|
200
182
|
Returns:
|
|
201
183
|
:obj:`FilterTransformer`.
|
|
202
184
|
"""
|
|
203
|
-
return FilterTransformer(
|
|
204
|
-
FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type)
|
|
205
|
-
)
|
|
185
|
+
return FilterTransformer(FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type))
|
|
206
186
|
|
|
207
187
|
|
|
208
188
|
@processor_state
|
|
@@ -230,9 +210,7 @@ class FilterByDesignTransformer(
|
|
|
230
210
|
"""Return a function that takes sampling frequency and returns filter coefficients."""
|
|
231
211
|
...
|
|
232
212
|
|
|
233
|
-
def update_settings(
|
|
234
|
-
self, new_settings: typing.Optional[SettingsType] = None, **kwargs
|
|
235
|
-
) -> None:
|
|
213
|
+
def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
|
|
236
214
|
"""
|
|
237
215
|
Update settings and mark that filter coefficients need to be recalculated.
|
|
238
216
|
|
|
@@ -271,9 +249,7 @@ class FilterByDesignTransformer(
|
|
|
271
249
|
b, a = coefs
|
|
272
250
|
coefs = scipy.signal.tf2sos(b, a)
|
|
273
251
|
|
|
274
|
-
self.state.filter.update_coefficients(
|
|
275
|
-
coefs, coef_type=self.settings.coef_type
|
|
276
|
-
)
|
|
252
|
+
self.state.filter.update_coefficients(coefs, coef_type=self.settings.coef_type)
|
|
277
253
|
self.state.needs_redesign = False
|
|
278
254
|
|
|
279
255
|
return super().__call__(message)
|
|
@@ -298,9 +274,7 @@ class FilterByDesignTransformer(
|
|
|
298
274
|
b, a = coefs
|
|
299
275
|
coefs = scipy.signal.tf2sos(b, a)
|
|
300
276
|
|
|
301
|
-
new_settings = FilterSettings(
|
|
302
|
-
axis=axis, coef_type=self.settings.coef_type, coefs=coefs
|
|
303
|
-
)
|
|
277
|
+
new_settings = FilterSettings(axis=axis, coef_type=self.settings.coef_type, coefs=coefs)
|
|
304
278
|
self.state.filter = FilterTransformer(settings=new_settings)
|
|
305
279
|
|
|
306
280
|
def _process(self, message: AxisArray) -> AxisArray:
|
ezmsg/sigproc/filterbank.py
CHANGED
|
@@ -2,20 +2,20 @@ import functools
|
|
|
2
2
|
import math
|
|
3
3
|
import typing
|
|
4
4
|
|
|
5
|
+
import ezmsg.core as ez
|
|
5
6
|
import numpy as np
|
|
6
|
-
import scipy.signal as sps
|
|
7
|
-
import scipy.fft as sp_fft
|
|
8
|
-
from scipy.special import lambertw
|
|
9
7
|
import numpy.typing as npt
|
|
10
|
-
import
|
|
11
|
-
|
|
12
|
-
from ezmsg.
|
|
13
|
-
|
|
14
|
-
from .base import (
|
|
8
|
+
import scipy.fft as sp_fft
|
|
9
|
+
import scipy.signal as sps
|
|
10
|
+
from ezmsg.baseproc import (
|
|
15
11
|
BaseStatefulTransformer,
|
|
16
12
|
BaseTransformerUnit,
|
|
17
13
|
processor_state,
|
|
18
14
|
)
|
|
15
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
16
|
+
from ezmsg.util.messages.util import replace
|
|
17
|
+
from scipy.special import lambertw
|
|
18
|
+
|
|
19
19
|
from .spectrum import OptionsEnum
|
|
20
20
|
from .window import WindowTransformer
|
|
21
21
|
|
|
@@ -32,8 +32,14 @@ class MinPhaseMode(OptionsEnum):
|
|
|
32
32
|
"""The mode of operation for the filterbank."""
|
|
33
33
|
|
|
34
34
|
NONE = "No kernel modification"
|
|
35
|
-
HILBERT =
|
|
36
|
-
|
|
35
|
+
HILBERT = (
|
|
36
|
+
"Hilbert Method; designed to be used with equiripple filters (e.g., from remez) with unity or zero gain regions"
|
|
37
|
+
)
|
|
38
|
+
HOMOMORPHIC = (
|
|
39
|
+
"Works best with filters with an odd number of taps, and the resulting minimum phase filter "
|
|
40
|
+
"will have a magnitude response that approximates the square root of the original filter’s "
|
|
41
|
+
"magnitude response using half the number of taps"
|
|
42
|
+
)
|
|
37
43
|
# HOMOMORPHICFULL = "Like HOMOMORPHIC, but uses the full number of taps and same magnitude"
|
|
38
44
|
|
|
39
45
|
|
|
@@ -78,23 +84,17 @@ class FilterbankState:
|
|
|
78
84
|
mode: FilterbankMode | None = None
|
|
79
85
|
|
|
80
86
|
|
|
81
|
-
class FilterbankTransformer(
|
|
82
|
-
BaseStatefulTransformer[FilterbankSettings, AxisArray, AxisArray, FilterbankState]
|
|
83
|
-
):
|
|
87
|
+
class FilterbankTransformer(BaseStatefulTransformer[FilterbankSettings, AxisArray, AxisArray, FilterbankState]):
|
|
84
88
|
def _hash_message(self, message: AxisArray) -> int:
|
|
85
89
|
axis = self.settings.axis or message.dims[0]
|
|
86
90
|
gain = message.axes[axis].gain if axis in message.axes else 1.0
|
|
87
91
|
targ_ax_ix = message.get_axis_idx(axis)
|
|
88
|
-
in_shape =
|
|
89
|
-
message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
90
|
-
)
|
|
92
|
+
in_shape = message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
91
93
|
|
|
92
94
|
return hash(
|
|
93
95
|
(
|
|
94
96
|
message.key,
|
|
95
|
-
gain
|
|
96
|
-
if self.settings.mode in [FilterbankMode.FFT, FilterbankMode.AUTO]
|
|
97
|
-
else None,
|
|
97
|
+
gain if self.settings.mode in [FilterbankMode.FFT, FilterbankMode.AUTO] else None,
|
|
98
98
|
message.data.dtype.kind,
|
|
99
99
|
in_shape,
|
|
100
100
|
)
|
|
@@ -104,9 +104,7 @@ class FilterbankTransformer(
|
|
|
104
104
|
axis = self.settings.axis or message.dims[0]
|
|
105
105
|
gain = message.axes[axis].gain if axis in message.axes else 1.0
|
|
106
106
|
targ_ax_ix = message.get_axis_idx(axis)
|
|
107
|
-
in_shape =
|
|
108
|
-
message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
109
|
-
)
|
|
107
|
+
in_shape = message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
|
|
110
108
|
|
|
111
109
|
kernels = self.settings.kernels
|
|
112
110
|
if self.settings.min_phase != MinPhaseMode.NONE:
|
|
@@ -118,9 +116,7 @@ class FilterbankTransformer(
|
|
|
118
116
|
kernels = [sps.minimum_phase(k, method=method) for k in kernels]
|
|
119
117
|
|
|
120
118
|
# Determine if this will be operating with complex data.
|
|
121
|
-
b_complex = message.data.dtype.kind == "c" or any(
|
|
122
|
-
[_.dtype.kind == "c" for _ in kernels]
|
|
123
|
-
)
|
|
119
|
+
b_complex = message.data.dtype.kind == "c" or any([_.dtype.kind == "c" for _ in kernels])
|
|
124
120
|
|
|
125
121
|
# Calculate window_dur, window_shift, nfft
|
|
126
122
|
max_kernel_len = max([_.size for _ in kernels])
|
|
@@ -130,17 +126,13 @@ class FilterbankTransformer(
|
|
|
130
126
|
|
|
131
127
|
# Prepare previous iteration's overlap tail to add to input -- all zeros.
|
|
132
128
|
tail_shape = in_shape + (len(kernels), self._state.overlap)
|
|
133
|
-
self._state.tail = np.zeros(
|
|
134
|
-
tail_shape, dtype="complex" if b_complex else "float"
|
|
135
|
-
)
|
|
129
|
+
self._state.tail = np.zeros(tail_shape, dtype="complex" if b_complex else "float")
|
|
136
130
|
|
|
137
131
|
# Prepare output template -- kernels axis immediately before the target axis
|
|
138
132
|
dummy_shape = in_shape + (len(kernels), 0)
|
|
139
133
|
self._state.template = AxisArray(
|
|
140
134
|
data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
|
|
141
|
-
dims=message.dims[:targ_ax_ix]
|
|
142
|
-
+ message.dims[targ_ax_ix + 1 :]
|
|
143
|
-
+ [self.settings.new_axis, axis],
|
|
135
|
+
dims=message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [self.settings.new_axis, axis],
|
|
144
136
|
axes=message.axes.copy(),
|
|
145
137
|
key=message.key,
|
|
146
138
|
)
|
|
@@ -155,8 +147,7 @@ class FilterbankTransformer(
|
|
|
155
147
|
dummy_arr = np.zeros(n_dummy)
|
|
156
148
|
self._state.mode = (
|
|
157
149
|
FilterbankMode.CONV
|
|
158
|
-
if sps.choose_conv_method(dummy_arr, concat_kernel, mode="full")
|
|
159
|
-
== "direct"
|
|
150
|
+
if sps.choose_conv_method(dummy_arr, concat_kernel, mode="full") == "direct"
|
|
160
151
|
else FilterbankMode.FFT
|
|
161
152
|
)
|
|
162
153
|
|
|
@@ -166,16 +157,11 @@ class FilterbankTransformer(
|
|
|
166
157
|
len(kernels),
|
|
167
158
|
self._state.overlap + message.data.shape[targ_ax_ix],
|
|
168
159
|
)
|
|
169
|
-
self._state.dest_arr = np.zeros(
|
|
170
|
-
dest_shape, dtype="complex" if b_complex else "float"
|
|
171
|
-
)
|
|
160
|
+
self._state.dest_arr = np.zeros(dest_shape, dtype="complex" if b_complex else "float")
|
|
172
161
|
self._state.prep_kerns = kernels
|
|
173
162
|
else: # FFT mode
|
|
174
163
|
# Calculate optimal nfft and windowing size.
|
|
175
|
-
opt_size = (
|
|
176
|
-
-self._state.overlap
|
|
177
|
-
* lambertw(-1 / (2 * math.e * self._state.overlap), k=-1).real
|
|
178
|
-
)
|
|
164
|
+
opt_size = -self._state.overlap * lambertw(-1 / (2 * math.e * self._state.overlap), k=-1).real
|
|
179
165
|
self._state.nfft = sp_fft.next_fast_len(math.ceil(opt_size))
|
|
180
166
|
win_len = self._state.nfft - self._state.overlap
|
|
181
167
|
# infft same as nfft. Keeping as separate variable because I might need it again.
|
|
@@ -201,19 +187,11 @@ class FilterbankTransformer(
|
|
|
201
187
|
# for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
|
|
202
188
|
# more fft backends.
|
|
203
189
|
if b_complex:
|
|
204
|
-
self._state.fft = functools.partial(
|
|
205
|
-
|
|
206
|
-
)
|
|
207
|
-
self._state.ifft = functools.partial(
|
|
208
|
-
sp_fft.ifft, n=self._state.infft, norm="backward"
|
|
209
|
-
)
|
|
190
|
+
self._state.fft = functools.partial(sp_fft.fft, n=self._state.nfft, norm="backward")
|
|
191
|
+
self._state.ifft = functools.partial(sp_fft.ifft, n=self._state.infft, norm="backward")
|
|
210
192
|
else:
|
|
211
|
-
self._state.fft = functools.partial(
|
|
212
|
-
|
|
213
|
-
)
|
|
214
|
-
self._state.ifft = functools.partial(
|
|
215
|
-
sp_fft.irfft, n=self._state.infft, norm="backward"
|
|
216
|
-
)
|
|
193
|
+
self._state.fft = functools.partial(sp_fft.rfft, n=self._state.nfft, norm="backward")
|
|
194
|
+
self._state.ifft = functools.partial(sp_fft.irfft, n=self._state.infft, norm="backward")
|
|
217
195
|
|
|
218
196
|
# Calculate fft of kernels
|
|
219
197
|
self._state.prep_kerns = np.array([self._state.fft(_) for _ in kernels])
|
|
@@ -229,9 +207,7 @@ class FilterbankTransformer(
|
|
|
229
207
|
in_dat = np.moveaxis(message.data, targ_ax_ix, -1)
|
|
230
208
|
if self._state.mode == FilterbankMode.FFT:
|
|
231
209
|
# Fix message.dims because we will pass it to windower
|
|
232
|
-
move_dims =
|
|
233
|
-
message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [axis]
|
|
234
|
-
)
|
|
210
|
+
move_dims = message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [axis]
|
|
235
211
|
message = replace(message, data=in_dat, dims=move_dims)
|
|
236
212
|
else:
|
|
237
213
|
in_dat = message.data
|
|
@@ -239,22 +215,15 @@ class FilterbankTransformer(
|
|
|
239
215
|
if self._state.mode == FilterbankMode.CONV:
|
|
240
216
|
n_dest = in_dat.shape[-1] + self._state.overlap
|
|
241
217
|
if self._state.dest_arr.shape[-1] < n_dest:
|
|
242
|
-
pad = np.zeros(
|
|
243
|
-
|
|
244
|
-
+ (n_dest - self._state.dest_arr.shape[-1],)
|
|
245
|
-
)
|
|
246
|
-
self._state.dest_arr = np.concatenate(
|
|
247
|
-
[self._state.dest_arr, pad], axis=-1
|
|
248
|
-
)
|
|
218
|
+
pad = np.zeros(self._state.dest_arr.shape[:-1] + (n_dest - self._state.dest_arr.shape[-1],))
|
|
219
|
+
self._state.dest_arr = np.concatenate([self._state.dest_arr, pad], axis=-1)
|
|
249
220
|
self._state.dest_arr.fill(0)
|
|
250
221
|
|
|
251
222
|
# Note: I tried several alternatives to this loop; all were slower than this.
|
|
252
223
|
# numba.jit; stride_tricks + np.einsum; threading. Latter might be better with Python 3.13.
|
|
253
224
|
for k_ix, k in enumerate(self._state.prep_kerns):
|
|
254
225
|
n_out = in_dat.shape[-1] + k.shape[-1] - 1
|
|
255
|
-
self._state.dest_arr[..., k_ix, :n_out] = np.apply_along_axis(
|
|
256
|
-
np.convolve, -1, in_dat, k, mode="full"
|
|
257
|
-
)
|
|
226
|
+
self._state.dest_arr[..., k_ix, :n_out] = np.apply_along_axis(np.convolve, -1, in_dat, k, mode="full")
|
|
258
227
|
self._state.dest_arr[..., : self._state.overlap] += self._state.tail
|
|
259
228
|
new_tail = self._state.dest_arr[..., in_dat.shape[-1] : n_dest]
|
|
260
229
|
if new_tail.size > 0:
|
|
@@ -278,18 +247,14 @@ class FilterbankTransformer(
|
|
|
278
247
|
# Previous iteration's tail:
|
|
279
248
|
overlapped[..., :1, : self._state.overlap] += self._state.tail
|
|
280
249
|
# window-to-window:
|
|
281
|
-
overlapped[..., 1:, : self._state.overlap] += overlapped[
|
|
282
|
-
..., :-1, -self._state.overlap :
|
|
283
|
-
]
|
|
250
|
+
overlapped[..., 1:, : self._state.overlap] += overlapped[..., :-1, -self._state.overlap :]
|
|
284
251
|
# Save tail:
|
|
285
252
|
new_tail = overlapped[..., -1:, -self._state.overlap :]
|
|
286
253
|
if new_tail.size > 0:
|
|
287
254
|
# All of the above code works if input is size-zero, but we don't want to save a zero-size tail.
|
|
288
255
|
self._state.tail = new_tail
|
|
289
256
|
# Concat over win axis, without overlap.
|
|
290
|
-
res = overlapped[..., : -self._state.overlap].reshape(
|
|
291
|
-
overlapped.shape[:-2] + (-1,)
|
|
292
|
-
)
|
|
257
|
+
res = overlapped[..., : -self._state.overlap].reshape(overlapped.shape[:-2] + (-1,))
|
|
293
258
|
|
|
294
259
|
return replace(
|
|
295
260
|
self._state.template,
|
|
@@ -298,9 +263,7 @@ class FilterbankTransformer(
|
|
|
298
263
|
)
|
|
299
264
|
|
|
300
265
|
|
|
301
|
-
class Filterbank(
|
|
302
|
-
BaseTransformerUnit[FilterbankSettings, AxisArray, AxisArray, FilterbankTransformer]
|
|
303
|
-
):
|
|
266
|
+
class Filterbank(BaseTransformerUnit[FilterbankSettings, AxisArray, AxisArray, FilterbankTransformer]):
|
|
304
267
|
SETTINGS = FilterbankSettings
|
|
305
268
|
|
|
306
269
|
|
|
@@ -3,22 +3,19 @@ import typing
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
-
|
|
7
|
-
from ezmsg.util.messages.util import replace
|
|
8
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
-
|
|
10
|
-
from .base import (
|
|
6
|
+
from ezmsg.baseproc import (
|
|
11
7
|
BaseStatefulTransformer,
|
|
12
8
|
processor_state,
|
|
13
9
|
)
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
from ezmsg.util.messages.util import replace
|
|
14
12
|
|
|
15
13
|
from .filterbank import (
|
|
16
|
-
FilterbankTransformer,
|
|
17
|
-
FilterbankSettings,
|
|
18
14
|
FilterbankMode,
|
|
15
|
+
FilterbankSettings,
|
|
16
|
+
FilterbankTransformer,
|
|
19
17
|
MinPhaseMode,
|
|
20
18
|
)
|
|
21
|
-
|
|
22
19
|
from .kaiser import KaiserFilterSettings, kaiser_design_fun
|
|
23
20
|
|
|
24
21
|
|
|
@@ -55,9 +52,7 @@ class FilterbankDesignState:
|
|
|
55
52
|
|
|
56
53
|
|
|
57
54
|
class FilterbankDesignTransformer(
|
|
58
|
-
BaseStatefulTransformer[
|
|
59
|
-
FilterbankDesignSettings, AxisArray, AxisArray, FilterbankDesignState
|
|
60
|
-
],
|
|
55
|
+
BaseStatefulTransformer[FilterbankDesignSettings, AxisArray, AxisArray, FilterbankDesignState],
|
|
61
56
|
):
|
|
62
57
|
"""
|
|
63
58
|
Transformer that designs and applies a filterbank based on Kaiser windowed FIR filters.
|
|
@@ -70,9 +65,7 @@ class FilterbankDesignTransformer(
|
|
|
70
65
|
else:
|
|
71
66
|
raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
|
|
72
67
|
|
|
73
|
-
def update_settings(
|
|
74
|
-
self, new_settings: typing.Optional[FilterbankDesignSettings] = None, **kwargs
|
|
75
|
-
) -> None:
|
|
68
|
+
def update_settings(self, new_settings: typing.Optional[FilterbankDesignSettings] = None, **kwargs) -> None:
|
|
76
69
|
"""
|
|
77
70
|
Update settings and mark that filter coefficients need to be recalculated.
|
|
78
71
|
|