ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import math
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
import numpy as np
|
|
7
|
+
import scipy.interpolate
|
|
8
|
+
from ezmsg.baseproc import (
|
|
9
|
+
BaseConsumerUnit,
|
|
10
|
+
BaseStatefulProcessor,
|
|
11
|
+
processor_state,
|
|
12
|
+
)
|
|
13
|
+
from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
|
|
14
|
+
from ezmsg.util.messages.util import replace
|
|
15
|
+
|
|
16
|
+
from .util.axisarray_buffer import HybridAxisArrayBuffer, HybridAxisBuffer
|
|
17
|
+
from .util.buffer import UpdateStrategy
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ResampleSettings(ez.Settings):
|
|
21
|
+
axis: str = "time"
|
|
22
|
+
|
|
23
|
+
resample_rate: float | None = None
|
|
24
|
+
"""target resample rate in Hz. If None, the resample rate will be determined by the reference signal."""
|
|
25
|
+
|
|
26
|
+
max_chunk_delay: float = np.inf
|
|
27
|
+
"""Maximum delay between outputs in seconds. If the delay exceeds this value, the transformer will extrapolate."""
|
|
28
|
+
|
|
29
|
+
fill_value: str = "extrapolate"
|
|
30
|
+
"""
|
|
31
|
+
Value to use for out-of-bounds samples.
|
|
32
|
+
If 'extrapolate', the transformer will extrapolate.
|
|
33
|
+
If 'last', the transformer will use the last sample.
|
|
34
|
+
See scipy.interpolate.interp1d for more options.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
buffer_duration: float = 2.0
|
|
38
|
+
|
|
39
|
+
buffer_update_strategy: UpdateStrategy = "immediate"
|
|
40
|
+
"""
|
|
41
|
+
The buffer update strategy. See :obj:`ezmsg.sigproc.util.buffer.UpdateStrategy`.
|
|
42
|
+
If you expect to push data much more frequently than it is resampled, then "on_demand"
|
|
43
|
+
might be more efficient. For most other scenarios, "immediate" is best.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@processor_state
|
|
48
|
+
class ResampleState:
|
|
49
|
+
src_buffer: HybridAxisArrayBuffer | None = None
|
|
50
|
+
"""
|
|
51
|
+
Buffer for the incoming signal data. This is the source for training the interpolation function.
|
|
52
|
+
Its contents are rarely empty because we usually hold back some data to allow for accurate
|
|
53
|
+
interpolation and optionally extrapolation.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
ref_axis_buffer: HybridAxisBuffer | None = None
|
|
57
|
+
"""
|
|
58
|
+
The buffer for the reference axis (usually a time axis). The interpolation function
|
|
59
|
+
will be evaluated at the reference axis values.
|
|
60
|
+
When resample_rate is None, this buffer will be filled with the axis from incoming
|
|
61
|
+
_reference_ messages.
|
|
62
|
+
When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
|
|
63
|
+
is filled with a synthetic axis that is generated from the incoming signal messages.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
last_ref_ax_val: float | None = None
|
|
67
|
+
"""
|
|
68
|
+
The last value of the reference axis that was returned. This helps us to know
|
|
69
|
+
what the _next_ returned value should be, and to avoid returning the same value.
|
|
70
|
+
TODO: We can eliminate this variable if we maintain "by convention" that the
|
|
71
|
+
reference axis always has 1 value at its start that we exclude from the resampling.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
last_write_time: float = -np.inf
|
|
75
|
+
"""
|
|
76
|
+
Wall clock time of the last write to the signal buffer.
|
|
77
|
+
This is used to determine if we need to extrapolate the reference axis
|
|
78
|
+
if we have not received an update within max_chunk_delay.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ResampleProcessor(BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]):
|
|
83
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
84
|
+
ax_idx: int = message.get_axis_idx(self.settings.axis)
|
|
85
|
+
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
86
|
+
ax = message.axes[self.settings.axis]
|
|
87
|
+
gain = ax.gain if hasattr(ax, "gain") else None
|
|
88
|
+
return hash((message.key, gain) + sample_shape)
|
|
89
|
+
|
|
90
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Reset the internal state based on the incoming message.
|
|
93
|
+
"""
|
|
94
|
+
self.state.src_buffer = HybridAxisArrayBuffer(
|
|
95
|
+
duration=self.settings.buffer_duration,
|
|
96
|
+
axis=self.settings.axis,
|
|
97
|
+
update_strategy=self.settings.buffer_update_strategy,
|
|
98
|
+
overflow_strategy="grow",
|
|
99
|
+
)
|
|
100
|
+
if self.settings.resample_rate is not None:
|
|
101
|
+
# If we are resampling at a prescribed rate, then we synthesize a reference axis
|
|
102
|
+
self.state.ref_axis_buffer = HybridAxisBuffer(
|
|
103
|
+
duration=self.settings.buffer_duration,
|
|
104
|
+
)
|
|
105
|
+
in_ax = message.axes[self.settings.axis]
|
|
106
|
+
out_gain = 1 / self.settings.resample_rate
|
|
107
|
+
t0 = in_ax.data[0] if hasattr(in_ax, "data") else in_ax.value(0)
|
|
108
|
+
self.state.last_ref_ax_val = t0 - out_gain
|
|
109
|
+
self.state.last_write_time = -np.inf
|
|
110
|
+
|
|
111
|
+
def push_reference(self, message: AxisArray) -> None:
|
|
112
|
+
ax = message.axes[self.settings.axis]
|
|
113
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
114
|
+
if self.state.ref_axis_buffer is None:
|
|
115
|
+
self.state.ref_axis_buffer = HybridAxisBuffer(
|
|
116
|
+
duration=self.settings.buffer_duration,
|
|
117
|
+
update_strategy=self.settings.buffer_update_strategy,
|
|
118
|
+
overflow_strategy="grow",
|
|
119
|
+
)
|
|
120
|
+
t0 = ax.data[0] if hasattr(ax, "data") else ax.value(0)
|
|
121
|
+
self.state.last_ref_ax_val = t0 - ax.gain
|
|
122
|
+
self.state.ref_axis_buffer.write(ax, n_samples=message.data.shape[ax_idx])
|
|
123
|
+
|
|
124
|
+
def _process(self, message: AxisArray) -> None:
|
|
125
|
+
"""
|
|
126
|
+
Add a new data message to the buffer and update the reference axis if needed.
|
|
127
|
+
"""
|
|
128
|
+
# Note: The src_buffer will copy and permute message if ax_idx != 0
|
|
129
|
+
self.state.src_buffer.write(message)
|
|
130
|
+
|
|
131
|
+
# If we are resampling at a prescribed rate (i.e., not by reference msgs),
|
|
132
|
+
# then we use this opportunity to extend our synthetic reference axis.
|
|
133
|
+
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
134
|
+
if self.settings.resample_rate is not None and message.data.shape[ax_idx] > 0:
|
|
135
|
+
in_ax = message.axes[self.settings.axis]
|
|
136
|
+
in_t_end = in_ax.data[-1] if hasattr(in_ax, "data") else in_ax.value(message.data.shape[ax_idx] - 1)
|
|
137
|
+
out_gain = 1 / self.settings.resample_rate
|
|
138
|
+
prev_t_end = self.state.last_ref_ax_val
|
|
139
|
+
n_synth = math.ceil((in_t_end - prev_t_end) * self.settings.resample_rate)
|
|
140
|
+
synth_ref_axis = LinearAxis(unit="s", gain=out_gain, offset=prev_t_end + out_gain)
|
|
141
|
+
self.state.ref_axis_buffer.write(synth_ref_axis, n_samples=n_synth)
|
|
142
|
+
|
|
143
|
+
self.state.last_write_time = time.time()
|
|
144
|
+
|
|
145
|
+
def __next__(self) -> AxisArray:
|
|
146
|
+
if self.state.src_buffer is None or self.state.ref_axis_buffer is None:
|
|
147
|
+
# If we have not received any data, or we require reference data
|
|
148
|
+
# that we do not yet have, then return an empty template.
|
|
149
|
+
return AxisArray(data=np.array([]), dims=[""], axes={}, key="null")
|
|
150
|
+
|
|
151
|
+
src = self.state.src_buffer
|
|
152
|
+
ref = self.state.ref_axis_buffer
|
|
153
|
+
|
|
154
|
+
# If we have no reference or the source is insufficient for interpolation
|
|
155
|
+
# then return the empty template
|
|
156
|
+
if ref.is_empty() or src.available() < 3:
|
|
157
|
+
src_axarr = src.peek(0)
|
|
158
|
+
return replace(
|
|
159
|
+
src_axarr,
|
|
160
|
+
axes={
|
|
161
|
+
**src_axarr.axes,
|
|
162
|
+
self.settings.axis: ref.peek(0),
|
|
163
|
+
},
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Build the reference xvec.
|
|
167
|
+
# Note: The reference axis buffer may grow upon `.peek()`
|
|
168
|
+
# as it flushes data from its deque to its buffer.
|
|
169
|
+
ref_ax = ref.peek()
|
|
170
|
+
if hasattr(ref_ax, "data"):
|
|
171
|
+
ref_xvec = ref_ax.data
|
|
172
|
+
else:
|
|
173
|
+
ref_xvec = ref_ax.value(np.arange(ref.available()))
|
|
174
|
+
|
|
175
|
+
# If we do not rely on an external reference, and we have not received new data in a while,
|
|
176
|
+
# then extrapolate our reference vector out beyond the delay limit.
|
|
177
|
+
b_project = self.settings.resample_rate is not None and time.time() > (
|
|
178
|
+
self.state.last_write_time + self.settings.max_chunk_delay
|
|
179
|
+
)
|
|
180
|
+
if b_project:
|
|
181
|
+
n_append = math.ceil(self.settings.max_chunk_delay / ref_ax.gain)
|
|
182
|
+
xvec_append = ref_xvec[-1] + np.arange(1, n_append + 1) * ref_ax.gain
|
|
183
|
+
ref_xvec = np.hstack((ref_xvec, xvec_append))
|
|
184
|
+
|
|
185
|
+
# Get source to train interpolation
|
|
186
|
+
src_axarr = src.peek()
|
|
187
|
+
src_axis = src_axarr.axes[self.settings.axis]
|
|
188
|
+
x = src_axis.data if hasattr(src_axis, "data") else src_axis.value(np.arange(src_axarr.data.shape[0]))
|
|
189
|
+
|
|
190
|
+
# Only resample at reference values that have not been interpolated over previously.
|
|
191
|
+
b_ref = ref_xvec > self.state.last_ref_ax_val
|
|
192
|
+
if not b_project:
|
|
193
|
+
# Not extrapolating -- Do not resample beyond the end of the source buffer.
|
|
194
|
+
b_ref = np.logical_and(b_ref, ref_xvec <= x[-1])
|
|
195
|
+
ref_idx = np.where(b_ref)[0]
|
|
196
|
+
|
|
197
|
+
if len(ref_idx) == 0:
|
|
198
|
+
# Nothing to interpolate over; return empty data
|
|
199
|
+
null_ref = replace(ref_ax, data=ref_ax.data[:0]) if hasattr(ref_ax, "data") else ref_ax
|
|
200
|
+
return replace(
|
|
201
|
+
src_axarr,
|
|
202
|
+
data=src_axarr.data[:0, ...],
|
|
203
|
+
axes={**src_axarr.axes, self.settings.axis: null_ref},
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
xnew = ref_xvec[ref_idx]
|
|
207
|
+
|
|
208
|
+
# Identify source data indices around ref tvec with some padding for better interpolation.
|
|
209
|
+
src_start_ix = max(0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0)
|
|
210
|
+
|
|
211
|
+
x = x[src_start_ix:]
|
|
212
|
+
y = src_axarr.data[src_start_ix:]
|
|
213
|
+
|
|
214
|
+
if isinstance(self.settings.fill_value, str) and self.settings.fill_value == "last":
|
|
215
|
+
fill_value = (y[0], y[-1])
|
|
216
|
+
else:
|
|
217
|
+
fill_value = self.settings.fill_value
|
|
218
|
+
f = scipy.interpolate.interp1d(
|
|
219
|
+
x,
|
|
220
|
+
y,
|
|
221
|
+
kind="linear",
|
|
222
|
+
axis=0,
|
|
223
|
+
copy=False,
|
|
224
|
+
bounds_error=False,
|
|
225
|
+
fill_value=fill_value,
|
|
226
|
+
assume_sorted=True,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Calculate output
|
|
230
|
+
resampled_data = f(xnew)
|
|
231
|
+
|
|
232
|
+
# Create output message
|
|
233
|
+
if hasattr(ref_ax, "data"):
|
|
234
|
+
out_ax = replace(ref_ax, data=xnew)
|
|
235
|
+
else:
|
|
236
|
+
out_ax = replace(ref_ax, offset=xnew[0])
|
|
237
|
+
result = replace(
|
|
238
|
+
src_axarr,
|
|
239
|
+
data=resampled_data,
|
|
240
|
+
axes={
|
|
241
|
+
**src_axarr.axes,
|
|
242
|
+
self.settings.axis: out_ax,
|
|
243
|
+
},
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Update the state. For state buffers, seek beyond samples that are no longer needed.
|
|
247
|
+
# src: keep at least 1 sample before the final resampled value
|
|
248
|
+
seek_ix = np.where(x >= xnew[-1])[0]
|
|
249
|
+
if len(seek_ix) > 0:
|
|
250
|
+
self.state.src_buffer.seek(max(0, src_start_ix + seek_ix[0] - 1))
|
|
251
|
+
# ref: remove samples that have been sent to output
|
|
252
|
+
self.state.ref_axis_buffer.seek(ref_idx[-1] + 1)
|
|
253
|
+
self.state.last_ref_ax_val = xnew[-1]
|
|
254
|
+
|
|
255
|
+
return result
|
|
256
|
+
|
|
257
|
+
def send(self, message: AxisArray) -> AxisArray:
|
|
258
|
+
self(message)
|
|
259
|
+
return next(self)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class ResampleUnit(BaseConsumerUnit[ResampleSettings, AxisArray, ResampleProcessor]):
|
|
263
|
+
SETTINGS = ResampleSettings
|
|
264
|
+
INPUT_REFERENCE = ez.InputStream(AxisArray)
|
|
265
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
266
|
+
|
|
267
|
+
@ez.subscriber(INPUT_REFERENCE, zero_copy=True)
|
|
268
|
+
async def on_reference(self, message: AxisArray):
|
|
269
|
+
self.processor.push_reference(message)
|
|
270
|
+
|
|
271
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
272
|
+
async def gen_resampled(self):
|
|
273
|
+
while True:
|
|
274
|
+
result: AxisArray = next(self.processor)
|
|
275
|
+
if np.prod(result.data.shape) > 0:
|
|
276
|
+
yield self.OUTPUT_SIGNAL, result
|
|
277
|
+
else:
|
|
278
|
+
await asyncio.sleep(0.001)
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
from ezmsg.baseproc import (
|
|
7
|
+
BaseAdaptiveTransformer,
|
|
8
|
+
BaseAdaptiveTransformerUnit,
|
|
9
|
+
processor_state,
|
|
10
|
+
)
|
|
11
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
12
|
+
from ezmsg.util.messages.util import replace
|
|
13
|
+
|
|
14
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RollingScalerSettings(ez.Settings):
|
|
18
|
+
axis: str = "time"
|
|
19
|
+
"""
|
|
20
|
+
Axis along which samples are arranged.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
k_samples: int | None = 20
|
|
24
|
+
"""
|
|
25
|
+
Rolling window size in number of samples.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
window_size: float | None = None
|
|
29
|
+
"""
|
|
30
|
+
Rolling window size in seconds.
|
|
31
|
+
If set, overrides `k_samples`.
|
|
32
|
+
`update_with_signal` likely should be True if using this option.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
update_with_signal: bool = False
|
|
36
|
+
"""
|
|
37
|
+
If True, update rolling statistics using the incoming process stream.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
min_samples: int = 1
|
|
41
|
+
"""
|
|
42
|
+
Minimum number of samples required to compute statistics.
|
|
43
|
+
Used when `window_size` is not set.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
min_seconds: float = 1.0
|
|
47
|
+
"""
|
|
48
|
+
Minimum duration in seconds required to compute statistics.
|
|
49
|
+
Used when `window_size` is set.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
artifact_z_thresh: float | None = None
|
|
53
|
+
"""
|
|
54
|
+
Threshold for z-score based artifact detection.
|
|
55
|
+
If set, samples with any channel exceeding this z-score will be excluded
|
|
56
|
+
from updating the rolling statistics.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
clip: float | None = 10.0
|
|
60
|
+
"""
|
|
61
|
+
If set, clip the output values to the range [-clip, clip].
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@processor_state
|
|
66
|
+
class RollingScalerState:
|
|
67
|
+
mean: npt.NDArray | None = None
|
|
68
|
+
N: int = 0
|
|
69
|
+
M2: npt.NDArray | None = None
|
|
70
|
+
samples: deque | None = None
|
|
71
|
+
k_samples: int | None = None
|
|
72
|
+
min_samples: int | None = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]):
|
|
76
|
+
"""
|
|
77
|
+
Processor for rolling z-score normalization of input `AxisArray` messages.
|
|
78
|
+
|
|
79
|
+
The processor maintains rolling statistics (mean and variance) over the last `k_samples`
|
|
80
|
+
samples received via the `partial_fit()` method. When processing an `AxisArray` message,
|
|
81
|
+
it normalizes the data using the current rolling statistics.
|
|
82
|
+
|
|
83
|
+
The input `AxisArray` messages are expected to have shape `(time, ch)`, where `ch` is the
|
|
84
|
+
channel axis. The processor computes the z-score for each channel independently.
|
|
85
|
+
|
|
86
|
+
Note: You should consider instead using the AdaptiveStandardScalerTransformer which
|
|
87
|
+
is computationally more efficient and uses less memory. This RollingScalerProcessor
|
|
88
|
+
is primarily provided to reproduce processing in the literature.
|
|
89
|
+
|
|
90
|
+
Settings:
|
|
91
|
+
---------
|
|
92
|
+
k_samples: int
|
|
93
|
+
Number of previous samples to use for rolling statistics.
|
|
94
|
+
|
|
95
|
+
Example:
|
|
96
|
+
-----------------------------
|
|
97
|
+
```python
|
|
98
|
+
processor = RollingScalerProcessor(
|
|
99
|
+
settings=RollingScalerSettings(
|
|
100
|
+
k_samples=20 # Number of previous samples to use for rolling statistics
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
```
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
107
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
108
|
+
gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
|
|
109
|
+
axis_idx = message.get_axis_idx(axis)
|
|
110
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
111
|
+
return hash((message.key, samp_shape, gain))
|
|
112
|
+
|
|
113
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
114
|
+
ch = message.data.shape[-1]
|
|
115
|
+
self._state.mean = np.zeros(ch)
|
|
116
|
+
self._state.N = 0
|
|
117
|
+
self._state.M2 = np.zeros(ch)
|
|
118
|
+
self._state.k_samples = (
|
|
119
|
+
int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
|
|
120
|
+
if self.settings.window_size is not None
|
|
121
|
+
else self.settings.k_samples
|
|
122
|
+
)
|
|
123
|
+
if self._state.k_samples is not None and self._state.k_samples < 1:
|
|
124
|
+
ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.")
|
|
125
|
+
self._state.k_samples = 1
|
|
126
|
+
elif self._state.k_samples is None:
|
|
127
|
+
ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
|
|
128
|
+
self._state.samples = deque(maxlen=self._state.k_samples)
|
|
129
|
+
self._state.min_samples = (
|
|
130
|
+
int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
|
|
131
|
+
if self.settings.window_size is not None
|
|
132
|
+
else self.settings.min_samples
|
|
133
|
+
)
|
|
134
|
+
if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples:
|
|
135
|
+
ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.")
|
|
136
|
+
self._state.min_samples = self._state.k_samples
|
|
137
|
+
|
|
138
|
+
def _add_batch_stats(self, x: npt.NDArray) -> None:
|
|
139
|
+
x = np.asarray(x, dtype=np.float64)
|
|
140
|
+
n_b = x.shape[0]
|
|
141
|
+
mean_b = np.mean(x, axis=0)
|
|
142
|
+
M2_b = np.sum((x - mean_b) ** 2, axis=0)
|
|
143
|
+
|
|
144
|
+
if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
|
|
145
|
+
n_old, mean_old, M2_old = self._state.samples.popleft()
|
|
146
|
+
N_T = self._state.N
|
|
147
|
+
N_new = N_T - n_old
|
|
148
|
+
|
|
149
|
+
if N_new <= 0:
|
|
150
|
+
self._state.N = 0
|
|
151
|
+
self._state.mean = np.zeros_like(self._state.mean)
|
|
152
|
+
self._state.M2 = np.zeros_like(self._state.M2)
|
|
153
|
+
else:
|
|
154
|
+
delta = mean_old - self._state.mean
|
|
155
|
+
self._state.N = N_new
|
|
156
|
+
self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
|
|
157
|
+
self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
|
|
158
|
+
|
|
159
|
+
N_A = self._state.N
|
|
160
|
+
N = N_A + n_b
|
|
161
|
+
delta = mean_b - self._state.mean
|
|
162
|
+
self._state.mean = self._state.mean + delta * (n_b / N)
|
|
163
|
+
self._state.M2 = self._state.M2 + M2_b + (delta * delta) * (N_A * n_b / N)
|
|
164
|
+
self._state.N = N
|
|
165
|
+
|
|
166
|
+
self._state.samples.append((n_b, mean_b, M2_b))
|
|
167
|
+
|
|
168
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
169
|
+
x = message.sample.data
|
|
170
|
+
self._add_batch_stats(x)
|
|
171
|
+
|
|
172
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
173
|
+
if self._state.N == 0 or self._state.N < self._state.min_samples:
|
|
174
|
+
if self.settings.update_with_signal:
|
|
175
|
+
x = message.data
|
|
176
|
+
if self.settings.artifact_z_thresh is not None and self._state.N > 0:
|
|
177
|
+
varis = self._state.M2 / self._state.N
|
|
178
|
+
std = np.maximum(np.sqrt(varis), 1e-8)
|
|
179
|
+
z = np.abs((x - self._state.mean) / std)
|
|
180
|
+
mask = np.any(z > self.settings.artifact_z_thresh, axis=1)
|
|
181
|
+
x = x[~mask]
|
|
182
|
+
if x.size > 0:
|
|
183
|
+
self._add_batch_stats(x)
|
|
184
|
+
return message
|
|
185
|
+
|
|
186
|
+
varis = self._state.M2 / self._state.N
|
|
187
|
+
std = np.maximum(np.sqrt(varis), 1e-8)
|
|
188
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
189
|
+
result = (message.data - self._state.mean) / std
|
|
190
|
+
result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
|
|
191
|
+
if self.settings.clip is not None:
|
|
192
|
+
result = np.clip(result, -self.settings.clip, self.settings.clip)
|
|
193
|
+
|
|
194
|
+
if self.settings.update_with_signal:
|
|
195
|
+
x = message.data
|
|
196
|
+
if self.settings.artifact_z_thresh is not None:
|
|
197
|
+
z_scores = np.abs((x - self._state.mean) / std)
|
|
198
|
+
mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1)
|
|
199
|
+
x = x[~mask]
|
|
200
|
+
if x.size > 0:
|
|
201
|
+
self._add_batch_stats(x)
|
|
202
|
+
|
|
203
|
+
return replace(message, data=result)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class RollingScalerUnit(
|
|
207
|
+
BaseAdaptiveTransformerUnit[
|
|
208
|
+
RollingScalerSettings,
|
|
209
|
+
AxisArray,
|
|
210
|
+
AxisArray,
|
|
211
|
+
RollingScalerProcessor,
|
|
212
|
+
]
|
|
213
|
+
):
|
|
214
|
+
"""
|
|
215
|
+
Unit wrapper for :obj:`RollingScalerProcessor`.
|
|
216
|
+
|
|
217
|
+
This unit performs rolling z-score normalization on incoming `AxisArray` messages. The unit maintains rolling
|
|
218
|
+
statistics (mean and variance) over the last `k_samples` samples received. When processing an `AxisArray` message,
|
|
219
|
+
it normalizes the data using the current rolling statistics.
|
|
220
|
+
|
|
221
|
+
Example:
|
|
222
|
+
-----------------------------
|
|
223
|
+
```python
|
|
224
|
+
unit = RollingScalerUnit(
|
|
225
|
+
settings=RollingScalerSettings(
|
|
226
|
+
k_samples=20 # Number of previous samples to use for rolling statistics
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
```
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
SETTINGS = RollingScalerSettings
|