ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
- ezmsg/sigproc/affinetransform.py +16 -42
- ezmsg/sigproc/aggregate.py +17 -34
- ezmsg/sigproc/bandpower.py +12 -20
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +7 -16
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/coordinatespaces.py +142 -0
- ezmsg/sigproc/decimate.py +3 -7
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +11 -20
- ezmsg/sigproc/ewma.py +11 -28
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +34 -59
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +37 -74
- ezmsg/sigproc/filterbankdesign.py +7 -14
- ezmsg/sigproc/fir_hilbert.py +13 -30
- ezmsg/sigproc/fir_pmc.py +5 -10
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +4 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +4 -1
- ezmsg/sigproc/math/difference.py +100 -36
- ezmsg/sigproc/math/invert.py +3 -3
- ezmsg/sigproc/math/log.py +5 -6
- ezmsg/sigproc/math/scale.py +2 -0
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +3 -6
- ezmsg/sigproc/resample.py +17 -38
- ezmsg/sigproc/rollingscaler.py +12 -37
- ezmsg/sigproc/sampler.py +19 -37
- ezmsg/sigproc/scaler.py +11 -22
- ezmsg/sigproc/signalinjector.py +7 -18
- ezmsg/sigproc/slicer.py +14 -34
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +17 -38
- ezmsg/sigproc/transpose.py +12 -24
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +12 -26
- ezmsg/sigproc/util/buffer.py +22 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +7 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +10 -19
- ezmsg/sigproc/window.py +29 -83
- ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
- ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
- {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/math/log.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
|
-
|
|
1
|
+
"""Take the logarithm of the data."""
|
|
2
|
+
|
|
3
|
+
# TODO: Array API
|
|
2
4
|
import ezmsg.core as ez
|
|
5
|
+
import numpy as np
|
|
3
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
7
|
from ezmsg.util.messages.util import replace
|
|
5
8
|
|
|
@@ -17,11 +20,7 @@ class LogSettings(ez.Settings):
|
|
|
17
20
|
class LogTransformer(BaseTransformer[LogSettings, AxisArray, AxisArray]):
|
|
18
21
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
19
22
|
data = message.data
|
|
20
|
-
if (
|
|
21
|
-
self.settings.clip_zero
|
|
22
|
-
and np.any(data <= 0)
|
|
23
|
-
and np.issubdtype(data.dtype, np.floating)
|
|
24
|
-
):
|
|
23
|
+
if self.settings.clip_zero and np.any(data <= 0) and np.issubdtype(data.dtype, np.floating):
|
|
25
24
|
data = np.clip(data, a_min=np.finfo(data.dtype).tiny, a_max=None)
|
|
26
25
|
return replace(message, data=np.log(data) / np.log(self.settings.base))
|
|
27
26
|
|
ezmsg/sigproc/math/scale.py
CHANGED
ezmsg/sigproc/messages.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
import time
|
|
2
|
+
import warnings
|
|
3
3
|
|
|
4
4
|
import numpy.typing as npt
|
|
5
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
6
|
|
|
7
|
-
|
|
8
7
|
# UPCOMING: TSMessage Deprecation
|
|
9
8
|
# TSMessage is deprecated because it doesn't handle multiple time axes well.
|
|
10
9
|
# AxisArray has an incompatible API but supports a superset of functionality.
|
ezmsg/sigproc/quantize.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
3
4
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
4
5
|
|
|
5
|
-
from .base import BaseTransformer, BaseTransformerUnit
|
|
6
|
-
|
|
7
6
|
|
|
8
7
|
class QuantizeSettings(ez.Settings):
|
|
9
8
|
"""
|
|
@@ -65,7 +64,5 @@ class QuantizeTransformer(BaseTransformer[QuantizeSettings, AxisArray, AxisArray
|
|
|
65
64
|
return replace(message, data=data)
|
|
66
65
|
|
|
67
66
|
|
|
68
|
-
class QuantizerUnit(
|
|
69
|
-
BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]
|
|
70
|
-
):
|
|
67
|
+
class QuantizerUnit(BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]):
|
|
71
68
|
SETTINGS = QuantizeSettings
|
ezmsg/sigproc/resample.py
CHANGED
|
@@ -2,17 +2,17 @@ import asyncio
|
|
|
2
2
|
import math
|
|
3
3
|
import time
|
|
4
4
|
|
|
5
|
+
import ezmsg.core as ez
|
|
5
6
|
import numpy as np
|
|
6
7
|
import scipy.interpolate
|
|
7
|
-
|
|
8
|
-
from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
|
|
9
|
-
from ezmsg.util.messages.util import replace
|
|
10
|
-
|
|
11
|
-
from .base import (
|
|
12
|
-
BaseStatefulProcessor,
|
|
8
|
+
from ezmsg.baseproc import (
|
|
13
9
|
BaseConsumerUnit,
|
|
10
|
+
BaseStatefulProcessor,
|
|
14
11
|
processor_state,
|
|
15
12
|
)
|
|
13
|
+
from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
|
|
14
|
+
from ezmsg.util.messages.util import replace
|
|
15
|
+
|
|
16
16
|
from .util.axisarray_buffer import HybridAxisArrayBuffer, HybridAxisBuffer
|
|
17
17
|
from .util.buffer import UpdateStrategy
|
|
18
18
|
|
|
@@ -29,7 +29,7 @@ class ResampleSettings(ez.Settings):
|
|
|
29
29
|
fill_value: str = "extrapolate"
|
|
30
30
|
"""
|
|
31
31
|
Value to use for out-of-bounds samples.
|
|
32
|
-
If 'extrapolate', the transformer will extrapolate.
|
|
32
|
+
If 'extrapolate', the transformer will extrapolate.
|
|
33
33
|
If 'last', the transformer will use the last sample.
|
|
34
34
|
See scipy.interpolate.interp1d for more options.
|
|
35
35
|
"""
|
|
@@ -57,9 +57,9 @@ class ResampleState:
|
|
|
57
57
|
"""
|
|
58
58
|
The buffer for the reference axis (usually a time axis). The interpolation function
|
|
59
59
|
will be evaluated at the reference axis values.
|
|
60
|
-
When resample_rate is None, this buffer will be filled with the axis from incoming
|
|
60
|
+
When resample_rate is None, this buffer will be filled with the axis from incoming
|
|
61
61
|
_reference_ messages.
|
|
62
|
-
When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
|
|
62
|
+
When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
|
|
63
63
|
is filled with a synthetic axis that is generated from the incoming signal messages.
|
|
64
64
|
"""
|
|
65
65
|
|
|
@@ -67,7 +67,7 @@ class ResampleState:
|
|
|
67
67
|
"""
|
|
68
68
|
The last value of the reference axis that was returned. This helps us to know
|
|
69
69
|
what the _next_ returned value should be, and to avoid returning the same value.
|
|
70
|
-
TODO: We can eliminate this variable if we maintain "by convention" that the
|
|
70
|
+
TODO: We can eliminate this variable if we maintain "by convention" that the
|
|
71
71
|
reference axis always has 1 value at its start that we exclude from the resampling.
|
|
72
72
|
"""
|
|
73
73
|
|
|
@@ -79,9 +79,7 @@ class ResampleState:
|
|
|
79
79
|
"""
|
|
80
80
|
|
|
81
81
|
|
|
82
|
-
class ResampleProcessor(
|
|
83
|
-
BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]
|
|
84
|
-
):
|
|
82
|
+
class ResampleProcessor(BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]):
|
|
85
83
|
def _hash_message(self, message: AxisArray) -> int:
|
|
86
84
|
ax_idx: int = message.get_axis_idx(self.settings.axis)
|
|
87
85
|
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
@@ -135,17 +133,11 @@ class ResampleProcessor(
|
|
|
135
133
|
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
136
134
|
if self.settings.resample_rate is not None and message.data.shape[ax_idx] > 0:
|
|
137
135
|
in_ax = message.axes[self.settings.axis]
|
|
138
|
-
in_t_end = (
|
|
139
|
-
in_ax.data[-1]
|
|
140
|
-
if hasattr(in_ax, "data")
|
|
141
|
-
else in_ax.value(message.data.shape[ax_idx] - 1)
|
|
142
|
-
)
|
|
136
|
+
in_t_end = in_ax.data[-1] if hasattr(in_ax, "data") else in_ax.value(message.data.shape[ax_idx] - 1)
|
|
143
137
|
out_gain = 1 / self.settings.resample_rate
|
|
144
138
|
prev_t_end = self.state.last_ref_ax_val
|
|
145
139
|
n_synth = math.ceil((in_t_end - prev_t_end) * self.settings.resample_rate)
|
|
146
|
-
synth_ref_axis = LinearAxis(
|
|
147
|
-
unit="s", gain=out_gain, offset=prev_t_end + out_gain
|
|
148
|
-
)
|
|
140
|
+
synth_ref_axis = LinearAxis(unit="s", gain=out_gain, offset=prev_t_end + out_gain)
|
|
149
141
|
self.state.ref_axis_buffer.write(synth_ref_axis, n_samples=n_synth)
|
|
150
142
|
|
|
151
143
|
self.state.last_write_time = time.time()
|
|
@@ -193,11 +185,7 @@ class ResampleProcessor(
|
|
|
193
185
|
# Get source to train interpolation
|
|
194
186
|
src_axarr = src.peek()
|
|
195
187
|
src_axis = src_axarr.axes[self.settings.axis]
|
|
196
|
-
x = (
|
|
197
|
-
src_axis.data
|
|
198
|
-
if hasattr(src_axis, "data")
|
|
199
|
-
else src_axis.value(np.arange(src_axarr.data.shape[0]))
|
|
200
|
-
)
|
|
188
|
+
x = src_axis.data if hasattr(src_axis, "data") else src_axis.value(np.arange(src_axarr.data.shape[0]))
|
|
201
189
|
|
|
202
190
|
# Only resample at reference values that have not been interpolated over previously.
|
|
203
191
|
b_ref = ref_xvec > self.state.last_ref_ax_val
|
|
@@ -208,11 +196,7 @@ class ResampleProcessor(
|
|
|
208
196
|
|
|
209
197
|
if len(ref_idx) == 0:
|
|
210
198
|
# Nothing to interpolate over; return empty data
|
|
211
|
-
null_ref = (
|
|
212
|
-
replace(ref_ax, data=ref_ax.data[:0])
|
|
213
|
-
if hasattr(ref_ax, "data")
|
|
214
|
-
else ref_ax
|
|
215
|
-
)
|
|
199
|
+
null_ref = replace(ref_ax, data=ref_ax.data[:0]) if hasattr(ref_ax, "data") else ref_ax
|
|
216
200
|
return replace(
|
|
217
201
|
src_axarr,
|
|
218
202
|
data=src_axarr.data[:0, ...],
|
|
@@ -222,17 +206,12 @@ class ResampleProcessor(
|
|
|
222
206
|
xnew = ref_xvec[ref_idx]
|
|
223
207
|
|
|
224
208
|
# Identify source data indices around ref tvec with some padding for better interpolation.
|
|
225
|
-
src_start_ix = max(
|
|
226
|
-
0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0
|
|
227
|
-
)
|
|
209
|
+
src_start_ix = max(0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0)
|
|
228
210
|
|
|
229
211
|
x = x[src_start_ix:]
|
|
230
212
|
y = src_axarr.data[src_start_ix:]
|
|
231
213
|
|
|
232
|
-
if (
|
|
233
|
-
isinstance(self.settings.fill_value, str)
|
|
234
|
-
and self.settings.fill_value == "last"
|
|
235
|
-
):
|
|
214
|
+
if isinstance(self.settings.fill_value, str) and self.settings.fill_value == "last":
|
|
236
215
|
fill_value = (y[0], y[-1])
|
|
237
216
|
else:
|
|
238
217
|
fill_value = self.settings.fill_value
|
ezmsg/sigproc/rollingscaler.py
CHANGED
|
@@ -3,14 +3,15 @@ from collections import deque
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
|
+
|
|
6
9
|
from ezmsg.sigproc.base import (
|
|
7
10
|
BaseAdaptiveTransformer,
|
|
8
11
|
BaseAdaptiveTransformerUnit,
|
|
9
12
|
processor_state,
|
|
10
13
|
)
|
|
11
14
|
from ezmsg.sigproc.sampler import SampleMessage
|
|
12
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
|
-
from ezmsg.util.messages.util import replace
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class RollingScalerSettings(ez.Settings):
|
|
@@ -71,11 +72,7 @@ class RollingScalerState:
|
|
|
71
72
|
min_samples: int | None = None
|
|
72
73
|
|
|
73
74
|
|
|
74
|
-
class RollingScalerProcessor(
|
|
75
|
-
BaseAdaptiveTransformer[
|
|
76
|
-
RollingScalerSettings, AxisArray, AxisArray, RollingScalerState
|
|
77
|
-
]
|
|
78
|
-
):
|
|
75
|
+
class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]):
|
|
79
76
|
"""
|
|
80
77
|
Processor for rolling z-score normalization of input `AxisArray` messages.
|
|
81
78
|
|
|
@@ -119,40 +116,23 @@ class RollingScalerProcessor(
|
|
|
119
116
|
self._state.N = 0
|
|
120
117
|
self._state.M2 = np.zeros(ch)
|
|
121
118
|
self._state.k_samples = (
|
|
122
|
-
int(
|
|
123
|
-
np.ceil(
|
|
124
|
-
self.settings.window_size / message.axes[self.settings.axis].gain
|
|
125
|
-
)
|
|
126
|
-
)
|
|
119
|
+
int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
|
|
127
120
|
if self.settings.window_size is not None
|
|
128
121
|
else self.settings.k_samples
|
|
129
122
|
)
|
|
130
123
|
if self._state.k_samples is not None and self._state.k_samples < 1:
|
|
131
|
-
ez.logger.warning(
|
|
132
|
-
"window_size smaller than sample gain; setting k_samples to 1."
|
|
133
|
-
)
|
|
124
|
+
ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.")
|
|
134
125
|
self._state.k_samples = 1
|
|
135
126
|
elif self._state.k_samples is None:
|
|
136
|
-
ez.logger.warning(
|
|
137
|
-
"k_samples is None; z-score accumulation will be unbounded."
|
|
138
|
-
)
|
|
127
|
+
ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
|
|
139
128
|
self._state.samples = deque(maxlen=self._state.k_samples)
|
|
140
129
|
self._state.min_samples = (
|
|
141
|
-
int(
|
|
142
|
-
np.ceil(
|
|
143
|
-
self.settings.min_seconds / message.axes[self.settings.axis].gain
|
|
144
|
-
)
|
|
145
|
-
)
|
|
130
|
+
int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
|
|
146
131
|
if self.settings.window_size is not None
|
|
147
132
|
else self.settings.min_samples
|
|
148
133
|
)
|
|
149
|
-
if
|
|
150
|
-
|
|
151
|
-
and self._state.min_samples > self._state.k_samples
|
|
152
|
-
):
|
|
153
|
-
ez.logger.warning(
|
|
154
|
-
"min_samples is greater than k_samples; adjusting min_samples to k_samples."
|
|
155
|
-
)
|
|
134
|
+
if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples:
|
|
135
|
+
ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.")
|
|
156
136
|
self._state.min_samples = self._state.k_samples
|
|
157
137
|
|
|
158
138
|
def _add_batch_stats(self, x: npt.NDArray) -> None:
|
|
@@ -161,10 +141,7 @@ class RollingScalerProcessor(
|
|
|
161
141
|
mean_b = np.mean(x, axis=0)
|
|
162
142
|
M2_b = np.sum((x - mean_b) ** 2, axis=0)
|
|
163
143
|
|
|
164
|
-
if (
|
|
165
|
-
self._state.k_samples is not None
|
|
166
|
-
and len(self._state.samples) == self._state.k_samples
|
|
167
|
-
):
|
|
144
|
+
if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
|
|
168
145
|
n_old, mean_old, M2_old = self._state.samples.popleft()
|
|
169
146
|
N_T = self._state.N
|
|
170
147
|
N_new = N_T - n_old
|
|
@@ -177,9 +154,7 @@ class RollingScalerProcessor(
|
|
|
177
154
|
delta = mean_old - self._state.mean
|
|
178
155
|
self._state.N = N_new
|
|
179
156
|
self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new
|
|
180
|
-
self._state.M2 = (
|
|
181
|
-
self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
|
|
182
|
-
)
|
|
157
|
+
self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new)
|
|
183
158
|
|
|
184
159
|
N_A = self._state.N
|
|
185
160
|
N = N_A + n_b
|
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from collections import deque
|
|
3
2
|
import copy
|
|
4
3
|
import traceback
|
|
5
4
|
import typing
|
|
5
|
+
from collections import deque
|
|
6
6
|
|
|
7
|
-
import numpy as np
|
|
8
7
|
import ezmsg.core as ez
|
|
8
|
+
import numpy as np
|
|
9
|
+
from ezmsg.baseproc import (
|
|
10
|
+
BaseConsumerUnit,
|
|
11
|
+
BaseProducerUnit,
|
|
12
|
+
BaseStatefulProducer,
|
|
13
|
+
BaseStatefulTransformer,
|
|
14
|
+
BaseTransformerUnit,
|
|
15
|
+
processor_state,
|
|
16
|
+
)
|
|
9
17
|
from ezmsg.util.messages.axisarray import (
|
|
10
18
|
AxisArray,
|
|
11
19
|
)
|
|
12
20
|
from ezmsg.util.messages.util import replace
|
|
13
21
|
|
|
14
|
-
from .util.profile import profile_subpub
|
|
15
22
|
from .util.axisarray_buffer import HybridAxisArrayBuffer
|
|
16
23
|
from .util.buffer import UpdateStrategy
|
|
17
24
|
from .util.message import SampleMessage, SampleTriggerMessage
|
|
18
|
-
from .
|
|
19
|
-
BaseStatefulTransformer,
|
|
20
|
-
BaseConsumerUnit,
|
|
21
|
-
BaseTransformerUnit,
|
|
22
|
-
BaseStatefulProducer,
|
|
23
|
-
BaseProducerUnit,
|
|
24
|
-
processor_state,
|
|
25
|
-
)
|
|
25
|
+
from .util.profile import profile_subpub
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class SamplerSettings(ez.Settings):
|
|
@@ -74,12 +74,8 @@ class SamplerState:
|
|
|
74
74
|
triggers: deque[SampleTriggerMessage] | None = None
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
class SamplerTransformer(
|
|
78
|
-
|
|
79
|
-
):
|
|
80
|
-
def __call__(
|
|
81
|
-
self, message: AxisArray | SampleTriggerMessage
|
|
82
|
-
) -> list[SampleMessage]:
|
|
77
|
+
class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]):
|
|
78
|
+
def __call__(self, message: AxisArray | SampleTriggerMessage) -> list[SampleMessage]:
|
|
83
79
|
# TODO: Currently we have a single entry point that accepts both
|
|
84
80
|
# data and trigger messages and we choose a code path based on
|
|
85
81
|
# the message type. However, in the future we will likely replace
|
|
@@ -99,9 +95,7 @@ class SamplerTransformer(
|
|
|
99
95
|
# Compute hash based on message properties that require state reset
|
|
100
96
|
axis = self.settings.axis or message.dims[0]
|
|
101
97
|
axis_idx = message.get_axis_idx(axis)
|
|
102
|
-
sample_shape =
|
|
103
|
-
message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
104
|
-
)
|
|
98
|
+
sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
105
99
|
return hash((sample_shape, message.key))
|
|
106
100
|
|
|
107
101
|
def _reset_state(self, message: AxisArray) -> None:
|
|
@@ -193,20 +187,14 @@ class SamplerTransformer(
|
|
|
193
187
|
trigger_ts: float = message.timestamp
|
|
194
188
|
if not self.settings.estimate_alignment:
|
|
195
189
|
# Override the trigger timestamp with the next sample's likely timestamp.
|
|
196
|
-
trigger_ts =
|
|
197
|
-
self._state.buffer.axis_final_value + self._state.buffer.axis_gain
|
|
198
|
-
)
|
|
190
|
+
trigger_ts = self._state.buffer.axis_final_value + self._state.buffer.axis_gain
|
|
199
191
|
|
|
200
|
-
new_trig_msg = replace(
|
|
201
|
-
message, timestamp=trigger_ts, period=_period, value=_value
|
|
202
|
-
)
|
|
192
|
+
new_trig_msg = replace(message, timestamp=trigger_ts, period=_period, value=_value)
|
|
203
193
|
self._state.triggers.append(new_trig_msg)
|
|
204
194
|
return []
|
|
205
195
|
|
|
206
196
|
|
|
207
|
-
class Sampler(
|
|
208
|
-
BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]
|
|
209
|
-
):
|
|
197
|
+
class Sampler(BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]):
|
|
210
198
|
SETTINGS = SamplerSettings
|
|
211
199
|
|
|
212
200
|
INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
|
|
@@ -269,19 +257,13 @@ class TriggerGeneratorState:
|
|
|
269
257
|
output: int = 0
|
|
270
258
|
|
|
271
259
|
|
|
272
|
-
class TriggerProducer(
|
|
273
|
-
BaseStatefulProducer[
|
|
274
|
-
TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState
|
|
275
|
-
]
|
|
276
|
-
):
|
|
260
|
+
class TriggerProducer(BaseStatefulProducer[TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState]):
|
|
277
261
|
def _reset_state(self) -> None:
|
|
278
262
|
self._state.output = 0
|
|
279
263
|
|
|
280
264
|
async def _produce(self) -> SampleTriggerMessage:
|
|
281
265
|
await asyncio.sleep(self.settings.publish_period)
|
|
282
|
-
out_msg = SampleTriggerMessage(
|
|
283
|
-
period=self.settings.period, value=self._state.output
|
|
284
|
-
)
|
|
266
|
+
out_msg = SampleTriggerMessage(period=self.settings.period, value=self._state.output)
|
|
285
267
|
self._state.output += 1
|
|
286
268
|
return out_msg
|
|
287
269
|
|
ezmsg/sigproc/scaler.py
CHANGED
|
@@ -1,27 +1,24 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
-
from ezmsg.
|
|
5
|
-
from ezmsg.util.messages.util import replace
|
|
6
|
-
from ezmsg.util.generator import consumer
|
|
7
|
-
|
|
8
|
-
from .base import (
|
|
4
|
+
from ezmsg.baseproc import (
|
|
9
5
|
BaseStatefulTransformer,
|
|
10
6
|
BaseTransformerUnit,
|
|
11
7
|
processor_state,
|
|
12
8
|
)
|
|
13
|
-
from .
|
|
9
|
+
from ezmsg.util.generator import consumer
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
from ezmsg.util.messages.util import replace
|
|
14
12
|
|
|
15
13
|
# Imports for backwards compatibility with previous module location
|
|
16
14
|
from .ewma import EWMA_Deprecated as EWMA_Deprecated
|
|
17
|
-
from .ewma import
|
|
15
|
+
from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
|
|
18
16
|
from .ewma import _tau_from_alpha as _tau_from_alpha
|
|
17
|
+
from .ewma import ewma_step as ewma_step
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
@consumer
|
|
22
|
-
def scaler(
|
|
23
|
-
time_constant: float = 1.0, axis: str | None = None
|
|
24
|
-
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
21
|
+
def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
25
22
|
"""
|
|
26
23
|
Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
|
|
27
24
|
This is faster than :obj:`scaler_np` for single-channel data.
|
|
@@ -85,19 +82,13 @@ class AdaptiveStandardScalerTransformer(
|
|
|
85
82
|
]
|
|
86
83
|
):
|
|
87
84
|
def _reset_state(self, message: AxisArray) -> None:
|
|
88
|
-
self._state.samps_ewma = EWMATransformer(
|
|
89
|
-
|
|
90
|
-
)
|
|
91
|
-
self._state.vars_sq_ewma = EWMATransformer(
|
|
92
|
-
time_constant=self.settings.time_constant, axis=self.settings.axis
|
|
93
|
-
)
|
|
85
|
+
self._state.samps_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
|
|
86
|
+
self._state.vars_sq_ewma = EWMATransformer(time_constant=self.settings.time_constant, axis=self.settings.axis)
|
|
94
87
|
|
|
95
88
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
96
89
|
# Update step
|
|
97
90
|
mean_message = self._state.samps_ewma(message)
|
|
98
|
-
var_sq_message = self._state.vars_sq_ewma(
|
|
99
|
-
replace(message, data=message.data**2)
|
|
100
|
-
)
|
|
91
|
+
var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
|
|
101
92
|
|
|
102
93
|
# Get step
|
|
103
94
|
varis = var_sq_message.data - mean_message.data**2
|
|
@@ -119,9 +110,7 @@ class AdaptiveStandardScaler(
|
|
|
119
110
|
|
|
120
111
|
|
|
121
112
|
# Backwards compatibility...
|
|
122
|
-
def scaler_np(
|
|
123
|
-
time_constant: float = 1.0, axis: str | None = None
|
|
124
|
-
) -> AdaptiveStandardScalerTransformer:
|
|
113
|
+
def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
|
|
125
114
|
return AdaptiveStandardScalerTransformer(
|
|
126
115
|
settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
|
|
127
116
|
)
|
ezmsg/sigproc/signalinjector.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
3
|
-
from ezmsg.util.messages.util import replace
|
|
4
2
|
import numpy as np
|
|
5
3
|
import numpy.typing as npt
|
|
6
|
-
|
|
7
|
-
from .base import (
|
|
4
|
+
from ezmsg.baseproc import (
|
|
8
5
|
BaseAsyncTransformer,
|
|
9
6
|
BaseTransformerUnit,
|
|
10
7
|
processor_state,
|
|
11
8
|
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
class SignalInjectorSettings(ez.Settings):
|
|
@@ -27,15 +26,11 @@ class SignalInjectorState:
|
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
class SignalInjectorTransformer(
|
|
30
|
-
BaseAsyncTransformer[
|
|
31
|
-
SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState
|
|
32
|
-
]
|
|
29
|
+
BaseAsyncTransformer[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState]
|
|
33
30
|
):
|
|
34
31
|
def _hash_message(self, message: AxisArray) -> int:
|
|
35
32
|
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
36
|
-
sample_shape =
|
|
37
|
-
message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
38
|
-
)
|
|
33
|
+
sample_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
39
34
|
return hash((message.key,) + sample_shape)
|
|
40
35
|
|
|
41
36
|
def _reset_state(self, message: AxisArray) -> None:
|
|
@@ -44,9 +39,7 @@ class SignalInjectorTransformer(
|
|
|
44
39
|
if self._state.cur_amplitude is None:
|
|
45
40
|
self._state.cur_amplitude = self.settings.amplitude
|
|
46
41
|
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
47
|
-
self._state.cur_shape =
|
|
48
|
-
message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
49
|
-
)
|
|
42
|
+
self._state.cur_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
50
43
|
rng = np.random.default_rng(self.settings.mixing_seed)
|
|
51
44
|
self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
|
|
52
45
|
self._state.mixing = (self._state.mixing * 2.0) - 1.0
|
|
@@ -63,11 +56,7 @@ class SignalInjectorTransformer(
|
|
|
63
56
|
return out_msg
|
|
64
57
|
|
|
65
58
|
|
|
66
|
-
class SignalInjector(
|
|
67
|
-
BaseTransformerUnit[
|
|
68
|
-
SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer
|
|
69
|
-
]
|
|
70
|
-
):
|
|
59
|
+
class SignalInjector(BaseTransformerUnit[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer]):
|
|
71
60
|
SETTINGS = SignalInjectorSettings
|
|
72
61
|
INPUT_FREQUENCY = ez.InputStream(float | None)
|
|
73
62
|
INPUT_AMPLITUDE = ez.InputStream(float)
|
ezmsg/sigproc/slicer.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
1
2
|
import numpy as np
|
|
2
3
|
import numpy.typing as npt
|
|
3
|
-
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseStatefulTransformer,
|
|
6
|
+
BaseTransformerUnit,
|
|
7
|
+
processor_state,
|
|
8
|
+
)
|
|
4
9
|
from ezmsg.util.messages.axisarray import (
|
|
5
10
|
AxisArray,
|
|
6
|
-
slice_along_axis,
|
|
7
11
|
AxisBase,
|
|
8
12
|
replace,
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from .base import (
|
|
12
|
-
BaseStatefulTransformer,
|
|
13
|
-
BaseTransformerUnit,
|
|
14
|
-
processor_state,
|
|
13
|
+
slice_along_axis,
|
|
15
14
|
)
|
|
16
15
|
|
|
17
16
|
"""
|
|
@@ -49,11 +48,7 @@ def parse_slice(
|
|
|
49
48
|
if "," not in s:
|
|
50
49
|
parts = [part.strip() for part in s.split(":")]
|
|
51
50
|
if len(parts) == 1:
|
|
52
|
-
if (
|
|
53
|
-
axinfo is not None
|
|
54
|
-
and hasattr(axinfo, "data")
|
|
55
|
-
and parts[0] in axinfo.data
|
|
56
|
-
):
|
|
51
|
+
if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
|
|
57
52
|
return tuple(np.where(axinfo.data == parts[0])[0])
|
|
58
53
|
return (int(parts[0]),)
|
|
59
54
|
return (slice(*(int(part.strip()) if part else None for part in parts)),)
|
|
@@ -76,9 +71,7 @@ class SlicerState:
|
|
|
76
71
|
b_change_dims: bool = False
|
|
77
72
|
|
|
78
73
|
|
|
79
|
-
class SlicerTransformer(
|
|
80
|
-
BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]
|
|
81
|
-
):
|
|
74
|
+
class SlicerTransformer(BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]):
|
|
82
75
|
def _hash_message(self, message: AxisArray) -> int:
|
|
83
76
|
axis = self.settings.axis or message.dims[-1]
|
|
84
77
|
axis_idx = message.get_axis_idx(axis)
|
|
@@ -101,11 +94,7 @@ class SlicerTransformer(
|
|
|
101
94
|
self._state.slice_ = np.s_[indices]
|
|
102
95
|
|
|
103
96
|
# Create the output axis
|
|
104
|
-
if (
|
|
105
|
-
axis in message.axes
|
|
106
|
-
and hasattr(message.axes[axis], "data")
|
|
107
|
-
and len(message.axes[axis].data) > 0
|
|
108
|
-
):
|
|
97
|
+
if axis in message.axes and hasattr(message.axes[axis], "data") and len(message.axes[axis].data) > 0:
|
|
109
98
|
in_data = np.array(message.axes[axis].data)
|
|
110
99
|
if self._state.b_change_dims:
|
|
111
100
|
out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
|
|
@@ -119,17 +108,10 @@ class SlicerTransformer(
|
|
|
119
108
|
|
|
120
109
|
replace_kwargs = {}
|
|
121
110
|
if self._state.b_change_dims:
|
|
122
|
-
replace_kwargs["dims"] = [
|
|
123
|
-
|
|
124
|
-
]
|
|
125
|
-
replace_kwargs["axes"] = {
|
|
126
|
-
k: v for k, v in message.axes.items() if k != axis
|
|
127
|
-
}
|
|
111
|
+
replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx]
|
|
112
|
+
replace_kwargs["axes"] = {k: v for k, v in message.axes.items() if k != axis}
|
|
128
113
|
elif self._state.new_axis is not None:
|
|
129
|
-
replace_kwargs["axes"] = {
|
|
130
|
-
k: (v if k != axis else self._state.new_axis)
|
|
131
|
-
for k, v in message.axes.items()
|
|
132
|
-
}
|
|
114
|
+
replace_kwargs["axes"] = {k: (v if k != axis else self._state.new_axis) for k, v in message.axes.items()}
|
|
133
115
|
|
|
134
116
|
return replace(
|
|
135
117
|
message,
|
|
@@ -138,9 +120,7 @@ class SlicerTransformer(
|
|
|
138
120
|
)
|
|
139
121
|
|
|
140
122
|
|
|
141
|
-
class Slicer(
|
|
142
|
-
BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]
|
|
143
|
-
):
|
|
123
|
+
class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
|
|
144
124
|
SETTINGS = SlicerSettings
|
|
145
125
|
|
|
146
126
|
|
ezmsg/sigproc/spectral.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from .spectrum import OptionsEnum as OptionsEnum
|
|
2
|
-
from .spectrum import WindowFunction as WindowFunction
|
|
3
|
-
from .spectrum import SpectralTransform as SpectralTransform
|
|
4
2
|
from .spectrum import SpectralOutput as SpectralOutput
|
|
5
|
-
from .spectrum import
|
|
3
|
+
from .spectrum import SpectralTransform as SpectralTransform
|
|
6
4
|
from .spectrum import Spectrum as Spectrum
|
|
5
|
+
from .spectrum import SpectrumSettings as SpectrumSettings
|
|
6
|
+
from .spectrum import WindowFunction as WindowFunction
|