ezmsg-sigproc 2.9.0__py3-none-any.whl → 2.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/butterworthzerophase.py +243 -61
- ezmsg/sigproc/math/pow.py +43 -0
- ezmsg/sigproc/scaler.py +55 -30
- ezmsg/sigproc/singlebandpow.py +116 -0
- {ezmsg_sigproc-2.9.0.dist-info → ezmsg_sigproc-2.11.0.dist-info}/METADATA +1 -1
- {ezmsg_sigproc-2.9.0.dist-info → ezmsg_sigproc-2.11.0.dist-info}/RECORD +9 -7
- {ezmsg_sigproc-2.9.0.dist-info → ezmsg_sigproc-2.11.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-2.9.0.dist-info → ezmsg_sigproc-2.11.0.dist-info}/licenses/LICENSE +0 -0
ezmsg/sigproc/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.
|
|
32
|
-
__version_tuple__ = version_tuple = (2,
|
|
31
|
+
__version__ = version = '2.11.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 11, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -1,42 +1,80 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Streaming zero-phase Butterworth filter implemented as a two-stage composite processor.
|
|
3
|
+
|
|
4
|
+
Stage 1: Forward causal Butterworth filter (from ezmsg.sigproc.butterworthfilter)
|
|
5
|
+
Stage 2: Backward acausal filter with buffering (ButterworthBackwardFilterTransformer)
|
|
6
|
+
|
|
7
|
+
The output is delayed by `pad_length` samples to ensure the backward pass has sufficient
|
|
8
|
+
future context. The pad_length is computed analytically using scipy's heuristic.
|
|
9
|
+
"""
|
|
10
|
+
|
|
1
11
|
import functools
|
|
2
12
|
import typing
|
|
3
13
|
|
|
4
|
-
import ezmsg.core as ez
|
|
5
14
|
import numpy as np
|
|
6
15
|
import scipy.signal
|
|
7
|
-
from ezmsg.baseproc import
|
|
16
|
+
from ezmsg.baseproc import BaseTransformerUnit
|
|
17
|
+
from ezmsg.baseproc.composite import CompositeProcessor
|
|
8
18
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
19
|
from ezmsg.util.messages.util import replace
|
|
10
20
|
|
|
11
|
-
from
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
FilterByDesignTransformer,
|
|
16
|
-
SOSCoeffs,
|
|
21
|
+
from .butterworthfilter import (
|
|
22
|
+
ButterworthFilterSettings,
|
|
23
|
+
ButterworthFilterTransformer,
|
|
24
|
+
butter_design_fun,
|
|
17
25
|
)
|
|
26
|
+
from .filter import BACoeffs, FilterByDesignTransformer, SOSCoeffs
|
|
27
|
+
from .util.axisarray_buffer import HybridAxisArrayBuffer
|
|
18
28
|
|
|
19
29
|
|
|
20
30
|
class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
|
|
21
|
-
"""
|
|
31
|
+
"""
|
|
32
|
+
Settings for :obj:`ButterworthZeroPhase`.
|
|
33
|
+
|
|
34
|
+
This implements a streaming zero-phase Butterworth filter using forward-backward
|
|
35
|
+
filtering. The output is delayed by `pad_length` samples to ensure the backward
|
|
36
|
+
pass has sufficient future context.
|
|
37
|
+
|
|
38
|
+
The pad_length is computed by finding where the filter's impulse response decays
|
|
39
|
+
to `settle_cutoff` fraction of its peak value. This accounts for the filter's
|
|
40
|
+
actual time constant rather than just its order.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
# Inherits from ButterworthFilterSettings:
|
|
44
|
+
# axis, coef_type, order, cuton, cutoff, wn_hz
|
|
22
45
|
|
|
23
|
-
|
|
24
|
-
padtype: str | None = None
|
|
46
|
+
settle_cutoff: float = 0.01
|
|
25
47
|
"""
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
Default is
|
|
48
|
+
Fraction of peak impulse response used to determine settling time.
|
|
49
|
+
The pad_length is set to the number of samples until the impulse response
|
|
50
|
+
decays to this fraction of its peak. Default is 0.01 (1% of peak).
|
|
29
51
|
"""
|
|
30
52
|
|
|
31
|
-
|
|
53
|
+
max_pad_duration: float | None = None
|
|
32
54
|
"""
|
|
33
|
-
|
|
34
|
-
|
|
55
|
+
Maximum pad duration in seconds. If set, the pad_length will be capped
|
|
56
|
+
at this value times the sampling rate. Use this to limit latency for
|
|
57
|
+
filters with very long impulse responses. Default is None (no limit).
|
|
35
58
|
"""
|
|
36
59
|
|
|
37
60
|
|
|
38
|
-
class
|
|
39
|
-
"""
|
|
61
|
+
class ButterworthBackwardFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]):
|
|
62
|
+
"""
|
|
63
|
+
Backward (acausal) Butterworth filter with buffering.
|
|
64
|
+
|
|
65
|
+
This transformer buffers its input and applies the filter in reverse,
|
|
66
|
+
outputting only the "settled" portion where transients have decayed.
|
|
67
|
+
This introduces a lag of ``pad_length`` samples.
|
|
68
|
+
|
|
69
|
+
Intended to be used as stage 2 in a zero-phase filter pipeline, receiving
|
|
70
|
+
forward-filtered data from a ButterworthFilterTransformer.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
# Instance attributes (initialized in _reset_state)
|
|
74
|
+
_buffer: HybridAxisArrayBuffer | None
|
|
75
|
+
_coefs_cache: BACoeffs | SOSCoeffs | None
|
|
76
|
+
_zi_tiled: np.ndarray | None
|
|
77
|
+
_pad_length: int
|
|
40
78
|
|
|
41
79
|
def get_design_function(
|
|
42
80
|
self,
|
|
@@ -50,74 +88,218 @@ class ButterworthZeroPhaseTransformer(FilterByDesignTransformer[ButterworthZeroP
|
|
|
50
88
|
wn_hz=self.settings.wn_hz,
|
|
51
89
|
)
|
|
52
90
|
|
|
53
|
-
def
|
|
91
|
+
def _compute_pad_length(self, fs: float) -> int:
|
|
54
92
|
"""
|
|
55
|
-
|
|
93
|
+
Compute pad length based on the filter's impulse response settling time.
|
|
94
|
+
|
|
95
|
+
The pad_length is determined by finding where the impulse response decays
|
|
96
|
+
to `settle_cutoff` fraction of its peak value. This is then optionally
|
|
97
|
+
capped by `max_pad_duration`.
|
|
56
98
|
|
|
57
99
|
Args:
|
|
58
|
-
|
|
59
|
-
|
|
100
|
+
fs: Sampling frequency in Hz.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Number of samples for the pad length.
|
|
60
104
|
"""
|
|
61
|
-
#
|
|
62
|
-
|
|
63
|
-
|
|
105
|
+
# Design the filter to compute impulse response
|
|
106
|
+
coefs = self.get_design_function()(fs)
|
|
107
|
+
if coefs is None:
|
|
108
|
+
# Filter design failed or is disabled
|
|
109
|
+
return 0
|
|
110
|
+
|
|
111
|
+
# Generate impulse response - use a generous length initially
|
|
112
|
+
# Start with scipy's heuristic as minimum, then extend if needed
|
|
113
|
+
if self.settings.coef_type == "ba":
|
|
114
|
+
min_length = 3 * (self.settings.order + 1)
|
|
64
115
|
else:
|
|
65
|
-
|
|
116
|
+
n_sections = (self.settings.order + 1) // 2
|
|
117
|
+
min_length = 3 * n_sections * 2
|
|
66
118
|
|
|
67
|
-
#
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
119
|
+
# Use 10x the minimum as initial impulse length, or at least 10000 samples
|
|
120
|
+
# (10000 samples allows for ~333ms at 30kHz, covering most practical cases)
|
|
121
|
+
impulse_length = max(min_length * 10, 10000)
|
|
122
|
+
|
|
123
|
+
# Cap impulse length computation if max_pad_duration is set
|
|
124
|
+
if self.settings.max_pad_duration is not None:
|
|
125
|
+
max_samples = int(self.settings.max_pad_duration * fs)
|
|
126
|
+
impulse_length = min(impulse_length, max_samples + 1)
|
|
127
|
+
|
|
128
|
+
impulse = np.zeros(impulse_length)
|
|
129
|
+
impulse[0] = 1.0
|
|
130
|
+
|
|
131
|
+
if self.settings.coef_type == "ba":
|
|
132
|
+
b, a = coefs
|
|
133
|
+
h = scipy.signal.lfilter(b, a, impulse)
|
|
134
|
+
else:
|
|
135
|
+
h = scipy.signal.sosfilt(coefs, impulse)
|
|
136
|
+
|
|
137
|
+
# Find where impulse response settles to settle_cutoff of peak
|
|
138
|
+
abs_h = np.abs(h)
|
|
139
|
+
peak = abs_h.max()
|
|
140
|
+
if peak == 0:
|
|
141
|
+
return min_length
|
|
142
|
+
|
|
143
|
+
threshold = self.settings.settle_cutoff * peak
|
|
144
|
+
above_threshold = np.where(abs_h > threshold)[0]
|
|
145
|
+
|
|
146
|
+
if len(above_threshold) == 0:
|
|
147
|
+
pad_length = min_length
|
|
148
|
+
else:
|
|
149
|
+
pad_length = above_threshold[-1] + 1
|
|
150
|
+
|
|
151
|
+
# Ensure at least the scipy heuristic minimum
|
|
152
|
+
pad_length = max(pad_length, min_length)
|
|
153
|
+
|
|
154
|
+
# Apply max_pad_duration cap if set
|
|
155
|
+
if self.settings.max_pad_duration is not None:
|
|
156
|
+
max_samples = int(self.settings.max_pad_duration * fs)
|
|
157
|
+
pad_length = min(pad_length, max_samples)
|
|
158
|
+
|
|
159
|
+
return pad_length
|
|
71
160
|
|
|
72
161
|
def _reset_state(self, message: AxisArray) -> None:
|
|
162
|
+
"""Reset filter state when stream changes."""
|
|
73
163
|
self._coefs_cache = None
|
|
74
|
-
self.
|
|
164
|
+
self._zi_tiled = None
|
|
165
|
+
self._buffer = None
|
|
166
|
+
# Compute pad_length based on the message's sampling rate
|
|
167
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
168
|
+
fs = 1 / message.axes[axis].gain
|
|
169
|
+
self._pad_length = self._compute_pad_length(fs)
|
|
75
170
|
self.state.needs_redesign = True
|
|
76
171
|
|
|
172
|
+
def _compute_zi_tiled(self, data: np.ndarray, ax_idx: int) -> None:
|
|
173
|
+
"""Compute and cache the tiled zi for the given data shape.
|
|
174
|
+
|
|
175
|
+
Called once per stream (or after filter redesign). The result is
|
|
176
|
+
broadcast-ready for multiplication by the edge sample on each chunk.
|
|
177
|
+
"""
|
|
178
|
+
if self.settings.coef_type == "ba":
|
|
179
|
+
b, a = self._coefs_cache
|
|
180
|
+
zi_base = scipy.signal.lfilter_zi(b, a)
|
|
181
|
+
else: # sos
|
|
182
|
+
zi_base = scipy.signal.sosfilt_zi(self._coefs_cache)
|
|
183
|
+
|
|
184
|
+
n_tail = data.ndim - ax_idx - 1
|
|
185
|
+
|
|
186
|
+
if self.settings.coef_type == "ba":
|
|
187
|
+
zi_expand = (None,) * ax_idx + (slice(None),) + (None,) * n_tail
|
|
188
|
+
n_tile = data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :]
|
|
189
|
+
else: # sos
|
|
190
|
+
zi_expand = (slice(None),) + (None,) * ax_idx + (slice(None),) + (None,) * n_tail
|
|
191
|
+
n_tile = (1,) + data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :]
|
|
192
|
+
|
|
193
|
+
self._zi_tiled = np.tile(zi_base[zi_expand], n_tile)
|
|
194
|
+
|
|
195
|
+
def _initialize_zi(self, data: np.ndarray, ax_idx: int) -> np.ndarray:
|
|
196
|
+
"""Initialize filter state (zi) scaled by edge value."""
|
|
197
|
+
if self._zi_tiled is None:
|
|
198
|
+
self._compute_zi_tiled(data, ax_idx)
|
|
199
|
+
first_sample = np.take(data, [0], axis=ax_idx)
|
|
200
|
+
return self._zi_tiled * first_sample
|
|
201
|
+
|
|
77
202
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
78
203
|
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
79
204
|
ax_idx = message.get_axis_idx(axis)
|
|
80
205
|
fs = 1 / message.axes[axis].gain
|
|
81
206
|
|
|
82
|
-
if
|
|
83
|
-
|
|
84
|
-
or self.state.needs_redesign
|
|
85
|
-
or (self._fs_cache is None or not np.isclose(self._fs_cache, fs))
|
|
86
|
-
):
|
|
207
|
+
# Check if we need to redesign filter
|
|
208
|
+
if self._coefs_cache is None or self.state.needs_redesign:
|
|
87
209
|
self._coefs_cache = self.get_design_function()(fs)
|
|
88
|
-
self.
|
|
210
|
+
self._pad_length = self._compute_pad_length(fs)
|
|
211
|
+
self._zi_tiled = None # Invalidate; recomputed on next use.
|
|
89
212
|
self.state.needs_redesign = False
|
|
90
213
|
|
|
214
|
+
# Initialize buffer with duration based on pad_length
|
|
215
|
+
# Add some margin to handle variable chunk sizes
|
|
216
|
+
buffer_duration = (self._pad_length + 1) / fs
|
|
217
|
+
self._buffer = HybridAxisArrayBuffer(duration=buffer_duration, axis=axis)
|
|
218
|
+
|
|
219
|
+
# Early exit if filter is effectively disabled
|
|
91
220
|
if self._coefs_cache is None or self.settings.order <= 0 or message.data.size <= 0:
|
|
92
221
|
return message
|
|
93
222
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
223
|
+
# Write new data to buffer
|
|
224
|
+
self._buffer.write(message)
|
|
225
|
+
n_available = self._buffer.available()
|
|
226
|
+
n_output = n_available - self._pad_length
|
|
227
|
+
|
|
228
|
+
# If we don't have enough data yet, return empty
|
|
229
|
+
if n_output <= 0:
|
|
230
|
+
new_shape = list(message.data.shape)
|
|
231
|
+
new_shape[ax_idx] = 0
|
|
232
|
+
empty_data = np.empty(new_shape, dtype=message.data.dtype)
|
|
233
|
+
return replace(message, data=empty_data)
|
|
234
|
+
|
|
235
|
+
# Peek all available data from buffer
|
|
236
|
+
# Note: HybridAxisArrayBuffer moves the target axis to position 0
|
|
237
|
+
buffered = self._buffer.peek(n_available)
|
|
238
|
+
combined = buffered.data
|
|
239
|
+
buffer_ax_idx = 0 # Buffer always puts time axis at position 0
|
|
240
|
+
|
|
241
|
+
# Backward filter on reversed data
|
|
242
|
+
combined_rev = np.flip(combined, axis=buffer_ax_idx)
|
|
243
|
+
backward_zi = self._initialize_zi(combined_rev, buffer_ax_idx)
|
|
244
|
+
|
|
245
|
+
if self.settings.coef_type == "ba":
|
|
104
246
|
b, a = self._coefs_cache
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
247
|
+
y_bwd_rev, _ = scipy.signal.lfilter(b, a, combined_rev, axis=buffer_ax_idx, zi=backward_zi)
|
|
248
|
+
else: # sos
|
|
249
|
+
y_bwd_rev, _ = scipy.signal.sosfilt(self._coefs_cache, combined_rev, axis=buffer_ax_idx, zi=backward_zi)
|
|
250
|
+
|
|
251
|
+
# Reverse back to get output in correct time order
|
|
252
|
+
y_bwd = np.flip(y_bwd_rev, axis=buffer_ax_idx)
|
|
253
|
+
|
|
254
|
+
# Output the settled portion (first n_output samples)
|
|
255
|
+
y = y_bwd[:n_output]
|
|
256
|
+
|
|
257
|
+
# Advance buffer read head to discard output samples, keep pad_length
|
|
258
|
+
self._buffer.seek(n_output)
|
|
259
|
+
|
|
260
|
+
# Build output with adjusted time axis
|
|
261
|
+
# LinearAxis offset is already correct from the buffer
|
|
262
|
+
out_axis = buffered.axes[axis]
|
|
263
|
+
|
|
264
|
+
# Move axis back to original position if needed
|
|
265
|
+
if ax_idx != 0:
|
|
266
|
+
y = np.moveaxis(y, 0, ax_idx)
|
|
267
|
+
|
|
268
|
+
return replace(
|
|
269
|
+
message,
|
|
270
|
+
data=y,
|
|
271
|
+
axes={**message.axes, axis: out_axis},
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class ButterworthZeroPhaseTransformer(CompositeProcessor[ButterworthZeroPhaseSettings, AxisArray, AxisArray]):
|
|
276
|
+
"""
|
|
277
|
+
Streaming zero-phase Butterworth filter as a composite of two stages.
|
|
278
|
+
|
|
279
|
+
Stage 1 (forward): Standard causal Butterworth filter with state
|
|
280
|
+
Stage 2 (backward): Acausal Butterworth filter with buffering
|
|
281
|
+
|
|
282
|
+
The output is delayed by ``pad_length`` samples.
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
@staticmethod
|
|
286
|
+
def _initialize_processors(
|
|
287
|
+
settings: ButterworthZeroPhaseSettings,
|
|
288
|
+
) -> dict[str, typing.Any]:
|
|
289
|
+
# Both stages use the same filter design settings
|
|
290
|
+
return {
|
|
291
|
+
"forward": ButterworthFilterTransformer(settings),
|
|
292
|
+
"backward": ButterworthBackwardFilterTransformer(settings),
|
|
293
|
+
}
|
|
116
294
|
|
|
117
|
-
|
|
295
|
+
@classmethod
|
|
296
|
+
def get_message_type(cls, dir: str) -> type[AxisArray]:
|
|
297
|
+
if dir in ("in", "out"):
|
|
298
|
+
return AxisArray
|
|
299
|
+
raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
|
|
118
300
|
|
|
119
301
|
|
|
120
302
|
class ButterworthZeroPhase(
|
|
121
|
-
|
|
303
|
+
BaseTransformerUnit[ButterworthZeroPhaseSettings, AxisArray, AxisArray, ButterworthZeroPhaseTransformer]
|
|
122
304
|
):
|
|
123
305
|
SETTINGS = ButterworthZeroPhaseSettings
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Element-wise power of the data.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import ezmsg.core as ez
|
|
10
|
+
from array_api_compat import get_namespace
|
|
11
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
12
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
|
+
from ezmsg.util.messages.util import replace
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PowSettings(ez.Settings):
|
|
17
|
+
exponent: float = 2.0
|
|
18
|
+
"""The exponent to raise the data to. Default is 2.0 (squaring)."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PowTransformer(BaseTransformer[PowSettings, AxisArray, AxisArray]):
|
|
22
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
23
|
+
xp = get_namespace(message.data)
|
|
24
|
+
return replace(message, data=xp.pow(message.data, self.settings.exponent))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Pow(BaseTransformerUnit[PowSettings, AxisArray, AxisArray, PowTransformer]):
|
|
28
|
+
SETTINGS = PowSettings
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def pow(
|
|
32
|
+
exponent: float = 2.0,
|
|
33
|
+
) -> PowTransformer:
|
|
34
|
+
"""
|
|
35
|
+
Raise the data to an element-wise power. See :obj:`xp.pow` for more details.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
exponent: The exponent to raise the data to. Default is 2.0.
|
|
39
|
+
|
|
40
|
+
Returns: :obj:`PowTransformer`.
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
return PowTransformer(PowSettings(exponent=exponent))
|
ezmsg/sigproc/scaler.py
CHANGED
|
@@ -7,7 +7,6 @@ from ezmsg.baseproc import (
|
|
|
7
7
|
BaseTransformerUnit,
|
|
8
8
|
processor_state,
|
|
9
9
|
)
|
|
10
|
-
from ezmsg.util.generator import consumer
|
|
11
10
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
12
11
|
from ezmsg.util.messages.util import replace
|
|
13
12
|
|
|
@@ -18,50 +17,69 @@ from .ewma import _tau_from_alpha as _tau_from_alpha
|
|
|
18
17
|
from .ewma import ewma_step as ewma_step
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
|
|
20
|
+
class RiverAdaptiveStandardScalerSettings(ez.Settings):
|
|
21
|
+
time_constant: float = 1.0
|
|
22
|
+
"""Decay constant ``tau`` in seconds."""
|
|
23
|
+
|
|
24
|
+
axis: str | None = None
|
|
25
|
+
"""The name of the axis to accumulate statistics over."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@processor_state
|
|
29
|
+
class RiverAdaptiveStandardScalerState:
|
|
30
|
+
scaler: typing.Any = None
|
|
31
|
+
axis: str | None = None
|
|
32
|
+
axis_idx: int = 0
|
|
26
33
|
|
|
27
|
-
Args:
|
|
28
|
-
time_constant: Decay constant `tau` in seconds.
|
|
29
|
-
axis: The name of the axis to accumulate statistics over.
|
|
30
34
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
35
|
+
class RiverAdaptiveStandardScalerTransformer(
|
|
36
|
+
BaseStatefulTransformer[
|
|
37
|
+
RiverAdaptiveStandardScalerSettings,
|
|
38
|
+
AxisArray,
|
|
39
|
+
AxisArray,
|
|
40
|
+
RiverAdaptiveStandardScalerState,
|
|
41
|
+
]
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Apply the adaptive standard scaler from
|
|
45
|
+
`river <https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/>`_.
|
|
46
|
+
|
|
47
|
+
This processes data sample-by-sample using River's online learning
|
|
48
|
+
implementation. For a vectorized EWMA-based alternative, see
|
|
49
|
+
:class:`AdaptiveStandardScalerTransformer`.
|
|
34
50
|
"""
|
|
35
|
-
from river import preprocessing
|
|
36
51
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
data = msg_in.data
|
|
52
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
53
|
+
from river import preprocessing
|
|
54
|
+
|
|
55
|
+
axis = self.settings.axis
|
|
42
56
|
if axis is None:
|
|
43
|
-
axis =
|
|
44
|
-
axis_idx = 0
|
|
57
|
+
axis = message.dims[0]
|
|
58
|
+
self._state.axis_idx = 0
|
|
45
59
|
else:
|
|
46
|
-
axis_idx =
|
|
47
|
-
|
|
48
|
-
|
|
60
|
+
self._state.axis_idx = message.get_axis_idx(axis)
|
|
61
|
+
self._state.axis = axis
|
|
62
|
+
|
|
63
|
+
alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain)
|
|
64
|
+
self._state.scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
|
|
49
65
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
66
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
67
|
+
data = message.data
|
|
68
|
+
axis_idx = self._state.axis_idx
|
|
69
|
+
if axis_idx != 0:
|
|
70
|
+
data = np.moveaxis(data, axis_idx, 0)
|
|
53
71
|
|
|
54
72
|
result = []
|
|
55
73
|
for sample in data:
|
|
56
74
|
x = {k: v for k, v in enumerate(sample.flatten().tolist())}
|
|
57
|
-
|
|
58
|
-
y =
|
|
75
|
+
self._state.scaler.learn_one(x)
|
|
76
|
+
y = self._state.scaler.transform_one(x)
|
|
59
77
|
k = sorted(y.keys())
|
|
60
78
|
result.append(np.array([y[_] for _ in k]).reshape(sample.shape))
|
|
61
79
|
|
|
62
80
|
result = np.stack(result)
|
|
63
81
|
result = np.moveaxis(result, 0, axis_idx)
|
|
64
|
-
|
|
82
|
+
return replace(message, data=result)
|
|
65
83
|
|
|
66
84
|
|
|
67
85
|
class AdaptiveStandardScalerSettings(EWMASettings): ...
|
|
@@ -158,7 +176,14 @@ class AdaptiveStandardScaler(
|
|
|
158
176
|
self.processor.settings = msg
|
|
159
177
|
|
|
160
178
|
|
|
161
|
-
#
|
|
179
|
+
# Convenience functions to support deprecated generator API
|
|
180
|
+
def scaler(time_constant: float = 1.0, axis: str | None = None) -> RiverAdaptiveStandardScalerTransformer:
|
|
181
|
+
"""Create a :class:`RiverAdaptiveStandardScalerTransformer` with the given parameters."""
|
|
182
|
+
return RiverAdaptiveStandardScalerTransformer(
|
|
183
|
+
settings=RiverAdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
162
187
|
def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
|
|
163
188
|
return AdaptiveStandardScalerTransformer(
|
|
164
189
|
settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Time-domain single-band power estimation.
|
|
3
|
+
|
|
4
|
+
Two methods are provided:
|
|
5
|
+
|
|
6
|
+
1. **RMS Band Power** — Bandpass filter, square, window into bins, take the mean, optionally take the square root.
|
|
7
|
+
2. **Square-Law + LPF Band Power** — Bandpass filter, square, lowpass filter (smoothing), downsample.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from dataclasses import field
|
|
11
|
+
|
|
12
|
+
import ezmsg.core as ez
|
|
13
|
+
from ezmsg.baseproc import (
|
|
14
|
+
BaseProcessor,
|
|
15
|
+
BaseStatefulProcessor,
|
|
16
|
+
BaseTransformerUnit,
|
|
17
|
+
CompositeProcessor,
|
|
18
|
+
)
|
|
19
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
20
|
+
from ezmsg.util.messages.modify import modify_axis
|
|
21
|
+
|
|
22
|
+
from .aggregate import AggregateSettings, AggregateTransformer, AggregationFunction
|
|
23
|
+
from .butterworthfilter import ButterworthFilterSettings, ButterworthFilterTransformer
|
|
24
|
+
from .downsample import DownsampleSettings, DownsampleTransformer
|
|
25
|
+
from .math.pow import PowSettings, PowTransformer
|
|
26
|
+
from .window import WindowTransformer
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RMSBandPowerSettings(ez.Settings):
|
|
30
|
+
"""Settings for :obj:`RMSBandPowerTransformer`."""
|
|
31
|
+
|
|
32
|
+
bandpass: ButterworthFilterSettings = field(
|
|
33
|
+
default_factory=lambda: ButterworthFilterSettings(order=4, coef_type="sos")
|
|
34
|
+
)
|
|
35
|
+
"""Butterworth bandpass filter settings. Set ``cuton`` and ``cutoff`` to define the band."""
|
|
36
|
+
|
|
37
|
+
bin_duration: float = 0.05
|
|
38
|
+
"""Duration of each non-overlapping bin in seconds."""
|
|
39
|
+
|
|
40
|
+
apply_sqrt: bool = True
|
|
41
|
+
"""If True, output is RMS (root-mean-square). If False, output is mean-square power."""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RMSBandPowerTransformer(CompositeProcessor[RMSBandPowerSettings, AxisArray, AxisArray]):
|
|
45
|
+
"""
|
|
46
|
+
RMS band power estimation.
|
|
47
|
+
|
|
48
|
+
Pipeline: bandpass -> square -> window(bins) -> mean(time) -> rename bin->time -> [sqrt]
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def _initialize_processors(
|
|
53
|
+
settings: RMSBandPowerSettings,
|
|
54
|
+
) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
|
|
55
|
+
procs: dict[str, BaseProcessor | BaseStatefulProcessor] = {
|
|
56
|
+
"bandpass": ButterworthFilterTransformer(settings.bandpass),
|
|
57
|
+
"square": PowTransformer(PowSettings(exponent=2.0)),
|
|
58
|
+
"window": WindowTransformer(
|
|
59
|
+
axis="time",
|
|
60
|
+
newaxis="bin",
|
|
61
|
+
window_dur=settings.bin_duration,
|
|
62
|
+
window_shift=settings.bin_duration,
|
|
63
|
+
zero_pad_until="none",
|
|
64
|
+
),
|
|
65
|
+
"aggregate": AggregateTransformer(AggregateSettings(axis="time", operation=AggregationFunction.MEAN)),
|
|
66
|
+
"rename": modify_axis(name_map={"bin": "time"}),
|
|
67
|
+
}
|
|
68
|
+
if settings.apply_sqrt:
|
|
69
|
+
procs["sqrt"] = PowTransformer(PowSettings(exponent=0.5))
|
|
70
|
+
return procs
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class RMSBandPower(BaseTransformerUnit[RMSBandPowerSettings, AxisArray, AxisArray, RMSBandPowerTransformer]):
|
|
74
|
+
SETTINGS = RMSBandPowerSettings
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class SquareLawBandPowerSettings(ez.Settings):
|
|
78
|
+
"""Settings for :obj:`SquareLawBandPowerTransformer`."""
|
|
79
|
+
|
|
80
|
+
bandpass: ButterworthFilterSettings = field(
|
|
81
|
+
default_factory=lambda: ButterworthFilterSettings(order=4, coef_type="sos")
|
|
82
|
+
)
|
|
83
|
+
"""Butterworth bandpass filter settings. Set ``cuton`` and ``cutoff`` to define the band."""
|
|
84
|
+
|
|
85
|
+
lowpass: ButterworthFilterSettings = field(
|
|
86
|
+
default_factory=lambda: ButterworthFilterSettings(order=4, coef_type="sos")
|
|
87
|
+
)
|
|
88
|
+
"""Butterworth lowpass filter settings for smoothing the squared signal."""
|
|
89
|
+
|
|
90
|
+
downsample: DownsampleSettings = field(default_factory=DownsampleSettings)
|
|
91
|
+
"""Downsample settings for rate reduction after lowpass smoothing."""
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class SquareLawBandPowerTransformer(CompositeProcessor[SquareLawBandPowerSettings, AxisArray, AxisArray]):
|
|
95
|
+
"""
|
|
96
|
+
Square-law + LPF band power estimation.
|
|
97
|
+
|
|
98
|
+
Pipeline: bandpass -> square -> lowpass -> downsample
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def _initialize_processors(
|
|
103
|
+
settings: SquareLawBandPowerSettings,
|
|
104
|
+
) -> dict[str, BaseProcessor | BaseStatefulProcessor]:
|
|
105
|
+
return {
|
|
106
|
+
"bandpass": ButterworthFilterTransformer(settings.bandpass),
|
|
107
|
+
"square": PowTransformer(PowSettings(exponent=2.0)),
|
|
108
|
+
"lowpass": ButterworthFilterTransformer(settings.lowpass),
|
|
109
|
+
"downsample": DownsampleTransformer(settings.downsample),
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class SquareLawBandPower(
|
|
114
|
+
BaseTransformerUnit[SquareLawBandPowerSettings, AxisArray, AxisArray, SquareLawBandPowerTransformer]
|
|
115
|
+
):
|
|
116
|
+
SETTINGS = SquareLawBandPowerSettings
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-sigproc
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.11.0
|
|
4
4
|
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
5
|
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
|
|
2
|
-
ezmsg/sigproc/__version__.py,sha256=
|
|
2
|
+
ezmsg/sigproc/__version__.py,sha256=eqKbWb9LnxuZWE9-pafopBz45ugg0beSlKLIOIjeSzc,706
|
|
3
3
|
ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
|
|
4
4
|
ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
|
|
5
5
|
ezmsg/sigproc/affinetransform.py,sha256=jl7DiSa5Yb0qsmFJbfSiSeGmvK1SGoBgycFC5JU5DVY,9434
|
|
@@ -7,7 +7,7 @@ ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,93
|
|
|
7
7
|
ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
|
|
8
8
|
ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
|
|
9
9
|
ezmsg/sigproc/butterworthfilter.py,sha256=NKTGkgjvlmC1Dc9gD2Z6UBzUq12KicfnczrzM5ZTosk,5255
|
|
10
|
-
ezmsg/sigproc/butterworthzerophase.py,sha256=
|
|
10
|
+
ezmsg/sigproc/butterworthzerophase.py,sha256=CU6cXkI6j1LQCEz0sr2IthAPCq_TEtbvSb7h2Nw1w74,11820
|
|
11
11
|
ezmsg/sigproc/cheby.py,sha256=B8pGt5_pOBpNZCmaibNl_NKkyuasd8ZEJXeTDCTaino,3711
|
|
12
12
|
ezmsg/sigproc/combfilter.py,sha256=MSxr1I-jBePW_9AuCiv3RQ1HUNxIsNhLk0q1Iu8ikAw,4766
|
|
13
13
|
ezmsg/sigproc/coordinatespaces.py,sha256=bp_0fTS9b27OQqLoFzgE3f9rb287P8y0S1dWWGrS08o,5298
|
|
@@ -34,8 +34,9 @@ ezmsg/sigproc/quantize.py,sha256=uSM2z2xXwL0dgSltyzLEmlKjaJZ2meA3PDWX8_bM0Hs,219
|
|
|
34
34
|
ezmsg/sigproc/resample.py,sha256=3mm9pvxryNVhQuTCIMW3ToUkUfbVOCsIgvXUiurit1Y,11389
|
|
35
35
|
ezmsg/sigproc/rollingscaler.py,sha256=e-smSKDhmDD2nWIf6I77CtRxQp_7sHS268SGPi7aXp8,8499
|
|
36
36
|
ezmsg/sigproc/sampler.py,sha256=iOk2YoUX22u9iTjFKimzP5V074RDBVcmswgfyxvZRZo,10761
|
|
37
|
-
ezmsg/sigproc/scaler.py,sha256=
|
|
37
|
+
ezmsg/sigproc/scaler.py,sha256=nCgShZufPId_b-Sbsc8Si31lbtOb3nPImNcnksd774w,6578
|
|
38
38
|
ezmsg/sigproc/signalinjector.py,sha256=mB62H2b-ScgPtH1jajEpxgDHqdb-RKekQfgyNncsE8Y,2874
|
|
39
|
+
ezmsg/sigproc/singlebandpow.py,sha256=BVlWhFI6zU3ME3EVdZbwf-FMz1d2sfuNFDKXs1hn5HM,4353
|
|
39
40
|
ezmsg/sigproc/slicer.py,sha256=xLXxWf722V08ytVwvPimYjDKKj0pkC2HjdgCVaoaOvs,5195
|
|
40
41
|
ezmsg/sigproc/spectral.py,sha256=wFzuihS7qJZMQcp0ds_qCG-zCbvh5DyhFRjn2wA9TWQ,322
|
|
41
42
|
ezmsg/sigproc/spectrogram.py,sha256=g8xYWENzle6O5uEF-vfjsF5gOSDnJTwiu3ZudicO470,2893
|
|
@@ -50,6 +51,7 @@ ezmsg/sigproc/math/clip.py,sha256=1D6mUlOzBB7L35G_KKYZmfg7nYlbuDdITV4EH0R-yUo,15
|
|
|
50
51
|
ezmsg/sigproc/math/difference.py,sha256=uUYZgbLe-GrFSN6EOFjs9fQZllp827IluxL6m8TJuH8,4791
|
|
51
52
|
ezmsg/sigproc/math/invert.py,sha256=nz8jbfvDoez6s9NmAprBtTAI5oSDj0wNUPk8j13XiVk,855
|
|
52
53
|
ezmsg/sigproc/math/log.py,sha256=JhjSqLnQnvx_3F4txRYHuUPSJ12Yj2HvRTsCMNvlxpo,2022
|
|
54
|
+
ezmsg/sigproc/math/pow.py,sha256=0sdlXFUEBXmpEV_i75oshGRjMguv8L13nLt7hlvdX3E,1284
|
|
53
55
|
ezmsg/sigproc/math/scale.py,sha256=4_xHcHNuf13E1fxIF5vbkPfkN4En6zkfPIKID7lCERk,1133
|
|
54
56
|
ezmsg/sigproc/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
55
57
|
ezmsg/sigproc/util/asio.py,sha256=aAj0e7OoBvkRy28k05HL2s9YPCTxOddc05xMN-qd4lQ,577
|
|
@@ -59,7 +61,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
|
|
|
59
61
|
ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
|
|
60
62
|
ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
|
|
61
63
|
ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
|
|
62
|
-
ezmsg_sigproc-2.
|
|
63
|
-
ezmsg_sigproc-2.
|
|
64
|
-
ezmsg_sigproc-2.
|
|
65
|
-
ezmsg_sigproc-2.
|
|
64
|
+
ezmsg_sigproc-2.11.0.dist-info/METADATA,sha256=8XB8fu3sNqsrwV-ff8xtlWUKsFdERMSqqkotMhfNtu0,1909
|
|
65
|
+
ezmsg_sigproc-2.11.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
66
|
+
ezmsg_sigproc-2.11.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
|
|
67
|
+
ezmsg_sigproc-2.11.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|