ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +119 -104
- ezmsg/sigproc/bandpower.py +58 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -84
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
|
-
import asyncio
|
|
1
|
+
import asyncio
|
|
2
2
|
from collections import deque
|
|
3
|
-
|
|
4
|
-
import time
|
|
3
|
+
import traceback
|
|
5
4
|
import typing
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
@@ -12,215 +11,17 @@ from ezmsg.util.messages.axisarray import (
|
|
|
12
11
|
slice_along_axis,
|
|
13
12
|
)
|
|
14
13
|
from ezmsg.util.messages.util import replace
|
|
15
|
-
from ezmsg.util.generator import consumer
|
|
16
14
|
|
|
17
15
|
from .util.profile import profile_subpub
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
value: typing.Any = None
|
|
29
|
-
"""A value or 'label' associated with the trigger."""
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@dataclass
|
|
33
|
-
class SampleMessage:
|
|
34
|
-
trigger: SampleTriggerMessage
|
|
35
|
-
"""The time, window, and value (if any) associated with the trigger."""
|
|
36
|
-
|
|
37
|
-
sample: AxisArray
|
|
38
|
-
"""The data sampled around the trigger."""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@consumer
|
|
42
|
-
def sampler(
|
|
43
|
-
buffer_dur: float,
|
|
44
|
-
axis: str | None = None,
|
|
45
|
-
period: tuple[float, float] | None = None,
|
|
46
|
-
value: typing.Any = None,
|
|
47
|
-
estimate_alignment: bool = True,
|
|
48
|
-
) -> typing.Generator[list[SampleMessage], AxisArray | SampleTriggerMessage, None]:
|
|
49
|
-
"""
|
|
50
|
-
Sample data into a buffer, accept triggers, and return slices of sampled
|
|
51
|
-
data around the trigger time.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
buffer_dur: The duration of the buffer in seconds. The buffer must be long enough to store the oldest
|
|
55
|
-
sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
|
|
56
|
-
need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
|
|
57
|
-
axis: The axis along which to sample the data.
|
|
58
|
-
None (default) will choose the first axis in the first input.
|
|
59
|
-
Note: (for now) the axis must exist in the msg .axes and be of type AxisArray.LinearAxis
|
|
60
|
-
period: The period in seconds during which to sample the data.
|
|
61
|
-
Defaults to None. Only used if not None and the trigger message does not define its own period.
|
|
62
|
-
value: The value to sample. Defaults to None.
|
|
63
|
-
estimate_alignment: Whether to estimate the sample alignment. Defaults to True.
|
|
64
|
-
If True, the trigger timestamp field is used to slice the buffer.
|
|
65
|
-
If False, the trigger timestamp is ignored and the next signal's .offset is used.
|
|
66
|
-
NOTE: For faster-than-realtime playback -- Signals and triggers must share the same (fast) clock for
|
|
67
|
-
estimate_alignment to operate correctly.
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
A generator that expects `.send` either an :obj:`AxisArray` containing streaming data messages,
|
|
71
|
-
or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`SampleMessage` s.
|
|
72
|
-
"""
|
|
73
|
-
msg_out: list[SampleMessage] = []
|
|
74
|
-
|
|
75
|
-
# State variables (most shared between trigger- and data-processing.
|
|
76
|
-
triggers: deque[SampleTriggerMessage] = deque()
|
|
77
|
-
buffer: npt.NDArray | None = None
|
|
78
|
-
n_samples: int = 0
|
|
79
|
-
offset: float = 0.0
|
|
80
|
-
|
|
81
|
-
check_inputs = {
|
|
82
|
-
"fs": None, # Also a state variable
|
|
83
|
-
"key": None,
|
|
84
|
-
"shape": None,
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
while True:
|
|
88
|
-
msg_in = yield msg_out
|
|
89
|
-
msg_out = []
|
|
90
|
-
|
|
91
|
-
if isinstance(msg_in, SampleTriggerMessage):
|
|
92
|
-
# Input is a trigger message that we will use to sample the buffer.
|
|
93
|
-
|
|
94
|
-
if buffer is None or check_inputs["fs"] is None:
|
|
95
|
-
# We've yet to see any data; drop the trigger.
|
|
96
|
-
continue
|
|
97
|
-
|
|
98
|
-
_period = msg_in.period if msg_in.period is not None else period
|
|
99
|
-
_value = msg_in.value if msg_in.value is not None else value
|
|
100
|
-
|
|
101
|
-
if _period is None:
|
|
102
|
-
ez.logger.warning("Sampling failed: period not specified")
|
|
103
|
-
continue
|
|
104
|
-
|
|
105
|
-
# Check that period is valid
|
|
106
|
-
if _period[0] >= _period[1]:
|
|
107
|
-
ez.logger.warning(
|
|
108
|
-
f"Sampling failed: invalid period requested ({_period})"
|
|
109
|
-
)
|
|
110
|
-
continue
|
|
111
|
-
|
|
112
|
-
# Check that period is compatible with buffer duration.
|
|
113
|
-
max_buf_len = int(np.round(buffer_dur * check_inputs["fs"]))
|
|
114
|
-
req_buf_len = int(np.round((_period[1] - _period[0]) * check_inputs["fs"]))
|
|
115
|
-
if req_buf_len >= max_buf_len:
|
|
116
|
-
ez.logger.warning(f"Sampling failed: {period=} >= {buffer_dur=}")
|
|
117
|
-
continue
|
|
118
|
-
|
|
119
|
-
trigger_ts: float = msg_in.timestamp
|
|
120
|
-
if not estimate_alignment:
|
|
121
|
-
# Override the trigger timestamp with the next sample's likely timestamp.
|
|
122
|
-
trigger_ts = offset + (n_samples + 1) / check_inputs["fs"]
|
|
123
|
-
|
|
124
|
-
new_trig_msg = replace(
|
|
125
|
-
msg_in, timestamp=trigger_ts, period=_period, value=_value
|
|
126
|
-
)
|
|
127
|
-
triggers.append(new_trig_msg)
|
|
128
|
-
|
|
129
|
-
elif isinstance(msg_in, AxisArray):
|
|
130
|
-
# Get properties from message
|
|
131
|
-
axis = axis or msg_in.dims[0]
|
|
132
|
-
axis_idx = msg_in.get_axis_idx(axis)
|
|
133
|
-
axis_info = msg_in.get_axis(axis)
|
|
134
|
-
fs = 1.0 / axis_info.gain
|
|
135
|
-
sample_shape = (
|
|
136
|
-
msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
# TODO: We could accommodate change in dim order.
|
|
140
|
-
# if axis_idx != check_inputs["axis_idx"]:
|
|
141
|
-
# msg_in = replace(
|
|
142
|
-
# msg_in,
|
|
143
|
-
# data=np.moveaxis(msg_in.data, axis_idx, check_inputs["axis_idx"]),
|
|
144
|
-
# dims=TODO...
|
|
145
|
-
# )
|
|
146
|
-
# axis_idx = check_inputs["axis_idx"]
|
|
147
|
-
|
|
148
|
-
# If the properties have changed in a breaking way then reset buffer and triggers.
|
|
149
|
-
b_reset = fs != check_inputs["fs"]
|
|
150
|
-
b_reset = b_reset or sample_shape != check_inputs["shape"]
|
|
151
|
-
# TODO: Skip next line if we do np.moveaxis above
|
|
152
|
-
b_reset = b_reset or axis_idx != check_inputs["axis_idx"]
|
|
153
|
-
b_reset = b_reset or msg_in.key != check_inputs["key"]
|
|
154
|
-
if b_reset:
|
|
155
|
-
check_inputs["fs"] = fs
|
|
156
|
-
check_inputs["shape"] = sample_shape
|
|
157
|
-
check_inputs["axis_idx"] = axis_idx
|
|
158
|
-
check_inputs["key"] = msg_in.key
|
|
159
|
-
n_samples = msg_in.data.shape[axis_idx]
|
|
160
|
-
buffer = None
|
|
161
|
-
if len(triggers) > 0:
|
|
162
|
-
ez.logger.warning("Data stream changed: Discarding all triggers")
|
|
163
|
-
triggers.clear()
|
|
164
|
-
|
|
165
|
-
# Save some info for trigger processing
|
|
166
|
-
offset = axis_info.offset
|
|
167
|
-
|
|
168
|
-
# Update buffer
|
|
169
|
-
buffer = (
|
|
170
|
-
msg_in.data
|
|
171
|
-
if buffer is None
|
|
172
|
-
else np.concatenate((buffer, msg_in.data), axis=axis_idx)
|
|
173
|
-
)
|
|
174
|
-
|
|
175
|
-
# Calculate timestamps associated with buffer.
|
|
176
|
-
buffer_offset = np.arange(buffer.shape[axis_idx], dtype=float)
|
|
177
|
-
buffer_offset -= buffer_offset[-msg_in.data.shape[axis_idx]]
|
|
178
|
-
buffer_offset *= axis_info.gain
|
|
179
|
-
buffer_offset += axis_info.offset
|
|
180
|
-
|
|
181
|
-
# ... for each trigger, collect the message (if possible) and append to msg_out
|
|
182
|
-
for trig in list(triggers):
|
|
183
|
-
if trig.period is None:
|
|
184
|
-
# This trigger was malformed; drop it.
|
|
185
|
-
triggers.remove(trig)
|
|
186
|
-
|
|
187
|
-
# If the previous iteration had insufficient data for the trigger timestamp + period,
|
|
188
|
-
# and buffer-management removed data required for the trigger, then we will never be able
|
|
189
|
-
# to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
|
|
190
|
-
if (trig.timestamp + trig.period[0]) < buffer_offset[0]:
|
|
191
|
-
ez.logger.warning(
|
|
192
|
-
f"Sampling failed: Buffer span {buffer_offset[0]} is beyond the "
|
|
193
|
-
f"requested sample period start: {trig.timestamp + trig.period[0]}"
|
|
194
|
-
)
|
|
195
|
-
triggers.remove(trig)
|
|
196
|
-
|
|
197
|
-
t_start = trig.timestamp + trig.period[0]
|
|
198
|
-
if t_start >= buffer_offset[0]:
|
|
199
|
-
start = np.searchsorted(buffer_offset, t_start)
|
|
200
|
-
stop = start + int(np.round(fs * (trig.period[1] - trig.period[0])))
|
|
201
|
-
if buffer.shape[axis_idx] > stop:
|
|
202
|
-
# Trigger period fully enclosed in buffer.
|
|
203
|
-
msg_out.append(
|
|
204
|
-
SampleMessage(
|
|
205
|
-
trigger=trig,
|
|
206
|
-
sample=replace(
|
|
207
|
-
msg_in,
|
|
208
|
-
data=slice_along_axis(
|
|
209
|
-
buffer, slice(start, stop), axis_idx
|
|
210
|
-
),
|
|
211
|
-
axes={
|
|
212
|
-
**msg_in.axes,
|
|
213
|
-
axis: replace(
|
|
214
|
-
axis_info, offset=buffer_offset[start]
|
|
215
|
-
),
|
|
216
|
-
},
|
|
217
|
-
),
|
|
218
|
-
)
|
|
219
|
-
)
|
|
220
|
-
triggers.remove(trig)
|
|
221
|
-
|
|
222
|
-
buf_len = int(buffer_dur * fs)
|
|
223
|
-
buffer = slice_along_axis(buffer, np.s_[-buf_len:], axis_idx)
|
|
16
|
+
from .util.message import SampleMessage, SampleTriggerMessage
|
|
17
|
+
from .base import (
|
|
18
|
+
BaseStatefulTransformer,
|
|
19
|
+
BaseConsumerUnit,
|
|
20
|
+
BaseTransformerUnit,
|
|
21
|
+
BaseStatefulProducer,
|
|
22
|
+
BaseProducerUnit,
|
|
23
|
+
processor_state,
|
|
24
|
+
)
|
|
224
25
|
|
|
225
26
|
|
|
226
27
|
class SamplerSettings(ez.Settings):
|
|
@@ -230,9 +31,20 @@ class SamplerSettings(ez.Settings):
|
|
|
230
31
|
"""
|
|
231
32
|
|
|
232
33
|
buffer_dur: float
|
|
34
|
+
"""
|
|
35
|
+
The duration of the buffer in seconds. The buffer must be long enough to store the oldest
|
|
36
|
+
sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
|
|
37
|
+
need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
|
|
38
|
+
"""
|
|
39
|
+
|
|
233
40
|
axis: str | None = None
|
|
41
|
+
"""
|
|
42
|
+
The axis along which to sample the data.
|
|
43
|
+
None (default) will choose the first axis in the first input.
|
|
44
|
+
Note: (for now) the axis must exist in the msg .axes and be of type AxisArray.LinearAxis
|
|
45
|
+
"""
|
|
234
46
|
period: tuple[float, float] | None = None
|
|
235
|
-
"""Optional default period if unspecified in SampleTriggerMessage"""
|
|
47
|
+
"""Optional default period (in seconds) if unspecified in SampleTriggerMessage."""
|
|
236
48
|
|
|
237
49
|
value: typing.Any = None
|
|
238
50
|
"""Optional default value if unspecified in SampleTriggerMessage"""
|
|
@@ -246,51 +58,211 @@ class SamplerSettings(ez.Settings):
|
|
|
246
58
|
"""
|
|
247
59
|
|
|
248
60
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
61
|
+
@processor_state
|
|
62
|
+
class SamplerState:
|
|
63
|
+
fs: float = 0.0
|
|
64
|
+
offset: float | None = None
|
|
65
|
+
buffer: npt.NDArray | None = None
|
|
66
|
+
triggers: deque[SampleTriggerMessage] | None = None
|
|
67
|
+
n_samples: int = 0
|
|
252
68
|
|
|
253
69
|
|
|
254
|
-
class
|
|
255
|
-
|
|
70
|
+
class SamplerTransformer(
|
|
71
|
+
BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]
|
|
72
|
+
):
|
|
73
|
+
def __call__(
|
|
74
|
+
self, message: AxisArray | SampleTriggerMessage
|
|
75
|
+
) -> list[SampleMessage]:
|
|
76
|
+
if isinstance(message, AxisArray):
|
|
77
|
+
return super().__call__(message)
|
|
78
|
+
else:
|
|
79
|
+
return self.push_trigger(message)
|
|
80
|
+
|
|
81
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
82
|
+
# Compute hash based on message properties that require state reset
|
|
83
|
+
axis = self.settings.axis or message.dims[0]
|
|
84
|
+
axis_idx = message.get_axis_idx(axis)
|
|
85
|
+
fs = 1.0 / message.get_axis(axis).gain
|
|
86
|
+
sample_shape = (
|
|
87
|
+
message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
88
|
+
)
|
|
89
|
+
return hash((fs, sample_shape, axis_idx, message.key))
|
|
90
|
+
|
|
91
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
92
|
+
axis = self.settings.axis or message.dims[0]
|
|
93
|
+
axis_idx = message.get_axis_idx(axis)
|
|
94
|
+
axis_info = message.get_axis(axis)
|
|
95
|
+
self._state.fs = 1.0 / axis_info.gain
|
|
96
|
+
self._state.buffer = None
|
|
97
|
+
if self._state.triggers is None:
|
|
98
|
+
self._state.triggers = deque()
|
|
99
|
+
self._state.triggers.clear()
|
|
100
|
+
self._state.n_samples = message.data.shape[axis_idx]
|
|
101
|
+
|
|
102
|
+
def _process(self, message: AxisArray) -> list[SampleMessage]:
|
|
103
|
+
axis = self.settings.axis or message.dims[0]
|
|
104
|
+
axis_idx = message.get_axis_idx(axis)
|
|
105
|
+
axis_info = message.get_axis(axis)
|
|
106
|
+
self._state.offset = axis_info.offset
|
|
107
|
+
|
|
108
|
+
# Update buffer
|
|
109
|
+
self._state.buffer = (
|
|
110
|
+
message.data
|
|
111
|
+
if self._state.buffer is None
|
|
112
|
+
else np.concatenate((self._state.buffer, message.data), axis=axis_idx)
|
|
113
|
+
)
|
|
256
114
|
|
|
257
|
-
|
|
258
|
-
|
|
115
|
+
# Calculate timestamps associated with buffer.
|
|
116
|
+
buffer_offset = np.arange(self._state.buffer.shape[axis_idx], dtype=float)
|
|
117
|
+
buffer_offset -= buffer_offset[-message.data.shape[axis_idx]]
|
|
118
|
+
buffer_offset *= axis_info.gain
|
|
119
|
+
buffer_offset += axis_info.offset
|
|
120
|
+
|
|
121
|
+
# ... for each trigger, collect the message (if possible) and append to msg_out
|
|
122
|
+
msg_out: list[SampleMessage] = []
|
|
123
|
+
for trig in list(self._state.triggers):
|
|
124
|
+
if trig.period is None:
|
|
125
|
+
# This trigger was malformed; drop it.
|
|
126
|
+
self._state.triggers.remove(trig)
|
|
127
|
+
|
|
128
|
+
# If the previous iteration had insufficient data for the trigger timestamp + period,
|
|
129
|
+
# and buffer-management removed data required for the trigger, then we will never be able
|
|
130
|
+
# to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
|
|
131
|
+
if (trig.timestamp + trig.period[0]) < buffer_offset[0]:
|
|
132
|
+
ez.logger.warning(
|
|
133
|
+
f"Sampling failed: Buffer span {buffer_offset[0]} is beyond the "
|
|
134
|
+
f"requested sample period start: {trig.timestamp + trig.period[0]}"
|
|
135
|
+
)
|
|
136
|
+
self._state.triggers.remove(trig)
|
|
259
137
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
138
|
+
t_start = trig.timestamp + trig.period[0]
|
|
139
|
+
if t_start >= buffer_offset[0]:
|
|
140
|
+
start = np.searchsorted(buffer_offset, t_start)
|
|
141
|
+
stop = start + int(
|
|
142
|
+
np.round(self._state.fs * (trig.period[1] - trig.period[0]))
|
|
143
|
+
)
|
|
144
|
+
if self._state.buffer.shape[axis_idx] > stop:
|
|
145
|
+
# Trigger period fully enclosed in buffer.
|
|
146
|
+
msg_out.append(
|
|
147
|
+
SampleMessage(
|
|
148
|
+
trigger=trig,
|
|
149
|
+
sample=replace(
|
|
150
|
+
message,
|
|
151
|
+
data=slice_along_axis(
|
|
152
|
+
self._state.buffer, slice(start, stop), axis_idx
|
|
153
|
+
),
|
|
154
|
+
axes={
|
|
155
|
+
**message.axes,
|
|
156
|
+
axis: replace(
|
|
157
|
+
axis_info, offset=buffer_offset[start]
|
|
158
|
+
),
|
|
159
|
+
},
|
|
160
|
+
),
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
self._state.triggers.remove(trig)
|
|
164
|
+
|
|
165
|
+
# Trim buffer
|
|
166
|
+
buf_len = int(self.settings.buffer_dur * self._state.fs)
|
|
167
|
+
self._state.buffer = slice_along_axis(
|
|
168
|
+
self._state.buffer, np.s_[-buf_len:], axis_idx
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return msg_out
|
|
172
|
+
|
|
173
|
+
def push_trigger(self, message: SampleTriggerMessage) -> list[SampleMessage]:
|
|
174
|
+
# Input is a trigger message that we will use to sample the buffer.
|
|
175
|
+
|
|
176
|
+
if (
|
|
177
|
+
self._state.buffer is None
|
|
178
|
+
or not self._state.fs
|
|
179
|
+
or self._state.offset is None
|
|
180
|
+
):
|
|
181
|
+
# We've yet to see any data; drop the trigger.
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
_period = message.period if message.period is not None else self.settings.period
|
|
185
|
+
_value = message.value if message.value is not None else self.settings.value
|
|
186
|
+
|
|
187
|
+
if _period is None:
|
|
188
|
+
ez.logger.warning("Sampling failed: period not specified")
|
|
189
|
+
return []
|
|
190
|
+
|
|
191
|
+
# Check that period is valid
|
|
192
|
+
if _period[0] >= _period[1]:
|
|
193
|
+
ez.logger.warning(f"Sampling failed: invalid period requested ({_period})")
|
|
194
|
+
return []
|
|
195
|
+
|
|
196
|
+
# Check that period is compatible with buffer duration.
|
|
197
|
+
max_buf_len = int(np.round(self.settings.buffer_dur * self._state.fs))
|
|
198
|
+
req_buf_len = int(np.round((_period[1] - _period[0]) * self._state.fs))
|
|
199
|
+
if req_buf_len >= max_buf_len:
|
|
200
|
+
ez.logger.warning(
|
|
201
|
+
f"Sampling failed: {_period=} >= {self.settings.buffer_dur=}"
|
|
202
|
+
)
|
|
203
|
+
return []
|
|
204
|
+
|
|
205
|
+
trigger_ts: float = message.timestamp
|
|
206
|
+
if not self.settings.estimate_alignment:
|
|
207
|
+
# Override the trigger timestamp with the next sample's likely timestamp.
|
|
208
|
+
trigger_ts = (
|
|
209
|
+
self._state.offset + (self.state.n_samples + 1) / self._state.fs
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
new_trig_msg = replace(
|
|
213
|
+
message, timestamp=trigger_ts, period=_period, value=_value
|
|
272
214
|
)
|
|
215
|
+
self._state.triggers.append(new_trig_msg)
|
|
216
|
+
return []
|
|
273
217
|
|
|
274
|
-
async def initialize(self) -> None:
|
|
275
|
-
self.STATE.cur_settings = self.SETTINGS
|
|
276
|
-
self.construct_generator()
|
|
277
218
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
219
|
+
class Sampler(
|
|
220
|
+
BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]
|
|
221
|
+
):
|
|
222
|
+
SETTINGS = SamplerSettings
|
|
223
|
+
|
|
224
|
+
INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
|
|
225
|
+
OUTPUT_SIGNAL = ez.OutputStream(SampleMessage)
|
|
282
226
|
|
|
283
227
|
@ez.subscriber(INPUT_TRIGGER)
|
|
284
228
|
async def on_trigger(self, msg: SampleTriggerMessage) -> None:
|
|
285
|
-
_ = self.
|
|
229
|
+
_ = self.processor.push_trigger(msg)
|
|
286
230
|
|
|
287
|
-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
288
|
-
@ez.publisher(
|
|
231
|
+
@ez.subscriber(BaseConsumerUnit.INPUT_SIGNAL, zero_copy=True)
|
|
232
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
289
233
|
@profile_subpub(trace_oldest=False)
|
|
290
|
-
async def on_signal(self,
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
234
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
235
|
+
try:
|
|
236
|
+
for sample in self.processor(message):
|
|
237
|
+
yield self.OUTPUT_SIGNAL, sample
|
|
238
|
+
except Exception as e:
|
|
239
|
+
ez.logger.info(f"{traceback.format_exc()} - {e}")
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def sampler(
|
|
243
|
+
buffer_dur: float,
|
|
244
|
+
axis: str | None = None,
|
|
245
|
+
period: tuple[float, float] | None = None,
|
|
246
|
+
value: typing.Any = None,
|
|
247
|
+
estimate_alignment: bool = True,
|
|
248
|
+
) -> SamplerTransformer:
|
|
249
|
+
"""
|
|
250
|
+
Sample data into a buffer, accept triggers, and return slices of sampled
|
|
251
|
+
data around the trigger time.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
A generator that expects `.send` either an :obj:`AxisArray` containing streaming data messages,
|
|
255
|
+
or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`SampleMessage` s.
|
|
256
|
+
"""
|
|
257
|
+
return SamplerTransformer(
|
|
258
|
+
settings=SamplerSettings(
|
|
259
|
+
buffer_dur=buffer_dur,
|
|
260
|
+
axis=axis,
|
|
261
|
+
period=period,
|
|
262
|
+
value=value,
|
|
263
|
+
estimate_alignment=estimate_alignment,
|
|
264
|
+
)
|
|
265
|
+
)
|
|
294
266
|
|
|
295
267
|
|
|
296
268
|
class TriggerGeneratorSettings(ez.Settings):
|
|
@@ -304,23 +276,33 @@ class TriggerGeneratorSettings(ez.Settings):
|
|
|
304
276
|
"""The period between triggers (sec)"""
|
|
305
277
|
|
|
306
278
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
"""
|
|
279
|
+
@processor_state
|
|
280
|
+
class TriggerGeneratorState:
|
|
281
|
+
output: int = 0
|
|
311
282
|
|
|
312
|
-
SETTINGS = TriggerGeneratorSettings
|
|
313
283
|
|
|
314
|
-
|
|
284
|
+
class TriggerProducer(
|
|
285
|
+
BaseStatefulProducer[
|
|
286
|
+
TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState
|
|
287
|
+
]
|
|
288
|
+
):
|
|
289
|
+
def _reset_state(self) -> None:
|
|
290
|
+
self._state.output = 0
|
|
315
291
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
292
|
+
async def _produce(self) -> SampleTriggerMessage:
|
|
293
|
+
await asyncio.sleep(self.settings.publish_period)
|
|
294
|
+
out_msg = SampleTriggerMessage(
|
|
295
|
+
period=self.settings.period, value=self._state.output
|
|
296
|
+
)
|
|
297
|
+
self._state.output += 1
|
|
298
|
+
return out_msg
|
|
319
299
|
|
|
320
|
-
output = 0
|
|
321
|
-
while True:
|
|
322
|
-
out_msg = SampleTriggerMessage(period=self.SETTINGS.period, value=output)
|
|
323
|
-
yield self.OUTPUT_TRIGGER, out_msg
|
|
324
300
|
|
|
325
|
-
|
|
326
|
-
|
|
301
|
+
class TriggerGenerator(
|
|
302
|
+
BaseProducerUnit[
|
|
303
|
+
TriggerGeneratorSettings,
|
|
304
|
+
SampleTriggerMessage,
|
|
305
|
+
TriggerProducer,
|
|
306
|
+
]
|
|
307
|
+
):
|
|
308
|
+
SETTINGS = TriggerGeneratorSettings
|