ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/filter.py CHANGED
@@ -1,13 +1,21 @@
1
- from dataclasses import dataclass, replace, field
1
+ import typing
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass, field
2
4
 
3
5
  import ezmsg.core as ez
4
- import scipy.signal
5
6
  import numpy as np
6
- import asyncio
7
-
7
+ import numpy.typing as npt
8
+ import scipy.signal
9
+ from ezmsg.baseproc import (
10
+ BaseConsumerUnit,
11
+ BaseStatefulTransformer,
12
+ BaseTransformerUnit,
13
+ SettingsType,
14
+ TransformerType,
15
+ processor_state,
16
+ )
8
17
  from ezmsg.util.messages.axisarray import AxisArray
9
-
10
- from typing import AsyncGenerator, Optional, Tuple
18
+ from ezmsg.util.messages.util import replace
11
19
 
12
20
 
13
21
  @dataclass
@@ -16,125 +24,282 @@ class FilterCoefficients:
16
24
  a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
17
25
 
18
26
 
19
- class FilterSettingsBase(ez.Settings):
20
- axis: Optional[str] = None
21
- fs: Optional[float] = None
27
+ # Type aliases
28
+ BACoeffs = tuple[npt.NDArray, npt.NDArray]
29
+ SOSCoeffs = npt.NDArray
30
+ FilterCoefsType = typing.TypeVar("FilterCoefsType", BACoeffs, SOSCoeffs)
31
+
32
+
33
+ def _normalize_coefs(
34
+ coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray | None,
35
+ ) -> tuple[str, tuple[npt.NDArray, ...] | None]:
36
+ coef_type = "ba"
37
+ if coefs is not None:
38
+ # scipy.signal functions called with first arg `*coefs`.
39
+ # Make sure we have a tuple of coefficients.
40
+ if isinstance(coefs, np.ndarray):
41
+ coef_type = "sos"
42
+ coefs = (coefs,) # sos funcs just want a single ndarray.
43
+ elif isinstance(coefs, FilterCoefficients):
44
+ coefs = (coefs.b, coefs.a)
45
+ elif not isinstance(coefs, tuple):
46
+ coefs = (coefs,)
47
+ return coef_type, coefs
48
+
49
+
50
+ class FilterBaseSettings(ez.Settings):
51
+ axis: str | None = None
52
+ """The name of the axis to operate on."""
53
+
54
+ coef_type: str = "ba"
55
+ """The type of filter coefficients. One of "ba" or "sos"."""
56
+
57
+
58
+ class FilterSettings(FilterBaseSettings):
59
+ coefs: FilterCoefficients | None = None
60
+ """The pre-calculated filter coefficients."""
61
+
62
+ # Note: coef_type = "ba" is assumed for this class.
63
+
64
+
65
+ @processor_state
66
+ class FilterState:
67
+ zi: npt.NDArray | None = None
68
+
69
+
70
+ class FilterTransformer(BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]):
71
+ """
72
+ Filter data using the provided coefficients.
73
+ """
74
+
75
+ def __call__(self, message: AxisArray) -> AxisArray:
76
+ if self.settings.coefs is None:
77
+ return message
78
+ if self._state.zi is None:
79
+ self._reset_state(message)
80
+ self._hash = self._hash_message(message)
81
+ return super().__call__(message)
82
+
83
+ def _hash_message(self, message: AxisArray) -> int:
84
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
85
+ axis_idx = message.get_axis_idx(axis)
86
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
87
+ return hash((message.key, samp_shape))
88
+
89
+ def _reset_state(self, message: AxisArray) -> None:
90
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
91
+ axis_idx = message.get_axis_idx(axis)
92
+ n_tail = message.data.ndim - axis_idx - 1
93
+ _, coefs = _normalize_coefs(self.settings.coefs)
94
+
95
+ if self.settings.coef_type == "ba":
96
+ b, a = coefs
97
+ if len(a) == 1 or np.allclose(a[1:], 0):
98
+ # For FIR filters, use lfiltic with zero initial conditions
99
+ zi = scipy.signal.lfiltic(b, a, [])
100
+ else:
101
+ # For IIR filters...
102
+ zi = scipy.signal.lfilter_zi(b, a)
103
+ else:
104
+ # For second-order sections (SOS) filters, use sosfilt_zi
105
+ zi = scipy.signal.sosfilt_zi(*coefs)
106
+
107
+ zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
108
+ n_tile = message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :]
109
+
110
+ if self.settings.coef_type == "sos":
111
+ zi_expand = (slice(None),) + zi_expand
112
+ n_tile = (1,) + n_tile
113
+
114
+ self.state.zi = np.tile(zi[zi_expand], n_tile)
115
+
116
+ def update_coefficients(
117
+ self,
118
+ coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray,
119
+ coef_type: str | None = None,
120
+ ) -> None:
121
+ """
122
+ Update filter coefficients.
123
+
124
+ If the new coefficients have the same length as the current ones, only the coefficients are updated.
125
+ If the lengths differ, the filter state is also reset to handle the new filter order.
126
+
127
+ Args:
128
+ coefs: New filter coefficients
129
+ """
130
+ old_coefs = self.settings.coefs
131
+
132
+ # Update settings with new coefficients
133
+ self.settings = replace(self.settings, coefs=coefs)
134
+ if coef_type is not None:
135
+ self.settings = replace(self.settings, coef_type=coef_type)
136
+
137
+ # Check if we need to reset the state
138
+ if self.state.zi is not None:
139
+ reset_needed = False
140
+
141
+ if self.settings.coef_type == "ba":
142
+ if isinstance(old_coefs, FilterCoefficients) and isinstance(coefs, FilterCoefficients):
143
+ if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(coefs.a):
144
+ reset_needed = True
145
+ elif isinstance(old_coefs, tuple) and isinstance(coefs, tuple):
146
+ if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(coefs[1]):
147
+ reset_needed = True
148
+ else:
149
+ reset_needed = True
150
+ elif self.settings.coef_type == "sos":
151
+ if isinstance(old_coefs, np.ndarray) and isinstance(coefs, np.ndarray):
152
+ if old_coefs.shape != coefs.shape:
153
+ reset_needed = True
154
+ else:
155
+ reset_needed = True
156
+
157
+ if reset_needed:
158
+ self.state.zi = None # This will trigger _reset_state on the next call
159
+
160
+ def _process(self, message: AxisArray) -> AxisArray:
161
+ if message.data.size > 0:
162
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
163
+ axis_idx = message.get_axis_idx(axis)
164
+ _, coefs = _normalize_coefs(self.settings.coefs)
165
+ filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[self.settings.coef_type]
166
+ dat_out, self.state.zi = filt_func(*coefs, message.data, axis=axis_idx, zi=self.state.zi)
167
+ else:
168
+ dat_out = message.data
169
+
170
+ return replace(message, data=dat_out)
22
171
 
23
172
 
24
- class FilterSettings(FilterSettingsBase):
25
- # If you'd like to statically design a filter, define it in settings
26
- filt: Optional[FilterCoefficients] = None
173
+ class Filter(BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]):
174
+ SETTINGS = FilterSettings
27
175
 
28
176
 
29
- class FilterState(ez.State):
30
- axis: Optional[str] = None
31
- zi: Optional[np.ndarray] = None
32
- filt_designed: bool = False
33
- filt: Optional[FilterCoefficients] = None
34
- filt_set: asyncio.Event = field(default_factory=asyncio.Event)
35
- samp_shape: Optional[Tuple[int, ...]] = None
36
- fs: Optional[float] = None # Hz
177
+ def filtergen(axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str) -> FilterTransformer:
178
+ """
179
+ Filter data using the provided coefficients.
37
180
 
181
+ Returns:
182
+ :obj:`FilterTransformer`.
183
+ """
184
+ return FilterTransformer(FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type))
38
185
 
39
- class Filter(ez.Unit):
40
- SETTINGS: FilterSettingsBase
41
- STATE: FilterState
42
186
 
43
- INPUT_FILTER = ez.InputStream(FilterCoefficients)
44
- INPUT_SIGNAL = ez.InputStream(AxisArray)
45
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
187
+ @processor_state
188
+ class FilterByDesignState:
189
+ filter: FilterTransformer | None = None
190
+ needs_redesign: bool = False
46
191
 
47
- def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
48
- raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
49
192
 
50
- # Set up filter with static initialization if specified
51
- def initialize(self) -> None:
52
- if self.SETTINGS.axis is not None:
53
- self.STATE.axis = self.SETTINGS.axis
193
+ class FilterByDesignTransformer(
194
+ BaseStatefulTransformer[SettingsType, AxisArray, AxisArray, FilterByDesignState],
195
+ ABC,
196
+ typing.Generic[SettingsType, FilterCoefsType],
197
+ ):
198
+ """Abstract base class for filter design transformers."""
54
199
 
55
- if isinstance(self.SETTINGS, FilterSettings):
56
- if self.SETTINGS.filt is not None:
57
- self.STATE.filt = self.SETTINGS.filt
58
- self.STATE.filt_set.set()
200
+ @classmethod
201
+ def get_message_type(cls, dir: str) -> type[AxisArray]:
202
+ if dir in ("in", "out"):
203
+ return AxisArray
204
+ else:
205
+ raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
206
+
207
+ @abstractmethod
208
+ def get_design_function(self) -> typing.Callable[[float], FilterCoefsType | None]:
209
+ """Return a function that takes sampling frequency and returns filter coefficients."""
210
+ ...
211
+
212
+ def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
213
+ """
214
+ Update settings and mark that filter coefficients need to be recalculated.
215
+
216
+ Args:
217
+ new_settings: Complete new settings object to replace current settings
218
+ **kwargs: Individual settings to update
219
+ """
220
+ # Update settings
221
+ if new_settings is not None:
222
+ self.settings = new_settings
223
+ else:
224
+ self.settings = replace(self.settings, **kwargs)
225
+
226
+ # Set flag to trigger recalculation on next message
227
+ if self.state.filter is not None:
228
+ self.state.needs_redesign = True
229
+
230
+ def __call__(self, message: AxisArray) -> AxisArray:
231
+ # Offer a shortcut when there is no design function or order is 0.
232
+ if hasattr(self.settings, "order") and not self.settings.order:
233
+ return message
234
+ design_fun = self.get_design_function()
235
+ if design_fun is None:
236
+ return message
237
+
238
+ # Check if filter exists but needs redesign due to settings change
239
+ if self.state.filter is not None and self.state.needs_redesign:
240
+ axis = self.state.filter.settings.axis
241
+ fs = 1 / message.axes[axis].gain
242
+ coefs = design_fun(fs)
243
+
244
+ # Convert BA to SOS if requested
245
+ if coefs is not None and self.settings.coef_type == "sos":
246
+ if isinstance(coefs, tuple) and len(coefs) == 2:
247
+ # It's BA format, convert to SOS
248
+ b, a = coefs
249
+ coefs = scipy.signal.tf2sos(b, a)
250
+
251
+ self.state.filter.update_coefficients(coefs, coef_type=self.settings.coef_type)
252
+ self.state.needs_redesign = False
253
+
254
+ return super().__call__(message)
255
+
256
+ def _hash_message(self, message: AxisArray) -> int:
257
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
258
+ gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
259
+ axis_idx = message.get_axis_idx(axis)
260
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
261
+ return hash((message.key, samp_shape, gain))
262
+
263
+ def _reset_state(self, message: AxisArray) -> None:
264
+ design_fun = self.get_design_function()
265
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
266
+ fs = 1 / message.axes[axis].gain
267
+ coefs = design_fun(fs)
268
+
269
+ # Convert BA to SOS if requested
270
+ if coefs is not None and self.settings.coef_type == "sos":
271
+ if isinstance(coefs, tuple) and len(coefs) == 2:
272
+ # It's BA format, convert to SOS
273
+ b, a = coefs
274
+ coefs = scipy.signal.tf2sos(b, a)
275
+
276
+ new_settings = FilterSettings(axis=axis, coef_type=self.settings.coef_type, coefs=coefs)
277
+ self.state.filter = FilterTransformer(settings=new_settings)
278
+
279
+ def _process(self, message: AxisArray) -> AxisArray:
280
+ return self.state.filter(message)
281
+
282
+
283
+ class BaseFilterByDesignTransformerUnit(
284
+ BaseTransformerUnit[SettingsType, AxisArray, AxisArray, FilterByDesignTransformer],
285
+ typing.Generic[SettingsType, TransformerType],
286
+ ):
287
+ @ez.subscriber(BaseConsumerUnit.INPUT_SETTINGS)
288
+ async def on_settings(self, msg: SettingsType) -> None:
289
+ """
290
+ Receive a settings message, override self.SETTINGS, and re-create the processor.
291
+ Child classes that wish to have fine-grained control over whether the
292
+ core processor resets on settings changes should override this method.
293
+
294
+ Args:
295
+ msg: a settings message.
296
+ """
297
+ self.apply_settings(msg)
298
+
299
+ # Check if processor exists yet
300
+ if hasattr(self, "processor") and self.processor is not None:
301
+ # Update the existing processor with new settings
302
+ self.processor.update_settings(self.SETTINGS)
59
303
  else:
60
- self.STATE.filt_set.clear()
61
-
62
- if self.SETTINGS.fs is not None:
63
- try:
64
- self.update_filter()
65
- except NotImplementedError:
66
- ez.logger.debug("Using filter coefficients.")
67
-
68
- @ez.subscriber(INPUT_FILTER)
69
- async def redesign(self, message: FilterCoefficients):
70
- self.STATE.filt = message
71
-
72
- def update_filter(self):
73
- try:
74
- coefs = self.design_filter()
75
- self.STATE.filt = (
76
- FilterCoefficients() if coefs is None else FilterCoefficients(*coefs)
77
- )
78
- self.STATE.filt_set.set()
79
- self.STATE.filt_designed = True
80
- except NotImplementedError as e:
81
- raise e
82
- except Exception as e:
83
- ez.logger.warning(f"Error when designing filter: {e}")
84
-
85
- @ez.subscriber(INPUT_SIGNAL)
86
- @ez.publisher(OUTPUT_SIGNAL)
87
- async def apply_filter(self, msg: AxisArray) -> AsyncGenerator:
88
- axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
89
- axis_idx = msg.get_axis_idx(axis_name)
90
- axis = msg.get_axis(axis_name)
91
- fs = 1.0 / axis.gain
92
-
93
- if self.STATE.fs != fs and self.STATE.filt_designed is True:
94
- self.STATE.fs = fs
95
- self.update_filter()
96
-
97
- # Ensure filter is defined
98
- # TODO: Maybe have me be a passthrough filter until coefficients are received
99
- if self.STATE.filt is None:
100
- self.STATE.filt_set.clear()
101
- ez.logger.info("Awaiting filter coefficients...")
102
- await self.STATE.filt_set.wait()
103
- ez.logger.info("Filter coefficients received.")
104
-
105
- assert self.STATE.filt is not None
106
-
107
- arr_in = msg.data
108
-
109
- # If the array is one dimensional, add a temporary second dimension so that the math works out
110
- one_dimensional = False
111
- if arr_in.ndim == 1:
112
- arr_in = np.expand_dims(arr_in, axis=1)
113
- one_dimensional = True
114
-
115
- # We will perform filter with time dimension as last axis
116
- arr_in = np.moveaxis(arr_in, axis_idx, -1)
117
- samp_shape = arr_in[..., 0].shape
118
-
119
- # Re-calculate/reset zi if necessary
120
- if self.STATE.zi is None or samp_shape != self.STATE.samp_shape:
121
- zi: np.ndarray = scipy.signal.lfilter_zi(
122
- self.STATE.filt.b, self.STATE.filt.a
123
- )
124
- self.STATE.samp_shape = samp_shape
125
- self.STATE.zi = np.array([zi] * np.prod(self.STATE.samp_shape))
126
- self.STATE.zi = self.STATE.zi.reshape(
127
- tuple(list(self.STATE.samp_shape) + [zi.shape[0]])
128
- )
129
-
130
- arr_out, self.STATE.zi = scipy.signal.lfilter(
131
- self.STATE.filt.b, self.STATE.filt.a, arr_in, zi=self.STATE.zi
132
- )
133
-
134
- arr_out = np.moveaxis(arr_out, -1, axis_idx)
135
-
136
- # Remove temporary first dimension if necessary
137
- if one_dimensional:
138
- arr_out = np.squeeze(arr_out, axis=1)
139
-
140
- yield self.OUTPUT_SIGNAL, replace(msg, data=arr_out),
304
+ # Processor doesn't exist yet, create a new one
305
+ self.create_processor()