ezmsg-sigproc 1.8.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +119 -104
- ezmsg/sigproc/bandpower.py +58 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -78
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.1.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/scaler.py
CHANGED
|
@@ -1,163 +1,21 @@
|
|
|
1
|
-
import functools
|
|
2
1
|
import typing
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
|
-
import numpy.typing as npt
|
|
6
|
-
import scipy.signal
|
|
7
|
-
import ezmsg.core as ez
|
|
8
4
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
5
|
from ezmsg.util.messages.util import replace
|
|
10
6
|
from ezmsg.util.generator import consumer
|
|
11
7
|
|
|
12
|
-
from .base import
|
|
8
|
+
from .base import (
|
|
9
|
+
BaseStatefulTransformer,
|
|
10
|
+
BaseTransformerUnit,
|
|
11
|
+
processor_state,
|
|
12
|
+
)
|
|
13
|
+
from .ewma import EWMATransformer, EWMASettings, _alpha_from_tau
|
|
13
14
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
"""
|
|
19
|
-
return -dt / np.log(1 - alpha)
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
23
|
-
"""
|
|
24
|
-
# https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
|
|
25
|
-
:param tau: The amount of time for the smoothed response of a unit step function to reach
|
|
26
|
-
1 - 1/e approx-eq 63.2%.
|
|
27
|
-
:param dt: sampling period, or 1 / sampling_rate.
|
|
28
|
-
:return: alpha, the "fading factor" in exponential smoothing.
|
|
29
|
-
"""
|
|
30
|
-
return 1 - np.exp(-dt / tau)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def ewma_step(
|
|
34
|
-
sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
|
|
35
|
-
):
|
|
36
|
-
"""
|
|
37
|
-
Do an exponentially weighted moving average step.
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
sample: The new sample.
|
|
41
|
-
zi: The output of the previous step.
|
|
42
|
-
alpha: Fading factor.
|
|
43
|
-
beta: Persisting factor. If None, it is calculated as 1-alpha.
|
|
44
|
-
|
|
45
|
-
Returns:
|
|
46
|
-
alpha * sample + beta * zi
|
|
47
|
-
|
|
48
|
-
"""
|
|
49
|
-
# Potential micro-optimization:
|
|
50
|
-
# Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
|
|
51
|
-
# Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
|
|
52
|
-
# return zi + alpha * (new_sample - zi)
|
|
53
|
-
beta = beta or (1 - alpha)
|
|
54
|
-
return alpha * sample + beta * zi
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
class EWMA:
|
|
58
|
-
def __init__(self, alpha: float):
|
|
59
|
-
self.beta = 1 - alpha
|
|
60
|
-
self._filt_func = functools.partial(
|
|
61
|
-
scipy.signal.lfilter, [alpha], [1.0, alpha - 1.0], axis=0
|
|
62
|
-
)
|
|
63
|
-
self.prev = None
|
|
64
|
-
|
|
65
|
-
def compute(self, arr: npt.NDArray) -> npt.NDArray:
|
|
66
|
-
if self.prev is None:
|
|
67
|
-
self.prev = self.beta * arr[:1]
|
|
68
|
-
expected, self.prev = self._filt_func(arr, zi=self.prev)
|
|
69
|
-
return expected
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class EWMA_Deprecated:
|
|
73
|
-
"""
|
|
74
|
-
Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
|
|
75
|
-
but they ended up being slower than the scipy.signal.lfilter method.
|
|
76
|
-
Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
|
|
77
|
-
and beta**n approaches zero.
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
|
-
def __init__(self, alpha: float, max_len: int):
|
|
81
|
-
self.alpha = alpha
|
|
82
|
-
self.beta = 1 - alpha
|
|
83
|
-
self.prev: npt.NDArray | None = None
|
|
84
|
-
self.weights = np.empty((max_len + 1,), float)
|
|
85
|
-
self._precalc_weights(max_len)
|
|
86
|
-
self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
|
|
87
|
-
|
|
88
|
-
def _precalc_weights(self, n: int):
|
|
89
|
-
# (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
|
|
90
|
-
np.power(self.beta, np.arange(n + 1), out=self.weights)
|
|
91
|
-
|
|
92
|
-
def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
|
|
93
|
-
if out is None:
|
|
94
|
-
out = np.empty(arr.shape, arr.dtype)
|
|
95
|
-
|
|
96
|
-
n = arr.shape[0]
|
|
97
|
-
weights = self.weights[:n]
|
|
98
|
-
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
99
|
-
|
|
100
|
-
# α*P0, α*P1, α*P2, ..., α*Pn
|
|
101
|
-
np.multiply(self.alpha, arr, out)
|
|
102
|
-
|
|
103
|
-
# α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
|
|
104
|
-
np.divide(out, weights, out)
|
|
105
|
-
|
|
106
|
-
# α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
|
|
107
|
-
np.cumsum(out, axis=0, out=out)
|
|
108
|
-
|
|
109
|
-
# (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
|
|
110
|
-
np.multiply(out, weights, out)
|
|
111
|
-
|
|
112
|
-
# Add the previous output
|
|
113
|
-
if self.prev is None:
|
|
114
|
-
self.prev = arr[:1]
|
|
115
|
-
|
|
116
|
-
out += self.prev * np.expand_dims(
|
|
117
|
-
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
self.prev = out[-1:]
|
|
121
|
-
|
|
122
|
-
return out
|
|
123
|
-
|
|
124
|
-
def compute2(self, arr: npt.NDArray) -> npt.NDArray:
|
|
125
|
-
"""
|
|
126
|
-
Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
|
|
127
|
-
|
|
128
|
-
Args:
|
|
129
|
-
arr: The input array to be smoothed.
|
|
130
|
-
|
|
131
|
-
Returns:
|
|
132
|
-
The smoothed array.
|
|
133
|
-
"""
|
|
134
|
-
n = arr.shape[0]
|
|
135
|
-
if n > len(self.weights):
|
|
136
|
-
self._precalc_weights(n)
|
|
137
|
-
weights = self.weights[:n][::-1]
|
|
138
|
-
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
|
|
139
|
-
|
|
140
|
-
result = np.cumsum(self.alpha * weights * arr, axis=0)
|
|
141
|
-
result = result / weights
|
|
142
|
-
|
|
143
|
-
# Handle the first call when prev is unset
|
|
144
|
-
if self.prev is None:
|
|
145
|
-
self.prev = arr[:1]
|
|
146
|
-
|
|
147
|
-
result += self.prev * np.expand_dims(
|
|
148
|
-
self.weights[1 : n + 1], list(range(1, arr.ndim))
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
# Store the result back into prev
|
|
152
|
-
self.prev = result[-1]
|
|
153
|
-
|
|
154
|
-
return result
|
|
155
|
-
|
|
156
|
-
def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
|
|
157
|
-
if self.prev is None:
|
|
158
|
-
self.prev = new_sample
|
|
159
|
-
self.prev = self._step_func(new_sample, self.prev)
|
|
160
|
-
return self.prev
|
|
15
|
+
# Imports for backwards compatibility with previous module location
|
|
16
|
+
from .ewma import EWMA_Deprecated as EWMA_Deprecated
|
|
17
|
+
from .ewma import ewma_step as ewma_step
|
|
18
|
+
from .ewma import _tau_from_alpha as _tau_from_alpha
|
|
161
19
|
|
|
162
20
|
|
|
163
21
|
@consumer
|
|
@@ -208,83 +66,62 @@ def scaler(
|
|
|
208
66
|
msg_out = replace(msg_in, data=result)
|
|
209
67
|
|
|
210
68
|
|
|
211
|
-
|
|
212
|
-
def scaler_np(
|
|
213
|
-
time_constant: float = 1.0, axis: str | None = None
|
|
214
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
215
|
-
"""
|
|
216
|
-
Create a generator function that applies an adaptive standard scaler.
|
|
217
|
-
This is faster than :obj:`scaler` for multichannel data.
|
|
218
|
-
|
|
219
|
-
Args:
|
|
220
|
-
time_constant: Decay constant `tau` in seconds.
|
|
221
|
-
axis: The name of the axis to accumulate statistics over.
|
|
222
|
-
Note: The axis must be in the msg.axes and be of type AxisArray.LinearAxis.
|
|
223
|
-
|
|
224
|
-
Returns:
|
|
225
|
-
A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
|
|
226
|
-
and yields an :obj:`AxisArray` with its data being a standardized, or "Z-scored" version of the input data.
|
|
227
|
-
"""
|
|
228
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
229
|
-
|
|
230
|
-
# State variables
|
|
231
|
-
samps_ewma: EWMA | None = None
|
|
232
|
-
vars_sq_ewma: EWMA | None = None
|
|
69
|
+
class AdaptiveStandardScalerSettings(EWMASettings): ...
|
|
233
70
|
|
|
234
|
-
# Reset if input changes
|
|
235
|
-
check_input = {
|
|
236
|
-
"gain": None, # Resets alpha
|
|
237
|
-
"shape": None,
|
|
238
|
-
"key": None, # Key change implies buffered means/vars are invalid.
|
|
239
|
-
}
|
|
240
71
|
|
|
241
|
-
|
|
242
|
-
|
|
72
|
+
@processor_state
|
|
73
|
+
class AdaptiveStandardScalerState:
|
|
74
|
+
samps_ewma: EWMATransformer | None = None
|
|
75
|
+
vars_sq_ewma: EWMATransformer | None = None
|
|
76
|
+
alpha: float | None = None
|
|
243
77
|
|
|
244
|
-
axis = axis or msg_in.dims[0]
|
|
245
|
-
axis_idx = msg_in.get_axis_idx(axis)
|
|
246
78
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
79
|
+
class AdaptiveStandardScalerTransformer(
|
|
80
|
+
BaseStatefulTransformer[
|
|
81
|
+
AdaptiveStandardScalerSettings,
|
|
82
|
+
AxisArray,
|
|
83
|
+
AxisArray,
|
|
84
|
+
AdaptiveStandardScalerState,
|
|
85
|
+
]
|
|
86
|
+
):
|
|
87
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
88
|
+
self._state.samps_ewma = EWMATransformer(
|
|
89
|
+
time_constant=self.settings.time_constant, axis=self.settings.axis
|
|
90
|
+
)
|
|
91
|
+
self._state.vars_sq_ewma = EWMATransformer(
|
|
92
|
+
time_constant=self.settings.time_constant, axis=self.settings.axis
|
|
93
|
+
)
|
|
258
94
|
|
|
95
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
259
96
|
# Update step
|
|
260
|
-
|
|
261
|
-
|
|
97
|
+
mean_message = self._state.samps_ewma(message)
|
|
98
|
+
var_sq_message = self._state.vars_sq_ewma(
|
|
99
|
+
replace(message, data=message.data**2)
|
|
100
|
+
)
|
|
262
101
|
|
|
263
102
|
# Get step
|
|
264
|
-
varis =
|
|
103
|
+
varis = var_sq_message.data - mean_message.data**2
|
|
265
104
|
with np.errstate(divide="ignore", invalid="ignore"):
|
|
266
|
-
result = (data -
|
|
105
|
+
result = (message.data - mean_message.data) / (varis**0.5)
|
|
267
106
|
result[np.isnan(result)] = 0.0
|
|
268
|
-
|
|
269
|
-
msg_out = replace(msg_in, data=result)
|
|
107
|
+
return replace(message, data=result)
|
|
270
108
|
|
|
271
109
|
|
|
272
|
-
class
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
class AdaptiveStandardScaler(GenAxisArray):
|
|
283
|
-
"""Unit for :obj:`scaler_np`"""
|
|
284
|
-
|
|
110
|
+
class AdaptiveStandardScaler(
|
|
111
|
+
BaseTransformerUnit[
|
|
112
|
+
AdaptiveStandardScalerSettings,
|
|
113
|
+
AxisArray,
|
|
114
|
+
AxisArray,
|
|
115
|
+
AdaptiveStandardScalerTransformer,
|
|
116
|
+
]
|
|
117
|
+
):
|
|
285
118
|
SETTINGS = AdaptiveStandardScalerSettings
|
|
286
119
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
120
|
+
|
|
121
|
+
# Backwards compatibility...
|
|
122
|
+
def scaler_np(
|
|
123
|
+
time_constant: float = 1.0, axis: str | None = None
|
|
124
|
+
) -> AdaptiveStandardScalerTransformer:
|
|
125
|
+
return AdaptiveStandardScalerTransformer(
|
|
126
|
+
settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
|
|
127
|
+
)
|
ezmsg/sigproc/signalinjector.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
1
|
import ezmsg.core as ez
|
|
4
2
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
3
|
from ezmsg.util.messages.util import replace
|
|
6
4
|
import numpy as np
|
|
7
5
|
import numpy.typing as npt
|
|
8
6
|
|
|
9
|
-
from .
|
|
7
|
+
from .base import (
|
|
8
|
+
BaseAsyncTransformer,
|
|
9
|
+
BaseTransformerUnit,
|
|
10
|
+
processor_state,
|
|
11
|
+
)
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class SignalInjectorSettings(ez.Settings):
|
|
@@ -16,57 +18,64 @@ class SignalInjectorSettings(ez.Settings):
|
|
|
16
18
|
mixing_seed: int | None = None
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
|
|
21
|
+
@processor_state
|
|
22
|
+
class SignalInjectorState:
|
|
20
23
|
cur_shape: tuple[int, ...] | None = None
|
|
21
24
|
cur_frequency: float | None = None
|
|
22
|
-
cur_amplitude: float
|
|
23
|
-
mixing: npt.NDArray
|
|
25
|
+
cur_amplitude: float | None = None
|
|
26
|
+
mixing: npt.NDArray | None = None
|
|
24
27
|
|
|
25
28
|
|
|
26
|
-
class
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
class SignalInjectorTransformer(
|
|
30
|
+
BaseAsyncTransformer[
|
|
31
|
+
SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState
|
|
32
|
+
]
|
|
33
|
+
):
|
|
34
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
35
|
+
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
36
|
+
sample_shape = (
|
|
37
|
+
message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
38
|
+
)
|
|
39
|
+
return hash((message.key,) + sample_shape)
|
|
40
|
+
|
|
41
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
42
|
+
if self._state.cur_frequency is None:
|
|
43
|
+
self._state.cur_frequency = self.settings.frequency
|
|
44
|
+
if self._state.cur_amplitude is None:
|
|
45
|
+
self._state.cur_amplitude = self.settings.amplitude
|
|
46
|
+
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
47
|
+
self._state.cur_shape = (
|
|
48
|
+
message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
49
|
+
)
|
|
50
|
+
rng = np.random.default_rng(self.settings.mixing_seed)
|
|
51
|
+
self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
|
|
52
|
+
self._state.mixing = (self._state.mixing * 2.0) - 1.0
|
|
53
|
+
|
|
54
|
+
async def _aprocess(self, message: AxisArray) -> AxisArray:
|
|
55
|
+
if self._state.cur_frequency is None:
|
|
56
|
+
return message
|
|
57
|
+
out_msg = replace(message, data=message.data.copy())
|
|
58
|
+
t = out_msg.ax(self.settings.time_dim).values[..., np.newaxis]
|
|
59
|
+
signal = np.sin(2 * np.pi * self._state.cur_frequency * t)
|
|
60
|
+
mixed_signal = signal * self._state.mixing * self._state.cur_amplitude
|
|
61
|
+
with out_msg.view2d(self.settings.time_dim) as view:
|
|
62
|
+
view[...] = view + mixed_signal.astype(view.dtype)
|
|
63
|
+
return out_msg
|
|
31
64
|
|
|
32
|
-
SETTINGS = SignalInjectorSettings
|
|
33
|
-
STATE = SignalInjectorState
|
|
34
65
|
|
|
66
|
+
class SignalInjector(
|
|
67
|
+
BaseTransformerUnit[
|
|
68
|
+
SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer
|
|
69
|
+
]
|
|
70
|
+
):
|
|
71
|
+
SETTINGS = SignalInjectorSettings
|
|
35
72
|
INPUT_FREQUENCY = ez.InputStream(float | None)
|
|
36
73
|
INPUT_AMPLITUDE = ez.InputStream(float)
|
|
37
|
-
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
38
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
39
|
-
|
|
40
|
-
async def initialize(self) -> None:
|
|
41
|
-
self.STATE.cur_frequency = self.SETTINGS.frequency
|
|
42
|
-
self.STATE.cur_amplitude = self.SETTINGS.amplitude
|
|
43
|
-
self.STATE.mixing = np.array([])
|
|
44
74
|
|
|
45
75
|
@ez.subscriber(INPUT_FREQUENCY)
|
|
46
76
|
async def on_frequency(self, msg: float | None) -> None:
|
|
47
|
-
self.
|
|
77
|
+
self.processor.state.cur_frequency = msg
|
|
48
78
|
|
|
49
79
|
@ez.subscriber(INPUT_AMPLITUDE)
|
|
50
80
|
async def on_amplitude(self, msg: float) -> None:
|
|
51
|
-
self.
|
|
52
|
-
|
|
53
|
-
@ez.subscriber(INPUT_SIGNAL)
|
|
54
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
55
|
-
@profile_subpub(trace_oldest=False)
|
|
56
|
-
async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
57
|
-
if self.STATE.cur_shape != msg.shape:
|
|
58
|
-
self.STATE.cur_shape = msg.shape
|
|
59
|
-
rng = np.random.default_rng(self.SETTINGS.mixing_seed)
|
|
60
|
-
self.STATE.mixing = rng.random((1, msg.shape2d(self.SETTINGS.time_dim)[1]))
|
|
61
|
-
self.STATE.mixing = (self.STATE.mixing * 2.0) - 1.0
|
|
62
|
-
|
|
63
|
-
if self.STATE.cur_frequency is None:
|
|
64
|
-
yield self.OUTPUT_SIGNAL, msg
|
|
65
|
-
else:
|
|
66
|
-
out_msg = replace(msg, data=msg.data.copy())
|
|
67
|
-
t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
|
|
68
|
-
signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
|
|
69
|
-
mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
|
|
70
|
-
with out_msg.view2d(self.SETTINGS.time_dim) as view:
|
|
71
|
-
view[...] = view + mixed_signal.astype(view.dtype)
|
|
72
|
-
yield self.OUTPUT_SIGNAL, out_msg
|
|
81
|
+
self.processor.state.cur_amplitude = msg
|
ezmsg/sigproc/slicer.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
import numpy.typing as npt
|
|
5
3
|
import ezmsg.core as ez
|
|
@@ -9,10 +7,12 @@ from ezmsg.util.messages.axisarray import (
|
|
|
9
7
|
AxisBase,
|
|
10
8
|
replace,
|
|
11
9
|
)
|
|
12
|
-
from ezmsg.util.generator import consumer
|
|
13
|
-
|
|
14
|
-
from .base import GenAxisArray
|
|
15
10
|
|
|
11
|
+
from .base import (
|
|
12
|
+
BaseStatefulTransformer,
|
|
13
|
+
BaseTransformerUnit,
|
|
14
|
+
processor_state,
|
|
15
|
+
)
|
|
16
16
|
|
|
17
17
|
"""
|
|
18
18
|
Slicer:Select a subset of data along a particular axis.
|
|
@@ -61,106 +61,98 @@ def parse_slice(
|
|
|
61
61
|
return tuple([item for sublist in suplist for item in sublist])
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
selection:
|
|
67
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
68
|
-
"""
|
|
69
|
-
Slice along a particular axis.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
|
|
73
|
-
axis: The name of the axis to slice along. If None, the last axis is used.
|
|
64
|
+
class SlicerSettings(ez.Settings):
|
|
65
|
+
selection: str = ""
|
|
66
|
+
"""selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details."""
|
|
74
67
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
with the data payload containing a sliced view of the input data.
|
|
68
|
+
axis: str | None = None
|
|
69
|
+
"""The name of the axis to slice along. If None, the last axis is used."""
|
|
78
70
|
|
|
79
|
-
"""
|
|
80
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
81
71
|
|
|
82
|
-
|
|
83
|
-
|
|
72
|
+
@processor_state
|
|
73
|
+
class SlicerState:
|
|
74
|
+
slice_: slice | int | npt.NDArray | None = None
|
|
84
75
|
new_axis: AxisBase | None = None
|
|
85
|
-
b_change_dims: bool = False
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
76
|
+
b_change_dims: bool = False
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class SlicerTransformer(
|
|
80
|
+
BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]
|
|
81
|
+
):
|
|
82
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
83
|
+
axis = self.settings.axis or message.dims[-1]
|
|
84
|
+
axis_idx = message.get_axis_idx(axis)
|
|
85
|
+
return hash((message.key, message.data.shape[axis_idx]))
|
|
86
|
+
|
|
87
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
88
|
+
axis = self.settings.axis or message.dims[-1]
|
|
89
|
+
axis_idx = message.get_axis_idx(axis)
|
|
90
|
+
self._state.new_axis = None
|
|
91
|
+
self._state.b_change_dims = False
|
|
92
|
+
|
|
93
|
+
# Calculate the slice
|
|
94
|
+
_slices = parse_slice(self.settings.selection, message.axes.get(axis, None))
|
|
95
|
+
if len(_slices) == 1:
|
|
96
|
+
self._state.slice_ = _slices[0]
|
|
97
|
+
self._state.b_change_dims = isinstance(self._state.slice_, int)
|
|
98
|
+
else:
|
|
99
|
+
indices = np.arange(message.data.shape[axis_idx])
|
|
100
|
+
indices = np.hstack([indices[_] for _ in _slices])
|
|
101
|
+
self._state.slice_ = np.s_[indices]
|
|
102
|
+
|
|
103
|
+
# Create the output axis
|
|
104
|
+
if (
|
|
105
|
+
axis in message.axes
|
|
106
|
+
and hasattr(message.axes[axis], "data")
|
|
107
|
+
and len(message.axes[axis].data) > 0
|
|
108
|
+
):
|
|
109
|
+
in_data = np.array(message.axes[axis].data)
|
|
110
|
+
if self._state.b_change_dims:
|
|
111
|
+
out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
|
|
117
112
|
else:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
indices = np.arange(msg_in.data.shape[axis_idx])
|
|
121
|
-
indices = np.hstack([indices[_] for _ in _slices])
|
|
122
|
-
_slice = np.s_[indices] # Integer scalar array
|
|
113
|
+
out_data = in_data[self._state.slice_]
|
|
114
|
+
self._state.new_axis = replace(message.axes[axis], data=out_data)
|
|
123
115
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
and hasattr(msg_in.axes[axis], "data")
|
|
128
|
-
and len(msg_in.axes[axis].data) > 0
|
|
129
|
-
):
|
|
130
|
-
in_data = np.array(msg_in.axes[axis].data)
|
|
131
|
-
if b_change_dims:
|
|
132
|
-
out_data = in_data[_slice : _slice + 1]
|
|
133
|
-
else:
|
|
134
|
-
out_data = in_data[_slice]
|
|
135
|
-
new_axis = replace(msg_in.axes[axis], data=out_data)
|
|
116
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
117
|
+
axis = self.settings.axis or message.dims[-1]
|
|
118
|
+
axis_idx = message.get_axis_idx(axis)
|
|
136
119
|
|
|
137
120
|
replace_kwargs = {}
|
|
138
|
-
if b_change_dims:
|
|
139
|
-
# Dropping the target axis
|
|
121
|
+
if self._state.b_change_dims:
|
|
140
122
|
replace_kwargs["dims"] = [
|
|
141
|
-
_ for dim_ix, _ in enumerate(
|
|
123
|
+
_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx
|
|
142
124
|
]
|
|
143
|
-
replace_kwargs["axes"] = {k: v for k, v in msg_in.axes.items() if k != axis}
|
|
144
|
-
elif new_axis is not None:
|
|
145
125
|
replace_kwargs["axes"] = {
|
|
146
|
-
k:
|
|
126
|
+
k: v for k, v in message.axes.items() if k != axis
|
|
147
127
|
}
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
128
|
+
elif self._state.new_axis is not None:
|
|
129
|
+
replace_kwargs["axes"] = {
|
|
130
|
+
k: (v if k != axis else self._state.new_axis)
|
|
131
|
+
for k, v in message.axes.items()
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
return replace(
|
|
135
|
+
message,
|
|
136
|
+
data=slice_along_axis(message.data, self._state.slice_, axis_idx),
|
|
151
137
|
**replace_kwargs,
|
|
152
138
|
)
|
|
153
139
|
|
|
154
140
|
|
|
155
|
-
class
|
|
156
|
-
|
|
157
|
-
|
|
141
|
+
class Slicer(
|
|
142
|
+
BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]
|
|
143
|
+
):
|
|
144
|
+
SETTINGS = SlicerSettings
|
|
158
145
|
|
|
159
146
|
|
|
160
|
-
|
|
161
|
-
|
|
147
|
+
def slicer(selection: str = "", axis: str | None = None) -> SlicerTransformer:
|
|
148
|
+
"""
|
|
149
|
+
Slice along a particular axis.
|
|
162
150
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
151
|
+
Args:
|
|
152
|
+
selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
|
|
153
|
+
axis: The name of the axis to slice along. If None, the last axis is used.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
:obj:`SlicerTransformer`
|
|
157
|
+
"""
|
|
158
|
+
return SlicerTransformer(SlicerSettings(selection=selection, axis=axis))
|