ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +4 -1
- ezmsg/sigproc/affinetransform.py +124 -0
- ezmsg/sigproc/aggregate.py +103 -0
- ezmsg/sigproc/bandpower.py +53 -0
- ezmsg/sigproc/butterworthfilter.py +41 -6
- ezmsg/sigproc/downsample.py +52 -26
- ezmsg/sigproc/ewmfilter.py +11 -3
- ezmsg/sigproc/filter.py +82 -14
- ezmsg/sigproc/sampler.py +173 -200
- ezmsg/sigproc/scaler.py +127 -0
- ezmsg/sigproc/signalinjector.py +67 -0
- ezmsg/sigproc/slicer.py +98 -0
- ezmsg/sigproc/spectral.py +9 -132
- ezmsg/sigproc/spectrogram.py +68 -0
- ezmsg/sigproc/spectrum.py +158 -0
- ezmsg/sigproc/synth.py +179 -80
- ezmsg/sigproc/window.py +212 -110
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.2.3.dist-info}/METADATA +15 -13
- ezmsg_sigproc-1.2.3.dist-info/RECORD +23 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.2.3.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/__version__.py +0 -1
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.2.3.dist-info}/LICENSE.txt +0 -0
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,15 +1,19 @@
|
|
|
1
|
+
from collections import deque
|
|
1
2
|
from dataclasses import dataclass, replace, field
|
|
2
3
|
import time
|
|
4
|
+
from typing import Optional, Any, Tuple, List, Union, AsyncGenerator, Generator
|
|
3
5
|
|
|
4
6
|
import ezmsg.core as ez
|
|
5
7
|
import numpy as np
|
|
6
8
|
|
|
7
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
10
|
+
from ezmsg.util.generator import consumer
|
|
8
11
|
|
|
9
|
-
|
|
12
|
+
# Dev/test apparatus
|
|
13
|
+
import asyncio
|
|
10
14
|
|
|
11
15
|
|
|
12
|
-
@dataclass(unsafe_hash
|
|
16
|
+
@dataclass(unsafe_hash=True)
|
|
13
17
|
class SampleTriggerMessage:
|
|
14
18
|
timestamp: float = field(default_factory=time.time)
|
|
15
19
|
period: Optional[Tuple[float, float]] = None
|
|
@@ -22,6 +26,156 @@ class SampleMessage:
|
|
|
22
26
|
sample: AxisArray
|
|
23
27
|
|
|
24
28
|
|
|
29
|
+
@consumer
|
|
30
|
+
def sampler(
|
|
31
|
+
buffer_dur: float,
|
|
32
|
+
axis: Optional[str] = None,
|
|
33
|
+
period: Optional[Tuple[float, float]] = None,
|
|
34
|
+
value: Any = None,
|
|
35
|
+
estimate_alignment: bool = True
|
|
36
|
+
) -> Generator[Union[AxisArray, SampleTriggerMessage], List[SampleMessage], None]:
|
|
37
|
+
"""
|
|
38
|
+
A generator function that samples data into a buffer, accepts triggers, and returns slices of sampled
|
|
39
|
+
data around the trigger time.
|
|
40
|
+
|
|
41
|
+
Parameters:
|
|
42
|
+
- buffer_dur (float): The duration of the buffer in seconds. The buffer must be long enough to store the oldest
|
|
43
|
+
sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
|
|
44
|
+
need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
|
|
45
|
+
- axis (Optional[str]): The axis along which to sample the data.
|
|
46
|
+
None (default) will choose the first axis in the first input.
|
|
47
|
+
- period (Optional[Tuple[float, float]]): The period in seconds during which to sample the data.
|
|
48
|
+
Defaults to None. Only used if not None and the trigger message does not define its own period.
|
|
49
|
+
- value (Any): The value to sample. Defaults to None.
|
|
50
|
+
- estimate_alignment (bool): Whether to estimate the sample alignment. Defaults to True.
|
|
51
|
+
If True, the trigger timestamp field is used to slice the buffer.
|
|
52
|
+
If False, the trigger timestamp is ignored and the next signal's .offset is used.
|
|
53
|
+
NOTE: For faster-than-realtime playback -- Signals and triggers must share the same (fast) clock for
|
|
54
|
+
estimate_alignment to operate correctly.
|
|
55
|
+
|
|
56
|
+
Sends:
|
|
57
|
+
- AxisArray containing streaming data messages
|
|
58
|
+
- SampleTriggerMessage containing a trigger
|
|
59
|
+
Yields:
|
|
60
|
+
- list[SampleMessage]: The list of sample messages.
|
|
61
|
+
"""
|
|
62
|
+
msg_in = None
|
|
63
|
+
msg_out: Optional[list[SampleMessage]] = None
|
|
64
|
+
|
|
65
|
+
# State variables (most shared between trigger- and data-processing.
|
|
66
|
+
triggers: deque[SampleTriggerMessage] = deque()
|
|
67
|
+
last_msg_stats = None
|
|
68
|
+
buffer = None
|
|
69
|
+
|
|
70
|
+
while True:
|
|
71
|
+
msg_in = yield msg_out
|
|
72
|
+
msg_out = []
|
|
73
|
+
if isinstance(msg_in, SampleTriggerMessage):
|
|
74
|
+
if last_msg_stats is None or buffer is None:
|
|
75
|
+
# We've yet to see any data; drop the trigger.
|
|
76
|
+
continue
|
|
77
|
+
fs = last_msg_stats["fs"]
|
|
78
|
+
axis_idx = last_msg_stats["axis_idx"]
|
|
79
|
+
|
|
80
|
+
_period = msg_in.period if msg_in.period is not None else period
|
|
81
|
+
_value = msg_in.value if msg_in.value is not None else value
|
|
82
|
+
|
|
83
|
+
if _period is None:
|
|
84
|
+
ez.logger.warning("Sampling failed: period not specified")
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
# Check that period is valid
|
|
88
|
+
if _period[0] >= _period[1]:
|
|
89
|
+
ez.logger.warning(f"Sampling failed: invalid period requested ({_period})")
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# Check that period is compatible with buffer duration.
|
|
93
|
+
max_buf_len = int(buffer_dur * fs)
|
|
94
|
+
req_buf_len = int((_period[1] - _period[0]) * fs)
|
|
95
|
+
if req_buf_len >= max_buf_len:
|
|
96
|
+
ez.logger.warning(
|
|
97
|
+
f"Sampling failed: {period=} >= {buffer_dur=}"
|
|
98
|
+
)
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
trigger_ts: float = msg_in.timestamp
|
|
102
|
+
if not estimate_alignment:
|
|
103
|
+
# Override the trigger timestamp with the next sample's likely timestamp.
|
|
104
|
+
trigger_ts = last_msg_stats["offset"] + (last_msg_stats["n_samples"] + 1) / fs
|
|
105
|
+
|
|
106
|
+
new_trig_msg = replace(msg_in, timestamp=trigger_ts, period=_period, value=_value)
|
|
107
|
+
triggers.append(new_trig_msg)
|
|
108
|
+
|
|
109
|
+
elif isinstance(msg_in, AxisArray):
|
|
110
|
+
if axis is None:
|
|
111
|
+
axis = msg_in.dims[0]
|
|
112
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
113
|
+
axis_info = msg_in.get_axis(axis)
|
|
114
|
+
fs = 1.0 / axis_info.gain
|
|
115
|
+
sample_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1:]
|
|
116
|
+
|
|
117
|
+
# If the signal properties have changed in a breaking way then reset buffer and triggers.
|
|
118
|
+
if last_msg_stats is None or fs != last_msg_stats["fs"] or sample_shape != last_msg_stats["sample_shape"]:
|
|
119
|
+
last_msg_stats = {
|
|
120
|
+
"fs": fs,
|
|
121
|
+
"sample_shape": sample_shape,
|
|
122
|
+
"axis_idx": axis_idx,
|
|
123
|
+
"n_samples": msg_in.data.shape[axis_idx]
|
|
124
|
+
}
|
|
125
|
+
buffer = None
|
|
126
|
+
if len(triggers) > 0:
|
|
127
|
+
ez.logger.warning("Data stream changed: Discarding all triggers")
|
|
128
|
+
triggers.clear()
|
|
129
|
+
last_msg_stats["offset"] = axis_info.offset # Should be updated on every message.
|
|
130
|
+
|
|
131
|
+
# Update buffer
|
|
132
|
+
buffer = msg_in.data if buffer is None else np.concatenate((buffer, msg_in.data), axis=axis_idx)
|
|
133
|
+
|
|
134
|
+
# Calculate timestamps associated with buffer.
|
|
135
|
+
buffer_offset = np.arange(buffer.shape[axis_idx], dtype=float)
|
|
136
|
+
buffer_offset -= buffer_offset[-msg_in.data.shape[axis_idx]]
|
|
137
|
+
buffer_offset *= axis_info.gain
|
|
138
|
+
buffer_offset += axis_info.offset
|
|
139
|
+
|
|
140
|
+
# ... for each trigger, collect the message (if possible) and append to msg_out
|
|
141
|
+
for trig in list(triggers):
|
|
142
|
+
if trig.period is None:
|
|
143
|
+
# This trigger was malformed; drop it.
|
|
144
|
+
triggers.remove(trig)
|
|
145
|
+
|
|
146
|
+
# If the previous iteration had insufficient data for the trigger timestamp + period,
|
|
147
|
+
# and buffer-management removed data required for the trigger, then we will never be able
|
|
148
|
+
# to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
|
|
149
|
+
if (trig.timestamp + trig.period[0]) < buffer_offset[0]:
|
|
150
|
+
ez.logger.warning(
|
|
151
|
+
f"Sampling failed: Buffer span {buffer_offset[0]} is beyond the "
|
|
152
|
+
f"requested sample period start: {trig.timestamp + trig.period[0]}"
|
|
153
|
+
)
|
|
154
|
+
triggers.remove(trig)
|
|
155
|
+
|
|
156
|
+
# TODO: Speed up with searchsorted?
|
|
157
|
+
t_start = trig.timestamp + trig.period[0]
|
|
158
|
+
if t_start >= buffer_offset[0]:
|
|
159
|
+
start = np.searchsorted(buffer_offset, t_start)
|
|
160
|
+
stop = start + int(fs * (trig.period[1] - trig.period[0]))
|
|
161
|
+
if buffer.shape[axis_idx] > stop:
|
|
162
|
+
# Trigger period fully enclosed in buffer.
|
|
163
|
+
msg_out.append(
|
|
164
|
+
SampleMessage(
|
|
165
|
+
trigger=trig,
|
|
166
|
+
sample=replace(
|
|
167
|
+
msg_in,
|
|
168
|
+
data=slice_along_axis(buffer, slice(start, stop), axis_idx),
|
|
169
|
+
axes={**msg_in.axes, axis: replace(axis_info, offset=buffer_offset[start])}
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
triggers.remove(trig)
|
|
174
|
+
|
|
175
|
+
buf_len = int(buffer_dur * fs)
|
|
176
|
+
buffer = slice_along_axis(buffer, np.s_[-buf_len:], axis_idx)
|
|
177
|
+
|
|
178
|
+
|
|
25
179
|
class SamplerSettings(ez.Settings):
|
|
26
180
|
buffer_dur: float
|
|
27
181
|
axis: Optional[str] = None
|
|
@@ -40,9 +194,7 @@ class SamplerSettings(ez.Settings):
|
|
|
40
194
|
|
|
41
195
|
class SamplerState(ez.State):
|
|
42
196
|
cur_settings: SamplerSettings
|
|
43
|
-
|
|
44
|
-
last_msg: Optional[AxisArray] = None
|
|
45
|
-
buffer: Optional[np.ndarray] = None
|
|
197
|
+
gen: Generator[Union[AxisArray, SampleTriggerMessage], List[SampleMessage], None]
|
|
46
198
|
|
|
47
199
|
|
|
48
200
|
class Sampler(ez.Unit):
|
|
@@ -54,162 +206,35 @@ class Sampler(ez.Unit):
|
|
|
54
206
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
55
207
|
OUTPUT_SAMPLE = ez.OutputStream(SampleMessage)
|
|
56
208
|
|
|
209
|
+
def construct_generator(self):
|
|
210
|
+
self.STATE.gen = sampler(
|
|
211
|
+
buffer_dur=self.STATE.cur_settings.buffer_dur,
|
|
212
|
+
axis=self.STATE.cur_settings.axis,
|
|
213
|
+
period=self.STATE.cur_settings.period,
|
|
214
|
+
value=self.STATE.cur_settings.value,
|
|
215
|
+
estimate_alignment=self.STATE.cur_settings.estimate_alignment
|
|
216
|
+
)
|
|
217
|
+
|
|
57
218
|
def initialize(self) -> None:
|
|
58
219
|
self.STATE.cur_settings = self.SETTINGS
|
|
220
|
+
self.construct_generator()
|
|
59
221
|
|
|
60
222
|
@ez.subscriber(INPUT_SETTINGS)
|
|
61
223
|
async def on_settings(self, msg: SamplerSettings) -> None:
|
|
62
224
|
self.STATE.cur_settings = msg
|
|
225
|
+
self.construct_generator()
|
|
63
226
|
|
|
64
227
|
@ez.subscriber(INPUT_TRIGGER)
|
|
65
228
|
async def on_trigger(self, msg: SampleTriggerMessage) -> None:
|
|
66
|
-
|
|
67
|
-
axis_name = self.STATE.cur_settings.axis
|
|
68
|
-
if axis_name is None:
|
|
69
|
-
axis_name = self.STATE.last_msg.dims[0]
|
|
70
|
-
axis = self.STATE.last_msg.get_axis(axis_name)
|
|
71
|
-
axis_idx = self.STATE.last_msg.get_axis_idx(axis_name)
|
|
72
|
-
|
|
73
|
-
fs = 1.0 / axis.gain
|
|
74
|
-
last_msg_timestamp = axis.offset + (
|
|
75
|
-
self.STATE.last_msg.shape[axis_idx] / fs
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
period = (
|
|
79
|
-
msg.period if msg.period is not None else self.STATE.cur_settings.period
|
|
80
|
-
)
|
|
81
|
-
value = (
|
|
82
|
-
msg.value if msg.value is not None else self.STATE.cur_settings.value
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
if period is None:
|
|
86
|
-
ez.logger.warning(f"Sampling failed: period not specified")
|
|
87
|
-
return
|
|
88
|
-
|
|
89
|
-
# Check that period is valid
|
|
90
|
-
start_offset = int(period[0] * fs)
|
|
91
|
-
stop_offset = int(period[1] * fs)
|
|
92
|
-
if (stop_offset - start_offset) <= 0:
|
|
93
|
-
ez.logger.warning(f"Sampling failed: invalid period requested")
|
|
94
|
-
return
|
|
95
|
-
|
|
96
|
-
# Check that period is compatible with buffer duration
|
|
97
|
-
max_buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
|
|
98
|
-
req_buf_len = int((period[1] - period[0]) * fs)
|
|
99
|
-
if req_buf_len >= max_buf_len:
|
|
100
|
-
ez.logger.warning(
|
|
101
|
-
f"Sampling failed: {period=} >= {self.STATE.cur_settings.buffer_dur=}"
|
|
102
|
-
)
|
|
103
|
-
return
|
|
104
|
-
|
|
105
|
-
offset: int = 0
|
|
106
|
-
if self.STATE.cur_settings.estimate_alignment:
|
|
107
|
-
# Do what we can with the wall clock to determine sample alignment
|
|
108
|
-
wall_delta = msg.timestamp - last_msg_timestamp
|
|
109
|
-
offset = int(wall_delta * fs)
|
|
110
|
-
|
|
111
|
-
# Check that current buffer accumulation allows for offset - period start
|
|
112
|
-
if (
|
|
113
|
-
self.STATE.buffer is None
|
|
114
|
-
or -min(offset + start_offset, 0) >= self.STATE.buffer.shape[0]
|
|
115
|
-
):
|
|
116
|
-
ez.logger.warning(
|
|
117
|
-
"Sampling failed: insufficient buffer accumulation for requested sample period"
|
|
118
|
-
)
|
|
119
|
-
return
|
|
120
|
-
|
|
121
|
-
self.STATE.triggers[replace(msg, period=period, value=value)] = offset
|
|
122
|
-
|
|
123
|
-
else:
|
|
124
|
-
ez.logger.warning("Sampling failed: no signal to sample yet")
|
|
229
|
+
_ = self.STATE.gen.send(msg)
|
|
125
230
|
|
|
126
231
|
@ez.subscriber(INPUT_SIGNAL)
|
|
127
232
|
@ez.publisher(OUTPUT_SAMPLE)
|
|
128
233
|
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
129
|
-
|
|
130
|
-
if axis_name is None:
|
|
131
|
-
axis_name = msg.dims[0]
|
|
132
|
-
axis = msg.get_axis(axis_name)
|
|
133
|
-
|
|
134
|
-
fs = 1.0 / axis.gain
|
|
135
|
-
|
|
136
|
-
if self.STATE.last_msg is None:
|
|
137
|
-
self.STATE.last_msg = msg
|
|
138
|
-
|
|
139
|
-
# Easier to deal with timeseries on axis 0
|
|
140
|
-
last_msg = self.STATE.last_msg
|
|
141
|
-
msg_data = np.moveaxis(msg.data, msg.get_axis_idx(axis_name), 0)
|
|
142
|
-
last_msg_data = np.moveaxis(last_msg.data, last_msg.get_axis_idx(axis_name), 0)
|
|
143
|
-
last_msg_axis = last_msg.get_axis(axis_name)
|
|
144
|
-
last_msg_fs = 1.0 / last_msg_axis.gain
|
|
145
|
-
|
|
146
|
-
# Check if signal properties have changed in a breaking way
|
|
147
|
-
if fs != last_msg_fs or msg_data.shape[1:] != last_msg_data.shape[1:]:
|
|
148
|
-
# Data stream changed meaningfully -- flush buffer, stop sampling
|
|
149
|
-
if len(self.STATE.triggers) > 0:
|
|
150
|
-
ez.logger.warning("Sampling failed: Discarding all triggers")
|
|
151
|
-
ez.logger.warning("Flushing buffer: signal properties changed")
|
|
152
|
-
self.STATE.buffer = None
|
|
153
|
-
self.STATE.triggers = dict()
|
|
154
|
-
|
|
155
|
-
# Accumulate buffer ( time dim => dim 0 )
|
|
156
|
-
self.STATE.buffer = (
|
|
157
|
-
msg_data
|
|
158
|
-
if self.STATE.buffer is None
|
|
159
|
-
else np.concatenate((self.STATE.buffer, msg_data), axis=0)
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
buffer_offset = np.arange(self.STATE.buffer.shape[0] + msg_data.shape[0])
|
|
163
|
-
buffer_offset -= self.STATE.buffer.shape[0] + 1
|
|
164
|
-
buffer_offset = (buffer_offset * axis.gain) + axis.offset
|
|
165
|
-
|
|
166
|
-
pub_samples: List[SampleMessage] = []
|
|
167
|
-
remaining_triggers: Dict[SampleTriggerMessage, int] = dict()
|
|
168
|
-
for trigger, offset in self.STATE.triggers.items():
|
|
169
|
-
if trigger.period is None:
|
|
170
|
-
continue
|
|
171
|
-
|
|
172
|
-
# trigger_offset points to t = 0 within buffer
|
|
173
|
-
offset -= msg_data.shape[0]
|
|
174
|
-
start = offset + int(trigger.period[0] * fs)
|
|
175
|
-
stop = offset + int(trigger.period[1] * fs)
|
|
176
|
-
|
|
177
|
-
if stop < 0: # We should be able to dispatch a sample
|
|
178
|
-
sample_data = self.STATE.buffer[start:stop, ...]
|
|
179
|
-
sample_data = np.moveaxis(sample_data, msg.get_axis_idx(axis_name), 0)
|
|
180
|
-
|
|
181
|
-
sample_offset = buffer_offset[start]
|
|
182
|
-
sample_axis = replace(axis, offset=sample_offset)
|
|
183
|
-
sample_axes = {**msg.axes, **{axis_name: sample_axis}}
|
|
184
|
-
|
|
185
|
-
pub_samples.append(
|
|
186
|
-
SampleMessage(
|
|
187
|
-
trigger=trigger,
|
|
188
|
-
sample=replace(msg, data=sample_data, axes=sample_axes),
|
|
189
|
-
)
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
else:
|
|
193
|
-
remaining_triggers[trigger] = offset
|
|
194
|
-
|
|
234
|
+
pub_samples = self.STATE.gen.send(msg)
|
|
195
235
|
for sample in pub_samples:
|
|
196
236
|
yield self.OUTPUT_SAMPLE, sample
|
|
197
237
|
|
|
198
|
-
self.STATE.triggers = remaining_triggers
|
|
199
|
-
|
|
200
|
-
buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
|
|
201
|
-
self.STATE.buffer = self.STATE.buffer[-buf_len:, ...]
|
|
202
|
-
self.STATE.last_msg = msg
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
## Dev/test apparatus
|
|
206
|
-
import asyncio
|
|
207
|
-
|
|
208
|
-
from ezmsg.testing.debuglog import DebugLog
|
|
209
|
-
from ezmsg.sigproc.synth import Oscillator, OscillatorSettings
|
|
210
|
-
|
|
211
|
-
from typing import AsyncGenerator
|
|
212
|
-
|
|
213
238
|
|
|
214
239
|
class TriggerGeneratorSettings(ez.Settings):
|
|
215
240
|
period: Tuple[float, float] # sec
|
|
@@ -228,60 +253,8 @@ class TriggerGenerator(ez.Unit):
|
|
|
228
253
|
|
|
229
254
|
output = 0
|
|
230
255
|
while True:
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
)
|
|
256
|
+
out_msg = SampleTriggerMessage(period=self.SETTINGS.period, value=output)
|
|
257
|
+
yield self.OUTPUT_TRIGGER, out_msg
|
|
234
258
|
|
|
235
259
|
await asyncio.sleep(self.SETTINGS.publish_period)
|
|
236
260
|
output += 1
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
class SamplerTestSystemSettings(ez.Settings):
|
|
240
|
-
sampler_settings: SamplerSettings
|
|
241
|
-
trigger_settings: TriggerGeneratorSettings
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
class SamplerTestSystem(ez.Collection):
|
|
245
|
-
SETTINGS: SamplerTestSystemSettings
|
|
246
|
-
|
|
247
|
-
OSC = Oscillator()
|
|
248
|
-
SAMPLER = Sampler()
|
|
249
|
-
TRIGGER = TriggerGenerator()
|
|
250
|
-
DEBUG = DebugLog()
|
|
251
|
-
|
|
252
|
-
def configure(self) -> None:
|
|
253
|
-
self.SAMPLER.apply_settings(self.SETTINGS.sampler_settings)
|
|
254
|
-
self.TRIGGER.apply_settings(self.SETTINGS.trigger_settings)
|
|
255
|
-
|
|
256
|
-
self.OSC.apply_settings(
|
|
257
|
-
OscillatorSettings(
|
|
258
|
-
n_time=2, # Number of samples to output per block
|
|
259
|
-
fs=10, # Sampling rate of signal output in Hz
|
|
260
|
-
dispatch_rate="realtime",
|
|
261
|
-
freq=2.0, # Oscillation frequency in Hz
|
|
262
|
-
amp=1.0, # Amplitude
|
|
263
|
-
phase=0.0, # Phase offset (in radians)
|
|
264
|
-
sync=True, # Adjust `freq` to sync with sampling rate
|
|
265
|
-
)
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
def network(self) -> ez.NetworkDefinition:
|
|
269
|
-
return (
|
|
270
|
-
(self.OSC.OUTPUT_SIGNAL, self.SAMPLER.INPUT_SIGNAL),
|
|
271
|
-
(self.TRIGGER.OUTPUT_TRIGGER, self.SAMPLER.INPUT_TRIGGER),
|
|
272
|
-
(self.TRIGGER.OUTPUT_TRIGGER, self.DEBUG.INPUT),
|
|
273
|
-
(self.SAMPLER.OUTPUT_SAMPLE, self.DEBUG.INPUT),
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
if __name__ == "__main__":
|
|
278
|
-
settings = SamplerTestSystemSettings(
|
|
279
|
-
sampler_settings=SamplerSettings(buffer_dur=5.0),
|
|
280
|
-
trigger_settings=TriggerGeneratorSettings(
|
|
281
|
-
period=(1.0, 2.0), prewait=0.5, publish_period=5.0
|
|
282
|
-
),
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
system = SamplerTestSystem(settings)
|
|
286
|
-
|
|
287
|
-
ez.run(SYSTEM = system)
|
ezmsg/sigproc/scaler.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
from typing import Generator, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
import ezmsg.core as ez
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.generator import consumer, GenAxisArray
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _tau_from_alpha(alpha: float, dt: float) -> float:
|
|
12
|
+
"""
|
|
13
|
+
Inverse of _alpha_from_tau. See that function for explanation.
|
|
14
|
+
"""
|
|
15
|
+
return -dt / np.log(1 - alpha)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
19
|
+
"""
|
|
20
|
+
# https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
|
|
21
|
+
:param tau: The amount of time for the smoothed response of a unit step function to reach
|
|
22
|
+
1 - 1/e approx-eq 63.2%.
|
|
23
|
+
:param dt: sampling period, or 1 / sampling_rate.
|
|
24
|
+
:return: alpha, the "fading factor" in exponential smoothing.
|
|
25
|
+
"""
|
|
26
|
+
return 1 - np.exp(-dt / tau)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@consumer
|
|
30
|
+
def scaler(time_constant: float = 1.0, axis: Optional[str] = None) -> Generator[AxisArray, AxisArray, None]:
|
|
31
|
+
from river import preprocessing
|
|
32
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
33
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
34
|
+
_scaler = None
|
|
35
|
+
while True:
|
|
36
|
+
axis_arr_in = yield axis_arr_out
|
|
37
|
+
data = axis_arr_in.data
|
|
38
|
+
if axis is None:
|
|
39
|
+
axis = axis_arr_in.dims[0]
|
|
40
|
+
axis_idx = 0
|
|
41
|
+
else:
|
|
42
|
+
axis_idx = axis_arr_in.get_axis_idx(axis)
|
|
43
|
+
if axis_idx != 0:
|
|
44
|
+
data = np.moveaxis(data, axis_idx, 0)
|
|
45
|
+
|
|
46
|
+
if _scaler is None:
|
|
47
|
+
alpha = _alpha_from_tau(time_constant, axis_arr_in.axes[axis].gain)
|
|
48
|
+
_scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
|
|
49
|
+
|
|
50
|
+
result = []
|
|
51
|
+
for sample in data:
|
|
52
|
+
x = {k: v for k, v in enumerate(sample.flatten().tolist())}
|
|
53
|
+
_scaler.learn_one(x)
|
|
54
|
+
y = _scaler.transform_one(x)
|
|
55
|
+
k = sorted(y.keys())
|
|
56
|
+
result.append(np.array([y[_] for _ in k]).reshape(sample.shape))
|
|
57
|
+
|
|
58
|
+
result = np.stack(result)
|
|
59
|
+
result = np.moveaxis(result, 0, axis_idx)
|
|
60
|
+
axis_arr_out = replace(axis_arr_in, data=result)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@consumer
|
|
64
|
+
def scaler_np(time_constant: float = 1.0, axis: Optional[str] = None) -> Generator[AxisArray, AxisArray, None]:
|
|
65
|
+
# The only dependency is numpy.
|
|
66
|
+
# This is faster for multi-channel data but slower for single-channel data.
|
|
67
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
68
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
69
|
+
means = vars_means = vars_sq_means = None
|
|
70
|
+
alpha = None
|
|
71
|
+
|
|
72
|
+
def _ew_update(arr, prev, _alpha):
|
|
73
|
+
if np.all(prev == 0):
|
|
74
|
+
return arr
|
|
75
|
+
# return _alpha * arr + (1 - _alpha) * prev
|
|
76
|
+
# Micro-optimization: sub, mult, add (below) is faster than sub, mult, mult, add (above)
|
|
77
|
+
return prev + _alpha * (arr - prev)
|
|
78
|
+
|
|
79
|
+
while True:
|
|
80
|
+
axis_arr_in = yield axis_arr_out
|
|
81
|
+
|
|
82
|
+
data = axis_arr_in.data
|
|
83
|
+
if axis is None:
|
|
84
|
+
axis = axis_arr_in.dims[0]
|
|
85
|
+
axis_idx = 0
|
|
86
|
+
else:
|
|
87
|
+
axis_idx = axis_arr_in.get_axis_idx(axis)
|
|
88
|
+
data = np.moveaxis(data, axis_idx, 0)
|
|
89
|
+
|
|
90
|
+
if alpha is None:
|
|
91
|
+
alpha = _alpha_from_tau(time_constant, axis_arr_in.axes[axis].gain)
|
|
92
|
+
|
|
93
|
+
if means is None or means.shape != data[0].shape:
|
|
94
|
+
vars_sq_means = np.zeros_like(data[0], dtype=float)
|
|
95
|
+
vars_means = np.zeros_like(data[0], dtype=float)
|
|
96
|
+
means = np.zeros_like(data[0], dtype=float)
|
|
97
|
+
|
|
98
|
+
result = []
|
|
99
|
+
for sample in data:
|
|
100
|
+
# Update step
|
|
101
|
+
vars_means = _ew_update(sample, vars_means, alpha)
|
|
102
|
+
vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
|
|
103
|
+
means = _ew_update(sample, means, alpha)
|
|
104
|
+
# Get step
|
|
105
|
+
varis = vars_sq_means - vars_means ** 2
|
|
106
|
+
y = ((sample - means) / (varis**0.5))
|
|
107
|
+
y[np.isnan(y)] = 0.0
|
|
108
|
+
result.append(y)
|
|
109
|
+
|
|
110
|
+
result = np.stack(result, axis=0)
|
|
111
|
+
result = np.moveaxis(result, 0, axis_idx)
|
|
112
|
+
axis_arr_out = replace(axis_arr_in, data=result)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class AdaptiveStandardScalerSettings(ez.Settings):
|
|
116
|
+
time_constant: float = 1.0
|
|
117
|
+
axis: Optional[str] = None
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class AdaptiveStandardScaler(GenAxisArray):
|
|
121
|
+
SETTINGS: AdaptiveStandardScalerSettings
|
|
122
|
+
|
|
123
|
+
def construct_generator(self):
|
|
124
|
+
self.STATE.gen = scaler_np(
|
|
125
|
+
time_constant=self.SETTINGS.time_constant,
|
|
126
|
+
axis=self.SETTINGS.axis
|
|
127
|
+
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
|
|
5
|
+
from dataclasses import replace
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SignalInjectorSettings(ez.Settings):
|
|
13
|
+
time_dim: str = 'time' # Input signal needs a time dimension with units in sec.
|
|
14
|
+
frequency: typing.Optional[float] = None # Hz
|
|
15
|
+
amplitude: float = 1.0
|
|
16
|
+
mixing_seed: typing.Optional[int] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SignalInjectorState(ez.State):
|
|
20
|
+
cur_shape: typing.Optional[typing.Tuple[int, ...]] = None
|
|
21
|
+
cur_frequency: typing.Optional[float] = None
|
|
22
|
+
cur_amplitude: float
|
|
23
|
+
mixing: npt.NDArray
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SignalInjector(ez.Unit):
|
|
27
|
+
SETTINGS: SignalInjectorSettings
|
|
28
|
+
STATE: SignalInjectorState
|
|
29
|
+
|
|
30
|
+
INPUT_FREQUENCY = ez.InputStream(typing.Optional[float])
|
|
31
|
+
INPUT_AMPLITUDE = ez.InputStream(float)
|
|
32
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
33
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
34
|
+
|
|
35
|
+
async def initialize(self) -> None:
|
|
36
|
+
self.STATE.cur_frequency = self.SETTINGS.frequency
|
|
37
|
+
self.STATE.cur_amplitude = self.SETTINGS.amplitude
|
|
38
|
+
self.STATE.mixing = np.array([])
|
|
39
|
+
|
|
40
|
+
@ez.subscriber(INPUT_FREQUENCY)
|
|
41
|
+
async def on_frequency(self, msg: typing.Optional[float]) -> None:
|
|
42
|
+
self.STATE.cur_frequency = msg
|
|
43
|
+
|
|
44
|
+
@ez.subscriber(INPUT_AMPLITUDE)
|
|
45
|
+
async def on_amplitude(self, msg: float) -> None:
|
|
46
|
+
self.STATE.cur_amplitude = msg
|
|
47
|
+
|
|
48
|
+
@ez.subscriber(INPUT_SIGNAL)
|
|
49
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
50
|
+
async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
51
|
+
|
|
52
|
+
if self.STATE.cur_shape != msg.shape:
|
|
53
|
+
self.STATE.cur_shape = msg.shape
|
|
54
|
+
rng = np.random.default_rng(self.SETTINGS.mixing_seed)
|
|
55
|
+
self.STATE.mixing = rng.random((1, msg.shape2d(self.SETTINGS.time_dim)[1]))
|
|
56
|
+
self.STATE.mixing = (self.STATE.mixing * 2.0) - 1.0
|
|
57
|
+
|
|
58
|
+
if self.STATE.cur_frequency is None:
|
|
59
|
+
yield self.OUTPUT_SIGNAL, msg
|
|
60
|
+
else:
|
|
61
|
+
out_msg = replace(msg, data = msg.data.copy())
|
|
62
|
+
t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
|
|
63
|
+
signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
|
|
64
|
+
mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
|
|
65
|
+
with out_msg.view2d(self.SETTINGS.time_dim) as view:
|
|
66
|
+
view[...] = view + mixed_signal.astype(view.dtype)
|
|
67
|
+
yield self.OUTPUT_SIGNAL, out_msg
|