ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Aggregation operations over arrays.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
:obj:`AggregateTransformer` supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
:obj:`RangedAggregateTransformer` currently requires NumPy arrays.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import typing
|
|
11
|
+
|
|
12
|
+
import ezmsg.core as ez
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
15
|
+
from array_api_compat import get_namespace
|
|
16
|
+
from ezmsg.baseproc import (
|
|
17
|
+
BaseStatefulTransformer,
|
|
18
|
+
BaseTransformer,
|
|
19
|
+
BaseTransformerUnit,
|
|
20
|
+
processor_state,
|
|
21
|
+
)
|
|
22
|
+
from ezmsg.util.messages.axisarray import (
|
|
23
|
+
AxisArray,
|
|
24
|
+
AxisBase,
|
|
25
|
+
replace,
|
|
26
|
+
slice_along_axis,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from .spectral import OptionsEnum
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AggregationFunction(OptionsEnum):
|
|
33
|
+
"""Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation."""
|
|
34
|
+
|
|
35
|
+
NONE = "None (all)"
|
|
36
|
+
MAX = "max"
|
|
37
|
+
MIN = "min"
|
|
38
|
+
MEAN = "mean"
|
|
39
|
+
MEDIAN = "median"
|
|
40
|
+
STD = "std"
|
|
41
|
+
SUM = "sum"
|
|
42
|
+
NANMAX = "nanmax"
|
|
43
|
+
NANMIN = "nanmin"
|
|
44
|
+
NANMEAN = "nanmean"
|
|
45
|
+
NANMEDIAN = "nanmedian"
|
|
46
|
+
NANSTD = "nanstd"
|
|
47
|
+
NANSUM = "nansum"
|
|
48
|
+
ARGMIN = "argmin"
|
|
49
|
+
ARGMAX = "argmax"
|
|
50
|
+
TRAPEZOID = "trapezoid"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
AGGREGATORS = {
|
|
54
|
+
AggregationFunction.NONE: np.all,
|
|
55
|
+
AggregationFunction.MAX: np.max,
|
|
56
|
+
AggregationFunction.MIN: np.min,
|
|
57
|
+
AggregationFunction.MEAN: np.mean,
|
|
58
|
+
AggregationFunction.MEDIAN: np.median,
|
|
59
|
+
AggregationFunction.STD: np.std,
|
|
60
|
+
AggregationFunction.SUM: np.sum,
|
|
61
|
+
AggregationFunction.NANMAX: np.nanmax,
|
|
62
|
+
AggregationFunction.NANMIN: np.nanmin,
|
|
63
|
+
AggregationFunction.NANMEAN: np.nanmean,
|
|
64
|
+
AggregationFunction.NANMEDIAN: np.nanmedian,
|
|
65
|
+
AggregationFunction.NANSTD: np.nanstd,
|
|
66
|
+
AggregationFunction.NANSUM: np.nansum,
|
|
67
|
+
AggregationFunction.ARGMIN: np.argmin,
|
|
68
|
+
AggregationFunction.ARGMAX: np.argmax,
|
|
69
|
+
# Note: Some methods require x-coordinates and
|
|
70
|
+
# are handled specially in `_process`.
|
|
71
|
+
AggregationFunction.TRAPEZOID: np.trapezoid,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RangedAggregateSettings(ez.Settings):
|
|
76
|
+
"""
|
|
77
|
+
Settings for ``RangedAggregate``.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
axis: str | None = None
|
|
81
|
+
"""The name of the axis along which to apply the bands."""
|
|
82
|
+
|
|
83
|
+
bands: list[tuple[float, float]] | None = None
|
|
84
|
+
"""
|
|
85
|
+
[(band1_min, band1_max), (band2_min, band2_max), ...]
|
|
86
|
+
If not set then this acts as a passthrough node.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
operation: AggregationFunction = AggregationFunction.MEAN
|
|
90
|
+
""":obj:`AggregationFunction` to apply to each band."""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@processor_state
|
|
94
|
+
class RangedAggregateState:
|
|
95
|
+
slices: list[tuple[typing.Any, ...]] | None = None
|
|
96
|
+
out_axis: AxisBase | None = None
|
|
97
|
+
ax_vec: npt.NDArray | None = None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class RangedAggregateTransformer(
|
|
101
|
+
BaseStatefulTransformer[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateState]
|
|
102
|
+
):
|
|
103
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
104
|
+
# Override for shortcut passthrough mode.
|
|
105
|
+
if self.settings.bands is None:
|
|
106
|
+
return message
|
|
107
|
+
return super().__call__(message)
|
|
108
|
+
|
|
109
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
110
|
+
axis = self.settings.axis or message.dims[0]
|
|
111
|
+
target_axis = message.get_axis(axis)
|
|
112
|
+
|
|
113
|
+
hash_components = (message.key,)
|
|
114
|
+
if hasattr(target_axis, "data"):
|
|
115
|
+
hash_components += (len(target_axis.data),)
|
|
116
|
+
elif isinstance(target_axis, AxisArray.LinearAxis):
|
|
117
|
+
hash_components += (target_axis.gain, target_axis.offset)
|
|
118
|
+
return hash(hash_components)
|
|
119
|
+
|
|
120
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
121
|
+
axis = self.settings.axis or message.dims[0]
|
|
122
|
+
target_axis = message.get_axis(axis)
|
|
123
|
+
ax_idx = message.get_axis_idx(axis)
|
|
124
|
+
|
|
125
|
+
if hasattr(target_axis, "data"):
|
|
126
|
+
self._state.ax_vec = target_axis.data
|
|
127
|
+
else:
|
|
128
|
+
self._state.ax_vec = target_axis.value(np.arange(message.data.shape[ax_idx]))
|
|
129
|
+
|
|
130
|
+
ax_dat = []
|
|
131
|
+
slices = []
|
|
132
|
+
for start, stop in self.settings.bands:
|
|
133
|
+
inds = np.where(np.logical_and(self._state.ax_vec >= start, self._state.ax_vec <= stop))[0]
|
|
134
|
+
slices.append(np.s_[inds[0] : inds[-1] + 1])
|
|
135
|
+
if hasattr(target_axis, "data"):
|
|
136
|
+
if self._state.ax_vec.dtype.type is np.str_:
|
|
137
|
+
sl_dat = f"{self._state.ax_vec[start]} - {self._state.ax_vec[stop]}"
|
|
138
|
+
else:
|
|
139
|
+
ax_dat.append(np.mean(self._state.ax_vec[inds]))
|
|
140
|
+
else:
|
|
141
|
+
sl_dat = target_axis.value(np.mean(inds))
|
|
142
|
+
ax_dat.append(sl_dat)
|
|
143
|
+
|
|
144
|
+
self._state.slices = slices
|
|
145
|
+
self._state.out_axis = AxisArray.CoordinateAxis(
|
|
146
|
+
data=np.array(ax_dat),
|
|
147
|
+
dims=[axis],
|
|
148
|
+
unit=target_axis.unit,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
152
|
+
axis = self.settings.axis or message.dims[0]
|
|
153
|
+
ax_idx = message.get_axis_idx(axis)
|
|
154
|
+
agg_func = AGGREGATORS[self.settings.operation]
|
|
155
|
+
|
|
156
|
+
if self.settings.operation in [
|
|
157
|
+
AggregationFunction.TRAPEZOID,
|
|
158
|
+
]:
|
|
159
|
+
# Special handling for methods that require x-coordinates.
|
|
160
|
+
out_data = [
|
|
161
|
+
agg_func(
|
|
162
|
+
slice_along_axis(message.data, sl, axis=ax_idx),
|
|
163
|
+
x=self._state.ax_vec[sl],
|
|
164
|
+
axis=ax_idx,
|
|
165
|
+
)
|
|
166
|
+
for sl in self._state.slices
|
|
167
|
+
]
|
|
168
|
+
else:
|
|
169
|
+
out_data = [
|
|
170
|
+
agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx) for sl in self._state.slices
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
msg_out = replace(
|
|
174
|
+
message,
|
|
175
|
+
data=np.stack(out_data, axis=ax_idx),
|
|
176
|
+
axes={**message.axes, axis: self._state.out_axis},
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if self.settings.operation in [
|
|
180
|
+
AggregationFunction.ARGMIN,
|
|
181
|
+
AggregationFunction.ARGMAX,
|
|
182
|
+
]:
|
|
183
|
+
out_data = []
|
|
184
|
+
for sl_ix, sl in enumerate(self._state.slices):
|
|
185
|
+
offsets = np.take(msg_out.data, [sl_ix], axis=ax_idx)
|
|
186
|
+
out_data.append(self._state.ax_vec[sl][offsets])
|
|
187
|
+
msg_out.data = np.concatenate(out_data, axis=ax_idx)
|
|
188
|
+
|
|
189
|
+
return msg_out
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class RangedAggregate(BaseTransformerUnit[RangedAggregateSettings, AxisArray, AxisArray, RangedAggregateTransformer]):
|
|
193
|
+
SETTINGS = RangedAggregateSettings
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def ranged_aggregate(
|
|
197
|
+
axis: str | None = None,
|
|
198
|
+
bands: list[tuple[float, float]] | None = None,
|
|
199
|
+
operation: AggregationFunction = AggregationFunction.MEAN,
|
|
200
|
+
) -> RangedAggregateTransformer:
|
|
201
|
+
"""
|
|
202
|
+
Apply an aggregation operation over one or more bands.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
axis: The name of the axis along which to apply the bands.
|
|
206
|
+
bands: [(band1_min, band1_max), (band2_min, band2_max), ...]
|
|
207
|
+
If not set then this acts as a passthrough node.
|
|
208
|
+
operation: :obj:`AggregationFunction` to apply to each band.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
:obj:`RangedAggregateTransformer`
|
|
212
|
+
"""
|
|
213
|
+
return RangedAggregateTransformer(RangedAggregateSettings(axis=axis, bands=bands, operation=operation))
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class AggregateSettings(ez.Settings):
|
|
217
|
+
"""Settings for :obj:`Aggregate`."""
|
|
218
|
+
|
|
219
|
+
axis: str
|
|
220
|
+
"""The name of the axis to aggregate over. This axis will be removed from the output."""
|
|
221
|
+
|
|
222
|
+
operation: AggregationFunction = AggregationFunction.MEAN
|
|
223
|
+
""":obj:`AggregationFunction` to apply."""
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class AggregateTransformer(BaseTransformer[AggregateSettings, AxisArray, AxisArray]):
|
|
227
|
+
"""
|
|
228
|
+
Transformer that aggregates an entire axis using a specified operation.
|
|
229
|
+
|
|
230
|
+
Unlike :obj:`RangedAggregateTransformer` which aggregates over specific ranges/bands
|
|
231
|
+
and preserves the axis (with one value per band), this transformer aggregates the
|
|
232
|
+
entire axis and removes it from the output, reducing dimensionality by one.
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
236
|
+
xp = get_namespace(message.data)
|
|
237
|
+
axis_idx = message.get_axis_idx(self.settings.axis)
|
|
238
|
+
op = self.settings.operation
|
|
239
|
+
|
|
240
|
+
if op == AggregationFunction.NONE:
|
|
241
|
+
raise ValueError("AggregationFunction.NONE is not supported for full-axis aggregation")
|
|
242
|
+
|
|
243
|
+
if op == AggregationFunction.TRAPEZOID:
|
|
244
|
+
# Trapezoid integration requires x-coordinates
|
|
245
|
+
target_axis = message.get_axis(self.settings.axis)
|
|
246
|
+
if hasattr(target_axis, "data"):
|
|
247
|
+
x = target_axis.data
|
|
248
|
+
else:
|
|
249
|
+
x = target_axis.value(np.arange(message.data.shape[axis_idx]))
|
|
250
|
+
agg_data = np.trapezoid(np.asarray(message.data), x=x, axis=axis_idx)
|
|
251
|
+
else:
|
|
252
|
+
# Try array-API compatible function first, fall back to numpy
|
|
253
|
+
func_name = op.value
|
|
254
|
+
if hasattr(xp, func_name):
|
|
255
|
+
agg_data = getattr(xp, func_name)(message.data, axis=axis_idx)
|
|
256
|
+
else:
|
|
257
|
+
agg_data = AGGREGATORS[op](message.data, axis=axis_idx)
|
|
258
|
+
|
|
259
|
+
new_dims = list(message.dims)
|
|
260
|
+
new_dims.pop(axis_idx)
|
|
261
|
+
|
|
262
|
+
new_axes = dict(message.axes)
|
|
263
|
+
new_axes.pop(self.settings.axis, None)
|
|
264
|
+
|
|
265
|
+
return replace(
|
|
266
|
+
message,
|
|
267
|
+
data=agg_data,
|
|
268
|
+
dims=new_dims,
|
|
269
|
+
axes=new_axes,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class AggregateUnit(BaseTransformerUnit[AggregateSettings, AxisArray, AxisArray, AggregateTransformer]):
|
|
274
|
+
"""Unit that aggregates an entire axis using a specified operation."""
|
|
275
|
+
|
|
276
|
+
SETTINGS = AggregateSettings
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from dataclasses import field
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseProcessor,
|
|
6
|
+
BaseStatefulProcessor,
|
|
7
|
+
BaseTransformerUnit,
|
|
8
|
+
CompositeProcessor,
|
|
9
|
+
)
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
|
|
12
|
+
from .aggregate import (
|
|
13
|
+
AggregationFunction,
|
|
14
|
+
RangedAggregateSettings,
|
|
15
|
+
RangedAggregateTransformer,
|
|
16
|
+
)
|
|
17
|
+
from .spectrogram import SpectrogramSettings, SpectrogramTransformer
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BandPowerSettings(ez.Settings):
|
|
21
|
+
"""
|
|
22
|
+
Settings for ``BandPower``.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
|
|
26
|
+
"""
|
|
27
|
+
Settings for spectrogram calculation.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
bands: list[tuple[float, float]] | None = field(default_factory=lambda: [(17, 30), (70, 170)])
|
|
31
|
+
"""
|
|
32
|
+
(min, max) tuples of band limits in Hz.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
aggregation: AggregationFunction = AggregationFunction.MEAN
|
|
36
|
+
""":obj:`AggregationFunction` to apply to each band."""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, AxisArray]):
|
|
40
|
+
@staticmethod
|
|
41
|
+
def _initialize_processors(
|
|
42
|
+
settings: BandPowerSettings,
|
|
43
|
+
) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
|
|
44
|
+
return {
|
|
45
|
+
"spectrogram": SpectrogramTransformer(settings=settings.spectrogram_settings),
|
|
46
|
+
"aggregate": RangedAggregateTransformer(
|
|
47
|
+
settings=RangedAggregateSettings(
|
|
48
|
+
axis="freq",
|
|
49
|
+
bands=settings.bands,
|
|
50
|
+
operation=settings.aggregation,
|
|
51
|
+
)
|
|
52
|
+
),
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class BandPower(BaseTransformerUnit[BandPowerSettings, AxisArray, AxisArray, BandPowerTransformer]):
|
|
57
|
+
SETTINGS = BandPowerSettings
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def bandpower(
|
|
61
|
+
spectrogram_settings: SpectrogramSettings,
|
|
62
|
+
bands: list[tuple[float, float]] | None = [
|
|
63
|
+
(17, 30),
|
|
64
|
+
(70, 170),
|
|
65
|
+
],
|
|
66
|
+
aggregation: AggregationFunction = AggregationFunction.MEAN,
|
|
67
|
+
) -> BandPowerTransformer:
|
|
68
|
+
"""
|
|
69
|
+
Calculate the average spectral power in each band.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
:obj:`BandPowerTransformer`
|
|
73
|
+
"""
|
|
74
|
+
return BandPowerTransformer(
|
|
75
|
+
settings=BandPowerSettings(
|
|
76
|
+
spectrogram_settings=spectrogram_settings,
|
|
77
|
+
bands=bands,
|
|
78
|
+
aggregation=aggregation,
|
|
79
|
+
)
|
|
80
|
+
)
|
ezmsg/sigproc/base.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.
|
|
3
|
+
|
|
4
|
+
This module re-exports all symbols from ezmsg.baseproc to maintain backwards
|
|
5
|
+
compatibility for code that imports from ezmsg.sigproc.base.
|
|
6
|
+
|
|
7
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
warnings.warn(
|
|
13
|
+
"Importing from 'ezmsg.sigproc.base' is deprecated. Please import from 'ezmsg.baseproc' instead.",
|
|
14
|
+
DeprecationWarning,
|
|
15
|
+
stacklevel=2,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Re-export everything from ezmsg.baseproc for backwards compatibility
|
|
19
|
+
from ezmsg.baseproc import ( # noqa: E402
|
|
20
|
+
# Protocols
|
|
21
|
+
AdaptiveTransformer,
|
|
22
|
+
# Type variables
|
|
23
|
+
AdaptiveTransformerType,
|
|
24
|
+
# Stateful classes
|
|
25
|
+
BaseAdaptiveTransformer,
|
|
26
|
+
# Unit classes
|
|
27
|
+
BaseAdaptiveTransformerUnit,
|
|
28
|
+
BaseAsyncTransformer,
|
|
29
|
+
# Base processor classes
|
|
30
|
+
BaseConsumer,
|
|
31
|
+
BaseConsumerUnit,
|
|
32
|
+
BaseProcessor,
|
|
33
|
+
BaseProcessorUnit,
|
|
34
|
+
BaseProducer,
|
|
35
|
+
BaseProducerUnit,
|
|
36
|
+
BaseStatefulConsumer,
|
|
37
|
+
BaseStatefulProcessor,
|
|
38
|
+
BaseStatefulProducer,
|
|
39
|
+
BaseStatefulTransformer,
|
|
40
|
+
BaseTransformer,
|
|
41
|
+
BaseTransformerUnit,
|
|
42
|
+
# Composite classes
|
|
43
|
+
CompositeProcessor,
|
|
44
|
+
CompositeProducer,
|
|
45
|
+
CompositeStateful,
|
|
46
|
+
Consumer,
|
|
47
|
+
ConsumerType,
|
|
48
|
+
GenAxisArray,
|
|
49
|
+
MessageInType,
|
|
50
|
+
MessageOutType,
|
|
51
|
+
Processor,
|
|
52
|
+
Producer,
|
|
53
|
+
ProducerType,
|
|
54
|
+
# Message types
|
|
55
|
+
SampleMessage,
|
|
56
|
+
SettingsType,
|
|
57
|
+
Stateful,
|
|
58
|
+
StatefulConsumer,
|
|
59
|
+
StatefulProcessor,
|
|
60
|
+
StatefulProducer,
|
|
61
|
+
StatefulTransformer,
|
|
62
|
+
StateType,
|
|
63
|
+
Transformer,
|
|
64
|
+
TransformerType,
|
|
65
|
+
# Type resolution helpers
|
|
66
|
+
_get_base_processor_message_in_type,
|
|
67
|
+
_get_base_processor_message_out_type,
|
|
68
|
+
_get_base_processor_settings_type,
|
|
69
|
+
_get_base_processor_state_type,
|
|
70
|
+
_get_processor_message_type,
|
|
71
|
+
# Type utilities
|
|
72
|
+
check_message_type_compatibility,
|
|
73
|
+
get_base_adaptive_transformer_type,
|
|
74
|
+
get_base_consumer_type,
|
|
75
|
+
get_base_producer_type,
|
|
76
|
+
get_base_transformer_type,
|
|
77
|
+
is_sample_message,
|
|
78
|
+
# Decorators
|
|
79
|
+
processor_state,
|
|
80
|
+
# Profiling
|
|
81
|
+
profile_subpub,
|
|
82
|
+
resolve_typevar,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
__all__ = [
|
|
86
|
+
# Protocols
|
|
87
|
+
"Processor",
|
|
88
|
+
"Producer",
|
|
89
|
+
"Consumer",
|
|
90
|
+
"Transformer",
|
|
91
|
+
"StatefulProcessor",
|
|
92
|
+
"StatefulProducer",
|
|
93
|
+
"StatefulConsumer",
|
|
94
|
+
"StatefulTransformer",
|
|
95
|
+
"AdaptiveTransformer",
|
|
96
|
+
# Type variables
|
|
97
|
+
"MessageInType",
|
|
98
|
+
"MessageOutType",
|
|
99
|
+
"SettingsType",
|
|
100
|
+
"StateType",
|
|
101
|
+
"ProducerType",
|
|
102
|
+
"ConsumerType",
|
|
103
|
+
"TransformerType",
|
|
104
|
+
"AdaptiveTransformerType",
|
|
105
|
+
# Decorators
|
|
106
|
+
"processor_state",
|
|
107
|
+
# Base processor classes
|
|
108
|
+
"BaseProcessor",
|
|
109
|
+
"BaseProducer",
|
|
110
|
+
"BaseConsumer",
|
|
111
|
+
"BaseTransformer",
|
|
112
|
+
# Stateful classes
|
|
113
|
+
"Stateful",
|
|
114
|
+
"BaseStatefulProcessor",
|
|
115
|
+
"BaseStatefulProducer",
|
|
116
|
+
"BaseStatefulConsumer",
|
|
117
|
+
"BaseStatefulTransformer",
|
|
118
|
+
"BaseAdaptiveTransformer",
|
|
119
|
+
"BaseAsyncTransformer",
|
|
120
|
+
# Composite classes
|
|
121
|
+
"CompositeStateful",
|
|
122
|
+
"CompositeProcessor",
|
|
123
|
+
"CompositeProducer",
|
|
124
|
+
# Unit classes
|
|
125
|
+
"BaseProducerUnit",
|
|
126
|
+
"BaseProcessorUnit",
|
|
127
|
+
"BaseConsumerUnit",
|
|
128
|
+
"BaseTransformerUnit",
|
|
129
|
+
"BaseAdaptiveTransformerUnit",
|
|
130
|
+
"GenAxisArray",
|
|
131
|
+
# Type resolution helpers
|
|
132
|
+
"get_base_producer_type",
|
|
133
|
+
"get_base_consumer_type",
|
|
134
|
+
"get_base_transformer_type",
|
|
135
|
+
"get_base_adaptive_transformer_type",
|
|
136
|
+
"_get_base_processor_settings_type",
|
|
137
|
+
"_get_base_processor_message_in_type",
|
|
138
|
+
"_get_base_processor_message_out_type",
|
|
139
|
+
"_get_base_processor_state_type",
|
|
140
|
+
"_get_processor_message_type",
|
|
141
|
+
# Message types
|
|
142
|
+
"SampleMessage",
|
|
143
|
+
"is_sample_message",
|
|
144
|
+
# Profiling
|
|
145
|
+
"profile_subpub",
|
|
146
|
+
# Type utilities
|
|
147
|
+
"check_message_type_compatibility",
|
|
148
|
+
"resolve_typevar",
|
|
149
|
+
]
|
|
@@ -1,18 +1,59 @@
|
|
|
1
|
-
import
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
|
|
2
4
|
import scipy.signal
|
|
3
|
-
|
|
5
|
+
from scipy.signal import normalize
|
|
6
|
+
|
|
7
|
+
from .filter import (
|
|
8
|
+
BACoeffs,
|
|
9
|
+
BaseFilterByDesignTransformerUnit,
|
|
10
|
+
FilterBaseSettings,
|
|
11
|
+
FilterByDesignTransformer,
|
|
12
|
+
SOSCoeffs,
|
|
13
|
+
)
|
|
4
14
|
|
|
5
|
-
from .filter import Filter, FilterState, FilterSettingsBase
|
|
6
15
|
|
|
7
|
-
|
|
16
|
+
class ButterworthFilterSettings(FilterBaseSettings):
|
|
17
|
+
"""Settings for :obj:`ButterworthFilter`."""
|
|
8
18
|
|
|
19
|
+
# axis and coef_type are inherited from FilterBaseSettings
|
|
9
20
|
|
|
10
|
-
class ButterworthFilterSettings(FilterSettingsBase):
|
|
11
21
|
order: int = 0
|
|
12
|
-
|
|
13
|
-
|
|
22
|
+
"""
|
|
23
|
+
Filter order
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
cuton: float | None = None
|
|
27
|
+
"""
|
|
28
|
+
Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
|
|
29
|
+
if this is lower than `cutoff` then this is the beginning of the bandpass
|
|
30
|
+
or if this is greater than `cutoff` then this is the end of the bandstop.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
cutoff: float | None = None
|
|
34
|
+
"""
|
|
35
|
+
Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
|
|
36
|
+
if this is greater than `cuton` then this is the end of the bandpass,
|
|
37
|
+
or if this is less than `cuton` then this is the beginning of the bandstop.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
wn_hz: bool = True
|
|
41
|
+
"""
|
|
42
|
+
Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def filter_specs(
|
|
46
|
+
self,
|
|
47
|
+
) -> tuple[str, float | tuple[float, float]] | None:
|
|
48
|
+
"""
|
|
49
|
+
Determine the filter type given the corner frequencies.
|
|
14
50
|
|
|
15
|
-
|
|
51
|
+
Returns:
|
|
52
|
+
A tuple with the first element being a string indicating the filter type
|
|
53
|
+
(one of "lowpass", "highpass", "bandpass", "bandstop")
|
|
54
|
+
and the second element being the corner frequency or frequencies.
|
|
55
|
+
|
|
56
|
+
"""
|
|
16
57
|
if self.cuton is None and self.cutoff is None:
|
|
17
58
|
return None
|
|
18
59
|
elif self.cuton is None and self.cutoff is not None:
|
|
@@ -26,41 +67,90 @@ class ButterworthFilterSettings(FilterSettingsBase):
|
|
|
26
67
|
return "bandstop", (self.cutoff, self.cuton)
|
|
27
68
|
|
|
28
69
|
|
|
29
|
-
|
|
30
|
-
|
|
70
|
+
def butter_design_fun(
|
|
71
|
+
fs: float,
|
|
72
|
+
order: int = 0,
|
|
73
|
+
cuton: float | None = None,
|
|
74
|
+
cutoff: float | None = None,
|
|
75
|
+
coef_type: str = "ba",
|
|
76
|
+
wn_hz: bool = True,
|
|
77
|
+
) -> BACoeffs | SOSCoeffs | None:
|
|
78
|
+
"""
|
|
79
|
+
See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
|
|
80
|
+
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
|
|
81
|
+
You are likely to want to use this function with :obj:`filter_by_design`, which only passes `fs` to the design
|
|
82
|
+
function (this), meaning that you should wrap this function with a lambda or prepare with functools.partial.
|
|
31
83
|
|
|
84
|
+
Args:
|
|
85
|
+
fs: The sampling frequency of the data in Hz.
|
|
86
|
+
order: Filter order.
|
|
87
|
+
cuton: Corner frequency of the filter in Hz.
|
|
88
|
+
cutoff: Corner frequency of the filter in Hz.
|
|
89
|
+
coef_type: "ba", "sos", or "zpk"
|
|
90
|
+
wn_hz: Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
|
|
32
91
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
92
|
+
Returns:
|
|
93
|
+
The filter coefficients as a tuple of (b, a) for coef_type "ba", or as a single ndarray for "sos",
|
|
94
|
+
or (z, p, k) for "zpk".
|
|
36
95
|
|
|
37
|
-
|
|
96
|
+
"""
|
|
97
|
+
coefs = None
|
|
98
|
+
if order > 0:
|
|
99
|
+
btype, cutoffs = ButterworthFilterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs()
|
|
100
|
+
coefs = scipy.signal.butter(
|
|
101
|
+
order,
|
|
102
|
+
Wn=cutoffs,
|
|
103
|
+
btype=btype,
|
|
104
|
+
fs=fs if wn_hz else None,
|
|
105
|
+
output=coef_type,
|
|
106
|
+
)
|
|
107
|
+
if coefs is not None and coef_type == "ba":
|
|
108
|
+
coefs = normalize(*coefs)
|
|
109
|
+
return coefs
|
|
38
110
|
|
|
39
|
-
def initialize(self) -> None:
|
|
40
|
-
self.STATE.design = self.SETTINGS
|
|
41
|
-
self.STATE.filt_designed = True
|
|
42
|
-
super().initialize()
|
|
43
111
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
112
|
+
class ButterworthFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]):
|
|
113
|
+
def get_design_function(
|
|
114
|
+
self,
|
|
115
|
+
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
116
|
+
return functools.partial(
|
|
117
|
+
butter_design_fun,
|
|
118
|
+
order=self.settings.order,
|
|
119
|
+
cuton=self.settings.cuton,
|
|
120
|
+
cutoff=self.settings.cutoff,
|
|
121
|
+
coef_type=self.settings.coef_type,
|
|
122
|
+
wn_hz=self.settings.wn_hz,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class ButterworthFilter(BaseFilterByDesignTransformerUnit[ButterworthFilterSettings, ButterworthFilterTransformer]):
|
|
127
|
+
SETTINGS = ButterworthFilterSettings
|
|
128
|
+
|
|
57
129
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
130
|
+
def butter(
|
|
131
|
+
axis: str | None,
|
|
132
|
+
order: int = 0,
|
|
133
|
+
cuton: float | None = None,
|
|
134
|
+
cutoff: float | None = None,
|
|
135
|
+
coef_type: str = "ba",
|
|
136
|
+
wn_hz: bool = True,
|
|
137
|
+
) -> ButterworthFilterTransformer:
|
|
138
|
+
"""
|
|
139
|
+
Convenience generator wrapping filter_gen_by_design for Butterworth filters.
|
|
140
|
+
Apply Butterworth filter to streaming data. Uses :obj:`scipy.signal.butter` to design the filter.
|
|
141
|
+
See :obj:`ButterworthFilterSettings.filter_specs` for an explanation of specifying different
|
|
142
|
+
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
|
|
62
143
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
144
|
+
Returns:
|
|
145
|
+
:obj:`ButterworthFilterTransformer`
|
|
146
|
+
"""
|
|
147
|
+
return ButterworthFilterTransformer(
|
|
148
|
+
ButterworthFilterSettings(
|
|
149
|
+
axis=axis,
|
|
150
|
+
order=order,
|
|
151
|
+
cuton=cuton,
|
|
152
|
+
cutoff=cutoff,
|
|
153
|
+
coef_type=coef_type,
|
|
154
|
+
wn_hz=wn_hz,
|
|
155
|
+
)
|
|
156
|
+
)
|