ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/filter.py
CHANGED
|
@@ -1,15 +1,21 @@
|
|
|
1
|
-
from dataclasses import dataclass, field
|
|
2
1
|
import typing
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from dataclasses import dataclass, field
|
|
3
4
|
|
|
4
5
|
import ezmsg.core as ez
|
|
5
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
-
from ezmsg.util.messages.util import replace
|
|
7
|
-
from ezmsg.util.generator import consumer
|
|
8
6
|
import numpy as np
|
|
9
7
|
import numpy.typing as npt
|
|
10
8
|
import scipy.signal
|
|
11
|
-
|
|
12
|
-
|
|
9
|
+
from ezmsg.baseproc import (
|
|
10
|
+
BaseConsumerUnit,
|
|
11
|
+
BaseStatefulTransformer,
|
|
12
|
+
BaseTransformerUnit,
|
|
13
|
+
SettingsType,
|
|
14
|
+
TransformerType,
|
|
15
|
+
processor_state,
|
|
16
|
+
)
|
|
17
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
18
|
+
from ezmsg.util.messages.util import replace
|
|
13
19
|
|
|
14
20
|
|
|
15
21
|
@dataclass
|
|
@@ -18,182 +24,282 @@ class FilterCoefficients:
|
|
|
18
24
|
a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
19
25
|
|
|
20
26
|
|
|
27
|
+
# Type aliases
|
|
28
|
+
BACoeffs = tuple[npt.NDArray, npt.NDArray]
|
|
29
|
+
SOSCoeffs = npt.NDArray
|
|
30
|
+
FilterCoefsType = typing.TypeVar("FilterCoefsType", BACoeffs, SOSCoeffs)
|
|
31
|
+
|
|
32
|
+
|
|
21
33
|
def _normalize_coefs(
|
|
22
|
-
coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
|
|
23
|
-
) -> tuple[str, tuple[npt.NDArray, ...]]:
|
|
34
|
+
coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray | None,
|
|
35
|
+
) -> tuple[str, tuple[npt.NDArray, ...] | None]:
|
|
24
36
|
coef_type = "ba"
|
|
25
37
|
if coefs is not None:
|
|
26
38
|
# scipy.signal functions called with first arg `*coefs`.
|
|
27
39
|
# Make sure we have a tuple of coefficients.
|
|
28
|
-
if isinstance(coefs,
|
|
40
|
+
if isinstance(coefs, np.ndarray):
|
|
29
41
|
coef_type = "sos"
|
|
30
42
|
coefs = (coefs,) # sos funcs just want a single ndarray.
|
|
31
43
|
elif isinstance(coefs, FilterCoefficients):
|
|
32
|
-
coefs = (
|
|
44
|
+
coefs = (coefs.b, coefs.a)
|
|
45
|
+
elif not isinstance(coefs, tuple):
|
|
46
|
+
coefs = (coefs,)
|
|
33
47
|
return coef_type, coefs
|
|
34
48
|
|
|
35
49
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
40
|
-
"""
|
|
41
|
-
Filter data using the provided coefficients.
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
axis: The name of the axis to operate on.
|
|
45
|
-
coefs: The pre-calculated filter coefficients.
|
|
46
|
-
coef_type: The type of filter coefficients. One of "ba" or "sos".
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
A primed generator that, when passed an :obj:`AxisArray` via `.send(axis_array)`,
|
|
50
|
-
yields an :obj:`AxisArray` with the data filtered.
|
|
51
|
-
"""
|
|
52
|
-
# Massage inputs
|
|
53
|
-
if coefs is not None and not isinstance(coefs, tuple):
|
|
54
|
-
# scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
|
|
55
|
-
coefs = (coefs,)
|
|
50
|
+
class FilterBaseSettings(ez.Settings):
|
|
51
|
+
axis: str | None = None
|
|
52
|
+
"""The name of the axis to operate on."""
|
|
56
53
|
|
|
57
|
-
|
|
58
|
-
|
|
54
|
+
coef_type: str = "ba"
|
|
55
|
+
"""The type of filter coefficients. One of "ba" or "sos"."""
|
|
59
56
|
|
|
60
|
-
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
|
|
61
|
-
zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
|
|
62
57
|
|
|
63
|
-
|
|
64
|
-
|
|
58
|
+
class FilterSettings(FilterBaseSettings):
|
|
59
|
+
coefs: FilterCoefficients | None = None
|
|
60
|
+
"""The pre-calculated filter coefficients."""
|
|
65
61
|
|
|
66
|
-
#
|
|
67
|
-
check_input = {"key": None, "shape": None}
|
|
68
|
-
# fs changing will be handled by caller that creates coefficients.
|
|
69
|
-
|
|
70
|
-
while True:
|
|
71
|
-
msg_in: AxisArray = yield msg_out
|
|
72
|
-
|
|
73
|
-
if coefs is None:
|
|
74
|
-
# passthrough if we do not have a filter design.
|
|
75
|
-
msg_out = msg_in
|
|
76
|
-
continue
|
|
77
|
-
|
|
78
|
-
axis = msg_in.dims[0] if axis is None else axis
|
|
79
|
-
axis_idx = msg_in.get_axis_idx(axis)
|
|
80
|
-
|
|
81
|
-
# Re-calculate/reset zi if necessary
|
|
82
|
-
samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
|
|
83
|
-
b_reset = samp_shape != check_input["shape"]
|
|
84
|
-
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
85
|
-
if b_reset:
|
|
86
|
-
check_input["shape"] = samp_shape
|
|
87
|
-
check_input["key"] = msg_in.key
|
|
88
|
-
|
|
89
|
-
n_tail = msg_in.data.ndim - axis_idx - 1
|
|
90
|
-
zi = zi_func(*coefs)
|
|
91
|
-
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
|
|
92
|
-
n_tile = (
|
|
93
|
-
msg_in.data.shape[:axis_idx] + (1,) + msg_in.data.shape[axis_idx + 1 :]
|
|
94
|
-
)
|
|
95
|
-
if coef_type == "sos":
|
|
96
|
-
# sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
|
|
97
|
-
zi_expand = (slice(None),) + zi_expand
|
|
98
|
-
n_tile = (1,) + n_tile
|
|
99
|
-
zi = np.tile(zi[zi_expand], n_tile)
|
|
100
|
-
|
|
101
|
-
if msg_in.data.size > 0:
|
|
102
|
-
dat_out, zi = filt_func(*coefs, msg_in.data, axis=axis_idx, zi=zi)
|
|
103
|
-
else:
|
|
104
|
-
dat_out = msg_in.data
|
|
105
|
-
msg_out = replace(msg_in, data=dat_out)
|
|
62
|
+
# Note: coef_type = "ba" is assumed for this class.
|
|
106
63
|
|
|
107
64
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
FilterCoefsMultiType = BACoeffs | SOSCoeffs
|
|
65
|
+
@processor_state
|
|
66
|
+
class FilterState:
|
|
67
|
+
zi: npt.NDArray | None = None
|
|
112
68
|
|
|
113
69
|
|
|
114
|
-
|
|
115
|
-
def filter_gen_by_design(
|
|
116
|
-
axis: str,
|
|
117
|
-
coef_type: str,
|
|
118
|
-
design_fun: typing.Callable[[float], FilterCoefsMultiType | None],
|
|
119
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
70
|
+
class FilterTransformer(BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]):
|
|
120
71
|
"""
|
|
121
|
-
Filter data using
|
|
122
|
-
|
|
123
|
-
Args:
|
|
124
|
-
axis: The name of the axis to filter.
|
|
125
|
-
Note: The axis must be represented in the message .axes and be of type AxisArray.LinearAxis.
|
|
126
|
-
coef_type: "ba" or "sos"
|
|
127
|
-
design_fun: A callable that takes "fs" as its only argument and returns a tuple of filter coefficients.
|
|
128
|
-
If the design_fun returns None then the filter will act as a passthrough.
|
|
129
|
-
Hint: To make a design function that only requires fs, use functools.partial to set other parameters.
|
|
130
|
-
See butterworthfilter for an example.
|
|
131
|
-
|
|
132
|
-
Returns:
|
|
133
|
-
|
|
72
|
+
Filter data using the provided coefficients.
|
|
134
73
|
"""
|
|
135
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
136
74
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
75
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
76
|
+
if self.settings.coefs is None:
|
|
77
|
+
return message
|
|
78
|
+
if self._state.zi is None:
|
|
79
|
+
self._reset_state(message)
|
|
80
|
+
self._hash = self._hash_message(message)
|
|
81
|
+
return super().__call__(message)
|
|
82
|
+
|
|
83
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
84
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
85
|
+
axis_idx = message.get_axis_idx(axis)
|
|
86
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
87
|
+
return hash((message.key, samp_shape))
|
|
88
|
+
|
|
89
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
90
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
91
|
+
axis_idx = message.get_axis_idx(axis)
|
|
92
|
+
n_tail = message.data.ndim - axis_idx - 1
|
|
93
|
+
_, coefs = _normalize_coefs(self.settings.coefs)
|
|
94
|
+
|
|
95
|
+
if self.settings.coef_type == "ba":
|
|
96
|
+
b, a = coefs
|
|
97
|
+
if len(a) == 1 or np.allclose(a[1:], 0):
|
|
98
|
+
# For FIR filters, use lfiltic with zero initial conditions
|
|
99
|
+
zi = scipy.signal.lfiltic(b, a, [])
|
|
100
|
+
else:
|
|
101
|
+
# For IIR filters...
|
|
102
|
+
zi = scipy.signal.lfilter_zi(b, a)
|
|
103
|
+
else:
|
|
104
|
+
# For second-order sections (SOS) filters, use sosfilt_zi
|
|
105
|
+
zi = scipy.signal.sosfilt_zi(*coefs)
|
|
144
106
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
axis = axis or msg_in.dims[0]
|
|
148
|
-
b_reset = msg_in.axes[axis].gain != check_input["gain"]
|
|
149
|
-
if b_reset:
|
|
150
|
-
check_input["gain"] = msg_in.axes[axis].gain
|
|
151
|
-
coefs = design_fun(1 / msg_in.axes[axis].gain)
|
|
152
|
-
filter_gen = filtergen(axis, coefs, coef_type)
|
|
107
|
+
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
|
|
108
|
+
n_tile = message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :]
|
|
153
109
|
|
|
154
|
-
|
|
110
|
+
if self.settings.coef_type == "sos":
|
|
111
|
+
zi_expand = (slice(None),) + zi_expand
|
|
112
|
+
n_tile = (1,) + n_tile
|
|
155
113
|
|
|
114
|
+
self.state.zi = np.tile(zi[zi_expand], n_tile)
|
|
156
115
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
116
|
+
def update_coefficients(
|
|
117
|
+
self,
|
|
118
|
+
coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
|
|
119
|
+
coef_type: str | None = None,
|
|
120
|
+
) -> None:
|
|
121
|
+
"""
|
|
122
|
+
Update filter coefficients.
|
|
123
|
+
|
|
124
|
+
If the new coefficients have the same length as the current ones, only the coefficients are updated.
|
|
125
|
+
If the lengths differ, the filter state is also reset to handle the new filter order.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
coefs: New filter coefficients
|
|
129
|
+
"""
|
|
130
|
+
old_coefs = self.settings.coefs
|
|
131
|
+
|
|
132
|
+
# Update settings with new coefficients
|
|
133
|
+
self.settings = replace(self.settings, coefs=coefs)
|
|
134
|
+
if coef_type is not None:
|
|
135
|
+
self.settings = replace(self.settings, coef_type=coef_type)
|
|
136
|
+
|
|
137
|
+
# Check if we need to reset the state
|
|
138
|
+
if self.state.zi is not None:
|
|
139
|
+
reset_needed = False
|
|
140
|
+
|
|
141
|
+
if self.settings.coef_type == "ba":
|
|
142
|
+
if isinstance(old_coefs, FilterCoefficients) and isinstance(coefs, FilterCoefficients):
|
|
143
|
+
if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(coefs.a):
|
|
144
|
+
reset_needed = True
|
|
145
|
+
elif isinstance(old_coefs, tuple) and isinstance(coefs, tuple):
|
|
146
|
+
if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(coefs[1]):
|
|
147
|
+
reset_needed = True
|
|
148
|
+
else:
|
|
149
|
+
reset_needed = True
|
|
150
|
+
elif self.settings.coef_type == "sos":
|
|
151
|
+
if isinstance(old_coefs, np.ndarray) and isinstance(coefs, np.ndarray):
|
|
152
|
+
if old_coefs.shape != coefs.shape:
|
|
153
|
+
reset_needed = True
|
|
154
|
+
else:
|
|
155
|
+
reset_needed = True
|
|
156
|
+
|
|
157
|
+
if reset_needed:
|
|
158
|
+
self.state.zi = None # This will trigger _reset_state on the next call
|
|
159
|
+
|
|
160
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
161
|
+
if message.data.size > 0:
|
|
162
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
163
|
+
axis_idx = message.get_axis_idx(axis)
|
|
164
|
+
_, coefs = _normalize_coefs(self.settings.coefs)
|
|
165
|
+
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[self.settings.coef_type]
|
|
166
|
+
dat_out, self.state.zi = filt_func(*coefs, message.data, axis=axis_idx, zi=self.state.zi)
|
|
167
|
+
else:
|
|
168
|
+
dat_out = message.data
|
|
161
169
|
|
|
162
|
-
|
|
163
|
-
SETTINGS = FilterBaseSettings
|
|
170
|
+
return replace(message, data=dat_out)
|
|
164
171
|
|
|
165
|
-
# Backwards-compatible with `Filter` unit
|
|
166
|
-
INPUT_FILTER = ez.InputStream(FilterCoefsMultiType)
|
|
167
172
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
) -> typing.Callable[[float], FilterCoefsMultiType | None]:
|
|
171
|
-
raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
|
|
173
|
+
class Filter(BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]):
|
|
174
|
+
SETTINGS = FilterSettings
|
|
172
175
|
|
|
173
|
-
def construct_generator(self):
|
|
174
|
-
design_fun = self.design_filter()
|
|
175
|
-
self.STATE.gen = filter_gen_by_design(
|
|
176
|
-
self.SETTINGS.axis, self.SETTINGS.coef_type, design_fun
|
|
177
|
-
)
|
|
178
176
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
self.construct_generator()
|
|
177
|
+
def filtergen(axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str) -> FilterTransformer:
|
|
178
|
+
"""
|
|
179
|
+
Filter data using the provided coefficients.
|
|
183
180
|
|
|
181
|
+
Returns:
|
|
182
|
+
:obj:`FilterTransformer`.
|
|
183
|
+
"""
|
|
184
|
+
return FilterTransformer(FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type))
|
|
184
185
|
|
|
185
|
-
class FilterSettings(FilterBaseSettings):
|
|
186
|
-
# If you'd like to statically design a filter, define it in settings
|
|
187
|
-
coefs: FilterCoefficients | None = None
|
|
188
|
-
# Note: coef_type = "ba" is assumed for this class.
|
|
189
186
|
|
|
187
|
+
@processor_state
|
|
188
|
+
class FilterByDesignState:
|
|
189
|
+
filter: FilterTransformer | None = None
|
|
190
|
+
needs_redesign: bool = False
|
|
190
191
|
|
|
191
|
-
class Filter(FilterBase):
|
|
192
|
-
SETTINGS = FilterSettings
|
|
193
192
|
|
|
194
|
-
|
|
193
|
+
class FilterByDesignTransformer(
|
|
194
|
+
BaseStatefulTransformer[SettingsType, AxisArray, AxisArray, FilterByDesignState],
|
|
195
|
+
ABC,
|
|
196
|
+
typing.Generic[SettingsType, FilterCoefsType],
|
|
197
|
+
):
|
|
198
|
+
"""Abstract base class for filter design transformers."""
|
|
195
199
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
+
@classmethod
|
|
201
|
+
def get_message_type(cls, dir: str) -> type[AxisArray]:
|
|
202
|
+
if dir in ("in", "out"):
|
|
203
|
+
return AxisArray
|
|
204
|
+
else:
|
|
205
|
+
raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
|
|
206
|
+
|
|
207
|
+
@abstractmethod
|
|
208
|
+
def get_design_function(self) -> typing.Callable[[float], FilterCoefsType | None]:
|
|
209
|
+
"""Return a function that takes sampling frequency and returns filter coefficients."""
|
|
210
|
+
...
|
|
211
|
+
|
|
212
|
+
def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
|
|
213
|
+
"""
|
|
214
|
+
Update settings and mark that filter coefficients need to be recalculated.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
new_settings: Complete new settings object to replace current settings
|
|
218
|
+
**kwargs: Individual settings to update
|
|
219
|
+
"""
|
|
220
|
+
# Update settings
|
|
221
|
+
if new_settings is not None:
|
|
222
|
+
self.settings = new_settings
|
|
223
|
+
else:
|
|
224
|
+
self.settings = replace(self.settings, **kwargs)
|
|
225
|
+
|
|
226
|
+
# Set flag to trigger recalculation on next message
|
|
227
|
+
if self.state.filter is not None:
|
|
228
|
+
self.state.needs_redesign = True
|
|
229
|
+
|
|
230
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
231
|
+
# Offer a shortcut when there is no design function or order is 0.
|
|
232
|
+
if hasattr(self.settings, "order") and not self.settings.order:
|
|
233
|
+
return message
|
|
234
|
+
design_fun = self.get_design_function()
|
|
235
|
+
if design_fun is None:
|
|
236
|
+
return message
|
|
237
|
+
|
|
238
|
+
# Check if filter exists but needs redesign due to settings change
|
|
239
|
+
if self.state.filter is not None and self.state.needs_redesign:
|
|
240
|
+
axis = self.state.filter.settings.axis
|
|
241
|
+
fs = 1 / message.axes[axis].gain
|
|
242
|
+
coefs = design_fun(fs)
|
|
243
|
+
|
|
244
|
+
# Convert BA to SOS if requested
|
|
245
|
+
if coefs is not None and self.settings.coef_type == "sos":
|
|
246
|
+
if isinstance(coefs, tuple) and len(coefs) == 2:
|
|
247
|
+
# It's BA format, convert to SOS
|
|
248
|
+
b, a = coefs
|
|
249
|
+
coefs = scipy.signal.tf2sos(b, a)
|
|
250
|
+
|
|
251
|
+
self.state.filter.update_coefficients(coefs, coef_type=self.settings.coef_type)
|
|
252
|
+
self.state.needs_redesign = False
|
|
253
|
+
|
|
254
|
+
return super().__call__(message)
|
|
255
|
+
|
|
256
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
257
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
258
|
+
gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
|
|
259
|
+
axis_idx = message.get_axis_idx(axis)
|
|
260
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
261
|
+
return hash((message.key, samp_shape, gain))
|
|
262
|
+
|
|
263
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
264
|
+
design_fun = self.get_design_function()
|
|
265
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
266
|
+
fs = 1 / message.axes[axis].gain
|
|
267
|
+
coefs = design_fun(fs)
|
|
268
|
+
|
|
269
|
+
# Convert BA to SOS if requested
|
|
270
|
+
if coefs is not None and self.settings.coef_type == "sos":
|
|
271
|
+
if isinstance(coefs, tuple) and len(coefs) == 2:
|
|
272
|
+
# It's BA format, convert to SOS
|
|
273
|
+
b, a = coefs
|
|
274
|
+
coefs = scipy.signal.tf2sos(b, a)
|
|
275
|
+
|
|
276
|
+
new_settings = FilterSettings(axis=axis, coef_type=self.settings.coef_type, coefs=coefs)
|
|
277
|
+
self.state.filter = FilterTransformer(settings=new_settings)
|
|
278
|
+
|
|
279
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
280
|
+
return self.state.filter(message)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class BaseFilterByDesignTransformerUnit(
|
|
284
|
+
BaseTransformerUnit[SettingsType, AxisArray, AxisArray, FilterByDesignTransformer],
|
|
285
|
+
typing.Generic[SettingsType, TransformerType],
|
|
286
|
+
):
|
|
287
|
+
@ez.subscriber(BaseConsumerUnit.INPUT_SETTINGS)
|
|
288
|
+
async def on_settings(self, msg: SettingsType) -> None:
|
|
289
|
+
"""
|
|
290
|
+
Receive a settings message, override self.SETTINGS, and re-create the processor.
|
|
291
|
+
Child classes that wish to have fine-grained control over whether the
|
|
292
|
+
core processor resets on settings changes should override this method.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
msg: a settings message.
|
|
296
|
+
"""
|
|
297
|
+
self.apply_settings(msg)
|
|
298
|
+
|
|
299
|
+
# Check if processor exists yet
|
|
300
|
+
if hasattr(self, "processor") and self.processor is not None:
|
|
301
|
+
# Update the existing processor with new settings
|
|
302
|
+
self.processor.update_settings(self.SETTINGS)
|
|
303
|
+
else:
|
|
304
|
+
# Processor doesn't exist yet, create a new one
|
|
305
|
+
self.create_processor()
|