ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +16 -1
- ezmsg/sigproc/activation.py +75 -0
- ezmsg/sigproc/affinetransform.py +234 -0
- ezmsg/sigproc/aggregate.py +158 -0
- ezmsg/sigproc/bandpower.py +74 -0
- ezmsg/sigproc/base.py +38 -0
- ezmsg/sigproc/butterworthfilter.py +102 -11
- ezmsg/sigproc/decimate.py +7 -4
- ezmsg/sigproc/downsample.py +95 -51
- ezmsg/sigproc/ewmfilter.py +38 -16
- ezmsg/sigproc/filter.py +108 -20
- ezmsg/sigproc/filterbank.py +278 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +28 -0
- ezmsg/sigproc/math/clip.py +30 -0
- ezmsg/sigproc/math/difference.py +60 -0
- ezmsg/sigproc/math/invert.py +29 -0
- ezmsg/sigproc/math/log.py +32 -0
- ezmsg/sigproc/math/scale.py +31 -0
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +259 -224
- ezmsg/sigproc/scaler.py +173 -0
- ezmsg/sigproc/signalinjector.py +64 -0
- ezmsg/sigproc/slicer.py +133 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +86 -0
- ezmsg/sigproc/spectrum.py +259 -0
- ezmsg/sigproc/synth.py +299 -105
- ezmsg/sigproc/wavelets.py +167 -0
- ezmsg/sigproc/window.py +254 -116
- ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
- ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -2
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,287 +1,322 @@
|
|
|
1
|
+
import asyncio # Dev/test apparatus
|
|
2
|
+
from collections import deque
|
|
1
3
|
from dataclasses import dataclass, replace, field
|
|
2
4
|
import time
|
|
5
|
+
import typing
|
|
3
6
|
|
|
4
|
-
import ezmsg.core as ez
|
|
5
7
|
import numpy as np
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from
|
|
8
|
+
import numpy.typing as npt
|
|
9
|
+
import ezmsg.core as ez
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
11
|
+
from ezmsg.util.generator import consumer
|
|
10
12
|
|
|
11
13
|
|
|
12
|
-
@dataclass(unsafe_hash
|
|
14
|
+
@dataclass(unsafe_hash=True)
|
|
13
15
|
class SampleTriggerMessage:
|
|
14
16
|
timestamp: float = field(default_factory=time.time)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
+
"""Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
|
|
18
|
+
|
|
19
|
+
period: typing.Optional[typing.Tuple[float, float]] = None
|
|
20
|
+
"""The period around the timestamp, in seconds"""
|
|
21
|
+
|
|
22
|
+
value: typing.Any = None
|
|
23
|
+
"""A value or 'label' associated with the trigger."""
|
|
17
24
|
|
|
18
25
|
|
|
19
26
|
@dataclass
|
|
20
27
|
class SampleMessage:
|
|
21
28
|
trigger: SampleTriggerMessage
|
|
29
|
+
"""The time, window, and value (if any) associated with the trigger."""
|
|
30
|
+
|
|
22
31
|
sample: AxisArray
|
|
32
|
+
"""The data sampled around the trigger."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@consumer
|
|
36
|
+
def sampler(
|
|
37
|
+
buffer_dur: float,
|
|
38
|
+
axis: typing.Optional[str] = None,
|
|
39
|
+
period: typing.Optional[typing.Tuple[float, float]] = None,
|
|
40
|
+
value: typing.Any = None,
|
|
41
|
+
estimate_alignment: bool = True,
|
|
42
|
+
) -> typing.Generator[
|
|
43
|
+
typing.List[SampleMessage], typing.Union[AxisArray, SampleTriggerMessage], None
|
|
44
|
+
]:
|
|
45
|
+
"""
|
|
46
|
+
A generator function that samples data into a buffer, accepts triggers, and returns slices of sampled
|
|
47
|
+
data around the trigger time.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
buffer_dur: The duration of the buffer in seconds. The buffer must be long enough to store the oldest
|
|
51
|
+
sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
|
|
52
|
+
need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
|
|
53
|
+
axis: The axis along which to sample the data.
|
|
54
|
+
None (default) will choose the first axis in the first input.
|
|
55
|
+
period: The period in seconds during which to sample the data.
|
|
56
|
+
Defaults to None. Only used if not None and the trigger message does not define its own period.
|
|
57
|
+
value: The value to sample. Defaults to None.
|
|
58
|
+
estimate_alignment: Whether to estimate the sample alignment. Defaults to True.
|
|
59
|
+
If True, the trigger timestamp field is used to slice the buffer.
|
|
60
|
+
If False, the trigger timestamp is ignored and the next signal's .offset is used.
|
|
61
|
+
NOTE: For faster-than-realtime playback -- Signals and triggers must share the same (fast) clock for
|
|
62
|
+
estimate_alignment to operate correctly.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
A generator that expects `.send` either an :obj:`AxisArray` containing streaming data messages,
|
|
66
|
+
or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`SampleMessage` s.
|
|
67
|
+
"""
|
|
68
|
+
msg_out: list[SampleMessage] = []
|
|
69
|
+
|
|
70
|
+
# State variables (most shared between trigger- and data-processing.
|
|
71
|
+
triggers: deque[SampleTriggerMessage] = deque()
|
|
72
|
+
buffer: typing.Optional[npt.NDArray] = None
|
|
73
|
+
n_samples: int = 0
|
|
74
|
+
offset: float = 0.0
|
|
75
|
+
|
|
76
|
+
check_inputs = {
|
|
77
|
+
"fs": None, # Also a state variable
|
|
78
|
+
"key": None,
|
|
79
|
+
"shape": None,
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
while True:
|
|
83
|
+
msg_in = yield msg_out
|
|
84
|
+
msg_out = []
|
|
85
|
+
|
|
86
|
+
if isinstance(msg_in, SampleTriggerMessage):
|
|
87
|
+
# Input is a trigger message that we will use to sample the buffer.
|
|
88
|
+
|
|
89
|
+
if buffer is None or check_inputs["fs"] is None:
|
|
90
|
+
# We've yet to see any data; drop the trigger.
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
_period = msg_in.period if msg_in.period is not None else period
|
|
94
|
+
_value = msg_in.value if msg_in.value is not None else value
|
|
95
|
+
|
|
96
|
+
if _period is None:
|
|
97
|
+
ez.logger.warning("Sampling failed: period not specified")
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
# Check that period is valid
|
|
101
|
+
if _period[0] >= _period[1]:
|
|
102
|
+
ez.logger.warning(
|
|
103
|
+
f"Sampling failed: invalid period requested ({_period})"
|
|
104
|
+
)
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
# Check that period is compatible with buffer duration.
|
|
108
|
+
max_buf_len = int(np.round(buffer_dur * check_inputs["fs"]))
|
|
109
|
+
req_buf_len = int(np.round((_period[1] - _period[0]) * check_inputs["fs"]))
|
|
110
|
+
if req_buf_len >= max_buf_len:
|
|
111
|
+
ez.logger.warning(f"Sampling failed: {period=} >= {buffer_dur=}")
|
|
112
|
+
continue
|
|
113
|
+
|
|
114
|
+
trigger_ts: float = msg_in.timestamp
|
|
115
|
+
if not estimate_alignment:
|
|
116
|
+
# Override the trigger timestamp with the next sample's likely timestamp.
|
|
117
|
+
trigger_ts = offset + (n_samples + 1) / check_inputs["fs"]
|
|
118
|
+
|
|
119
|
+
new_trig_msg = replace(
|
|
120
|
+
msg_in, timestamp=trigger_ts, period=_period, value=_value
|
|
121
|
+
)
|
|
122
|
+
triggers.append(new_trig_msg)
|
|
123
|
+
|
|
124
|
+
elif isinstance(msg_in, AxisArray):
|
|
125
|
+
# Get properties from message
|
|
126
|
+
axis = axis or msg_in.dims[0]
|
|
127
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
128
|
+
axis_info = msg_in.get_axis(axis)
|
|
129
|
+
fs = 1.0 / axis_info.gain
|
|
130
|
+
sample_shape = (
|
|
131
|
+
msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# TODO: We could accommodate change in dim order.
|
|
135
|
+
# if axis_idx != check_inputs["axis_idx"]:
|
|
136
|
+
# msg_in = replace(
|
|
137
|
+
# msg_in,
|
|
138
|
+
# data=np.moveaxis(msg_in.data, axis_idx, check_inputs["axis_idx"]),
|
|
139
|
+
# dims=TODO...
|
|
140
|
+
# )
|
|
141
|
+
# axis_idx = check_inputs["axis_idx"]
|
|
142
|
+
|
|
143
|
+
# If the properties have changed in a breaking way then reset buffer and triggers.
|
|
144
|
+
b_reset = fs != check_inputs["fs"]
|
|
145
|
+
b_reset = b_reset or sample_shape != check_inputs["shape"]
|
|
146
|
+
# TODO: Skip next line if we do np.moveaxis above
|
|
147
|
+
b_reset = b_reset or axis_idx != check_inputs["axis_idx"]
|
|
148
|
+
b_reset = b_reset or msg_in.key != check_inputs["key"]
|
|
149
|
+
if b_reset:
|
|
150
|
+
check_inputs["fs"] = fs
|
|
151
|
+
check_inputs["shape"] = sample_shape
|
|
152
|
+
check_inputs["axis_idx"] = axis_idx
|
|
153
|
+
check_inputs["key"] = msg_in.key
|
|
154
|
+
n_samples = msg_in.data.shape[axis_idx]
|
|
155
|
+
buffer = None
|
|
156
|
+
if len(triggers) > 0:
|
|
157
|
+
ez.logger.warning("Data stream changed: Discarding all triggers")
|
|
158
|
+
triggers.clear()
|
|
159
|
+
|
|
160
|
+
# Save some info for trigger processing
|
|
161
|
+
offset = axis_info.offset
|
|
162
|
+
|
|
163
|
+
# Update buffer
|
|
164
|
+
buffer = (
|
|
165
|
+
msg_in.data
|
|
166
|
+
if buffer is None
|
|
167
|
+
else np.concatenate((buffer, msg_in.data), axis=axis_idx)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Calculate timestamps associated with buffer.
|
|
171
|
+
buffer_offset = np.arange(buffer.shape[axis_idx], dtype=float)
|
|
172
|
+
buffer_offset -= buffer_offset[-msg_in.data.shape[axis_idx]]
|
|
173
|
+
buffer_offset *= axis_info.gain
|
|
174
|
+
buffer_offset += axis_info.offset
|
|
175
|
+
|
|
176
|
+
# ... for each trigger, collect the message (if possible) and append to msg_out
|
|
177
|
+
for trig in list(triggers):
|
|
178
|
+
if trig.period is None:
|
|
179
|
+
# This trigger was malformed; drop it.
|
|
180
|
+
triggers.remove(trig)
|
|
181
|
+
|
|
182
|
+
# If the previous iteration had insufficient data for the trigger timestamp + period,
|
|
183
|
+
# and buffer-management removed data required for the trigger, then we will never be able
|
|
184
|
+
# to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
|
|
185
|
+
if (trig.timestamp + trig.period[0]) < buffer_offset[0]:
|
|
186
|
+
ez.logger.warning(
|
|
187
|
+
f"Sampling failed: Buffer span {buffer_offset[0]} is beyond the "
|
|
188
|
+
f"requested sample period start: {trig.timestamp + trig.period[0]}"
|
|
189
|
+
)
|
|
190
|
+
triggers.remove(trig)
|
|
191
|
+
|
|
192
|
+
t_start = trig.timestamp + trig.period[0]
|
|
193
|
+
if t_start >= buffer_offset[0]:
|
|
194
|
+
start = np.searchsorted(buffer_offset, t_start)
|
|
195
|
+
stop = start + int(np.round(fs * (trig.period[1] - trig.period[0])))
|
|
196
|
+
if buffer.shape[axis_idx] > stop:
|
|
197
|
+
# Trigger period fully enclosed in buffer.
|
|
198
|
+
msg_out.append(
|
|
199
|
+
SampleMessage(
|
|
200
|
+
trigger=trig,
|
|
201
|
+
sample=replace(
|
|
202
|
+
msg_in,
|
|
203
|
+
data=slice_along_axis(
|
|
204
|
+
buffer, slice(start, stop), axis_idx
|
|
205
|
+
),
|
|
206
|
+
axes={
|
|
207
|
+
**msg_in.axes,
|
|
208
|
+
axis: replace(
|
|
209
|
+
axis_info, offset=buffer_offset[start]
|
|
210
|
+
),
|
|
211
|
+
},
|
|
212
|
+
),
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
triggers.remove(trig)
|
|
216
|
+
|
|
217
|
+
buf_len = int(buffer_dur * fs)
|
|
218
|
+
buffer = slice_along_axis(buffer, np.s_[-buf_len:], axis_idx)
|
|
23
219
|
|
|
24
220
|
|
|
25
221
|
class SamplerSettings(ez.Settings):
|
|
222
|
+
"""
|
|
223
|
+
Settings for :obj:`Sampler`.
|
|
224
|
+
See :obj:`sampler` for a description of the fields.
|
|
225
|
+
"""
|
|
226
|
+
|
|
26
227
|
buffer_dur: float
|
|
27
|
-
axis: Optional[str] = None
|
|
28
|
-
period: Optional[
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
value: Any = None
|
|
228
|
+
axis: typing.Optional[str] = None
|
|
229
|
+
period: typing.Optional[typing.Tuple[float, float]] = None
|
|
230
|
+
"""Optional default period if unspecified in SampleTriggerMessage"""
|
|
231
|
+
|
|
232
|
+
value: typing.Any = None
|
|
233
|
+
"""Optional default value if unspecified in SampleTriggerMessage"""
|
|
32
234
|
|
|
33
235
|
estimate_alignment: bool = True
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
236
|
+
"""
|
|
237
|
+
If true, use message timestamp fields and reported sampling rate to estimate sample-accurate alignment for samples.
|
|
238
|
+
If false, sampling will be limited to incoming message rate -- "Block timing"
|
|
239
|
+
NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
|
|
240
|
+
"realtime" operation for estimate_alignment to operate correctly.
|
|
241
|
+
"""
|
|
39
242
|
|
|
40
243
|
|
|
41
244
|
class SamplerState(ez.State):
|
|
42
245
|
cur_settings: SamplerSettings
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
246
|
+
gen: typing.Generator[
|
|
247
|
+
typing.Union[AxisArray, SampleTriggerMessage], typing.List[SampleMessage], None
|
|
248
|
+
]
|
|
46
249
|
|
|
47
250
|
|
|
48
251
|
class Sampler(ez.Unit):
|
|
49
|
-
|
|
50
|
-
|
|
252
|
+
"""An :obj:`Unit` for :obj:`sampler`."""
|
|
253
|
+
|
|
254
|
+
SETTINGS = SamplerSettings
|
|
255
|
+
STATE = SamplerState
|
|
51
256
|
|
|
52
257
|
INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
|
|
53
258
|
INPUT_SETTINGS = ez.InputStream(SamplerSettings)
|
|
54
259
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
55
260
|
OUTPUT_SAMPLE = ez.OutputStream(SampleMessage)
|
|
56
261
|
|
|
57
|
-
def
|
|
262
|
+
def construct_generator(self):
|
|
263
|
+
self.STATE.gen = sampler(
|
|
264
|
+
buffer_dur=self.STATE.cur_settings.buffer_dur,
|
|
265
|
+
axis=self.STATE.cur_settings.axis,
|
|
266
|
+
period=self.STATE.cur_settings.period,
|
|
267
|
+
value=self.STATE.cur_settings.value,
|
|
268
|
+
estimate_alignment=self.STATE.cur_settings.estimate_alignment,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
async def initialize(self) -> None:
|
|
58
272
|
self.STATE.cur_settings = self.SETTINGS
|
|
273
|
+
self.construct_generator()
|
|
59
274
|
|
|
60
275
|
@ez.subscriber(INPUT_SETTINGS)
|
|
61
276
|
async def on_settings(self, msg: SamplerSettings) -> None:
|
|
62
277
|
self.STATE.cur_settings = msg
|
|
278
|
+
self.construct_generator()
|
|
63
279
|
|
|
64
280
|
@ez.subscriber(INPUT_TRIGGER)
|
|
65
281
|
async def on_trigger(self, msg: SampleTriggerMessage) -> None:
|
|
66
|
-
|
|
67
|
-
axis_name = self.STATE.cur_settings.axis
|
|
68
|
-
if axis_name is None:
|
|
69
|
-
axis_name = self.STATE.last_msg.dims[0]
|
|
70
|
-
axis = self.STATE.last_msg.get_axis(axis_name)
|
|
71
|
-
axis_idx = self.STATE.last_msg.get_axis_idx(axis_name)
|
|
72
|
-
|
|
73
|
-
fs = 1.0 / axis.gain
|
|
74
|
-
last_msg_timestamp = axis.offset + (
|
|
75
|
-
self.STATE.last_msg.shape[axis_idx] / fs
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
period = (
|
|
79
|
-
msg.period if msg.period is not None else self.STATE.cur_settings.period
|
|
80
|
-
)
|
|
81
|
-
value = (
|
|
82
|
-
msg.value if msg.value is not None else self.STATE.cur_settings.value
|
|
83
|
-
)
|
|
282
|
+
_ = self.STATE.gen.send(msg)
|
|
84
283
|
|
|
85
|
-
|
|
86
|
-
ez.logger.warning(f"Sampling failed: period not specified")
|
|
87
|
-
return
|
|
88
|
-
|
|
89
|
-
# Check that period is valid
|
|
90
|
-
start_offset = int(period[0] * fs)
|
|
91
|
-
stop_offset = int(period[1] * fs)
|
|
92
|
-
if (stop_offset - start_offset) <= 0:
|
|
93
|
-
ez.logger.warning(f"Sampling failed: invalid period requested")
|
|
94
|
-
return
|
|
95
|
-
|
|
96
|
-
# Check that period is compatible with buffer duration
|
|
97
|
-
max_buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
|
|
98
|
-
req_buf_len = int((period[1] - period[0]) * fs)
|
|
99
|
-
if req_buf_len >= max_buf_len:
|
|
100
|
-
ez.logger.warning(
|
|
101
|
-
f"Sampling failed: {period=} >= {self.STATE.cur_settings.buffer_dur=}"
|
|
102
|
-
)
|
|
103
|
-
return
|
|
104
|
-
|
|
105
|
-
offset: int = 0
|
|
106
|
-
if self.STATE.cur_settings.estimate_alignment:
|
|
107
|
-
# Do what we can with the wall clock to determine sample alignment
|
|
108
|
-
wall_delta = msg.timestamp - last_msg_timestamp
|
|
109
|
-
offset = int(wall_delta * fs)
|
|
110
|
-
|
|
111
|
-
# Check that current buffer accumulation allows for offset - period start
|
|
112
|
-
if (
|
|
113
|
-
self.STATE.buffer is None
|
|
114
|
-
or -min(offset + start_offset, 0) >= self.STATE.buffer.shape[0]
|
|
115
|
-
):
|
|
116
|
-
ez.logger.warning(
|
|
117
|
-
"Sampling failed: insufficient buffer accumulation for requested sample period"
|
|
118
|
-
)
|
|
119
|
-
return
|
|
120
|
-
|
|
121
|
-
self.STATE.triggers[replace(msg, period=period, value=value)] = offset
|
|
122
|
-
|
|
123
|
-
else:
|
|
124
|
-
ez.logger.warning("Sampling failed: no signal to sample yet")
|
|
125
|
-
|
|
126
|
-
@ez.subscriber(INPUT_SIGNAL)
|
|
284
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
127
285
|
@ez.publisher(OUTPUT_SAMPLE)
|
|
128
|
-
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
129
|
-
|
|
130
|
-
if axis_name is None:
|
|
131
|
-
axis_name = msg.dims[0]
|
|
132
|
-
axis = msg.get_axis(axis_name)
|
|
133
|
-
|
|
134
|
-
fs = 1.0 / axis.gain
|
|
135
|
-
|
|
136
|
-
if self.STATE.last_msg is None:
|
|
137
|
-
self.STATE.last_msg = msg
|
|
138
|
-
|
|
139
|
-
# Easier to deal with timeseries on axis 0
|
|
140
|
-
last_msg = self.STATE.last_msg
|
|
141
|
-
msg_data = np.moveaxis(msg.data, msg.get_axis_idx(axis_name), 0)
|
|
142
|
-
last_msg_data = np.moveaxis(last_msg.data, last_msg.get_axis_idx(axis_name), 0)
|
|
143
|
-
last_msg_axis = last_msg.get_axis(axis_name)
|
|
144
|
-
last_msg_fs = 1.0 / last_msg_axis.gain
|
|
145
|
-
|
|
146
|
-
# Check if signal properties have changed in a breaking way
|
|
147
|
-
if fs != last_msg_fs or msg_data.shape[1:] != last_msg_data.shape[1:]:
|
|
148
|
-
# Data stream changed meaningfully -- flush buffer, stop sampling
|
|
149
|
-
if len(self.STATE.triggers) > 0:
|
|
150
|
-
ez.logger.warning("Sampling failed: Discarding all triggers")
|
|
151
|
-
ez.logger.warning("Flushing buffer: signal properties changed")
|
|
152
|
-
self.STATE.buffer = None
|
|
153
|
-
self.STATE.triggers = dict()
|
|
154
|
-
|
|
155
|
-
# Accumulate buffer ( time dim => dim 0 )
|
|
156
|
-
self.STATE.buffer = (
|
|
157
|
-
msg_data
|
|
158
|
-
if self.STATE.buffer is None
|
|
159
|
-
else np.concatenate((self.STATE.buffer, msg_data), axis=0)
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
buffer_offset = np.arange(self.STATE.buffer.shape[0] + msg_data.shape[0])
|
|
163
|
-
buffer_offset -= self.STATE.buffer.shape[0] + 1
|
|
164
|
-
buffer_offset = (buffer_offset * axis.gain) + axis.offset
|
|
165
|
-
|
|
166
|
-
pub_samples: List[SampleMessage] = []
|
|
167
|
-
remaining_triggers: Dict[SampleTriggerMessage, int] = dict()
|
|
168
|
-
for trigger, offset in self.STATE.triggers.items():
|
|
169
|
-
if trigger.period is None:
|
|
170
|
-
continue
|
|
171
|
-
|
|
172
|
-
# trigger_offset points to t = 0 within buffer
|
|
173
|
-
offset -= msg_data.shape[0]
|
|
174
|
-
start = offset + int(trigger.period[0] * fs)
|
|
175
|
-
stop = offset + int(trigger.period[1] * fs)
|
|
176
|
-
|
|
177
|
-
if stop < 0: # We should be able to dispatch a sample
|
|
178
|
-
sample_data = self.STATE.buffer[start:stop, ...]
|
|
179
|
-
sample_data = np.moveaxis(sample_data, msg.get_axis_idx(axis_name), 0)
|
|
180
|
-
|
|
181
|
-
sample_offset = buffer_offset[start]
|
|
182
|
-
sample_axis = replace(axis, offset=sample_offset)
|
|
183
|
-
sample_axes = {**msg.axes, **{axis_name: sample_axis}}
|
|
184
|
-
|
|
185
|
-
pub_samples.append(
|
|
186
|
-
SampleMessage(
|
|
187
|
-
trigger=trigger,
|
|
188
|
-
sample=replace(msg, data=sample_data, axes=sample_axes),
|
|
189
|
-
)
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
else:
|
|
193
|
-
remaining_triggers[trigger] = offset
|
|
194
|
-
|
|
286
|
+
async def on_signal(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
287
|
+
pub_samples = self.STATE.gen.send(msg)
|
|
195
288
|
for sample in pub_samples:
|
|
196
289
|
yield self.OUTPUT_SAMPLE, sample
|
|
197
290
|
|
|
198
|
-
self.STATE.triggers = remaining_triggers
|
|
199
|
-
|
|
200
|
-
buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
|
|
201
|
-
self.STATE.buffer = self.STATE.buffer[-buf_len:, ...]
|
|
202
|
-
self.STATE.last_msg = msg
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
## Dev/test apparatus
|
|
206
|
-
import asyncio
|
|
207
291
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
from typing import AsyncGenerator
|
|
292
|
+
class TriggerGeneratorSettings(ez.Settings):
|
|
293
|
+
period: typing.Tuple[float, float]
|
|
294
|
+
"""The period around the trigger event."""
|
|
212
295
|
|
|
296
|
+
prewait: float = 0.5
|
|
297
|
+
"""The time before the first trigger (sec)"""
|
|
213
298
|
|
|
214
|
-
|
|
215
|
-
period
|
|
216
|
-
prewait: float = 0.5 # sec
|
|
217
|
-
publish_period: float = 5.0 # sec
|
|
299
|
+
publish_period: float = 5.0
|
|
300
|
+
"""The period between triggers (sec)"""
|
|
218
301
|
|
|
219
302
|
|
|
220
303
|
class TriggerGenerator(ez.Unit):
|
|
221
|
-
|
|
304
|
+
"""
|
|
305
|
+
A unit to generate triggers every `publish_period` interval.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
SETTINGS = TriggerGeneratorSettings
|
|
222
309
|
|
|
223
310
|
OUTPUT_TRIGGER = ez.OutputStream(SampleTriggerMessage)
|
|
224
311
|
|
|
225
312
|
@ez.publisher(OUTPUT_TRIGGER)
|
|
226
|
-
async def generate(self) -> AsyncGenerator:
|
|
313
|
+
async def generate(self) -> typing.AsyncGenerator:
|
|
227
314
|
await asyncio.sleep(self.SETTINGS.prewait)
|
|
228
315
|
|
|
229
316
|
output = 0
|
|
230
317
|
while True:
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
)
|
|
318
|
+
out_msg = SampleTriggerMessage(period=self.SETTINGS.period, value=output)
|
|
319
|
+
yield self.OUTPUT_TRIGGER, out_msg
|
|
234
320
|
|
|
235
321
|
await asyncio.sleep(self.SETTINGS.publish_period)
|
|
236
322
|
output += 1
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
class SamplerTestSystemSettings(ez.Settings):
|
|
240
|
-
sampler_settings: SamplerSettings
|
|
241
|
-
trigger_settings: TriggerGeneratorSettings
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
class SamplerTestSystem(ez.Collection):
|
|
245
|
-
SETTINGS: SamplerTestSystemSettings
|
|
246
|
-
|
|
247
|
-
OSC = Oscillator()
|
|
248
|
-
SAMPLER = Sampler()
|
|
249
|
-
TRIGGER = TriggerGenerator()
|
|
250
|
-
DEBUG = DebugLog()
|
|
251
|
-
|
|
252
|
-
def configure(self) -> None:
|
|
253
|
-
self.SAMPLER.apply_settings(self.SETTINGS.sampler_settings)
|
|
254
|
-
self.TRIGGER.apply_settings(self.SETTINGS.trigger_settings)
|
|
255
|
-
|
|
256
|
-
self.OSC.apply_settings(
|
|
257
|
-
OscillatorSettings(
|
|
258
|
-
n_time=2, # Number of samples to output per block
|
|
259
|
-
fs=10, # Sampling rate of signal output in Hz
|
|
260
|
-
dispatch_rate="realtime",
|
|
261
|
-
freq=2.0, # Oscillation frequency in Hz
|
|
262
|
-
amp=1.0, # Amplitude
|
|
263
|
-
phase=0.0, # Phase offset (in radians)
|
|
264
|
-
sync=True, # Adjust `freq` to sync with sampling rate
|
|
265
|
-
)
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
def network(self) -> ez.NetworkDefinition:
|
|
269
|
-
return (
|
|
270
|
-
(self.OSC.OUTPUT_SIGNAL, self.SAMPLER.INPUT_SIGNAL),
|
|
271
|
-
(self.TRIGGER.OUTPUT_TRIGGER, self.SAMPLER.INPUT_TRIGGER),
|
|
272
|
-
(self.TRIGGER.OUTPUT_TRIGGER, self.DEBUG.INPUT),
|
|
273
|
-
(self.SAMPLER.OUTPUT_SAMPLE, self.DEBUG.INPUT),
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
if __name__ == "__main__":
|
|
278
|
-
settings = SamplerTestSystemSettings(
|
|
279
|
-
sampler_settings=SamplerSettings(buffer_dur=5.0),
|
|
280
|
-
trigger_settings=TriggerGeneratorSettings(
|
|
281
|
-
period=(1.0, 2.0), prewait=0.5, publish_period=5.0
|
|
282
|
-
),
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
system = SamplerTestSystem(settings)
|
|
286
|
-
|
|
287
|
-
ez.run(SYSTEM = system)
|