ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -4
- ezmsg/sigproc/__version__.py +16 -0
- ezmsg/sigproc/activation.py +75 -0
- ezmsg/sigproc/affinetransform.py +149 -39
- ezmsg/sigproc/aggregate.py +84 -29
- ezmsg/sigproc/bandpower.py +36 -15
- ezmsg/sigproc/base.py +38 -0
- ezmsg/sigproc/butterworthfilter.py +76 -20
- ezmsg/sigproc/decimate.py +7 -4
- ezmsg/sigproc/downsample.py +79 -61
- ezmsg/sigproc/ewmfilter.py +28 -14
- ezmsg/sigproc/filter.py +51 -31
- ezmsg/sigproc/filterbank.py +278 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +28 -0
- ezmsg/sigproc/math/clip.py +30 -0
- ezmsg/sigproc/math/difference.py +60 -0
- ezmsg/sigproc/math/invert.py +29 -0
- ezmsg/sigproc/math/log.py +32 -0
- ezmsg/sigproc/math/scale.py +31 -0
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +152 -90
- ezmsg/sigproc/scaler.py +88 -42
- ezmsg/sigproc/signalinjector.py +7 -10
- ezmsg/sigproc/slicer.py +71 -36
- ezmsg/sigproc/spectral.py +6 -9
- ezmsg/sigproc/spectrogram.py +48 -30
- ezmsg/sigproc/spectrum.py +177 -76
- ezmsg/sigproc/synth.py +162 -67
- ezmsg/sigproc/wavelets.py +167 -0
- ezmsg/sigproc/window.py +193 -157
- ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
- ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
- {ezmsg_sigproc-1.2.3.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -1
- ezmsg_sigproc-1.2.3.dist-info/METADATA +0 -38
- ezmsg_sigproc-1.2.3.dist-info/RECORD +0 -23
- {ezmsg_sigproc-1.2.3.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,81 +1,94 @@
|
|
|
1
|
+
import asyncio # Dev/test apparatus
|
|
1
2
|
from collections import deque
|
|
2
3
|
from dataclasses import dataclass, replace, field
|
|
3
4
|
import time
|
|
4
|
-
|
|
5
|
+
import typing
|
|
5
6
|
|
|
6
|
-
import ezmsg.core as ez
|
|
7
7
|
import numpy as np
|
|
8
|
-
|
|
8
|
+
import numpy.typing as npt
|
|
9
|
+
import ezmsg.core as ez
|
|
9
10
|
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
10
11
|
from ezmsg.util.generator import consumer
|
|
11
12
|
|
|
12
|
-
# Dev/test apparatus
|
|
13
|
-
import asyncio
|
|
14
|
-
|
|
15
13
|
|
|
16
14
|
@dataclass(unsafe_hash=True)
|
|
17
15
|
class SampleTriggerMessage:
|
|
18
16
|
timestamp: float = field(default_factory=time.time)
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
"""Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
|
|
18
|
+
|
|
19
|
+
period: typing.Optional[typing.Tuple[float, float]] = None
|
|
20
|
+
"""The period around the timestamp, in seconds"""
|
|
21
|
+
|
|
22
|
+
value: typing.Any = None
|
|
23
|
+
"""A value or 'label' associated with the trigger."""
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
@dataclass
|
|
24
27
|
class SampleMessage:
|
|
25
28
|
trigger: SampleTriggerMessage
|
|
29
|
+
"""The time, window, and value (if any) associated with the trigger."""
|
|
30
|
+
|
|
26
31
|
sample: AxisArray
|
|
32
|
+
"""The data sampled around the trigger."""
|
|
27
33
|
|
|
28
34
|
|
|
29
35
|
@consumer
|
|
30
36
|
def sampler(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
) -> Generator[
|
|
37
|
+
buffer_dur: float,
|
|
38
|
+
axis: typing.Optional[str] = None,
|
|
39
|
+
period: typing.Optional[typing.Tuple[float, float]] = None,
|
|
40
|
+
value: typing.Any = None,
|
|
41
|
+
estimate_alignment: bool = True,
|
|
42
|
+
) -> typing.Generator[
|
|
43
|
+
typing.List[SampleMessage], typing.Union[AxisArray, SampleTriggerMessage], None
|
|
44
|
+
]:
|
|
37
45
|
"""
|
|
38
46
|
A generator function that samples data into a buffer, accepts triggers, and returns slices of sampled
|
|
39
47
|
data around the trigger time.
|
|
40
48
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
Yields:
|
|
60
|
-
- list[SampleMessage]: The list of sample messages.
|
|
49
|
+
Args:
|
|
50
|
+
buffer_dur: The duration of the buffer in seconds. The buffer must be long enough to store the oldest
|
|
51
|
+
sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
|
|
52
|
+
need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
|
|
53
|
+
axis: The axis along which to sample the data.
|
|
54
|
+
None (default) will choose the first axis in the first input.
|
|
55
|
+
period: The period in seconds during which to sample the data.
|
|
56
|
+
Defaults to None. Only used if not None and the trigger message does not define its own period.
|
|
57
|
+
value: The value to sample. Defaults to None.
|
|
58
|
+
estimate_alignment: Whether to estimate the sample alignment. Defaults to True.
|
|
59
|
+
If True, the trigger timestamp field is used to slice the buffer.
|
|
60
|
+
If False, the trigger timestamp is ignored and the next signal's .offset is used.
|
|
61
|
+
NOTE: For faster-than-realtime playback -- Signals and triggers must share the same (fast) clock for
|
|
62
|
+
estimate_alignment to operate correctly.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
A generator that expects `.send` either an :obj:`AxisArray` containing streaming data messages,
|
|
66
|
+
or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`SampleMessage` s.
|
|
61
67
|
"""
|
|
62
|
-
|
|
63
|
-
msg_out: Optional[list[SampleMessage]] = None
|
|
68
|
+
msg_out: list[SampleMessage] = []
|
|
64
69
|
|
|
65
70
|
# State variables (most shared between trigger- and data-processing.
|
|
66
71
|
triggers: deque[SampleTriggerMessage] = deque()
|
|
67
|
-
|
|
68
|
-
|
|
72
|
+
buffer: typing.Optional[npt.NDArray] = None
|
|
73
|
+
n_samples: int = 0
|
|
74
|
+
offset: float = 0.0
|
|
75
|
+
|
|
76
|
+
check_inputs = {
|
|
77
|
+
"fs": None, # Also a state variable
|
|
78
|
+
"key": None,
|
|
79
|
+
"shape": None,
|
|
80
|
+
}
|
|
69
81
|
|
|
70
82
|
while True:
|
|
71
83
|
msg_in = yield msg_out
|
|
72
84
|
msg_out = []
|
|
85
|
+
|
|
73
86
|
if isinstance(msg_in, SampleTriggerMessage):
|
|
74
|
-
|
|
87
|
+
# Input is a trigger message that we will use to sample the buffer.
|
|
88
|
+
|
|
89
|
+
if buffer is None or check_inputs["fs"] is None:
|
|
75
90
|
# We've yet to see any data; drop the trigger.
|
|
76
91
|
continue
|
|
77
|
-
fs = last_msg_stats["fs"]
|
|
78
|
-
axis_idx = last_msg_stats["axis_idx"]
|
|
79
92
|
|
|
80
93
|
_period = msg_in.period if msg_in.period is not None else period
|
|
81
94
|
_value = msg_in.value if msg_in.value is not None else value
|
|
@@ -86,50 +99,73 @@ def sampler(
|
|
|
86
99
|
|
|
87
100
|
# Check that period is valid
|
|
88
101
|
if _period[0] >= _period[1]:
|
|
89
|
-
ez.logger.warning(
|
|
102
|
+
ez.logger.warning(
|
|
103
|
+
f"Sampling failed: invalid period requested ({_period})"
|
|
104
|
+
)
|
|
90
105
|
continue
|
|
91
106
|
|
|
92
107
|
# Check that period is compatible with buffer duration.
|
|
93
|
-
max_buf_len = int(buffer_dur * fs)
|
|
94
|
-
req_buf_len = int((_period[1] - _period[0]) * fs)
|
|
108
|
+
max_buf_len = int(np.round(buffer_dur * check_inputs["fs"]))
|
|
109
|
+
req_buf_len = int(np.round((_period[1] - _period[0]) * check_inputs["fs"]))
|
|
95
110
|
if req_buf_len >= max_buf_len:
|
|
96
|
-
ez.logger.warning(
|
|
97
|
-
f"Sampling failed: {period=} >= {buffer_dur=}"
|
|
98
|
-
)
|
|
111
|
+
ez.logger.warning(f"Sampling failed: {period=} >= {buffer_dur=}")
|
|
99
112
|
continue
|
|
100
113
|
|
|
101
114
|
trigger_ts: float = msg_in.timestamp
|
|
102
115
|
if not estimate_alignment:
|
|
103
116
|
# Override the trigger timestamp with the next sample's likely timestamp.
|
|
104
|
-
trigger_ts =
|
|
117
|
+
trigger_ts = offset + (n_samples + 1) / check_inputs["fs"]
|
|
105
118
|
|
|
106
|
-
new_trig_msg = replace(
|
|
119
|
+
new_trig_msg = replace(
|
|
120
|
+
msg_in, timestamp=trigger_ts, period=_period, value=_value
|
|
121
|
+
)
|
|
107
122
|
triggers.append(new_trig_msg)
|
|
108
123
|
|
|
109
124
|
elif isinstance(msg_in, AxisArray):
|
|
110
|
-
|
|
111
|
-
|
|
125
|
+
# Get properties from message
|
|
126
|
+
axis = axis or msg_in.dims[0]
|
|
112
127
|
axis_idx = msg_in.get_axis_idx(axis)
|
|
113
128
|
axis_info = msg_in.get_axis(axis)
|
|
114
129
|
fs = 1.0 / axis_info.gain
|
|
115
|
-
sample_shape =
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
130
|
+
sample_shape = (
|
|
131
|
+
msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# TODO: We could accommodate change in dim order.
|
|
135
|
+
# if axis_idx != check_inputs["axis_idx"]:
|
|
136
|
+
# msg_in = replace(
|
|
137
|
+
# msg_in,
|
|
138
|
+
# data=np.moveaxis(msg_in.data, axis_idx, check_inputs["axis_idx"]),
|
|
139
|
+
# dims=TODO...
|
|
140
|
+
# )
|
|
141
|
+
# axis_idx = check_inputs["axis_idx"]
|
|
142
|
+
|
|
143
|
+
# If the properties have changed in a breaking way then reset buffer and triggers.
|
|
144
|
+
b_reset = fs != check_inputs["fs"]
|
|
145
|
+
b_reset = b_reset or sample_shape != check_inputs["shape"]
|
|
146
|
+
# TODO: Skip next line if we do np.moveaxis above
|
|
147
|
+
b_reset = b_reset or axis_idx != check_inputs["axis_idx"]
|
|
148
|
+
b_reset = b_reset or msg_in.key != check_inputs["key"]
|
|
149
|
+
if b_reset:
|
|
150
|
+
check_inputs["fs"] = fs
|
|
151
|
+
check_inputs["shape"] = sample_shape
|
|
152
|
+
check_inputs["axis_idx"] = axis_idx
|
|
153
|
+
check_inputs["key"] = msg_in.key
|
|
154
|
+
n_samples = msg_in.data.shape[axis_idx]
|
|
125
155
|
buffer = None
|
|
126
156
|
if len(triggers) > 0:
|
|
127
157
|
ez.logger.warning("Data stream changed: Discarding all triggers")
|
|
128
158
|
triggers.clear()
|
|
129
|
-
|
|
159
|
+
|
|
160
|
+
# Save some info for trigger processing
|
|
161
|
+
offset = axis_info.offset
|
|
130
162
|
|
|
131
163
|
# Update buffer
|
|
132
|
-
buffer =
|
|
164
|
+
buffer = (
|
|
165
|
+
msg_in.data
|
|
166
|
+
if buffer is None
|
|
167
|
+
else np.concatenate((buffer, msg_in.data), axis=axis_idx)
|
|
168
|
+
)
|
|
133
169
|
|
|
134
170
|
# Calculate timestamps associated with buffer.
|
|
135
171
|
buffer_offset = np.arange(buffer.shape[axis_idx], dtype=float)
|
|
@@ -153,11 +189,10 @@ def sampler(
|
|
|
153
189
|
)
|
|
154
190
|
triggers.remove(trig)
|
|
155
191
|
|
|
156
|
-
# TODO: Speed up with searchsorted?
|
|
157
192
|
t_start = trig.timestamp + trig.period[0]
|
|
158
193
|
if t_start >= buffer_offset[0]:
|
|
159
194
|
start = np.searchsorted(buffer_offset, t_start)
|
|
160
|
-
stop = start + int(fs * (trig.period[1] - trig.period[0]))
|
|
195
|
+
stop = start + int(np.round(fs * (trig.period[1] - trig.period[0])))
|
|
161
196
|
if buffer.shape[axis_idx] > stop:
|
|
162
197
|
# Trigger period fully enclosed in buffer.
|
|
163
198
|
msg_out.append(
|
|
@@ -165,9 +200,16 @@ def sampler(
|
|
|
165
200
|
trigger=trig,
|
|
166
201
|
sample=replace(
|
|
167
202
|
msg_in,
|
|
168
|
-
data=slice_along_axis(
|
|
169
|
-
|
|
170
|
-
|
|
203
|
+
data=slice_along_axis(
|
|
204
|
+
buffer, slice(start, stop), axis_idx
|
|
205
|
+
),
|
|
206
|
+
axes={
|
|
207
|
+
**msg_in.axes,
|
|
208
|
+
axis: replace(
|
|
209
|
+
axis_info, offset=buffer_offset[start]
|
|
210
|
+
),
|
|
211
|
+
},
|
|
212
|
+
),
|
|
171
213
|
)
|
|
172
214
|
)
|
|
173
215
|
triggers.remove(trig)
|
|
@@ -177,29 +219,40 @@ def sampler(
|
|
|
177
219
|
|
|
178
220
|
|
|
179
221
|
class SamplerSettings(ez.Settings):
|
|
222
|
+
"""
|
|
223
|
+
Settings for :obj:`Sampler`.
|
|
224
|
+
See :obj:`sampler` for a description of the fields.
|
|
225
|
+
"""
|
|
226
|
+
|
|
180
227
|
buffer_dur: float
|
|
181
|
-
axis: Optional[str] = None
|
|
182
|
-
period: Optional[
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
value: Any = None
|
|
228
|
+
axis: typing.Optional[str] = None
|
|
229
|
+
period: typing.Optional[typing.Tuple[float, float]] = None
|
|
230
|
+
"""Optional default period if unspecified in SampleTriggerMessage"""
|
|
231
|
+
|
|
232
|
+
value: typing.Any = None
|
|
233
|
+
"""Optional default value if unspecified in SampleTriggerMessage"""
|
|
186
234
|
|
|
187
235
|
estimate_alignment: bool = True
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
236
|
+
"""
|
|
237
|
+
If true, use message timestamp fields and reported sampling rate to estimate sample-accurate alignment for samples.
|
|
238
|
+
If false, sampling will be limited to incoming message rate -- "Block timing"
|
|
239
|
+
NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
|
|
240
|
+
"realtime" operation for estimate_alignment to operate correctly.
|
|
241
|
+
"""
|
|
193
242
|
|
|
194
243
|
|
|
195
244
|
class SamplerState(ez.State):
|
|
196
245
|
cur_settings: SamplerSettings
|
|
197
|
-
gen: Generator[
|
|
246
|
+
gen: typing.Generator[
|
|
247
|
+
typing.Union[AxisArray, SampleTriggerMessage], typing.List[SampleMessage], None
|
|
248
|
+
]
|
|
198
249
|
|
|
199
250
|
|
|
200
251
|
class Sampler(ez.Unit):
|
|
201
|
-
|
|
202
|
-
|
|
252
|
+
"""An :obj:`Unit` for :obj:`sampler`."""
|
|
253
|
+
|
|
254
|
+
SETTINGS = SamplerSettings
|
|
255
|
+
STATE = SamplerState
|
|
203
256
|
|
|
204
257
|
INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
|
|
205
258
|
INPUT_SETTINGS = ez.InputStream(SamplerSettings)
|
|
@@ -212,10 +265,10 @@ class Sampler(ez.Unit):
|
|
|
212
265
|
axis=self.STATE.cur_settings.axis,
|
|
213
266
|
period=self.STATE.cur_settings.period,
|
|
214
267
|
value=self.STATE.cur_settings.value,
|
|
215
|
-
estimate_alignment=self.STATE.cur_settings.estimate_alignment
|
|
268
|
+
estimate_alignment=self.STATE.cur_settings.estimate_alignment,
|
|
216
269
|
)
|
|
217
270
|
|
|
218
|
-
def initialize(self) -> None:
|
|
271
|
+
async def initialize(self) -> None:
|
|
219
272
|
self.STATE.cur_settings = self.SETTINGS
|
|
220
273
|
self.construct_generator()
|
|
221
274
|
|
|
@@ -228,27 +281,36 @@ class Sampler(ez.Unit):
|
|
|
228
281
|
async def on_trigger(self, msg: SampleTriggerMessage) -> None:
|
|
229
282
|
_ = self.STATE.gen.send(msg)
|
|
230
283
|
|
|
231
|
-
@ez.subscriber(INPUT_SIGNAL)
|
|
284
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
232
285
|
@ez.publisher(OUTPUT_SAMPLE)
|
|
233
|
-
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
286
|
+
async def on_signal(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
234
287
|
pub_samples = self.STATE.gen.send(msg)
|
|
235
288
|
for sample in pub_samples:
|
|
236
289
|
yield self.OUTPUT_SAMPLE, sample
|
|
237
290
|
|
|
238
291
|
|
|
239
292
|
class TriggerGeneratorSettings(ez.Settings):
|
|
240
|
-
period: Tuple[float, float]
|
|
241
|
-
|
|
242
|
-
|
|
293
|
+
period: typing.Tuple[float, float]
|
|
294
|
+
"""The period around the trigger event."""
|
|
295
|
+
|
|
296
|
+
prewait: float = 0.5
|
|
297
|
+
"""The time before the first trigger (sec)"""
|
|
298
|
+
|
|
299
|
+
publish_period: float = 5.0
|
|
300
|
+
"""The period between triggers (sec)"""
|
|
243
301
|
|
|
244
302
|
|
|
245
303
|
class TriggerGenerator(ez.Unit):
|
|
246
|
-
|
|
304
|
+
"""
|
|
305
|
+
A unit to generate triggers every `publish_period` interval.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
SETTINGS = TriggerGeneratorSettings
|
|
247
309
|
|
|
248
310
|
OUTPUT_TRIGGER = ez.OutputStream(SampleTriggerMessage)
|
|
249
311
|
|
|
250
312
|
@ez.publisher(OUTPUT_TRIGGER)
|
|
251
|
-
async def generate(self) -> AsyncGenerator:
|
|
313
|
+
async def generate(self) -> typing.AsyncGenerator:
|
|
252
314
|
await asyncio.sleep(self.SETTINGS.prewait)
|
|
253
315
|
|
|
254
316
|
output = 0
|
ezmsg/sigproc/scaler.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from dataclasses import replace
|
|
2
|
-
|
|
2
|
+
import typing
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
-
|
|
5
|
+
import numpy.typing as npt
|
|
6
6
|
import ezmsg.core as ez
|
|
7
7
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
-
from ezmsg.util.generator import consumer
|
|
8
|
+
from ezmsg.util.generator import consumer
|
|
9
|
+
|
|
10
|
+
from .base import GenAxisArray
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def _tau_from_alpha(alpha: float, dt: float) -> float:
|
|
@@ -27,24 +29,39 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
@consumer
|
|
30
|
-
def scaler(
|
|
32
|
+
def scaler(
|
|
33
|
+
time_constant: float = 1.0, axis: typing.Optional[str] = None
|
|
34
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
35
|
+
"""
|
|
36
|
+
Create a generator function that applies the
|
|
37
|
+
adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
|
|
38
|
+
This is faster than :obj:`scaler_np` for single-channel data.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
time_constant: Decay constant `tau` in seconds.
|
|
42
|
+
axis: The name of the axis to accumulate statistics over.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A primed generator object that expects `.send(axis_array)` and yields a
|
|
46
|
+
standardized, or "Z-scored" version of the input.
|
|
47
|
+
"""
|
|
31
48
|
from river import preprocessing
|
|
32
|
-
|
|
33
|
-
|
|
49
|
+
|
|
50
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
34
51
|
_scaler = None
|
|
35
52
|
while True:
|
|
36
|
-
|
|
37
|
-
data =
|
|
53
|
+
msg_in: AxisArray = yield msg_out
|
|
54
|
+
data = msg_in.data
|
|
38
55
|
if axis is None:
|
|
39
|
-
axis =
|
|
56
|
+
axis = msg_in.dims[0]
|
|
40
57
|
axis_idx = 0
|
|
41
58
|
else:
|
|
42
|
-
axis_idx =
|
|
59
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
43
60
|
if axis_idx != 0:
|
|
44
61
|
data = np.moveaxis(data, axis_idx, 0)
|
|
45
62
|
|
|
46
63
|
if _scaler is None:
|
|
47
|
-
alpha = _alpha_from_tau(time_constant,
|
|
64
|
+
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
48
65
|
_scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
|
|
49
66
|
|
|
50
67
|
result = []
|
|
@@ -57,17 +74,39 @@ def scaler(time_constant: float = 1.0, axis: Optional[str] = None) -> Generator[
|
|
|
57
74
|
|
|
58
75
|
result = np.stack(result)
|
|
59
76
|
result = np.moveaxis(result, 0, axis_idx)
|
|
60
|
-
|
|
77
|
+
msg_out = replace(msg_in, data=result)
|
|
61
78
|
|
|
62
79
|
|
|
63
80
|
@consumer
|
|
64
|
-
def scaler_np(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
81
|
+
def scaler_np(
|
|
82
|
+
time_constant: float = 1.0, axis: typing.Optional[str] = None
|
|
83
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
84
|
+
"""
|
|
85
|
+
Create a generator function that applies an adaptive standard scaler.
|
|
86
|
+
This is faster than :obj:`scaler` for multichannel data.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
time_constant: Decay constant `tau` in seconds.
|
|
90
|
+
axis: The name of the axis to accumulate statistics over.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
A primed generator object that expects `.send(axis_array)` and yields a
|
|
94
|
+
standardized, or "Z-scored" version of the input.
|
|
95
|
+
"""
|
|
96
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
97
|
+
|
|
98
|
+
# State variables
|
|
99
|
+
alpha: float = 0.0
|
|
100
|
+
means: typing.Optional[npt.NDArray] = None
|
|
101
|
+
vars_means: typing.Optional[npt.NDArray] = None
|
|
102
|
+
vars_sq_means: typing.Optional[npt.NDArray] = None
|
|
103
|
+
|
|
104
|
+
# Reset if input changes
|
|
105
|
+
check_input = {
|
|
106
|
+
"gain": None, # Resets alpha
|
|
107
|
+
"shape": None,
|
|
108
|
+
"key": None, # Key change implies buffered means/vars are invalid.
|
|
109
|
+
}
|
|
71
110
|
|
|
72
111
|
def _ew_update(arr, prev, _alpha):
|
|
73
112
|
if np.all(prev == 0):
|
|
@@ -77,51 +116,58 @@ def scaler_np(time_constant: float = 1.0, axis: Optional[str] = None) -> Generat
|
|
|
77
116
|
return prev + _alpha * (arr - prev)
|
|
78
117
|
|
|
79
118
|
while True:
|
|
80
|
-
|
|
119
|
+
msg_in: AxisArray = yield msg_out
|
|
81
120
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
axis = axis_arr_in.dims[0]
|
|
85
|
-
axis_idx = 0
|
|
86
|
-
else:
|
|
87
|
-
axis_idx = axis_arr_in.get_axis_idx(axis)
|
|
88
|
-
data = np.moveaxis(data, axis_idx, 0)
|
|
121
|
+
axis = axis or msg_in.dims[0]
|
|
122
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
89
123
|
|
|
90
|
-
if
|
|
91
|
-
alpha = _alpha_from_tau(time_constant,
|
|
124
|
+
if msg_in.axes[axis].gain != check_input["gain"]:
|
|
125
|
+
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
126
|
+
check_input["gain"] = msg_in.axes[axis].gain
|
|
92
127
|
|
|
93
|
-
|
|
128
|
+
data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
|
|
129
|
+
b_reset = data.shape[1:] != check_input["shape"]
|
|
130
|
+
b_reset |= msg_in.key != check_input["key"]
|
|
131
|
+
if b_reset:
|
|
132
|
+
check_input["shape"] = data.shape[1:]
|
|
133
|
+
check_input["key"] = msg_in.key
|
|
94
134
|
vars_sq_means = np.zeros_like(data[0], dtype=float)
|
|
95
135
|
vars_means = np.zeros_like(data[0], dtype=float)
|
|
96
136
|
means = np.zeros_like(data[0], dtype=float)
|
|
97
137
|
|
|
98
|
-
result =
|
|
99
|
-
for
|
|
138
|
+
result = np.zeros_like(data)
|
|
139
|
+
for sample_ix in range(data.shape[0]):
|
|
140
|
+
sample = data[sample_ix]
|
|
100
141
|
# Update step
|
|
101
142
|
vars_means = _ew_update(sample, vars_means, alpha)
|
|
102
143
|
vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
|
|
103
144
|
means = _ew_update(sample, means, alpha)
|
|
104
145
|
# Get step
|
|
105
|
-
varis = vars_sq_means - vars_means
|
|
106
|
-
y = (
|
|
107
|
-
|
|
108
|
-
result.append(y)
|
|
146
|
+
varis = vars_sq_means - vars_means**2
|
|
147
|
+
y = (sample - means) / (varis**0.5)
|
|
148
|
+
result[sample_ix] = y
|
|
109
149
|
|
|
110
|
-
result
|
|
150
|
+
result[np.isnan(result)] = 0.0
|
|
111
151
|
result = np.moveaxis(result, 0, axis_idx)
|
|
112
|
-
|
|
152
|
+
msg_out = replace(msg_in, data=result)
|
|
113
153
|
|
|
114
154
|
|
|
115
155
|
class AdaptiveStandardScalerSettings(ez.Settings):
|
|
156
|
+
"""
|
|
157
|
+
Settings for :obj:`AdaptiveStandardScaler`.
|
|
158
|
+
See :obj:`scaler_np` for a description of the parameters.
|
|
159
|
+
"""
|
|
160
|
+
|
|
116
161
|
time_constant: float = 1.0
|
|
117
|
-
axis: Optional[str] = None
|
|
162
|
+
axis: typing.Optional[str] = None
|
|
118
163
|
|
|
119
164
|
|
|
120
165
|
class AdaptiveStandardScaler(GenAxisArray):
|
|
121
|
-
|
|
166
|
+
"""Unit for :obj:`scaler_np`"""
|
|
167
|
+
|
|
168
|
+
SETTINGS = AdaptiveStandardScalerSettings
|
|
122
169
|
|
|
123
170
|
def construct_generator(self):
|
|
124
171
|
self.STATE.gen = scaler_np(
|
|
125
|
-
time_constant=self.SETTINGS.time_constant,
|
|
126
|
-
axis=self.SETTINGS.axis
|
|
172
|
+
time_constant=self.SETTINGS.time_constant, axis=self.SETTINGS.axis
|
|
127
173
|
)
|
ezmsg/sigproc/signalinjector.py
CHANGED
|
@@ -1,17 +1,15 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
1
2
|
import typing
|
|
2
3
|
|
|
3
4
|
import ezmsg.core as ez
|
|
4
|
-
|
|
5
|
-
from dataclasses import replace
|
|
6
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
-
|
|
8
6
|
import numpy as np
|
|
9
7
|
import numpy.typing as npt
|
|
10
8
|
|
|
11
9
|
|
|
12
10
|
class SignalInjectorSettings(ez.Settings):
|
|
13
|
-
time_dim: str =
|
|
14
|
-
frequency: typing.Optional[float] = None
|
|
11
|
+
time_dim: str = "time" # Input signal needs a time dimension with units in sec.
|
|
12
|
+
frequency: typing.Optional[float] = None # Hz
|
|
15
13
|
amplitude: float = 1.0
|
|
16
14
|
mixing_seed: typing.Optional[int] = None
|
|
17
15
|
|
|
@@ -24,8 +22,8 @@ class SignalInjectorState(ez.State):
|
|
|
24
22
|
|
|
25
23
|
|
|
26
24
|
class SignalInjector(ez.Unit):
|
|
27
|
-
SETTINGS
|
|
28
|
-
STATE
|
|
25
|
+
SETTINGS = SignalInjectorSettings
|
|
26
|
+
STATE = SignalInjectorState
|
|
29
27
|
|
|
30
28
|
INPUT_FREQUENCY = ez.InputStream(typing.Optional[float])
|
|
31
29
|
INPUT_AMPLITUDE = ez.InputStream(float)
|
|
@@ -48,7 +46,6 @@ class SignalInjector(ez.Unit):
|
|
|
48
46
|
@ez.subscriber(INPUT_SIGNAL)
|
|
49
47
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
50
48
|
async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
51
|
-
|
|
52
49
|
if self.STATE.cur_shape != msg.shape:
|
|
53
50
|
self.STATE.cur_shape = msg.shape
|
|
54
51
|
rng = np.random.default_rng(self.SETTINGS.mixing_seed)
|
|
@@ -58,10 +55,10 @@ class SignalInjector(ez.Unit):
|
|
|
58
55
|
if self.STATE.cur_frequency is None:
|
|
59
56
|
yield self.OUTPUT_SIGNAL, msg
|
|
60
57
|
else:
|
|
61
|
-
out_msg = replace(msg, data
|
|
58
|
+
out_msg = replace(msg, data=msg.data.copy())
|
|
62
59
|
t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
|
|
63
60
|
signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
|
|
64
61
|
mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
|
|
65
62
|
with out_msg.view2d(self.SETTINGS.time_dim) as view:
|
|
66
63
|
view[...] = view + mixed_signal.astype(view.dtype)
|
|
67
|
-
yield self.OUTPUT_SIGNAL, out_msg
|
|
64
|
+
yield self.OUTPUT_SIGNAL, out_msg
|