ezmsg-sigproc 2.4.0__py3-none-any.whl → 2.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/butterworthzerophase.py +132 -0
- ezmsg/sigproc/denormalize.py +3 -5
- ezmsg/sigproc/fir_hilbert.py +353 -0
- ezmsg/sigproc/fir_pmc.py +214 -0
- ezmsg/sigproc/rollingscaler.py +257 -0
- {ezmsg_sigproc-2.4.0.dist-info → ezmsg_sigproc-2.5.0.dist-info}/METADATA +2 -2
- {ezmsg_sigproc-2.4.0.dist-info → ezmsg_sigproc-2.5.0.dist-info}/RECORD +10 -6
- {ezmsg_sigproc-2.4.0.dist-info → ezmsg_sigproc-2.5.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-2.4.0.dist-info → ezmsg_sigproc-2.5.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.
|
|
32
|
-
__version_tuple__ = version_tuple = (2,
|
|
31
|
+
__version__ = version = '2.5.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 5, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
import numpy as np
|
|
6
|
+
import scipy.signal
|
|
7
|
+
from ezmsg.sigproc.base import SettingsType
|
|
8
|
+
from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
|
|
9
|
+
from ezmsg.sigproc.filter import (
|
|
10
|
+
BACoeffs,
|
|
11
|
+
BaseFilterByDesignTransformerUnit,
|
|
12
|
+
FilterByDesignTransformer,
|
|
13
|
+
SOSCoeffs,
|
|
14
|
+
)
|
|
15
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
16
|
+
from ezmsg.util.messages.util import replace
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
|
|
20
|
+
"""Settings for :obj:`ButterworthZeroPhase`."""
|
|
21
|
+
|
|
22
|
+
# axis, coef_type, order, cuton, cutoff, wn_hz are inherited from ButterworthFilterSettings
|
|
23
|
+
padtype: str | None = None
|
|
24
|
+
"""
|
|
25
|
+
Padding type to use in `scipy.signal.filtfilt`.
|
|
26
|
+
Must be one of {'odd', 'even', 'constant', None}.
|
|
27
|
+
Default is None for no padding.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
padlen: int | None = 0
|
|
31
|
+
"""
|
|
32
|
+
Length of the padding to use in `scipy.signal.filtfilt`.
|
|
33
|
+
If None, SciPy's default padding is used.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ButterworthZeroPhaseTransformer(
|
|
38
|
+
FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]
|
|
39
|
+
):
|
|
40
|
+
"""Zero-phase (filtfilt) Butterworth using your design function."""
|
|
41
|
+
|
|
42
|
+
def get_design_function(
|
|
43
|
+
self,
|
|
44
|
+
) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
|
|
45
|
+
return functools.partial(
|
|
46
|
+
butter_design_fun,
|
|
47
|
+
order=self.settings.order,
|
|
48
|
+
cuton=self.settings.cuton,
|
|
49
|
+
cutoff=self.settings.cutoff,
|
|
50
|
+
coef_type=self.settings.coef_type,
|
|
51
|
+
wn_hz=self.settings.wn_hz,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def update_settings(
|
|
55
|
+
self, new_settings: typing.Optional[SettingsType] = None, **kwargs
|
|
56
|
+
) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Update settings and mark that filter coefficients need to be recalculated.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
new_settings: Complete new settings object to replace current settings
|
|
62
|
+
**kwargs: Individual settings to update
|
|
63
|
+
"""
|
|
64
|
+
# Update settings
|
|
65
|
+
if new_settings is not None:
|
|
66
|
+
self.settings = new_settings
|
|
67
|
+
else:
|
|
68
|
+
self.settings = replace(self.settings, **kwargs)
|
|
69
|
+
|
|
70
|
+
# Set flag to trigger recalculation on next message
|
|
71
|
+
self._coefs_cache = None
|
|
72
|
+
self._fs_cache = None
|
|
73
|
+
self.state.needs_redesign = True
|
|
74
|
+
|
|
75
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
76
|
+
self._coefs_cache = None
|
|
77
|
+
self._fs_cache = None
|
|
78
|
+
self.state.needs_redesign = True
|
|
79
|
+
|
|
80
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
81
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
82
|
+
ax_idx = message.get_axis_idx(axis)
|
|
83
|
+
fs = 1 / message.axes[axis].gain
|
|
84
|
+
|
|
85
|
+
if (
|
|
86
|
+
self._coefs_cache is None
|
|
87
|
+
or self.state.needs_redesign
|
|
88
|
+
or (self._fs_cache is None or not np.isclose(self._fs_cache, fs))
|
|
89
|
+
):
|
|
90
|
+
self._coefs_cache = self.get_design_function()(fs)
|
|
91
|
+
self._fs_cache = fs
|
|
92
|
+
self.state.needs_redesign = False
|
|
93
|
+
|
|
94
|
+
if (
|
|
95
|
+
self._coefs_cache is None
|
|
96
|
+
or self.settings.order <= 0
|
|
97
|
+
or message.data.size <= 0
|
|
98
|
+
):
|
|
99
|
+
return message
|
|
100
|
+
|
|
101
|
+
x = message.data
|
|
102
|
+
if self.settings.coef_type == "sos":
|
|
103
|
+
y = scipy.signal.sosfiltfilt(
|
|
104
|
+
self._coefs_cache,
|
|
105
|
+
x,
|
|
106
|
+
axis=ax_idx,
|
|
107
|
+
padtype=self.settings.padtype,
|
|
108
|
+
padlen=self.settings.padlen,
|
|
109
|
+
)
|
|
110
|
+
elif self.settings.coef_type == "ba":
|
|
111
|
+
b, a = self._coefs_cache
|
|
112
|
+
y = scipy.signal.filtfilt(
|
|
113
|
+
b,
|
|
114
|
+
a,
|
|
115
|
+
x,
|
|
116
|
+
axis=ax_idx,
|
|
117
|
+
padtype=self.settings.padtype,
|
|
118
|
+
padlen=self.settings.padlen,
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
ez.logger.error("coef_type must be 'sos' or 'ba'.")
|
|
122
|
+
raise ValueError("coef_type must be 'sos' or 'ba'.")
|
|
123
|
+
|
|
124
|
+
return replace(message, data=y)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class ButterworthZeroPhase(
|
|
128
|
+
BaseFilterByDesignTransformerUnit[
|
|
129
|
+
ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer
|
|
130
|
+
]
|
|
131
|
+
):
|
|
132
|
+
SETTINGS = ButterworthZeroPhaseSettings
|
ezmsg/sigproc/denormalize.py
CHANGED
|
@@ -22,15 +22,13 @@ class DenormalizeSettings(ez.Settings):
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@processor_state
|
|
25
|
-
class
|
|
25
|
+
class DenormalizeState:
|
|
26
26
|
gains: npt.NDArray | None = None
|
|
27
27
|
offsets: npt.NDArray | None = None
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class DenormalizeTransformer(
|
|
31
|
-
BaseStatefulTransformer[
|
|
32
|
-
DenormalizeSettings, AxisArray, AxisArray, DenormalizeRateState
|
|
33
|
-
]
|
|
31
|
+
BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]
|
|
34
32
|
):
|
|
35
33
|
"""
|
|
36
34
|
Scales data from a normalized distribution (mean=0, std=1) to a denormalized
|
|
@@ -78,7 +76,7 @@ class DenormalizeTransformer(
|
|
|
78
76
|
)
|
|
79
77
|
|
|
80
78
|
|
|
81
|
-
class
|
|
79
|
+
class DenormalizeUnit(
|
|
82
80
|
BaseTransformerUnit[
|
|
83
81
|
DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer
|
|
84
82
|
]
|
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
import numpy as np
|
|
6
|
+
import scipy.signal as sps
|
|
7
|
+
from ezmsg.sigproc.base import BaseStatefulTransformer, processor_state
|
|
8
|
+
from ezmsg.sigproc.filter import (
|
|
9
|
+
BACoeffs,
|
|
10
|
+
BaseFilterByDesignTransformerUnit,
|
|
11
|
+
BaseTransformerUnit,
|
|
12
|
+
FilterBaseSettings,
|
|
13
|
+
FilterByDesignTransformer,
|
|
14
|
+
)
|
|
15
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
16
|
+
from ezmsg.util.messages.util import replace
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FIRHilbertFilterSettings(FilterBaseSettings):
|
|
20
|
+
"""Settings for :obj:`FIRHilbertFilter`."""
|
|
21
|
+
|
|
22
|
+
# axis inherited from FilterBaseSettings
|
|
23
|
+
|
|
24
|
+
coef_type: str = "ba"
|
|
25
|
+
"""
|
|
26
|
+
Coefficient type. Must be 'ba' for FIR.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
order: int = 170
|
|
30
|
+
"""
|
|
31
|
+
Filter order (taps = order + 1).
|
|
32
|
+
Hilbert (type-III) filters require even order (odd taps).
|
|
33
|
+
If odd order (even taps), order will be incremented by 1.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
f_lo: float = 1.0
|
|
37
|
+
"""
|
|
38
|
+
Lower corner of Hilbert “pass” band (Hz).
|
|
39
|
+
Transition starts at f_lo.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
f_hi: float | None = None
|
|
43
|
+
"""
|
|
44
|
+
Upper corner of Hilbert “pass” band (Hz).
|
|
45
|
+
Transition starts at f_hi.
|
|
46
|
+
If None, highpass from f_lo to Nyquist.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
trans_lo: float = 1.0
|
|
50
|
+
"""
|
|
51
|
+
Transition width (Hz) below f_lo.
|
|
52
|
+
Decrease to sharpen transition.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
trans_hi: float = 1.0
|
|
56
|
+
"""
|
|
57
|
+
Transition width (Hz) at high end.
|
|
58
|
+
Decrease to sharpen transition.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
weight_pass: float = 1.0
|
|
62
|
+
"""
|
|
63
|
+
Weight for Hilbert pass region.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
weight_stop_lo: float = 1.0
|
|
67
|
+
"""
|
|
68
|
+
Weight for low stop band.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
weight_stop_hi: float = 1.0
|
|
72
|
+
"""
|
|
73
|
+
Weight for high stop band.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
norm_band: tuple[float, float] | None = None
|
|
77
|
+
"""
|
|
78
|
+
Optional normalization band (f_lo, f_hi) in Hz for gain normalization.
|
|
79
|
+
If None, no normalization is applied.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
norm_freq: float | None = None
|
|
83
|
+
"""
|
|
84
|
+
Optional normalization frequency in Hz for gain normalization.
|
|
85
|
+
If None, no normalization is applied.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def fir_hilbert_design_fun(
|
|
90
|
+
fs: float,
|
|
91
|
+
order: int = 170,
|
|
92
|
+
f_lo: float = 1.0,
|
|
93
|
+
f_hi: float | None = None,
|
|
94
|
+
trans_lo: float = 1.0,
|
|
95
|
+
trans_hi: float = 1.0,
|
|
96
|
+
weight_pass: float = 1.0,
|
|
97
|
+
weight_stop_lo: float = 1.0,
|
|
98
|
+
weight_stop_hi: float = 1.0,
|
|
99
|
+
norm_band: tuple[float, float] | None = None,
|
|
100
|
+
norm_freq: float | None = None,
|
|
101
|
+
) -> BACoeffs | None:
|
|
102
|
+
"""
|
|
103
|
+
Hilbert FIR filter design using the Remez exchange algorithm.
|
|
104
|
+
Design an `order`th-order FIR Hilbert filter and return the filter coefficients.
|
|
105
|
+
See :obj:`FIRHilbertFilterSettings` for argument description.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
The filter coefficients as a tuple of (b, a).
|
|
109
|
+
"""
|
|
110
|
+
if order <= 0:
|
|
111
|
+
return None
|
|
112
|
+
if order % 2 == 1:
|
|
113
|
+
order += 1
|
|
114
|
+
nyq = fs / 2.0
|
|
115
|
+
taps = order + 1
|
|
116
|
+
f1 = max(f_lo, 0.0) + trans_lo
|
|
117
|
+
f2 = (nyq - trans_hi) if (f_hi is None) else min(f_hi, nyq - trans_hi)
|
|
118
|
+
if not (0.0 < f1 < f2 < nyq):
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Hilbert passband collapsed or invalid: "
|
|
121
|
+
f"f_lo={f_lo}, f_hi={f_hi}, trans_lo={trans_lo}, trans_hi={trans_hi}, fs={fs}"
|
|
122
|
+
)
|
|
123
|
+
# Bands: [0, f1-trans_lo] stop ; [f1, f2] pass (Hilbert) ; [f2+trans_hi, nyq] stop
|
|
124
|
+
bands = [0.0, max(f1 - trans_lo, 0.0), f1, f2, min(f2 + trans_hi, nyq), nyq]
|
|
125
|
+
desired = [0.0, 1.0, 0.0]
|
|
126
|
+
weight = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0), max(weight_stop_hi, 0.0)]
|
|
127
|
+
for i in range(1, len(bands) - 1):
|
|
128
|
+
if bands[i] <= bands[i - 1]:
|
|
129
|
+
bands[i] = np.nextafter(bands[i - 1], np.inf)
|
|
130
|
+
if bands[-2] >= nyq:
|
|
131
|
+
ez.logger.warning(
|
|
132
|
+
"Hilbert upper stopband collapsed; using 2-band (stop/pass) design."
|
|
133
|
+
)
|
|
134
|
+
bands = bands[:-3] + [nyq]
|
|
135
|
+
desired = desired[:-1]
|
|
136
|
+
weight = weight[:-1]
|
|
137
|
+
b = sps.remez(taps, bands, desired, weight=weight, type="hilbert", fs=fs)
|
|
138
|
+
a = np.array([1.0])
|
|
139
|
+
g = None
|
|
140
|
+
if norm_freq is not None:
|
|
141
|
+
if norm_freq < f1 or norm_freq > f2:
|
|
142
|
+
ez.logger.warning(
|
|
143
|
+
"Invalid normalization frequency specifications. Skipping normalization."
|
|
144
|
+
)
|
|
145
|
+
else:
|
|
146
|
+
f0 = float(norm_freq)
|
|
147
|
+
w = 2.0 * np.pi * (np.asarray([f0], dtype=np.float64) / fs)
|
|
148
|
+
_, H = sps.freqz(b, a, worN=w)
|
|
149
|
+
g = float(np.abs(H[0]))
|
|
150
|
+
elif norm_band is not None:
|
|
151
|
+
lo, hi = norm_band
|
|
152
|
+
if lo < f1 or hi > f2:
|
|
153
|
+
lo = max(lo, f1)
|
|
154
|
+
hi = min(hi, f2)
|
|
155
|
+
ez.logger.warning(
|
|
156
|
+
"Normalization band outside passband. Clipping to passband for normalization."
|
|
157
|
+
)
|
|
158
|
+
if lo >= hi:
|
|
159
|
+
ez.logger.warning(
|
|
160
|
+
"Invalid normalization band specifications. Skipping normalization."
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
freqs = np.linspace(lo, hi, 2048, dtype=np.float64)
|
|
164
|
+
w = 2.0 * np.pi * (np.asarray(freqs, dtype=np.float64) / fs)
|
|
165
|
+
_, H = sps.freqz(b, a, worN=w)
|
|
166
|
+
g = float(np.median(np.abs(H)))
|
|
167
|
+
if g is not None and g > 0:
|
|
168
|
+
b = b / g
|
|
169
|
+
return (b, a)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class FIRHilbertFilterTransformer(
|
|
173
|
+
FilterByDesignTransformer[FIRHilbertFilterSettings, BACoeffs]
|
|
174
|
+
):
|
|
175
|
+
def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]:
|
|
176
|
+
if self.settings.coef_type != "ba":
|
|
177
|
+
ez.logger.error("FIRHilbert only supports coef_type='ba'.")
|
|
178
|
+
raise ValueError("FIRHilbert only supports coef_type='ba'.")
|
|
179
|
+
|
|
180
|
+
return functools.partial(
|
|
181
|
+
fir_hilbert_design_fun,
|
|
182
|
+
order=self.settings.order,
|
|
183
|
+
f_lo=self.settings.f_lo,
|
|
184
|
+
f_hi=self.settings.f_hi,
|
|
185
|
+
trans_lo=self.settings.trans_lo,
|
|
186
|
+
trans_hi=self.settings.trans_hi,
|
|
187
|
+
weight_pass=self.settings.weight_pass,
|
|
188
|
+
weight_stop_lo=self.settings.weight_stop_lo,
|
|
189
|
+
weight_stop_hi=self.settings.weight_stop_hi,
|
|
190
|
+
norm_band=self.settings.norm_band,
|
|
191
|
+
norm_freq=self.settings.norm_freq,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def get_taps(self) -> int | None:
|
|
195
|
+
if self._state.filter is None:
|
|
196
|
+
return None
|
|
197
|
+
b, _ = self._state.filter.settings.coefs
|
|
198
|
+
return b.size if b is not None else None
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class FIRHilbertFilterUnit(
|
|
202
|
+
BaseFilterByDesignTransformerUnit[
|
|
203
|
+
FIRHilbertFilterSettings, FIRHilbertFilterTransformer
|
|
204
|
+
]
|
|
205
|
+
):
|
|
206
|
+
SETTINGS = FIRHilbertFilterSettings
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@processor_state
|
|
210
|
+
class FIRHilbertEnvelopeState:
|
|
211
|
+
filter: FIRHilbertFilterTransformer | None = None
|
|
212
|
+
delay_buf: np.ndarray | None = None
|
|
213
|
+
dly: int | None = None
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class FIRHilbertEnvelopeTransformer(
|
|
217
|
+
BaseStatefulTransformer[
|
|
218
|
+
FIRHilbertFilterSettings, AxisArray, AxisArray, FIRHilbertEnvelopeState
|
|
219
|
+
]
|
|
220
|
+
):
|
|
221
|
+
"""
|
|
222
|
+
Processor for computing the envelope of a signal using the Hilbert transform.
|
|
223
|
+
|
|
224
|
+
This processor applies a Hilbert FIR filter to the input signal to obtain the analytic signal, from which the
|
|
225
|
+
envelope is computed.
|
|
226
|
+
|
|
227
|
+
The processor expects and outputs `AxisArray` messages with a `"time"` (time) axis.
|
|
228
|
+
|
|
229
|
+
Settings:
|
|
230
|
+
---------
|
|
231
|
+
order : int
|
|
232
|
+
Filter order (taps = order + 1).
|
|
233
|
+
Hilbert (type-III) filters require even order (odd taps).
|
|
234
|
+
If odd order (even taps), order will be incremented by 1.
|
|
235
|
+
f_lo : float
|
|
236
|
+
Lower corner of Hilbert “pass” band (Hz).
|
|
237
|
+
Transition starts at f_lo.
|
|
238
|
+
f_hi : float, optional
|
|
239
|
+
Upper corner of Hilbert “pass” band (Hz).
|
|
240
|
+
Transition starts at f_hi.
|
|
241
|
+
If None, highpass from f_lo to Nyquist.
|
|
242
|
+
trans_lo : float
|
|
243
|
+
Transition width (Hz) below f_lo.
|
|
244
|
+
Decrease to sharpen transition.
|
|
245
|
+
trans_hi : float
|
|
246
|
+
Transition width (Hz) above f_hi.
|
|
247
|
+
Decrease to sharpen transition.
|
|
248
|
+
weight_pass : float
|
|
249
|
+
Weight for Hilbert pass region.
|
|
250
|
+
weight_stop_lo : float
|
|
251
|
+
Weight for low stop band.
|
|
252
|
+
weight_stop_hi : float
|
|
253
|
+
Weight for high stop band.
|
|
254
|
+
norm_band : tuple(float, float), optional
|
|
255
|
+
Optional normalization band (f_lo, f_hi) in Hz for gain normalization.
|
|
256
|
+
If None, no normalization is applied.
|
|
257
|
+
norm_freq : float, optional
|
|
258
|
+
Optional normalization frequency in Hz for gain normalization.
|
|
259
|
+
If None, no normalization is applied.
|
|
260
|
+
|
|
261
|
+
Example:
|
|
262
|
+
-----------------------------
|
|
263
|
+
```python
|
|
264
|
+
processor = FIRHilbertEnvelopeTransformer(
|
|
265
|
+
settings=FIRHilbertFilterSettings(
|
|
266
|
+
order=170,
|
|
267
|
+
f_lo=1.0,
|
|
268
|
+
f_hi=50.0,
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
```
|
|
272
|
+
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
276
|
+
axis = self.settings.axis or message.dims[0]
|
|
277
|
+
gain = getattr(self._state.filter, "gain", 0.0)
|
|
278
|
+
axis_idx = message.get_axis_idx(axis)
|
|
279
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
280
|
+
return hash((message.key, samp_shape, gain))
|
|
281
|
+
|
|
282
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
283
|
+
self._state.filter = FIRHilbertFilterTransformer(settings=self.settings)
|
|
284
|
+
self._state.delay_buf = None
|
|
285
|
+
self._state.dly = None
|
|
286
|
+
|
|
287
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
288
|
+
y_imag_msg = self._state.filter(message)
|
|
289
|
+
y_imag = y_imag_msg.data
|
|
290
|
+
|
|
291
|
+
axis_name = self.settings.axis or message.dims[0]
|
|
292
|
+
axis_idx = message.get_axis_idx(axis_name)
|
|
293
|
+
if self._state.dly is None:
|
|
294
|
+
taps = self._state.filter.get_taps()
|
|
295
|
+
self._state.dly = (taps - 1) // 2
|
|
296
|
+
|
|
297
|
+
x = message.data
|
|
298
|
+
|
|
299
|
+
move_axis = False
|
|
300
|
+
if axis_idx != x.ndim - 1:
|
|
301
|
+
x = np.moveaxis(x, axis_idx, -1)
|
|
302
|
+
y_imag = np.moveaxis(y_imag, axis_idx, -1)
|
|
303
|
+
move_axis = True
|
|
304
|
+
|
|
305
|
+
if self._state.delay_buf is None:
|
|
306
|
+
lead_shape = x.shape[:-1]
|
|
307
|
+
self._state.delay_buf = np.zeros(
|
|
308
|
+
lead_shape + (self._state.dly,), dtype=x.dtype
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
x_cat = np.concatenate([self._state.delay_buf, x], axis=-1)
|
|
312
|
+
x_delayed_full = x_cat[..., : -self._state.dly]
|
|
313
|
+
y_real = x_delayed_full[..., -x.shape[-1] :]
|
|
314
|
+
|
|
315
|
+
self._state.delay_buf = x_cat[..., -self._state.dly :].copy()
|
|
316
|
+
|
|
317
|
+
analytic = y_real.astype(np.complex64) + 1j * y_imag.astype(np.complex64)
|
|
318
|
+
out = np.abs(analytic)
|
|
319
|
+
|
|
320
|
+
if move_axis:
|
|
321
|
+
out = np.moveaxis(out, -1, axis_idx)
|
|
322
|
+
|
|
323
|
+
return replace(message, data=out, axes=message.axes)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class FIRHilbertEnvelopeUnit(
|
|
327
|
+
BaseTransformerUnit[
|
|
328
|
+
FIRHilbertFilterSettings,
|
|
329
|
+
AxisArray,
|
|
330
|
+
AxisArray,
|
|
331
|
+
FIRHilbertEnvelopeTransformer,
|
|
332
|
+
]
|
|
333
|
+
):
|
|
334
|
+
"""
|
|
335
|
+
Unit wrapper for the `FIRHilbertEnvelopeTransformer`.
|
|
336
|
+
|
|
337
|
+
This unit provides a plug-and-play interface for calculating the envelope using the FIR Hilbert transform on a
|
|
338
|
+
signal in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs processed data in the same format.
|
|
339
|
+
|
|
340
|
+
Example:
|
|
341
|
+
--------
|
|
342
|
+
```python
|
|
343
|
+
unit = FIRHilbertEnvelopeUnit(
|
|
344
|
+
settings=FIRHilbertFilterSettings(
|
|
345
|
+
order=170,
|
|
346
|
+
f_lo=1.0,
|
|
347
|
+
f_hi=50.0,
|
|
348
|
+
)
|
|
349
|
+
)
|
|
350
|
+
```
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
SETTINGS = FIRHilbertFilterSettings
|
ezmsg/sigproc/fir_pmc.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
import numpy as np
|
|
6
|
+
import scipy.signal
|
|
7
|
+
from ezmsg.sigproc.filter import (
|
|
8
|
+
BACoeffs,
|
|
9
|
+
BaseFilterByDesignTransformerUnit,
|
|
10
|
+
FilterBaseSettings,
|
|
11
|
+
FilterByDesignTransformer,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ParksMcClellanFIRSettings(FilterBaseSettings):
|
|
16
|
+
"""Settings for :obj:`ParksMcClellanFIR`."""
|
|
17
|
+
|
|
18
|
+
# axis inherited from FilterBaseSettings
|
|
19
|
+
|
|
20
|
+
coef_type: str = "ba"
|
|
21
|
+
"""
|
|
22
|
+
Coefficient type. Must be 'ba' for FIR.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
order: int = 0
|
|
26
|
+
"""
|
|
27
|
+
Filter order (taps = order + 1).
|
|
28
|
+
PMC FIR filters require even order (odd taps).
|
|
29
|
+
If odd order (even taps), order will be incremented by 1.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
cuton: float | None = None
|
|
33
|
+
"""
|
|
34
|
+
Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
|
|
35
|
+
if this is lower than `cutoff` then this is the beginning of the bandpass
|
|
36
|
+
or if this is greater than `cutoff` then this is the end of the bandstop.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
cutoff: float | None = None
|
|
40
|
+
"""
|
|
41
|
+
Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
|
|
42
|
+
if this is greater than `cuton` then this is the end of the bandpass,
|
|
43
|
+
or if this is less than `cuton` then this is the beginning of the bandstop.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
transition: float = 10.0
|
|
47
|
+
"""
|
|
48
|
+
Transition bandwidth (Hz) applied to each passband edge.
|
|
49
|
+
For low/high: single transition. For bands: both edges.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
weight_pass: float = 1.0
|
|
53
|
+
"""
|
|
54
|
+
Weight for the passband.
|
|
55
|
+
Used for both high and low passbands in bandstop filters.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
weight_stop_lo: float = 1.0
|
|
59
|
+
"""
|
|
60
|
+
Weight for the lower stopband.
|
|
61
|
+
Not used for bandstop filters.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
weight_stop_hi: float = 1.0
|
|
65
|
+
"""
|
|
66
|
+
Weight for the upper stopband.
|
|
67
|
+
Used as the central-stop weight for bandstop filters.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def filter_specs(
|
|
71
|
+
self,
|
|
72
|
+
) -> tuple[str, tuple[float, float] | float] | None:
|
|
73
|
+
"""
|
|
74
|
+
Determine the filter type given the corner frequencies.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
A tuple with the first element being a string indicating the filter type
|
|
78
|
+
(one of "lowpass", "highpass", "bandpass", "bandstop")
|
|
79
|
+
and the second element being the corner frequency or frequencies.
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
if self.cuton is None and self.cutoff is None:
|
|
83
|
+
return None
|
|
84
|
+
elif self.cuton is None and self.cutoff is not None:
|
|
85
|
+
return "lowpass", self.cutoff
|
|
86
|
+
elif self.cuton is not None and self.cutoff is None:
|
|
87
|
+
return "highpass", self.cuton
|
|
88
|
+
elif self.cuton is not None and self.cutoff is not None:
|
|
89
|
+
if self.cuton <= self.cutoff:
|
|
90
|
+
return "bandpass", (self.cuton, self.cutoff)
|
|
91
|
+
else:
|
|
92
|
+
return "bandstop", (self.cutoff, self.cuton)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def parks_mcclellan_design_fun(
|
|
96
|
+
fs: float,
|
|
97
|
+
order: int = 0,
|
|
98
|
+
cuton: float | None = None,
|
|
99
|
+
cutoff: float | None = None,
|
|
100
|
+
transition: float = 10.0,
|
|
101
|
+
weight_pass: float = 1.0,
|
|
102
|
+
weight_stop_lo: float = 1.0,
|
|
103
|
+
weight_stop_hi: float = 1.0,
|
|
104
|
+
) -> BACoeffs | None:
|
|
105
|
+
"""
|
|
106
|
+
See :obj:`ParksMcClellanFIRSettings.filter_specs` for an explanation of specifying different
|
|
107
|
+
filter types (lowpass, highpass, bandpass, bandstop) from the parameters.
|
|
108
|
+
|
|
109
|
+
Designs a Parks-McClellan FIR filter via the Remez exchange algorithm using the given specifications.
|
|
110
|
+
PMC filters are equiripple and linear phase.
|
|
111
|
+
|
|
112
|
+
You are likely to want to use this function with :obj:`filter_by_design`, which only passes `fs` to the design
|
|
113
|
+
function (this), meaning that you should wrap this function with a lambda or prepare with functools.partial.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
fs: The sampling frequency of the data in Hz.
|
|
117
|
+
order: Filter order.
|
|
118
|
+
cuton: Corner frequency of the filter in Hz.
|
|
119
|
+
cutoff: Corner frequency of the filter in Hz.
|
|
120
|
+
transition: Transition bandwidth (Hz) applied to each passband edge.
|
|
121
|
+
weight_pass: Weight for the passband.
|
|
122
|
+
weight_stop_lo: Weight for the lower stopband.
|
|
123
|
+
weight_stop_hi: Weight for the upper stopband.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
The filter coefficients as a tuple of (b, a).
|
|
127
|
+
"""
|
|
128
|
+
if order <= 0:
|
|
129
|
+
return None
|
|
130
|
+
if order % 2 == 1:
|
|
131
|
+
order += 1
|
|
132
|
+
|
|
133
|
+
specs = ParksMcClellanFIRSettings(cuton=cuton, cutoff=cutoff).filter_specs()
|
|
134
|
+
if specs is None:
|
|
135
|
+
# Under-specified: no filter
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
btype, corners = specs
|
|
139
|
+
nyq = fs / 2.0
|
|
140
|
+
tw = max(transition, 0.0)
|
|
141
|
+
|
|
142
|
+
def clip_hz(x: float) -> float:
|
|
143
|
+
return float(min(max(x, 0.0), nyq))
|
|
144
|
+
|
|
145
|
+
if btype == "lowpass":
|
|
146
|
+
b = [0.0, clip_hz(corners), clip_hz(corners + tw), nyq]
|
|
147
|
+
d = [1.0, 0.0]
|
|
148
|
+
w = [max(weight_pass, 0.0), max(weight_stop_hi, 0.0)]
|
|
149
|
+
|
|
150
|
+
elif btype == "highpass":
|
|
151
|
+
b = [0.0, clip_hz(corners - tw), clip_hz(corners), nyq]
|
|
152
|
+
d = [0.0, 1.0]
|
|
153
|
+
w = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0)]
|
|
154
|
+
|
|
155
|
+
elif btype == "bandpass":
|
|
156
|
+
b = [
|
|
157
|
+
0.0,
|
|
158
|
+
clip_hz(corners[0] - tw),
|
|
159
|
+
clip_hz(corners[0]),
|
|
160
|
+
clip_hz(corners[1]),
|
|
161
|
+
clip_hz(corners[1] + tw),
|
|
162
|
+
nyq,
|
|
163
|
+
]
|
|
164
|
+
d = [0.0, 1.0, 0.0]
|
|
165
|
+
w = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0), max(weight_stop_hi, 0.0)]
|
|
166
|
+
|
|
167
|
+
else:
|
|
168
|
+
b = [
|
|
169
|
+
0.0,
|
|
170
|
+
clip_hz(corners[0]),
|
|
171
|
+
clip_hz(corners[0] + tw),
|
|
172
|
+
clip_hz(corners[1] - tw),
|
|
173
|
+
clip_hz(corners[1]),
|
|
174
|
+
nyq,
|
|
175
|
+
]
|
|
176
|
+
d = [1.0, 0.0, 1.0]
|
|
177
|
+
# For bandstop we can reuse stop_hi as central-stop weight; stop_lo is the DC-side passband stop weight
|
|
178
|
+
w = [max(weight_pass, 0.0), max(weight_stop_hi, 0.0), max(weight_pass, 0.0)]
|
|
179
|
+
|
|
180
|
+
# Ensure bands strictly increase and have nonzero width per segment
|
|
181
|
+
# Adjust tiny overlaps due to clipping
|
|
182
|
+
for i in range(1, len(b)):
|
|
183
|
+
if b[i] <= b[i - 1]:
|
|
184
|
+
b[i] = min(b[i - 1] + 1e-6, nyq)
|
|
185
|
+
|
|
186
|
+
b = scipy.signal.remez(numtaps=order + 1, bands=b, desired=d, weight=w, fs=fs)
|
|
187
|
+
return (b, np.array([1.0]))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class ParksMcClellanFIRTransformer(
|
|
191
|
+
FilterByDesignTransformer[ParksMcClellanFIRSettings, BACoeffs]
|
|
192
|
+
):
|
|
193
|
+
def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]:
|
|
194
|
+
if self.settings.coef_type != "ba":
|
|
195
|
+
ez.logger.error("ParksMcClellanFIR only supports coef_type='ba'.")
|
|
196
|
+
raise ValueError("ParksMcClellanFIR only supports coef_type='ba'.")
|
|
197
|
+
return functools.partial(
|
|
198
|
+
parks_mcclellan_design_fun,
|
|
199
|
+
order=self.settings.order,
|
|
200
|
+
cuton=self.settings.cuton,
|
|
201
|
+
cutoff=self.settings.cutoff,
|
|
202
|
+
transition=self.settings.transition,
|
|
203
|
+
weight_pass=self.settings.weight_pass,
|
|
204
|
+
weight_stop_lo=self.settings.weight_stop_lo,
|
|
205
|
+
weight_stop_hi=self.settings.weight_stop_hi,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class ParksMcClellanFIR(
|
|
210
|
+
BaseFilterByDesignTransformerUnit[
|
|
211
|
+
ParksMcClellanFIRSettings, ParksMcClellanFIRTransformer
|
|
212
|
+
]
|
|
213
|
+
):
|
|
214
|
+
SETTINGS = ParksMcClellanFIRSettings
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
from ezmsg.sigproc.base import (
|
|
7
|
+
BaseAdaptiveTransformer,
|
|
8
|
+
BaseAdaptiveTransformerUnit,
|
|
9
|
+
processor_state,
|
|
10
|
+
)
|
|
11
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
12
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
|
+
from ezmsg.util.messages.util import replace
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RollingScalerSettings(ez.Settings):
|
|
17
|
+
axis: str = "time"
|
|
18
|
+
"""
|
|
19
|
+
Axis along which samples are arranged.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
k_samples: int | None = 20
|
|
23
|
+
"""
|
|
24
|
+
Rolling window size in number of samples.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
window_size: float | None = None
|
|
28
|
+
"""
|
|
29
|
+
Rolling window size in seconds.
|
|
30
|
+
If set, overrides `k_samples`.
|
|
31
|
+
`update_with_signal` likely should be True if using this option.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
update_with_signal: bool = False
|
|
35
|
+
"""
|
|
36
|
+
If True, update rolling statistics using the incoming process stream.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
min_samples: int = 1
|
|
40
|
+
"""
|
|
41
|
+
Minimum number of samples required to compute statistics.
|
|
42
|
+
Used when `window_size` is not set.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
min_seconds: float = 1.0
|
|
46
|
+
"""
|
|
47
|
+
Minimum duration in seconds required to compute statistics.
|
|
48
|
+
Used when `window_size` is set.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
artifact_z_thresh: float | None = None
|
|
52
|
+
"""
|
|
53
|
+
Threshold for z-score based artifact detection.
|
|
54
|
+
If set, samples with any channel exceeding this z-score will be excluded
|
|
55
|
+
from updating the rolling statistics.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
clip: float | None = 10.0
|
|
59
|
+
"""
|
|
60
|
+
If set, clip the output values to the range [-clip, clip].
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@processor_state
|
|
65
|
+
class RollingScalerState:
|
|
66
|
+
mean: npt.NDArray | None = None
|
|
67
|
+
N: int = 0
|
|
68
|
+
M2: npt.NDArray | None = None
|
|
69
|
+
samples: deque | None = None
|
|
70
|
+
k_samples: int | None = None
|
|
71
|
+
min_samples: int | None = None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class RollingScalerProcessor(
|
|
75
|
+
BaseAdaptiveTransformer[
|
|
76
|
+
RollingScalerSettings, AxisArray, AxisArray, RollingScalerState
|
|
77
|
+
]
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
Processor for rolling z-score normalization of input `AxisArray` messages.
|
|
81
|
+
|
|
82
|
+
The processor maintains rolling statistics (mean and variance) over the last `k_samples`
|
|
83
|
+
samples received via the `partial_fit()` method. When processing an `AxisArray` message,
|
|
84
|
+
it normalizes the data using the current rolling statistics.
|
|
85
|
+
|
|
86
|
+
The input `AxisArray` messages are expected to have shape `(time, ch)`, where `ch` is the
|
|
87
|
+
channel axis. The processor computes the z-score for each channel independently.
|
|
88
|
+
|
|
89
|
+
Note: You should consider instead using the AdaptiveStandardScalerTransformer which
|
|
90
|
+
is computationally more efficient and uses less memory. This RollingScalerProcessor
|
|
91
|
+
is primarily provided to reproduce processing in the literature.
|
|
92
|
+
|
|
93
|
+
Settings:
|
|
94
|
+
---------
|
|
95
|
+
k_samples: int
|
|
96
|
+
Number of previous samples to use for rolling statistics.
|
|
97
|
+
|
|
98
|
+
Example:
|
|
99
|
+
-----------------------------
|
|
100
|
+
```python
|
|
101
|
+
processor = RollingScalerProcessor(
|
|
102
|
+
settings=RollingScalerSettings(
|
|
103
|
+
k_samples=20 # Number of previous samples to use for rolling statistics
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
```
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
110
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
111
|
+
gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
|
|
112
|
+
axis_idx = message.get_axis_idx(axis)
|
|
113
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
114
|
+
return hash((message.key, samp_shape, gain))
|
|
115
|
+
|
|
116
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
117
|
+
ch = message.data.shape[-1]
|
|
118
|
+
self._state.mean = np.zeros(ch)
|
|
119
|
+
self._state.N = 0
|
|
120
|
+
self._state.M2 = np.zeros(ch)
|
|
121
|
+
self._state.k_samples = (
|
|
122
|
+
int(
|
|
123
|
+
np.ceil(
|
|
124
|
+
self.settings.window_size / message.axes[self.settings.axis].gain
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
if self.settings.window_size is not None
|
|
128
|
+
else self.settings.k_samples
|
|
129
|
+
)
|
|
130
|
+
if self._state.k_samples is not None and self._state.k_samples < 1:
|
|
131
|
+
ez.logger.warning(
|
|
132
|
+
"window_size smaller than sample gain; setting k_samples to 1."
|
|
133
|
+
)
|
|
134
|
+
self._state.k_samples = 1
|
|
135
|
+
elif self._state.k_samples is None:
|
|
136
|
+
ez.logger.warning(
|
|
137
|
+
"k_samples is None; z-score accumulation will be unbounded."
|
|
138
|
+
)
|
|
139
|
+
self._state.samples = deque(maxlen=self._state.k_samples)
|
|
140
|
+
self._state.min_samples = (
|
|
141
|
+
int(
|
|
142
|
+
np.ceil(
|
|
143
|
+
self.settings.min_seconds / message.axes[self.settings.axis].gain
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
if self.settings.window_size is not None
|
|
147
|
+
else self.settings.min_samples
|
|
148
|
+
)
|
|
149
|
+
if (
|
|
150
|
+
self._state.k_samples is not None
|
|
151
|
+
and self._state.min_samples > self._state.k_samples
|
|
152
|
+
):
|
|
153
|
+
ez.logger.warning(
|
|
154
|
+
"min_samples is greater than k_samples; adjusting min_samples to k_samples."
|
|
155
|
+
)
|
|
156
|
+
self._state.min_samples = self._state.k_samples
|
|
157
|
+
|
|
158
|
+
def _add_batch_stats(self, x: npt.NDArray) -> None:
|
|
159
|
+
x = np.asarray(x, dtype=np.float64)
|
|
160
|
+
n_b = x.shape[0]
|
|
161
|
+
mean_b = np.mean(x, axis=0)
|
|
162
|
+
M2_b = np.sum((x - mean_b) ** 2, axis=0)
|
|
163
|
+
|
|
164
|
+
if (
|
|
165
|
+
self._state.k_samples is not None
|
|
166
|
+
and len(self._state.samples) == self._state.k_samples
|
|
167
|
+
):
|
|
168
|
+
n_old, mean_old, M2_old = self._state.samples.popleft()
|
|
169
|
+
N_T = self._state.N
|
|
170
|
+
N_new = N_T - n_old
|
|
171
|
+
|
|
172
|
+
if N_new <= 0:
|
|
173
|
+
self._state.N = 0
|
|
174
|
+
self._state.mean = np.zeros_like(self._state.mean)
|
|
175
|
+
self._state.M2 = np.zeros_like(self._state.M2)
|
|
176
|
+
else:
|
|
177
|
+
delta = mean_old - self._state.mean
|
|
178
|
+
self._state.N = N_new
|
|
179
|
+
self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
|
|
180
|
+
self._state.M2 = (
|
|
181
|
+
self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
N_A = self._state.N
|
|
185
|
+
N = N_A + n_b
|
|
186
|
+
delta = mean_b - self._state.mean
|
|
187
|
+
self._state.mean = self._state.mean + delta * (n_b / N)
|
|
188
|
+
self._state.M2 = self._state.M2 + M2_b + (delta * delta) * (N_A * n_b / N)
|
|
189
|
+
self._state.N = N
|
|
190
|
+
|
|
191
|
+
self._state.samples.append((n_b, mean_b, M2_b))
|
|
192
|
+
|
|
193
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
194
|
+
x = message.sample.data
|
|
195
|
+
self._add_batch_stats(x)
|
|
196
|
+
|
|
197
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
198
|
+
if self._state.N == 0 or self._state.N < self._state.min_samples:
|
|
199
|
+
if self.settings.update_with_signal:
|
|
200
|
+
x = message.data
|
|
201
|
+
if self.settings.artifact_z_thresh is not None and self._state.N > 0:
|
|
202
|
+
varis = self._state.M2 / self._state.N
|
|
203
|
+
std = np.maximum(np.sqrt(varis), 1e-8)
|
|
204
|
+
z = np.abs((x - self._state.mean) / std)
|
|
205
|
+
mask = np.any(z > self.settings.artifact_z_thresh, axis=1)
|
|
206
|
+
x = x[~mask]
|
|
207
|
+
if x.size > 0:
|
|
208
|
+
self._add_batch_stats(x)
|
|
209
|
+
return message
|
|
210
|
+
|
|
211
|
+
varis = self._state.M2 / self._state.N
|
|
212
|
+
std = np.maximum(np.sqrt(varis), 1e-8)
|
|
213
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
214
|
+
result = (message.data - self._state.mean) / std
|
|
215
|
+
result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
|
|
216
|
+
if self.settings.clip is not None:
|
|
217
|
+
result = np.clip(result, -self.settings.clip, self.settings.clip)
|
|
218
|
+
|
|
219
|
+
if self.settings.update_with_signal:
|
|
220
|
+
x = message.data
|
|
221
|
+
if self.settings.artifact_z_thresh is not None:
|
|
222
|
+
z_scores = np.abs((x - self._state.mean) / std)
|
|
223
|
+
mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1)
|
|
224
|
+
x = x[~mask]
|
|
225
|
+
if x.size > 0:
|
|
226
|
+
self._add_batch_stats(x)
|
|
227
|
+
|
|
228
|
+
return replace(message, data=result)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class RollingScalerUnit(
|
|
232
|
+
BaseAdaptiveTransformerUnit[
|
|
233
|
+
RollingScalerSettings,
|
|
234
|
+
AxisArray,
|
|
235
|
+
AxisArray,
|
|
236
|
+
RollingScalerProcessor,
|
|
237
|
+
]
|
|
238
|
+
):
|
|
239
|
+
"""
|
|
240
|
+
Unit wrapper for :obj:`RollingScalerProcessor`.
|
|
241
|
+
|
|
242
|
+
This unit performs rolling z-score normalization on incoming `AxisArray` messages. The unit maintains rolling
|
|
243
|
+
statistics (mean and variance) over the last `k_samples` samples received. When processing an `AxisArray` message,
|
|
244
|
+
it normalizes the data using the current rolling statistics.
|
|
245
|
+
|
|
246
|
+
Example:
|
|
247
|
+
-----------------------------
|
|
248
|
+
```python
|
|
249
|
+
unit = RollingScalerUnit(
|
|
250
|
+
settings=RollingScalerSettings(
|
|
251
|
+
k_samples=20 # Number of previous samples to use for rolling statistics
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
```
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
SETTINGS = RollingScalerSettings
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-sigproc
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.5.0
|
|
4
4
|
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
|
-
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
|
|
5
|
+
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
|
|
6
6
|
License-Expression: MIT
|
|
7
7
|
License-File: LICENSE.txt
|
|
8
8
|
Requires-Python: >=3.10.15
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
|
|
2
|
-
ezmsg/sigproc/__version__.py,sha256=
|
|
2
|
+
ezmsg/sigproc/__version__.py,sha256=bdfCUdK0KgIdZbIExZ_otf09cMr8k-qiQbglewDXQI8,704
|
|
3
3
|
ezmsg/sigproc/activation.py,sha256=qWAhpbFBxSoqbGy4P9JKE5LY-5v8rQI1U81OvNxBG2Y,2820
|
|
4
4
|
ezmsg/sigproc/adaptive_lattice_notch.py,sha256=3M65PrZpdgBlQtE7Ph4Gu2ISIyWw4j8Xxhm5PpSkLFw,9102
|
|
5
5
|
ezmsg/sigproc/affinetransform.py,sha256=WU495KoDKZfHPS3Dumh65rgf639koNlfDIx_torIByg,8662
|
|
@@ -7,10 +7,11 @@ ezmsg/sigproc/aggregate.py,sha256=wHUP_aS9NgnOxBCPN1_tSxCqMMb8UPBEoKwGKX7-ASk,91
|
|
|
7
7
|
ezmsg/sigproc/bandpower.py,sha256=j-Y6iWjD2xkggfi-4HAFJVBPJHHBGvAZy1uM4murZkQ,2319
|
|
8
8
|
ezmsg/sigproc/base.py,sha256=PQr03O2P1v9LzcSR0GJLvPpBCLtnmGaz76gUeXphcH4,48753
|
|
9
9
|
ezmsg/sigproc/butterworthfilter.py,sha256=7ZP4CRsXBt3-5dzyUjD45vc0J3Fhpm4CLrk-ps28jhc,5305
|
|
10
|
+
ezmsg/sigproc/butterworthzerophase.py,sha256=B95FxHBk0uSXizsndR5yc8I2V_gXVNWZ9WVMS4m1Hek,4190
|
|
10
11
|
ezmsg/sigproc/cheby.py,sha256=-aSauAwxJmmSSiRaw5qGY9rvYFOmk1bZlS4gGrS0jls,3737
|
|
11
12
|
ezmsg/sigproc/combfilter.py,sha256=5UCfzGESpS5LSx6rxZv8_n25ZUvOOmws-mM_gpTZNhU,4777
|
|
12
13
|
ezmsg/sigproc/decimate.py,sha256=Lz46fBllWagu17QeQzgklm6GWCV-zPysiydiby2IElU,2347
|
|
13
|
-
ezmsg/sigproc/denormalize.py,sha256=
|
|
14
|
+
ezmsg/sigproc/denormalize.py,sha256=CujviBepGysjB5X7RZoDOMC5tUC97ryHnUdqhi-eMPo,3065
|
|
14
15
|
ezmsg/sigproc/detrend.py,sha256=7bpjFKdk2b6FdVn2GEtMbWtCuk7ToeiYKEBHVbN4Gd0,903
|
|
15
16
|
ezmsg/sigproc/diff.py,sha256=P5BBjR7KdaCL9aD3GG09cmC7a-3cxDeEUw4nKdQ1HY8,2895
|
|
16
17
|
ezmsg/sigproc/downsample.py,sha256=0X6EwPZ_XTwA2-nx5w-2HmMZUEDFuGAYF5EmPSuuVj8,3721
|
|
@@ -21,12 +22,15 @@ ezmsg/sigproc/fbcca.py,sha256=8NTJAOpHIvNFwQepui2_ZaJV4SMDFgXrqoWJyiQdF5U,12362
|
|
|
21
22
|
ezmsg/sigproc/filter.py,sha256=1MQUZDFIf6HAHuuhGQEvH4Yd6Jv_vv12PM25YaHjdxc,11921
|
|
22
23
|
ezmsg/sigproc/filterbank.py,sha256=pJzv_G6chgWa1ARmRjMAMgt9eEGnA-ZbMSge4EWrcYY,13633
|
|
23
24
|
ezmsg/sigproc/filterbankdesign.py,sha256=OfIXM0ushSqbdSQG9DZB1Mh57d-lqdJQX8aqfxNN67E,4734
|
|
25
|
+
ezmsg/sigproc/fir_hilbert.py,sha256=qqHTp-yIhAD3VBoENTxpBmy7TgF2lYqbZ65OSfqeWO4,11042
|
|
26
|
+
ezmsg/sigproc/fir_pmc.py,sha256=ApWMl7WNQ9Ihr-J74DrAVwxD1r8gvLcElYcEL0RtQ2U,7024
|
|
24
27
|
ezmsg/sigproc/firfilter.py,sha256=MCrwY3DLq-uMLX04JswVB9oHBSYJGbdUiQYW6eRdkxE,3805
|
|
25
28
|
ezmsg/sigproc/gaussiansmoothing.py,sha256=NaVezgNwdvp-kam1I_7lSID4Obi0UCxZshH7A2afaVg,2692
|
|
26
29
|
ezmsg/sigproc/kaiser.py,sha256=WsZB8a4DP7WwrYLlGczHS61L86TiH6qEStAB6zxODhY,3502
|
|
27
30
|
ezmsg/sigproc/messages.py,sha256=y_twVPK7TxRj8ajmuSuBuxwvLTgyv9OF7Y7v9bw1tfs,926
|
|
28
31
|
ezmsg/sigproc/quantize.py,sha256=VzaqE6PatibEjkk7XrGO-ubAXYurAed9FYOn4bcQZQk,2193
|
|
29
32
|
ezmsg/sigproc/resample.py,sha256=wqSM7g3QrcrklCeGVNN4l_qZLSXRUPHXCUxl1L47300,11654
|
|
33
|
+
ezmsg/sigproc/rollingscaler.py,sha256=RrVAoN7cRvFz7kHSyeQr1pjKiKkJDM_1ChQ5V9FWZKo,8860
|
|
30
34
|
ezmsg/sigproc/sampler.py,sha256=D5oMIZHAJS6XIKMdOHsDw97d4ZxfNP7iZwpc6J8Jmpk,10898
|
|
31
35
|
ezmsg/sigproc/scaler.py,sha256=fCLHvCNUSgv0XChf8iS9s5uHCSCVjCasM2TCvyG5BwQ,4111
|
|
32
36
|
ezmsg/sigproc/signalinjector.py,sha256=hGC837JyDLtAGrfsdMwzEoOqWXiwP7r7sGlUC9nahTY,2948
|
|
@@ -53,7 +57,7 @@ ezmsg/sigproc/util/message.py,sha256=l_b1b6bXX8N6VF9RbUELzsHs73cKkDURBdIr0lt3CY0
|
|
|
53
57
|
ezmsg/sigproc/util/profile.py,sha256=KNJ_QkKelQHNEp2C8MhqzdhYydMNULc_NQq3ccMfzIk,5775
|
|
54
58
|
ezmsg/sigproc/util/sparse.py,sha256=mE64p1tYb5A1shaRE1D-VnH-RshbLb8g8kXSXxnA-J4,4842
|
|
55
59
|
ezmsg/sigproc/util/typeresolution.py,sha256=5R7xmG-F4CkdqQ5aoQnqM-htQb-VwAJl58jJgxtClys,3146
|
|
56
|
-
ezmsg_sigproc-2.
|
|
57
|
-
ezmsg_sigproc-2.
|
|
58
|
-
ezmsg_sigproc-2.
|
|
59
|
-
ezmsg_sigproc-2.
|
|
60
|
+
ezmsg_sigproc-2.5.0.dist-info/METADATA,sha256=SiHigniH10jk8aeW-C7SLMqdldTbpphobjwVJbaBdX0,5019
|
|
61
|
+
ezmsg_sigproc-2.5.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
62
|
+
ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
|
|
63
|
+
ezmsg_sigproc-2.5.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|