ezmsg-sigproc 1.8.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +119 -104
- ezmsg/sigproc/bandpower.py +58 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -78
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.1.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/filter.py
CHANGED
|
@@ -1,15 +1,22 @@
|
|
|
1
|
+
from abc import abstractmethod, ABC
|
|
1
2
|
from dataclasses import dataclass, field
|
|
2
3
|
import typing
|
|
3
4
|
|
|
4
5
|
import ezmsg.core as ez
|
|
5
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
7
|
from ezmsg.util.messages.util import replace
|
|
7
|
-
from ezmsg.util.generator import consumer
|
|
8
8
|
import numpy as np
|
|
9
9
|
import numpy.typing as npt
|
|
10
10
|
import scipy.signal
|
|
11
11
|
|
|
12
|
-
from ezmsg.sigproc.base import
|
|
12
|
+
from ezmsg.sigproc.base import (
|
|
13
|
+
processor_state,
|
|
14
|
+
BaseStatefulTransformer,
|
|
15
|
+
BaseTransformerUnit,
|
|
16
|
+
SettingsType,
|
|
17
|
+
BaseConsumerUnit,
|
|
18
|
+
TransformerType,
|
|
19
|
+
)
|
|
13
20
|
|
|
14
21
|
|
|
15
22
|
@dataclass
|
|
@@ -18,6 +25,12 @@ class FilterCoefficients:
|
|
|
18
25
|
a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
19
26
|
|
|
20
27
|
|
|
28
|
+
# Type aliases
|
|
29
|
+
BACoeffs = tuple[npt.NDArray, npt.NDArray]
|
|
30
|
+
SOSCoeffs = npt.NDArray
|
|
31
|
+
FilterCoefsType = typing.TypeVar("FilterCoefsType", BACoeffs, SOSCoeffs)
|
|
32
|
+
|
|
33
|
+
|
|
21
34
|
def _normalize_coefs(
|
|
22
35
|
coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
|
|
23
36
|
) -> tuple[str, tuple[npt.NDArray, ...]]:
|
|
@@ -33,167 +46,270 @@ def _normalize_coefs(
|
|
|
33
46
|
return coef_type, coefs
|
|
34
47
|
|
|
35
48
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
40
|
-
"""
|
|
41
|
-
Filter data using the provided coefficients.
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
axis: The name of the axis to operate on.
|
|
45
|
-
coefs: The pre-calculated filter coefficients.
|
|
46
|
-
coef_type: The type of filter coefficients. One of "ba" or "sos".
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
A primed generator that, when passed an :obj:`AxisArray` via `.send(axis_array)`,
|
|
50
|
-
yields an :obj:`AxisArray` with the data filtered.
|
|
51
|
-
"""
|
|
52
|
-
# Massage inputs
|
|
53
|
-
if coefs is not None and not isinstance(coefs, tuple):
|
|
54
|
-
# scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
|
|
55
|
-
coefs = (coefs,)
|
|
49
|
+
class FilterBaseSettings(ez.Settings):
|
|
50
|
+
axis: str | None = None
|
|
51
|
+
"""The name of the axis to operate on."""
|
|
56
52
|
|
|
57
|
-
|
|
58
|
-
|
|
53
|
+
coef_type: str = "ba"
|
|
54
|
+
"""The type of filter coefficients. One of "ba" or "sos"."""
|
|
59
55
|
|
|
60
|
-
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
|
|
61
|
-
zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
|
|
62
56
|
|
|
63
|
-
|
|
64
|
-
|
|
57
|
+
class FilterSettings(FilterBaseSettings):
|
|
58
|
+
coefs: FilterCoefficients | None = None
|
|
59
|
+
"""The pre-calculated filter coefficients."""
|
|
65
60
|
|
|
66
|
-
#
|
|
67
|
-
check_input = {"key": None, "shape": None}
|
|
68
|
-
# fs changing will be handled by caller that creates coefficients.
|
|
69
|
-
|
|
70
|
-
while True:
|
|
71
|
-
msg_in: AxisArray = yield msg_out
|
|
72
|
-
|
|
73
|
-
if coefs is None:
|
|
74
|
-
# passthrough if we do not have a filter design.
|
|
75
|
-
msg_out = msg_in
|
|
76
|
-
continue
|
|
77
|
-
|
|
78
|
-
axis = msg_in.dims[0] if axis is None else axis
|
|
79
|
-
axis_idx = msg_in.get_axis_idx(axis)
|
|
80
|
-
|
|
81
|
-
# Re-calculate/reset zi if necessary
|
|
82
|
-
samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
|
|
83
|
-
b_reset = samp_shape != check_input["shape"]
|
|
84
|
-
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
85
|
-
if b_reset:
|
|
86
|
-
check_input["shape"] = samp_shape
|
|
87
|
-
check_input["key"] = msg_in.key
|
|
88
|
-
|
|
89
|
-
n_tail = msg_in.data.ndim - axis_idx - 1
|
|
90
|
-
zi = zi_func(*coefs)
|
|
91
|
-
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
|
|
92
|
-
n_tile = (
|
|
93
|
-
msg_in.data.shape[:axis_idx] + (1,) + msg_in.data.shape[axis_idx + 1 :]
|
|
94
|
-
)
|
|
95
|
-
if coef_type == "sos":
|
|
96
|
-
# sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
|
|
97
|
-
zi_expand = (slice(None),) + zi_expand
|
|
98
|
-
n_tile = (1,) + n_tile
|
|
99
|
-
zi = np.tile(zi[zi_expand], n_tile)
|
|
100
|
-
|
|
101
|
-
if msg_in.data.size > 0:
|
|
102
|
-
dat_out, zi = filt_func(*coefs, msg_in.data, axis=axis_idx, zi=zi)
|
|
103
|
-
else:
|
|
104
|
-
dat_out = msg_in.data
|
|
105
|
-
msg_out = replace(msg_in, data=dat_out)
|
|
61
|
+
# Note: coef_type = "ba" is assumed for this class.
|
|
106
62
|
|
|
107
63
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
FilterCoefsMultiType = BACoeffs | SOSCoeffs
|
|
64
|
+
@processor_state
|
|
65
|
+
class FilterState:
|
|
66
|
+
zi: npt.NDArray | None = None
|
|
112
67
|
|
|
113
68
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
coef_type: str,
|
|
118
|
-
design_fun: typing.Callable[[float], FilterCoefsMultiType | None],
|
|
119
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
69
|
+
class FilterTransformer(
|
|
70
|
+
BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]
|
|
71
|
+
):
|
|
120
72
|
"""
|
|
121
|
-
Filter data using
|
|
122
|
-
|
|
123
|
-
Args:
|
|
124
|
-
axis: The name of the axis to filter.
|
|
125
|
-
Note: The axis must be represented in the message .axes and be of type AxisArray.LinearAxis.
|
|
126
|
-
coef_type: "ba" or "sos"
|
|
127
|
-
design_fun: A callable that takes "fs" as its only argument and returns a tuple of filter coefficients.
|
|
128
|
-
If the design_fun returns None then the filter will act as a passthrough.
|
|
129
|
-
Hint: To make a design function that only requires fs, use functools.partial to set other parameters.
|
|
130
|
-
See butterworthfilter for an example.
|
|
131
|
-
|
|
132
|
-
Returns:
|
|
133
|
-
|
|
73
|
+
Filter data using the provided coefficients.
|
|
134
74
|
"""
|
|
135
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
136
75
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
76
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
77
|
+
if self.settings.coefs is None:
|
|
78
|
+
return message
|
|
79
|
+
if self._state.zi is None:
|
|
80
|
+
self._reset_state(message)
|
|
81
|
+
self._hash = self._hash_message(message)
|
|
82
|
+
return super().__call__(message)
|
|
83
|
+
|
|
84
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
85
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
86
|
+
axis_idx = message.get_axis_idx(axis)
|
|
87
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
88
|
+
return hash((message.key, samp_shape))
|
|
89
|
+
|
|
90
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
91
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
92
|
+
axis_idx = message.get_axis_idx(axis)
|
|
93
|
+
n_tail = message.data.ndim - axis_idx - 1
|
|
94
|
+
coefs = (
|
|
95
|
+
(self.settings.coefs,)
|
|
96
|
+
if self.settings.coefs is not None
|
|
97
|
+
and not isinstance(self.settings.coefs, tuple)
|
|
98
|
+
else self.settings.coefs
|
|
99
|
+
)
|
|
100
|
+
zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[
|
|
101
|
+
self.settings.coef_type
|
|
102
|
+
]
|
|
103
|
+
zi = zi_func(*coefs)
|
|
104
|
+
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
|
|
105
|
+
n_tile = (
|
|
106
|
+
message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :]
|
|
107
|
+
)
|
|
155
108
|
|
|
109
|
+
if self.settings.coef_type == "sos":
|
|
110
|
+
zi_expand = (slice(None),) + zi_expand
|
|
111
|
+
n_tile = (1,) + n_tile
|
|
156
112
|
|
|
157
|
-
|
|
158
|
-
axis: str | None = None
|
|
159
|
-
coef_type: str = "ba"
|
|
113
|
+
self.state.zi = np.tile(zi[zi_expand], n_tile)
|
|
160
114
|
|
|
115
|
+
def update_coefficients(
|
|
116
|
+
self,
|
|
117
|
+
coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
|
|
118
|
+
coef_type: str | None = None,
|
|
119
|
+
) -> None:
|
|
120
|
+
"""
|
|
121
|
+
Update filter coefficients.
|
|
122
|
+
|
|
123
|
+
If the new coefficients have the same length as the current ones, only the coefficients are updated.
|
|
124
|
+
If the lengths differ, the filter state is also reset to handle the new filter order.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
coefs: New filter coefficients
|
|
128
|
+
"""
|
|
129
|
+
old_coefs = self.settings.coefs
|
|
130
|
+
|
|
131
|
+
# Update settings with new coefficients
|
|
132
|
+
self.settings = replace(self.settings, coefs=coefs)
|
|
133
|
+
if coef_type is not None:
|
|
134
|
+
self.settings = replace(self.settings, coef_type=coef_type)
|
|
135
|
+
|
|
136
|
+
# Check if we need to reset the state
|
|
137
|
+
if self.state.zi is not None:
|
|
138
|
+
reset_needed = False
|
|
139
|
+
|
|
140
|
+
if self.settings.coef_type == "ba":
|
|
141
|
+
if isinstance(old_coefs, FilterCoefficients) and isinstance(
|
|
142
|
+
coefs, FilterCoefficients
|
|
143
|
+
):
|
|
144
|
+
if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(
|
|
145
|
+
coefs.a
|
|
146
|
+
):
|
|
147
|
+
reset_needed = True
|
|
148
|
+
elif isinstance(old_coefs, tuple) and isinstance(coefs, tuple):
|
|
149
|
+
if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(
|
|
150
|
+
coefs[1]
|
|
151
|
+
):
|
|
152
|
+
reset_needed = True
|
|
153
|
+
else:
|
|
154
|
+
reset_needed = True
|
|
155
|
+
elif self.settings.coef_type == "sos":
|
|
156
|
+
if isinstance(old_coefs, np.ndarray) and isinstance(coefs, np.ndarray):
|
|
157
|
+
if old_coefs.shape != coefs.shape:
|
|
158
|
+
reset_needed = True
|
|
159
|
+
else:
|
|
160
|
+
reset_needed = True
|
|
161
|
+
|
|
162
|
+
if reset_needed:
|
|
163
|
+
self.state.zi = None # This will trigger _reset_state on the next call
|
|
164
|
+
|
|
165
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
166
|
+
if message.data.size > 0:
|
|
167
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
168
|
+
axis_idx = message.get_axis_idx(axis)
|
|
169
|
+
coefs = (
|
|
170
|
+
(self.settings.coefs,)
|
|
171
|
+
if self.settings.coefs is not None
|
|
172
|
+
and not isinstance(self.settings.coefs, tuple)
|
|
173
|
+
else self.settings.coefs
|
|
174
|
+
)
|
|
175
|
+
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[
|
|
176
|
+
self.settings.coef_type
|
|
177
|
+
]
|
|
178
|
+
dat_out, self.state.zi = filt_func(
|
|
179
|
+
*coefs, message.data, axis=axis_idx, zi=self.state.zi
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
dat_out = message.data
|
|
161
183
|
|
|
162
|
-
|
|
163
|
-
SETTINGS = FilterBaseSettings
|
|
184
|
+
return replace(message, data=dat_out)
|
|
164
185
|
|
|
165
|
-
# Backwards-compatible with `Filter` unit
|
|
166
|
-
INPUT_FILTER = ez.InputStream(FilterCoefsMultiType)
|
|
167
186
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
187
|
+
class Filter(
|
|
188
|
+
BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]
|
|
189
|
+
):
|
|
190
|
+
SETTINGS = FilterSettings
|
|
172
191
|
|
|
173
|
-
def construct_generator(self):
|
|
174
|
-
design_fun = self.design_filter()
|
|
175
|
-
self.STATE.gen = filter_gen_by_design(
|
|
176
|
-
self.SETTINGS.axis, self.SETTINGS.coef_type, design_fun
|
|
177
|
-
)
|
|
178
192
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
193
|
+
def filtergen(
|
|
194
|
+
axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str
|
|
195
|
+
) -> FilterTransformer:
|
|
196
|
+
"""
|
|
197
|
+
Filter data using the provided coefficients.
|
|
183
198
|
|
|
199
|
+
Returns:
|
|
200
|
+
:obj:`FilterTransformer`.
|
|
201
|
+
"""
|
|
202
|
+
return FilterTransformer(
|
|
203
|
+
FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type)
|
|
204
|
+
)
|
|
184
205
|
|
|
185
|
-
class FilterSettings(FilterBaseSettings):
|
|
186
|
-
# If you'd like to statically design a filter, define it in settings
|
|
187
|
-
coefs: FilterCoefficients | None = None
|
|
188
|
-
# Note: coef_type = "ba" is assumed for this class.
|
|
189
206
|
|
|
207
|
+
@processor_state
|
|
208
|
+
class FilterByDesignState:
|
|
209
|
+
filter: FilterTransformer | None = None
|
|
210
|
+
needs_redesign: bool = False
|
|
190
211
|
|
|
191
|
-
class Filter(FilterBase):
|
|
192
|
-
SETTINGS = FilterSettings
|
|
193
212
|
|
|
194
|
-
|
|
213
|
+
class FilterByDesignTransformer(
|
|
214
|
+
BaseStatefulTransformer[SettingsType, AxisArray, AxisArray, FilterByDesignState],
|
|
215
|
+
ABC,
|
|
216
|
+
typing.Generic[SettingsType, FilterCoefsType],
|
|
217
|
+
):
|
|
218
|
+
"""Abstract base class for filter design transformers."""
|
|
195
219
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
220
|
+
@classmethod
|
|
221
|
+
def get_message_type(cls, dir: str) -> type[AxisArray]:
|
|
222
|
+
if dir in ("in", "out"):
|
|
223
|
+
return AxisArray
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
|
|
226
|
+
|
|
227
|
+
@abstractmethod
|
|
228
|
+
def get_design_function(self) -> typing.Callable[[float], FilterCoefsType | None]:
|
|
229
|
+
"""Return a function that takes sampling frequency and returns filter coefficients."""
|
|
230
|
+
...
|
|
231
|
+
|
|
232
|
+
def update_settings(
|
|
233
|
+
self, new_settings: typing.Optional[SettingsType] = None, **kwargs
|
|
234
|
+
) -> None:
|
|
235
|
+
"""
|
|
236
|
+
Update settings and mark that filter coefficients need to be recalculated.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
new_settings: Complete new settings object to replace current settings
|
|
240
|
+
**kwargs: Individual settings to update
|
|
241
|
+
"""
|
|
242
|
+
# Update settings
|
|
243
|
+
if new_settings is not None:
|
|
244
|
+
self.settings = new_settings
|
|
245
|
+
else:
|
|
246
|
+
self.settings = replace(self.settings, **kwargs)
|
|
247
|
+
|
|
248
|
+
# Set flag to trigger recalculation on next message
|
|
249
|
+
if self.state.filter is not None:
|
|
250
|
+
self.state.needs_redesign = True
|
|
251
|
+
|
|
252
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
253
|
+
# Offer a shortcut when there is no design function or order is 0.
|
|
254
|
+
if hasattr(self.settings, "order") and not self.settings.order:
|
|
255
|
+
return message
|
|
256
|
+
design_fun = self.get_design_function()
|
|
257
|
+
if design_fun is None:
|
|
258
|
+
return message
|
|
259
|
+
|
|
260
|
+
# Check if filter exists but needs redesign due to settings change
|
|
261
|
+
if self.state.filter is not None and self.state.needs_redesign:
|
|
262
|
+
axis = self.state.filter.settings.axis
|
|
263
|
+
fs = 1 / message.axes[axis].gain
|
|
264
|
+
coefs = design_fun(fs)
|
|
265
|
+
self.state.filter.update_coefficients(
|
|
266
|
+
coefs, coef_type=self.settings.coef_type
|
|
267
|
+
)
|
|
268
|
+
self.state.needs_redesign = False
|
|
269
|
+
|
|
270
|
+
return super().__call__(message)
|
|
271
|
+
|
|
272
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
273
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
274
|
+
gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
|
|
275
|
+
axis_idx = message.get_axis_idx(axis)
|
|
276
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
277
|
+
return hash((message.key, samp_shape, gain))
|
|
278
|
+
|
|
279
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
280
|
+
design_fun = self.get_design_function()
|
|
281
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
282
|
+
fs = 1 / message.axes[axis].gain
|
|
283
|
+
coefs = design_fun(fs)
|
|
284
|
+
new_settings = FilterSettings(
|
|
285
|
+
axis=axis, coef_type=self.settings.coef_type, coefs=coefs
|
|
286
|
+
)
|
|
287
|
+
self.state.filter = FilterTransformer(settings=new_settings)
|
|
288
|
+
|
|
289
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
290
|
+
return self.state.filter(message)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class BaseFilterByDesignTransformerUnit(
|
|
294
|
+
BaseTransformerUnit[SettingsType, AxisArray, AxisArray, FilterByDesignTransformer],
|
|
295
|
+
typing.Generic[SettingsType, TransformerType],
|
|
296
|
+
):
|
|
297
|
+
@ez.subscriber(BaseConsumerUnit.INPUT_SETTINGS)
|
|
298
|
+
async def on_settings(self, msg: SettingsType) -> None:
|
|
299
|
+
"""
|
|
300
|
+
Receive a settings message, override self.SETTINGS, and re-create the processor.
|
|
301
|
+
Child classes that wish to have fine-grained control over whether the
|
|
302
|
+
core processor resets on settings changes should override this method.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
msg: a settings message.
|
|
306
|
+
"""
|
|
307
|
+
self.apply_settings(msg)
|
|
308
|
+
|
|
309
|
+
# Check if processor exists yet
|
|
310
|
+
if hasattr(self, "processor") and self.processor is not None:
|
|
311
|
+
# Update the existing processor with new settings
|
|
312
|
+
self.processor.update_settings(self.SETTINGS)
|
|
313
|
+
else:
|
|
314
|
+
# Processor doesn't exist yet, create a new one
|
|
315
|
+
self.create_processor()
|