ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +133 -101
  6. ezmsg/sigproc/bandpower.py +64 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -84
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.1.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/filter.py CHANGED
@@ -1,15 +1,22 @@
1
+ from abc import abstractmethod, ABC
1
2
  from dataclasses import dataclass, field
2
3
  import typing
3
4
 
4
5
  import ezmsg.core as ez
5
6
  from ezmsg.util.messages.axisarray import AxisArray
6
7
  from ezmsg.util.messages.util import replace
7
- from ezmsg.util.generator import consumer
8
8
  import numpy as np
9
9
  import numpy.typing as npt
10
10
  import scipy.signal
11
11
 
12
- from ezmsg.sigproc.base import GenAxisArray
12
+ from ezmsg.sigproc.base import (
13
+ processor_state,
14
+ BaseStatefulTransformer,
15
+ BaseTransformerUnit,
16
+ SettingsType,
17
+ BaseConsumerUnit,
18
+ TransformerType,
19
+ )
13
20
 
14
21
 
15
22
  @dataclass
@@ -18,6 +25,12 @@ class FilterCoefficients:
18
25
  a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
19
26
 
20
27
 
28
+ # Type aliases
29
+ BACoeffs = tuple[npt.NDArray, npt.NDArray]
30
+ SOSCoeffs = npt.NDArray
31
+ FilterCoefsType = typing.TypeVar("FilterCoefsType", BACoeffs, SOSCoeffs)
32
+
33
+
21
34
  def _normalize_coefs(
22
35
  coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
23
36
  ) -> tuple[str, tuple[npt.NDArray, ...]]:
@@ -33,167 +46,270 @@ def _normalize_coefs(
33
46
  return coef_type, coefs
34
47
 
35
48
 
36
- @consumer
37
- def filtergen(
38
- axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str
39
- ) -> typing.Generator[AxisArray, AxisArray, None]:
40
- """
41
- Filter data using the provided coefficients.
42
-
43
- Args:
44
- axis: The name of the axis to operate on.
45
- coefs: The pre-calculated filter coefficients.
46
- coef_type: The type of filter coefficients. One of "ba" or "sos".
47
-
48
- Returns:
49
- A primed generator that, when passed an :obj:`AxisArray` via `.send(axis_array)`,
50
- yields an :obj:`AxisArray` with the data filtered.
51
- """
52
- # Massage inputs
53
- if coefs is not None and not isinstance(coefs, tuple):
54
- # scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
55
- coefs = (coefs,)
49
+ class FilterBaseSettings(ez.Settings):
50
+ axis: str | None = None
51
+ """The name of the axis to operate on."""
56
52
 
57
- # Init IO
58
- msg_out = AxisArray(np.array([]), dims=[""])
53
+ coef_type: str = "ba"
54
+ """The type of filter coefficients. One of "ba" or "sos"."""
59
55
 
60
- filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
61
- zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
62
56
 
63
- # State variables
64
- zi: npt.NDArray | None = None
57
+ class FilterSettings(FilterBaseSettings):
58
+ coefs: FilterCoefficients | None = None
59
+ """The pre-calculated filter coefficients."""
65
60
 
66
- # Reset if these change.
67
- check_input = {"key": None, "shape": None}
68
- # fs changing will be handled by caller that creates coefficients.
69
-
70
- while True:
71
- msg_in: AxisArray = yield msg_out
72
-
73
- if coefs is None:
74
- # passthrough if we do not have a filter design.
75
- msg_out = msg_in
76
- continue
77
-
78
- axis = msg_in.dims[0] if axis is None else axis
79
- axis_idx = msg_in.get_axis_idx(axis)
80
-
81
- # Re-calculate/reset zi if necessary
82
- samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
83
- b_reset = samp_shape != check_input["shape"]
84
- b_reset = b_reset or msg_in.key != check_input["key"]
85
- if b_reset:
86
- check_input["shape"] = samp_shape
87
- check_input["key"] = msg_in.key
88
-
89
- n_tail = msg_in.data.ndim - axis_idx - 1
90
- zi = zi_func(*coefs)
91
- zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
92
- n_tile = (
93
- msg_in.data.shape[:axis_idx] + (1,) + msg_in.data.shape[axis_idx + 1 :]
94
- )
95
- if coef_type == "sos":
96
- # sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
97
- zi_expand = (slice(None),) + zi_expand
98
- n_tile = (1,) + n_tile
99
- zi = np.tile(zi[zi_expand], n_tile)
100
-
101
- if msg_in.data.size > 0:
102
- dat_out, zi = filt_func(*coefs, msg_in.data, axis=axis_idx, zi=zi)
103
- else:
104
- dat_out = msg_in.data
105
- msg_out = replace(msg_in, data=dat_out)
61
+ # Note: coef_type = "ba" is assumed for this class.
106
62
 
107
63
 
108
- # Type aliases
109
- BACoeffs = tuple[npt.NDArray, npt.NDArray]
110
- SOSCoeffs = npt.NDArray
111
- FilterCoefsMultiType = BACoeffs | SOSCoeffs
64
+ @processor_state
65
+ class FilterState:
66
+ zi: npt.NDArray | None = None
112
67
 
113
68
 
114
- @consumer
115
- def filter_gen_by_design(
116
- axis: str,
117
- coef_type: str,
118
- design_fun: typing.Callable[[float], FilterCoefsMultiType | None],
119
- ) -> typing.Generator[AxisArray, AxisArray, None]:
69
+ class FilterTransformer(
70
+ BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]
71
+ ):
120
72
  """
121
- Filter data using a filter whose coefficients are calculated using the provided design function.
122
-
123
- Args:
124
- axis: The name of the axis to filter.
125
- Note: The axis must be represented in the message .axes and be of type AxisArray.LinearAxis.
126
- coef_type: "ba" or "sos"
127
- design_fun: A callable that takes "fs" as its only argument and returns a tuple of filter coefficients.
128
- If the design_fun returns None then the filter will act as a passthrough.
129
- Hint: To make a design function that only requires fs, use functools.partial to set other parameters.
130
- See butterworthfilter for an example.
131
-
132
- Returns:
133
-
73
+ Filter data using the provided coefficients.
134
74
  """
135
- msg_out = AxisArray(np.array([]), dims=[""])
136
75
 
137
- # State variables
138
- # Initialize filtergen as passthrough until we receive a message that allows us to design the filter.
139
- filter_gen = filtergen(axis, None, coef_type)
140
-
141
- # Reset if these change.
142
- check_input = {"gain": None}
143
- # No need to check parameters that don't affect the design; filter_gen should check most of its parameters.
144
-
145
- while True:
146
- msg_in: AxisArray = yield msg_out
147
- axis = axis or msg_in.dims[0]
148
- b_reset = msg_in.axes[axis].gain != check_input["gain"]
149
- if b_reset:
150
- check_input["gain"] = msg_in.axes[axis].gain
151
- coefs = design_fun(1 / msg_in.axes[axis].gain)
152
- filter_gen = filtergen(axis, coefs, coef_type)
153
-
154
- msg_out = filter_gen.send(msg_in)
76
+ def __call__(self, message: AxisArray) -> AxisArray:
77
+ if self.settings.coefs is None:
78
+ return message
79
+ if self._state.zi is None:
80
+ self._reset_state(message)
81
+ self._hash = self._hash_message(message)
82
+ return super().__call__(message)
83
+
84
+ def _hash_message(self, message: AxisArray) -> int:
85
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
86
+ axis_idx = message.get_axis_idx(axis)
87
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
88
+ return hash((message.key, samp_shape))
89
+
90
+ def _reset_state(self, message: AxisArray) -> None:
91
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
92
+ axis_idx = message.get_axis_idx(axis)
93
+ n_tail = message.data.ndim - axis_idx - 1
94
+ coefs = (
95
+ (self.settings.coefs,)
96
+ if self.settings.coefs is not None
97
+ and not isinstance(self.settings.coefs, tuple)
98
+ else self.settings.coefs
99
+ )
100
+ zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[
101
+ self.settings.coef_type
102
+ ]
103
+ zi = zi_func(*coefs)
104
+ zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
105
+ n_tile = (
106
+ message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :]
107
+ )
155
108
 
109
+ if self.settings.coef_type == "sos":
110
+ zi_expand = (slice(None),) + zi_expand
111
+ n_tile = (1,) + n_tile
156
112
 
157
- class FilterBaseSettings(ez.Settings):
158
- axis: str | None = None
159
- coef_type: str = "ba"
113
+ self.state.zi = np.tile(zi[zi_expand], n_tile)
160
114
 
115
+ def update_coefficients(
116
+ self,
117
+ coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
118
+ coef_type: str | None = None,
119
+ ) -> None:
120
+ """
121
+ Update filter coefficients.
122
+
123
+ If the new coefficients have the same length as the current ones, only the coefficients are updated.
124
+ If the lengths differ, the filter state is also reset to handle the new filter order.
125
+
126
+ Args:
127
+ coefs: New filter coefficients
128
+ """
129
+ old_coefs = self.settings.coefs
130
+
131
+ # Update settings with new coefficients
132
+ self.settings = replace(self.settings, coefs=coefs)
133
+ if coef_type is not None:
134
+ self.settings = replace(self.settings, coef_type=coef_type)
135
+
136
+ # Check if we need to reset the state
137
+ if self.state.zi is not None:
138
+ reset_needed = False
139
+
140
+ if self.settings.coef_type == "ba":
141
+ if isinstance(old_coefs, FilterCoefficients) and isinstance(
142
+ coefs, FilterCoefficients
143
+ ):
144
+ if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(
145
+ coefs.a
146
+ ):
147
+ reset_needed = True
148
+ elif isinstance(old_coefs, tuple) and isinstance(coefs, tuple):
149
+ if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(
150
+ coefs[1]
151
+ ):
152
+ reset_needed = True
153
+ else:
154
+ reset_needed = True
155
+ elif self.settings.coef_type == "sos":
156
+ if isinstance(old_coefs, np.ndarray) and isinstance(coefs, np.ndarray):
157
+ if old_coefs.shape != coefs.shape:
158
+ reset_needed = True
159
+ else:
160
+ reset_needed = True
161
+
162
+ if reset_needed:
163
+ self.state.zi = None # This will trigger _reset_state on the next call
164
+
165
+ def _process(self, message: AxisArray) -> AxisArray:
166
+ if message.data.size > 0:
167
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
168
+ axis_idx = message.get_axis_idx(axis)
169
+ coefs = (
170
+ (self.settings.coefs,)
171
+ if self.settings.coefs is not None
172
+ and not isinstance(self.settings.coefs, tuple)
173
+ else self.settings.coefs
174
+ )
175
+ filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[
176
+ self.settings.coef_type
177
+ ]
178
+ dat_out, self.state.zi = filt_func(
179
+ *coefs, message.data, axis=axis_idx, zi=self.state.zi
180
+ )
181
+ else:
182
+ dat_out = message.data
161
183
 
162
- class FilterBase(GenAxisArray):
163
- SETTINGS = FilterBaseSettings
184
+ return replace(message, data=dat_out)
164
185
 
165
- # Backwards-compatible with `Filter` unit
166
- INPUT_FILTER = ez.InputStream(FilterCoefsMultiType)
167
186
 
168
- def design_filter(
169
- self,
170
- ) -> typing.Callable[[float], FilterCoefsMultiType | None]:
171
- raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
187
+ class Filter(
188
+ BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]
189
+ ):
190
+ SETTINGS = FilterSettings
172
191
 
173
- def construct_generator(self):
174
- design_fun = self.design_filter()
175
- self.STATE.gen = filter_gen_by_design(
176
- self.SETTINGS.axis, self.SETTINGS.coef_type, design_fun
177
- )
178
192
 
179
- @ez.subscriber(INPUT_FILTER)
180
- async def redesign(self, message: FilterBaseSettings) -> None:
181
- self.apply_settings(message)
182
- self.construct_generator()
193
+ def filtergen(
194
+ axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str
195
+ ) -> FilterTransformer:
196
+ """
197
+ Filter data using the provided coefficients.
183
198
 
199
+ Returns:
200
+ :obj:`FilterTransformer`.
201
+ """
202
+ return FilterTransformer(
203
+ FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type)
204
+ )
184
205
 
185
- class FilterSettings(FilterBaseSettings):
186
- # If you'd like to statically design a filter, define it in settings
187
- coefs: FilterCoefficients | None = None
188
- # Note: coef_type = "ba" is assumed for this class.
189
206
 
207
+ @processor_state
208
+ class FilterByDesignState:
209
+ filter: FilterTransformer | None = None
210
+ needs_redesign: bool = False
190
211
 
191
- class Filter(FilterBase):
192
- SETTINGS = FilterSettings
193
212
 
194
- INPUT_FILTER = ez.InputStream(FilterCoefficients)
213
+ class FilterByDesignTransformer(
214
+ BaseStatefulTransformer[SettingsType, AxisArray, AxisArray, FilterByDesignState],
215
+ ABC,
216
+ typing.Generic[SettingsType, FilterCoefsType],
217
+ ):
218
+ """Abstract base class for filter design transformers."""
195
219
 
196
- def design_filter(self) -> typing.Callable[[float], BACoeffs | None]:
197
- if self.SETTINGS.coefs is None:
198
- return lambda fs: None
199
- return lambda fs: (self.SETTINGS.coefs.b, self.SETTINGS.coefs.a)
220
+ @classmethod
221
+ def get_message_type(cls, dir: str) -> type[AxisArray]:
222
+ if dir in ("in", "out"):
223
+ return AxisArray
224
+ else:
225
+ raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
226
+
227
+ @abstractmethod
228
+ def get_design_function(self) -> typing.Callable[[float], FilterCoefsType | None]:
229
+ """Return a function that takes sampling frequency and returns filter coefficients."""
230
+ ...
231
+
232
+ def update_settings(
233
+ self, new_settings: typing.Optional[SettingsType] = None, **kwargs
234
+ ) -> None:
235
+ """
236
+ Update settings and mark that filter coefficients need to be recalculated.
237
+
238
+ Args:
239
+ new_settings: Complete new settings object to replace current settings
240
+ **kwargs: Individual settings to update
241
+ """
242
+ # Update settings
243
+ if new_settings is not None:
244
+ self.settings = new_settings
245
+ else:
246
+ self.settings = replace(self.settings, **kwargs)
247
+
248
+ # Set flag to trigger recalculation on next message
249
+ if self.state.filter is not None:
250
+ self.state.needs_redesign = True
251
+
252
+ def __call__(self, message: AxisArray) -> AxisArray:
253
+ # Offer a shortcut when there is no design function or order is 0.
254
+ if hasattr(self.settings, "order") and not self.settings.order:
255
+ return message
256
+ design_fun = self.get_design_function()
257
+ if design_fun is None:
258
+ return message
259
+
260
+ # Check if filter exists but needs redesign due to settings change
261
+ if self.state.filter is not None and self.state.needs_redesign:
262
+ axis = self.state.filter.settings.axis
263
+ fs = 1 / message.axes[axis].gain
264
+ coefs = design_fun(fs)
265
+ self.state.filter.update_coefficients(
266
+ coefs, coef_type=self.settings.coef_type
267
+ )
268
+ self.state.needs_redesign = False
269
+
270
+ return super().__call__(message)
271
+
272
+ def _hash_message(self, message: AxisArray) -> int:
273
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
274
+ gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
275
+ axis_idx = message.get_axis_idx(axis)
276
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
277
+ return hash((message.key, samp_shape, gain))
278
+
279
+ def _reset_state(self, message: AxisArray) -> None:
280
+ design_fun = self.get_design_function()
281
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
282
+ fs = 1 / message.axes[axis].gain
283
+ coefs = design_fun(fs)
284
+ new_settings = FilterSettings(
285
+ axis=axis, coef_type=self.settings.coef_type, coefs=coefs
286
+ )
287
+ self.state.filter = FilterTransformer(settings=new_settings)
288
+
289
+ def _process(self, message: AxisArray) -> AxisArray:
290
+ return self.state.filter(message)
291
+
292
+
293
+ class BaseFilterByDesignTransformerUnit(
294
+ BaseTransformerUnit[SettingsType, AxisArray, AxisArray, FilterByDesignTransformer],
295
+ typing.Generic[SettingsType, TransformerType],
296
+ ):
297
+ @ez.subscriber(BaseConsumerUnit.INPUT_SETTINGS)
298
+ async def on_settings(self, msg: SettingsType) -> None:
299
+ """
300
+ Receive a settings message, override self.SETTINGS, and re-create the processor.
301
+ Child classes that wish to have fine-grained control over whether the
302
+ core processor resets on settings changes should override this method.
303
+
304
+ Args:
305
+ msg: a settings message.
306
+ """
307
+ self.apply_settings(msg)
308
+
309
+ # Check if processor exists yet
310
+ if hasattr(self, "processor") and self.processor is not None:
311
+ # Update the existing processor with new settings
312
+ self.processor.update_settings(self.SETTINGS)
313
+ else:
314
+ # Processor doesn't exist yet, create a new one
315
+ self.create_processor()