ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,305 @@
1
+ """
2
+ Streaming zero-phase Butterworth filter implemented as a two-stage composite processor.
3
+
4
+ Stage 1: Forward causal Butterworth filter (from ezmsg.sigproc.butterworthfilter)
5
+ Stage 2: Backward acausal filter with buffering (ButterworthBackwardFilterTransformer)
6
+
7
+ The output is delayed by `pad_length` samples to ensure the backward pass has sufficient
8
+ future context. The pad_length is computed analytically using scipy's heuristic.
9
+ """
10
+
11
+ import functools
12
+ import typing
13
+
14
+ import numpy as np
15
+ import scipy.signal
16
+ from ezmsg.baseproc import BaseTransformerUnit
17
+ from ezmsg.baseproc.composite import CompositeProcessor
18
+ from ezmsg.util.messages.axisarray import AxisArray
19
+ from ezmsg.util.messages.util import replace
20
+
21
+ from .butterworthfilter import (
22
+ ButterworthFilterSettings,
23
+ ButterworthFilterTransformer,
24
+ butter_design_fun,
25
+ )
26
+ from .filter import BACoeffs, FilterByDesignTransformer, SOSCoeffs
27
+ from .util.axisarray_buffer import HybridAxisArrayBuffer
28
+
29
+
30
+ class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
31
+ """
32
+ Settings for :obj:`ButterworthZeroPhase`.
33
+
34
+ This implements a streaming zero-phase Butterworth filter using forward-backward
35
+ filtering. The output is delayed by `pad_length` samples to ensure the backward
36
+ pass has sufficient future context.
37
+
38
+ The pad_length is computed by finding where the filter's impulse response decays
39
+ to `settle_cutoff` fraction of its peak value. This accounts for the filter's
40
+ actual time constant rather than just its order.
41
+ """
42
+
43
+ # Inherits from ButterworthFilterSettings:
44
+ # axis, coef_type, order, cuton, cutoff, wn_hz
45
+
46
+ settle_cutoff: float = 0.01
47
+ """
48
+ Fraction of peak impulse response used to determine settling time.
49
+ The pad_length is set to the number of samples until the impulse response
50
+ decays to this fraction of its peak. Default is 0.01 (1% of peak).
51
+ """
52
+
53
+ max_pad_duration: float | None = None
54
+ """
55
+ Maximum pad duration in seconds. If set, the pad_length will be capped
56
+ at this value times the sampling rate. Use this to limit latency for
57
+ filters with very long impulse responses. Default is None (no limit).
58
+ """
59
+
60
+
61
+ class ButterworthBackwardFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]):
62
+ """
63
+ Backward (acausal) Butterworth filter with buffering.
64
+
65
+ This transformer buffers its input and applies the filter in reverse,
66
+ outputting only the "settled" portion where transients have decayed.
67
+ This introduces a lag of ``pad_length`` samples.
68
+
69
+ Intended to be used as stage 2 in a zero-phase filter pipeline, receiving
70
+ forward-filtered data from a ButterworthFilterTransformer.
71
+ """
72
+
73
+ # Instance attributes (initialized in _reset_state)
74
+ _buffer: HybridAxisArrayBuffer | None
75
+ _coefs_cache: BACoeffs | SOSCoeffs | None
76
+ _zi_tiled: np.ndarray | None
77
+ _pad_length: int
78
+
79
+ def get_design_function(
80
+ self,
81
+ ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
82
+ return functools.partial(
83
+ butter_design_fun,
84
+ order=self.settings.order,
85
+ cuton=self.settings.cuton,
86
+ cutoff=self.settings.cutoff,
87
+ coef_type=self.settings.coef_type,
88
+ wn_hz=self.settings.wn_hz,
89
+ )
90
+
91
+ def _compute_pad_length(self, fs: float) -> int:
92
+ """
93
+ Compute pad length based on the filter's impulse response settling time.
94
+
95
+ The pad_length is determined by finding where the impulse response decays
96
+ to `settle_cutoff` fraction of its peak value. This is then optionally
97
+ capped by `max_pad_duration`.
98
+
99
+ Args:
100
+ fs: Sampling frequency in Hz.
101
+
102
+ Returns:
103
+ Number of samples for the pad length.
104
+ """
105
+ # Design the filter to compute impulse response
106
+ coefs = self.get_design_function()(fs)
107
+ if coefs is None:
108
+ # Filter design failed or is disabled
109
+ return 0
110
+
111
+ # Generate impulse response - use a generous length initially
112
+ # Start with scipy's heuristic as minimum, then extend if needed
113
+ if self.settings.coef_type == "ba":
114
+ min_length = 3 * (self.settings.order + 1)
115
+ else:
116
+ n_sections = (self.settings.order + 1) // 2
117
+ min_length = 3 * n_sections * 2
118
+
119
+ # Use 10x the minimum as initial impulse length, or at least 10000 samples
120
+ # (10000 samples allows for ~333ms at 30kHz, covering most practical cases)
121
+ impulse_length = max(min_length * 10, 10000)
122
+
123
+ # Cap impulse length computation if max_pad_duration is set
124
+ if self.settings.max_pad_duration is not None:
125
+ max_samples = int(self.settings.max_pad_duration * fs)
126
+ impulse_length = min(impulse_length, max_samples + 1)
127
+
128
+ impulse = np.zeros(impulse_length)
129
+ impulse[0] = 1.0
130
+
131
+ if self.settings.coef_type == "ba":
132
+ b, a = coefs
133
+ h = scipy.signal.lfilter(b, a, impulse)
134
+ else:
135
+ h = scipy.signal.sosfilt(coefs, impulse)
136
+
137
+ # Find where impulse response settles to settle_cutoff of peak
138
+ abs_h = np.abs(h)
139
+ peak = abs_h.max()
140
+ if peak == 0:
141
+ return min_length
142
+
143
+ threshold = self.settings.settle_cutoff * peak
144
+ above_threshold = np.where(abs_h > threshold)[0]
145
+
146
+ if len(above_threshold) == 0:
147
+ pad_length = min_length
148
+ else:
149
+ pad_length = above_threshold[-1] + 1
150
+
151
+ # Ensure at least the scipy heuristic minimum
152
+ pad_length = max(pad_length, min_length)
153
+
154
+ # Apply max_pad_duration cap if set
155
+ if self.settings.max_pad_duration is not None:
156
+ max_samples = int(self.settings.max_pad_duration * fs)
157
+ pad_length = min(pad_length, max_samples)
158
+
159
+ return pad_length
160
+
161
+ def _reset_state(self, message: AxisArray) -> None:
162
+ """Reset filter state when stream changes."""
163
+ self._coefs_cache = None
164
+ self._zi_tiled = None
165
+ self._buffer = None
166
+ # Compute pad_length based on the message's sampling rate
167
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
168
+ fs = 1 / message.axes[axis].gain
169
+ self._pad_length = self._compute_pad_length(fs)
170
+ self.state.needs_redesign = True
171
+
172
+ def _compute_zi_tiled(self, data: np.ndarray, ax_idx: int) -> None:
173
+ """Compute and cache the tiled zi for the given data shape.
174
+
175
+ Called once per stream (or after filter redesign). The result is
176
+ broadcast-ready for multiplication by the edge sample on each chunk.
177
+ """
178
+ if self.settings.coef_type == "ba":
179
+ b, a = self._coefs_cache
180
+ zi_base = scipy.signal.lfilter_zi(b, a)
181
+ else: # sos
182
+ zi_base = scipy.signal.sosfilt_zi(self._coefs_cache)
183
+
184
+ n_tail = data.ndim - ax_idx - 1
185
+
186
+ if self.settings.coef_type == "ba":
187
+ zi_expand = (None,) * ax_idx + (slice(None),) + (None,) * n_tail
188
+ n_tile = data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :]
189
+ else: # sos
190
+ zi_expand = (slice(None),) + (None,) * ax_idx + (slice(None),) + (None,) * n_tail
191
+ n_tile = (1,) + data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :]
192
+
193
+ self._zi_tiled = np.tile(zi_base[zi_expand], n_tile)
194
+
195
+ def _initialize_zi(self, data: np.ndarray, ax_idx: int) -> np.ndarray:
196
+ """Initialize filter state (zi) scaled by edge value."""
197
+ if self._zi_tiled is None:
198
+ self._compute_zi_tiled(data, ax_idx)
199
+ first_sample = np.take(data, [0], axis=ax_idx)
200
+ return self._zi_tiled * first_sample
201
+
202
+ def _process(self, message: AxisArray) -> AxisArray:
203
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
204
+ ax_idx = message.get_axis_idx(axis)
205
+ fs = 1 / message.axes[axis].gain
206
+
207
+ # Check if we need to redesign filter
208
+ if self._coefs_cache is None or self.state.needs_redesign:
209
+ self._coefs_cache = self.get_design_function()(fs)
210
+ self._pad_length = self._compute_pad_length(fs)
211
+ self._zi_tiled = None # Invalidate; recomputed on next use.
212
+ self.state.needs_redesign = False
213
+
214
+ # Initialize buffer with duration based on pad_length
215
+ # Add some margin to handle variable chunk sizes
216
+ buffer_duration = (self._pad_length + 1) / fs
217
+ self._buffer = HybridAxisArrayBuffer(duration=buffer_duration, axis=axis)
218
+
219
+ # Early exit if filter is effectively disabled
220
+ if self._coefs_cache is None or self.settings.order <= 0 or message.data.size <= 0:
221
+ return message
222
+
223
+ # Write new data to buffer
224
+ self._buffer.write(message)
225
+ n_available = self._buffer.available()
226
+ n_output = n_available - self._pad_length
227
+
228
+ # If we don't have enough data yet, return empty
229
+ if n_output <= 0:
230
+ new_shape = list(message.data.shape)
231
+ new_shape[ax_idx] = 0
232
+ empty_data = np.empty(new_shape, dtype=message.data.dtype)
233
+ return replace(message, data=empty_data)
234
+
235
+ # Peek all available data from buffer
236
+ # Note: HybridAxisArrayBuffer moves the target axis to position 0
237
+ buffered = self._buffer.peek(n_available)
238
+ combined = buffered.data
239
+ buffer_ax_idx = 0 # Buffer always puts time axis at position 0
240
+
241
+ # Backward filter on reversed data
242
+ combined_rev = np.flip(combined, axis=buffer_ax_idx)
243
+ backward_zi = self._initialize_zi(combined_rev, buffer_ax_idx)
244
+
245
+ if self.settings.coef_type == "ba":
246
+ b, a = self._coefs_cache
247
+ y_bwd_rev, _ = scipy.signal.lfilter(b, a, combined_rev, axis=buffer_ax_idx, zi=backward_zi)
248
+ else: # sos
249
+ y_bwd_rev, _ = scipy.signal.sosfilt(self._coefs_cache, combined_rev, axis=buffer_ax_idx, zi=backward_zi)
250
+
251
+ # Reverse back to get output in correct time order
252
+ y_bwd = np.flip(y_bwd_rev, axis=buffer_ax_idx)
253
+
254
+ # Output the settled portion (first n_output samples)
255
+ y = y_bwd[:n_output]
256
+
257
+ # Advance buffer read head to discard output samples, keep pad_length
258
+ self._buffer.seek(n_output)
259
+
260
+ # Build output with adjusted time axis
261
+ # LinearAxis offset is already correct from the buffer
262
+ out_axis = buffered.axes[axis]
263
+
264
+ # Move axis back to original position if needed
265
+ if ax_idx != 0:
266
+ y = np.moveaxis(y, 0, ax_idx)
267
+
268
+ return replace(
269
+ message,
270
+ data=y,
271
+ axes={**message.axes, axis: out_axis},
272
+ )
273
+
274
+
275
+ class ButterworthZeroPhaseTransformer(CompositeProcessor[ButterworthZeroPhaseSettings, AxisArray, AxisArray]):
276
+ """
277
+ Streaming zero-phase Butterworth filter as a composite of two stages.
278
+
279
+ Stage 1 (forward): Standard causal Butterworth filter with state
280
+ Stage 2 (backward): Acausal Butterworth filter with buffering
281
+
282
+ The output is delayed by ``pad_length`` samples.
283
+ """
284
+
285
+ @staticmethod
286
+ def _initialize_processors(
287
+ settings: ButterworthZeroPhaseSettings,
288
+ ) -> dict[str, typing.Any]:
289
+ # Both stages use the same filter design settings
290
+ return {
291
+ "forward": ButterworthFilterTransformer(settings),
292
+ "backward": ButterworthBackwardFilterTransformer(settings),
293
+ }
294
+
295
+ @classmethod
296
+ def get_message_type(cls, dir: str) -> type[AxisArray]:
297
+ if dir in ("in", "out"):
298
+ return AxisArray
299
+ raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
300
+
301
+
302
+ class ButterworthZeroPhase(
303
+ BaseTransformerUnit[ButterworthZeroPhaseSettings, AxisArray, AxisArray, ButterworthZeroPhaseTransformer]
304
+ ):
305
+ SETTINGS = ButterworthZeroPhaseSettings
ezmsg/sigproc/cheby.py ADDED
@@ -0,0 +1,125 @@
1
+ import functools
2
+ import typing
3
+
4
+ import scipy.signal
5
+ from scipy.signal import normalize
6
+
7
+ from .filter import (
8
+ BACoeffs,
9
+ BaseFilterByDesignTransformerUnit,
10
+ FilterBaseSettings,
11
+ FilterByDesignTransformer,
12
+ SOSCoeffs,
13
+ )
14
+
15
+
16
+ class ChebyshevFilterSettings(FilterBaseSettings):
17
+ """Settings for :obj:`ChebyshevFilter`."""
18
+
19
+ # axis and coef_type are inherited from FilterBaseSettings
20
+
21
+ order: int = 0
22
+ """
23
+ Filter order
24
+ """
25
+
26
+ ripple_tol: float | None = None
27
+ """
28
+ The maximum ripple allowed below unity gain in the passband. Specified in decibels, as a positive number.
29
+ """
30
+
31
+ Wn: float | tuple[float, float] | None = None
32
+ """
33
+ A scalar or length-2 sequence giving the critical frequencies.
34
+ For Type I filters, this is the point in the transition band at which the gain first drops below -rp.
35
+ For digital filters, Wn are in the same units as fs unless wn_hz is False.
36
+ For analog filters, Wn is an angular frequency (e.g., rad/s).
37
+ """
38
+
39
+ btype: str = "lowpass"
40
+ """
41
+ {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
42
+ """
43
+
44
+ analog: bool = False
45
+ """
46
+ When True, return an analog filter, otherwise a digital filter is returned.
47
+ """
48
+
49
+ cheby_type: str = "cheby1"
50
+ """
51
+ Which type of Chebyshev filter to design. Either "cheby1" or "cheby2".
52
+ """
53
+
54
+ wn_hz: bool = True
55
+ """
56
+ Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
57
+ """
58
+
59
+
60
+ def cheby_design_fun(
61
+ fs: float,
62
+ order: int = 0,
63
+ ripple_tol: float | None = None,
64
+ Wn: float | tuple[float, float] | None = None,
65
+ btype: str = "lowpass",
66
+ analog: bool = False,
67
+ coef_type: str = "ba",
68
+ cheby_type: str = "cheby1",
69
+ wn_hz: bool = True,
70
+ ) -> BACoeffs | SOSCoeffs | None:
71
+ """
72
+ Chebyshev type I and type II digital and analog filter design.
73
+ Design an `order`th-order digital or analog Chebyshev type I or type II filter and return the filter coefficients.
74
+ See :obj:`ChebyFilterSettings` for argument description.
75
+
76
+ Returns:
77
+ The filter coefficients as a tuple of (b, a) for coef_type "ba", or as a single ndarray for "sos",
78
+ or (z, p, k) for "zpk".
79
+ """
80
+ coefs = None
81
+ if order > 0:
82
+ if cheby_type == "cheby1":
83
+ coefs = scipy.signal.cheby1(
84
+ order,
85
+ ripple_tol,
86
+ Wn,
87
+ btype=btype,
88
+ analog=analog,
89
+ output=coef_type,
90
+ fs=fs if wn_hz else None,
91
+ )
92
+ elif cheby_type == "cheby2":
93
+ coefs = scipy.signal.cheby2(
94
+ order,
95
+ ripple_tol,
96
+ Wn,
97
+ btype=btype,
98
+ analog=analog,
99
+ output=coef_type,
100
+ fs=fs,
101
+ )
102
+ if coefs is not None and coef_type == "ba":
103
+ coefs = normalize(*coefs)
104
+ return coefs
105
+
106
+
107
+ class ChebyshevFilterTransformer(FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]):
108
+ def get_design_function(
109
+ self,
110
+ ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
111
+ return functools.partial(
112
+ cheby_design_fun,
113
+ order=self.settings.order,
114
+ ripple_tol=self.settings.ripple_tol,
115
+ Wn=self.settings.Wn,
116
+ btype=self.settings.btype,
117
+ analog=self.settings.analog,
118
+ coef_type=self.settings.coef_type,
119
+ cheby_type=self.settings.cheby_type,
120
+ wn_hz=self.settings.wn_hz,
121
+ )
122
+
123
+
124
+ class ChebyshevFilter(BaseFilterByDesignTransformerUnit[ChebyshevFilterSettings, ChebyshevFilterTransformer]):
125
+ SETTINGS = ChebyshevFilterSettings
@@ -0,0 +1,160 @@
1
+ import functools
2
+ import typing
3
+
4
+ import numpy as np
5
+ import scipy.signal
6
+ from scipy.signal import normalize
7
+
8
+ from .filter import (
9
+ BACoeffs,
10
+ BaseFilterByDesignTransformerUnit,
11
+ FilterBaseSettings,
12
+ FilterByDesignTransformer,
13
+ SOSCoeffs,
14
+ )
15
+
16
+
17
+ class CombFilterSettings(FilterBaseSettings):
18
+ """Settings for :obj:`CombFilter`."""
19
+
20
+ # axis and coef_type are inherited from FilterBaseSettings
21
+
22
+ fundamental_freq: float = 60.0
23
+ """
24
+ Fundamental frequency in Hz
25
+ """
26
+
27
+ num_harmonics: int = 3
28
+ """
29
+ Number of harmonics to include (including fundamental)
30
+ """
31
+
32
+ q_factor: float = 35.0
33
+ """
34
+ Quality factor (Q) for each peak/notch
35
+ """
36
+
37
+ filter_type: str = "notch"
38
+ """
39
+ Type of comb filter: 'notch' removes harmonics, 'peak' passes harmonics at the expense of others.
40
+ """
41
+
42
+ quality_scaling: str = "constant"
43
+ """
44
+ 'constant': same quality for all harmonics results in wider bands at higher frequencies,
45
+ 'proportional': quality proportional to frequency results in constant bandwidths.
46
+ """
47
+
48
+
49
+ def comb_design_fun(
50
+ fs: float,
51
+ fundamental_freq: float = 60.0,
52
+ num_harmonics: int = 3,
53
+ q_factor: float = 35.0,
54
+ filter_type: str = "notch",
55
+ coef_type: str = "sos",
56
+ quality_scaling: str = "constant",
57
+ ) -> BACoeffs | SOSCoeffs | None:
58
+ """
59
+ Design a comb filter as cascaded second-order sections targeting a fundamental frequency and its harmonics.
60
+
61
+ Returns:
62
+ The filter coefficients as SOS (recommended) or (b, a) for finite precision stability.
63
+ """
64
+ if coef_type != "sos" and coef_type != "ba":
65
+ raise ValueError("Comb filter only supports 'sos' or 'ba' coefficient types")
66
+
67
+ # Generate all SOS sections
68
+ all_sos = []
69
+
70
+ for i in range(1, num_harmonics + 1):
71
+ freq = fundamental_freq * i
72
+
73
+ # Skip if frequency exceeds Nyquist
74
+ if freq >= fs / 2:
75
+ continue
76
+
77
+ # Adjust Q factor based on scaling method
78
+ current_q = q_factor
79
+ if quality_scaling == "proportional":
80
+ current_q = q_factor * i
81
+
82
+ if filter_type == "notch":
83
+ sos = scipy.signal.iirnotch(w0=freq, Q=current_q, fs=fs)
84
+ else: # peak filter
85
+ sos = scipy.signal.iirpeak(w0=freq, Q=current_q, fs=fs)
86
+ # Though .iirnotch and .iirpeak return b, a pairs, these are second order so
87
+ # we can use them directly as SOS sections.
88
+ # Check:
89
+ # assert np.allclose(scipy.signal.tf2sos(sos[0], sos[1])[0], np.hstack(sos))
90
+
91
+ all_sos.append(np.hstack(sos))
92
+
93
+ if not all_sos:
94
+ return None
95
+
96
+ # Combine all SOS sections
97
+ combined_sos = np.vstack(all_sos)
98
+
99
+ if coef_type == "ba":
100
+ # Convert to transfer function form
101
+ b, a = scipy.signal.sos2tf(combined_sos)
102
+ return normalize(b, a)
103
+
104
+ return combined_sos
105
+
106
+
107
+ class CombFilterTransformer(FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]):
108
+ def get_design_function(
109
+ self,
110
+ ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
111
+ return functools.partial(
112
+ comb_design_fun,
113
+ fundamental_freq=self.settings.fundamental_freq,
114
+ num_harmonics=self.settings.num_harmonics,
115
+ q_factor=self.settings.q_factor,
116
+ filter_type=self.settings.filter_type,
117
+ coef_type=self.settings.coef_type,
118
+ quality_scaling=self.settings.quality_scaling,
119
+ )
120
+
121
+
122
+ class CombFilterUnit(BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]):
123
+ SETTINGS = CombFilterSettings
124
+
125
+
126
+ def comb(
127
+ axis: str | None,
128
+ fundamental_freq: float = 50.0,
129
+ num_harmonics: int = 3,
130
+ q_factor: float = 35.0,
131
+ filter_type: str = "notch",
132
+ coef_type: str = "sos",
133
+ quality_scaling: str = "constant",
134
+ ) -> CombFilterTransformer:
135
+ """
136
+ Create a comb filter for enhancing or removing a fundamental frequency and its harmonics.
137
+
138
+ Args:
139
+ axis: Axis to filter along
140
+ fundamental_freq: Base frequency in Hz
141
+ num_harmonics: Number of harmonic peaks/notches (including fundamental)
142
+ q_factor: Quality factor for peak/notch width
143
+ filter_type: 'notch' to remove or 'peak' to enhance harmonics
144
+ coef_type: Coefficient type ('sos' recommended for stability)
145
+ quality_scaling: How to handle bandwidths across harmonics
146
+
147
+ Returns:
148
+ CombFilterTransformer
149
+ """
150
+ return CombFilterTransformer(
151
+ CombFilterSettings(
152
+ axis=axis,
153
+ fundamental_freq=fundamental_freq,
154
+ num_harmonics=num_harmonics,
155
+ q_factor=q_factor,
156
+ filter_type=filter_type,
157
+ coef_type=coef_type,
158
+ quality_scaling=quality_scaling,
159
+ )
160
+ )