ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
  4. ezmsg/sigproc/affinetransform.py +13 -38
  5. ezmsg/sigproc/aggregate.py +13 -30
  6. ezmsg/sigproc/bandpower.py +7 -15
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +123 -0
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/decimate.py +2 -6
  13. ezmsg/sigproc/denormalize.py +6 -11
  14. ezmsg/sigproc/detrend.py +3 -4
  15. ezmsg/sigproc/diff.py +8 -17
  16. ezmsg/sigproc/downsample.py +6 -14
  17. ezmsg/sigproc/ewma.py +11 -27
  18. ezmsg/sigproc/ewmfilter.py +1 -1
  19. ezmsg/sigproc/extract_axis.py +3 -4
  20. ezmsg/sigproc/fbcca.py +31 -56
  21. ezmsg/sigproc/filter.py +19 -45
  22. ezmsg/sigproc/filterbank.py +33 -70
  23. ezmsg/sigproc/filterbankdesign.py +5 -12
  24. ezmsg/sigproc/fir_hilbert.py +336 -0
  25. ezmsg/sigproc/fir_pmc.py +209 -0
  26. ezmsg/sigproc/firfilter.py +12 -14
  27. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  28. ezmsg/sigproc/kaiser.py +11 -15
  29. ezmsg/sigproc/math/abs.py +1 -3
  30. ezmsg/sigproc/math/add.py +121 -0
  31. ezmsg/sigproc/math/clip.py +1 -1
  32. ezmsg/sigproc/math/difference.py +98 -36
  33. ezmsg/sigproc/math/invert.py +1 -3
  34. ezmsg/sigproc/math/log.py +2 -6
  35. ezmsg/sigproc/messages.py +1 -2
  36. ezmsg/sigproc/quantize.py +2 -4
  37. ezmsg/sigproc/resample.py +13 -34
  38. ezmsg/sigproc/rollingscaler.py +232 -0
  39. ezmsg/sigproc/sampler.py +17 -35
  40. ezmsg/sigproc/scaler.py +8 -18
  41. ezmsg/sigproc/signalinjector.py +6 -16
  42. ezmsg/sigproc/slicer.py +9 -28
  43. ezmsg/sigproc/spectral.py +3 -3
  44. ezmsg/sigproc/spectrogram.py +12 -19
  45. ezmsg/sigproc/spectrum.py +12 -32
  46. ezmsg/sigproc/transpose.py +7 -18
  47. ezmsg/sigproc/util/asio.py +25 -156
  48. ezmsg/sigproc/util/axisarray_buffer.py +10 -26
  49. ezmsg/sigproc/util/buffer.py +18 -43
  50. ezmsg/sigproc/util/message.py +17 -31
  51. ezmsg/sigproc/util/profile.py +23 -174
  52. ezmsg/sigproc/util/sparse.py +5 -15
  53. ezmsg/sigproc/util/typeresolution.py +17 -83
  54. ezmsg/sigproc/wavelets.py +6 -15
  55. ezmsg/sigproc/window.py +24 -78
  56. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
  57. ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
  58. ezmsg/sigproc/synth.py +0 -774
  59. ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
  60. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
  61. /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/synth.py DELETED
@@ -1,774 +0,0 @@
1
- import asyncio
2
- import traceback
3
- from dataclasses import dataclass, field
4
- import time
5
- import typing
6
-
7
- import numpy as np
8
- import ezmsg.core as ez
9
- from ezmsg.util.messages.axisarray import AxisArray
10
- from ezmsg.util.messages.util import replace
11
-
12
- from .butterworthfilter import ButterworthFilterSettings, ButterworthFilterTransformer
13
- from .base import (
14
- BaseStatefulProducer,
15
- BaseProducerUnit,
16
- BaseTransformer,
17
- BaseTransformerUnit,
18
- CompositeProducer,
19
- ProducerType,
20
- SettingsType,
21
- MessageInType,
22
- MessageOutType,
23
- processor_state,
24
- )
25
- from .util.asio import run_coroutine_sync
26
- from .util.profile import profile_subpub
27
-
28
-
29
- @dataclass
30
- class AddState:
31
- queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
32
- queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
33
-
34
-
35
- class AddProcessor:
36
- def __init__(self):
37
- self._state = AddState()
38
-
39
- @property
40
- def state(self) -> AddState:
41
- return self._state
42
-
43
- @state.setter
44
- def state(self, state: AddState | bytes | None) -> None:
45
- if state is not None:
46
- # TODO: Support hydrating state from bytes
47
- # if isinstance(state, bytes):
48
- # self._state = pickle.loads(state)
49
- # else:
50
- self._state = state
51
-
52
- def push_a(self, msg: AxisArray) -> None:
53
- self._state.queue_a.put_nowait(msg)
54
-
55
- def push_b(self, msg: AxisArray) -> None:
56
- self._state.queue_b.put_nowait(msg)
57
-
58
- async def __acall__(self) -> AxisArray:
59
- a = await self._state.queue_a.get()
60
- b = await self._state.queue_b.get()
61
- return replace(a, data=a.data + b.data)
62
-
63
- def __call__(self) -> AxisArray:
64
- return run_coroutine_sync(self.__acall__())
65
-
66
- # Aliases for legacy interface
67
- async def __anext__(self) -> AxisArray:
68
- return await self.__acall__()
69
-
70
- def __next__(self) -> AxisArray:
71
- return self.__call__()
72
-
73
-
74
- class Add(ez.Unit):
75
- """Add two signals together. Assumes compatible/similar axes/dimensions."""
76
-
77
- INPUT_SIGNAL_A = ez.InputStream(AxisArray)
78
- INPUT_SIGNAL_B = ez.InputStream(AxisArray)
79
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
80
-
81
- async def initialize(self) -> None:
82
- self.processor = AddProcessor()
83
-
84
- @ez.subscriber(INPUT_SIGNAL_A)
85
- async def on_a(self, msg: AxisArray) -> None:
86
- self.processor.push_a(msg)
87
-
88
- @ez.subscriber(INPUT_SIGNAL_B)
89
- async def on_b(self, msg: AxisArray) -> None:
90
- self.processor.push_b(msg)
91
-
92
- @ez.publisher(OUTPUT_SIGNAL)
93
- async def output(self) -> typing.AsyncGenerator:
94
- while True:
95
- yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
96
-
97
-
98
- class ClockSettings(ez.Settings):
99
- """Settings for clock generator."""
100
-
101
- dispatch_rate: float | str | None = None
102
- """Dispatch rate in Hz, 'realtime', or None for external clock"""
103
-
104
-
105
- @processor_state
106
- class ClockState:
107
- """State for clock generator."""
108
-
109
- t_0: float = field(default_factory=time.time) # Start time
110
- n_dispatch: int = 0 # Number of dispatches
111
-
112
-
113
- class ClockProducer(BaseStatefulProducer[ClockSettings, ez.Flag, ClockState]):
114
- """
115
- Produces clock ticks at specified rate.
116
- Can be used to drive periodic operations.
117
- """
118
-
119
- def _reset_state(self) -> None:
120
- """Reset internal state."""
121
- self._state.t_0 = time.time()
122
- self._state.n_dispatch = 0
123
-
124
- def __call__(self) -> ez.Flag:
125
- """Synchronous clock production. We override __call__ (which uses run_coroutine_sync) to avoid async overhead."""
126
- if self._hash == -1:
127
- self._reset_state()
128
- self._hash = 0
129
-
130
- if isinstance(self.settings.dispatch_rate, (int, float)):
131
- # Manual dispatch_rate. (else it is 'as fast as possible')
132
- target_time = (
133
- self.state.t_0
134
- + (self.state.n_dispatch + 1) / self.settings.dispatch_rate
135
- )
136
- now = time.time()
137
- if target_time > now:
138
- time.sleep(target_time - now)
139
-
140
- self.state.n_dispatch += 1
141
- return ez.Flag()
142
-
143
- async def _produce(self) -> ez.Flag:
144
- """Generate next clock tick."""
145
- if isinstance(self.settings.dispatch_rate, (int, float)):
146
- # Manual dispatch_rate. (else it is 'as fast as possible')
147
- target_time = (
148
- self.state.t_0
149
- + (self.state.n_dispatch + 1) / self.settings.dispatch_rate
150
- )
151
- now = time.time()
152
- if target_time > now:
153
- await asyncio.sleep(target_time - now)
154
-
155
- self.state.n_dispatch += 1
156
- return ez.Flag()
157
-
158
-
159
- def aclock(dispatch_rate: float | None) -> ClockProducer:
160
- """
161
- Construct an async generator that yields events at a specified rate.
162
-
163
- Returns:
164
- A :obj:`ClockProducer` object.
165
- """
166
- return ClockProducer(ClockSettings(dispatch_rate=dispatch_rate))
167
-
168
-
169
- clock = aclock
170
- """
171
- Alias for :obj:`aclock` expected by synchronous methods. `ClockProducer` can be used in sync or async.
172
- """
173
-
174
-
175
- class Clock(
176
- BaseProducerUnit[
177
- ClockSettings, # SettingsType
178
- ez.Flag, # MessageType
179
- ClockProducer, # ProducerType
180
- ]
181
- ):
182
- SETTINGS = ClockSettings
183
-
184
- @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
185
- async def produce(self) -> typing.AsyncGenerator:
186
- # Override so we can not to yield if out is False-like
187
- while True:
188
- out = await self.producer.__acall__()
189
- if out:
190
- yield self.OUTPUT_SIGNAL, out
191
-
192
-
193
- # COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. #
194
- class CounterSettings(ez.Settings):
195
- # TODO: Adapt this to use ezmsg.util.rate?
196
- """
197
- Settings for :obj:`Counter`.
198
- See :obj:`acounter` for a description of the parameters.
199
- """
200
-
201
- n_time: int
202
- """Number of samples to output per block."""
203
-
204
- fs: float
205
- """Sampling rate of signal output in Hz"""
206
-
207
- n_ch: int = 1
208
- """Number of channels to synthesize"""
209
-
210
- dispatch_rate: float | str | None = None
211
- """
212
- Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
213
- Note: if dispatch_rate is a float then time offsets will be synthetic and the
214
- system will run faster or slower than wall clock time.
215
- """
216
-
217
- mod: int | None = None
218
- """If set to an integer, counter will rollover"""
219
-
220
-
221
- @processor_state
222
- class CounterState:
223
- """
224
- State for counter generator.
225
- """
226
-
227
- counter_start: int = 0
228
- """next sample's first value"""
229
-
230
- n_sent: int = 0
231
- """number of samples sent"""
232
-
233
- clock_zero: float | None = None
234
- """time of first sample"""
235
-
236
- timer_type: str = "unspecified"
237
- """
238
- "realtime" | "ext_clock" | "manual" | "unspecified"
239
- """
240
-
241
- new_generator: asyncio.Event | None = None
242
- """
243
- Event to signal the counter has been reset.
244
- """
245
-
246
-
247
- class CounterProducer(BaseStatefulProducer[CounterSettings, AxisArray, CounterState]):
248
- """Produces incrementing integer blocks as AxisArray."""
249
-
250
- # TODO: Adapt this to use ezmsg.util.rate?
251
-
252
- @classmethod
253
- def get_message_type(cls, dir: str) -> typing.Optional[type[AxisArray]]:
254
- if dir == "in":
255
- return None
256
- elif dir == "out":
257
- return AxisArray
258
- else:
259
- raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
260
-
261
- def __init__(self, *args, **kwargs):
262
- super().__init__(*args, **kwargs)
263
- if isinstance(
264
- self.settings.dispatch_rate, str
265
- ) and self.settings.dispatch_rate not in ["realtime", "ext_clock"]:
266
- raise ValueError(f"Unknown dispatch_rate: {self.settings.dispatch_rate}")
267
- self._reset_state()
268
- self._hash = 0
269
-
270
- def _reset_state(self) -> None:
271
- """Reset internal state."""
272
- self._state.counter_start = 0
273
- self._state.n_sent = 0
274
- self._state.clock_zero = time.time()
275
- if self.settings.dispatch_rate is not None:
276
- if isinstance(self.settings.dispatch_rate, str):
277
- self._state.timer_type = self.settings.dispatch_rate.lower()
278
- else:
279
- self._state.timer_type = "manual"
280
- if self._state.new_generator is None:
281
- self._state.new_generator = asyncio.Event()
282
- # Set the event to indicate that the state has been reset.
283
- self._state.new_generator.set()
284
-
285
- async def _produce(self) -> AxisArray:
286
- """Generate next counter block."""
287
- # 1. Prepare counter data
288
- block_samp = np.arange(
289
- self.state.counter_start, self.state.counter_start + self.settings.n_time
290
- )[:, np.newaxis]
291
- if self.settings.mod is not None:
292
- block_samp %= self.settings.mod
293
- block_samp = np.tile(block_samp, (1, self.settings.n_ch))
294
-
295
- # 2. Sleep if necessary. 3. Calculate time offset.
296
- if self._state.timer_type == "realtime":
297
- n_next = self.state.n_sent + self.settings.n_time
298
- t_next = self.state.clock_zero + n_next / self.settings.fs
299
- await asyncio.sleep(t_next - time.time())
300
- offset = t_next - self.settings.n_time / self.settings.fs
301
- elif self._state.timer_type == "manual":
302
- # manual dispatch rate
303
- n_disp_next = 1 + self.state.n_sent / self.settings.n_time
304
- t_disp_next = (
305
- self.state.clock_zero + n_disp_next / self.settings.dispatch_rate
306
- )
307
- await asyncio.sleep(t_disp_next - time.time())
308
- offset = self.state.n_sent / self.settings.fs
309
- elif self._state.timer_type == "ext_clock":
310
- # ext_clock -- no sleep. Assume this is called at appropriate intervals.
311
- offset = time.time()
312
- else:
313
- # Was "unspecified"
314
- offset = self.state.n_sent / self.settings.fs
315
-
316
- # 4. Create output AxisArray
317
- # Note: We can make this a bit faster by preparing a template for self._state
318
- result = AxisArray(
319
- data=block_samp,
320
- dims=["time", "ch"],
321
- axes={
322
- "time": AxisArray.TimeAxis(fs=self.settings.fs, offset=offset),
323
- "ch": AxisArray.CoordinateAxis(
324
- data=np.array([f"Ch{_}" for _ in range(self.settings.n_ch)]),
325
- dims=["ch"],
326
- ),
327
- },
328
- key="acounter",
329
- )
330
-
331
- # 5. Update state
332
- self.state.counter_start = block_samp[-1, 0] + 1
333
- self.state.n_sent += self.settings.n_time
334
-
335
- return result
336
-
337
-
338
- def acounter(
339
- n_time: int,
340
- fs: float | None,
341
- n_ch: int = 1,
342
- dispatch_rate: float | str | None = None,
343
- mod: int | None = None,
344
- ) -> CounterProducer:
345
- """
346
- Construct an asynchronous generator to generate AxisArray objects at a specified rate
347
- and with the specified sampling rate.
348
-
349
- NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode.
350
- This method of sleeping/yielding execution priority has quirky behavior with
351
- sub-millisecond sleep periods which may result in unexpected behavior (e.g.
352
- fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
353
-
354
- Returns:
355
- An asynchronous generator.
356
- """
357
- return CounterProducer(
358
- CounterSettings(
359
- n_time=n_time, fs=fs, n_ch=n_ch, dispatch_rate=dispatch_rate, mod=mod
360
- )
361
- )
362
-
363
-
364
- class Counter(
365
- BaseProducerUnit[
366
- CounterSettings, # SettingsType
367
- AxisArray, # MessageOutType
368
- CounterProducer, # ProducerType
369
- ]
370
- ):
371
- """Generates monotonically increasing counter. Unit for :obj:`CounterProducer`."""
372
-
373
- SETTINGS = CounterSettings
374
- INPUT_CLOCK = ez.InputStream(ez.Flag)
375
-
376
- @ez.subscriber(INPUT_CLOCK)
377
- @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
378
- async def on_clock(self, _: ez.Flag):
379
- if self.producer.settings.dispatch_rate == "ext_clock":
380
- out = await self.producer.__acall__()
381
- yield self.OUTPUT_SIGNAL, out
382
-
383
- @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
384
- async def produce(self) -> typing.AsyncGenerator:
385
- """
386
- Generate counter output.
387
- This is an infinite loop, but we will likely only enter the loop once if we are self-timed,
388
- and twice if we are using an external clock.
389
-
390
- When using an internal clock, we enter the loop, and wait for the event which should have
391
- been reset upon initialization then we immediately clear, then go to the internal loop
392
- that will async call __acall__ to let the internal timer determine when to produce an output.
393
-
394
- When using an external clock, we enter the loop, and wait for the event which should have been
395
- reset upon initialization then we immediately clear, then we hit `continue` to loop back around
396
- and wait for the event to be set again -- potentially forever. In this case, it is expected that
397
- `on_clock` will be called to produce the output.
398
- """
399
- try:
400
- while True:
401
- # Once-only, enter the generator loop
402
- await self.producer.state.new_generator.wait()
403
- self.producer.state.new_generator.clear()
404
-
405
- if self.producer.settings.dispatch_rate == "ext_clock":
406
- # We shouldn't even be here. Cycle around and wait on the event again.
407
- continue
408
-
409
- # We are not using an external clock. Run the generator.
410
- while not self.producer.state.new_generator.is_set():
411
- out = await self.producer.__acall__()
412
- yield self.OUTPUT_SIGNAL, out
413
- except Exception:
414
- ez.logger.info(traceback.format_exc())
415
-
416
-
417
- class SinGeneratorSettings(ez.Settings):
418
- """
419
- Settings for :obj:`SinGenerator`.
420
- See :obj:`sin` for parameter descriptions.
421
- """
422
-
423
- axis: str | None = "time"
424
- """
425
- The name of the axis over which the sinusoid passes.
426
- Note: The axis must exist in the msg.axes and be of type AxisArray.LinearAxis.
427
- """
428
-
429
- freq: float = 1.0
430
- """The frequency of the sinusoid, in Hz."""
431
-
432
- amp: float = 1.0 # Amplitude
433
- """The amplitude of the sinusoid."""
434
-
435
- phase: float = 0.0 # Phase offset (in radians)
436
- """The initial phase of the sinusoid, in radians."""
437
-
438
-
439
- class SinTransformer(BaseTransformer[SinGeneratorSettings, AxisArray, AxisArray]):
440
- """Transforms counter values into sinusoidal waveforms."""
441
-
442
- def _process(self, message: AxisArray) -> AxisArray:
443
- """Transform input counter values into sinusoidal waveform."""
444
- axis = self.settings.axis or message.dims[0]
445
-
446
- ang_freq = 2.0 * np.pi * self.settings.freq
447
- w = (ang_freq * message.get_axis(axis).gain) * message.data
448
- out_data = self.settings.amp * np.sin(w + self.settings.phase)
449
-
450
- return replace(message, data=out_data)
451
-
452
-
453
- class SinGenerator(
454
- BaseTransformerUnit[SinGeneratorSettings, AxisArray, AxisArray, SinTransformer]
455
- ):
456
- """Unit for generating sinusoidal waveforms."""
457
-
458
- SETTINGS = SinGeneratorSettings
459
-
460
-
461
- def sin(
462
- axis: str | None = "time",
463
- freq: float = 1.0,
464
- amp: float = 1.0,
465
- phase: float = 0.0,
466
- ) -> SinTransformer:
467
- """
468
- Construct a generator of sinusoidal waveforms in AxisArray objects.
469
-
470
- Returns:
471
- A primed generator that expects .send(axis_array) of sample counts
472
- and yields an AxisArray of sinusoids.
473
- """
474
- return SinTransformer(
475
- SinGeneratorSettings(axis=axis, freq=freq, amp=amp, phase=phase)
476
- )
477
-
478
-
479
- class RandomGeneratorSettings(ez.Settings):
480
- loc: float = 0.0
481
- """loc argument for :obj:`numpy.random.normal`"""
482
-
483
- scale: float = 1.0
484
- """scale argument for :obj:`numpy.random.normal`"""
485
-
486
-
487
- class RandomTransformer(BaseTransformer[RandomGeneratorSettings, AxisArray, AxisArray]):
488
- """
489
- Replaces input data with random data and returns the result.
490
- """
491
-
492
- def __init__(
493
- self, *args, settings: RandomGeneratorSettings | None = None, **kwargs
494
- ):
495
- super().__init__(*args, settings=settings, **kwargs)
496
-
497
- def _process(self, message: AxisArray) -> AxisArray:
498
- random_data = np.random.normal(
499
- size=message.shape, loc=self.settings.loc, scale=self.settings.scale
500
- )
501
- return replace(message, data=random_data)
502
-
503
-
504
- class RandomGenerator(
505
- BaseTransformerUnit[
506
- RandomGeneratorSettings,
507
- AxisArray,
508
- AxisArray,
509
- RandomTransformer,
510
- ]
511
- ):
512
- SETTINGS = RandomGeneratorSettings
513
-
514
-
515
- class OscillatorSettings(ez.Settings):
516
- """Settings for :obj:`Oscillator`"""
517
-
518
- n_time: int
519
- """Number of samples to output per block."""
520
-
521
- fs: float
522
- """Sampling rate of signal output in Hz"""
523
-
524
- n_ch: int = 1
525
- """Number of channels to output per block"""
526
-
527
- dispatch_rate: float | str | None = None
528
- """(Hz) | 'realtime' | 'ext_clock'"""
529
-
530
- freq: float = 1.0
531
- """Oscillation frequency in Hz"""
532
-
533
- amp: float = 1.0
534
- """Amplitude"""
535
-
536
- phase: float = 0.0
537
- """Phase offset (in radians)"""
538
-
539
- sync: bool = False
540
- """Adjust `freq` to sync with sampling rate"""
541
-
542
-
543
- class OscillatorProducer(CompositeProducer[OscillatorSettings, AxisArray]):
544
- @staticmethod
545
- def _initialize_processors(
546
- settings: OscillatorSettings,
547
- ) -> dict[str, CounterProducer | SinTransformer]:
548
- # Calculate synchronous settings if necessary
549
- freq = settings.freq
550
- mod = None
551
- if settings.sync:
552
- period = 1.0 / settings.freq
553
- mod = round(period * settings.fs)
554
- freq = 1.0 / (mod / settings.fs)
555
-
556
- return {
557
- "counter": CounterProducer(
558
- CounterSettings(
559
- n_time=settings.n_time,
560
- fs=settings.fs,
561
- n_ch=settings.n_ch,
562
- dispatch_rate=settings.dispatch_rate,
563
- mod=mod,
564
- )
565
- ),
566
- "sin": SinTransformer(
567
- SinGeneratorSettings(freq=freq, amp=settings.amp, phase=settings.phase)
568
- ),
569
- }
570
-
571
-
572
- class BaseCounterFirstProducerUnit(
573
- BaseProducerUnit[SettingsType, MessageOutType, ProducerType],
574
- typing.Generic[SettingsType, MessageInType, MessageOutType, ProducerType],
575
- ):
576
- """
577
- Base class for units whose primary processor is a composite producer with a CounterProducer as the first
578
- processor (producer) in the chain.
579
- """
580
-
581
- INPUT_SIGNAL = ez.InputStream(MessageInType)
582
-
583
- def create_producer(self):
584
- super().create_producer()
585
-
586
- def recurse_get_counter(proc) -> CounterProducer:
587
- if hasattr(proc, "_procs"):
588
- return recurse_get_counter(list(proc._procs.values())[0])
589
- return proc
590
-
591
- self._counter = recurse_get_counter(self.producer)
592
-
593
- @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
594
- @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
595
- @profile_subpub(trace_oldest=False)
596
- async def on_signal(self, _: ez.Flag):
597
- if self.producer.settings.dispatch_rate == "ext_clock":
598
- out = await self.producer.__acall__()
599
- yield self.OUTPUT_SIGNAL, out
600
-
601
- @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
602
- async def produce(self) -> typing.AsyncGenerator:
603
- try:
604
- counter_state = self._counter.state
605
- while True:
606
- # Once-only, enter the generator loop
607
- await counter_state.new_generator.wait()
608
- counter_state.new_generator.clear()
609
-
610
- if self.producer.settings.dispatch_rate == "ext_clock":
611
- # We shouldn't even be here. Cycle around and wait on the event again.
612
- continue
613
-
614
- # We are not using an external clock. Run the generator.
615
- while not counter_state.new_generator.is_set():
616
- out = await self.producer.__acall__()
617
- yield self.OUTPUT_SIGNAL, out
618
- except Exception:
619
- ez.logger.info(traceback.format_exc())
620
-
621
-
622
- class Oscillator(
623
- BaseCounterFirstProducerUnit[
624
- OscillatorSettings, AxisArray, AxisArray, OscillatorProducer
625
- ]
626
- ):
627
- """Generates sinusoidal waveforms using a counter and sine transformer."""
628
-
629
- SETTINGS = OscillatorSettings
630
-
631
-
632
- class NoiseSettings(ez.Settings):
633
- """
634
- See :obj:`CounterSettings` and :obj:`RandomGeneratorSettings`.
635
- """
636
-
637
- n_time: int # Number of samples to output per block
638
- fs: float # Sampling rate of signal output in Hz
639
- n_ch: int = 1 # Number of channels to output
640
- dispatch_rate: float | str | None = None
641
- """(Hz), 'realtime', or 'ext_clock'"""
642
- loc: float = 0.0 # DC offset
643
- scale: float = 1.0 # Scale (in standard deviations)
644
-
645
-
646
- WhiteNoiseSettings = NoiseSettings
647
-
648
-
649
- class WhiteNoiseProducer(CompositeProducer[NoiseSettings, AxisArray]):
650
- @staticmethod
651
- def _initialize_processors(
652
- settings: NoiseSettings,
653
- ) -> dict[str, CounterProducer | RandomTransformer]:
654
- return {
655
- "counter": CounterProducer(
656
- CounterSettings(
657
- n_time=settings.n_time,
658
- fs=settings.fs,
659
- n_ch=settings.n_ch,
660
- dispatch_rate=settings.dispatch_rate,
661
- mod=None,
662
- )
663
- ),
664
- "random": RandomTransformer(
665
- RandomGeneratorSettings(
666
- loc=settings.loc,
667
- scale=settings.scale,
668
- )
669
- ),
670
- }
671
-
672
-
673
- class WhiteNoise(
674
- BaseCounterFirstProducerUnit[
675
- NoiseSettings, AxisArray, AxisArray, WhiteNoiseProducer
676
- ]
677
- ):
678
- """chains a :obj:`Counter` and :obj:`RandomGenerator`."""
679
-
680
- SETTINGS = NoiseSettings
681
-
682
-
683
- PinkNoiseSettings = NoiseSettings
684
-
685
-
686
- class PinkNoiseProducer(CompositeProducer[PinkNoiseSettings, AxisArray]):
687
- @staticmethod
688
- def _initialize_processors(
689
- settings: PinkNoiseSettings,
690
- ) -> dict[str, WhiteNoiseProducer | ButterworthFilterTransformer]:
691
- return {
692
- "white_noise": WhiteNoiseProducer(settings=settings),
693
- "filter": ButterworthFilterTransformer(
694
- settings=ButterworthFilterSettings(
695
- axis="time",
696
- order=1,
697
- cutoff=settings.fs * 0.01, # Hz
698
- )
699
- ),
700
- }
701
-
702
-
703
- class PinkNoise(
704
- BaseCounterFirstProducerUnit[NoiseSettings, AxisArray, AxisArray, PinkNoiseProducer]
705
- ):
706
- """chains :obj:`WhiteNoise` and :obj:`ButterworthFilter`."""
707
-
708
- SETTINGS = NoiseSettings
709
-
710
-
711
- class EEGSynthSettings(ez.Settings):
712
- """See :obj:`OscillatorSettings`."""
713
-
714
- fs: float = 500.0 # Hz
715
- n_time: int = 100
716
- alpha_freq: float = 10.5 # Hz
717
- n_ch: int = 8
718
-
719
-
720
- class EEGSynth(ez.Collection):
721
- """
722
- A :obj:`Collection` that chains a :obj:`Clock` to both :obj:`PinkNoise`
723
- and :obj:`Oscillator`, then :obj:`Add` s the result.
724
-
725
- Unlike the Oscillator, WhiteNoise, and PinkNoise composite processors which have linear
726
- flows, this class has a diamond flow, with clock branching to both PinkNoise and Oscillator,
727
- which then are combined in Add.
728
-
729
- Optional: Refactor as a ProducerUnit, similar to Clock, but we manually add all the other
730
- transformers.
731
- """
732
-
733
- SETTINGS = EEGSynthSettings
734
-
735
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
736
-
737
- CLOCK = Clock()
738
- NOISE = PinkNoise()
739
- OSC = Oscillator()
740
- ADD = Add()
741
-
742
- def configure(self) -> None:
743
- self.CLOCK.apply_settings(
744
- ClockSettings(dispatch_rate=self.SETTINGS.fs / self.SETTINGS.n_time)
745
- )
746
-
747
- self.OSC.apply_settings(
748
- OscillatorSettings(
749
- n_time=self.SETTINGS.n_time,
750
- fs=self.SETTINGS.fs,
751
- n_ch=self.SETTINGS.n_ch,
752
- dispatch_rate="ext_clock",
753
- freq=self.SETTINGS.alpha_freq,
754
- )
755
- )
756
-
757
- self.NOISE.apply_settings(
758
- PinkNoiseSettings(
759
- n_time=self.SETTINGS.n_time,
760
- fs=self.SETTINGS.fs,
761
- n_ch=self.SETTINGS.n_ch,
762
- dispatch_rate="ext_clock",
763
- scale=5.0,
764
- )
765
- )
766
-
767
- def network(self) -> ez.NetworkDefinition:
768
- return (
769
- (self.CLOCK.OUTPUT_SIGNAL, self.OSC.INPUT_SIGNAL),
770
- (self.CLOCK.OUTPUT_SIGNAL, self.NOISE.INPUT_SIGNAL),
771
- (self.OSC.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A),
772
- (self.NOISE.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B),
773
- (self.ADD.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
774
- )