ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
  4. ezmsg/sigproc/affinetransform.py +13 -38
  5. ezmsg/sigproc/aggregate.py +13 -30
  6. ezmsg/sigproc/bandpower.py +7 -15
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +123 -0
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/decimate.py +2 -6
  13. ezmsg/sigproc/denormalize.py +6 -11
  14. ezmsg/sigproc/detrend.py +3 -4
  15. ezmsg/sigproc/diff.py +8 -17
  16. ezmsg/sigproc/downsample.py +6 -14
  17. ezmsg/sigproc/ewma.py +11 -27
  18. ezmsg/sigproc/ewmfilter.py +1 -1
  19. ezmsg/sigproc/extract_axis.py +3 -4
  20. ezmsg/sigproc/fbcca.py +31 -56
  21. ezmsg/sigproc/filter.py +19 -45
  22. ezmsg/sigproc/filterbank.py +33 -70
  23. ezmsg/sigproc/filterbankdesign.py +5 -12
  24. ezmsg/sigproc/fir_hilbert.py +336 -0
  25. ezmsg/sigproc/fir_pmc.py +209 -0
  26. ezmsg/sigproc/firfilter.py +12 -14
  27. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  28. ezmsg/sigproc/kaiser.py +11 -15
  29. ezmsg/sigproc/math/abs.py +1 -3
  30. ezmsg/sigproc/math/add.py +121 -0
  31. ezmsg/sigproc/math/clip.py +1 -1
  32. ezmsg/sigproc/math/difference.py +98 -36
  33. ezmsg/sigproc/math/invert.py +1 -3
  34. ezmsg/sigproc/math/log.py +2 -6
  35. ezmsg/sigproc/messages.py +1 -2
  36. ezmsg/sigproc/quantize.py +2 -4
  37. ezmsg/sigproc/resample.py +13 -34
  38. ezmsg/sigproc/rollingscaler.py +232 -0
  39. ezmsg/sigproc/sampler.py +17 -35
  40. ezmsg/sigproc/scaler.py +8 -18
  41. ezmsg/sigproc/signalinjector.py +6 -16
  42. ezmsg/sigproc/slicer.py +9 -28
  43. ezmsg/sigproc/spectral.py +3 -3
  44. ezmsg/sigproc/spectrogram.py +12 -19
  45. ezmsg/sigproc/spectrum.py +12 -32
  46. ezmsg/sigproc/transpose.py +7 -18
  47. ezmsg/sigproc/util/asio.py +25 -156
  48. ezmsg/sigproc/util/axisarray_buffer.py +10 -26
  49. ezmsg/sigproc/util/buffer.py +18 -43
  50. ezmsg/sigproc/util/message.py +17 -31
  51. ezmsg/sigproc/util/profile.py +23 -174
  52. ezmsg/sigproc/util/sparse.py +5 -15
  53. ezmsg/sigproc/util/typeresolution.py +17 -83
  54. ezmsg/sigproc/wavelets.py +6 -15
  55. ezmsg/sigproc/window.py +24 -78
  56. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
  57. ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
  58. ezmsg/sigproc/synth.py +0 -774
  59. ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
  60. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
  61. /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,336 @@
1
+ import functools
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ import scipy.signal as sps
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.util import replace
9
+
10
+ from ezmsg.sigproc.base import BaseStatefulTransformer, processor_state
11
+ from ezmsg.sigproc.filter import (
12
+ BACoeffs,
13
+ BaseFilterByDesignTransformerUnit,
14
+ BaseTransformerUnit,
15
+ FilterBaseSettings,
16
+ FilterByDesignTransformer,
17
+ )
18
+
19
+
20
+ class FIRHilbertFilterSettings(FilterBaseSettings):
21
+ """Settings for :obj:`FIRHilbertFilter`."""
22
+
23
+ # axis inherited from FilterBaseSettings
24
+
25
+ coef_type: str = "ba"
26
+ """
27
+ Coefficient type. Must be 'ba' for FIR.
28
+ """
29
+
30
+ order: int = 170
31
+ """
32
+ Filter order (taps = order + 1).
33
+ Hilbert (type-III) filters require even order (odd taps).
34
+ If odd order (even taps), order will be incremented by 1.
35
+ """
36
+
37
+ f_lo: float = 1.0
38
+ """
39
+ Lower corner of Hilbert “pass” band (Hz).
40
+ Transition starts at f_lo.
41
+ """
42
+
43
+ f_hi: float | None = None
44
+ """
45
+ Upper corner of Hilbert “pass” band (Hz).
46
+ Transition starts at f_hi.
47
+ If None, highpass from f_lo to Nyquist.
48
+ """
49
+
50
+ trans_lo: float = 1.0
51
+ """
52
+ Transition width (Hz) below f_lo.
53
+ Decrease to sharpen transition.
54
+ """
55
+
56
+ trans_hi: float = 1.0
57
+ """
58
+ Transition width (Hz) at high end.
59
+ Decrease to sharpen transition.
60
+ """
61
+
62
+ weight_pass: float = 1.0
63
+ """
64
+ Weight for Hilbert pass region.
65
+ """
66
+
67
+ weight_stop_lo: float = 1.0
68
+ """
69
+ Weight for low stop band.
70
+ """
71
+
72
+ weight_stop_hi: float = 1.0
73
+ """
74
+ Weight for high stop band.
75
+ """
76
+
77
+ norm_band: tuple[float, float] | None = None
78
+ """
79
+ Optional normalization band (f_lo, f_hi) in Hz for gain normalization.
80
+ If None, no normalization is applied.
81
+ """
82
+
83
+ norm_freq: float | None = None
84
+ """
85
+ Optional normalization frequency in Hz for gain normalization.
86
+ If None, no normalization is applied.
87
+ """
88
+
89
+
90
+ def fir_hilbert_design_fun(
91
+ fs: float,
92
+ order: int = 170,
93
+ f_lo: float = 1.0,
94
+ f_hi: float | None = None,
95
+ trans_lo: float = 1.0,
96
+ trans_hi: float = 1.0,
97
+ weight_pass: float = 1.0,
98
+ weight_stop_lo: float = 1.0,
99
+ weight_stop_hi: float = 1.0,
100
+ norm_band: tuple[float, float] | None = None,
101
+ norm_freq: float | None = None,
102
+ ) -> BACoeffs | None:
103
+ """
104
+ Hilbert FIR filter design using the Remez exchange algorithm.
105
+ Design an `order`th-order FIR Hilbert filter and return the filter coefficients.
106
+ See :obj:`FIRHilbertFilterSettings` for argument description.
107
+
108
+ Returns:
109
+ The filter coefficients as a tuple of (b, a).
110
+ """
111
+ if order <= 0:
112
+ return None
113
+ if order % 2 == 1:
114
+ order += 1
115
+ nyq = fs / 2.0
116
+ taps = order + 1
117
+ f1 = max(f_lo, 0.0) + trans_lo
118
+ f2 = (nyq - trans_hi) if (f_hi is None) else min(f_hi, nyq - trans_hi)
119
+ if not (0.0 < f1 < f2 < nyq):
120
+ raise ValueError(
121
+ f"Hilbert passband collapsed or invalid: "
122
+ f"f_lo={f_lo}, f_hi={f_hi}, trans_lo={trans_lo}, trans_hi={trans_hi}, fs={fs}"
123
+ )
124
+ # Bands: [0, f1-trans_lo] stop ; [f1, f2] pass (Hilbert) ; [f2+trans_hi, nyq] stop
125
+ bands = [0.0, max(f1 - trans_lo, 0.0), f1, f2, min(f2 + trans_hi, nyq), nyq]
126
+ desired = [0.0, 1.0, 0.0]
127
+ weight = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0), max(weight_stop_hi, 0.0)]
128
+ for i in range(1, len(bands) - 1):
129
+ if bands[i] <= bands[i - 1]:
130
+ bands[i] = np.nextafter(bands[i - 1], np.inf)
131
+ if bands[-2] >= nyq:
132
+ ez.logger.warning("Hilbert upper stopband collapsed; using 2-band (stop/pass) design.")
133
+ bands = bands[:-3] + [nyq]
134
+ desired = desired[:-1]
135
+ weight = weight[:-1]
136
+ b = sps.remez(taps, bands, desired, weight=weight, type="hilbert", fs=fs)
137
+ a = np.array([1.0])
138
+ g = None
139
+ if norm_freq is not None:
140
+ if norm_freq < f1 or norm_freq > f2:
141
+ ez.logger.warning("Invalid normalization frequency specifications. Skipping normalization.")
142
+ else:
143
+ f0 = float(norm_freq)
144
+ w = 2.0 * np.pi * (np.asarray([f0], dtype=np.float64) / fs)
145
+ _, H = sps.freqz(b, a, worN=w)
146
+ g = float(np.abs(H[0]))
147
+ elif norm_band is not None:
148
+ lo, hi = norm_band
149
+ if lo < f1 or hi > f2:
150
+ lo = max(lo, f1)
151
+ hi = min(hi, f2)
152
+ ez.logger.warning("Normalization band outside passband. Clipping to passband for normalization.")
153
+ if lo >= hi:
154
+ ez.logger.warning("Invalid normalization band specifications. Skipping normalization.")
155
+ else:
156
+ freqs = np.linspace(lo, hi, 2048, dtype=np.float64)
157
+ w = 2.0 * np.pi * (np.asarray(freqs, dtype=np.float64) / fs)
158
+ _, H = sps.freqz(b, a, worN=w)
159
+ g = float(np.median(np.abs(H)))
160
+ if g is not None and g > 0:
161
+ b = b / g
162
+ return (b, a)
163
+
164
+
165
+ class FIRHilbertFilterTransformer(FilterByDesignTransformer[FIRHilbertFilterSettings, BACoeffs]):
166
+ def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]:
167
+ if self.settings.coef_type != "ba":
168
+ ez.logger.error("FIRHilbert only supports coef_type='ba'.")
169
+ raise ValueError("FIRHilbert only supports coef_type='ba'.")
170
+
171
+ return functools.partial(
172
+ fir_hilbert_design_fun,
173
+ order=self.settings.order,
174
+ f_lo=self.settings.f_lo,
175
+ f_hi=self.settings.f_hi,
176
+ trans_lo=self.settings.trans_lo,
177
+ trans_hi=self.settings.trans_hi,
178
+ weight_pass=self.settings.weight_pass,
179
+ weight_stop_lo=self.settings.weight_stop_lo,
180
+ weight_stop_hi=self.settings.weight_stop_hi,
181
+ norm_band=self.settings.norm_band,
182
+ norm_freq=self.settings.norm_freq,
183
+ )
184
+
185
+ def get_taps(self) -> int | None:
186
+ if self._state.filter is None:
187
+ return None
188
+ b, _ = self._state.filter.settings.coefs
189
+ return b.size if b is not None else None
190
+
191
+
192
+ class FIRHilbertFilterUnit(BaseFilterByDesignTransformerUnit[FIRHilbertFilterSettings, FIRHilbertFilterTransformer]):
193
+ SETTINGS = FIRHilbertFilterSettings
194
+
195
+
196
+ @processor_state
197
+ class FIRHilbertEnvelopeState:
198
+ filter: FIRHilbertFilterTransformer | None = None
199
+ delay_buf: np.ndarray | None = None
200
+ dly: int | None = None
201
+
202
+
203
+ class FIRHilbertEnvelopeTransformer(
204
+ BaseStatefulTransformer[FIRHilbertFilterSettings, AxisArray, AxisArray, FIRHilbertEnvelopeState]
205
+ ):
206
+ """
207
+ Processor for computing the envelope of a signal using the Hilbert transform.
208
+
209
+ This processor applies a Hilbert FIR filter to the input signal to obtain the analytic signal, from which the
210
+ envelope is computed.
211
+
212
+ The processor expects and outputs `AxisArray` messages with a `"time"` (time) axis.
213
+
214
+ Settings:
215
+ ---------
216
+ order : int
217
+ Filter order (taps = order + 1).
218
+ Hilbert (type-III) filters require even order (odd taps).
219
+ If odd order (even taps), order will be incremented by 1.
220
+ f_lo : float
221
+ Lower corner of Hilbert “pass” band (Hz).
222
+ Transition starts at f_lo.
223
+ f_hi : float, optional
224
+ Upper corner of Hilbert “pass” band (Hz).
225
+ Transition starts at f_hi.
226
+ If None, highpass from f_lo to Nyquist.
227
+ trans_lo : float
228
+ Transition width (Hz) below f_lo.
229
+ Decrease to sharpen transition.
230
+ trans_hi : float
231
+ Transition width (Hz) above f_hi.
232
+ Decrease to sharpen transition.
233
+ weight_pass : float
234
+ Weight for Hilbert pass region.
235
+ weight_stop_lo : float
236
+ Weight for low stop band.
237
+ weight_stop_hi : float
238
+ Weight for high stop band.
239
+ norm_band : tuple(float, float), optional
240
+ Optional normalization band (f_lo, f_hi) in Hz for gain normalization.
241
+ If None, no normalization is applied.
242
+ norm_freq : float, optional
243
+ Optional normalization frequency in Hz for gain normalization.
244
+ If None, no normalization is applied.
245
+
246
+ Example:
247
+ -----------------------------
248
+ ```python
249
+ processor = FIRHilbertEnvelopeTransformer(
250
+ settings=FIRHilbertFilterSettings(
251
+ order=170,
252
+ f_lo=1.0,
253
+ f_hi=50.0,
254
+ )
255
+ )
256
+ ```
257
+
258
+ """
259
+
260
+ def _hash_message(self, message: AxisArray) -> int:
261
+ axis = self.settings.axis or message.dims[0]
262
+ gain = getattr(self._state.filter, "gain", 0.0)
263
+ axis_idx = message.get_axis_idx(axis)
264
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
265
+ return hash((message.key, samp_shape, gain))
266
+
267
+ def _reset_state(self, message: AxisArray) -> None:
268
+ self._state.filter = FIRHilbertFilterTransformer(settings=self.settings)
269
+ self._state.delay_buf = None
270
+ self._state.dly = None
271
+
272
+ def _process(self, message: AxisArray) -> AxisArray:
273
+ y_imag_msg = self._state.filter(message)
274
+ y_imag = y_imag_msg.data
275
+
276
+ axis_name = self.settings.axis or message.dims[0]
277
+ axis_idx = message.get_axis_idx(axis_name)
278
+ if self._state.dly is None:
279
+ taps = self._state.filter.get_taps()
280
+ self._state.dly = (taps - 1) // 2
281
+
282
+ x = message.data
283
+
284
+ move_axis = False
285
+ if axis_idx != x.ndim - 1:
286
+ x = np.moveaxis(x, axis_idx, -1)
287
+ y_imag = np.moveaxis(y_imag, axis_idx, -1)
288
+ move_axis = True
289
+
290
+ if self._state.delay_buf is None:
291
+ lead_shape = x.shape[:-1]
292
+ self._state.delay_buf = np.zeros(lead_shape + (self._state.dly,), dtype=x.dtype)
293
+
294
+ x_cat = np.concatenate([self._state.delay_buf, x], axis=-1)
295
+ x_delayed_full = x_cat[..., : -self._state.dly]
296
+ y_real = x_delayed_full[..., -x.shape[-1] :]
297
+
298
+ self._state.delay_buf = x_cat[..., -self._state.dly :].copy()
299
+
300
+ analytic = y_real.astype(np.complex64) + 1j * y_imag.astype(np.complex64)
301
+ out = np.abs(analytic)
302
+
303
+ if move_axis:
304
+ out = np.moveaxis(out, -1, axis_idx)
305
+
306
+ return replace(message, data=out, axes=message.axes)
307
+
308
+
309
+ class FIRHilbertEnvelopeUnit(
310
+ BaseTransformerUnit[
311
+ FIRHilbertFilterSettings,
312
+ AxisArray,
313
+ AxisArray,
314
+ FIRHilbertEnvelopeTransformer,
315
+ ]
316
+ ):
317
+ """
318
+ Unit wrapper for the `FIRHilbertEnvelopeTransformer`.
319
+
320
+ This unit provides a plug-and-play interface for calculating the envelope using the FIR Hilbert transform on a
321
+ signal in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs processed data in the same format.
322
+
323
+ Example:
324
+ --------
325
+ ```python
326
+ unit = FIRHilbertEnvelopeUnit(
327
+ settings=FIRHilbertFilterSettings(
328
+ order=170,
329
+ f_lo=1.0,
330
+ f_hi=50.0,
331
+ )
332
+ )
333
+ ```
334
+ """
335
+
336
+ SETTINGS = FIRHilbertFilterSettings
@@ -0,0 +1,209 @@
1
+ import functools
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ import scipy.signal
7
+
8
+ from ezmsg.sigproc.filter import (
9
+ BACoeffs,
10
+ BaseFilterByDesignTransformerUnit,
11
+ FilterBaseSettings,
12
+ FilterByDesignTransformer,
13
+ )
14
+
15
+
16
+ class ParksMcClellanFIRSettings(FilterBaseSettings):
17
+ """Settings for :obj:`ParksMcClellanFIR`."""
18
+
19
+ # axis inherited from FilterBaseSettings
20
+
21
+ coef_type: str = "ba"
22
+ """
23
+ Coefficient type. Must be 'ba' for FIR.
24
+ """
25
+
26
+ order: int = 0
27
+ """
28
+ Filter order (taps = order + 1).
29
+ PMC FIR filters require even order (odd taps).
30
+ If odd order (even taps), order will be incremented by 1.
31
+ """
32
+
33
+ cuton: float | None = None
34
+ """
35
+ Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
36
+ if this is lower than `cutoff` then this is the beginning of the bandpass
37
+ or if this is greater than `cutoff` then this is the end of the bandstop.
38
+ """
39
+
40
+ cutoff: float | None = None
41
+ """
42
+ Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
43
+ if this is greater than `cuton` then this is the end of the bandpass,
44
+ or if this is less than `cuton` then this is the beginning of the bandstop.
45
+ """
46
+
47
+ transition: float = 10.0
48
+ """
49
+ Transition bandwidth (Hz) applied to each passband edge.
50
+ For low/high: single transition. For bands: both edges.
51
+ """
52
+
53
+ weight_pass: float = 1.0
54
+ """
55
+ Weight for the passband.
56
+ Used for both high and low passbands in bandstop filters.
57
+ """
58
+
59
+ weight_stop_lo: float = 1.0
60
+ """
61
+ Weight for the lower stopband.
62
+ Not used for bandstop filters.
63
+ """
64
+
65
+ weight_stop_hi: float = 1.0
66
+ """
67
+ Weight for the upper stopband.
68
+ Used as the central-stop weight for bandstop filters.
69
+ """
70
+
71
+ def filter_specs(
72
+ self,
73
+ ) -> tuple[str, tuple[float, float] | float] | None:
74
+ """
75
+ Determine the filter type given the corner frequencies.
76
+
77
+ Returns:
78
+ A tuple with the first element being a string indicating the filter type
79
+ (one of "lowpass", "highpass", "bandpass", "bandstop")
80
+ and the second element being the corner frequency or frequencies.
81
+
82
+ """
83
+ if self.cuton is None and self.cutoff is None:
84
+ return None
85
+ elif self.cuton is None and self.cutoff is not None:
86
+ return "lowpass", self.cutoff
87
+ elif self.cuton is not None and self.cutoff is None:
88
+ return "highpass", self.cuton
89
+ elif self.cuton is not None and self.cutoff is not None:
90
+ if self.cuton <= self.cutoff:
91
+ return "bandpass", (self.cuton, self.cutoff)
92
+ else:
93
+ return "bandstop", (self.cutoff, self.cuton)
94
+
95
+
96
+ def parks_mcclellan_design_fun(
97
+ fs: float,
98
+ order: int = 0,
99
+ cuton: float | None = None,
100
+ cutoff: float | None = None,
101
+ transition: float = 10.0,
102
+ weight_pass: float = 1.0,
103
+ weight_stop_lo: float = 1.0,
104
+ weight_stop_hi: float = 1.0,
105
+ ) -> BACoeffs | None:
106
+ """
107
+ See :obj:`ParksMcClellanFIRSettings.filter_specs` for an explanation of specifying different
108
+ filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
109
+
110
+ Designs a Parks-McClellan FIR filter via the Remez exchange algorithm using the given specifications.
111
+ PMC filters are equiripple and linear phase.
112
+
113
+ You are likely to want to use this function with :obj:`filter_by_design`, which only passes `fs` to the design
114
+ function (this), meaning that you should wrap this function with a lambda or prepare with functools.partial.
115
+
116
+ Args:
117
+ fs: The sampling frequency of the data in Hz.
118
+ order: Filter order.
119
+ cuton: Corner frequency of the filter in Hz.
120
+ cutoff: Corner frequency of the filter in Hz.
121
+ transition: Transition bandwidth (Hz) applied to each passband edge.
122
+ weight_pass: Weight for the passband.
123
+ weight_stop_lo: Weight for the lower stopband.
124
+ weight_stop_hi: Weight for the upper stopband.
125
+
126
+ Returns:
127
+ The filter coefficients as a tuple of (b, a).
128
+ """
129
+ if order <= 0:
130
+ return None
131
+ if order % 2 == 1:
132
+ order += 1
133
+
134
+ specs = ParksMcClellanFIRSettings(cuton=cuton, cutoff=cutoff).filter_specs()
135
+ if specs is None:
136
+ # Under-specified: no filter
137
+ return None
138
+
139
+ btype, corners = specs
140
+ nyq = fs / 2.0
141
+ tw = max(transition, 0.0)
142
+
143
+ def clip_hz(x: float) -> float:
144
+ return float(min(max(x, 0.0), nyq))
145
+
146
+ if btype == "lowpass":
147
+ b = [0.0, clip_hz(corners), clip_hz(corners + tw), nyq]
148
+ d = [1.0, 0.0]
149
+ w = [max(weight_pass, 0.0), max(weight_stop_hi, 0.0)]
150
+
151
+ elif btype == "highpass":
152
+ b = [0.0, clip_hz(corners - tw), clip_hz(corners), nyq]
153
+ d = [0.0, 1.0]
154
+ w = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0)]
155
+
156
+ elif btype == "bandpass":
157
+ b = [
158
+ 0.0,
159
+ clip_hz(corners[0] - tw),
160
+ clip_hz(corners[0]),
161
+ clip_hz(corners[1]),
162
+ clip_hz(corners[1] + tw),
163
+ nyq,
164
+ ]
165
+ d = [0.0, 1.0, 0.0]
166
+ w = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0), max(weight_stop_hi, 0.0)]
167
+
168
+ else:
169
+ b = [
170
+ 0.0,
171
+ clip_hz(corners[0]),
172
+ clip_hz(corners[0] + tw),
173
+ clip_hz(corners[1] - tw),
174
+ clip_hz(corners[1]),
175
+ nyq,
176
+ ]
177
+ d = [1.0, 0.0, 1.0]
178
+ # For bandstop we can reuse stop_hi as central-stop weight; stop_lo is the DC-side passband stop weight
179
+ w = [max(weight_pass, 0.0), max(weight_stop_hi, 0.0), max(weight_pass, 0.0)]
180
+
181
+ # Ensure bands strictly increase and have nonzero width per segment
182
+ # Adjust tiny overlaps due to clipping
183
+ for i in range(1, len(b)):
184
+ if b[i] <= b[i - 1]:
185
+ b[i] = min(b[i - 1] + 1e-6, nyq)
186
+
187
+ b = scipy.signal.remez(numtaps=order + 1, bands=b, desired=d, weight=w, fs=fs)
188
+ return (b, np.array([1.0]))
189
+
190
+
191
+ class ParksMcClellanFIRTransformer(FilterByDesignTransformer[ParksMcClellanFIRSettings, BACoeffs]):
192
+ def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]:
193
+ if self.settings.coef_type != "ba":
194
+ ez.logger.error("ParksMcClellanFIR only supports coef_type='ba'.")
195
+ raise ValueError("ParksMcClellanFIR only supports coef_type='ba'.")
196
+ return functools.partial(
197
+ parks_mcclellan_design_fun,
198
+ order=self.settings.order,
199
+ cuton=self.settings.cuton,
200
+ cutoff=self.settings.cutoff,
201
+ transition=self.settings.transition,
202
+ weight_pass=self.settings.weight_pass,
203
+ weight_stop_lo=self.settings.weight_stop_lo,
204
+ weight_stop_hi=self.settings.weight_stop_hi,
205
+ )
206
+
207
+
208
+ class ParksMcClellanFIR(BaseFilterByDesignTransformerUnit[ParksMcClellanFIRSettings, ParksMcClellanFIRTransformer]):
209
+ SETTINGS = ParksMcClellanFIRSettings
@@ -6,10 +6,10 @@ import numpy.typing as npt
6
6
  import scipy.signal
7
7
 
8
8
  from .filter import (
9
- FilterBaseSettings,
10
- FilterByDesignTransformer,
11
9
  BACoeffs,
12
10
  BaseFilterByDesignTransformerUnit,
11
+ FilterBaseSettings,
12
+ FilterByDesignTransformer,
13
13
  )
14
14
 
15
15
 
@@ -25,16 +25,16 @@ class FIRFilterSettings(FilterBaseSettings):
25
25
 
26
26
  cutoff: float | npt.ArrayLike | None = None
27
27
  """
28
- Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
29
- (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
30
- the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
31
- cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
28
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
29
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
30
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
31
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
32
32
  not be included in cutoff.
33
33
  """
34
34
 
35
35
  width: float | None = None
36
36
  """
37
- If width is not None, then assume it is the approximate width of the transition region (expressed in
37
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
38
38
  the same units as fs) for use in Kaiser FIR filter design. In this case, the window argument is ignored.
39
39
  """
40
40
 
@@ -45,18 +45,18 @@ class FIRFilterSettings(FilterBaseSettings):
45
45
 
46
46
  pass_zero: bool | str = True
47
47
  """
48
- If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
48
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
49
49
  be a string argument for the desired filter type (equivalent to btype in IIR design functions).
50
50
  {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
51
51
  """
52
52
 
53
53
  scale: bool = True
54
54
  """
55
- Set to True to scale the coefficients so that the frequency response is exactly unity at a certain
55
+ Set to True to scale the coefficients so that the frequency response is exactly unity at a certain
56
56
  frequency. That frequency is either:
57
57
  * 0 (DC) if the first passband starts at 0 (i.e. pass_zero is True)
58
- * fs/2 (the Nyquist frequency) if the first passband ends at fs/2
59
- (i.e the filter is a single band highpass filter);
58
+ * fs/2 (the Nyquist frequency) if the first passband ends at fs/2
59
+ (i.e the filter is a single band highpass filter);
60
60
  center of first passband otherwise
61
61
  """
62
62
 
@@ -113,7 +113,5 @@ class FIRFilterTransformer(FilterByDesignTransformer[FIRFilterSettings, BACoeffs
113
113
  )
114
114
 
115
115
 
116
- class FIRFilter(
117
- BaseFilterByDesignTransformerUnit[FIRFilterSettings, FIRFilterTransformer]
118
- ):
116
+ class FIRFilter(BaseFilterByDesignTransformerUnit[FIRFilterSettings, FIRFilterTransformer]):
119
117
  SETTINGS = FIRFilterSettings
@@ -1,13 +1,13 @@
1
- from typing import Callable
2
1
  import warnings
2
+ from typing import Callable
3
3
 
4
4
  import numpy as np
5
5
 
6
6
  from .filter import (
7
- FilterBaseSettings,
8
7
  BACoeffs,
9
- FilterByDesignTransformer,
10
8
  BaseFilterByDesignTransformerUnit,
9
+ FilterBaseSettings,
10
+ FilterByDesignTransformer,
11
11
  )
12
12
 
13
13
 
@@ -68,9 +68,7 @@ def gaussian_smoothing_filter_design(
68
68
  return b, a
69
69
 
70
70
 
71
- class GaussianSmoothingFilterTransformer(
72
- FilterByDesignTransformer[GaussianSmoothingSettings, BACoeffs]
73
- ):
71
+ class GaussianSmoothingFilterTransformer(FilterByDesignTransformer[GaussianSmoothingSettings, BACoeffs]):
74
72
  def get_design_function(
75
73
  self,
76
74
  ) -> Callable[[float], BACoeffs]:
@@ -86,8 +84,6 @@ class GaussianSmoothingFilterTransformer(
86
84
 
87
85
 
88
86
  class GaussianSmoothingFilter(
89
- BaseFilterByDesignTransformerUnit[
90
- GaussianSmoothingSettings, GaussianSmoothingFilterTransformer
91
- ]
87
+ BaseFilterByDesignTransformerUnit[GaussianSmoothingSettings, GaussianSmoothingFilterTransformer]
92
88
  ):
93
89
  SETTINGS = GaussianSmoothingSettings