ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
from ezmsg.baseproc import (
|
|
7
|
+
BaseAdaptiveTransformer,
|
|
8
|
+
BaseAdaptiveTransformerUnit,
|
|
9
|
+
processor_state,
|
|
10
|
+
)
|
|
11
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
12
|
+
from ezmsg.util.messages.util import replace
|
|
13
|
+
|
|
14
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RollingScalerSettings(ez.Settings):
|
|
18
|
+
axis: str = "time"
|
|
19
|
+
"""
|
|
20
|
+
Axis along which samples are arranged.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
k_samples: int | None = 20
|
|
24
|
+
"""
|
|
25
|
+
Rolling window size in number of samples.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
window_size: float | None = None
|
|
29
|
+
"""
|
|
30
|
+
Rolling window size in seconds.
|
|
31
|
+
If set, overrides `k_samples`.
|
|
32
|
+
`update_with_signal` likely should be True if using this option.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
update_with_signal: bool = False
|
|
36
|
+
"""
|
|
37
|
+
If True, update rolling statistics using the incoming process stream.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
min_samples: int = 1
|
|
41
|
+
"""
|
|
42
|
+
Minimum number of samples required to compute statistics.
|
|
43
|
+
Used when `window_size` is not set.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
min_seconds: float = 1.0
|
|
47
|
+
"""
|
|
48
|
+
Minimum duration in seconds required to compute statistics.
|
|
49
|
+
Used when `window_size` is set.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
artifact_z_thresh: float | None = None
|
|
53
|
+
"""
|
|
54
|
+
Threshold for z-score based artifact detection.
|
|
55
|
+
If set, samples with any channel exceeding this z-score will be excluded
|
|
56
|
+
from updating the rolling statistics.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
clip: float | None = 10.0
|
|
60
|
+
"""
|
|
61
|
+
If set, clip the output values to the range [-clip, clip].
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@processor_state
|
|
66
|
+
class RollingScalerState:
|
|
67
|
+
mean: npt.NDArray | None = None
|
|
68
|
+
N: int = 0
|
|
69
|
+
M2: npt.NDArray | None = None
|
|
70
|
+
samples: deque | None = None
|
|
71
|
+
k_samples: int | None = None
|
|
72
|
+
min_samples: int | None = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]):
|
|
76
|
+
"""
|
|
77
|
+
Processor for rolling z-score normalization of input `AxisArray` messages.
|
|
78
|
+
|
|
79
|
+
The processor maintains rolling statistics (mean and variance) over the last `k_samples`
|
|
80
|
+
samples received via the `partial_fit()` method. When processing an `AxisArray` message,
|
|
81
|
+
it normalizes the data using the current rolling statistics.
|
|
82
|
+
|
|
83
|
+
The input `AxisArray` messages are expected to have shape `(time, ch)`, where `ch` is the
|
|
84
|
+
channel axis. The processor computes the z-score for each channel independently.
|
|
85
|
+
|
|
86
|
+
Note: You should consider instead using the AdaptiveStandardScalerTransformer which
|
|
87
|
+
is computationally more efficient and uses less memory. This RollingScalerProcessor
|
|
88
|
+
is primarily provided to reproduce processing in the literature.
|
|
89
|
+
|
|
90
|
+
Settings:
|
|
91
|
+
---------
|
|
92
|
+
k_samples: int
|
|
93
|
+
Number of previous samples to use for rolling statistics.
|
|
94
|
+
|
|
95
|
+
Example:
|
|
96
|
+
-----------------------------
|
|
97
|
+
```python
|
|
98
|
+
processor = RollingScalerProcessor(
|
|
99
|
+
settings=RollingScalerSettings(
|
|
100
|
+
k_samples=20 # Number of previous samples to use for rolling statistics
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
```
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
107
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
108
|
+
gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
|
|
109
|
+
axis_idx = message.get_axis_idx(axis)
|
|
110
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
111
|
+
return hash((message.key, samp_shape, gain))
|
|
112
|
+
|
|
113
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
114
|
+
ch = message.data.shape[-1]
|
|
115
|
+
self._state.mean = np.zeros(ch)
|
|
116
|
+
self._state.N = 0
|
|
117
|
+
self._state.M2 = np.zeros(ch)
|
|
118
|
+
self._state.k_samples = (
|
|
119
|
+
int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
|
|
120
|
+
if self.settings.window_size is not None
|
|
121
|
+
else self.settings.k_samples
|
|
122
|
+
)
|
|
123
|
+
if self._state.k_samples is not None and self._state.k_samples < 1:
|
|
124
|
+
ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.")
|
|
125
|
+
self._state.k_samples = 1
|
|
126
|
+
elif self._state.k_samples is None:
|
|
127
|
+
ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
|
|
128
|
+
self._state.samples = deque(maxlen=self._state.k_samples)
|
|
129
|
+
self._state.min_samples = (
|
|
130
|
+
int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
|
|
131
|
+
if self.settings.window_size is not None
|
|
132
|
+
else self.settings.min_samples
|
|
133
|
+
)
|
|
134
|
+
if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples:
|
|
135
|
+
ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.")
|
|
136
|
+
self._state.min_samples = self._state.k_samples
|
|
137
|
+
|
|
138
|
+
def _add_batch_stats(self, x: npt.NDArray) -> None:
|
|
139
|
+
x = np.asarray(x, dtype=np.float64)
|
|
140
|
+
n_b = x.shape[0]
|
|
141
|
+
mean_b = np.mean(x, axis=0)
|
|
142
|
+
M2_b = np.sum((x - mean_b) ** 2, axis=0)
|
|
143
|
+
|
|
144
|
+
if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
|
|
145
|
+
n_old, mean_old, M2_old = self._state.samples.popleft()
|
|
146
|
+
N_T = self._state.N
|
|
147
|
+
N_new = N_T - n_old
|
|
148
|
+
|
|
149
|
+
if N_new <= 0:
|
|
150
|
+
self._state.N = 0
|
|
151
|
+
self._state.mean = np.zeros_like(self._state.mean)
|
|
152
|
+
self._state.M2 = np.zeros_like(self._state.M2)
|
|
153
|
+
else:
|
|
154
|
+
delta = mean_old - self._state.mean
|
|
155
|
+
self._state.N = N_new
|
|
156
|
+
self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
|
|
157
|
+
self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
|
|
158
|
+
|
|
159
|
+
N_A = self._state.N
|
|
160
|
+
N = N_A + n_b
|
|
161
|
+
delta = mean_b - self._state.mean
|
|
162
|
+
self._state.mean = self._state.mean + delta * (n_b / N)
|
|
163
|
+
self._state.M2 = self._state.M2 + M2_b + (delta * delta) * (N_A * n_b / N)
|
|
164
|
+
self._state.N = N
|
|
165
|
+
|
|
166
|
+
self._state.samples.append((n_b, mean_b, M2_b))
|
|
167
|
+
|
|
168
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
169
|
+
x = message.sample.data
|
|
170
|
+
self._add_batch_stats(x)
|
|
171
|
+
|
|
172
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
173
|
+
if self._state.N == 0 or self._state.N < self._state.min_samples:
|
|
174
|
+
if self.settings.update_with_signal:
|
|
175
|
+
x = message.data
|
|
176
|
+
if self.settings.artifact_z_thresh is not None and self._state.N > 0:
|
|
177
|
+
varis = self._state.M2 / self._state.N
|
|
178
|
+
std = np.maximum(np.sqrt(varis), 1e-8)
|
|
179
|
+
z = np.abs((x - self._state.mean) / std)
|
|
180
|
+
mask = np.any(z > self.settings.artifact_z_thresh, axis=1)
|
|
181
|
+
x = x[~mask]
|
|
182
|
+
if x.size > 0:
|
|
183
|
+
self._add_batch_stats(x)
|
|
184
|
+
return message
|
|
185
|
+
|
|
186
|
+
varis = self._state.M2 / self._state.N
|
|
187
|
+
std = np.maximum(np.sqrt(varis), 1e-8)
|
|
188
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
189
|
+
result = (message.data - self._state.mean) / std
|
|
190
|
+
result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
|
|
191
|
+
if self.settings.clip is not None:
|
|
192
|
+
result = np.clip(result, -self.settings.clip, self.settings.clip)
|
|
193
|
+
|
|
194
|
+
if self.settings.update_with_signal:
|
|
195
|
+
x = message.data
|
|
196
|
+
if self.settings.artifact_z_thresh is not None:
|
|
197
|
+
z_scores = np.abs((x - self._state.mean) / std)
|
|
198
|
+
mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1)
|
|
199
|
+
x = x[~mask]
|
|
200
|
+
if x.size > 0:
|
|
201
|
+
self._add_batch_stats(x)
|
|
202
|
+
|
|
203
|
+
return replace(message, data=result)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class RollingScalerUnit(
|
|
207
|
+
BaseAdaptiveTransformerUnit[
|
|
208
|
+
RollingScalerSettings,
|
|
209
|
+
AxisArray,
|
|
210
|
+
AxisArray,
|
|
211
|
+
RollingScalerProcessor,
|
|
212
|
+
]
|
|
213
|
+
):
|
|
214
|
+
"""
|
|
215
|
+
Unit wrapper for :obj:`RollingScalerProcessor`.
|
|
216
|
+
|
|
217
|
+
This unit performs rolling z-score normalization on incoming `AxisArray` messages. The unit maintains rolling
|
|
218
|
+
statistics (mean and variance) over the last `k_samples` samples received. When processing an `AxisArray` message,
|
|
219
|
+
it normalizes the data using the current rolling statistics.
|
|
220
|
+
|
|
221
|
+
Example:
|
|
222
|
+
-----------------------------
|
|
223
|
+
```python
|
|
224
|
+
unit = RollingScalerUnit(
|
|
225
|
+
settings=RollingScalerSettings(
|
|
226
|
+
k_samples=20 # Number of previous samples to use for rolling statistics
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
```
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
SETTINGS = RollingScalerSettings
|