ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
- ezmsg/sigproc/affinetransform.py +13 -38
- ezmsg/sigproc/aggregate.py +13 -30
- ezmsg/sigproc/bandpower.py +7 -15
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +123 -0
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/decimate.py +2 -6
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +6 -14
- ezmsg/sigproc/ewma.py +11 -27
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +31 -56
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +33 -70
- ezmsg/sigproc/filterbankdesign.py +5 -12
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +1 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +1 -1
- ezmsg/sigproc/math/difference.py +98 -36
- ezmsg/sigproc/math/invert.py +1 -3
- ezmsg/sigproc/math/log.py +2 -6
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +2 -4
- ezmsg/sigproc/resample.py +13 -34
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +17 -35
- ezmsg/sigproc/scaler.py +8 -18
- ezmsg/sigproc/signalinjector.py +6 -16
- ezmsg/sigproc/slicer.py +9 -28
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +12 -32
- ezmsg/sigproc/transpose.py +7 -18
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +10 -26
- ezmsg/sigproc/util/buffer.py +18 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +5 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +6 -15
- ezmsg/sigproc/window.py +24 -78
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
- ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
|
+
|
|
9
|
+
from ezmsg.sigproc.base import (
|
|
10
|
+
BaseAdaptiveTransformer,
|
|
11
|
+
BaseAdaptiveTransformerUnit,
|
|
12
|
+
processor_state,
|
|
13
|
+
)
|
|
14
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RollingScalerSettings(ez.Settings):
|
|
18
|
+
axis: str = "time"
|
|
19
|
+
"""
|
|
20
|
+
Axis along which samples are arranged.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
k_samples: int | None = 20
|
|
24
|
+
"""
|
|
25
|
+
Rolling window size in number of samples.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
window_size: float | None = None
|
|
29
|
+
"""
|
|
30
|
+
Rolling window size in seconds.
|
|
31
|
+
If set, overrides `k_samples`.
|
|
32
|
+
`update_with_signal` likely should be True if using this option.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
update_with_signal: bool = False
|
|
36
|
+
"""
|
|
37
|
+
If True, update rolling statistics using the incoming process stream.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
min_samples: int = 1
|
|
41
|
+
"""
|
|
42
|
+
Minimum number of samples required to compute statistics.
|
|
43
|
+
Used when `window_size` is not set.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
min_seconds: float = 1.0
|
|
47
|
+
"""
|
|
48
|
+
Minimum duration in seconds required to compute statistics.
|
|
49
|
+
Used when `window_size` is set.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
artifact_z_thresh: float | None = None
|
|
53
|
+
"""
|
|
54
|
+
Threshold for z-score based artifact detection.
|
|
55
|
+
If set, samples with any channel exceeding this z-score will be excluded
|
|
56
|
+
from updating the rolling statistics.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
clip: float | None = 10.0
|
|
60
|
+
"""
|
|
61
|
+
If set, clip the output values to the range [-clip, clip].
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@processor_state
|
|
66
|
+
class RollingScalerState:
|
|
67
|
+
mean: npt.NDArray | None = None
|
|
68
|
+
N: int = 0
|
|
69
|
+
M2: npt.NDArray | None = None
|
|
70
|
+
samples: deque | None = None
|
|
71
|
+
k_samples: int | None = None
|
|
72
|
+
min_samples: int | None = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]):
|
|
76
|
+
"""
|
|
77
|
+
Processor for rolling z-score normalization of input `AxisArray` messages.
|
|
78
|
+
|
|
79
|
+
The processor maintains rolling statistics (mean and variance) over the last `k_samples`
|
|
80
|
+
samples received via the `partial_fit()` method. When processing an `AxisArray` message,
|
|
81
|
+
it normalizes the data using the current rolling statistics.
|
|
82
|
+
|
|
83
|
+
The input `AxisArray` messages are expected to have shape `(time, ch)`, where `ch` is the
|
|
84
|
+
channel axis. The processor computes the z-score for each channel independently.
|
|
85
|
+
|
|
86
|
+
Note: You should consider instead using the AdaptiveStandardScalerTransformer which
|
|
87
|
+
is computationally more efficient and uses less memory. This RollingScalerProcessor
|
|
88
|
+
is primarily provided to reproduce processing in the literature.
|
|
89
|
+
|
|
90
|
+
Settings:
|
|
91
|
+
---------
|
|
92
|
+
k_samples: int
|
|
93
|
+
Number of previous samples to use for rolling statistics.
|
|
94
|
+
|
|
95
|
+
Example:
|
|
96
|
+
-----------------------------
|
|
97
|
+
```python
|
|
98
|
+
processor = RollingScalerProcessor(
|
|
99
|
+
settings=RollingScalerSettings(
|
|
100
|
+
k_samples=20 # Number of previous samples to use for rolling statistics
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
```
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
107
|
+
axis = message.dims[0] if self.settings.axis is None else self.settings.axis
|
|
108
|
+
gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1
|
|
109
|
+
axis_idx = message.get_axis_idx(axis)
|
|
110
|
+
samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
111
|
+
return hash((message.key, samp_shape, gain))
|
|
112
|
+
|
|
113
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
114
|
+
ch = message.data.shape[-1]
|
|
115
|
+
self._state.mean = np.zeros(ch)
|
|
116
|
+
self._state.N = 0
|
|
117
|
+
self._state.M2 = np.zeros(ch)
|
|
118
|
+
self._state.k_samples = (
|
|
119
|
+
int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
|
|
120
|
+
if self.settings.window_size is not None
|
|
121
|
+
else self.settings.k_samples
|
|
122
|
+
)
|
|
123
|
+
if self._state.k_samples is not None and self._state.k_samples < 1:
|
|
124
|
+
ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.")
|
|
125
|
+
self._state.k_samples = 1
|
|
126
|
+
elif self._state.k_samples is None:
|
|
127
|
+
ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
|
|
128
|
+
self._state.samples = deque(maxlen=self._state.k_samples)
|
|
129
|
+
self._state.min_samples = (
|
|
130
|
+
int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
|
|
131
|
+
if self.settings.window_size is not None
|
|
132
|
+
else self.settings.min_samples
|
|
133
|
+
)
|
|
134
|
+
if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples:
|
|
135
|
+
ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.")
|
|
136
|
+
self._state.min_samples = self._state.k_samples
|
|
137
|
+
|
|
138
|
+
def _add_batch_stats(self, x: npt.NDArray) -> None:
|
|
139
|
+
x = np.asarray(x, dtype=np.float64)
|
|
140
|
+
n_b = x.shape[0]
|
|
141
|
+
mean_b = np.mean(x, axis=0)
|
|
142
|
+
M2_b = np.sum((x - mean_b) ** 2, axis=0)
|
|
143
|
+
|
|
144
|
+
if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
|
|
145
|
+
n_old, mean_old, M2_old = self._state.samples.popleft()
|
|
146
|
+
N_T = self._state.N
|
|
147
|
+
N_new = N_T - n_old
|
|
148
|
+
|
|
149
|
+
if N_new <= 0:
|
|
150
|
+
self._state.N = 0
|
|
151
|
+
self._state.mean = np.zeros_like(self._state.mean)
|
|
152
|
+
self._state.M2 = np.zeros_like(self._state.M2)
|
|
153
|
+
else:
|
|
154
|
+
delta = mean_old - self._state.mean
|
|
155
|
+
self._state.N = N_new
|
|
156
|
+
self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
|
|
157
|
+
self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
|
|
158
|
+
|
|
159
|
+
N_A = self._state.N
|
|
160
|
+
N = N_A + n_b
|
|
161
|
+
delta = mean_b - self._state.mean
|
|
162
|
+
self._state.mean = self._state.mean + delta * (n_b / N)
|
|
163
|
+
self._state.M2 = self._state.M2 + M2_b + (delta * delta) * (N_A * n_b / N)
|
|
164
|
+
self._state.N = N
|
|
165
|
+
|
|
166
|
+
self._state.samples.append((n_b, mean_b, M2_b))
|
|
167
|
+
|
|
168
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
169
|
+
x = message.sample.data
|
|
170
|
+
self._add_batch_stats(x)
|
|
171
|
+
|
|
172
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
173
|
+
if self._state.N == 0 or self._state.N < self._state.min_samples:
|
|
174
|
+
if self.settings.update_with_signal:
|
|
175
|
+
x = message.data
|
|
176
|
+
if self.settings.artifact_z_thresh is not None and self._state.N > 0:
|
|
177
|
+
varis = self._state.M2 / self._state.N
|
|
178
|
+
std = np.maximum(np.sqrt(varis), 1e-8)
|
|
179
|
+
z = np.abs((x - self._state.mean) / std)
|
|
180
|
+
mask = np.any(z > self.settings.artifact_z_thresh, axis=1)
|
|
181
|
+
x = x[~mask]
|
|
182
|
+
if x.size > 0:
|
|
183
|
+
self._add_batch_stats(x)
|
|
184
|
+
return message
|
|
185
|
+
|
|
186
|
+
varis = self._state.M2 / self._state.N
|
|
187
|
+
std = np.maximum(np.sqrt(varis), 1e-8)
|
|
188
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
189
|
+
result = (message.data - self._state.mean) / std
|
|
190
|
+
result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
|
|
191
|
+
if self.settings.clip is not None:
|
|
192
|
+
result = np.clip(result, -self.settings.clip, self.settings.clip)
|
|
193
|
+
|
|
194
|
+
if self.settings.update_with_signal:
|
|
195
|
+
x = message.data
|
|
196
|
+
if self.settings.artifact_z_thresh is not None:
|
|
197
|
+
z_scores = np.abs((x - self._state.mean) / std)
|
|
198
|
+
mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1)
|
|
199
|
+
x = x[~mask]
|
|
200
|
+
if x.size > 0:
|
|
201
|
+
self._add_batch_stats(x)
|
|
202
|
+
|
|
203
|
+
return replace(message, data=result)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class RollingScalerUnit(
|
|
207
|
+
BaseAdaptiveTransformerUnit[
|
|
208
|
+
RollingScalerSettings,
|
|
209
|
+
AxisArray,
|
|
210
|
+
AxisArray,
|
|
211
|
+
RollingScalerProcessor,
|
|
212
|
+
]
|
|
213
|
+
):
|
|
214
|
+
"""
|
|
215
|
+
Unit wrapper for :obj:`RollingScalerProcessor`.
|
|
216
|
+
|
|
217
|
+
This unit performs rolling z-score normalization on incoming `AxisArray` messages. The unit maintains rolling
|
|
218
|
+
statistics (mean and variance) over the last `k_samples` samples received. When processing an `AxisArray` message,
|
|
219
|
+
it normalizes the data using the current rolling statistics.
|
|
220
|
+
|
|
221
|
+
Example:
|
|
222
|
+
-----------------------------
|
|
223
|
+
```python
|
|
224
|
+
unit = RollingScalerUnit(
|
|
225
|
+
settings=RollingScalerSettings(
|
|
226
|
+
k_samples=20 # Number of previous samples to use for rolling statistics
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
```
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
SETTINGS = RollingScalerSettings
|
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from collections import deque
|
|
3
2
|
import copy
|
|
4
3
|
import traceback
|
|
5
4
|
import typing
|
|
5
|
+
from collections import deque
|
|
6
6
|
|
|
7
|
-
import numpy as np
|
|
8
7
|
import ezmsg.core as ez
|
|
8
|
+
import numpy as np
|
|
9
9
|
from ezmsg.util.messages.axisarray import (
|
|
10
10
|
AxisArray,
|
|
11
11
|
)
|
|
12
12
|
from ezmsg.util.messages.util import replace
|
|
13
13
|
|
|
14
|
-
from .util.profile import profile_subpub
|
|
15
|
-
from .util.axisarray_buffer import HybridAxisArrayBuffer
|
|
16
|
-
from .util.buffer import UpdateStrategy
|
|
17
|
-
from .util.message import SampleMessage, SampleTriggerMessage
|
|
18
14
|
from .base import (
|
|
19
|
-
BaseStatefulTransformer,
|
|
20
15
|
BaseConsumerUnit,
|
|
21
|
-
BaseTransformerUnit,
|
|
22
|
-
BaseStatefulProducer,
|
|
23
16
|
BaseProducerUnit,
|
|
17
|
+
BaseStatefulProducer,
|
|
18
|
+
BaseStatefulTransformer,
|
|
19
|
+
BaseTransformerUnit,
|
|
24
20
|
processor_state,
|
|
25
21
|
)
|
|
22
|
+
from .util.axisarray_buffer import HybridAxisArrayBuffer
|
|
23
|
+
from .util.buffer import UpdateStrategy
|
|
24
|
+
from .util.message import SampleMessage, SampleTriggerMessage
|
|
25
|
+
from .util.profile import profile_subpub
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class SamplerSettings(ez.Settings):
|
|
@@ -74,12 +74,8 @@ class SamplerState:
|
|
|
74
74
|
triggers: deque[SampleTriggerMessage] | None = None
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
class SamplerTransformer(
|
|
78
|
-
|
|
79
|
-
):
|
|
80
|
-
def __call__(
|
|
81
|
-
self, message: AxisArray | SampleTriggerMessage
|
|
82
|
-
) -> list[SampleMessage]:
|
|
77
|
+
class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]):
|
|
78
|
+
def __call__(self, message: AxisArray | SampleTriggerMessage) -> list[SampleMessage]:
|
|
83
79
|
# TODO: Currently we have a single entry point that accepts both
|
|
84
80
|
# data and trigger messages and we choose a code path based on
|
|
85
81
|
# the message type. However, in the future we will likely replace
|
|
@@ -99,9 +95,7 @@ class SamplerTransformer(
|
|
|
99
95
|
# Compute hash based on message properties that require state reset
|
|
100
96
|
axis = self.settings.axis or message.dims[0]
|
|
101
97
|
axis_idx = message.get_axis_idx(axis)
|
|
102
|
-
sample_shape =
|
|
103
|
-
message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
104
|
-
)
|
|
98
|
+
sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
105
99
|
return hash((sample_shape, message.key))
|
|
106
100
|
|
|
107
101
|
def _reset_state(self, message: AxisArray) -> None:
|
|
@@ -193,20 +187,14 @@ class SamplerTransformer(
|
|
|
193
187
|
trigger_ts: float = message.timestamp
|
|
194
188
|
if not self.settings.estimate_alignment:
|
|
195
189
|
# Override the trigger timestamp with the next sample's likely timestamp.
|
|
196
|
-
trigger_ts =
|
|
197
|
-
self._state.buffer.axis_final_value + self._state.buffer.axis_gain
|
|
198
|
-
)
|
|
190
|
+
trigger_ts = self._state.buffer.axis_final_value + self._state.buffer.axis_gain
|
|
199
191
|
|
|
200
|
-
new_trig_msg = replace(
|
|
201
|
-
message, timestamp=trigger_ts, period=_period, value=_value
|
|
202
|
-
)
|
|
192
|
+
new_trig_msg = replace(message, timestamp=trigger_ts, period=_period, value=_value)
|
|
203
193
|
self._state.triggers.append(new_trig_msg)
|
|
204
194
|
return []
|
|
205
195
|
|
|
206
196
|
|
|
207
|
-
class Sampler(
|
|
208
|
-
BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]
|
|
209
|
-
):
|
|
197
|
+
class Sampler(BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]):
|
|
210
198
|
SETTINGS = SamplerSettings
|
|
211
199
|
|
|
212
200
|
INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
|
|
@@ -269,19 +257,13 @@ class TriggerGeneratorState:
|
|
|
269
257
|
output: int = 0
|
|
270
258
|
|
|
271
259
|
|
|
272
|
-
class TriggerProducer(
|
|
273
|
-
BaseStatefulProducer[
|
|
274
|
-
TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState
|
|
275
|
-
]
|
|
276
|
-
):
|
|
260
|
+
class TriggerProducer(BaseStatefulProducer[TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState]):
|
|
277
261
|
def _reset_state(self) -> None:
|
|
278
262
|
self._state.output = 0
|
|
279
263
|
|
|
280
264
|
async def _produce(self) -> SampleTriggerMessage:
|
|
281
265
|
await asyncio.sleep(self.settings.publish_period)
|
|
282
|
-
out_msg = SampleTriggerMessage(
|
|
283
|
-
period=self.settings.period, value=self._state.output
|
|
284
|
-
)
|
|
266
|
+
out_msg = SampleTriggerMessage(period=self.settings.period, value=self._state.output)
|
|
285
267
|
self._state.output += 1
|
|
286
268
|
return out_msg
|
|
287
269
|
|
ezmsg/sigproc/scaler.py
CHANGED
|
@@ -1,27 +1,25 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
+
from ezmsg.util.generator import consumer
|
|
4
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
6
|
from ezmsg.util.messages.util import replace
|
|
6
|
-
from ezmsg.util.generator import consumer
|
|
7
7
|
|
|
8
8
|
from .base import (
|
|
9
9
|
BaseStatefulTransformer,
|
|
10
10
|
BaseTransformerUnit,
|
|
11
11
|
processor_state,
|
|
12
12
|
)
|
|
13
|
-
from .ewma import EWMATransformer, EWMASettings, _alpha_from_tau
|
|
14
13
|
|
|
15
14
|
# Imports for backwards compatibility with previous module location
|
|
16
15
|
from .ewma import EWMA_Deprecated as EWMA_Deprecated
|
|
17
|
-
from .ewma import
|
|
16
|
+
from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
|
|
18
17
|
from .ewma import _tau_from_alpha as _tau_from_alpha
|
|
18
|
+
from .ewma import ewma_step as ewma_step
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
@consumer
|
|
22
|
-
def scaler(
|
|
23
|
-
time_constant: float = 1.0, axis: str | None = None
|
|
24
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
22
|
+
def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
25
23
|
"""
|
|
26
24
|
Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
|
|
27
25
|
This is faster than :obj:`scaler_np` for single-channel data.
|
|
@@ -85,19 +83,13 @@ class AdaptiveStandardScalerTransformer(
|
|
|
85
83
|
]
|
|
86
84
|
):
|
|
87
85
|
def _reset_state(self, message: AxisArray) -> None:
|
|
88
|
-
self._state.samps_ewma = EWMATransformer(
|
|
89
|
-
|
|
90
|
-
)
|
|
91
|
-
self._state.vars_sq_ewma = EWMATransformer(
|
|
92
|
-
time_constant=self.settings.time_constant, axis=self.settings.axis
|
|
93
|
-
)
|
|
86
|
+
self._state.samps_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
|
|
87
|
+
self._state.vars_sq_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
|
|
94
88
|
|
|
95
89
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
96
90
|
# Update step
|
|
97
91
|
mean_message = self._state.samps_ewma(message)
|
|
98
|
-
var_sq_message = self._state.vars_sq_ewma(
|
|
99
|
-
replace(message, data=message.data**2)
|
|
100
|
-
)
|
|
92
|
+
var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
|
|
101
93
|
|
|
102
94
|
# Get step
|
|
103
95
|
varis = var_sq_message.data - mean_message.data**2
|
|
@@ -119,9 +111,7 @@ class AdaptiveStandardScaler(
|
|
|
119
111
|
|
|
120
112
|
|
|
121
113
|
# Backwards compatibility...
|
|
122
|
-
def scaler_np(
|
|
123
|
-
time_constant: float = 1.0, axis: str | None = None
|
|
124
|
-
) -> AdaptiveStandardScalerTransformer:
|
|
114
|
+
def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
|
|
125
115
|
return AdaptiveStandardScalerTransformer(
|
|
126
116
|
settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
|
|
127
117
|
)
|
ezmsg/sigproc/signalinjector.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
3
|
-
from ezmsg.util.messages.util import replace
|
|
4
2
|
import numpy as np
|
|
5
3
|
import numpy.typing as npt
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
+
from ezmsg.util.messages.util import replace
|
|
6
6
|
|
|
7
7
|
from .base import (
|
|
8
8
|
BaseAsyncTransformer,
|
|
@@ -27,15 +27,11 @@ class SignalInjectorState:
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class SignalInjectorTransformer(
|
|
30
|
-
BaseAsyncTransformer[
|
|
31
|
-
SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState
|
|
32
|
-
]
|
|
30
|
+
BaseAsyncTransformer[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState]
|
|
33
31
|
):
|
|
34
32
|
def _hash_message(self, message: AxisArray) -> int:
|
|
35
33
|
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
36
|
-
sample_shape =
|
|
37
|
-
message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
38
|
-
)
|
|
34
|
+
sample_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
39
35
|
return hash((message.key,) + sample_shape)
|
|
40
36
|
|
|
41
37
|
def _reset_state(self, message: AxisArray) -> None:
|
|
@@ -44,9 +40,7 @@ class SignalInjectorTransformer(
|
|
|
44
40
|
if self._state.cur_amplitude is None:
|
|
45
41
|
self._state.cur_amplitude = self.settings.amplitude
|
|
46
42
|
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
47
|
-
self._state.cur_shape =
|
|
48
|
-
message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
49
|
-
)
|
|
43
|
+
self._state.cur_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
50
44
|
rng = np.random.default_rng(self.settings.mixing_seed)
|
|
51
45
|
self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
|
|
52
46
|
self._state.mixing = (self._state.mixing * 2.0) - 1.0
|
|
@@ -63,11 +57,7 @@ class SignalInjectorTransformer(
|
|
|
63
57
|
return out_msg
|
|
64
58
|
|
|
65
59
|
|
|
66
|
-
class SignalInjector(
|
|
67
|
-
BaseTransformerUnit[
|
|
68
|
-
SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer
|
|
69
|
-
]
|
|
70
|
-
):
|
|
60
|
+
class SignalInjector(BaseTransformerUnit[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer]):
|
|
71
61
|
SETTINGS = SignalInjectorSettings
|
|
72
62
|
INPUT_FREQUENCY = ez.InputStream(float | None)
|
|
73
63
|
INPUT_AMPLITUDE = ez.InputStream(float)
|
ezmsg/sigproc/slicer.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
1
2
|
import numpy as np
|
|
2
3
|
import numpy.typing as npt
|
|
3
|
-
import ezmsg.core as ez
|
|
4
4
|
from ezmsg.util.messages.axisarray import (
|
|
5
5
|
AxisArray,
|
|
6
|
-
slice_along_axis,
|
|
7
6
|
AxisBase,
|
|
8
7
|
replace,
|
|
8
|
+
slice_along_axis,
|
|
9
9
|
)
|
|
10
10
|
|
|
11
11
|
from .base import (
|
|
@@ -49,11 +49,7 @@ def parse_slice(
|
|
|
49
49
|
if "," not in s:
|
|
50
50
|
parts = [part.strip() for part in s.split(":")]
|
|
51
51
|
if len(parts) == 1:
|
|
52
|
-
if (
|
|
53
|
-
axinfo is not None
|
|
54
|
-
and hasattr(axinfo, "data")
|
|
55
|
-
and parts[0] in axinfo.data
|
|
56
|
-
):
|
|
52
|
+
if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
|
|
57
53
|
return tuple(np.where(axinfo.data == parts[0])[0])
|
|
58
54
|
return (int(parts[0]),)
|
|
59
55
|
return (slice(*(int(part.strip()) if part else None for part in parts)),)
|
|
@@ -76,9 +72,7 @@ class SlicerState:
|
|
|
76
72
|
b_change_dims: bool = False
|
|
77
73
|
|
|
78
74
|
|
|
79
|
-
class SlicerTransformer(
|
|
80
|
-
BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]
|
|
81
|
-
):
|
|
75
|
+
class SlicerTransformer(BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]):
|
|
82
76
|
def _hash_message(self, message: AxisArray) -> int:
|
|
83
77
|
axis = self.settings.axis or message.dims[-1]
|
|
84
78
|
axis_idx = message.get_axis_idx(axis)
|
|
@@ -101,11 +95,7 @@ class SlicerTransformer(
|
|
|
101
95
|
self._state.slice_ = np.s_[indices]
|
|
102
96
|
|
|
103
97
|
# Create the output axis
|
|
104
|
-
if (
|
|
105
|
-
axis in message.axes
|
|
106
|
-
and hasattr(message.axes[axis], "data")
|
|
107
|
-
and len(message.axes[axis].data) > 0
|
|
108
|
-
):
|
|
98
|
+
if axis in message.axes and hasattr(message.axes[axis], "data") and len(message.axes[axis].data) > 0:
|
|
109
99
|
in_data = np.array(message.axes[axis].data)
|
|
110
100
|
if self._state.b_change_dims:
|
|
111
101
|
out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
|
|
@@ -119,17 +109,10 @@ class SlicerTransformer(
|
|
|
119
109
|
|
|
120
110
|
replace_kwargs = {}
|
|
121
111
|
if self._state.b_change_dims:
|
|
122
|
-
replace_kwargs["dims"] = [
|
|
123
|
-
|
|
124
|
-
]
|
|
125
|
-
replace_kwargs["axes"] = {
|
|
126
|
-
k: v for k, v in message.axes.items() if k != axis
|
|
127
|
-
}
|
|
112
|
+
replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx]
|
|
113
|
+
replace_kwargs["axes"] = {k: v for k, v in message.axes.items() if k != axis}
|
|
128
114
|
elif self._state.new_axis is not None:
|
|
129
|
-
replace_kwargs["axes"] = {
|
|
130
|
-
k: (v if k != axis else self._state.new_axis)
|
|
131
|
-
for k, v in message.axes.items()
|
|
132
|
-
}
|
|
115
|
+
replace_kwargs["axes"] = {k: (v if k != axis else self._state.new_axis) for k, v in message.axes.items()}
|
|
133
116
|
|
|
134
117
|
return replace(
|
|
135
118
|
message,
|
|
@@ -138,9 +121,7 @@ class SlicerTransformer(
|
|
|
138
121
|
)
|
|
139
122
|
|
|
140
123
|
|
|
141
|
-
class Slicer(
|
|
142
|
-
BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]
|
|
143
|
-
):
|
|
124
|
+
class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
|
|
144
125
|
SETTINGS = SlicerSettings
|
|
145
126
|
|
|
146
127
|
|
ezmsg/sigproc/spectral.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from .spectrum import OptionsEnum as OptionsEnum
|
|
2
|
-
from .spectrum import WindowFunction as WindowFunction
|
|
3
|
-
from .spectrum import SpectralTransform as SpectralTransform
|
|
4
2
|
from .spectrum import SpectralOutput as SpectralOutput
|
|
5
|
-
from .spectrum import
|
|
3
|
+
from .spectrum import SpectralTransform as SpectralTransform
|
|
6
4
|
from .spectrum import Spectrum as Spectrum
|
|
5
|
+
from .spectrum import SpectrumSettings as SpectrumSettings
|
|
6
|
+
from .spectrum import WindowFunction as WindowFunction
|
ezmsg/sigproc/spectrogram.py
CHANGED
|
@@ -1,20 +1,21 @@
|
|
|
1
1
|
from typing import Generator
|
|
2
|
+
|
|
2
3
|
import ezmsg.core as ez
|
|
3
4
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
5
|
from ezmsg.util.messages.modify import modify_axis
|
|
5
6
|
|
|
6
|
-
from .window import Anchor, WindowTransformer
|
|
7
|
-
from .spectrum import (
|
|
8
|
-
WindowFunction,
|
|
9
|
-
SpectralTransform,
|
|
10
|
-
SpectralOutput,
|
|
11
|
-
SpectrumTransformer,
|
|
12
|
-
)
|
|
13
7
|
from .base import (
|
|
14
|
-
CompositeProcessor,
|
|
15
8
|
BaseStatefulProcessor,
|
|
16
9
|
BaseTransformerUnit,
|
|
10
|
+
CompositeProcessor,
|
|
11
|
+
)
|
|
12
|
+
from .spectrum import (
|
|
13
|
+
SpectralOutput,
|
|
14
|
+
SpectralTransform,
|
|
15
|
+
SpectrumTransformer,
|
|
16
|
+
WindowFunction,
|
|
17
17
|
)
|
|
18
|
+
from .window import Anchor, WindowTransformer
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class SpectrogramSettings(ez.Settings):
|
|
@@ -41,9 +42,7 @@ class SpectrogramSettings(ez.Settings):
|
|
|
41
42
|
"""The :obj:`SpectralOutput` format."""
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
class SpectrogramTransformer(
|
|
45
|
-
CompositeProcessor[SpectrogramSettings, AxisArray, AxisArray]
|
|
46
|
-
):
|
|
45
|
+
class SpectrogramTransformer(CompositeProcessor[SpectrogramSettings, AxisArray, AxisArray]):
|
|
47
46
|
@staticmethod
|
|
48
47
|
def _initialize_processors(
|
|
49
48
|
settings: SpectrogramSettings,
|
|
@@ -54,9 +53,7 @@ class SpectrogramTransformer(
|
|
|
54
53
|
newaxis="win",
|
|
55
54
|
window_dur=settings.window_dur,
|
|
56
55
|
window_shift=settings.window_shift,
|
|
57
|
-
zero_pad_until="shift"
|
|
58
|
-
if settings.window_shift is not None
|
|
59
|
-
else "input",
|
|
56
|
+
zero_pad_until="shift" if settings.window_shift is not None else "input",
|
|
60
57
|
anchor=settings.window_anchor,
|
|
61
58
|
),
|
|
62
59
|
"spectrum": SpectrumTransformer(
|
|
@@ -69,11 +66,7 @@ class SpectrogramTransformer(
|
|
|
69
66
|
}
|
|
70
67
|
|
|
71
68
|
|
|
72
|
-
class Spectrogram(
|
|
73
|
-
BaseTransformerUnit[
|
|
74
|
-
SpectrogramSettings, AxisArray, AxisArray, SpectrogramTransformer
|
|
75
|
-
]
|
|
76
|
-
):
|
|
69
|
+
class Spectrogram(BaseTransformerUnit[SpectrogramSettings, AxisArray, AxisArray, SpectrogramTransformer]):
|
|
77
70
|
SETTINGS = SpectrogramSettings
|
|
78
71
|
|
|
79
72
|
|