ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/scaler.py
CHANGED
|
@@ -1,169 +1,25 @@
|
|
|
1
|
-
import functools
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
|
-
import numpy as np
|
|
5
|
-
import numpy.typing as npt
|
|
6
|
-
import scipy.signal
|
|
7
3
|
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ezmsg.baseproc import (
|
|
6
|
+
BaseStatefulTransformer,
|
|
7
|
+
BaseTransformerUnit,
|
|
8
|
+
processor_state,
|
|
9
|
+
)
|
|
10
|
+
from ezmsg.util.generator import consumer
|
|
8
11
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
12
|
from ezmsg.util.messages.util import replace
|
|
10
|
-
from ezmsg.util.generator import consumer
|
|
11
|
-
|
|
12
|
-
from .base import GenAxisArray
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def _tau_from_alpha(alpha: float, dt: float) -> float:
|
|
16
|
-
"""
|
|
17
|
-
Inverse of _alpha_from_tau. See that function for explanation.
|
|
18
|
-
"""
|
|
19
|
-
return -dt / np.log(1 - alpha)
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
23
|
-
"""
|
|
24
|
-
# https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
|
|
25
|
-
:param tau: The amount of time for the smoothed response of a unit step function to reach
|
|
26
|
-
1 - 1/e approx-eq 63.2%.
|
|
27
|
-
:param dt: sampling period, or 1 / sampling_rate.
|
|
28
|
-
:return: alpha, the "fading factor" in exponential smoothing.
|
|
29
|
-
"""
|
|
30
|
-
return 1 - np.exp(-dt / tau)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def ewma_step(
|
|
34
|
-
sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
|
|
35
|
-
):
|
|
36
|
-
"""
|
|
37
|
-
Do an exponentially weighted moving average step.
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
sample: The new sample.
|
|
41
|
-
zi: The output of the previous step.
|
|
42
|
-
alpha: Fading factor.
|
|
43
|
-
beta: Persisting factor. If None, it is calculated as 1-alpha.
|
|
44
|
-
|
|
45
|
-
Returns:
|
|
46
|
-
alpha * sample + beta * zi
|
|
47
|
-
|
|
48
|
-
"""
|
|
49
|
-
# Potential micro-optimization:
|
|
50
|
-
# Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
|
|
51
|
-
# Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
|
|
52
|
-
# return zi + alpha * (new_sample - zi)
|
|
53
|
-
beta = beta or (1 - alpha)
|
|
54
|
-
return alpha * sample + beta * zi
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
class EWMA:
|
|
58
|
-
def __init__(self, alpha: float):
|
|
59
|
-
self.beta = 1 - alpha
|
|
60
|
-
self._filt_func = functools.partial(
|
|
61
|
-
scipy.signal.lfilter, [alpha], [1.0, alpha - 1.0], axis=0
|
|
62
|
-
)
|
|
63
|
-
self.prev = None
|
|
64
|
-
|
|
65
|
-
def compute(self, arr: npt.NDArray) -> npt.NDArray:
|
|
66
|
-
if self.prev is None:
|
|
67
|
-
self.prev = self.beta * arr[:1]
|
|
68
|
-
expected, self.prev = self._filt_func(arr, zi=self.prev)
|
|
69
|
-
return expected
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class EWMA_Deprecated:
|
|
73
|
-
"""
|
|
74
|
-
Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
|
|
75
|
-
but they ended up being slower than the scipy.signal.lfilter method.
|
|
76
|
-
Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
|
|
77
|
-
and beta**n approaches zero.
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
|
-
def __init__(self, alpha: float, max_len: int):
|
|
81
|
-
self.alpha = alpha
|
|
82
|
-
self.beta = 1 - alpha
|
|
83
|
-
self.prev: npt.NDArray | None = None
|
|
84
|
-
self.weights = np.empty((max_len + 1,), float)
|
|
85
|
-
self._precalc_weights(max_len)
|
|
86
|
-
self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
|
|
87
|
-
|
|
88
|
-
def _precalc_weights(self, n: int):
|
|
89
|
-
# (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
|
|
90
|
-
np.power(self.beta, np.arange(n + 1), out=self.weights)
|
|
91
|
-
|
|
92
|
-
def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
|
|
93
|
-
if out is None:
|
|
94
|
-
out = np.empty(arr.shape, arr.dtype)
|
|
95
|
-
|
|
96
|
-
n = arr.shape[0]
|
|
97
|
-
weights = self.weights[:n]
|
|
98
|
-
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
99
|
-
|
|
100
|
-
# α*P0, α*P1, α*P2, ..., α*Pn
|
|
101
|
-
np.multiply(self.alpha, arr, out)
|
|
102
|
-
|
|
103
|
-
# α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
|
|
104
|
-
np.divide(out, weights, out)
|
|
105
|
-
|
|
106
|
-
# α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
|
|
107
|
-
np.cumsum(out, axis=0, out=out)
|
|
108
|
-
|
|
109
|
-
# (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
|
|
110
|
-
np.multiply(out, weights, out)
|
|
111
|
-
|
|
112
|
-
# Add the previous output
|
|
113
|
-
if self.prev is None:
|
|
114
|
-
self.prev = arr[:1]
|
|
115
|
-
|
|
116
|
-
out += self.prev * np.expand_dims(
|
|
117
|
-
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
self.prev = out[-1:]
|
|
121
13
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
Args:
|
|
129
|
-
arr: The input array to be smoothed.
|
|
130
|
-
|
|
131
|
-
Returns:
|
|
132
|
-
The smoothed array.
|
|
133
|
-
"""
|
|
134
|
-
n = arr.shape[0]
|
|
135
|
-
if n > len(self.weights):
|
|
136
|
-
self._precalc_weights(n)
|
|
137
|
-
weights = self.weights[:n][::-1]
|
|
138
|
-
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
139
|
-
|
|
140
|
-
result = np.cumsum(self.alpha * weights * arr, axis=0)
|
|
141
|
-
result = result / weights
|
|
142
|
-
|
|
143
|
-
# Handle the first call when prev is unset
|
|
144
|
-
if self.prev is None:
|
|
145
|
-
self.prev = arr[:1]
|
|
146
|
-
|
|
147
|
-
result += self.prev * np.expand_dims(
|
|
148
|
-
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
# Store the result back into prev
|
|
152
|
-
self.prev = result[-1]
|
|
153
|
-
|
|
154
|
-
return result
|
|
155
|
-
|
|
156
|
-
def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
|
|
157
|
-
if self.prev is None:
|
|
158
|
-
self.prev = new_sample
|
|
159
|
-
self.prev = self._step_func(new_sample, self.prev)
|
|
160
|
-
return self.prev
|
|
14
|
+
# Imports for backwards compatibility with previous module location
|
|
15
|
+
from .ewma import EWMA_Deprecated as EWMA_Deprecated
|
|
16
|
+
from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
|
|
17
|
+
from .ewma import _tau_from_alpha as _tau_from_alpha
|
|
18
|
+
from .ewma import ewma_step as ewma_step
|
|
161
19
|
|
|
162
20
|
|
|
163
21
|
@consumer
|
|
164
|
-
def scaler(
|
|
165
|
-
time_constant: float = 1.0, axis: str | None = None
|
|
166
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
22
|
+
def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
167
23
|
"""
|
|
168
24
|
Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
|
|
169
25
|
This is faster than :obj:`scaler_np` for single-channel data.
|
|
@@ -208,83 +64,102 @@ def scaler(
|
|
|
208
64
|
msg_out = replace(msg_in, data=result)
|
|
209
65
|
|
|
210
66
|
|
|
211
|
-
|
|
212
|
-
def scaler_np(
|
|
213
|
-
time_constant: float = 1.0, axis: str | None = None
|
|
214
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
215
|
-
"""
|
|
216
|
-
Create a generator function that applies an adaptive standard scaler.
|
|
217
|
-
This is faster than :obj:`scaler` for multichannel data.
|
|
67
|
+
class AdaptiveStandardScalerSettings(EWMASettings): ...
|
|
218
68
|
|
|
219
|
-
Args:
|
|
220
|
-
time_constant: Decay constant `tau` in seconds.
|
|
221
|
-
axis: The name of the axis to accumulate statistics over.
|
|
222
|
-
Note: The axis must be in the msg.axes and be of type AxisArray.LinearAxis.
|
|
223
69
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
70
|
+
@processor_state
|
|
71
|
+
class AdaptiveStandardScalerState:
|
|
72
|
+
samps_ewma: EWMATransformer | None = None
|
|
73
|
+
vars_sq_ewma: EWMATransformer | None = None
|
|
74
|
+
alpha: float | None = None
|
|
229
75
|
|
|
230
|
-
# State variables
|
|
231
|
-
samps_ewma: EWMA | None = None
|
|
232
|
-
vars_sq_ewma: EWMA | None = None
|
|
233
76
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
77
|
+
class AdaptiveStandardScalerTransformer(
|
|
78
|
+
BaseStatefulTransformer[
|
|
79
|
+
AdaptiveStandardScalerSettings,
|
|
80
|
+
AxisArray,
|
|
81
|
+
AxisArray,
|
|
82
|
+
AdaptiveStandardScalerState,
|
|
83
|
+
]
|
|
84
|
+
):
|
|
85
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
86
|
+
self._state.samps_ewma = EWMATransformer(
|
|
87
|
+
time_constant=self.settings.time_constant,
|
|
88
|
+
axis=self.settings.axis,
|
|
89
|
+
accumulate=self.settings.accumulate,
|
|
90
|
+
)
|
|
91
|
+
self._state.vars_sq_ewma = EWMATransformer(
|
|
92
|
+
time_constant=self.settings.time_constant,
|
|
93
|
+
axis=self.settings.axis,
|
|
94
|
+
accumulate=self.settings.accumulate,
|
|
95
|
+
)
|
|
240
96
|
|
|
241
|
-
|
|
242
|
-
|
|
97
|
+
@property
|
|
98
|
+
def accumulate(self) -> bool:
|
|
99
|
+
"""Whether to accumulate statistics from incoming samples."""
|
|
100
|
+
return self.settings.accumulate
|
|
243
101
|
|
|
244
|
-
|
|
245
|
-
|
|
102
|
+
@accumulate.setter
|
|
103
|
+
def accumulate(self, value: bool) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Set the accumulate mode and propagate to child EWMA transformers.
|
|
246
106
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
if
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
256
|
-
samps_ewma = EWMA(alpha=alpha)
|
|
257
|
-
vars_sq_ewma = EWMA(alpha=alpha)
|
|
107
|
+
Args:
|
|
108
|
+
value: If True, update statistics with each sample.
|
|
109
|
+
If False, only apply current statistics without updating.
|
|
110
|
+
"""
|
|
111
|
+
if self._state.samps_ewma is not None:
|
|
112
|
+
self._state.samps_ewma.settings = replace(self._state.samps_ewma.settings, accumulate=value)
|
|
113
|
+
if self._state.vars_sq_ewma is not None:
|
|
114
|
+
self._state.vars_sq_ewma.settings = replace(self._state.vars_sq_ewma.settings, accumulate=value)
|
|
258
115
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
116
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
117
|
+
# Update step (respects accumulate setting via child EWMAs)
|
|
118
|
+
mean_message = self._state.samps_ewma(message)
|
|
119
|
+
var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
|
|
262
120
|
|
|
263
121
|
# Get step
|
|
264
|
-
varis =
|
|
122
|
+
varis = var_sq_message.data - mean_message.data**2
|
|
265
123
|
with np.errstate(divide="ignore", invalid="ignore"):
|
|
266
|
-
result = (data -
|
|
124
|
+
result = (message.data - mean_message.data) / (varis**0.5)
|
|
267
125
|
result[np.isnan(result)] = 0.0
|
|
268
|
-
|
|
269
|
-
msg_out = replace(msg_in, data=result)
|
|
126
|
+
return replace(message, data=result)
|
|
270
127
|
|
|
271
128
|
|
|
272
|
-
class
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
129
|
+
class AdaptiveStandardScaler(
|
|
130
|
+
BaseTransformerUnit[
|
|
131
|
+
AdaptiveStandardScalerSettings,
|
|
132
|
+
AxisArray,
|
|
133
|
+
AxisArray,
|
|
134
|
+
AdaptiveStandardScalerTransformer,
|
|
135
|
+
]
|
|
136
|
+
):
|
|
137
|
+
SETTINGS = AdaptiveStandardScalerSettings
|
|
277
138
|
|
|
278
|
-
|
|
279
|
-
|
|
139
|
+
@ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
|
|
140
|
+
async def on_settings(self, msg: AdaptiveStandardScalerSettings) -> None:
|
|
141
|
+
"""
|
|
142
|
+
Handle settings updates with smart reset behavior.
|
|
280
143
|
|
|
144
|
+
Only resets state if `axis` changes (structural change).
|
|
145
|
+
Changes to `time_constant` or `accumulate` are applied without
|
|
146
|
+
resetting accumulated statistics.
|
|
147
|
+
"""
|
|
148
|
+
old_axis = self.SETTINGS.axis
|
|
149
|
+
self.apply_settings(msg)
|
|
281
150
|
|
|
282
|
-
|
|
283
|
-
|
|
151
|
+
if msg.axis != old_axis:
|
|
152
|
+
# Axis changed - need full reset
|
|
153
|
+
self.create_processor()
|
|
154
|
+
else:
|
|
155
|
+
# Update accumulate on processor (propagates to child EWMAs)
|
|
156
|
+
self.processor.accumulate = msg.accumulate
|
|
157
|
+
# Also update own settings reference
|
|
158
|
+
self.processor.settings = msg
|
|
284
159
|
|
|
285
|
-
SETTINGS = AdaptiveStandardScalerSettings
|
|
286
160
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
)
|
|
161
|
+
# Backwards compatibility...
|
|
162
|
+
def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
|
|
163
|
+
return AdaptiveStandardScalerTransformer(
|
|
164
|
+
settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
|
|
165
|
+
)
|
ezmsg/sigproc/signalinjector.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
1
|
import ezmsg.core as ez
|
|
4
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
-
from ezmsg.util.messages.util import replace
|
|
6
2
|
import numpy as np
|
|
7
3
|
import numpy.typing as npt
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseAsyncTransformer,
|
|
6
|
+
BaseTransformerUnit,
|
|
7
|
+
processor_state,
|
|
8
|
+
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
class SignalInjectorSettings(ez.Settings):
|
|
@@ -14,56 +17,54 @@ class SignalInjectorSettings(ez.Settings):
|
|
|
14
17
|
mixing_seed: int | None = None
|
|
15
18
|
|
|
16
19
|
|
|
17
|
-
|
|
20
|
+
@processor_state
|
|
21
|
+
class SignalInjectorState:
|
|
18
22
|
cur_shape: tuple[int, ...] | None = None
|
|
19
23
|
cur_frequency: float | None = None
|
|
20
|
-
cur_amplitude: float
|
|
21
|
-
mixing: npt.NDArray
|
|
24
|
+
cur_amplitude: float | None = None
|
|
25
|
+
mixing: npt.NDArray | None = None
|
|
22
26
|
|
|
23
27
|
|
|
24
|
-
class
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
28
|
+
class SignalInjectorTransformer(
|
|
29
|
+
BaseAsyncTransformer[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState]
|
|
30
|
+
):
|
|
31
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
32
|
+
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
33
|
+
sample_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
34
|
+
return hash((message.key,) + sample_shape)
|
|
29
35
|
|
|
30
|
-
|
|
31
|
-
|
|
36
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
37
|
+
if self._state.cur_frequency is None:
|
|
38
|
+
self._state.cur_frequency = self.settings.frequency
|
|
39
|
+
if self._state.cur_amplitude is None:
|
|
40
|
+
self._state.cur_amplitude = self.settings.amplitude
|
|
41
|
+
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
42
|
+
self._state.cur_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
43
|
+
rng = np.random.default_rng(self.settings.mixing_seed)
|
|
44
|
+
self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
|
|
45
|
+
self._state.mixing = (self._state.mixing * 2.0) - 1.0
|
|
32
46
|
|
|
47
|
+
async def _aprocess(self, message: AxisArray) -> AxisArray:
|
|
48
|
+
if self._state.cur_frequency is None:
|
|
49
|
+
return message
|
|
50
|
+
out_msg = replace(message, data=message.data.copy())
|
|
51
|
+
t = out_msg.ax(self.settings.time_dim).values[..., np.newaxis]
|
|
52
|
+
signal = np.sin(2 * np.pi * self._state.cur_frequency * t)
|
|
53
|
+
mixed_signal = signal * self._state.mixing * self._state.cur_amplitude
|
|
54
|
+
with out_msg.view2d(self.settings.time_dim) as view:
|
|
55
|
+
view[...] = view + mixed_signal.astype(view.dtype)
|
|
56
|
+
return out_msg
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SignalInjector(BaseTransformerUnit[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer]):
|
|
60
|
+
SETTINGS = SignalInjectorSettings
|
|
33
61
|
INPUT_FREQUENCY = ez.InputStream(float | None)
|
|
34
62
|
INPUT_AMPLITUDE = ez.InputStream(float)
|
|
35
|
-
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
36
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
37
|
-
|
|
38
|
-
async def initialize(self) -> None:
|
|
39
|
-
self.STATE.cur_frequency = self.SETTINGS.frequency
|
|
40
|
-
self.STATE.cur_amplitude = self.SETTINGS.amplitude
|
|
41
|
-
self.STATE.mixing = np.array([])
|
|
42
63
|
|
|
43
64
|
@ez.subscriber(INPUT_FREQUENCY)
|
|
44
65
|
async def on_frequency(self, msg: float | None) -> None:
|
|
45
|
-
self.
|
|
66
|
+
self.processor.state.cur_frequency = msg
|
|
46
67
|
|
|
47
68
|
@ez.subscriber(INPUT_AMPLITUDE)
|
|
48
69
|
async def on_amplitude(self, msg: float) -> None:
|
|
49
|
-
self.
|
|
50
|
-
|
|
51
|
-
@ez.subscriber(INPUT_SIGNAL)
|
|
52
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
53
|
-
async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
54
|
-
if self.STATE.cur_shape != msg.shape:
|
|
55
|
-
self.STATE.cur_shape = msg.shape
|
|
56
|
-
rng = np.random.default_rng(self.SETTINGS.mixing_seed)
|
|
57
|
-
self.STATE.mixing = rng.random((1, msg.shape2d(self.SETTINGS.time_dim)[1]))
|
|
58
|
-
self.STATE.mixing = (self.STATE.mixing * 2.0) - 1.0
|
|
59
|
-
|
|
60
|
-
if self.STATE.cur_frequency is None:
|
|
61
|
-
yield self.OUTPUT_SIGNAL, msg
|
|
62
|
-
else:
|
|
63
|
-
out_msg = replace(msg, data=msg.data.copy())
|
|
64
|
-
t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
|
|
65
|
-
signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
|
|
66
|
-
mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
|
|
67
|
-
with out_msg.view2d(self.SETTINGS.time_dim) as view:
|
|
68
|
-
view[...] = view + mixed_signal.astype(view.dtype)
|
|
69
|
-
yield self.OUTPUT_SIGNAL, out_msg
|
|
70
|
+
self.processor.state.cur_amplitude = msg
|
ezmsg/sigproc/slicer.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
import ezmsg.core as ez
|
|
3
2
|
import numpy as np
|
|
4
3
|
import numpy.typing as npt
|
|
5
|
-
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseStatefulTransformer,
|
|
6
|
+
BaseTransformerUnit,
|
|
7
|
+
processor_state,
|
|
8
|
+
)
|
|
6
9
|
from ezmsg.util.messages.axisarray import (
|
|
7
10
|
AxisArray,
|
|
8
|
-
slice_along_axis,
|
|
9
11
|
AxisBase,
|
|
10
12
|
replace,
|
|
13
|
+
slice_along_axis,
|
|
11
14
|
)
|
|
12
|
-
from ezmsg.util.generator import consumer
|
|
13
|
-
|
|
14
|
-
from .base import GenAxisArray
|
|
15
|
-
|
|
16
15
|
|
|
17
16
|
"""
|
|
18
17
|
Slicer:Select a subset of data along a particular axis.
|
|
@@ -49,11 +48,7 @@ def parse_slice(
|
|
|
49
48
|
if "," not in s:
|
|
50
49
|
parts = [part.strip() for part in s.split(":")]
|
|
51
50
|
if len(parts) == 1:
|
|
52
|
-
if (
|
|
53
|
-
axinfo is not None
|
|
54
|
-
and hasattr(axinfo, "data")
|
|
55
|
-
and parts[0] in axinfo.data
|
|
56
|
-
):
|
|
51
|
+
if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
|
|
57
52
|
return tuple(np.where(axinfo.data == parts[0])[0])
|
|
58
53
|
return (int(parts[0]),)
|
|
59
54
|
return (slice(*(int(part.strip()) if part else None for part in parts)),)
|
|
@@ -61,106 +56,83 @@ def parse_slice(
|
|
|
61
56
|
return tuple([item for sublist in suplist for item in sublist])
|
|
62
57
|
|
|
63
58
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
selection:
|
|
67
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
68
|
-
"""
|
|
69
|
-
Slice along a particular axis.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
|
|
73
|
-
axis: The name of the axis to slice along. If None, the last axis is used.
|
|
59
|
+
class SlicerSettings(ez.Settings):
|
|
60
|
+
selection: str = ""
|
|
61
|
+
"""selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details."""
|
|
74
62
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
with the data payload containing a sliced view of the input data.
|
|
63
|
+
axis: str | None = None
|
|
64
|
+
"""The name of the axis to slice along. If None, the last axis is used."""
|
|
78
65
|
|
|
79
|
-
"""
|
|
80
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
81
66
|
|
|
82
|
-
|
|
83
|
-
|
|
67
|
+
@processor_state
|
|
68
|
+
class SlicerState:
|
|
69
|
+
slice_: slice | int | npt.NDArray | None = None
|
|
84
70
|
new_axis: AxisBase | None = None
|
|
85
|
-
b_change_dims: bool = False
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
if
|
|
114
|
-
|
|
115
|
-
# Do we drop the sliced dimension?
|
|
116
|
-
b_change_dims = isinstance(_slice, int)
|
|
71
|
+
b_change_dims: bool = False
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class SlicerTransformer(BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]):
|
|
75
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
76
|
+
axis = self.settings.axis or message.dims[-1]
|
|
77
|
+
axis_idx = message.get_axis_idx(axis)
|
|
78
|
+
return hash((message.key, message.data.shape[axis_idx]))
|
|
79
|
+
|
|
80
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
81
|
+
axis = self.settings.axis or message.dims[-1]
|
|
82
|
+
axis_idx = message.get_axis_idx(axis)
|
|
83
|
+
self._state.new_axis = None
|
|
84
|
+
self._state.b_change_dims = False
|
|
85
|
+
|
|
86
|
+
# Calculate the slice
|
|
87
|
+
_slices = parse_slice(self.settings.selection, message.axes.get(axis, None))
|
|
88
|
+
if len(_slices) == 1:
|
|
89
|
+
self._state.slice_ = _slices[0]
|
|
90
|
+
self._state.b_change_dims = isinstance(self._state.slice_, int)
|
|
91
|
+
else:
|
|
92
|
+
indices = np.arange(message.data.shape[axis_idx])
|
|
93
|
+
indices = np.hstack([indices[_] for _ in _slices])
|
|
94
|
+
self._state.slice_ = np.s_[indices]
|
|
95
|
+
|
|
96
|
+
# Create the output axis
|
|
97
|
+
if axis in message.axes and hasattr(message.axes[axis], "data") and len(message.axes[axis].data) > 0:
|
|
98
|
+
in_data = np.array(message.axes[axis].data)
|
|
99
|
+
if self._state.b_change_dims:
|
|
100
|
+
out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
|
|
117
101
|
else:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
# Create the output axis.
|
|
125
|
-
if (
|
|
126
|
-
axis in msg_in.axes
|
|
127
|
-
and hasattr(msg_in.axes[axis], "data")
|
|
128
|
-
and len(msg_in.axes[axis].data) > 0
|
|
129
|
-
):
|
|
130
|
-
in_data = np.array(msg_in.axes[axis].data)
|
|
131
|
-
if b_change_dims:
|
|
132
|
-
out_data = in_data[_slice : _slice + 1]
|
|
133
|
-
else:
|
|
134
|
-
out_data = in_data[_slice]
|
|
135
|
-
new_axis = replace(msg_in.axes[axis], data=out_data)
|
|
102
|
+
out_data = in_data[self._state.slice_]
|
|
103
|
+
self._state.new_axis = replace(message.axes[axis], data=out_data)
|
|
104
|
+
|
|
105
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
106
|
+
axis = self.settings.axis or message.dims[-1]
|
|
107
|
+
axis_idx = message.get_axis_idx(axis)
|
|
136
108
|
|
|
137
109
|
replace_kwargs = {}
|
|
138
|
-
if b_change_dims:
|
|
139
|
-
|
|
140
|
-
replace_kwargs["
|
|
141
|
-
|
|
142
|
-
]
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
}
|
|
148
|
-
msg_out = replace(
|
|
149
|
-
msg_in,
|
|
150
|
-
data=slice_along_axis(msg_in.data, _slice, axis_idx),
|
|
110
|
+
if self._state.b_change_dims:
|
|
111
|
+
replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx]
|
|
112
|
+
replace_kwargs["axes"] = {k: v for k, v in message.axes.items() if k != axis}
|
|
113
|
+
elif self._state.new_axis is not None:
|
|
114
|
+
replace_kwargs["axes"] = {k: (v if k != axis else self._state.new_axis) for k, v in message.axes.items()}
|
|
115
|
+
|
|
116
|
+
return replace(
|
|
117
|
+
message,
|
|
118
|
+
data=slice_along_axis(message.data, self._state.slice_, axis_idx),
|
|
151
119
|
**replace_kwargs,
|
|
152
120
|
)
|
|
153
121
|
|
|
154
122
|
|
|
155
|
-
class SlicerSettings
|
|
156
|
-
|
|
157
|
-
axis: str | None = None
|
|
123
|
+
class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
|
|
124
|
+
SETTINGS = SlicerSettings
|
|
158
125
|
|
|
159
126
|
|
|
160
|
-
|
|
161
|
-
|
|
127
|
+
def slicer(selection: str = "", axis: str | None = None) -> SlicerTransformer:
|
|
128
|
+
"""
|
|
129
|
+
Slice along a particular axis.
|
|
162
130
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
131
|
+
Args:
|
|
132
|
+
selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
|
|
133
|
+
axis: The name of the axis to slice along. If None, the last axis is used.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
:obj:`SlicerTransformer`
|
|
137
|
+
"""
|
|
138
|
+
return SlicerTransformer(SlicerSettings(selection=selection, axis=axis))
|