ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,278 @@
1
+ import asyncio
2
+ import math
3
+ import time
4
+
5
+ import ezmsg.core as ez
6
+ import numpy as np
7
+ import scipy.interpolate
8
+ from ezmsg.baseproc import (
9
+ BaseConsumerUnit,
10
+ BaseStatefulProcessor,
11
+ processor_state,
12
+ )
13
+ from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
14
+ from ezmsg.util.messages.util import replace
15
+
16
+ from .util.axisarray_buffer import HybridAxisArrayBuffer, HybridAxisBuffer
17
+ from .util.buffer import UpdateStrategy
18
+
19
+
20
+ class ResampleSettings(ez.Settings):
21
+ axis: str = "time"
22
+
23
+ resample_rate: float | None = None
24
+ """target resample rate in Hz. If None, the resample rate will be determined by the reference signal."""
25
+
26
+ max_chunk_delay: float = np.inf
27
+ """Maximum delay between outputs in seconds. If the delay exceeds this value, the transformer will extrapolate."""
28
+
29
+ fill_value: str = "extrapolate"
30
+ """
31
+ Value to use for out-of-bounds samples.
32
+ If 'extrapolate', the transformer will extrapolate.
33
+ If 'last', the transformer will use the last sample.
34
+ See scipy.interpolate.interp1d for more options.
35
+ """
36
+
37
+ buffer_duration: float = 2.0
38
+
39
+ buffer_update_strategy: UpdateStrategy = "immediate"
40
+ """
41
+ The buffer update strategy. See :obj:`ezmsg.sigproc.util.buffer.UpdateStrategy`.
42
+ If you expect to push data much more frequently than it is resampled, then "on_demand"
43
+ might be more efficient. For most other scenarios, "immediate" is best.
44
+ """
45
+
46
+
47
+ @processor_state
48
+ class ResampleState:
49
+ src_buffer: HybridAxisArrayBuffer | None = None
50
+ """
51
+ Buffer for the incoming signal data. This is the source for training the interpolation function.
52
+ Its contents are rarely empty because we usually hold back some data to allow for accurate
53
+ interpolation and optionally extrapolation.
54
+ """
55
+
56
+ ref_axis_buffer: HybridAxisBuffer | None = None
57
+ """
58
+ The buffer for the reference axis (usually a time axis). The interpolation function
59
+ will be evaluated at the reference axis values.
60
+ When resample_rate is None, this buffer will be filled with the axis from incoming
61
+ _reference_ messages.
62
+ When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
63
+ is filled with a synthetic axis that is generated from the incoming signal messages.
64
+ """
65
+
66
+ last_ref_ax_val: float | None = None
67
+ """
68
+ The last value of the reference axis that was returned. This helps us to know
69
+ what the _next_ returned value should be, and to avoid returning the same value.
70
+ TODO: We can eliminate this variable if we maintain "by convention" that the
71
+ reference axis always has 1 value at its start that we exclude from the resampling.
72
+ """
73
+
74
+ last_write_time: float = -np.inf
75
+ """
76
+ Wall clock time of the last write to the signal buffer.
77
+ This is used to determine if we need to extrapolate the reference axis
78
+ if we have not received an update within max_chunk_delay.
79
+ """
80
+
81
+
82
+ class ResampleProcessor(BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]):
83
+ def _hash_message(self, message: AxisArray) -> int:
84
+ ax_idx: int = message.get_axis_idx(self.settings.axis)
85
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
86
+ ax = message.axes[self.settings.axis]
87
+ gain = ax.gain if hasattr(ax, "gain") else None
88
+ return hash((message.key, gain) + sample_shape)
89
+
90
+ def _reset_state(self, message: AxisArray) -> None:
91
+ """
92
+ Reset the internal state based on the incoming message.
93
+ """
94
+ self.state.src_buffer = HybridAxisArrayBuffer(
95
+ duration=self.settings.buffer_duration,
96
+ axis=self.settings.axis,
97
+ update_strategy=self.settings.buffer_update_strategy,
98
+ overflow_strategy="grow",
99
+ )
100
+ if self.settings.resample_rate is not None:
101
+ # If we are resampling at a prescribed rate, then we synthesize a reference axis
102
+ self.state.ref_axis_buffer = HybridAxisBuffer(
103
+ duration=self.settings.buffer_duration,
104
+ )
105
+ in_ax = message.axes[self.settings.axis]
106
+ out_gain = 1 / self.settings.resample_rate
107
+ t0 = in_ax.data[0] if hasattr(in_ax, "data") else in_ax.value(0)
108
+ self.state.last_ref_ax_val = t0 - out_gain
109
+ self.state.last_write_time = -np.inf
110
+
111
+ def push_reference(self, message: AxisArray) -> None:
112
+ ax = message.axes[self.settings.axis]
113
+ ax_idx = message.get_axis_idx(self.settings.axis)
114
+ if self.state.ref_axis_buffer is None:
115
+ self.state.ref_axis_buffer = HybridAxisBuffer(
116
+ duration=self.settings.buffer_duration,
117
+ update_strategy=self.settings.buffer_update_strategy,
118
+ overflow_strategy="grow",
119
+ )
120
+ t0 = ax.data[0] if hasattr(ax, "data") else ax.value(0)
121
+ self.state.last_ref_ax_val = t0 - ax.gain
122
+ self.state.ref_axis_buffer.write(ax, n_samples=message.data.shape[ax_idx])
123
+
124
+ def _process(self, message: AxisArray) -> None:
125
+ """
126
+ Add a new data message to the buffer and update the reference axis if needed.
127
+ """
128
+ # Note: The src_buffer will copy and permute message if ax_idx != 0
129
+ self.state.src_buffer.write(message)
130
+
131
+ # If we are resampling at a prescribed rate (i.e., not by reference msgs),
132
+ # then we use this opportunity to extend our synthetic reference axis.
133
+ ax_idx = message.get_axis_idx(self.settings.axis)
134
+ if self.settings.resample_rate is not None and message.data.shape[ax_idx] > 0:
135
+ in_ax = message.axes[self.settings.axis]
136
+ in_t_end = in_ax.data[-1] if hasattr(in_ax, "data") else in_ax.value(message.data.shape[ax_idx] - 1)
137
+ out_gain = 1 / self.settings.resample_rate
138
+ prev_t_end = self.state.last_ref_ax_val
139
+ n_synth = math.ceil((in_t_end - prev_t_end) * self.settings.resample_rate)
140
+ synth_ref_axis = LinearAxis(unit="s", gain=out_gain, offset=prev_t_end + out_gain)
141
+ self.state.ref_axis_buffer.write(synth_ref_axis, n_samples=n_synth)
142
+
143
+ self.state.last_write_time = time.time()
144
+
145
+ def __next__(self) -> AxisArray:
146
+ if self.state.src_buffer is None or self.state.ref_axis_buffer is None:
147
+ # If we have not received any data, or we require reference data
148
+ # that we do not yet have, then return an empty template.
149
+ return AxisArray(data=np.array([]), dims=[""], axes={}, key="null")
150
+
151
+ src = self.state.src_buffer
152
+ ref = self.state.ref_axis_buffer
153
+
154
+ # If we have no reference or the source is insufficient for interpolation
155
+ # then return the empty template
156
+ if ref.is_empty() or src.available() < 3:
157
+ src_axarr = src.peek(0)
158
+ return replace(
159
+ src_axarr,
160
+ axes={
161
+ **src_axarr.axes,
162
+ self.settings.axis: ref.peek(0),
163
+ },
164
+ )
165
+
166
+ # Build the reference xvec.
167
+ # Note: The reference axis buffer may grow upon `.peek()`
168
+ # as it flushes data from its deque to its buffer.
169
+ ref_ax = ref.peek()
170
+ if hasattr(ref_ax, "data"):
171
+ ref_xvec = ref_ax.data
172
+ else:
173
+ ref_xvec = ref_ax.value(np.arange(ref.available()))
174
+
175
+ # If we do not rely on an external reference, and we have not received new data in a while,
176
+ # then extrapolate our reference vector out beyond the delay limit.
177
+ b_project = self.settings.resample_rate is not None and time.time() > (
178
+ self.state.last_write_time + self.settings.max_chunk_delay
179
+ )
180
+ if b_project:
181
+ n_append = math.ceil(self.settings.max_chunk_delay / ref_ax.gain)
182
+ xvec_append = ref_xvec[-1] + np.arange(1, n_append + 1) * ref_ax.gain
183
+ ref_xvec = np.hstack((ref_xvec, xvec_append))
184
+
185
+ # Get source to train interpolation
186
+ src_axarr = src.peek()
187
+ src_axis = src_axarr.axes[self.settings.axis]
188
+ x = src_axis.data if hasattr(src_axis, "data") else src_axis.value(np.arange(src_axarr.data.shape[0]))
189
+
190
+ # Only resample at reference values that have not been interpolated over previously.
191
+ b_ref = ref_xvec > self.state.last_ref_ax_val
192
+ if not b_project:
193
+ # Not extrapolating -- Do not resample beyond the end of the source buffer.
194
+ b_ref = np.logical_and(b_ref, ref_xvec <= x[-1])
195
+ ref_idx = np.where(b_ref)[0]
196
+
197
+ if len(ref_idx) == 0:
198
+ # Nothing to interpolate over; return empty data
199
+ null_ref = replace(ref_ax, data=ref_ax.data[:0]) if hasattr(ref_ax, "data") else ref_ax
200
+ return replace(
201
+ src_axarr,
202
+ data=src_axarr.data[:0, ...],
203
+ axes={**src_axarr.axes, self.settings.axis: null_ref},
204
+ )
205
+
206
+ xnew = ref_xvec[ref_idx]
207
+
208
+ # Identify source data indices around ref tvec with some padding for better interpolation.
209
+ src_start_ix = max(0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0)
210
+
211
+ x = x[src_start_ix:]
212
+ y = src_axarr.data[src_start_ix:]
213
+
214
+ if isinstance(self.settings.fill_value, str) and self.settings.fill_value == "last":
215
+ fill_value = (y[0], y[-1])
216
+ else:
217
+ fill_value = self.settings.fill_value
218
+ f = scipy.interpolate.interp1d(
219
+ x,
220
+ y,
221
+ kind="linear",
222
+ axis=0,
223
+ copy=False,
224
+ bounds_error=False,
225
+ fill_value=fill_value,
226
+ assume_sorted=True,
227
+ )
228
+
229
+ # Calculate output
230
+ resampled_data = f(xnew)
231
+
232
+ # Create output message
233
+ if hasattr(ref_ax, "data"):
234
+ out_ax = replace(ref_ax, data=xnew)
235
+ else:
236
+ out_ax = replace(ref_ax, offset=xnew[0])
237
+ result = replace(
238
+ src_axarr,
239
+ data=resampled_data,
240
+ axes={
241
+ **src_axarr.axes,
242
+ self.settings.axis: out_ax,
243
+ },
244
+ )
245
+
246
+ # Update the state. For state buffers, seek beyond samples that are no longer needed.
247
+ # src: keep at least 1 sample before the final resampled value
248
+ seek_ix = np.where(x >= xnew[-1])[0]
249
+ if len(seek_ix) > 0:
250
+ self.state.src_buffer.seek(max(0, src_start_ix + seek_ix[0] - 1))
251
+ # ref: remove samples that have been sent to output
252
+ self.state.ref_axis_buffer.seek(ref_idx[-1] + 1)
253
+ self.state.last_ref_ax_val = xnew[-1]
254
+
255
+ return result
256
+
257
+ def send(self, message: AxisArray) -> AxisArray:
258
+ self(message)
259
+ return next(self)
260
+
261
+
262
+ class ResampleUnit(BaseConsumerUnit[ResampleSettings, AxisArray, ResampleProcessor]):
263
+ SETTINGS = ResampleSettings
264
+ INPUT_REFERENCE = ez.InputStream(AxisArray)
265
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
266
+
267
+ @ez.subscriber(INPUT_REFERENCE, zero_copy=True)
268
+ async def on_reference(self, message: AxisArray):
269
+ self.processor.push_reference(message)
270
+
271
+ @ez.publisher(OUTPUT_SIGNAL)
272
+ async def gen_resampled(self):
273
+ while True:
274
+ result: AxisArray = next(self.processor)
275
+ if np.prod(result.data.shape) > 0:
276
+ yield self.OUTPUT_SIGNAL, result
277
+ else:
278
+ await asyncio.sleep(0.001)
@@ -0,0 +1,232 @@
1
+ from collections import deque
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ from ezmsg.baseproc import (
7
+ BaseAdaptiveTransformer,
8
+ BaseAdaptiveTransformerUnit,
9
+ processor_state,
10
+ )
11
+ from ezmsg.util.messages.axisarray import AxisArray
12
+ from ezmsg.util.messages.util import replace
13
+
14
+ from ezmsg.sigproc.sampler import SampleMessage
15
+
16
+
17
+ class RollingScalerSettings(ez.Settings):
18
+ axis: str = "time"
19
+ """
20
+ Axis along which samples are arranged.
21
+ """
22
+
23
+ k_samples: int | None = 20
24
+ """
25
+ Rolling window size in number of samples.
26
+ """
27
+
28
+ window_size: float | None = None
29
+ """
30
+ Rolling window size in seconds.
31
+ If set, overrides `k_samples`.
32
+ `update_with_signal` likely should be True if using this option.
33
+ """
34
+
35
+ update_with_signal: bool = False
36
+ """
37
+ If True, update rolling statistics using the incoming process stream.
38
+ """
39
+
40
+ min_samples: int = 1
41
+ """
42
+ Minimum number of samples required to compute statistics.
43
+ Used when `window_size` is not set.
44
+ """
45
+
46
+ min_seconds: float = 1.0
47
+ """
48
+ Minimum duration in seconds required to compute statistics.
49
+ Used when `window_size` is set.
50
+ """
51
+
52
+ artifact_z_thresh: float | None = None
53
+ """
54
+ Threshold for z-score based artifact detection.
55
+ If set, samples with any channel exceeding this z-score will be excluded
56
+ from updating the rolling statistics.
57
+ """
58
+
59
+ clip: float | None = 10.0
60
+ """
61
+ If set, clip the output values to the range [-clip, clip].
62
+ """
63
+
64
+
65
+ @processor_state
66
+ class RollingScalerState:
67
+ mean: npt.NDArray | None = None
68
+ N: int = 0
69
+ M2: npt.NDArray | None = None
70
+ samples: deque | None = None
71
+ k_samples: int | None = None
72
+ min_samples: int | None = None
73
+
74
+
75
+ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]):
76
+ """
77
+ Processor for rolling z-score normalization of input `AxisArray` messages.
78
+
79
+ The processor maintains rolling statistics (mean and variance) over the last `k_samples`
80
+ samples received via the `partial_fit()` method. When processing an `AxisArray` message,
81
+ it normalizes the data using the current rolling statistics.
82
+
83
+ The input `AxisArray` messages are expected to have shape `(time, ch)`, where `ch` is the
84
+ channel axis. The processor computes the z-score for each channel independently.
85
+
86
+ Note: You should consider instead using the AdaptiveStandardScalerTransformer which
87
+ is computationally more efficient and uses less memory. This RollingScalerProcessor
88
+ is primarily provided to reproduce processing in the literature.
89
+
90
+ Settings:
91
+ ---------
92
+ k_samples: int
93
+ Number of previous samples to use for rolling statistics.
94
+
95
+ Example:
96
+ -----------------------------
97
+ ```python
98
+ processor = RollingScalerProcessor(
99
+ settings=RollingScalerSettings(
100
+ k_samples=20 # Number of previous samples to use for rolling statistics
101
+ )
102
+ )
103
+ ```
104
+ """
105
+
106
+ def _hash_message(self, message: AxisArray) -> int:
107
+ axis = message.dims[0] if self.settings.axis is None else self.settings.axis
108
+ gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
109
+ axis_idx = message.get_axis_idx(axis)
110
+ samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
111
+ return hash((message.key, samp_shape, gain))
112
+
113
+ def _reset_state(self, message: AxisArray) -> None:
114
+ ch = message.data.shape[-1]
115
+ self._state.mean = np.zeros(ch)
116
+ self._state.N = 0
117
+ self._state.M2 = np.zeros(ch)
118
+ self._state.k_samples = (
119
+ int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
120
+ if self.settings.window_size is not None
121
+ else self.settings.k_samples
122
+ )
123
+ if self._state.k_samples is not None and self._state.k_samples < 1:
124
+ ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.")
125
+ self._state.k_samples = 1
126
+ elif self._state.k_samples is None:
127
+ ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
128
+ self._state.samples = deque(maxlen=self._state.k_samples)
129
+ self._state.min_samples = (
130
+ int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
131
+ if self.settings.window_size is not None
132
+ else self.settings.min_samples
133
+ )
134
+ if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples:
135
+ ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.")
136
+ self._state.min_samples = self._state.k_samples
137
+
138
+ def _add_batch_stats(self, x: npt.NDArray) -> None:
139
+ x = np.asarray(x, dtype=np.float64)
140
+ n_b = x.shape[0]
141
+ mean_b = np.mean(x, axis=0)
142
+ M2_b = np.sum((x - mean_b) ** 2, axis=0)
143
+
144
+ if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
145
+ n_old, mean_old, M2_old = self._state.samples.popleft()
146
+ N_T = self._state.N
147
+ N_new = N_T - n_old
148
+
149
+ if N_new <= 0:
150
+ self._state.N = 0
151
+ self._state.mean = np.zeros_like(self._state.mean)
152
+ self._state.M2 = np.zeros_like(self._state.M2)
153
+ else:
154
+ delta = mean_old - self._state.mean
155
+ self._state.N = N_new
156
+ self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
157
+ self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
158
+
159
+ N_A = self._state.N
160
+ N = N_A + n_b
161
+ delta = mean_b - self._state.mean
162
+ self._state.mean = self._state.mean + delta * (n_b / N)
163
+ self._state.M2 = self._state.M2 + M2_b + (delta * delta) * (N_A * n_b / N)
164
+ self._state.N = N
165
+
166
+ self._state.samples.append((n_b, mean_b, M2_b))
167
+
168
+ def partial_fit(self, message: SampleMessage) -> None:
169
+ x = message.sample.data
170
+ self._add_batch_stats(x)
171
+
172
+ def _process(self, message: AxisArray) -> AxisArray:
173
+ if self._state.N == 0 or self._state.N < self._state.min_samples:
174
+ if self.settings.update_with_signal:
175
+ x = message.data
176
+ if self.settings.artifact_z_thresh is not None and self._state.N > 0:
177
+ varis = self._state.M2 / self._state.N
178
+ std = np.maximum(np.sqrt(varis), 1e-8)
179
+ z = np.abs((x - self._state.mean) / std)
180
+ mask = np.any(z > self.settings.artifact_z_thresh, axis=1)
181
+ x = x[~mask]
182
+ if x.size > 0:
183
+ self._add_batch_stats(x)
184
+ return message
185
+
186
+ varis = self._state.M2 / self._state.N
187
+ std = np.maximum(np.sqrt(varis), 1e-8)
188
+ with np.errstate(divide="ignore", invalid="ignore"):
189
+ result = (message.data - self._state.mean) / std
190
+ result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
191
+ if self.settings.clip is not None:
192
+ result = np.clip(result, -self.settings.clip, self.settings.clip)
193
+
194
+ if self.settings.update_with_signal:
195
+ x = message.data
196
+ if self.settings.artifact_z_thresh is not None:
197
+ z_scores = np.abs((x - self._state.mean) / std)
198
+ mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1)
199
+ x = x[~mask]
200
+ if x.size > 0:
201
+ self._add_batch_stats(x)
202
+
203
+ return replace(message, data=result)
204
+
205
+
206
+ class RollingScalerUnit(
207
+ BaseAdaptiveTransformerUnit[
208
+ RollingScalerSettings,
209
+ AxisArray,
210
+ AxisArray,
211
+ RollingScalerProcessor,
212
+ ]
213
+ ):
214
+ """
215
+ Unit wrapper for :obj:`RollingScalerProcessor`.
216
+
217
+ This unit performs rolling z-score normalization on incoming `AxisArray` messages. The unit maintains rolling
218
+ statistics (mean and variance) over the last `k_samples` samples received. When processing an `AxisArray` message,
219
+ it normalizes the data using the current rolling statistics.
220
+
221
+ Example:
222
+ -----------------------------
223
+ ```python
224
+ unit = RollingScalerUnit(
225
+ settings=RollingScalerSettings(
226
+ k_samples=20 # Number of previous samples to use for rolling statistics
227
+ )
228
+ )
229
+ ```
230
+ """
231
+
232
+ SETTINGS = RollingScalerSettings