ezmsg-sigproc 1.8.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +119 -104
  6. ezmsg/sigproc/bandpower.py +58 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -78
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.1.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/synth.py CHANGED
@@ -1,345 +1,515 @@
1
1
  import asyncio
2
- from dataclasses import field
2
+ import traceback
3
+ from dataclasses import dataclass, field
3
4
  import time
4
5
  import typing
5
6
 
6
7
  import numpy as np
7
8
  import ezmsg.core as ez
8
- from ezmsg.util.generator import consumer
9
9
  from ezmsg.util.messages.axisarray import AxisArray
10
10
  from ezmsg.util.messages.util import replace
11
11
 
12
- from .butterworthfilter import ButterworthFilter, ButterworthFilterSettings
13
- from .base import GenAxisArray
12
+ from .butterworthfilter import ButterworthFilterSettings, ButterworthFilterTransformer
13
+ from .base import (
14
+ BaseStatefulProducer,
15
+ BaseProducerUnit,
16
+ BaseTransformer,
17
+ BaseTransformerUnit,
18
+ CompositeProducer,
19
+ ProducerType,
20
+ SettingsType,
21
+ MessageInType,
22
+ MessageOutType,
23
+ processor_state,
24
+ )
25
+ from .util.asio import run_coroutine_sync
26
+ from .util.profile import profile_subpub
27
+
28
+
29
+ @dataclass
30
+ class AddState:
31
+ queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
32
+ queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
14
33
 
15
34
 
16
- def clock(dispatch_rate: float | None) -> typing.Generator[ez.Flag, None, None]:
17
- """
18
- Construct a generator that yields events at a specified rate.
35
+ class AddProcessor:
36
+ def __init__(self):
37
+ self._state = AddState()
19
38
 
20
- Args:
21
- dispatch_rate: event rate in seconds.
39
+ @property
40
+ def state(self) -> AddState:
41
+ return self._state
22
42
 
23
- Returns:
24
- A generator object that yields :obj:`ez.Flag` events at a specified rate.
25
- """
26
- n_dispatch = -1
27
- t_0 = time.time()
28
- while True:
29
- if dispatch_rate is not None:
30
- n_dispatch += 1
31
- t_next = t_0 + n_dispatch / dispatch_rate
32
- time.sleep(max(0, t_next - time.time()))
33
- yield ez.Flag()
43
+ @state.setter
44
+ def state(self, state: AddState | bytes | None) -> None:
45
+ if state is not None:
46
+ # TODO: Support hydrating state from bytes
47
+ # if isinstance(state, bytes):
48
+ # self._state = pickle.loads(state)
49
+ # else:
50
+ self._state = state
34
51
 
52
+ def push_a(self, msg: AxisArray) -> None:
53
+ self._state.queue_a.put_nowait(msg)
35
54
 
36
- async def aclock(dispatch_rate: float | None) -> typing.AsyncGenerator[ez.Flag, None]:
37
- """
38
- ``asyncio`` version of :obj:`clock`.
55
+ def push_b(self, msg: AxisArray) -> None:
56
+ self._state.queue_b.put_nowait(msg)
39
57
 
40
- Returns:
41
- asynchronous generator object. Must use `anext` or `async for`.
42
- """
43
- t_0 = time.time()
44
- n_dispatch = -1
45
- while True:
46
- if dispatch_rate is not None:
47
- n_dispatch += 1
48
- t_next = t_0 + n_dispatch / dispatch_rate
49
- await asyncio.sleep(t_next - time.time())
50
- yield ez.Flag()
58
+ async def __acall__(self) -> AxisArray:
59
+ a = await self._state.queue_a.get()
60
+ b = await self._state.queue_b.get()
61
+ return replace(a, data=a.data + b.data)
62
+
63
+ def __call__(self) -> AxisArray:
64
+ return run_coroutine_sync(self.__acall__())
65
+
66
+ # Aliases for legacy interface
67
+ async def __anext__(self) -> AxisArray:
68
+ return await self.__acall__()
69
+
70
+ def __next__(self) -> AxisArray:
71
+ return self.__call__()
72
+
73
+
74
+ class Add(ez.Unit):
75
+ """Add two signals together. Assumes compatible/similar axes/dimensions."""
76
+
77
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
78
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
79
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
80
+
81
+ async def initialize(self) -> None:
82
+ self.processor = AddProcessor()
83
+
84
+ @ez.subscriber(INPUT_SIGNAL_A)
85
+ async def on_a(self, msg: AxisArray) -> None:
86
+ self.processor.push_a(msg)
87
+
88
+ @ez.subscriber(INPUT_SIGNAL_B)
89
+ async def on_b(self, msg: AxisArray) -> None:
90
+ self.processor.push_b(msg)
91
+
92
+ @ez.publisher(OUTPUT_SIGNAL)
93
+ async def output(self) -> typing.AsyncGenerator:
94
+ while True:
95
+ yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
51
96
 
52
97
 
53
98
  class ClockSettings(ez.Settings):
54
- """Settings for :obj:`Clock`. See :obj:`clock` for parameter description."""
99
+ """Settings for clock generator."""
55
100
 
56
- # Message dispatch rate (Hz), or None (fast as possible)
57
- dispatch_rate: float | None
101
+ dispatch_rate: float | str | None = None
102
+ """Dispatch rate in Hz, 'realtime', or None for external clock"""
58
103
 
59
104
 
60
- class ClockState(ez.State):
61
- cur_settings: ClockSettings
62
- gen: typing.AsyncGenerator
105
+ @processor_state
106
+ class ClockState:
107
+ """State for clock generator."""
63
108
 
109
+ t_0: float = field(default_factory=time.time) # Start time
110
+ n_dispatch: int = 0 # Number of dispatches
64
111
 
65
- class Clock(ez.Unit):
66
- """Unit for :obj:`clock`."""
67
112
 
68
- SETTINGS = ClockSettings
69
- STATE = ClockState
113
+ class ClockProducer(BaseStatefulProducer[ClockSettings, ez.Flag, ClockState]):
114
+ """
115
+ Produces clock ticks at specified rate.
116
+ Can be used to drive periodic operations.
117
+ """
70
118
 
71
- INPUT_SETTINGS = ez.InputStream(ClockSettings)
72
- OUTPUT_CLOCK = ez.OutputStream(ez.Flag)
119
+ def _reset_state(self) -> None:
120
+ """Reset internal state."""
121
+ self._state.t_0 = time.time()
122
+ self._state.n_dispatch = 0
123
+
124
+ def __call__(self) -> ez.Flag:
125
+ """Synchronous clock production. We override __call__ (which uses run_coroutine_sync) to avoid async overhead."""
126
+ if self._hash == -1:
127
+ self._reset_state()
128
+ self._hash = 0
129
+
130
+ if isinstance(self.settings.dispatch_rate, (int, float)):
131
+ # Manual dispatch_rate. (else it is 'as fast as possible')
132
+ target_time = (
133
+ self.state.t_0
134
+ + (self.state.n_dispatch + 1) / self.settings.dispatch_rate
135
+ )
136
+ now = time.time()
137
+ if target_time > now:
138
+ time.sleep(target_time - now)
139
+
140
+ self.state.n_dispatch += 1
141
+ return ez.Flag()
142
+
143
+ async def _produce(self) -> ez.Flag:
144
+ """Generate next clock tick."""
145
+ if isinstance(self.settings.dispatch_rate, (int, float)):
146
+ # Manual dispatch_rate. (else it is 'as fast as possible')
147
+ target_time = (
148
+ self.state.t_0
149
+ + (self.state.n_dispatch + 1) / self.settings.dispatch_rate
150
+ )
151
+ now = time.time()
152
+ if target_time > now:
153
+ await asyncio.sleep(target_time - now)
154
+
155
+ self.state.n_dispatch += 1
156
+ return ez.Flag()
157
+
158
+
159
+ def aclock(dispatch_rate: float | None) -> ClockProducer:
160
+ """
161
+ Construct an async generator that yields events at a specified rate.
162
+
163
+ Returns:
164
+ A :obj:`ClockProducer` object.
165
+ """
166
+ return ClockProducer(ClockSettings(dispatch_rate=dispatch_rate))
73
167
 
74
- async def initialize(self) -> None:
75
- self.STATE.cur_settings = self.SETTINGS
76
- self.construct_generator()
77
168
 
78
- def construct_generator(self):
79
- self.STATE.gen = aclock(self.STATE.cur_settings.dispatch_rate)
169
+ clock = aclock
170
+ """
171
+ Alias for :obj:`aclock` expected by synchronous methods. `ClockProducer` can be used in sync or async.
172
+ """
80
173
 
81
- @ez.subscriber(INPUT_SETTINGS)
82
- async def on_settings(self, msg: ClockSettings) -> None:
83
- self.STATE.cur_settings = msg
84
- self.construct_generator()
85
174
 
86
- @ez.publisher(OUTPUT_CLOCK)
87
- async def generate(self) -> typing.AsyncGenerator:
175
+ class Clock(
176
+ BaseProducerUnit[
177
+ ClockSettings, # SettingsType
178
+ ez.Flag, # MessageType
179
+ ClockProducer, # ProducerType
180
+ ]
181
+ ):
182
+ SETTINGS = ClockSettings
183
+
184
+ @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
185
+ async def produce(self) -> typing.AsyncGenerator:
186
+ # Override so we can not to yield if out is False-like
88
187
  while True:
89
- out = await self.STATE.gen.__anext__()
188
+ out = await self.producer.__acall__()
90
189
  if out:
91
- yield self.OUTPUT_CLOCK, out
190
+ yield self.OUTPUT_SIGNAL, out
92
191
 
93
192
 
94
193
  # COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. #
95
- async def acounter(
96
- n_time: int,
97
- fs: float | None,
98
- n_ch: int = 1,
99
- dispatch_rate: float | str | None = None,
100
- mod: int | None = None,
101
- ) -> typing.AsyncGenerator[AxisArray, None]:
194
+ class CounterSettings(ez.Settings):
195
+ # TODO: Adapt this to use ezmsg.util.rate?
196
+ """
197
+ Settings for :obj:`Counter`.
198
+ See :obj:`acounter` for a description of the parameters.
102
199
  """
103
- Construct an asynchronous generator to generate AxisArray objects at a specified rate
104
- and with the specified sampling rate.
105
200
 
106
- NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode.
107
- This method of sleeping/yielding execution priority has quirky behavior with
108
- sub-millisecond sleep periods which may result in unexpected behavior (e.g.
109
- fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
201
+ n_time: int
202
+ """Number of samples to output per block."""
110
203
 
111
- Args:
112
- n_time: Number of samples to output per block.
113
- fs: Sampling rate of signal output in Hz.
114
- n_ch: Number of channels to synthesize
115
- dispatch_rate: Message dispatch rate (Hz), 'realtime' or None (fast as possible)
116
- Note: if dispatch_rate is a float then time offsets will be synthetic and the
117
- system will run faster or slower than wall clock time.
118
- mod: If set to an integer, counter will rollover at this number.
204
+ fs: float
205
+ """Sampling rate of signal output in Hz"""
119
206
 
120
- Returns:
121
- An asynchronous generator.
207
+ n_ch: int = 1
208
+ """Number of channels to synthesize"""
209
+
210
+ dispatch_rate: float | str | None = None
211
+ """
212
+ Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
213
+ Note: if dispatch_rate is a float then time offsets will be synthetic and the
214
+ system will run faster or slower than wall clock time.
215
+ """
216
+
217
+ mod: int | None = None
218
+ """If set to an integer, counter will rollover"""
219
+
220
+
221
+ @processor_state
222
+ class CounterState:
223
+ """
224
+ State for counter generator.
122
225
  """
123
226
 
227
+ counter_start: int = 0
228
+ """next sample's first value"""
229
+
230
+ n_sent: int = 0
231
+ """number of samples sent"""
232
+
233
+ clock_zero: float | None = None
234
+ """time of first sample"""
235
+
236
+ timer_type: str = "unspecified"
237
+ """
238
+ "realtime" | "ext_clock" | "manual" | "unspecified"
239
+ """
240
+
241
+ new_generator: asyncio.Event | None = None
242
+ """
243
+ Event to signal the counter has been reset.
244
+ """
245
+
246
+
247
+ class CounterProducer(BaseStatefulProducer[CounterSettings, AxisArray, CounterState]):
248
+ """Produces incrementing integer blocks as AxisArray."""
249
+
124
250
  # TODO: Adapt this to use ezmsg.util.rate?
125
251
 
126
- counter_start: int = 0 # next sample's first value
127
-
128
- b_realtime = False
129
- b_manual_dispatch = False
130
- b_ext_clock = False
131
- if dispatch_rate is not None:
132
- if isinstance(dispatch_rate, str):
133
- if dispatch_rate.lower() == "realtime":
134
- b_realtime = True
135
- elif dispatch_rate.lower() == "ext_clock":
136
- b_ext_clock = True
252
+ @classmethod
253
+ def get_message_type(cls, dir: str) -> typing.Optional[type[AxisArray]]:
254
+ if dir == "in":
255
+ return None
256
+ elif dir == "out":
257
+ return AxisArray
137
258
  else:
138
- b_manual_dispatch = True
139
-
140
- n_sent: int = 0 # It is convenient to know how many samples we have sent.
141
- clock_zero: float = time.time() # time associated with first sample
142
- template = AxisArray(
143
- data=np.array([[]]),
144
- dims=["time", "ch"],
145
- axes={
146
- "time": AxisArray.TimeAxis(fs=fs),
147
- "ch": AxisArray.CoordinateAxis(
148
- data=np.array([f"Ch{_}" for _ in range(n_ch)]), dims=["ch"]
149
- ),
150
- },
151
- key="acounter",
152
- )
259
+ raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
153
260
 
154
- while True:
155
- # 1. Sleep, if necessary, until we are at the end of the current block
156
- if b_realtime:
157
- n_next = n_sent + n_time
158
- t_next = clock_zero + n_next / fs
261
+ def __init__(self, *args, **kwargs):
262
+ super().__init__(*args, **kwargs)
263
+ if isinstance(
264
+ self.settings.dispatch_rate, str
265
+ ) and self.settings.dispatch_rate not in ["realtime", "ext_clock"]:
266
+ raise ValueError(f"Unknown dispatch_rate: {self.settings.dispatch_rate}")
267
+ self._reset_state()
268
+ self._hash = 0
269
+
270
+ def _reset_state(self) -> None:
271
+ """Reset internal state."""
272
+ self._state.counter_start = 0
273
+ self._state.n_sent = 0
274
+ self._state.clock_zero = time.time()
275
+ if self.settings.dispatch_rate is not None:
276
+ if isinstance(self.settings.dispatch_rate, str):
277
+ self._state.timer_type = self.settings.dispatch_rate.lower()
278
+ else:
279
+ self._state.timer_type = "manual"
280
+ if self._state.new_generator is None:
281
+ self._state.new_generator = asyncio.Event()
282
+ # Set the event to indicate that the state has been reset.
283
+ self._state.new_generator.set()
284
+
285
+ async def _produce(self) -> AxisArray:
286
+ """Generate next counter block."""
287
+ # 1. Prepare counter data
288
+ block_samp = np.arange(
289
+ self.state.counter_start, self.state.counter_start + self.settings.n_time
290
+ )[:, np.newaxis]
291
+ if self.settings.mod is not None:
292
+ block_samp %= self.settings.mod
293
+ block_samp = np.tile(block_samp, (1, self.settings.n_ch))
294
+
295
+ # 2. Sleep if necessary. 3. Calculate time offset.
296
+ if self._state.timer_type == "realtime":
297
+ n_next = self.state.n_sent + self.settings.n_time
298
+ t_next = self.state.clock_zero + n_next / self.settings.fs
159
299
  await asyncio.sleep(t_next - time.time())
160
- elif b_manual_dispatch:
161
- n_disp_next = 1 + n_sent / n_time
162
- t_disp_next = clock_zero + n_disp_next / dispatch_rate
300
+ offset = t_next - self.settings.n_time / self.settings.fs
301
+ elif self._state.timer_type == "manual":
302
+ # manual dispatch rate
303
+ n_disp_next = 1 + self.state.n_sent / self.settings.n_time
304
+ t_disp_next = (
305
+ self.state.clock_zero + n_disp_next / self.settings.dispatch_rate
306
+ )
163
307
  await asyncio.sleep(t_disp_next - time.time())
164
-
165
- # 2. Prepare counter data.
166
- block_samp = np.arange(counter_start, counter_start + n_time)[:, np.newaxis]
167
- if mod is not None:
168
- block_samp %= mod
169
- block_samp = np.tile(block_samp, (1, n_ch))
170
-
171
- # 3. Prepare offset - the time associated with block_samp[0]
172
- if b_realtime:
173
- offset = t_next - n_time / fs
174
- elif b_ext_clock:
308
+ offset = self.state.n_sent / self.settings.fs
309
+ elif self._state.timer_type == "ext_clock":
310
+ # ext_clock -- no sleep. Assume this is called at appropriate intervals.
175
311
  offset = time.time()
176
312
  else:
177
- # Purely synthetic.
178
- offset = n_sent / fs
179
- # offset += clock_zero # ??
313
+ # Was "unspecified"
314
+ offset = self.state.n_sent / self.settings.fs
180
315
 
181
- # 4. yield output
182
- yield replace(
183
- template,
316
+ # 4. Create output AxisArray
317
+ # Note: We can make this a bit faster by preparing a template for self._state
318
+ result = AxisArray(
184
319
  data=block_samp,
320
+ dims=["time", "ch"],
185
321
  axes={
186
- "time": replace(template.axes["time"], offset=offset),
187
- "ch": template.axes["ch"],
322
+ "time": AxisArray.TimeAxis(fs=self.settings.fs, offset=offset),
323
+ "ch": AxisArray.CoordinateAxis(
324
+ data=np.array([f"Ch{_}" for _ in range(self.settings.n_ch)]),
325
+ dims=["ch"],
326
+ ),
188
327
  },
328
+ key="acounter",
189
329
  )
190
330
 
191
- # 5. Update state for next iteration (after next yield)
192
- counter_start = block_samp[-1, 0] + 1 # do not % mod
193
- n_sent += n_time
331
+ # 5. Update state
332
+ self.state.counter_start = block_samp[-1, 0] + 1
333
+ self.state.n_sent += self.settings.n_time
194
334
 
335
+ return result
195
336
 
196
- class CounterSettings(ez.Settings):
197
- # TODO: Adapt this to use ezmsg.util.rate?
337
+
338
+ def acounter(
339
+ n_time: int,
340
+ fs: float | None,
341
+ n_ch: int = 1,
342
+ dispatch_rate: float | str | None = None,
343
+ mod: int | None = None,
344
+ ) -> CounterProducer:
198
345
  """
199
- Settings for :obj:`Counter`.
200
- See :obj:`acounter` for a description of the parameters.
346
+ Construct an asynchronous generator to generate AxisArray objects at a specified rate
347
+ and with the specified sampling rate.
348
+
349
+ NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode.
350
+ This method of sleeping/yielding execution priority has quirky behavior with
351
+ sub-millisecond sleep periods which may result in unexpected behavior (e.g.
352
+ fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
353
+
354
+ Returns:
355
+ An asynchronous generator.
201
356
  """
357
+ return CounterProducer(
358
+ CounterSettings(
359
+ n_time=n_time, fs=fs, n_ch=n_ch, dispatch_rate=dispatch_rate, mod=mod
360
+ )
361
+ )
202
362
 
203
- n_time: int # Number of samples to output per block
204
- fs: float # Sampling rate of signal output in Hz
205
- n_ch: int = 1 # Number of channels to synthesize
206
363
 
207
- # Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
208
- # Note: if dispatch_rate is a float then time offsets will be synthetic and the
209
- # system will run faster or slower than wall clock time.
210
- dispatch_rate: float | str | None = None
364
+ class Counter(
365
+ BaseProducerUnit[
366
+ CounterSettings, # SettingsType
367
+ AxisArray, # MessageOutType
368
+ CounterProducer, # ProducerType
369
+ ]
370
+ ):
371
+ """Generates monotonically increasing counter. Unit for :obj:`CounterProducer`."""
211
372
 
212
- # If set to an integer, counter will rollover
213
- mod: int | None = None
373
+ SETTINGS = CounterSettings
374
+ INPUT_CLOCK = ez.InputStream(ez.Flag)
214
375
 
376
+ @ez.subscriber(INPUT_CLOCK)
377
+ @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
378
+ async def on_clock(self, _: ez.Flag):
379
+ if self.producer.settings.dispatch_rate == "ext_clock":
380
+ out = await self.producer.__acall__()
381
+ yield self.OUTPUT_SIGNAL, out
215
382
 
216
- class CounterState(ez.State):
217
- gen: typing.AsyncGenerator[AxisArray, ez.Flag | None]
218
- cur_settings: CounterSettings
219
- new_generator: asyncio.Event
383
+ @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
384
+ async def produce(self) -> typing.AsyncGenerator:
385
+ """
386
+ Generate counter output.
387
+ This is an infinite loop, but we will likely only enter the loop once if we are self-timed,
388
+ and twice if we are using an external clock.
389
+
390
+ When using an internal clock, we enter the loop, and wait for the event which should have
391
+ been reset upon initialization then we immediately clear, then go to the internal loop
392
+ that will async call __acall__ to let the internal timer determine when to produce an output.
393
+
394
+ When using an external clock, we enter the loop, and wait for the event which should have been
395
+ reset upon initialization then we immediately clear, then we hit `continue` to loop back around
396
+ and wait for the event to be set again -- potentially forever. In this case, it is expected that
397
+ `on_clock` will be called to produce the output.
398
+ """
399
+ try:
400
+ while True:
401
+ # Once-only, enter the generator loop
402
+ await self.producer.state.new_generator.wait()
403
+ self.producer.state.new_generator.clear()
404
+
405
+ if self.producer.settings.dispatch_rate == "ext_clock":
406
+ # We shouldn't even be here. Cycle around and wait on the event again.
407
+ continue
408
+
409
+ # We are not using an external clock. Run the generator.
410
+ while not self.producer.state.new_generator.is_set():
411
+ out = await self.producer.__acall__()
412
+ yield self.OUTPUT_SIGNAL, out
413
+ except Exception:
414
+ ez.logger.info(traceback.format_exc())
220
415
 
221
416
 
222
- class Counter(ez.Unit):
223
- """Generates monotonically increasing counter. Unit for :obj:`acounter`."""
417
+ class SinGeneratorSettings(ez.Settings):
418
+ """
419
+ Settings for :obj:`SinGenerator`.
420
+ See :obj:`sin` for parameter descriptions.
421
+ """
224
422
 
225
- SETTINGS = CounterSettings
226
- STATE = CounterState
423
+ axis: str | None = "time"
424
+ """
425
+ The name of the axis over which the sinusoid passes.
426
+ Note: The axis must exist in the msg.axes and be of type AxisArray.LinearAxis.
427
+ """
227
428
 
228
- INPUT_CLOCK = ez.InputStream(ez.Flag)
229
- INPUT_SETTINGS = ez.InputStream(CounterSettings)
230
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
429
+ freq: float = 1.0
430
+ """The frequency of the sinusoid, in Hz."""
231
431
 
232
- async def initialize(self) -> None:
233
- self.STATE.new_generator = asyncio.Event()
234
- self.validate_settings(self.SETTINGS)
432
+ amp: float = 1.0 # Amplitude
433
+ """The amplitude of the sinusoid."""
235
434
 
236
- @ez.subscriber(INPUT_SETTINGS)
237
- async def on_settings(self, msg: CounterSettings) -> None:
238
- self.validate_settings(msg)
435
+ phase: float = 0.0 # Phase offset (in radians)
436
+ """The initial phase of the sinusoid, in radians."""
239
437
 
240
- def validate_settings(self, settings: CounterSettings) -> None:
241
- if isinstance(
242
- settings.dispatch_rate, str
243
- ) and self.SETTINGS.dispatch_rate not in ["realtime", "ext_clock"]:
244
- raise ValueError(f"Unknown dispatch_rate: {self.SETTINGS.dispatch_rate}")
245
- self.STATE.cur_settings = settings
246
- self.construct_generator()
247
-
248
- def construct_generator(self):
249
- self.STATE.gen = acounter(
250
- self.STATE.cur_settings.n_time,
251
- self.STATE.cur_settings.fs,
252
- n_ch=self.STATE.cur_settings.n_ch,
253
- dispatch_rate=self.STATE.cur_settings.dispatch_rate,
254
- mod=self.STATE.cur_settings.mod,
255
- )
256
- self.STATE.new_generator.set()
257
438
 
258
- @ez.subscriber(INPUT_CLOCK)
259
- @ez.publisher(OUTPUT_SIGNAL)
260
- async def on_clock(self, clock: ez.Flag):
261
- if self.STATE.cur_settings.dispatch_rate == "ext_clock":
262
- out = await self.STATE.gen.__anext__()
263
- yield self.OUTPUT_SIGNAL, out
439
+ class SinTransformer(BaseTransformer[SinGeneratorSettings, AxisArray, AxisArray]):
440
+ """Transforms counter values into sinusoidal waveforms."""
264
441
 
265
- @ez.publisher(OUTPUT_SIGNAL)
266
- async def run_generator(self) -> typing.AsyncGenerator:
267
- while True:
268
- await self.STATE.new_generator.wait()
269
- self.STATE.new_generator.clear()
442
+ def _process(self, message: AxisArray) -> AxisArray:
443
+ """Transform input counter values into sinusoidal waveform."""
444
+ axis = self.settings.axis or message.dims[0]
270
445
 
271
- if self.STATE.cur_settings.dispatch_rate == "ext_clock":
272
- continue
446
+ ang_freq = 2.0 * np.pi * self.settings.freq
447
+ w = (ang_freq * message.get_axis(axis).gain) * message.data
448
+ out_data = self.settings.amp * np.sin(w + self.settings.phase)
449
+
450
+ return replace(message, data=out_data)
273
451
 
274
- while not self.STATE.new_generator.is_set():
275
- out = await self.STATE.gen.__anext__()
276
- yield self.OUTPUT_SIGNAL, out
452
+
453
+ class SinGenerator(
454
+ BaseTransformerUnit[SinGeneratorSettings, AxisArray, AxisArray, SinTransformer]
455
+ ):
456
+ """Unit for generating sinusoidal waveforms."""
457
+
458
+ SETTINGS = SinGeneratorSettings
277
459
 
278
460
 
279
- @consumer
280
461
  def sin(
281
462
  axis: str | None = "time",
282
463
  freq: float = 1.0,
283
464
  amp: float = 1.0,
284
465
  phase: float = 0.0,
285
- ) -> typing.Generator[AxisArray, AxisArray, None]:
466
+ ) -> SinTransformer:
286
467
  """
287
468
  Construct a generator of sinusoidal waveforms in AxisArray objects.
288
469
 
289
- Args:
290
- axis: The name of the axis over which the sinusoid passes.
291
- Note: The axis must exist in the msg.axes and be of type AxisArray.LinearAxis.
292
- freq: The frequency of the sinusoid, in Hz.
293
- amp: The amplitude of the sinusoid.
294
- phase: The initial phase of the sinusoid, in radians.
295
-
296
470
  Returns:
297
471
  A primed generator that expects .send(axis_array) of sample counts
298
472
  and yields an AxisArray of sinusoids.
299
473
  """
300
- msg_out = AxisArray(np.array([]), dims=[""])
301
-
302
- ang_freq = 2.0 * np.pi * freq
474
+ return SinTransformer(
475
+ SinGeneratorSettings(axis=axis, freq=freq, amp=amp, phase=phase)
476
+ )
303
477
 
304
- while True:
305
- msg_in: AxisArray = yield msg_out
306
- # msg_in is expected to be sample counts
307
478
 
308
- axis_name = axis
309
- if axis_name is None:
310
- axis_name = msg_in.dims[0]
479
+ class RandomGeneratorSettings(ez.Settings):
480
+ loc: float = 0.0
481
+ """loc argument for :obj:`numpy.random.normal`"""
311
482
 
312
- w = (ang_freq * msg_in.get_axis(axis_name).gain) * msg_in.data
313
- out_data = amp * np.sin(w + phase)
314
- msg_out = replace(msg_in, data=out_data)
483
+ scale: float = 1.0
484
+ """scale argument for :obj:`numpy.random.normal`"""
315
485
 
316
486
 
317
- class SinGeneratorSettings(ez.Settings):
487
+ class RandomTransformer(BaseTransformer[RandomGeneratorSettings, AxisArray, AxisArray]):
318
488
  """
319
- Settings for :obj:`SinGenerator`.
320
- See :obj:`sin` for parameter descriptions.
489
+ Replaces input data with random data and returns the result.
321
490
  """
322
491
 
323
- time_axis: str | None = "time"
324
- freq: float = 1.0 # Oscillation frequency in Hz
325
- amp: float = 1.0 # Amplitude
326
- phase: float = 0.0 # Phase offset (in radians)
327
-
492
+ def __init__(
493
+ self, *args, settings: RandomGeneratorSettings | None = None, **kwargs
494
+ ):
495
+ super().__init__(*args, settings=settings, **kwargs)
328
496
 
329
- class SinGenerator(GenAxisArray):
330
- """
331
- Unit for :obj:`sin`.
332
- """
497
+ def _process(self, message: AxisArray) -> AxisArray:
498
+ random_data = np.random.normal(
499
+ size=message.shape, loc=self.settings.loc, scale=self.settings.scale
500
+ )
501
+ return replace(message, data=random_data)
333
502
 
334
- SETTINGS = SinGeneratorSettings
335
503
 
336
- def construct_generator(self):
337
- self.STATE.gen = sin(
338
- axis=self.SETTINGS.time_axis,
339
- freq=self.SETTINGS.freq,
340
- amp=self.SETTINGS.amp,
341
- phase=self.SETTINGS.phase,
342
- )
504
+ class RandomGenerator(
505
+ BaseTransformerUnit[
506
+ RandomGeneratorSettings,
507
+ AxisArray,
508
+ AxisArray,
509
+ RandomTransformer,
510
+ ]
511
+ ):
512
+ SETTINGS = RandomGeneratorSettings
343
513
 
344
514
 
345
515
  class OscillatorSettings(ez.Settings):
@@ -370,78 +540,93 @@ class OscillatorSettings(ez.Settings):
370
540
  """Adjust `freq` to sync with sampling rate"""
371
541
 
372
542
 
373
- class Oscillator(ez.Collection):
374
- """
375
- :obj:`Collection that chains :obj:`Counter` and :obj:`SinGenerator`.
376
- """
377
-
378
- SETTINGS = OscillatorSettings
379
-
380
- INPUT_CLOCK = ez.InputStream(ez.Flag)
381
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
382
-
383
- COUNTER = Counter()
384
- SIN = SinGenerator()
385
-
386
- def configure(self) -> None:
543
+ class OscillatorProducer(CompositeProducer[OscillatorSettings, AxisArray]):
544
+ @staticmethod
545
+ def _initialize_processors(
546
+ settings: OscillatorSettings,
547
+ ) -> dict[str, CounterProducer | SinTransformer]:
387
548
  # Calculate synchronous settings if necessary
388
- freq = self.SETTINGS.freq
549
+ freq = settings.freq
389
550
  mod = None
390
- if self.SETTINGS.sync:
391
- period = 1.0 / self.SETTINGS.freq
392
- mod = round(period * self.SETTINGS.fs)
393
- freq = 1.0 / (mod / self.SETTINGS.fs)
394
-
395
- self.COUNTER.apply_settings(
396
- CounterSettings(
397
- n_time=self.SETTINGS.n_time,
398
- fs=self.SETTINGS.fs,
399
- n_ch=self.SETTINGS.n_ch,
400
- dispatch_rate=self.SETTINGS.dispatch_rate,
401
- mod=mod,
402
- )
403
- )
404
-
405
- self.SIN.apply_settings(
406
- SinGeneratorSettings(
407
- freq=freq, amp=self.SETTINGS.amp, phase=self.SETTINGS.phase
408
- )
409
- )
410
-
411
- def network(self) -> ez.NetworkDefinition:
412
- return (
413
- (self.INPUT_CLOCK, self.COUNTER.INPUT_CLOCK),
414
- (self.COUNTER.OUTPUT_SIGNAL, self.SIN.INPUT_SIGNAL),
415
- (self.SIN.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
416
- )
551
+ if settings.sync:
552
+ period = 1.0 / settings.freq
553
+ mod = round(period * settings.fs)
554
+ freq = 1.0 / (mod / settings.fs)
555
+
556
+ return {
557
+ "counter": CounterProducer(
558
+ CounterSettings(
559
+ n_time=settings.n_time,
560
+ fs=settings.fs,
561
+ n_ch=settings.n_ch,
562
+ dispatch_rate=settings.dispatch_rate,
563
+ mod=mod,
564
+ )
565
+ ),
566
+ "sin": SinTransformer(
567
+ SinGeneratorSettings(freq=freq, amp=settings.amp, phase=settings.phase)
568
+ ),
569
+ }
417
570
 
418
571
 
419
- class RandomGeneratorSettings(ez.Settings):
420
- loc: float = 0.0
421
- """loc argument for :obj:`numpy.random.normal`"""
572
+ class BaseCounterFirstProducerUnit(
573
+ BaseProducerUnit[SettingsType, MessageOutType, ProducerType],
574
+ typing.Generic[SettingsType, MessageInType, MessageOutType, ProducerType],
575
+ ):
576
+ """
577
+ Base class for units whose primary processor is a composite producer with a CounterProducer as the first
578
+ processor (producer) in the chain.
579
+ """
422
580
 
423
- scale: float = 1.0
424
- """scale argument for :obj:`numpy.random.normal`"""
581
+ INPUT_SIGNAL = ez.InputStream(MessageInType)
425
582
 
583
+ def create_producer(self):
584
+ super().create_producer()
426
585
 
427
- class RandomGenerator(ez.Unit):
428
- """
429
- Replaces input data with random data and yields the result.
430
- """
586
+ def recurse_get_counter(proc) -> CounterProducer:
587
+ if hasattr(proc, "_procs"):
588
+ return recurse_get_counter(list(proc._procs.values())[0])
589
+ return proc
431
590
 
432
- SETTINGS = RandomGeneratorSettings
591
+ self._counter = recurse_get_counter(self.producer)
433
592
 
434
- INPUT_SIGNAL = ez.InputStream(AxisArray)
435
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
593
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
594
+ @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
595
+ @profile_subpub(trace_oldest=False)
596
+ async def on_signal(self, _: ez.Flag):
597
+ if self.producer.settings.dispatch_rate == "ext_clock":
598
+ out = await self.producer.__acall__()
599
+ yield self.OUTPUT_SIGNAL, out
436
600
 
437
- @ez.subscriber(INPUT_SIGNAL)
438
- @ez.publisher(OUTPUT_SIGNAL)
439
- async def generate(self, msg: AxisArray) -> typing.AsyncGenerator:
440
- random_data = np.random.normal(
441
- size=msg.shape, loc=self.SETTINGS.loc, scale=self.SETTINGS.scale
442
- )
601
+ @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
602
+ async def produce(self) -> typing.AsyncGenerator:
603
+ try:
604
+ counter_state = self._counter.state
605
+ while True:
606
+ # Once-only, enter the generator loop
607
+ await counter_state.new_generator.wait()
608
+ counter_state.new_generator.clear()
609
+
610
+ if self.producer.settings.dispatch_rate == "ext_clock":
611
+ # We shouldn't even be here. Cycle around and wait on the event again.
612
+ continue
613
+
614
+ # We are not using an external clock. Run the generator.
615
+ while not counter_state.new_generator.is_set():
616
+ out = await self.producer.__acall__()
617
+ yield self.OUTPUT_SIGNAL, out
618
+ except Exception:
619
+ ez.logger.info(traceback.format_exc())
620
+
621
+
622
+ class Oscillator(
623
+ BaseCounterFirstProducerUnit[
624
+ OscillatorSettings, AxisArray, AxisArray, OscillatorProducer
625
+ ]
626
+ ):
627
+ """Generates sinusoidal waveforms using a counter and sine transformer."""
443
628
 
444
- yield self.OUTPUT_SIGNAL, replace(msg, data=random_data)
629
+ SETTINGS = OscillatorSettings
445
630
 
446
631
 
447
632
  class NoiseSettings(ez.Settings):
@@ -461,105 +646,66 @@ class NoiseSettings(ez.Settings):
461
646
  WhiteNoiseSettings = NoiseSettings
462
647
 
463
648
 
464
- class WhiteNoise(ez.Collection):
465
- """
466
- A :obj:`Collection` that chains a :obj:`Counter` and :obj:`RandomGenerator`.
467
- """
468
-
469
- SETTINGS = NoiseSettings
649
+ class WhiteNoiseProducer(CompositeProducer[NoiseSettings, AxisArray]):
650
+ @staticmethod
651
+ def _initialize_processors(
652
+ settings: NoiseSettings,
653
+ ) -> dict[str, CounterProducer | RandomTransformer]:
654
+ return {
655
+ "counter": CounterProducer(
656
+ CounterSettings(
657
+ n_time=settings.n_time,
658
+ fs=settings.fs,
659
+ n_ch=settings.n_ch,
660
+ dispatch_rate=settings.dispatch_rate,
661
+ mod=None,
662
+ )
663
+ ),
664
+ "random": RandomTransformer(
665
+ RandomGeneratorSettings(
666
+ loc=settings.loc,
667
+ scale=settings.scale,
668
+ )
669
+ ),
670
+ }
470
671
 
471
- INPUT_CLOCK = ez.InputStream(ez.Flag)
472
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
473
672
 
474
- COUNTER = Counter()
475
- RANDOM = RandomGenerator()
673
+ class WhiteNoise(
674
+ BaseCounterFirstProducerUnit[
675
+ NoiseSettings, AxisArray, AxisArray, WhiteNoiseProducer
676
+ ]
677
+ ):
678
+ """chains a :obj:`Counter` and :obj:`RandomGenerator`."""
476
679
 
477
- def configure(self) -> None:
478
- self.RANDOM.apply_settings(
479
- RandomGeneratorSettings(loc=self.SETTINGS.loc, scale=self.SETTINGS.scale)
480
- )
481
-
482
- self.COUNTER.apply_settings(
483
- CounterSettings(
484
- n_time=self.SETTINGS.n_time,
485
- fs=self.SETTINGS.fs,
486
- n_ch=self.SETTINGS.n_ch,
487
- dispatch_rate=self.SETTINGS.dispatch_rate,
488
- mod=None,
489
- )
490
- )
491
-
492
- def network(self) -> ez.NetworkDefinition:
493
- return (
494
- (self.INPUT_CLOCK, self.COUNTER.INPUT_CLOCK),
495
- (self.COUNTER.OUTPUT_SIGNAL, self.RANDOM.INPUT_SIGNAL),
496
- (self.RANDOM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
497
- )
680
+ SETTINGS = NoiseSettings
498
681
 
499
682
 
500
683
  PinkNoiseSettings = NoiseSettings
501
684
 
502
685
 
503
- class PinkNoise(ez.Collection):
504
- """
505
- A :obj:`Collection` that chains :obj:`WhiteNoise` and :obj:`ButterworthFilter`.
506
- """
507
-
508
- SETTINGS = PinkNoiseSettings
509
-
510
- INPUT_CLOCK = ez.InputStream(ez.Flag)
511
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
512
-
513
- WHITE_NOISE = WhiteNoise()
514
- FILTER = ButterworthFilter()
515
-
516
- def configure(self) -> None:
517
- self.WHITE_NOISE.apply_settings(self.SETTINGS)
518
- self.FILTER.apply_settings(
519
- ButterworthFilterSettings(
520
- axis="time",
521
- order=1,
522
- cutoff=self.SETTINGS.fs * 0.01, # Hz
523
- )
524
- )
525
-
526
- def network(self) -> ez.NetworkDefinition:
527
- return (
528
- (self.INPUT_CLOCK, self.WHITE_NOISE.INPUT_CLOCK),
529
- (self.WHITE_NOISE.OUTPUT_SIGNAL, self.FILTER.INPUT_SIGNAL),
530
- (self.FILTER.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
531
- )
532
-
533
-
534
- class AddState(ez.State):
535
- queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
536
- queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
537
-
538
-
539
- class Add(ez.Unit):
540
- """Add two signals together. Assumes compatible/similar axes/dimensions."""
686
+ class PinkNoiseProducer(CompositeProducer[PinkNoiseSettings, AxisArray]):
687
+ @staticmethod
688
+ def _initialize_processors(
689
+ settings: PinkNoiseSettings,
690
+ ) -> dict[str, WhiteNoiseProducer | ButterworthFilterTransformer]:
691
+ return {
692
+ "white_noise": WhiteNoiseProducer(settings=settings),
693
+ "filter": ButterworthFilterTransformer(
694
+ settings=ButterworthFilterSettings(
695
+ axis="time",
696
+ order=1,
697
+ cutoff=settings.fs * 0.01, # Hz
698
+ )
699
+ ),
700
+ }
541
701
 
542
- STATE = AddState
543
702
 
544
- INPUT_SIGNAL_A = ez.InputStream(AxisArray)
545
- INPUT_SIGNAL_B = ez.InputStream(AxisArray)
546
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
703
+ class PinkNoise(
704
+ BaseCounterFirstProducerUnit[NoiseSettings, AxisArray, AxisArray, PinkNoiseProducer]
705
+ ):
706
+ """chains :obj:`WhiteNoise` and :obj:`ButterworthFilter`."""
547
707
 
548
- @ez.subscriber(INPUT_SIGNAL_A)
549
- async def on_a(self, msg: AxisArray) -> None:
550
- self.STATE.queue_a.put_nowait(msg)
551
-
552
- @ez.subscriber(INPUT_SIGNAL_B)
553
- async def on_b(self, msg: AxisArray) -> None:
554
- self.STATE.queue_b.put_nowait(msg)
555
-
556
- @ez.publisher(OUTPUT_SIGNAL)
557
- async def output(self) -> typing.AsyncGenerator:
558
- while True:
559
- a = await self.STATE.queue_a.get()
560
- b = await self.STATE.queue_b.get()
561
-
562
- yield self.OUTPUT_SIGNAL, replace(a, data=a.data + b.data)
708
+ SETTINGS = NoiseSettings
563
709
 
564
710
 
565
711
  class EEGSynthSettings(ez.Settings):
@@ -575,6 +721,13 @@ class EEGSynth(ez.Collection):
575
721
  """
576
722
  A :obj:`Collection` that chains a :obj:`Clock` to both :obj:`PinkNoise`
577
723
  and :obj:`Oscillator`, then :obj:`Add` s the result.
724
+
725
+ Unlike the Oscillator, WhiteNoise, and PinkNoise composite processors which have linear
726
+ flows, this class has a diamond flow, with clock branching to both PinkNoise and Oscillator,
727
+ which then are combined in Add.
728
+
729
+ Optional: Refactor as a ProducerUnit, similar to Clock, but we manually add all the other
730
+ transformers.
578
731
  """
579
732
 
580
733
  SETTINGS = EEGSynthSettings
@@ -613,8 +766,8 @@ class EEGSynth(ez.Collection):
613
766
 
614
767
  def network(self) -> ez.NetworkDefinition:
615
768
  return (
616
- (self.CLOCK.OUTPUT_CLOCK, self.OSC.INPUT_CLOCK),
617
- (self.CLOCK.OUTPUT_CLOCK, self.NOISE.INPUT_CLOCK),
769
+ (self.CLOCK.OUTPUT_SIGNAL, self.OSC.INPUT_SIGNAL),
770
+ (self.CLOCK.OUTPUT_SIGNAL, self.NOISE.INPUT_SIGNAL),
618
771
  (self.OSC.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A),
619
772
  (self.NOISE.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B),
620
773
  (self.ADD.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),