ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,287 +1,278 @@
|
|
|
1
|
-
|
|
2
|
-
import
|
|
1
|
+
import asyncio
|
|
2
|
+
import copy
|
|
3
|
+
import traceback
|
|
4
|
+
import typing
|
|
5
|
+
from collections import deque
|
|
3
6
|
|
|
4
7
|
import ezmsg.core as ez
|
|
5
8
|
import numpy as np
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
9
|
+
from ezmsg.baseproc import (
|
|
10
|
+
BaseConsumerUnit,
|
|
11
|
+
BaseProducerUnit,
|
|
12
|
+
BaseStatefulProducer,
|
|
13
|
+
BaseStatefulTransformer,
|
|
14
|
+
BaseTransformerUnit,
|
|
15
|
+
processor_state,
|
|
16
|
+
)
|
|
17
|
+
from ezmsg.util.messages.axisarray import (
|
|
18
|
+
AxisArray,
|
|
19
|
+
)
|
|
20
|
+
from ezmsg.util.messages.util import replace
|
|
21
|
+
|
|
22
|
+
from .util.axisarray_buffer import HybridAxisArrayBuffer
|
|
23
|
+
from .util.buffer import UpdateStrategy
|
|
24
|
+
from .util.message import SampleMessage, SampleTriggerMessage
|
|
25
|
+
from .util.profile import profile_subpub
|
|
23
26
|
|
|
24
27
|
|
|
25
28
|
class SamplerSettings(ez.Settings):
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
] = None # Optional default period if unspecified in SampleTriggerMessage
|
|
31
|
-
value: Any = None # Optional default value if unspecified in SampleTriggerMessage
|
|
32
|
-
|
|
33
|
-
estimate_alignment: bool = True
|
|
34
|
-
# If true, use message timestamp fields and reported sampling rate to estimate
|
|
35
|
-
# sample-accurate alignment for samples.
|
|
36
|
-
# If false, sampling will be limited to incoming message rate -- "Block timing"
|
|
37
|
-
# NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
|
|
38
|
-
# "realtime" operation for estimate_alignment to operate correctly.
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class SamplerState(ez.State):
|
|
42
|
-
cur_settings: SamplerSettings
|
|
43
|
-
triggers: Dict[SampleTriggerMessage, int] = field(default_factory=dict)
|
|
44
|
-
last_msg: Optional[AxisArray] = None
|
|
45
|
-
buffer: Optional[np.ndarray] = None
|
|
46
|
-
|
|
29
|
+
"""
|
|
30
|
+
Settings for :obj:`Sampler`.
|
|
31
|
+
See :obj:`sampler` for a description of the fields.
|
|
32
|
+
"""
|
|
47
33
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
34
|
+
buffer_dur: float
|
|
35
|
+
"""
|
|
36
|
+
The duration of the buffer in seconds. The buffer must be long enough to store the oldest
|
|
37
|
+
sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
|
|
38
|
+
need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
|
|
39
|
+
"""
|
|
51
40
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
41
|
+
axis: str | None = None
|
|
42
|
+
"""
|
|
43
|
+
The axis along which to sample the data.
|
|
44
|
+
None (default) will choose the first axis in the first input.
|
|
45
|
+
Note: (for now) the axis must exist in the msg .axes and be of type AxisArray.LinearAxis
|
|
46
|
+
"""
|
|
56
47
|
|
|
57
|
-
|
|
58
|
-
|
|
48
|
+
period: tuple[float, float] | None = None
|
|
49
|
+
"""Optional default period (in seconds) if unspecified in SampleTriggerMessage."""
|
|
59
50
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
self.STATE.cur_settings = msg
|
|
63
|
-
|
|
64
|
-
@ez.subscriber(INPUT_TRIGGER)
|
|
65
|
-
async def on_trigger(self, msg: SampleTriggerMessage) -> None:
|
|
66
|
-
if self.STATE.last_msg is not None:
|
|
67
|
-
axis_name = self.STATE.cur_settings.axis
|
|
68
|
-
if axis_name is None:
|
|
69
|
-
axis_name = self.STATE.last_msg.dims[0]
|
|
70
|
-
axis = self.STATE.last_msg.get_axis(axis_name)
|
|
71
|
-
axis_idx = self.STATE.last_msg.get_axis_idx(axis_name)
|
|
72
|
-
|
|
73
|
-
fs = 1.0 / axis.gain
|
|
74
|
-
last_msg_timestamp = axis.offset + (
|
|
75
|
-
self.STATE.last_msg.shape[axis_idx] / fs
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
period = (
|
|
79
|
-
msg.period if msg.period is not None else self.STATE.cur_settings.period
|
|
80
|
-
)
|
|
81
|
-
value = (
|
|
82
|
-
msg.value if msg.value is not None else self.STATE.cur_settings.value
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
if period is None:
|
|
86
|
-
ez.logger.warning(f"Sampling failed: period not specified")
|
|
87
|
-
return
|
|
88
|
-
|
|
89
|
-
# Check that period is valid
|
|
90
|
-
start_offset = int(period[0] * fs)
|
|
91
|
-
stop_offset = int(period[1] * fs)
|
|
92
|
-
if (stop_offset - start_offset) <= 0:
|
|
93
|
-
ez.logger.warning(f"Sampling failed: invalid period requested")
|
|
94
|
-
return
|
|
95
|
-
|
|
96
|
-
# Check that period is compatible with buffer duration
|
|
97
|
-
max_buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
|
|
98
|
-
req_buf_len = int((period[1] - period[0]) * fs)
|
|
99
|
-
if req_buf_len >= max_buf_len:
|
|
100
|
-
ez.logger.warning(
|
|
101
|
-
f"Sampling failed: {period=} >= {self.STATE.cur_settings.buffer_dur=}"
|
|
102
|
-
)
|
|
103
|
-
return
|
|
104
|
-
|
|
105
|
-
offset: int = 0
|
|
106
|
-
if self.STATE.cur_settings.estimate_alignment:
|
|
107
|
-
# Do what we can with the wall clock to determine sample alignment
|
|
108
|
-
wall_delta = msg.timestamp - last_msg_timestamp
|
|
109
|
-
offset = int(wall_delta * fs)
|
|
110
|
-
|
|
111
|
-
# Check that current buffer accumulation allows for offset - period start
|
|
112
|
-
if (
|
|
113
|
-
self.STATE.buffer is None
|
|
114
|
-
or -min(offset + start_offset, 0) >= self.STATE.buffer.shape[0]
|
|
115
|
-
):
|
|
116
|
-
ez.logger.warning(
|
|
117
|
-
"Sampling failed: insufficient buffer accumulation for requested sample period"
|
|
118
|
-
)
|
|
119
|
-
return
|
|
120
|
-
|
|
121
|
-
self.STATE.triggers[replace(msg, period=period, value=value)] = offset
|
|
51
|
+
value: typing.Any = None
|
|
52
|
+
"""Optional default value if unspecified in SampleTriggerMessage"""
|
|
122
53
|
|
|
54
|
+
estimate_alignment: bool = True
|
|
55
|
+
"""
|
|
56
|
+
If true, use message timestamp fields and reported sampling rate to estimate
|
|
57
|
+
sample-accurate alignment for samples.
|
|
58
|
+
If false, sampling will be limited to incoming message rate -- "Block timing"
|
|
59
|
+
NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
|
|
60
|
+
"realtime" operation for estimate_alignment to operate correctly.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
buffer_update_strategy: UpdateStrategy = "immediate"
|
|
64
|
+
"""
|
|
65
|
+
The buffer update strategy. See :obj:`ezmsg.sigproc.util.buffer.UpdateStrategy`.
|
|
66
|
+
If you expect to push data much more frequently than triggers, then "on_demand"
|
|
67
|
+
might be more efficient. For most other scenarios, "immediate" is best.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@processor_state
|
|
72
|
+
class SamplerState:
|
|
73
|
+
buffer: HybridAxisArrayBuffer | None = None
|
|
74
|
+
triggers: deque[SampleTriggerMessage] | None = None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]):
|
|
78
|
+
def __call__(self, message: AxisArray | SampleTriggerMessage) -> list[SampleMessage]:
|
|
79
|
+
# TODO: Currently we have a single entry point that accepts both
|
|
80
|
+
# data and trigger messages and we choose a code path based on
|
|
81
|
+
# the message type. However, in the future we will likely replace
|
|
82
|
+
# SampleTriggerMessage with an agumented form of AxisArray,
|
|
83
|
+
# leveraging its attrs field, which makes this a bit harder.
|
|
84
|
+
# We should probably force callers of this object to explicitly
|
|
85
|
+
# call `push_trigger` for trigger messages. This will also
|
|
86
|
+
# simplify typing somewhat because `push_trigger` should not
|
|
87
|
+
# return anything yet we currently have it returning an empty
|
|
88
|
+
# list just to be compatible with __call__.
|
|
89
|
+
if isinstance(message, AxisArray):
|
|
90
|
+
return super().__call__(message)
|
|
123
91
|
else:
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
self.
|
|
138
|
-
|
|
139
|
-
# Easier to deal with timeseries on axis 0
|
|
140
|
-
last_msg = self.STATE.last_msg
|
|
141
|
-
msg_data = np.moveaxis(msg.data, msg.get_axis_idx(axis_name), 0)
|
|
142
|
-
last_msg_data = np.moveaxis(last_msg.data, last_msg.get_axis_idx(axis_name), 0)
|
|
143
|
-
last_msg_axis = last_msg.get_axis(axis_name)
|
|
144
|
-
last_msg_fs = 1.0 / last_msg_axis.gain
|
|
145
|
-
|
|
146
|
-
# Check if signal properties have changed in a breaking way
|
|
147
|
-
if fs != last_msg_fs or msg_data.shape[1:] != last_msg_data.shape[1:]:
|
|
148
|
-
# Data stream changed meaningfully -- flush buffer, stop sampling
|
|
149
|
-
if len(self.STATE.triggers) > 0:
|
|
150
|
-
ez.logger.warning("Sampling failed: Discarding all triggers")
|
|
151
|
-
ez.logger.warning("Flushing buffer: signal properties changed")
|
|
152
|
-
self.STATE.buffer = None
|
|
153
|
-
self.STATE.triggers = dict()
|
|
154
|
-
|
|
155
|
-
# Accumulate buffer ( time dim => dim 0 )
|
|
156
|
-
self.STATE.buffer = (
|
|
157
|
-
msg_data
|
|
158
|
-
if self.STATE.buffer is None
|
|
159
|
-
else np.concatenate((self.STATE.buffer, msg_data), axis=0)
|
|
92
|
+
return self.push_trigger(message)
|
|
93
|
+
|
|
94
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
95
|
+
# Compute hash based on message properties that require state reset
|
|
96
|
+
axis = self.settings.axis or message.dims[0]
|
|
97
|
+
axis_idx = message.get_axis_idx(axis)
|
|
98
|
+
sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
99
|
+
return hash((sample_shape, message.key))
|
|
100
|
+
|
|
101
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
102
|
+
self._state.buffer = HybridAxisArrayBuffer(
|
|
103
|
+
duration=self.settings.buffer_dur,
|
|
104
|
+
axis=self.settings.axis or message.dims[0],
|
|
105
|
+
update_strategy=self.settings.buffer_update_strategy,
|
|
106
|
+
overflow_strategy="warn-overwrite", # True circular buffer
|
|
160
107
|
)
|
|
108
|
+
if self._state.triggers is None:
|
|
109
|
+
self._state.triggers = deque()
|
|
110
|
+
self._state.triggers.clear()
|
|
161
111
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
buffer_offset = (buffer_offset * axis.gain) + axis.offset
|
|
112
|
+
def _process(self, message: AxisArray) -> list[SampleMessage]:
|
|
113
|
+
self._state.buffer.write(message)
|
|
165
114
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
# trigger_offset points to t = 0 within buffer
|
|
173
|
-
offset -= msg_data.shape[0]
|
|
174
|
-
start = offset + int(trigger.period[0] * fs)
|
|
175
|
-
stop = offset + int(trigger.period[1] * fs)
|
|
115
|
+
# How much data in the buffer?
|
|
116
|
+
buff_t_range = (
|
|
117
|
+
self._state.buffer.axis_first_value,
|
|
118
|
+
self._state.buffer.axis_final_value,
|
|
119
|
+
)
|
|
176
120
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
121
|
+
# Process in reverse order so that we can remove triggers safely as we iterate.
|
|
122
|
+
msgs_out: list[SampleMessage] = []
|
|
123
|
+
for trig_ix in range(len(self._state.triggers) - 1, -1, -1):
|
|
124
|
+
trig = self._state.triggers[trig_ix]
|
|
125
|
+
if trig.period is None:
|
|
126
|
+
ez.logger.warning("Sampling failed: trigger period not specified")
|
|
127
|
+
del self._state.triggers[trig_ix]
|
|
128
|
+
continue
|
|
180
129
|
|
|
181
|
-
|
|
182
|
-
sample_axis = replace(axis, offset=sample_offset)
|
|
183
|
-
sample_axes = {**msg.axes, **{axis_name: sample_axis}}
|
|
130
|
+
trig_range = trig.timestamp + np.array(trig.period)
|
|
184
131
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
132
|
+
# If the previous iteration had insufficient data for the trigger timestamp + period,
|
|
133
|
+
# and buffer-management removed data required for the trigger, then we will never be able
|
|
134
|
+
# to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
|
|
135
|
+
if trig_range[0] < buff_t_range[0]:
|
|
136
|
+
ez.logger.warning(
|
|
137
|
+
f"Sampling failed: Buffer span {buff_t_range} begins beyond the "
|
|
138
|
+
f"requested sample period start: {trig_range[0]}"
|
|
190
139
|
)
|
|
140
|
+
del self._state.triggers[trig_ix]
|
|
141
|
+
continue
|
|
191
142
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
for sample in pub_samples:
|
|
196
|
-
yield self.OUTPUT_SAMPLE, sample
|
|
143
|
+
if trig_range[1] > buff_t_range[1]:
|
|
144
|
+
# We don't *yet* have enough data to satisfy this trigger.
|
|
145
|
+
continue
|
|
197
146
|
|
|
198
|
-
|
|
147
|
+
# We know we have enough data in the buffer to satisfy this trigger.
|
|
148
|
+
buff_idx = self._state.buffer.axis_searchsorted(trig_range, side="right")
|
|
149
|
+
self._state.buffer.seek(buff_idx[0]) # FFWD to starting position.
|
|
150
|
+
buff_axarr = self._state.buffer.peek(buff_idx[1] - buff_idx[0])
|
|
151
|
+
self._state.buffer.seek(-buff_idx[0]) # Rewind it back.
|
|
152
|
+
# Note: buffer will trim itself as needed based on buffer_dur.
|
|
199
153
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
154
|
+
# Prepare output and drop trigger
|
|
155
|
+
msgs_out.append(SampleMessage(trigger=copy.copy(trig), sample=buff_axarr))
|
|
156
|
+
del self._state.triggers[trig_ix]
|
|
203
157
|
|
|
158
|
+
msgs_out.reverse() # in-place
|
|
159
|
+
return msgs_out
|
|
204
160
|
|
|
205
|
-
|
|
206
|
-
|
|
161
|
+
def push_trigger(self, message: SampleTriggerMessage) -> list[SampleMessage]:
|
|
162
|
+
# Input is a trigger message that we will use to sample the buffer.
|
|
207
163
|
|
|
208
|
-
|
|
209
|
-
|
|
164
|
+
if self._state.buffer is None:
|
|
165
|
+
# We've yet to see any data; drop the trigger.
|
|
166
|
+
return []
|
|
210
167
|
|
|
211
|
-
|
|
168
|
+
_period = message.period if message.period is not None else self.settings.period
|
|
169
|
+
_value = message.value if message.value is not None else self.settings.value
|
|
212
170
|
|
|
171
|
+
if _period is None:
|
|
172
|
+
ez.logger.warning("Sampling failed: period not specified")
|
|
173
|
+
return []
|
|
213
174
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
175
|
+
# Check that period is valid
|
|
176
|
+
if _period[0] >= _period[1]:
|
|
177
|
+
ez.logger.warning(f"Sampling failed: invalid period requested ({_period})")
|
|
178
|
+
return []
|
|
218
179
|
|
|
180
|
+
# Check that period is compatible with buffer duration.
|
|
181
|
+
if (_period[1] - _period[0]) > self.settings.buffer_dur:
|
|
182
|
+
ez.logger.warning(
|
|
183
|
+
f"Sampling failed: trigger period {_period=} >= buffer capacity {self.settings.buffer_dur=}"
|
|
184
|
+
)
|
|
185
|
+
return []
|
|
219
186
|
|
|
220
|
-
|
|
221
|
-
|
|
187
|
+
trigger_ts: float = message.timestamp
|
|
188
|
+
if not self.settings.estimate_alignment:
|
|
189
|
+
# Override the trigger timestamp with the next sample's likely timestamp.
|
|
190
|
+
trigger_ts = self._state.buffer.axis_final_value + self._state.buffer.axis_gain
|
|
222
191
|
|
|
223
|
-
|
|
192
|
+
new_trig_msg = replace(message, timestamp=trigger_ts, period=_period, value=_value)
|
|
193
|
+
self._state.triggers.append(new_trig_msg)
|
|
194
|
+
return []
|
|
224
195
|
|
|
225
|
-
@ez.publisher(OUTPUT_TRIGGER)
|
|
226
|
-
async def generate(self) -> AsyncGenerator:
|
|
227
|
-
await asyncio.sleep(self.SETTINGS.prewait)
|
|
228
196
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
yield self.OUTPUT_TRIGGER, SampleTriggerMessage(
|
|
232
|
-
period=self.SETTINGS.period, value=output
|
|
233
|
-
)
|
|
197
|
+
class Sampler(BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]):
|
|
198
|
+
SETTINGS = SamplerSettings
|
|
234
199
|
|
|
235
|
-
|
|
236
|
-
|
|
200
|
+
INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
|
|
201
|
+
OUTPUT_SIGNAL = ez.OutputStream(SampleMessage)
|
|
237
202
|
|
|
203
|
+
@ez.subscriber(INPUT_TRIGGER)
|
|
204
|
+
async def on_trigger(self, msg: SampleTriggerMessage) -> None:
|
|
205
|
+
_ = self.processor.push_trigger(msg)
|
|
206
|
+
|
|
207
|
+
@ez.subscriber(BaseConsumerUnit.INPUT_SIGNAL, zero_copy=True)
|
|
208
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
209
|
+
@profile_subpub(trace_oldest=False)
|
|
210
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
211
|
+
try:
|
|
212
|
+
for sample in self.processor(message):
|
|
213
|
+
yield self.OUTPUT_SIGNAL, sample
|
|
214
|
+
except Exception as e:
|
|
215
|
+
ez.logger.info(f"{traceback.format_exc()} - {e}")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def sampler(
|
|
219
|
+
buffer_dur: float,
|
|
220
|
+
axis: str | None = None,
|
|
221
|
+
period: tuple[float, float] | None = None,
|
|
222
|
+
value: typing.Any = None,
|
|
223
|
+
estimate_alignment: bool = True,
|
|
224
|
+
) -> SamplerTransformer:
|
|
225
|
+
"""
|
|
226
|
+
Sample data into a buffer, accept triggers, and return slices of sampled
|
|
227
|
+
data around the trigger time.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
A generator that expects `.send` either an :obj:`AxisArray` containing streaming data messages,
|
|
231
|
+
or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`SampleMessage` s.
|
|
232
|
+
"""
|
|
233
|
+
return SamplerTransformer(
|
|
234
|
+
settings=SamplerSettings(
|
|
235
|
+
buffer_dur=buffer_dur,
|
|
236
|
+
axis=axis,
|
|
237
|
+
period=period,
|
|
238
|
+
value=value,
|
|
239
|
+
estimate_alignment=estimate_alignment,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
238
242
|
|
|
239
|
-
class SamplerTestSystemSettings(ez.Settings):
|
|
240
|
-
sampler_settings: SamplerSettings
|
|
241
|
-
trigger_settings: TriggerGeneratorSettings
|
|
242
243
|
|
|
244
|
+
class TriggerGeneratorSettings(ez.Settings):
|
|
245
|
+
period: tuple[float, float]
|
|
246
|
+
"""The period around the trigger event."""
|
|
243
247
|
|
|
244
|
-
|
|
245
|
-
|
|
248
|
+
prewait: float = 0.5
|
|
249
|
+
"""The time before the first trigger (sec)"""
|
|
246
250
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
TRIGGER = TriggerGenerator()
|
|
250
|
-
DEBUG = DebugLog()
|
|
251
|
+
publish_period: float = 5.0
|
|
252
|
+
"""The period between triggers (sec)"""
|
|
251
253
|
|
|
252
|
-
def configure(self) -> None:
|
|
253
|
-
self.SAMPLER.apply_settings(self.SETTINGS.sampler_settings)
|
|
254
|
-
self.TRIGGER.apply_settings(self.SETTINGS.trigger_settings)
|
|
255
254
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
fs=10, # Sampling rate of signal output in Hz
|
|
260
|
-
dispatch_rate="realtime",
|
|
261
|
-
freq=2.0, # Oscillation frequency in Hz
|
|
262
|
-
amp=1.0, # Amplitude
|
|
263
|
-
phase=0.0, # Phase offset (in radians)
|
|
264
|
-
sync=True, # Adjust `freq` to sync with sampling rate
|
|
265
|
-
)
|
|
266
|
-
)
|
|
255
|
+
@processor_state
|
|
256
|
+
class TriggerGeneratorState:
|
|
257
|
+
output: int = 0
|
|
267
258
|
|
|
268
|
-
def network(self) -> ez.NetworkDefinition:
|
|
269
|
-
return (
|
|
270
|
-
(self.OSC.OUTPUT_SIGNAL, self.SAMPLER.INPUT_SIGNAL),
|
|
271
|
-
(self.TRIGGER.OUTPUT_TRIGGER, self.SAMPLER.INPUT_TRIGGER),
|
|
272
|
-
(self.TRIGGER.OUTPUT_TRIGGER, self.DEBUG.INPUT),
|
|
273
|
-
(self.SAMPLER.OUTPUT_SAMPLE, self.DEBUG.INPUT),
|
|
274
|
-
)
|
|
275
259
|
|
|
260
|
+
class TriggerProducer(BaseStatefulProducer[TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState]):
|
|
261
|
+
def _reset_state(self) -> None:
|
|
262
|
+
self._state.output = 0
|
|
276
263
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
),
|
|
283
|
-
)
|
|
264
|
+
async def _produce(self) -> SampleTriggerMessage:
|
|
265
|
+
await asyncio.sleep(self.settings.publish_period)
|
|
266
|
+
out_msg = SampleTriggerMessage(period=self.settings.period, value=self._state.output)
|
|
267
|
+
self._state.output += 1
|
|
268
|
+
return out_msg
|
|
284
269
|
|
|
285
|
-
system = SamplerTestSystem(settings)
|
|
286
270
|
|
|
287
|
-
|
|
271
|
+
class TriggerGenerator(
|
|
272
|
+
BaseProducerUnit[
|
|
273
|
+
TriggerGeneratorSettings,
|
|
274
|
+
SampleTriggerMessage,
|
|
275
|
+
TriggerProducer,
|
|
276
|
+
]
|
|
277
|
+
):
|
|
278
|
+
SETTINGS = TriggerGeneratorSettings
|