ezmsg-sigproc 2.4.0__py3-none-any.whl → 2.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.4.0'
32
- __version_tuple__ = version_tuple = (2, 4, 0)
31
+ __version__ = version = '2.5.0'
32
+ __version_tuple__ = version_tuple = (2, 5, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,132 @@
1
+ import functools
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ import scipy.signal
7
+ from ezmsg.sigproc.base import SettingsType
8
+ from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
9
+ from ezmsg.sigproc.filter import (
10
+ BACoeffs,
11
+ BaseFilterByDesignTransformerUnit,
12
+ FilterByDesignTransformer,
13
+ SOSCoeffs,
14
+ )
15
+ from ezmsg.util.messages.axisarray import AxisArray
16
+ from ezmsg.util.messages.util import replace
17
+
18
+
19
+ class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
20
+ """Settings for :obj:`ButterworthZeroPhase`."""
21
+
22
+ # axis, coef_type, order, cuton, cutoff, wn_hz are inherited from ButterworthFilterSettings
23
+ padtype: str | None = None
24
+ """
25
+ Padding type to use in `scipy.signal.filtfilt`.
26
+ Must be one of {'odd', 'even', 'constant', None}.
27
+ Default is None for no padding.
28
+ """
29
+
30
+ padlen: int | None = 0
31
+ """
32
+ Length of the padding to use in `scipy.signal.filtfilt`.
33
+ If None, SciPy's default padding is used.
34
+ """
35
+
36
+
37
+ class ButterworthZeroPhaseTransformer(
38
+ FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]
39
+ ):
40
+ """Zero-phase (filtfilt) Butterworth using your design function."""
41
+
42
+ def get_design_function(
43
+ self,
44
+ ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
45
+ return functools.partial(
46
+ butter_design_fun,
47
+ order=self.settings.order,
48
+ cuton=self.settings.cuton,
49
+ cutoff=self.settings.cutoff,
50
+ coef_type=self.settings.coef_type,
51
+ wn_hz=self.settings.wn_hz,
52
+ )
53
+
54
+ def update_settings(
55
+ self, new_settings: typing.Optional[SettingsType] = None, **kwargs
56
+ ) -> None:
57
+ """
58
+ Update settings and mark that filter coefficients need to be recalculated.
59
+
60
+ Args:
61
+ new_settings: Complete new settings object to replace current settings
62
+ **kwargs: Individual settings to update
63
+ """
64
+ # Update settings
65
+ if new_settings is not None:
66
+ self.settings = new_settings
67
+ else:
68
+ self.settings = replace(self.settings, **kwargs)
69
+
70
+ # Set flag to trigger recalculation on next message
71
+ self._coefs_cache = None
72
+ self._fs_cache = None
73
+ self.state.needs_redesign = True
74
+
75
+ def _reset_state(self, message: AxisArray) -> None:
76
+ self._coefs_cache = None
77
+ self._fs_cache = None
78
+ self.state.needs_redesign = True
79
+
80
+ def _process(self, message: AxisArray) -> AxisArray:
81
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
82
+ ax_idx = message.get_axis_idx(axis)
83
+ fs = 1 / message.axes[axis].gain
84
+
85
+ if (
86
+ self._coefs_cache is None
87
+ or self.state.needs_redesign
88
+ or (self._fs_cache is None or not np.isclose(self._fs_cache, fs))
89
+ ):
90
+ self._coefs_cache = self.get_design_function()(fs)
91
+ self._fs_cache = fs
92
+ self.state.needs_redesign = False
93
+
94
+ if (
95
+ self._coefs_cache is None
96
+ or self.settings.order <= 0
97
+ or message.data.size <= 0
98
+ ):
99
+ return message
100
+
101
+ x = message.data
102
+ if self.settings.coef_type == "sos":
103
+ y = scipy.signal.sosfiltfilt(
104
+ self._coefs_cache,
105
+ x,
106
+ axis=ax_idx,
107
+ padtype=self.settings.padtype,
108
+ padlen=self.settings.padlen,
109
+ )
110
+ elif self.settings.coef_type == "ba":
111
+ b, a = self._coefs_cache
112
+ y = scipy.signal.filtfilt(
113
+ b,
114
+ a,
115
+ x,
116
+ axis=ax_idx,
117
+ padtype=self.settings.padtype,
118
+ padlen=self.settings.padlen,
119
+ )
120
+ else:
121
+ ez.logger.error("coef_type must be 'sos' or 'ba'.")
122
+ raise ValueError("coef_type must be 'sos' or 'ba'.")
123
+
124
+ return replace(message, data=y)
125
+
126
+
127
+ class ButterworthZeroPhase(
128
+ BaseFilterByDesignTransformerUnit[
129
+ ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer
130
+ ]
131
+ ):
132
+ SETTINGS = ButterworthZeroPhaseSettings
@@ -22,15 +22,13 @@ class DenormalizeSettings(ez.Settings):
22
22
 
23
23
 
24
24
  @processor_state
25
- class DenormalizeRateState:
25
+ class DenormalizeState:
26
26
  gains: npt.NDArray | None = None
27
27
  offsets: npt.NDArray | None = None
28
28
 
29
29
 
30
30
  class DenormalizeTransformer(
31
- BaseStatefulTransformer[
32
- DenormalizeSettings, AxisArray, AxisArray, DenormalizeRateState
33
- ]
31
+ BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]
34
32
  ):
35
33
  """
36
34
  Scales data from a normalized distribution (mean=0, std=1) to a denormalized
@@ -78,7 +76,7 @@ class DenormalizeTransformer(
78
76
  )
79
77
 
80
78
 
81
- class DenormalizeRateUnit(
79
+ class DenormalizeUnit(
82
80
  BaseTransformerUnit[
83
81
  DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer
84
82
  ]
@@ -0,0 +1,353 @@
1
+ import functools
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ import scipy.signal as sps
7
+ from ezmsg.sigproc.base import BaseStatefulTransformer, processor_state
8
+ from ezmsg.sigproc.filter import (
9
+ BACoeffs,
10
+ BaseFilterByDesignTransformerUnit,
11
+ BaseTransformerUnit,
12
+ FilterBaseSettings,
13
+ FilterByDesignTransformer,
14
+ )
15
+ from ezmsg.util.messages.axisarray import AxisArray
16
+ from ezmsg.util.messages.util import replace
17
+
18
+
19
+ class FIRHilbertFilterSettings(FilterBaseSettings):
20
+ """Settings for :obj:`FIRHilbertFilter`."""
21
+
22
+ # axis inherited from FilterBaseSettings
23
+
24
+ coef_type: str = "ba"
25
+ """
26
+ Coefficient type. Must be 'ba' for FIR.
27
+ """
28
+
29
+ order: int = 170
30
+ """
31
+ Filter order (taps = order + 1).
32
+ Hilbert (type-III) filters require even order (odd taps).
33
+ If odd order (even taps), order will be incremented by 1.
34
+ """
35
+
36
+ f_lo: float = 1.0
37
+ """
38
+ Lower corner of Hilbert “pass” band (Hz).
39
+ Transition starts at f_lo.
40
+ """
41
+
42
+ f_hi: float | None = None
43
+ """
44
+ Upper corner of Hilbert “pass” band (Hz).
45
+ Transition starts at f_hi.
46
+ If None, highpass from f_lo to Nyquist.
47
+ """
48
+
49
+ trans_lo: float = 1.0
50
+ """
51
+ Transition width (Hz) below f_lo.
52
+ Decrease to sharpen transition.
53
+ """
54
+
55
+ trans_hi: float = 1.0
56
+ """
57
+ Transition width (Hz) at high end.
58
+ Decrease to sharpen transition.
59
+ """
60
+
61
+ weight_pass: float = 1.0
62
+ """
63
+ Weight for Hilbert pass region.
64
+ """
65
+
66
+ weight_stop_lo: float = 1.0
67
+ """
68
+ Weight for low stop band.
69
+ """
70
+
71
+ weight_stop_hi: float = 1.0
72
+ """
73
+ Weight for high stop band.
74
+ """
75
+
76
+ norm_band: tuple[float, float] | None = None
77
+ """
78
+ Optional normalization band (f_lo, f_hi) in Hz for gain normalization.
79
+ If None, no normalization is applied.
80
+ """
81
+
82
+ norm_freq: float | None = None
83
+ """
84
+ Optional normalization frequency in Hz for gain normalization.
85
+ If None, no normalization is applied.
86
+ """
87
+
88
+
89
+ def fir_hilbert_design_fun(
90
+ fs: float,
91
+ order: int = 170,
92
+ f_lo: float = 1.0,
93
+ f_hi: float | None = None,
94
+ trans_lo: float = 1.0,
95
+ trans_hi: float = 1.0,
96
+ weight_pass: float = 1.0,
97
+ weight_stop_lo: float = 1.0,
98
+ weight_stop_hi: float = 1.0,
99
+ norm_band: tuple[float, float] | None = None,
100
+ norm_freq: float | None = None,
101
+ ) -> BACoeffs | None:
102
+ """
103
+ Hilbert FIR filter design using the Remez exchange algorithm.
104
+ Design an `order`th-order FIR Hilbert filter and return the filter coefficients.
105
+ See :obj:`FIRHilbertFilterSettings` for argument description.
106
+
107
+ Returns:
108
+ The filter coefficients as a tuple of (b, a).
109
+ """
110
+ if order <= 0:
111
+ return None
112
+ if order % 2 == 1:
113
+ order += 1
114
+ nyq = fs / 2.0
115
+ taps = order + 1
116
+ f1 = max(f_lo, 0.0) + trans_lo
117
+ f2 = (nyq - trans_hi) if (f_hi is None) else min(f_hi, nyq - trans_hi)
118
+ if not (0.0 < f1 < f2 < nyq):
119
+ raise ValueError(
120
+ f"Hilbert passband collapsed or invalid: "
121
+ f"f_lo={f_lo}, f_hi={f_hi}, trans_lo={trans_lo}, trans_hi={trans_hi}, fs={fs}"
122
+ )
123
+ # Bands: [0, f1-trans_lo] stop ; [f1, f2] pass (Hilbert) ; [f2+trans_hi, nyq] stop
124
+ bands = [0.0, max(f1 - trans_lo, 0.0), f1, f2, min(f2 + trans_hi, nyq), nyq]
125
+ desired = [0.0, 1.0, 0.0]
126
+ weight = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0), max(weight_stop_hi, 0.0)]
127
+ for i in range(1, len(bands) - 1):
128
+ if bands[i] <= bands[i - 1]:
129
+ bands[i] = np.nextafter(bands[i - 1], np.inf)
130
+ if bands[-2] >= nyq:
131
+ ez.logger.warning(
132
+ "Hilbert upper stopband collapsed; using 2-band (stop/pass) design."
133
+ )
134
+ bands = bands[:-3] + [nyq]
135
+ desired = desired[:-1]
136
+ weight = weight[:-1]
137
+ b = sps.remez(taps, bands, desired, weight=weight, type="hilbert", fs=fs)
138
+ a = np.array([1.0])
139
+ g = None
140
+ if norm_freq is not None:
141
+ if norm_freq < f1 or norm_freq > f2:
142
+ ez.logger.warning(
143
+ "Invalid normalization frequency specifications. Skipping normalization."
144
+ )
145
+ else:
146
+ f0 = float(norm_freq)
147
+ w = 2.0 * np.pi * (np.asarray([f0], dtype=np.float64) / fs)
148
+ _, H = sps.freqz(b, a, worN=w)
149
+ g = float(np.abs(H[0]))
150
+ elif norm_band is not None:
151
+ lo, hi = norm_band
152
+ if lo < f1 or hi > f2:
153
+ lo = max(lo, f1)
154
+ hi = min(hi, f2)
155
+ ez.logger.warning(
156
+ "Normalization band outside passband. Clipping to passband for normalization."
157
+ )
158
+ if lo >= hi:
159
+ ez.logger.warning(
160
+ "Invalid normalization band specifications. Skipping normalization."
161
+ )
162
+ else:
163
+ freqs = np.linspace(lo, hi, 2048, dtype=np.float64)
164
+ w = 2.0 * np.pi * (np.asarray(freqs, dtype=np.float64) / fs)
165
+ _, H = sps.freqz(b, a, worN=w)
166
+ g = float(np.median(np.abs(H)))
167
+ if g is not None and g > 0:
168
+ b = b / g
169
+ return (b, a)
170
+
171
+
172
+ class FIRHilbertFilterTransformer(
173
+ FilterByDesignTransformer[FIRHilbertFilterSettings, BACoeffs]
174
+ ):
175
+ def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]:
176
+ if self.settings.coef_type != "ba":
177
+ ez.logger.error("FIRHilbert only supports coef_type='ba'.")
178
+ raise ValueError("FIRHilbert only supports coef_type='ba'.")
179
+
180
+ return functools.partial(
181
+ fir_hilbert_design_fun,
182
+ order=self.settings.order,
183
+ f_lo=self.settings.f_lo,
184
+ f_hi=self.settings.f_hi,
185
+ trans_lo=self.settings.trans_lo,
186
+ trans_hi=self.settings.trans_hi,
187
+ weight_pass=self.settings.weight_pass,
188
+ weight_stop_lo=self.settings.weight_stop_lo,
189
+ weight_stop_hi=self.settings.weight_stop_hi,
190
+ norm_band=self.settings.norm_band,
191
+ norm_freq=self.settings.norm_freq,
192
+ )
193
+
194
+ def get_taps(self) -> int | None:
195
+ if self._state.filter is None:
196
+ return None
197
+ b, _ = self._state.filter.settings.coefs
198
+ return b.size if b is not None else None
199
+
200
+
201
+ class FIRHilbertFilterUnit(
202
+ BaseFilterByDesignTransformerUnit[
203
+ FIRHilbertFilterSettings, FIRHilbertFilterTransformer
204
+ ]
205
+ ):
206
+ SETTINGS = FIRHilbertFilterSettings
207
+
208
+
209
+ @processor_state
210
+ class FIRHilbertEnvelopeState:
211
+ filter: FIRHilbertFilterTransformer | None = None
212
+ delay_buf: np.ndarray | None = None
213
+ dly: int | None = None
214
+
215
+
216
+ class FIRHilbertEnvelopeTransformer(
217
+ BaseStatefulTransformer[
218
+ FIRHilbertFilterSettings, AxisArray, AxisArray, FIRHilbertEnvelopeState
219
+ ]
220
+ ):
221
+ """
222
+ Processor for computing the envelope of a signal using the Hilbert transform.
223
+
224
+ This processor applies a Hilbert FIR filter to the input signal to obtain the analytic signal, from which the
225
+ envelope is computed.
226
+
227
+ The processor expects and outputs `AxisArray` messages with a `"time"` (time) axis.
228
+
229
+ Settings:
230
+ ---------
231
+ order : int
232
+ Filter order (taps = order + 1).
233
+ Hilbert (type-III) filters require even order (odd taps).
234
+ If odd order (even taps), order will be incremented by 1.
235
+ f_lo : float
236
+ Lower corner of Hilbert “pass” band (Hz).
237
+ Transition starts at f_lo.
238
+ f_hi : float, optional
239
+ Upper corner of Hilbert “pass” band (Hz).
240
+ Transition starts at f_hi.
241
+ If None, highpass from f_lo to Nyquist.
242
+ trans_lo : float
243
+ Transition width (Hz) below f_lo.
244
+ Decrease to sharpen transition.
245
+ trans_hi : float
246
+ Transition width (Hz) above f_hi.
247
+ Decrease to sharpen transition.
248
+ weight_pass : float
249
+ Weight for Hilbert pass region.
250
+ weight_stop_lo : float
251
+ Weight for low stop band.
252
+ weight_stop_hi : float
253
+ Weight for high stop band.
254
+ norm_band : tuple(float, float), optional
255
+ Optional normalization band (f_lo, f_hi) in Hz for gain normalization.
256
+ If None, no normalization is applied.
257
+ norm_freq : float, optional
258
+ Optional normalization frequency in Hz for gain normalization.
259
+ If None, no normalization is applied.
260
+
261
+ Example:
262
+ -----------------------------
263
+ ```python
264
+ processor = FIRHilbertEnvelopeTransformer(
265
+ settings=FIRHilbertFilterSettings(
266
+ order=170,
267
+ f_lo=1.0,
268
+ f_hi=50.0,
269
+ )
270
+ )
271
+ ```
272
+
273
+ """
274
+
275
+ def _hash_message(self, message: AxisArray) -> int:
276
+ axis = self.settings.axis or message.dims[0]
277
+ gain = getattr(self._state.filter, "gain", 0.0)
278
+ axis_idx = message.get_axis_idx(axis)
279
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
280
+ return hash((message.key, samp_shape, gain))
281
+
282
+ def _reset_state(self, message: AxisArray) -> None:
283
+ self._state.filter = FIRHilbertFilterTransformer(settings=self.settings)
284
+ self._state.delay_buf = None
285
+ self._state.dly = None
286
+
287
+ def _process(self, message: AxisArray) -> AxisArray:
288
+ y_imag_msg = self._state.filter(message)
289
+ y_imag = y_imag_msg.data
290
+
291
+ axis_name = self.settings.axis or message.dims[0]
292
+ axis_idx = message.get_axis_idx(axis_name)
293
+ if self._state.dly is None:
294
+ taps = self._state.filter.get_taps()
295
+ self._state.dly = (taps - 1) // 2
296
+
297
+ x = message.data
298
+
299
+ move_axis = False
300
+ if axis_idx != x.ndim - 1:
301
+ x = np.moveaxis(x, axis_idx, -1)
302
+ y_imag = np.moveaxis(y_imag, axis_idx, -1)
303
+ move_axis = True
304
+
305
+ if self._state.delay_buf is None:
306
+ lead_shape = x.shape[:-1]
307
+ self._state.delay_buf = np.zeros(
308
+ lead_shape + (self._state.dly,), dtype=x.dtype
309
+ )
310
+
311
+ x_cat = np.concatenate([self._state.delay_buf, x], axis=-1)
312
+ x_delayed_full = x_cat[..., : -self._state.dly]
313
+ y_real = x_delayed_full[..., -x.shape[-1] :]
314
+
315
+ self._state.delay_buf = x_cat[..., -self._state.dly :].copy()
316
+
317
+ analytic = y_real.astype(np.complex64) + 1j * y_imag.astype(np.complex64)
318
+ out = np.abs(analytic)
319
+
320
+ if move_axis:
321
+ out = np.moveaxis(out, -1, axis_idx)
322
+
323
+ return replace(message, data=out, axes=message.axes)
324
+
325
+
326
+ class FIRHilbertEnvelopeUnit(
327
+ BaseTransformerUnit[
328
+ FIRHilbertFilterSettings,
329
+ AxisArray,
330
+ AxisArray,
331
+ FIRHilbertEnvelopeTransformer,
332
+ ]
333
+ ):
334
+ """
335
+ Unit wrapper for the `FIRHilbertEnvelopeTransformer`.
336
+
337
+ This unit provides a plug-and-play interface for calculating the envelope using the FIR Hilbert transform on a
338
+ signal in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs processed data in the same format.
339
+
340
+ Example:
341
+ --------
342
+ ```python
343
+ unit = FIRHilbertEnvelopeUnit(
344
+ settings=FIRHilbertFilterSettings(
345
+ order=170,
346
+ f_lo=1.0,
347
+ f_hi=50.0,
348
+ )
349
+ )
350
+ ```
351
+ """
352
+
353
+ SETTINGS = FIRHilbertFilterSettings
@@ -0,0 +1,214 @@
1
+ import functools
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ import scipy.signal
7
+ from ezmsg.sigproc.filter import (
8
+ BACoeffs,
9
+ BaseFilterByDesignTransformerUnit,
10
+ FilterBaseSettings,
11
+ FilterByDesignTransformer,
12
+ )
13
+
14
+
15
+ class ParksMcClellanFIRSettings(FilterBaseSettings):
16
+ """Settings for :obj:`ParksMcClellanFIR`."""
17
+
18
+ # axis inherited from FilterBaseSettings
19
+
20
+ coef_type: str = "ba"
21
+ """
22
+ Coefficient type. Must be 'ba' for FIR.
23
+ """
24
+
25
+ order: int = 0
26
+ """
27
+ Filter order (taps = order + 1).
28
+ PMC FIR filters require even order (odd taps).
29
+ If odd order (even taps), order will be incremented by 1.
30
+ """
31
+
32
+ cuton: float | None = None
33
+ """
34
+ Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
35
+ if this is lower than `cutoff` then this is the beginning of the bandpass
36
+ or if this is greater than `cutoff` then this is the end of the bandstop.
37
+ """
38
+
39
+ cutoff: float | None = None
40
+ """
41
+ Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
42
+ if this is greater than `cuton` then this is the end of the bandpass,
43
+ or if this is less than `cuton` then this is the beginning of the bandstop.
44
+ """
45
+
46
+ transition: float = 10.0
47
+ """
48
+ Transition bandwidth (Hz) applied to each passband edge.
49
+ For low/high: single transition. For bands: both edges.
50
+ """
51
+
52
+ weight_pass: float = 1.0
53
+ """
54
+ Weight for the passband.
55
+ Used for both high and low passbands in bandstop filters.
56
+ """
57
+
58
+ weight_stop_lo: float = 1.0
59
+ """
60
+ Weight for the lower stopband.
61
+ Not used for bandstop filters.
62
+ """
63
+
64
+ weight_stop_hi: float = 1.0
65
+ """
66
+ Weight for the upper stopband.
67
+ Used as the central-stop weight for bandstop filters.
68
+ """
69
+
70
+ def filter_specs(
71
+ self,
72
+ ) -> tuple[str, tuple[float, float] | float] | None:
73
+ """
74
+ Determine the filter type given the corner frequencies.
75
+
76
+ Returns:
77
+ A tuple with the first element being a string indicating the filter type
78
+ (one of "lowpass", "highpass", "bandpass", "bandstop")
79
+ and the second element being the corner frequency or frequencies.
80
+
81
+ """
82
+ if self.cuton is None and self.cutoff is None:
83
+ return None
84
+ elif self.cuton is None and self.cutoff is not None:
85
+ return "lowpass", self.cutoff
86
+ elif self.cuton is not None and self.cutoff is None:
87
+ return "highpass", self.cuton
88
+ elif self.cuton is not None and self.cutoff is not None:
89
+ if self.cuton <= self.cutoff:
90
+ return "bandpass", (self.cuton, self.cutoff)
91
+ else:
92
+ return "bandstop", (self.cutoff, self.cuton)
93
+
94
+
95
+ def parks_mcclellan_design_fun(
96
+ fs: float,
97
+ order: int = 0,
98
+ cuton: float | None = None,
99
+ cutoff: float | None = None,
100
+ transition: float = 10.0,
101
+ weight_pass: float = 1.0,
102
+ weight_stop_lo: float = 1.0,
103
+ weight_stop_hi: float = 1.0,
104
+ ) -> BACoeffs | None:
105
+ """
106
+ See :obj:`ParksMcClellanFIRSettings.filter_specs` for an explanation of specifying different
107
+ filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
108
+
109
+ Designs a Parks-McClellan FIR filter via the Remez exchange algorithm using the given specifications.
110
+ PMC filters are equiripple and linear phase.
111
+
112
+ You are likely to want to use this function with :obj:`filter_by_design`, which only passes `fs` to the design
113
+ function (this), meaning that you should wrap this function with a lambda or prepare with functools.partial.
114
+
115
+ Args:
116
+ fs: The sampling frequency of the data in Hz.
117
+ order: Filter order.
118
+ cuton: Corner frequency of the filter in Hz.
119
+ cutoff: Corner frequency of the filter in Hz.
120
+ transition: Transition bandwidth (Hz) applied to each passband edge.
121
+ weight_pass: Weight for the passband.
122
+ weight_stop_lo: Weight for the lower stopband.
123
+ weight_stop_hi: Weight for the upper stopband.
124
+
125
+ Returns:
126
+ The filter coefficients as a tuple of (b, a).
127
+ """
128
+ if order <= 0:
129
+ return None
130
+ if order % 2 == 1:
131
+ order += 1
132
+
133
+ specs = ParksMcClellanFIRSettings(cuton=cuton, cutoff=cutoff).filter_specs()
134
+ if specs is None:
135
+ # Under-specified: no filter
136
+ return None
137
+
138
+ btype, corners = specs
139
+ nyq = fs / 2.0
140
+ tw = max(transition, 0.0)
141
+
142
+ def clip_hz(x: float) -> float:
143
+ return float(min(max(x, 0.0), nyq))
144
+
145
+ if btype == "lowpass":
146
+ b = [0.0, clip_hz(corners), clip_hz(corners + tw), nyq]
147
+ d = [1.0, 0.0]
148
+ w = [max(weight_pass, 0.0), max(weight_stop_hi, 0.0)]
149
+
150
+ elif btype == "highpass":
151
+ b = [0.0, clip_hz(corners - tw), clip_hz(corners), nyq]
152
+ d = [0.0, 1.0]
153
+ w = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0)]
154
+
155
+ elif btype == "bandpass":
156
+ b = [
157
+ 0.0,
158
+ clip_hz(corners[0] - tw),
159
+ clip_hz(corners[0]),
160
+ clip_hz(corners[1]),
161
+ clip_hz(corners[1] + tw),
162
+ nyq,
163
+ ]
164
+ d = [0.0, 1.0, 0.0]
165
+ w = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0), max(weight_stop_hi, 0.0)]
166
+
167
+ else:
168
+ b = [
169
+ 0.0,
170
+ clip_hz(corners[0]),
171
+ clip_hz(corners[0] + tw),
172
+ clip_hz(corners[1] - tw),
173
+ clip_hz(corners[1]),
174
+ nyq,
175
+ ]
176
+ d = [1.0, 0.0, 1.0]
177
+ # For bandstop we can reuse stop_hi as central-stop weight; stop_lo is the DC-side passband stop weight
178
+ w = [max(weight_pass, 0.0), max(weight_stop_hi, 0.0), max(weight_pass, 0.0)]
179
+
180
+ # Ensure bands strictly increase and have nonzero width per segment
181
+ # Adjust tiny overlaps due to clipping
182
+ for i in range(1, len(b)):
183
+ if b[i] <= b[i - 1]:
184
+ b[i] = min(b[i - 1] + 1e-6, nyq)
185
+
186
+ b = scipy.signal.remez(numtaps=order + 1, bands=b, desired=d, weight=w, fs=fs)
187
+ return (b, np.array([1.0]))
188
+
189
+
190
+ class ParksMcClellanFIRTransformer(
191
+ FilterByDesignTransformer[ParksMcClellanFIRSettings, BACoeffs]
192
+ ):
193
+ def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]:
194
+ if self.settings.coef_type != "ba":
195
+ ez.logger.error("ParksMcClellanFIR only supports coef_type='ba'.")
196
+ raise ValueError("ParksMcClellanFIR only supports coef_type='ba'.")
197
+ return functools.partial(
198
+ parks_mcclellan_design_fun,
199
+ order=self.settings.order,
200
+ cuton=self.settings.cuton,
201
+ cutoff=self.settings.cutoff,
202
+ transition=self.settings.transition,
203
+ weight_pass=self.settings.weight_pass,
204
+ weight_stop_lo=self.settings.weight_stop_lo,
205
+ weight_stop_hi=self.settings.weight_stop_hi,
206
+ )
207
+
208
+
209
+ class ParksMcClellanFIR(
210
+ BaseFilterByDesignTransformerUnit[
211
+ ParksMcClellanFIRSettings, ParksMcClellanFIRTransformer
212
+ ]
213
+ ):
214
+ SETTINGS = ParksMcClellanFIRSettings
@@ -0,0 +1,257 @@
1
+ from collections import deque
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ from ezmsg.sigproc.base import (
7
+ BaseAdaptiveTransformer,
8
+ BaseAdaptiveTransformerUnit,
9
+ processor_state,
10
+ )
11
+ from ezmsg.sigproc.sampler import SampleMessage
12
+ from ezmsg.util.messages.axisarray import AxisArray
13
+ from ezmsg.util.messages.util import replace
14
+
15
+
16
+ class RollingScalerSettings(ez.Settings):
17
+ axis: str = "time"
18
+ """
19
+ Axis along which samples are arranged.
20
+ """
21
+
22
+ k_samples: int | None = 20
23
+ """
24
+ Rolling window size in number of samples.
25
+ """
26
+
27
+ window_size: float | None = None
28
+ """
29
+ Rolling window size in seconds.
30
+ If set, overrides `k_samples`.
31
+ `update_with_signal` likely should be True if using this option.
32
+ """
33
+
34
+ update_with_signal: bool = False
35
+ """
36
+ If True, update rolling statistics using the incoming process stream.
37
+ """
38
+
39
+ min_samples: int = 1
40
+ """
41
+ Minimum number of samples required to compute statistics.
42
+ Used when `window_size` is not set.
43
+ """
44
+
45
+ min_seconds: float = 1.0
46
+ """
47
+ Minimum duration in seconds required to compute statistics.
48
+ Used when `window_size` is set.
49
+ """
50
+
51
+ artifact_z_thresh: float | None = None
52
+ """
53
+ Threshold for z-score based artifact detection.
54
+ If set, samples with any channel exceeding this z-score will be excluded
55
+ from updating the rolling statistics.
56
+ """
57
+
58
+ clip: float | None = 10.0
59
+ """
60
+ If set, clip the output values to the range [-clip, clip].
61
+ """
62
+
63
+
64
+ @processor_state
65
+ class RollingScalerState:
66
+ mean: npt.NDArray | None = None
67
+ N: int = 0
68
+ M2: npt.NDArray | None = None
69
+ samples: deque | None = None
70
+ k_samples: int | None = None
71
+ min_samples: int | None = None
72
+
73
+
74
+ class RollingScalerProcessor(
75
+ BaseAdaptiveTransformer[
76
+ RollingScalerSettings, AxisArray, AxisArray, RollingScalerState
77
+ ]
78
+ ):
79
+ """
80
+ Processor for rolling z-score normalization of input `AxisArray` messages.
81
+
82
+ The processor maintains rolling statistics (mean and variance) over the last `k_samples`
83
+ samples received via the `partial_fit()` method. When processing an `AxisArray` message,
84
+ it normalizes the data using the current rolling statistics.
85
+
86
+ The input `AxisArray` messages are expected to have shape `(time, ch)`, where `ch` is the
87
+ channel axis. The processor computes the z-score for each channel independently.
88
+
89
+ Note: You should consider instead using the AdaptiveStandardScalerTransformer which
90
+ is computationally more efficient and uses less memory. This RollingScalerProcessor
91
+ is primarily provided to reproduce processing in the literature.
92
+
93
+ Settings:
94
+ ---------
95
+ k_samples: int
96
+ Number of previous samples to use for rolling statistics.
97
+
98
+ Example:
99
+ -----------------------------
100
+ ```python
101
+ processor = RollingScalerProcessor(
102
+ settings=RollingScalerSettings(
103
+ k_samples=20 # Number of previous samples to use for rolling statistics
104
+ )
105
+ )
106
+ ```
107
+ """
108
+
109
+ def _hash_message(self, message: AxisArray) -> int:
110
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
111
+ gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
112
+ axis_idx = message.get_axis_idx(axis)
113
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
114
+ return hash((message.key, samp_shape, gain))
115
+
116
+ def _reset_state(self, message: AxisArray) -> None:
117
+ ch = message.data.shape[-1]
118
+ self._state.mean = np.zeros(ch)
119
+ self._state.N = 0
120
+ self._state.M2 = np.zeros(ch)
121
+ self._state.k_samples = (
122
+ int(
123
+ np.ceil(
124
+ self.settings.window_size / message.axes[self.settings.axis].gain
125
+ )
126
+ )
127
+ if self.settings.window_size is not None
128
+ else self.settings.k_samples
129
+ )
130
+ if self._state.k_samples is not None and self._state.k_samples < 1:
131
+ ez.logger.warning(
132
+ "window_size smaller than sample gain; setting k_samples to 1."
133
+ )
134
+ self._state.k_samples = 1
135
+ elif self._state.k_samples is None:
136
+ ez.logger.warning(
137
+ "k_samples is None; z-score accumulation will be unbounded."
138
+ )
139
+ self._state.samples = deque(maxlen=self._state.k_samples)
140
+ self._state.min_samples = (
141
+ int(
142
+ np.ceil(
143
+ self.settings.min_seconds / message.axes[self.settings.axis].gain
144
+ )
145
+ )
146
+ if self.settings.window_size is not None
147
+ else self.settings.min_samples
148
+ )
149
+ if (
150
+ self._state.k_samples is not None
151
+ and self._state.min_samples > self._state.k_samples
152
+ ):
153
+ ez.logger.warning(
154
+ "min_samples is greater than k_samples; adjusting min_samples to k_samples."
155
+ )
156
+ self._state.min_samples = self._state.k_samples
157
+
158
+ def _add_batch_stats(self, x: npt.NDArray) -> None:
159
+ x = np.asarray(x, dtype=np.float64)
160
+ n_b = x.shape[0]
161
+ mean_b = np.mean(x, axis=0)
162
+ M2_b = np.sum((x - mean_b) ** 2, axis=0)
163
+
164
+ if (
165
+ self._state.k_samples is not None
166
+ and len(self._state.samples) == self._state.k_samples
167
+ ):
168
+ n_old, mean_old, M2_old = self._state.samples.popleft()
169
+ N_T = self._state.N
170
+ N_new = N_T - n_old
171
+
172
+ if N_new <= 0:
173
+ self._state.N = 0
174
+ self._state.mean = np.zeros_like(self._state.mean)
175
+ self._state.M2 = np.zeros_like(self._state.M2)
176
+ else:
177
+ delta = mean_old - self._state.mean
178
+ self._state.N = N_new
179
+ self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
180
+ self._state.M2 = (
181
+ self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
182
+ )
183
+
184
+ N_A = self._state.N
185
+ N = N_A + n_b
186
+ delta = mean_b - self._state.mean
187
+ self._state.mean = self._state.mean + delta * (n_b / N)
188
+ self._state.M2 = self._state.M2 + M2_b + (delta * delta) * (N_A * n_b / N)
189
+ self._state.N = N
190
+
191
+ self._state.samples.append((n_b, mean_b, M2_b))
192
+
193
+ def partial_fit(self, message: SampleMessage) -> None:
194
+ x = message.sample.data
195
+ self._add_batch_stats(x)
196
+
197
+ def _process(self, message: AxisArray) -> AxisArray:
198
+ if self._state.N == 0 or self._state.N < self._state.min_samples:
199
+ if self.settings.update_with_signal:
200
+ x = message.data
201
+ if self.settings.artifact_z_thresh is not None and self._state.N > 0:
202
+ varis = self._state.M2 / self._state.N
203
+ std = np.maximum(np.sqrt(varis), 1e-8)
204
+ z = np.abs((x - self._state.mean) / std)
205
+ mask = np.any(z > self.settings.artifact_z_thresh, axis=1)
206
+ x = x[~mask]
207
+ if x.size > 0:
208
+ self._add_batch_stats(x)
209
+ return message
210
+
211
+ varis = self._state.M2 / self._state.N
212
+ std = np.maximum(np.sqrt(varis), 1e-8)
213
+ with np.errstate(divide="ignore", invalid="ignore"):
214
+ result = (message.data - self._state.mean) / std
215
+ result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
216
+ if self.settings.clip is not None:
217
+ result = np.clip(result, -self.settings.clip, self.settings.clip)
218
+
219
+ if self.settings.update_with_signal:
220
+ x = message.data
221
+ if self.settings.artifact_z_thresh is not None:
222
+ z_scores = np.abs((x - self._state.mean) / std)
223
+ mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1)
224
+ x = x[~mask]
225
+ if x.size > 0:
226
+ self._add_batch_stats(x)
227
+
228
+ return replace(message, data=result)
229
+
230
+
231
+ class RollingScalerUnit(
232
+ BaseAdaptiveTransformerUnit[
233
+ RollingScalerSettings,
234
+ AxisArray,
235
+ AxisArray,
236
+ RollingScalerProcessor,
237
+ ]
238
+ ):
239
+ """
240
+ Unit wrapper for :obj:`RollingScalerProcessor`.
241
+
242
+ This unit performs rolling z-score normalization on incoming `AxisArray` messages. The unit maintains rolling
243
+ statistics (mean and variance) over the last `k_samples` samples received. When processing an `AxisArray` message,
244
+ it normalizes the data using the current rolling statistics.
245
+
246
+ Example:
247
+ -----------------------------
248
+ ```python
249
+ unit = RollingScalerUnit(
250
+ settings=RollingScalerSettings(
251
+ k_samples=20 # Number of previous samples to use for rolling statistics
252
+ )
253
+ )
254
+ ```
255
+ """
256
+
257
+ SETTINGS = RollingScalerSettings
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-sigproc
3
- Version: 2.4.0
3
+ Version: 2.5.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
- Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
5
+ Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
7
7
  License-File: LICENSE.txt
8
8
  Requires-Python: >=3.10.15
@@ -1,5 +1,5 @@
1
1
  ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
2
- ezmsg/sigproc/__version__.py,sha256=69_lBCO99qjONN2phoPwQ0THjdLm1VftCSvDKqELbGk,704
2
+ ezmsg/sigproc/__version__.py,sha256=bdfCUdK0KgIdZbIExZ_otf09cMr8k-qiQbglewDXQI8,704
3
3
  ezmsg/sigproc/activation.py,sha256=qWAhpbFBxSoqbGy4P9JKE5LY-5v8rQI1U81OvNxBG2Y,2820
4
4
  ezmsg/sigproc/adaptive_lattice_notch.py,sha256=3M65PrZpdgBlQtE7Ph4Gu2ISIyWw4j8Xxhm5PpSkLFw,9102
5
5
  ezmsg/sigproc/affinetransform.py,sha256=WU495KoDKZfHPS3Dumh65rgf639koNlfDIx_torIByg,8662
@@ -7,10 +7,11 @@ ezmsg/sigproc/aggregate.py,sha256=wHUP_aS9NgnOxBCPN1_tSxCqMMb8UPBEoKwGKX7-ASk,91
7
7
  ezmsg/sigproc/bandpower.py,sha256=j-Y6iWjD2xkggfi-4HAFJVBPJHHBGvAZy1uM4murZkQ,2319
8
8
  ezmsg/sigproc/base.py,sha256=PQr03O2P1v9LzcSR0GJLvPpBCLtnmGaz76gUeXphcH4,48753
9
9
  ezmsg/sigproc/butterworthfilter.py,sha256=7ZP4CRsXBt3-5dzyUjD45vc0J3Fhpm4CLrk-ps28jhc,5305
10
+ ezmsg/sigproc/butterworthzerophase.py,sha256=B95FxHBk0uSXizsndR5yc8I2V_gXVNWZ9WVMS4m1Hek,4190
10
11
  ezmsg/sigproc/cheby.py,sha256=-aSauAwxJmmSSiRaw5qGY9rvYFOmk1bZlS4gGrS0jls,3737
11
12
  ezmsg/sigproc/combfilter.py,sha256=5UCfzGESpS5LSx6rxZv8_n25ZUvOOmws-mM_gpTZNhU,4777
12
13
  ezmsg/sigproc/decimate.py,sha256=Lz46fBllWagu17QeQzgklm6GWCV-zPysiydiby2IElU,2347
13
- ezmsg/sigproc/denormalize.py,sha256=qMXkxpNoEACHzEfluA0wV4716HQyGE_1tcFAa8uzhIc,3091
14
+ ezmsg/sigproc/denormalize.py,sha256=CujviBepGysjB5X7RZoDOMC5tUC97ryHnUdqhi-eMPo,3065
14
15
  ezmsg/sigproc/detrend.py,sha256=7bpjFKdk2b6FdVn2GEtMbWtCuk7ToeiYKEBHVbN4Gd0,903
15
16
  ezmsg/sigproc/diff.py,sha256=P5BBjR7KdaCL9aD3GG09cmC7a-3cxDeEUw4nKdQ1HY8,2895
16
17
  ezmsg/sigproc/downsample.py,sha256=0X6EwPZ_XTwA2-nx5w-2HmMZUEDFuGAYF5EmPSuuVj8,3721
@@ -21,12 +22,15 @@ ezmsg/sigproc/fbcca.py,sha256=8NTJAOpHIvNFwQepui2_ZaJV4SMDFgXrqoWJyiQdF5U,12362
21
22
  ezmsg/sigproc/filter.py,sha256=1MQUZDFIf6HAHuuhGQEvH4Yd6Jv_vv12PM25YaHjdxc,11921
22
23
  ezmsg/sigproc/filterbank.py,sha256=pJzv_G6chgWa1ARmRjMAMgt9eEGnA-ZbMSge4EWrcYY,13633
23
24
  ezmsg/sigproc/filterbankdesign.py,sha256=OfIXM0ushSqbdSQG9DZB1Mh57d-lqdJQX8aqfxNN67E,4734
25
+ ezmsg/sigproc/fir_hilbert.py,sha256=qqHTp-yIhAD3VBoENTxpBmy7TgF2lYqbZ65OSfqeWO4,11042
26
+ ezmsg/sigproc/fir_pmc.py,sha256=ApWMl7WNQ9Ihr-J74DrAVwxD1r8gvLcElYcEL0RtQ2U,7024
24
27
  ezmsg/sigproc/firfilter.py,sha256=MCrwY3DLq-uMLX04JswVB9oHBSYJGbdUiQYW6eRdkxE,3805
25
28
  ezmsg/sigproc/gaussiansmoothing.py,sha256=NaVezgNwdvp-kam1I_7lSID4Obi0UCxZshH7A2afaVg,2692
26
29
  ezmsg/sigproc/kaiser.py,sha256=WsZB8a4DP7WwrYLlGczHS61L86TiH6qEStAB6zxODhY,3502
27
30
  ezmsg/sigproc/messages.py,sha256=y_twVPK7TxRj8ajmuSuBuxwvLTgyv9OF7Y7v9bw1tfs,926
28
31
  ezmsg/sigproc/quantize.py,sha256=VzaqE6PatibEjkk7XrGO-ubAXYurAed9FYOn4bcQZQk,2193
29
32
  ezmsg/sigproc/resample.py,sha256=wqSM7g3QrcrklCeGVNN4l_qZLSXRUPHXCUxl1L47300,11654
33
+ ezmsg/sigproc/rollingscaler.py,sha256=RrVAoN7cRvFz7kHSyeQr1pjKiKkJDM_1ChQ5V9FWZKo,8860
30
34
  ezmsg/sigproc/sampler.py,sha256=D5oMIZHAJS6XIKMdOHsDw97d4ZxfNP7iZwpc6J8Jmpk,10898
31
35
  ezmsg/sigproc/scaler.py,sha256=fCLHvCNUSgv0XChf8iS9s5uHCSCVjCasM2TCvyG5BwQ,4111
32
36
  ezmsg/sigproc/signalinjector.py,sha256=hGC837JyDLtAGrfsdMwzEoOqWXiwP7r7sGlUC9nahTY,2948
@@ -53,7 +57,7 @@ ezmsg/sigproc/util/message.py,sha256=l_b1b6bXX8N6VF9RbUELzsHs73cKkDURBdIr0lt3CY0
53
57
  ezmsg/sigproc/util/profile.py,sha256=KNJ_QkKelQHNEp2C8MhqzdhYydMNULc_NQq3ccMfzIk,5775
54
58
  ezmsg/sigproc/util/sparse.py,sha256=mE64p1tYb5A1shaRE1D-VnH-RshbLb8g8kXSXxnA-J4,4842
55
59
  ezmsg/sigproc/util/typeresolution.py,sha256=5R7xmG-F4CkdqQ5aoQnqM-htQb-VwAJl58jJgxtClys,3146
56
- ezmsg_sigproc-2.4.0.dist-info/METADATA,sha256=FcsrFuRHBBbdrHsdlVGJjU7hUGkX-ql3xYWGAPdkD1M,4977
57
- ezmsg_sigproc-2.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
58
- ezmsg_sigproc-2.4.0.dist-info/licenses/LICENSE.txt,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
59
- ezmsg_sigproc-2.4.0.dist-info/RECORD,,
60
+ ezmsg_sigproc-2.5.0.dist-info/METADATA,sha256=SiHigniH10jk8aeW-C7SLMqdldTbpphobjwVJbaBdX0,5019
61
+ ezmsg_sigproc-2.5.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
62
+ ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
63
+ ezmsg_sigproc-2.5.0.dist-info/RECORD,,