ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/sampler.py CHANGED
@@ -1,293 +1,244 @@
1
- import asyncio # Dev/test apparatus
2
- from collections import deque
3
- from dataclasses import dataclass, field
4
- import time
1
+ import asyncio
2
+ import copy
3
+ import traceback
5
4
  import typing
5
+ from collections import deque
6
6
 
7
- import numpy as np
8
- import numpy.typing as npt
9
7
  import ezmsg.core as ez
8
+ import numpy as np
9
+ from ezmsg.baseproc import (
10
+ BaseConsumerUnit,
11
+ BaseProducerUnit,
12
+ BaseStatefulProducer,
13
+ BaseStatefulTransformer,
14
+ BaseTransformerUnit,
15
+ processor_state,
16
+ )
10
17
  from ezmsg.util.messages.axisarray import (
11
18
  AxisArray,
12
- slice_along_axis,
13
19
  )
14
20
  from ezmsg.util.messages.util import replace
15
- from ezmsg.util.generator import consumer
16
-
17
21
 
18
- @dataclass(unsafe_hash=True)
19
- class SampleTriggerMessage:
20
- timestamp: float = field(default_factory=time.time)
21
- """Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
22
+ from .util.axisarray_buffer import HybridAxisArrayBuffer
23
+ from .util.buffer import UpdateStrategy
24
+ from .util.message import SampleMessage, SampleTriggerMessage
25
+ from .util.profile import profile_subpub
22
26
 
23
- period: tuple[float, float] | None = None
24
- """The period around the timestamp, in seconds"""
25
27
 
26
- value: typing.Any = None
27
- """A value or 'label' associated with the trigger."""
28
+ class SamplerSettings(ez.Settings):
29
+ """
30
+ Settings for :obj:`Sampler`.
31
+ See :obj:`sampler` for a description of the fields.
32
+ """
28
33
 
34
+ buffer_dur: float
35
+ """
36
+ The duration of the buffer in seconds. The buffer must be long enough to store the oldest
37
+ sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
38
+ need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
39
+ """
29
40
 
30
- @dataclass
31
- class SampleMessage:
32
- trigger: SampleTriggerMessage
33
- """The time, window, and value (if any) associated with the trigger."""
41
+ axis: str | None = None
42
+ """
43
+ The axis along which to sample the data.
44
+ None (default) will choose the first axis in the first input.
45
+ Note: (for now) the axis must exist in the msg .axes and be of type AxisArray.LinearAxis
46
+ """
34
47
 
35
- sample: AxisArray
36
- """The data sampled around the trigger."""
48
+ period: tuple[float, float] | None = None
49
+ """Optional default period (in seconds) if unspecified in SampleTriggerMessage."""
37
50
 
51
+ value: typing.Any = None
52
+ """Optional default value if unspecified in SampleTriggerMessage"""
38
53
 
39
- @consumer
40
- def sampler(
41
- buffer_dur: float,
42
- axis: str | None = None,
43
- period: tuple[float, float] | None = None,
44
- value: typing.Any = None,
45
- estimate_alignment: bool = True,
46
- ) -> typing.Generator[list[SampleMessage], AxisArray | SampleTriggerMessage, None]:
54
+ estimate_alignment: bool = True
55
+ """
56
+ If true, use message timestamp fields and reported sampling rate to estimate
57
+ sample-accurate alignment for samples.
58
+ If false, sampling will be limited to incoming message rate -- "Block timing"
59
+ NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
60
+ "realtime" operation for estimate_alignment to operate correctly.
47
61
  """
48
- Sample data into a buffer, accept triggers, and return slices of sampled
49
- data around the trigger time.
50
-
51
- Args:
52
- buffer_dur: The duration of the buffer in seconds. The buffer must be long enough to store the oldest
53
- sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
54
- need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
55
- axis: The axis along which to sample the data.
56
- None (default) will choose the first axis in the first input.
57
- Note: (for now) the axis must exist in the msg .axes and be of type AxisArray.LinearAxis
58
- period: The period in seconds during which to sample the data.
59
- Defaults to None. Only used if not None and the trigger message does not define its own period.
60
- value: The value to sample. Defaults to None.
61
- estimate_alignment: Whether to estimate the sample alignment. Defaults to True.
62
- If True, the trigger timestamp field is used to slice the buffer.
63
- If False, the trigger timestamp is ignored and the next signal's .offset is used.
64
- NOTE: For faster-than-realtime playback -- Signals and triggers must share the same (fast) clock for
65
- estimate_alignment to operate correctly.
66
62
 
67
- Returns:
68
- A generator that expects `.send` either an :obj:`AxisArray` containing streaming data messages,
69
- or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`SampleMessage` s.
63
+ buffer_update_strategy: UpdateStrategy = "immediate"
64
+ """
65
+ The buffer update strategy. See :obj:`ezmsg.sigproc.util.buffer.UpdateStrategy`.
66
+ If you expect to push data much more frequently than triggers, then "on_demand"
67
+ might be more efficient. For most other scenarios, "immediate" is best.
70
68
  """
71
- msg_out: list[SampleMessage] = []
72
69
 
73
- # State variables (most shared between trigger- and data-processing.
74
- triggers: deque[SampleTriggerMessage] = deque()
75
- buffer: npt.NDArray | None = None
76
- n_samples: int = 0
77
- offset: float = 0.0
78
70
 
79
- check_inputs = {
80
- "fs": None, # Also a state variable
81
- "key": None,
82
- "shape": None,
83
- }
71
+ @processor_state
72
+ class SamplerState:
73
+ buffer: HybridAxisArrayBuffer | None = None
74
+ triggers: deque[SampleTriggerMessage] | None = None
75
+
76
+
77
+ class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]):
78
+ def __call__(self, message: AxisArray | SampleTriggerMessage) -> list[SampleMessage]:
79
+ # TODO: Currently we have a single entry point that accepts both
80
+ # data and trigger messages and we choose a code path based on
81
+ # the message type. However, in the future we will likely replace
82
+ # SampleTriggerMessage with an agumented form of AxisArray,
83
+ # leveraging its attrs field, which makes this a bit harder.
84
+ # We should probably force callers of this object to explicitly
85
+ # call `push_trigger` for trigger messages. This will also
86
+ # simplify typing somewhat because `push_trigger` should not
87
+ # return anything yet we currently have it returning an empty
88
+ # list just to be compatible with __call__.
89
+ if isinstance(message, AxisArray):
90
+ return super().__call__(message)
91
+ else:
92
+ return self.push_trigger(message)
93
+
94
+ def _hash_message(self, message: AxisArray) -> int:
95
+ # Compute hash based on message properties that require state reset
96
+ axis = self.settings.axis or message.dims[0]
97
+ axis_idx = message.get_axis_idx(axis)
98
+ sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
99
+ return hash((sample_shape, message.key))
100
+
101
+ def _reset_state(self, message: AxisArray) -> None:
102
+ self._state.buffer = HybridAxisArrayBuffer(
103
+ duration=self.settings.buffer_dur,
104
+ axis=self.settings.axis or message.dims[0],
105
+ update_strategy=self.settings.buffer_update_strategy,
106
+ overflow_strategy="warn-overwrite", # True circular buffer
107
+ )
108
+ if self._state.triggers is None:
109
+ self._state.triggers = deque()
110
+ self._state.triggers.clear()
84
111
 
85
- while True:
86
- msg_in = yield msg_out
87
- msg_out = []
112
+ def _process(self, message: AxisArray) -> list[SampleMessage]:
113
+ self._state.buffer.write(message)
88
114
 
89
- if isinstance(msg_in, SampleTriggerMessage):
90
- # Input is a trigger message that we will use to sample the buffer.
115
+ # How much data in the buffer?
116
+ buff_t_range = (
117
+ self._state.buffer.axis_first_value,
118
+ self._state.buffer.axis_final_value,
119
+ )
91
120
 
92
- if buffer is None or check_inputs["fs"] is None:
93
- # We've yet to see any data; drop the trigger.
121
+ # Process in reverse order so that we can remove triggers safely as we iterate.
122
+ msgs_out: list[SampleMessage] = []
123
+ for trig_ix in range(len(self._state.triggers) - 1, -1, -1):
124
+ trig = self._state.triggers[trig_ix]
125
+ if trig.period is None:
126
+ ez.logger.warning("Sampling failed: trigger period not specified")
127
+ del self._state.triggers[trig_ix]
94
128
  continue
95
129
 
96
- _period = msg_in.period if msg_in.period is not None else period
97
- _value = msg_in.value if msg_in.value is not None else value
130
+ trig_range = trig.timestamp + np.array(trig.period)
98
131
 
99
- if _period is None:
100
- ez.logger.warning("Sampling failed: period not specified")
101
- continue
102
-
103
- # Check that period is valid
104
- if _period[0] >= _period[1]:
132
+ # If the previous iteration had insufficient data for the trigger timestamp + period,
133
+ # and buffer-management removed data required for the trigger, then we will never be able
134
+ # to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
135
+ if trig_range[0] < buff_t_range[0]:
105
136
  ez.logger.warning(
106
- f"Sampling failed: invalid period requested ({_period})"
137
+ f"Sampling failed: Buffer span {buff_t_range} begins beyond the "
138
+ f"requested sample period start: {trig_range[0]}"
107
139
  )
140
+ del self._state.triggers[trig_ix]
108
141
  continue
109
142
 
110
- # Check that period is compatible with buffer duration.
111
- max_buf_len = int(np.round(buffer_dur * check_inputs["fs"]))
112
- req_buf_len = int(np.round((_period[1] - _period[0]) * check_inputs["fs"]))
113
- if req_buf_len >= max_buf_len:
114
- ez.logger.warning(f"Sampling failed: {period=} >= {buffer_dur=}")
143
+ if trig_range[1] > buff_t_range[1]:
144
+ # We don't *yet* have enough data to satisfy this trigger.
115
145
  continue
116
146
 
117
- trigger_ts: float = msg_in.timestamp
118
- if not estimate_alignment:
119
- # Override the trigger timestamp with the next sample's likely timestamp.
120
- trigger_ts = offset + (n_samples + 1) / check_inputs["fs"]
147
+ # We know we have enough data in the buffer to satisfy this trigger.
148
+ buff_idx = self._state.buffer.axis_searchsorted(trig_range, side="right")
149
+ self._state.buffer.seek(buff_idx[0]) # FFWD to starting position.
150
+ buff_axarr = self._state.buffer.peek(buff_idx[1] - buff_idx[0])
151
+ self._state.buffer.seek(-buff_idx[0]) # Rewind it back.
152
+ # Note: buffer will trim itself as needed based on buffer_dur.
121
153
 
122
- new_trig_msg = replace(
123
- msg_in, timestamp=trigger_ts, period=_period, value=_value
124
- )
125
- triggers.append(new_trig_msg)
126
-
127
- elif isinstance(msg_in, AxisArray):
128
- # Get properties from message
129
- axis = axis or msg_in.dims[0]
130
- axis_idx = msg_in.get_axis_idx(axis)
131
- axis_info = msg_in.get_axis(axis)
132
- fs = 1.0 / axis_info.gain
133
- sample_shape = (
134
- msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
135
- )
154
+ # Prepare output and drop trigger
155
+ msgs_out.append(SampleMessage(trigger=copy.copy(trig), sample=buff_axarr))
156
+ del self._state.triggers[trig_ix]
136
157
 
137
- # TODO: We could accommodate change in dim order.
138
- # if axis_idx != check_inputs["axis_idx"]:
139
- # msg_in = replace(
140
- # msg_in,
141
- # data=np.moveaxis(msg_in.data, axis_idx, check_inputs["axis_idx"]),
142
- # dims=TODO...
143
- # )
144
- # axis_idx = check_inputs["axis_idx"]
145
-
146
- # If the properties have changed in a breaking way then reset buffer and triggers.
147
- b_reset = fs != check_inputs["fs"]
148
- b_reset = b_reset or sample_shape != check_inputs["shape"]
149
- # TODO: Skip next line if we do np.moveaxis above
150
- b_reset = b_reset or axis_idx != check_inputs["axis_idx"]
151
- b_reset = b_reset or msg_in.key != check_inputs["key"]
152
- if b_reset:
153
- check_inputs["fs"] = fs
154
- check_inputs["shape"] = sample_shape
155
- check_inputs["axis_idx"] = axis_idx
156
- check_inputs["key"] = msg_in.key
157
- n_samples = msg_in.data.shape[axis_idx]
158
- buffer = None
159
- if len(triggers) > 0:
160
- ez.logger.warning("Data stream changed: Discarding all triggers")
161
- triggers.clear()
162
-
163
- # Save some info for trigger processing
164
- offset = axis_info.offset
165
-
166
- # Update buffer
167
- buffer = (
168
- msg_in.data
169
- if buffer is None
170
- else np.concatenate((buffer, msg_in.data), axis=axis_idx)
171
- )
158
+ msgs_out.reverse() # in-place
159
+ return msgs_out
172
160
 
173
- # Calculate timestamps associated with buffer.
174
- buffer_offset = np.arange(buffer.shape[axis_idx], dtype=float)
175
- buffer_offset -= buffer_offset[-msg_in.data.shape[axis_idx]]
176
- buffer_offset *= axis_info.gain
177
- buffer_offset += axis_info.offset
178
-
179
- # ... for each trigger, collect the message (if possible) and append to msg_out
180
- for trig in list(triggers):
181
- if trig.period is None:
182
- # This trigger was malformed; drop it.
183
- triggers.remove(trig)
184
-
185
- # If the previous iteration had insufficient data for the trigger timestamp + period,
186
- # and buffer-management removed data required for the trigger, then we will never be able
187
- # to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
188
- if (trig.timestamp + trig.period[0]) < buffer_offset[0]:
189
- ez.logger.warning(
190
- f"Sampling failed: Buffer span {buffer_offset[0]} is beyond the "
191
- f"requested sample period start: {trig.timestamp + trig.period[0]}"
192
- )
193
- triggers.remove(trig)
194
-
195
- t_start = trig.timestamp + trig.period[0]
196
- if t_start >= buffer_offset[0]:
197
- start = np.searchsorted(buffer_offset, t_start)
198
- stop = start + int(np.round(fs * (trig.period[1] - trig.period[0])))
199
- if buffer.shape[axis_idx] > stop:
200
- # Trigger period fully enclosed in buffer.
201
- msg_out.append(
202
- SampleMessage(
203
- trigger=trig,
204
- sample=replace(
205
- msg_in,
206
- data=slice_along_axis(
207
- buffer, slice(start, stop), axis_idx
208
- ),
209
- axes={
210
- **msg_in.axes,
211
- axis: replace(
212
- axis_info, offset=buffer_offset[start]
213
- ),
214
- },
215
- ),
216
- )
217
- )
218
- triggers.remove(trig)
219
-
220
- buf_len = int(buffer_dur * fs)
221
- buffer = slice_along_axis(buffer, np.s_[-buf_len:], axis_idx)
161
+ def push_trigger(self, message: SampleTriggerMessage) -> list[SampleMessage]:
162
+ # Input is a trigger message that we will use to sample the buffer.
222
163
 
164
+ if self._state.buffer is None:
165
+ # We've yet to see any data; drop the trigger.
166
+ return []
223
167
 
224
- class SamplerSettings(ez.Settings):
225
- """
226
- Settings for :obj:`Sampler`.
227
- See :obj:`sampler` for a description of the fields.
228
- """
229
-
230
- buffer_dur: float
231
- axis: str | None = None
232
- period: tuple[float, float] | None = None
233
- """Optional default period if unspecified in SampleTriggerMessage"""
168
+ _period = message.period if message.period is not None else self.settings.period
169
+ _value = message.value if message.value is not None else self.settings.value
234
170
 
235
- value: typing.Any = None
236
- """Optional default value if unspecified in SampleTriggerMessage"""
171
+ if _period is None:
172
+ ez.logger.warning("Sampling failed: period not specified")
173
+ return []
237
174
 
238
- estimate_alignment: bool = True
239
- """
240
- If true, use message timestamp fields and reported sampling rate to estimate sample-accurate alignment for samples.
241
- If false, sampling will be limited to incoming message rate -- "Block timing"
242
- NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
243
- "realtime" operation for estimate_alignment to operate correctly.
244
- """
175
+ # Check that period is valid
176
+ if _period[0] >= _period[1]:
177
+ ez.logger.warning(f"Sampling failed: invalid period requested ({_period})")
178
+ return []
245
179
 
180
+ # Check that period is compatible with buffer duration.
181
+ if (_period[1] - _period[0]) > self.settings.buffer_dur:
182
+ ez.logger.warning(
183
+ f"Sampling failed: trigger period {_period=} >= buffer capacity {self.settings.buffer_dur=}"
184
+ )
185
+ return []
246
186
 
247
- class SamplerState(ez.State):
248
- cur_settings: SamplerSettings
249
- gen: typing.Generator[AxisArray | SampleTriggerMessage, list[SampleMessage], None]
187
+ trigger_ts: float = message.timestamp
188
+ if not self.settings.estimate_alignment:
189
+ # Override the trigger timestamp with the next sample's likely timestamp.
190
+ trigger_ts = self._state.buffer.axis_final_value + self._state.buffer.axis_gain
250
191
 
192
+ new_trig_msg = replace(message, timestamp=trigger_ts, period=_period, value=_value)
193
+ self._state.triggers.append(new_trig_msg)
194
+ return []
251
195
 
252
- class Sampler(ez.Unit):
253
- """An :obj:`Unit` for :obj:`sampler`."""
254
196
 
197
+ class Sampler(BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, SamplerTransformer]):
255
198
  SETTINGS = SamplerSettings
256
- STATE = SamplerState
257
199
 
258
200
  INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
259
- INPUT_SETTINGS = ez.InputStream(SamplerSettings)
260
- INPUT_SIGNAL = ez.InputStream(AxisArray)
261
- OUTPUT_SAMPLE = ez.OutputStream(SampleMessage)
262
-
263
- def construct_generator(self):
264
- self.STATE.gen = sampler(
265
- buffer_dur=self.STATE.cur_settings.buffer_dur,
266
- axis=self.STATE.cur_settings.axis,
267
- period=self.STATE.cur_settings.period,
268
- value=self.STATE.cur_settings.value,
269
- estimate_alignment=self.STATE.cur_settings.estimate_alignment,
270
- )
271
-
272
- async def initialize(self) -> None:
273
- self.STATE.cur_settings = self.SETTINGS
274
- self.construct_generator()
275
-
276
- @ez.subscriber(INPUT_SETTINGS)
277
- async def on_settings(self, msg: SamplerSettings) -> None:
278
- self.STATE.cur_settings = msg
279
- self.construct_generator()
201
+ OUTPUT_SIGNAL = ez.OutputStream(SampleMessage)
280
202
 
281
203
  @ez.subscriber(INPUT_TRIGGER)
282
204
  async def on_trigger(self, msg: SampleTriggerMessage) -> None:
283
- _ = self.STATE.gen.send(msg)
205
+ _ = self.processor.push_trigger(msg)
206
+
207
+ @ez.subscriber(BaseConsumerUnit.INPUT_SIGNAL, zero_copy=True)
208
+ @ez.publisher(OUTPUT_SIGNAL)
209
+ @profile_subpub(trace_oldest=False)
210
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
211
+ try:
212
+ for sample in self.processor(message):
213
+ yield self.OUTPUT_SIGNAL, sample
214
+ except Exception as e:
215
+ ez.logger.info(f"{traceback.format_exc()} - {e}")
216
+
217
+
218
+ def sampler(
219
+ buffer_dur: float,
220
+ axis: str | None = None,
221
+ period: tuple[float, float] | None = None,
222
+ value: typing.Any = None,
223
+ estimate_alignment: bool = True,
224
+ ) -> SamplerTransformer:
225
+ """
226
+ Sample data into a buffer, accept triggers, and return slices of sampled
227
+ data around the trigger time.
284
228
 
285
- @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
286
- @ez.publisher(OUTPUT_SAMPLE)
287
- async def on_signal(self, msg: AxisArray) -> typing.AsyncGenerator:
288
- pub_samples = self.STATE.gen.send(msg)
289
- for sample in pub_samples:
290
- yield self.OUTPUT_SAMPLE, sample
229
+ Returns:
230
+ A generator that expects `.send` either an :obj:`AxisArray` containing streaming data messages,
231
+ or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`SampleMessage` s.
232
+ """
233
+ return SamplerTransformer(
234
+ settings=SamplerSettings(
235
+ buffer_dur=buffer_dur,
236
+ axis=axis,
237
+ period=period,
238
+ value=value,
239
+ estimate_alignment=estimate_alignment,
240
+ )
241
+ )
291
242
 
292
243
 
293
244
  class TriggerGeneratorSettings(ez.Settings):
@@ -301,23 +252,27 @@ class TriggerGeneratorSettings(ez.Settings):
301
252
  """The period between triggers (sec)"""
302
253
 
303
254
 
304
- class TriggerGenerator(ez.Unit):
305
- """
306
- A unit to generate triggers every `publish_period` interval.
307
- """
255
+ @processor_state
256
+ class TriggerGeneratorState:
257
+ output: int = 0
308
258
 
309
- SETTINGS = TriggerGeneratorSettings
310
259
 
311
- OUTPUT_TRIGGER = ez.OutputStream(SampleTriggerMessage)
260
+ class TriggerProducer(BaseStatefulProducer[TriggerGeneratorSettings, SampleTriggerMessage, TriggerGeneratorState]):
261
+ def _reset_state(self) -> None:
262
+ self._state.output = 0
312
263
 
313
- @ez.publisher(OUTPUT_TRIGGER)
314
- async def generate(self) -> typing.AsyncGenerator:
315
- await asyncio.sleep(self.SETTINGS.prewait)
264
+ async def _produce(self) -> SampleTriggerMessage:
265
+ await asyncio.sleep(self.settings.publish_period)
266
+ out_msg = SampleTriggerMessage(period=self.settings.period, value=self._state.output)
267
+ self._state.output += 1
268
+ return out_msg
316
269
 
317
- output = 0
318
- while True:
319
- out_msg = SampleTriggerMessage(period=self.SETTINGS.period, value=output)
320
- yield self.OUTPUT_TRIGGER, out_msg
321
270
 
322
- await asyncio.sleep(self.SETTINGS.publish_period)
323
- output += 1
271
+ class TriggerGenerator(
272
+ BaseProducerUnit[
273
+ TriggerGeneratorSettings,
274
+ SampleTriggerMessage,
275
+ TriggerProducer,
276
+ ]
277
+ ):
278
+ SETTINGS = TriggerGeneratorSettings