ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,276 @@
1
+ """
2
+ Aggregation operations over arrays.
3
+
4
+ .. note::
5
+ :obj:`AggregateTransformer` supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ :obj:`RangedAggregateTransformer` currently requires NumPy arrays.
8
+ """
9
+
10
+ import typing
11
+
12
+ import ezmsg.core as ez
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ from array_api_compat import get_namespace
16
+ from ezmsg.baseproc import (
17
+ BaseStatefulTransformer,
18
+ BaseTransformer,
19
+ BaseTransformerUnit,
20
+ processor_state,
21
+ )
22
+ from ezmsg.util.messages.axisarray import (
23
+ AxisArray,
24
+ AxisBase,
25
+ replace,
26
+ slice_along_axis,
27
+ )
28
+
29
+ from .spectral import OptionsEnum
30
+
31
+
32
+ class AggregationFunction(OptionsEnum):
33
+ """Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation."""
34
+
35
+ NONE = "None (all)"
36
+ MAX = "max"
37
+ MIN = "min"
38
+ MEAN = "mean"
39
+ MEDIAN = "median"
40
+ STD = "std"
41
+ SUM = "sum"
42
+ NANMAX = "nanmax"
43
+ NANMIN = "nanmin"
44
+ NANMEAN = "nanmean"
45
+ NANMEDIAN = "nanmedian"
46
+ NANSTD = "nanstd"
47
+ NANSUM = "nansum"
48
+ ARGMIN = "argmin"
49
+ ARGMAX = "argmax"
50
+ TRAPEZOID = "trapezoid"
51
+
52
+
53
+ AGGREGATORS = {
54
+ AggregationFunction.NONE: np.all,
55
+ AggregationFunction.MAX: np.max,
56
+ AggregationFunction.MIN: np.min,
57
+ AggregationFunction.MEAN: np.mean,
58
+ AggregationFunction.MEDIAN: np.median,
59
+ AggregationFunction.STD: np.std,
60
+ AggregationFunction.SUM: np.sum,
61
+ AggregationFunction.NANMAX: np.nanmax,
62
+ AggregationFunction.NANMIN: np.nanmin,
63
+ AggregationFunction.NANMEAN: np.nanmean,
64
+ AggregationFunction.NANMEDIAN: np.nanmedian,
65
+ AggregationFunction.NANSTD: np.nanstd,
66
+ AggregationFunction.NANSUM: np.nansum,
67
+ AggregationFunction.ARGMIN: np.argmin,
68
+ AggregationFunction.ARGMAX: np.argmax,
69
+ # Note: Some methods require x-coordinates and
70
+ # are handled specially in `_process`.
71
+ AggregationFunction.TRAPEZOID: np.trapezoid,
72
+ }
73
+
74
+
75
+ class RangedAggregateSettings(ez.Settings):
76
+ """
77
+ Settings for ``RangedAggregate``.
78
+ """
79
+
80
+ axis: str | None = None
81
+ """The name of the axis along which to apply the bands."""
82
+
83
+ bands: list[tuple[float, float]] | None = None
84
+ """
85
+ [(band1_min, band1_max), (band2_min, band2_max), ...]
86
+ If not set then this acts as a passthrough node.
87
+ """
88
+
89
+ operation: AggregationFunction = AggregationFunction.MEAN
90
+ """:obj:`AggregationFunction` to apply to each band."""
91
+
92
+
93
+ @processor_state
94
+ class RangedAggregateState:
95
+ slices: list[tuple[typing.Any, ...]] | None = None
96
+ out_axis: AxisBase | None = None
97
+ ax_vec: npt.NDArray | None = None
98
+
99
+
100
+ class RangedAggregateTransformer(
101
+ BaseStatefulTransformer[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState]
102
+ ):
103
+ def __call__(self, message: AxisArray) -> AxisArray:
104
+ # Override for shortcut passthrough mode.
105
+ if self.settings.bands is None:
106
+ return message
107
+ return super().__call__(message)
108
+
109
+ def _hash_message(self, message: AxisArray) -> int:
110
+ axis = self.settings.axis or message.dims[0]
111
+ target_axis = message.get_axis(axis)
112
+
113
+ hash_components = (message.key,)
114
+ if hasattr(target_axis, "data"):
115
+ hash_components += (len(target_axis.data),)
116
+ elif isinstance(target_axis, AxisArray.LinearAxis):
117
+ hash_components += (target_axis.gain, target_axis.offset)
118
+ return hash(hash_components)
119
+
120
+ def _reset_state(self, message: AxisArray) -> None:
121
+ axis = self.settings.axis or message.dims[0]
122
+ target_axis = message.get_axis(axis)
123
+ ax_idx = message.get_axis_idx(axis)
124
+
125
+ if hasattr(target_axis, "data"):
126
+ self._state.ax_vec = target_axis.data
127
+ else:
128
+ self._state.ax_vec = target_axis.value(np.arange(message.data.shape[ax_idx]))
129
+
130
+ ax_dat = []
131
+ slices = []
132
+ for start, stop in self.settings.bands:
133
+ inds = np.where(np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop))[0]
134
+ slices.append(np.s_[inds[0] : inds[-1] + 1])
135
+ if hasattr(target_axis, "data"):
136
+ if self._state.ax_vec.dtype.type is np.str_:
137
+ sl_dat = f"{self._state.ax_vec[start]} - {self._state.ax_vec[stop]}"
138
+ else:
139
+ ax_dat.append(np.mean(self._state.ax_vec[inds]))
140
+ else:
141
+ sl_dat = target_axis.value(np.mean(inds))
142
+ ax_dat.append(sl_dat)
143
+
144
+ self._state.slices = slices
145
+ self._state.out_axis = AxisArray.CoordinateAxis(
146
+ data=np.array(ax_dat),
147
+ dims=[axis],
148
+ unit=target_axis.unit,
149
+ )
150
+
151
+ def _process(self, message: AxisArray) -> AxisArray:
152
+ axis = self.settings.axis or message.dims[0]
153
+ ax_idx = message.get_axis_idx(axis)
154
+ agg_func = AGGREGATORS[self.settings.operation]
155
+
156
+ if self.settings.operation in [
157
+ AggregationFunction.TRAPEZOID,
158
+ ]:
159
+ # Special handling for methods that require x-coordinates.
160
+ out_data = [
161
+ agg_func(
162
+ slice_along_axis(message.data, sl, axis=ax_idx),
163
+ x=self._state.ax_vec[sl],
164
+ axis=ax_idx,
165
+ )
166
+ for sl in self._state.slices
167
+ ]
168
+ else:
169
+ out_data = [
170
+ agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx) for sl in self._state.slices
171
+ ]
172
+
173
+ msg_out = replace(
174
+ message,
175
+ data=np.stack(out_data, axis=ax_idx),
176
+ axes={**message.axes, axis: self._state.out_axis},
177
+ )
178
+
179
+ if self.settings.operation in [
180
+ AggregationFunction.ARGMIN,
181
+ AggregationFunction.ARGMAX,
182
+ ]:
183
+ out_data = []
184
+ for sl_ix, sl in enumerate(self._state.slices):
185
+ offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
186
+ out_data.append(self._state.ax_vec[sl][offsets])
187
+ msg_out.data = np.concatenate(out_data, axis=ax_idx)
188
+
189
+ return msg_out
190
+
191
+
192
+ class RangedAggregate(BaseTransformerUnit[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer]):
193
+ SETTINGS = RangedAggregateSettings
194
+
195
+
196
+ def ranged_aggregate(
197
+ axis: str | None = None,
198
+ bands: list[tuple[float, float]] | None = None,
199
+ operation: AggregationFunction = AggregationFunction.MEAN,
200
+ ) -> RangedAggregateTransformer:
201
+ """
202
+ Apply an aggregation operation over one or more bands.
203
+
204
+ Args:
205
+ axis: The name of the axis along which to apply the bands.
206
+ bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
207
+ If not set then this acts as a passthrough node.
208
+ operation: :obj:`AggregationFunction` to apply to each band.
209
+
210
+ Returns:
211
+ :obj:`RangedAggregateTransformer`
212
+ """
213
+ return RangedAggregateTransformer(RangedAggregateSettings(axis=axis, bands=bands, operation=operation))
214
+
215
+
216
+ class AggregateSettings(ez.Settings):
217
+ """Settings for :obj:`Aggregate`."""
218
+
219
+ axis: str
220
+ """The name of the axis to aggregate over. This axis will be removed from the output."""
221
+
222
+ operation: AggregationFunction = AggregationFunction.MEAN
223
+ """:obj:`AggregationFunction` to apply."""
224
+
225
+
226
+ class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArray]):
227
+ """
228
+ Transformer that aggregates an entire axis using a specified operation.
229
+
230
+ Unlike :obj:`RangedAggregateTransformer` which aggregates over specific ranges/bands
231
+ and preserves the axis (with one value per band), this transformer aggregates the
232
+ entire axis and removes it from the output, reducing dimensionality by one.
233
+ """
234
+
235
+ def _process(self, message: AxisArray) -> AxisArray:
236
+ xp = get_namespace(message.data)
237
+ axis_idx = message.get_axis_idx(self.settings.axis)
238
+ op = self.settings.operation
239
+
240
+ if op == AggregationFunction.NONE:
241
+ raise ValueError("AggregationFunction.NONE is not supported for full-axis aggregation")
242
+
243
+ if op == AggregationFunction.TRAPEZOID:
244
+ # Trapezoid integration requires x-coordinates
245
+ target_axis = message.get_axis(self.settings.axis)
246
+ if hasattr(target_axis, "data"):
247
+ x = target_axis.data
248
+ else:
249
+ x = target_axis.value(np.arange(message.data.shape[axis_idx]))
250
+ agg_data = np.trapezoid(np.asarray(message.data), x=x, axis=axis_idx)
251
+ else:
252
+ # Try array-API compatible function first, fall back to numpy
253
+ func_name = op.value
254
+ if hasattr(xp, func_name):
255
+ agg_data = getattr(xp, func_name)(message.data, axis=axis_idx)
256
+ else:
257
+ agg_data = AGGREGATORS[op](message.data, axis=axis_idx)
258
+
259
+ new_dims = list(message.dims)
260
+ new_dims.pop(axis_idx)
261
+
262
+ new_axes = dict(message.axes)
263
+ new_axes.pop(self.settings.axis, None)
264
+
265
+ return replace(
266
+ message,
267
+ data=agg_data,
268
+ dims=new_dims,
269
+ axes=new_axes,
270
+ )
271
+
272
+
273
+ class AggregateUnit(BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]):
274
+ """Unit that aggregates an entire axis using a specified operation."""
275
+
276
+ SETTINGS = AggregateSettings
@@ -0,0 +1,80 @@
1
+ from dataclasses import field
2
+
3
+ import ezmsg.core as ez
4
+ from ezmsg.baseproc import (
5
+ BaseProcessor,
6
+ BaseStatefulProcessor,
7
+ BaseTransformerUnit,
8
+ CompositeProcessor,
9
+ )
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
+
12
+ from .aggregate import (
13
+ AggregationFunction,
14
+ RangedAggregateSettings,
15
+ RangedAggregateTransformer,
16
+ )
17
+ from .spectrogram import SpectrogramSettings, SpectrogramTransformer
18
+
19
+
20
+ class BandPowerSettings(ez.Settings):
21
+ """
22
+ Settings for ``BandPower``.
23
+ """
24
+
25
+ spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
26
+ """
27
+ Settings for spectrogram calculation.
28
+ """
29
+
30
+ bands: list[tuple[float, float]] | None = field(default_factory=lambda: [(17, 30), (70, 170)])
31
+ """
32
+ (min, max) tuples of band limits in Hz.
33
+ """
34
+
35
+ aggregation: AggregationFunction = AggregationFunction.MEAN
36
+ """:obj:`AggregationFunction` to apply to each band."""
37
+
38
+
39
+ class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, AxisArray]):
40
+ @staticmethod
41
+ def _initialize_processors(
42
+ settings: BandPowerSettings,
43
+ ) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
44
+ return {
45
+ "spectrogram": SpectrogramTransformer(settings=settings.spectrogram_settings),
46
+ "aggregate": RangedAggregateTransformer(
47
+ settings=RangedAggregateSettings(
48
+ axis="freq",
49
+ bands=settings.bands,
50
+ operation=settings.aggregation,
51
+ )
52
+ ),
53
+ }
54
+
55
+
56
+ class BandPower(BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]):
57
+ SETTINGS = BandPowerSettings
58
+
59
+
60
+ def bandpower(
61
+ spectrogram_settings: SpectrogramSettings,
62
+ bands: list[tuple[float, float]] | None = [
63
+ (17, 30),
64
+ (70, 170),
65
+ ],
66
+ aggregation: AggregationFunction = AggregationFunction.MEAN,
67
+ ) -> BandPowerTransformer:
68
+ """
69
+ Calculate the average spectral power in each band.
70
+
71
+ Returns:
72
+ :obj:`BandPowerTransformer`
73
+ """
74
+ return BandPowerTransformer(
75
+ settings=BandPowerSettings(
76
+ spectrogram_settings=spectrogram_settings,
77
+ bands=bands,
78
+ aggregation=aggregation,
79
+ )
80
+ )
ezmsg/sigproc/base.py ADDED
@@ -0,0 +1,149 @@
1
+ """
2
+ Backwards-compatible re-exports from ezmsg.baseproc.
3
+
4
+ This module re-exports all symbols from ezmsg.baseproc to maintain backwards
5
+ compatibility for code that imports from ezmsg.sigproc.base.
6
+
7
+ New code should import directly from ezmsg.baseproc instead.
8
+ """
9
+
10
+ import warnings
11
+
12
+ warnings.warn(
13
+ "Importing from 'ezmsg.sigproc.base' is deprecated. Please import from 'ezmsg.baseproc' instead.",
14
+ DeprecationWarning,
15
+ stacklevel=2,
16
+ )
17
+
18
+ # Re-export everything from ezmsg.baseproc for backwards compatibility
19
+ from ezmsg.baseproc import ( # noqa: E402
20
+ # Protocols
21
+ AdaptiveTransformer,
22
+ # Type variables
23
+ AdaptiveTransformerType,
24
+ # Stateful classes
25
+ BaseAdaptiveTransformer,
26
+ # Unit classes
27
+ BaseAdaptiveTransformerUnit,
28
+ BaseAsyncTransformer,
29
+ # Base processor classes
30
+ BaseConsumer,
31
+ BaseConsumerUnit,
32
+ BaseProcessor,
33
+ BaseProcessorUnit,
34
+ BaseProducer,
35
+ BaseProducerUnit,
36
+ BaseStatefulConsumer,
37
+ BaseStatefulProcessor,
38
+ BaseStatefulProducer,
39
+ BaseStatefulTransformer,
40
+ BaseTransformer,
41
+ BaseTransformerUnit,
42
+ # Composite classes
43
+ CompositeProcessor,
44
+ CompositeProducer,
45
+ CompositeStateful,
46
+ Consumer,
47
+ ConsumerType,
48
+ GenAxisArray,
49
+ MessageInType,
50
+ MessageOutType,
51
+ Processor,
52
+ Producer,
53
+ ProducerType,
54
+ # Message types
55
+ SampleMessage,
56
+ SettingsType,
57
+ Stateful,
58
+ StatefulConsumer,
59
+ StatefulProcessor,
60
+ StatefulProducer,
61
+ StatefulTransformer,
62
+ StateType,
63
+ Transformer,
64
+ TransformerType,
65
+ # Type resolution helpers
66
+ _get_base_processor_message_in_type,
67
+ _get_base_processor_message_out_type,
68
+ _get_base_processor_settings_type,
69
+ _get_base_processor_state_type,
70
+ _get_processor_message_type,
71
+ # Type utilities
72
+ check_message_type_compatibility,
73
+ get_base_adaptive_transformer_type,
74
+ get_base_consumer_type,
75
+ get_base_producer_type,
76
+ get_base_transformer_type,
77
+ is_sample_message,
78
+ # Decorators
79
+ processor_state,
80
+ # Profiling
81
+ profile_subpub,
82
+ resolve_typevar,
83
+ )
84
+
85
+ __all__ = [
86
+ # Protocols
87
+ "Processor",
88
+ "Producer",
89
+ "Consumer",
90
+ "Transformer",
91
+ "StatefulProcessor",
92
+ "StatefulProducer",
93
+ "StatefulConsumer",
94
+ "StatefulTransformer",
95
+ "AdaptiveTransformer",
96
+ # Type variables
97
+ "MessageInType",
98
+ "MessageOutType",
99
+ "SettingsType",
100
+ "StateType",
101
+ "ProducerType",
102
+ "ConsumerType",
103
+ "TransformerType",
104
+ "AdaptiveTransformerType",
105
+ # Decorators
106
+ "processor_state",
107
+ # Base processor classes
108
+ "BaseProcessor",
109
+ "BaseProducer",
110
+ "BaseConsumer",
111
+ "BaseTransformer",
112
+ # Stateful classes
113
+ "Stateful",
114
+ "BaseStatefulProcessor",
115
+ "BaseStatefulProducer",
116
+ "BaseStatefulConsumer",
117
+ "BaseStatefulTransformer",
118
+ "BaseAdaptiveTransformer",
119
+ "BaseAsyncTransformer",
120
+ # Composite classes
121
+ "CompositeStateful",
122
+ "CompositeProcessor",
123
+ "CompositeProducer",
124
+ # Unit classes
125
+ "BaseProducerUnit",
126
+ "BaseProcessorUnit",
127
+ "BaseConsumerUnit",
128
+ "BaseTransformerUnit",
129
+ "BaseAdaptiveTransformerUnit",
130
+ "GenAxisArray",
131
+ # Type resolution helpers
132
+ "get_base_producer_type",
133
+ "get_base_consumer_type",
134
+ "get_base_transformer_type",
135
+ "get_base_adaptive_transformer_type",
136
+ "_get_base_processor_settings_type",
137
+ "_get_base_processor_message_in_type",
138
+ "_get_base_processor_message_out_type",
139
+ "_get_base_processor_state_type",
140
+ "_get_processor_message_type",
141
+ # Message types
142
+ "SampleMessage",
143
+ "is_sample_message",
144
+ # Profiling
145
+ "profile_subpub",
146
+ # Type utilities
147
+ "check_message_type_compatibility",
148
+ "resolve_typevar",
149
+ ]
@@ -1,18 +1,59 @@
1
- import ezmsg.core as ez
1
+ import functools
2
+ import typing
3
+
2
4
  import scipy.signal
3
- import numpy as np
5
+ from scipy.signal import normalize
6
+
7
+ from .filter import (
8
+ BACoeffs,
9
+ BaseFilterByDesignTransformerUnit,
10
+ FilterBaseSettings,
11
+ FilterByDesignTransformer,
12
+ SOSCoeffs,
13
+ )
4
14
 
5
- from .filter import Filter, FilterState, FilterSettingsBase
6
15
 
7
- from typing import Optional, Tuple, Union
16
+ class ButterworthFilterSettings(FilterBaseSettings):
17
+ """Settings for :obj:`ButterworthFilter`."""
8
18
 
19
+ # axis and coef_type are inherited from FilterBaseSettings
9
20
 
10
- class ButterworthFilterSettings(FilterSettingsBase):
11
21
  order: int = 0
12
- cuton: Optional[float] = None # Hz
13
- cutoff: Optional[float] = None # Hz
22
+ """
23
+ Filter order
24
+ """
25
+
26
+ cuton: float | None = None
27
+ """
28
+ Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
29
+ if this is lower than `cutoff` then this is the beginning of the bandpass
30
+ or if this is greater than `cutoff` then this is the end of the bandstop.
31
+ """
32
+
33
+ cutoff: float | None = None
34
+ """
35
+ Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
36
+ if this is greater than `cuton` then this is the end of the bandpass,
37
+ or if this is less than `cuton` then this is the beginning of the bandstop.
38
+ """
39
+
40
+ wn_hz: bool = True
41
+ """
42
+ Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
43
+ """
44
+
45
+ def filter_specs(
46
+ self,
47
+ ) -> tuple[str, float | tuple[float, float]] | None:
48
+ """
49
+ Determine the filter type given the corner frequencies.
14
50
 
15
- def filter_specs(self) -> Optional[Tuple[str, Union[float, Tuple[float, float]]]]:
51
+ Returns:
52
+ A tuple with the first element being a string indicating the filter type
53
+ (one of "lowpass", "highpass", "bandpass", "bandstop")
54
+ and the second element being the corner frequency or frequencies.
55
+
56
+ """
16
57
  if self.cuton is None and self.cutoff is None:
17
58
  return None
18
59
  elif self.cuton is None and self.cutoff is not None:
@@ -26,41 +67,90 @@ class ButterworthFilterSettings(FilterSettingsBase):
26
67
  return "bandstop", (self.cutoff, self.cuton)
27
68
 
28
69
 
29
- class ButterworthFilterState(FilterState):
30
- design: ButterworthFilterSettings
70
+ def butter_design_fun(
71
+ fs: float,
72
+ order: int = 0,
73
+ cuton: float | None = None,
74
+ cutoff: float | None = None,
75
+ coef_type: str = "ba",
76
+ wn_hz: bool = True,
77
+ ) -> BACoeffs | SOSCoeffs | None:
78
+ """
79
+ See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
80
+ filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
81
+ You are likely to want to use this function with :obj:`filter_by_design`, which only passes `fs` to the design
82
+ function (this), meaning that you should wrap this function with a lambda or prepare with functools.partial.
31
83
 
84
+ Args:
85
+ fs: The sampling frequency of the data in Hz.
86
+ order: Filter order.
87
+ cuton: Corner frequency of the filter in Hz.
88
+ cutoff: Corner frequency of the filter in Hz.
89
+ coef_type: "ba", "sos", or "zpk"
90
+ wn_hz: Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
32
91
 
33
- class ButterworthFilter(Filter):
34
- SETTINGS: ButterworthFilterSettings
35
- STATE: ButterworthFilterState
92
+ Returns:
93
+ The filter coefficients as a tuple of (b, a) for coef_type "ba", or as a single ndarray for "sos",
94
+ or (z, p, k) for "zpk".
36
95
 
37
- INPUT_FILTER = ez.InputStream(ButterworthFilterSettings)
96
+ """
97
+ coefs = None
98
+ if order > 0:
99
+ btype, cutoffs = ButterworthFilterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs()
100
+ coefs = scipy.signal.butter(
101
+ order,
102
+ Wn=cutoffs,
103
+ btype=btype,
104
+ fs=fs if wn_hz else None,
105
+ output=coef_type,
106
+ )
107
+ if coefs is not None and coef_type == "ba":
108
+ coefs = normalize(*coefs)
109
+ return coefs
38
110
 
39
- def initialize(self) -> None:
40
- self.STATE.design = self.SETTINGS
41
- self.STATE.filt_designed = True
42
- super().initialize()
43
111
 
44
- def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
45
- specs = self.STATE.design.filter_specs()
46
- if self.STATE.design.order > 0 and specs is not None:
47
- btype, cut = specs
48
- return scipy.signal.butter(
49
- self.STATE.design.order,
50
- Wn=cut,
51
- btype=btype,
52
- fs=self.STATE.fs,
53
- output="ba",
54
- )
55
- else:
56
- return None
112
+ class ButterworthFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]):
113
+ def get_design_function(
114
+ self,
115
+ ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
116
+ return functools.partial(
117
+ butter_design_fun,
118
+ order=self.settings.order,
119
+ cuton=self.settings.cuton,
120
+ cutoff=self.settings.cutoff,
121
+ coef_type=self.settings.coef_type,
122
+ wn_hz=self.settings.wn_hz,
123
+ )
124
+
125
+
126
+ class ButterworthFilter(BaseFilterByDesignTransformerUnit[ButterworthFilterSettings, ButterworthFilterTransformer]):
127
+ SETTINGS = ButterworthFilterSettings
128
+
57
129
 
58
- @ez.subscriber(INPUT_FILTER)
59
- async def redesign(self, message: ButterworthFilterSettings) -> None:
60
- if type(message) is not ButterworthFilterSettings:
61
- return
130
+ def butter(
131
+ axis: str | None,
132
+ order: int = 0,
133
+ cuton: float | None = None,
134
+ cutoff: float | None = None,
135
+ coef_type: str = "ba",
136
+ wn_hz: bool = True,
137
+ ) -> ButterworthFilterTransformer:
138
+ """
139
+ Convenience generator wrapping filter_gen_by_design for Butterworth filters.
140
+ Apply Butterworth filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
141
+ See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
142
+ filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
62
143
 
63
- if self.STATE.design.order != message.order:
64
- self.STATE.zi = None
65
- self.STATE.design = message
66
- self.update_filter()
144
+ Returns:
145
+ :obj:`ButterworthFilterTransformer`
146
+ """
147
+ return ButterworthFilterTransformer(
148
+ ButterworthFilterSettings(
149
+ axis=axis,
150
+ order=order,
151
+ cuton=cuton,
152
+ cutoff=cutoff,
153
+ coef_type=coef_type,
154
+ wn_hz=wn_hz,
155
+ )
156
+ )