ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/filter.py
CHANGED
|
@@ -1,13 +1,21 @@
|
|
|
1
|
-
|
|
1
|
+
import typing
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from dataclasses import dataclass, field
|
|
2
4
|
|
|
3
5
|
import ezmsg.core as ez
|
|
4
|
-
import scipy.signal
|
|
5
6
|
import numpy as np
|
|
6
|
-
import
|
|
7
|
-
|
|
7
|
+
import numpy.typing as npt
|
|
8
|
+
import scipy.signal
|
|
9
|
+
from ezmsg.baseproc import (
|
|
10
|
+
BaseConsumerUnit,
|
|
11
|
+
BaseStatefulTransformer,
|
|
12
|
+
BaseTransformerUnit,
|
|
13
|
+
SettingsType,
|
|
14
|
+
TransformerType,
|
|
15
|
+
processor_state,
|
|
16
|
+
)
|
|
8
17
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
-
|
|
10
|
-
from typing import AsyncGenerator, Optional, Tuple
|
|
18
|
+
from ezmsg.util.messages.util import replace
|
|
11
19
|
|
|
12
20
|
|
|
13
21
|
@dataclass
|
|
@@ -16,125 +24,282 @@ class FilterCoefficients:
|
|
|
16
24
|
a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
|
|
17
25
|
|
|
18
26
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
27
|
+
# Type aliases
|
|
28
|
+
BACoeffs = tuple[npt.NDArray, npt.NDArray]
|
|
29
|
+
SOSCoeffs = npt.NDArray
|
|
30
|
+
FilterCoefsType = typing.TypeVar("FilterCoefsType", BACoeffs, SOSCoeffs)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _normalize_coefs(
|
|
34
|
+
coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray | None,
|
|
35
|
+
) -> tuple[str, tuple[npt.NDArray, ...] | None]:
|
|
36
|
+
coef_type = "ba"
|
|
37
|
+
if coefs is not None:
|
|
38
|
+
# scipy.signal functions called with first arg `*coefs`.
|
|
39
|
+
# Make sure we have a tuple of coefficients.
|
|
40
|
+
if isinstance(coefs, np.ndarray):
|
|
41
|
+
coef_type = "sos"
|
|
42
|
+
coefs = (coefs,) # sos funcs just want a single ndarray.
|
|
43
|
+
elif isinstance(coefs, FilterCoefficients):
|
|
44
|
+
coefs = (coefs.b, coefs.a)
|
|
45
|
+
elif not isinstance(coefs, tuple):
|
|
46
|
+
coefs = (coefs,)
|
|
47
|
+
return coef_type, coefs
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class FilterBaseSettings(ez.Settings):
|
|
51
|
+
axis: str | None = None
|
|
52
|
+
"""The name of the axis to operate on."""
|
|
53
|
+
|
|
54
|
+
coef_type: str = "ba"
|
|
55
|
+
"""The type of filter coefficients. One of "ba" or "sos"."""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class FilterSettings(FilterBaseSettings):
|
|
59
|
+
coefs: FilterCoefficients | None = None
|
|
60
|
+
"""The pre-calculated filter coefficients."""
|
|
61
|
+
|
|
62
|
+
# Note: coef_type = "ba" is assumed for this class.
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@processor_state
|
|
66
|
+
class FilterState:
|
|
67
|
+
zi: npt.NDArray | None = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class FilterTransformer(BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]):
|
|
71
|
+
"""
|
|
72
|
+
Filter data using the provided coefficients.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
76
|
+
if self.settings.coefs is None:
|
|
77
|
+
return message
|
|
78
|
+
if self._state.zi is None:
|
|
79
|
+
self._reset_state(message)
|
|
80
|
+
self._hash = self._hash_message(message)
|
|
81
|
+
return super().__call__(message)
|
|
82
|
+
|
|
83
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
84
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
85
|
+
axis_idx = message.get_axis_idx(axis)
|
|
86
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
87
|
+
return hash((message.key, samp_shape))
|
|
88
|
+
|
|
89
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
90
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
91
|
+
axis_idx = message.get_axis_idx(axis)
|
|
92
|
+
n_tail = message.data.ndim - axis_idx - 1
|
|
93
|
+
_, coefs = _normalize_coefs(self.settings.coefs)
|
|
94
|
+
|
|
95
|
+
if self.settings.coef_type == "ba":
|
|
96
|
+
b, a = coefs
|
|
97
|
+
if len(a) == 1 or np.allclose(a[1:], 0):
|
|
98
|
+
# For FIR filters, use lfiltic with zero initial conditions
|
|
99
|
+
zi = scipy.signal.lfiltic(b, a, [])
|
|
100
|
+
else:
|
|
101
|
+
# For IIR filters...
|
|
102
|
+
zi = scipy.signal.lfilter_zi(b, a)
|
|
103
|
+
else:
|
|
104
|
+
# For second-order sections (SOS) filters, use sosfilt_zi
|
|
105
|
+
zi = scipy.signal.sosfilt_zi(*coefs)
|
|
106
|
+
|
|
107
|
+
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
|
|
108
|
+
n_tile = message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :]
|
|
109
|
+
|
|
110
|
+
if self.settings.coef_type == "sos":
|
|
111
|
+
zi_expand = (slice(None),) + zi_expand
|
|
112
|
+
n_tile = (1,) + n_tile
|
|
113
|
+
|
|
114
|
+
self.state.zi = np.tile(zi[zi_expand], n_tile)
|
|
115
|
+
|
|
116
|
+
def update_coefficients(
|
|
117
|
+
self,
|
|
118
|
+
coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
|
|
119
|
+
coef_type: str | None = None,
|
|
120
|
+
) -> None:
|
|
121
|
+
"""
|
|
122
|
+
Update filter coefficients.
|
|
123
|
+
|
|
124
|
+
If the new coefficients have the same length as the current ones, only the coefficients are updated.
|
|
125
|
+
If the lengths differ, the filter state is also reset to handle the new filter order.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
coefs: New filter coefficients
|
|
129
|
+
"""
|
|
130
|
+
old_coefs = self.settings.coefs
|
|
131
|
+
|
|
132
|
+
# Update settings with new coefficients
|
|
133
|
+
self.settings = replace(self.settings, coefs=coefs)
|
|
134
|
+
if coef_type is not None:
|
|
135
|
+
self.settings = replace(self.settings, coef_type=coef_type)
|
|
136
|
+
|
|
137
|
+
# Check if we need to reset the state
|
|
138
|
+
if self.state.zi is not None:
|
|
139
|
+
reset_needed = False
|
|
140
|
+
|
|
141
|
+
if self.settings.coef_type == "ba":
|
|
142
|
+
if isinstance(old_coefs, FilterCoefficients) and isinstance(coefs, FilterCoefficients):
|
|
143
|
+
if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(coefs.a):
|
|
144
|
+
reset_needed = True
|
|
145
|
+
elif isinstance(old_coefs, tuple) and isinstance(coefs, tuple):
|
|
146
|
+
if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(coefs[1]):
|
|
147
|
+
reset_needed = True
|
|
148
|
+
else:
|
|
149
|
+
reset_needed = True
|
|
150
|
+
elif self.settings.coef_type == "sos":
|
|
151
|
+
if isinstance(old_coefs, np.ndarray) and isinstance(coefs, np.ndarray):
|
|
152
|
+
if old_coefs.shape != coefs.shape:
|
|
153
|
+
reset_needed = True
|
|
154
|
+
else:
|
|
155
|
+
reset_needed = True
|
|
156
|
+
|
|
157
|
+
if reset_needed:
|
|
158
|
+
self.state.zi = None # This will trigger _reset_state on the next call
|
|
159
|
+
|
|
160
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
161
|
+
if message.data.size > 0:
|
|
162
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
163
|
+
axis_idx = message.get_axis_idx(axis)
|
|
164
|
+
_, coefs = _normalize_coefs(self.settings.coefs)
|
|
165
|
+
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[self.settings.coef_type]
|
|
166
|
+
dat_out, self.state.zi = filt_func(*coefs, message.data, axis=axis_idx, zi=self.state.zi)
|
|
167
|
+
else:
|
|
168
|
+
dat_out = message.data
|
|
169
|
+
|
|
170
|
+
return replace(message, data=dat_out)
|
|
22
171
|
|
|
23
172
|
|
|
24
|
-
class FilterSettings
|
|
25
|
-
|
|
26
|
-
filt: Optional[FilterCoefficients] = None
|
|
173
|
+
class Filter(BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]):
|
|
174
|
+
SETTINGS = FilterSettings
|
|
27
175
|
|
|
28
176
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
filt_designed: bool = False
|
|
33
|
-
filt: Optional[FilterCoefficients] = None
|
|
34
|
-
filt_set: asyncio.Event = field(default_factory=asyncio.Event)
|
|
35
|
-
samp_shape: Optional[Tuple[int, ...]] = None
|
|
36
|
-
fs: Optional[float] = None # Hz
|
|
177
|
+
def filtergen(axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str) -> FilterTransformer:
|
|
178
|
+
"""
|
|
179
|
+
Filter data using the provided coefficients.
|
|
37
180
|
|
|
181
|
+
Returns:
|
|
182
|
+
:obj:`FilterTransformer`.
|
|
183
|
+
"""
|
|
184
|
+
return FilterTransformer(FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type))
|
|
38
185
|
|
|
39
|
-
class Filter(ez.Unit):
|
|
40
|
-
SETTINGS: FilterSettingsBase
|
|
41
|
-
STATE: FilterState
|
|
42
186
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
187
|
+
@processor_state
|
|
188
|
+
class FilterByDesignState:
|
|
189
|
+
filter: FilterTransformer | None = None
|
|
190
|
+
needs_redesign: bool = False
|
|
46
191
|
|
|
47
|
-
def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
48
|
-
raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
|
|
49
192
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
193
|
+
class FilterByDesignTransformer(
|
|
194
|
+
BaseStatefulTransformer[SettingsType, AxisArray, AxisArray, FilterByDesignState],
|
|
195
|
+
ABC,
|
|
196
|
+
typing.Generic[SettingsType, FilterCoefsType],
|
|
197
|
+
):
|
|
198
|
+
"""Abstract base class for filter design transformers."""
|
|
54
199
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
200
|
+
@classmethod
|
|
201
|
+
def get_message_type(cls, dir: str) -> type[AxisArray]:
|
|
202
|
+
if dir in ("in", "out"):
|
|
203
|
+
return AxisArray
|
|
204
|
+
else:
|
|
205
|
+
raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
|
|
206
|
+
|
|
207
|
+
@abstractmethod
|
|
208
|
+
def get_design_function(self) -> typing.Callable[[float], FilterCoefsType | None]:
|
|
209
|
+
"""Return a function that takes sampling frequency and returns filter coefficients."""
|
|
210
|
+
...
|
|
211
|
+
|
|
212
|
+
def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
|
|
213
|
+
"""
|
|
214
|
+
Update settings and mark that filter coefficients need to be recalculated.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
new_settings: Complete new settings object to replace current settings
|
|
218
|
+
**kwargs: Individual settings to update
|
|
219
|
+
"""
|
|
220
|
+
# Update settings
|
|
221
|
+
if new_settings is not None:
|
|
222
|
+
self.settings = new_settings
|
|
223
|
+
else:
|
|
224
|
+
self.settings = replace(self.settings, **kwargs)
|
|
225
|
+
|
|
226
|
+
# Set flag to trigger recalculation on next message
|
|
227
|
+
if self.state.filter is not None:
|
|
228
|
+
self.state.needs_redesign = True
|
|
229
|
+
|
|
230
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
231
|
+
# Offer a shortcut when there is no design function or order is 0.
|
|
232
|
+
if hasattr(self.settings, "order") and not self.settings.order:
|
|
233
|
+
return message
|
|
234
|
+
design_fun = self.get_design_function()
|
|
235
|
+
if design_fun is None:
|
|
236
|
+
return message
|
|
237
|
+
|
|
238
|
+
# Check if filter exists but needs redesign due to settings change
|
|
239
|
+
if self.state.filter is not None and self.state.needs_redesign:
|
|
240
|
+
axis = self.state.filter.settings.axis
|
|
241
|
+
fs = 1 / message.axes[axis].gain
|
|
242
|
+
coefs = design_fun(fs)
|
|
243
|
+
|
|
244
|
+
# Convert BA to SOS if requested
|
|
245
|
+
if coefs is not None and self.settings.coef_type == "sos":
|
|
246
|
+
if isinstance(coefs, tuple) and len(coefs) == 2:
|
|
247
|
+
# It's BA format, convert to SOS
|
|
248
|
+
b, a = coefs
|
|
249
|
+
coefs = scipy.signal.tf2sos(b, a)
|
|
250
|
+
|
|
251
|
+
self.state.filter.update_coefficients(coefs, coef_type=self.settings.coef_type)
|
|
252
|
+
self.state.needs_redesign = False
|
|
253
|
+
|
|
254
|
+
return super().__call__(message)
|
|
255
|
+
|
|
256
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
257
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
258
|
+
gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
|
|
259
|
+
axis_idx = message.get_axis_idx(axis)
|
|
260
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
261
|
+
return hash((message.key, samp_shape, gain))
|
|
262
|
+
|
|
263
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
264
|
+
design_fun = self.get_design_function()
|
|
265
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
266
|
+
fs = 1 / message.axes[axis].gain
|
|
267
|
+
coefs = design_fun(fs)
|
|
268
|
+
|
|
269
|
+
# Convert BA to SOS if requested
|
|
270
|
+
if coefs is not None and self.settings.coef_type == "sos":
|
|
271
|
+
if isinstance(coefs, tuple) and len(coefs) == 2:
|
|
272
|
+
# It's BA format, convert to SOS
|
|
273
|
+
b, a = coefs
|
|
274
|
+
coefs = scipy.signal.tf2sos(b, a)
|
|
275
|
+
|
|
276
|
+
new_settings = FilterSettings(axis=axis, coef_type=self.settings.coef_type, coefs=coefs)
|
|
277
|
+
self.state.filter = FilterTransformer(settings=new_settings)
|
|
278
|
+
|
|
279
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
280
|
+
return self.state.filter(message)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class BaseFilterByDesignTransformerUnit(
|
|
284
|
+
BaseTransformerUnit[SettingsType, AxisArray, AxisArray, FilterByDesignTransformer],
|
|
285
|
+
typing.Generic[SettingsType, TransformerType],
|
|
286
|
+
):
|
|
287
|
+
@ez.subscriber(BaseConsumerUnit.INPUT_SETTINGS)
|
|
288
|
+
async def on_settings(self, msg: SettingsType) -> None:
|
|
289
|
+
"""
|
|
290
|
+
Receive a settings message, override self.SETTINGS, and re-create the processor.
|
|
291
|
+
Child classes that wish to have fine-grained control over whether the
|
|
292
|
+
core processor resets on settings changes should override this method.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
msg: a settings message.
|
|
296
|
+
"""
|
|
297
|
+
self.apply_settings(msg)
|
|
298
|
+
|
|
299
|
+
# Check if processor exists yet
|
|
300
|
+
if hasattr(self, "processor") and self.processor is not None:
|
|
301
|
+
# Update the existing processor with new settings
|
|
302
|
+
self.processor.update_settings(self.SETTINGS)
|
|
59
303
|
else:
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
if self.SETTINGS.fs is not None:
|
|
63
|
-
try:
|
|
64
|
-
self.update_filter()
|
|
65
|
-
except NotImplementedError:
|
|
66
|
-
ez.logger.debug("Using filter coefficients.")
|
|
67
|
-
|
|
68
|
-
@ez.subscriber(INPUT_FILTER)
|
|
69
|
-
async def redesign(self, message: FilterCoefficients):
|
|
70
|
-
self.STATE.filt = message
|
|
71
|
-
|
|
72
|
-
def update_filter(self):
|
|
73
|
-
try:
|
|
74
|
-
coefs = self.design_filter()
|
|
75
|
-
self.STATE.filt = (
|
|
76
|
-
FilterCoefficients() if coefs is None else FilterCoefficients(*coefs)
|
|
77
|
-
)
|
|
78
|
-
self.STATE.filt_set.set()
|
|
79
|
-
self.STATE.filt_designed = True
|
|
80
|
-
except NotImplementedError as e:
|
|
81
|
-
raise e
|
|
82
|
-
except Exception as e:
|
|
83
|
-
ez.logger.warning(f"Error when designing filter: {e}")
|
|
84
|
-
|
|
85
|
-
@ez.subscriber(INPUT_SIGNAL)
|
|
86
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
87
|
-
async def apply_filter(self, msg: AxisArray) -> AsyncGenerator:
|
|
88
|
-
axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
|
|
89
|
-
axis_idx = msg.get_axis_idx(axis_name)
|
|
90
|
-
axis = msg.get_axis(axis_name)
|
|
91
|
-
fs = 1.0 / axis.gain
|
|
92
|
-
|
|
93
|
-
if self.STATE.fs != fs and self.STATE.filt_designed is True:
|
|
94
|
-
self.STATE.fs = fs
|
|
95
|
-
self.update_filter()
|
|
96
|
-
|
|
97
|
-
# Ensure filter is defined
|
|
98
|
-
# TODO: Maybe have me be a passthrough filter until coefficients are received
|
|
99
|
-
if self.STATE.filt is None:
|
|
100
|
-
self.STATE.filt_set.clear()
|
|
101
|
-
ez.logger.info("Awaiting filter coefficients...")
|
|
102
|
-
await self.STATE.filt_set.wait()
|
|
103
|
-
ez.logger.info("Filter coefficients received.")
|
|
104
|
-
|
|
105
|
-
assert self.STATE.filt is not None
|
|
106
|
-
|
|
107
|
-
arr_in = msg.data
|
|
108
|
-
|
|
109
|
-
# If the array is one dimensional, add a temporary second dimension so that the math works out
|
|
110
|
-
one_dimensional = False
|
|
111
|
-
if arr_in.ndim == 1:
|
|
112
|
-
arr_in = np.expand_dims(arr_in, axis=1)
|
|
113
|
-
one_dimensional = True
|
|
114
|
-
|
|
115
|
-
# We will perform filter with time dimension as last axis
|
|
116
|
-
arr_in = np.moveaxis(arr_in, axis_idx, -1)
|
|
117
|
-
samp_shape = arr_in[..., 0].shape
|
|
118
|
-
|
|
119
|
-
# Re-calculate/reset zi if necessary
|
|
120
|
-
if self.STATE.zi is None or samp_shape != self.STATE.samp_shape:
|
|
121
|
-
zi: np.ndarray = scipy.signal.lfilter_zi(
|
|
122
|
-
self.STATE.filt.b, self.STATE.filt.a
|
|
123
|
-
)
|
|
124
|
-
self.STATE.samp_shape = samp_shape
|
|
125
|
-
self.STATE.zi = np.array([zi] * np.prod(self.STATE.samp_shape))
|
|
126
|
-
self.STATE.zi = self.STATE.zi.reshape(
|
|
127
|
-
tuple(list(self.STATE.samp_shape) + [zi.shape[0]])
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
arr_out, self.STATE.zi = scipy.signal.lfilter(
|
|
131
|
-
self.STATE.filt.b, self.STATE.filt.a, arr_in, zi=self.STATE.zi
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
arr_out = np.moveaxis(arr_out, -1, axis_idx)
|
|
135
|
-
|
|
136
|
-
# Remove temporary first dimension if necessary
|
|
137
|
-
if one_dimensional:
|
|
138
|
-
arr_out = np.squeeze(arr_out, axis=1)
|
|
139
|
-
|
|
140
|
-
yield self.OUTPUT_SIGNAL, replace(msg, data=arr_out),
|
|
304
|
+
# Processor doesn't exist yet, create a new one
|
|
305
|
+
self.create_processor()
|