ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/synth.py CHANGED
@@ -1,21 +1,27 @@
1
1
  import asyncio
2
- from collections import deque
3
- from dataclasses import dataclass, replace, field
2
+ from dataclasses import replace, field
4
3
  import time
5
4
  from typing import Optional, Generator, AsyncGenerator, Union
6
5
 
7
6
  import numpy as np
8
7
  import ezmsg.core as ez
9
- from ezmsg.util.generator import consumer, GenAxisArray
8
+ from ezmsg.util.generator import consumer
10
9
  from ezmsg.util.messages.axisarray import AxisArray
11
10
 
12
11
  from .butterworthfilter import ButterworthFilter, ButterworthFilterSettings
12
+ from .base import GenAxisArray
13
13
 
14
14
 
15
- # CLOCK -- generate events at a specified rate #
16
- def clock(
17
- dispatch_rate: Optional[float]
18
- ) -> Generator[ez.Flag, None, None]:
15
+ def clock(dispatch_rate: Optional[float]) -> Generator[ez.Flag, None, None]:
16
+ """
17
+ Construct a generator that yields events at a specified rate.
18
+
19
+ Args:
20
+ dispatch_rate: event rate in seconds.
21
+
22
+ Returns:
23
+ A generator object that yields :obj:`ez.Flag` events at a specified rate.
24
+ """
19
25
  n_dispatch = -1
20
26
  t_0 = time.time()
21
27
  while True:
@@ -26,9 +32,13 @@ def clock(
26
32
  yield ez.Flag()
27
33
 
28
34
 
29
- async def aclock(
30
- dispatch_rate: Optional[float]
31
- ) -> AsyncGenerator[ez.Flag, None]:
35
+ async def aclock(dispatch_rate: Optional[float]) -> AsyncGenerator[ez.Flag, None]:
36
+ """
37
+ ``asyncio`` version of :obj:`clock`.
38
+
39
+ Returns:
40
+ asynchronous generator object. Must use `anext` or `async for`.
41
+ """
32
42
  t_0 = time.time()
33
43
  n_dispatch = -1
34
44
  while True:
@@ -40,6 +50,8 @@ async def aclock(
40
50
 
41
51
 
42
52
  class ClockSettings(ez.Settings):
53
+ """Settings for :obj:`Clock`. See :obj:`clock` for parameter description."""
54
+
43
55
  # Message dispatch rate (Hz), or None (fast as possible)
44
56
  dispatch_rate: Optional[float]
45
57
 
@@ -50,13 +62,15 @@ class ClockState(ez.State):
50
62
 
51
63
 
52
64
  class Clock(ez.Unit):
53
- SETTINGS: ClockSettings
54
- STATE: ClockState
65
+ """Unit for :obj:`clock`."""
66
+
67
+ SETTINGS = ClockSettings
68
+ STATE = ClockState
55
69
 
56
70
  INPUT_SETTINGS = ez.InputStream(ClockSettings)
57
71
  OUTPUT_CLOCK = ez.OutputStream(ez.Flag)
58
72
 
59
- def initialize(self) -> None:
73
+ async def initialize(self) -> None:
60
74
  self.STATE.cur_settings = self.SETTINGS
61
75
  self.construct_generator()
62
76
 
@@ -78,18 +92,33 @@ class Clock(ez.Unit):
78
92
 
79
93
  # COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. #
80
94
  async def acounter(
81
- n_time: int, # Number of samples to output per block
82
- fs: Optional[float], # Sampling rate of signal output in Hz
83
- n_ch: int = 1, # Number of channels to synthesize
84
-
85
- # Message dispatch rate (Hz), 'realtime' or None (fast as possible)
86
- # Note: if dispatch_rate is a float then time offsets will be synthetic and the
87
- # system will run faster or slower than wall clock time.
95
+ n_time: int,
96
+ fs: Optional[float],
97
+ n_ch: int = 1,
88
98
  dispatch_rate: Optional[Union[float, str]] = None,
89
-
90
- # If set to an integer, counter will rollover at this number.
91
99
  mod: Optional[int] = None,
92
100
  ) -> AsyncGenerator[AxisArray, None]:
101
+ """
102
+ Construct an asynchronous generator to generate AxisArray objects at a specified rate
103
+ and with the specified sampling rate.
104
+
105
+ NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode.
106
+ This method of sleeping/yielding execution priority has quirky behavior with
107
+ sub-millisecond sleep periods which may result in unexpected behavior (e.g.
108
+ fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
109
+
110
+ Args:
111
+ n_time: Number of samples to output per block.
112
+ fs: Sampling rate of signal output in Hz.
113
+ n_ch: Number of channels to synthesize
114
+ dispatch_rate: Message dispatch rate (Hz), 'realtime' or None (fast as possible)
115
+ Note: if dispatch_rate is a float then time offsets will be synthetic and the
116
+ system will run faster or slower than wall clock time.
117
+ mod: If set to an integer, counter will rollover at this number.
118
+
119
+ Returns:
120
+ An asynchronous generator.
121
+ """
93
122
 
94
123
  # TODO: Adapt this to use ezmsg.util.rate?
95
124
 
@@ -150,12 +179,10 @@ async def acounter(
150
179
 
151
180
 
152
181
  class CounterSettings(ez.Settings):
182
+ # TODO: Adapt this to use ezmsg.util.rate?
153
183
  """
154
- TODO: Adapt this to use ezmsg.util.rate?
155
- NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode.
156
- This method of sleeping/yielding execution priority has quirky behavior with
157
- sub-millisecond sleep periods which may result in unexpected behavior (e.g.
158
- fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
184
+ Settings for :obj:`Counter`.
185
+ See :obj:`acounter` for a description of the parameters.
159
186
  """
160
187
 
161
188
  n_time: int # Number of samples to output per block
@@ -178,10 +205,10 @@ class CounterState(ez.State):
178
205
 
179
206
 
180
207
  class Counter(ez.Unit):
181
- """Generates monotonically increasing counter"""
208
+ """Generates monotonically increasing counter. Unit for :obj:`acounter`."""
182
209
 
183
- SETTINGS: CounterSettings
184
- STATE: CounterState
210
+ SETTINGS = CounterSettings
211
+ STATE = CounterState
185
212
 
186
213
  INPUT_CLOCK = ez.InputStream(ez.Flag)
187
214
  INPUT_SETTINGS = ez.InputStream(CounterSettings)
@@ -209,27 +236,26 @@ class Counter(ez.Unit):
209
236
  self.STATE.cur_settings.fs,
210
237
  n_ch=self.STATE.cur_settings.n_ch,
211
238
  dispatch_rate=self.STATE.cur_settings.dispatch_rate,
212
- mod=self.STATE.cur_settings.mod
239
+ mod=self.STATE.cur_settings.mod,
213
240
  )
214
241
  self.STATE.new_generator.set()
215
-
242
+
216
243
  @ez.subscriber(INPUT_CLOCK)
217
244
  @ez.publisher(OUTPUT_SIGNAL)
218
245
  async def on_clock(self, clock: ez.Flag):
219
- if self.STATE.cur_settings.dispatch_rate == 'ext_clock':
246
+ if self.STATE.cur_settings.dispatch_rate == "ext_clock":
220
247
  out = await self.STATE.gen.__anext__()
221
248
  yield self.OUTPUT_SIGNAL, out
222
249
 
223
250
  @ez.publisher(OUTPUT_SIGNAL)
224
251
  async def run_generator(self) -> AsyncGenerator:
225
252
  while True:
226
-
227
253
  await self.STATE.new_generator.wait()
228
254
  self.STATE.new_generator.clear()
229
-
230
- if self.STATE.cur_settings.dispatch_rate == 'ext_clock':
255
+
256
+ if self.STATE.cur_settings.dispatch_rate == "ext_clock":
231
257
  continue
232
-
258
+
233
259
  while not self.STATE.new_generator.is_set():
234
260
  out = await self.STATE.gen.__anext__()
235
261
  yield self.OUTPUT_SIGNAL, out
@@ -238,29 +264,46 @@ class Counter(ez.Unit):
238
264
  @consumer
239
265
  def sin(
240
266
  axis: Optional[str] = "time",
241
- freq: float = 1.0, # Oscillation frequency in Hz
242
- amp: float = 1.0, # Amplitude
243
- phase: float = 0.0, # Phase offset (in radians)
267
+ freq: float = 1.0,
268
+ amp: float = 1.0,
269
+ phase: float = 0.0,
244
270
  ) -> Generator[AxisArray, AxisArray, None]:
245
- axis_arr_in = AxisArray(np.array([]), dims=[""])
246
- axis_arr_out = AxisArray(np.array([]), dims=[""])
271
+ """
272
+ Construct a generator of sinusoidal waveforms in AxisArray objects.
273
+
274
+ Args:
275
+ axis: The name of the axis over which the sinusoid passes.
276
+ freq: The frequency of the sinusoid, in Hz.
277
+ amp: The amplitude of the sinusoid.
278
+ phase: The initial phase of the sinusoid, in radians.
279
+
280
+ Returns:
281
+ A primed generator that expects .send(axis_array) of sample counts
282
+ and yields an AxisArray of sinusoids.
283
+ """
284
+ msg_out = AxisArray(np.array([]), dims=[""])
247
285
 
248
286
  ang_freq = 2.0 * np.pi * freq
249
287
 
250
288
  while True:
251
- axis_arr_in = yield axis_arr_out
252
- # axis_arr_in is expected to be sample counts
289
+ msg_in: AxisArray = yield msg_out
290
+ # msg_in is expected to be sample counts
253
291
 
254
292
  axis_name = axis
255
293
  if axis_name is None:
256
- axis_name = axis_arr_in.dims[0]
294
+ axis_name = msg_in.dims[0]
257
295
 
258
- w = (ang_freq * axis_arr_in.get_axis(axis_name).gain) * axis_arr_in.data
296
+ w = (ang_freq * msg_in.get_axis(axis_name).gain) * msg_in.data
259
297
  out_data = amp * np.sin(w + phase)
260
- axis_arr_out = replace(axis_arr_in, data=out_data)
298
+ msg_out = replace(msg_in, data=out_data)
261
299
 
262
300
 
263
301
  class SinGeneratorSettings(ez.Settings):
302
+ """
303
+ Settings for :obj:`SinGenerator`.
304
+ See :obj:`sin` for parameter descriptions.
305
+ """
306
+
264
307
  time_axis: Optional[str] = "time"
265
308
  freq: float = 1.0 # Oscillation frequency in Hz
266
309
  amp: float = 1.0 # Amplitude
@@ -268,30 +311,55 @@ class SinGeneratorSettings(ez.Settings):
268
311
 
269
312
 
270
313
  class SinGenerator(GenAxisArray):
271
- SETTINGS: SinGeneratorSettings
314
+ """
315
+ Unit for :obj:`sin`.
316
+ """
317
+
318
+ SETTINGS = SinGeneratorSettings
272
319
 
273
320
  def construct_generator(self):
274
321
  self.STATE.gen = sin(
275
322
  axis=self.SETTINGS.time_axis,
276
323
  freq=self.SETTINGS.freq,
277
324
  amp=self.SETTINGS.amp,
278
- phase=self.SETTINGS.phase
325
+ phase=self.SETTINGS.phase,
279
326
  )
280
327
 
281
328
 
282
329
  class OscillatorSettings(ez.Settings):
283
- n_time: int # Number of samples to output per block
284
- fs: float # Sampling rate of signal output in Hz
285
- n_ch: int = 1 # Number of channels to output per block
286
- dispatch_rate: Optional[Union[float, str]] = None # (Hz) | 'realtime' | 'ext_clock'
287
- freq: float = 1.0 # Oscillation frequency in Hz
288
- amp: float = 1.0 # Amplitude
289
- phase: float = 0.0 # Phase offset (in radians)
290
- sync: bool = False # Adjust `freq` to sync with sampling rate
330
+ """Settings for :obj:`Oscillator`"""
331
+
332
+ n_time: int
333
+ """Number of samples to output per block."""
334
+
335
+ fs: float
336
+ """Sampling rate of signal output in Hz"""
337
+
338
+ n_ch: int = 1
339
+ """Number of channels to output per block"""
340
+
341
+ dispatch_rate: Optional[Union[float, str]] = None
342
+ """(Hz) | 'realtime' | 'ext_clock'"""
343
+
344
+ freq: float = 1.0
345
+ """Oscillation frequency in Hz"""
346
+
347
+ amp: float = 1.0
348
+ """Amplitude"""
349
+
350
+ phase: float = 0.0
351
+ """Phase offset (in radians)"""
352
+
353
+ sync: bool = False
354
+ """Adjust `freq` to sync with sampling rate"""
291
355
 
292
356
 
293
357
  class Oscillator(ez.Collection):
294
- SETTINGS: OscillatorSettings
358
+ """
359
+ :obj:`Collection that chains :obj:`Counter` and :obj:`SinGenerator`.
360
+ """
361
+
362
+ SETTINGS = OscillatorSettings
295
363
 
296
364
  INPUT_CLOCK = ez.InputStream(ez.Flag)
297
365
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -334,11 +402,18 @@ class Oscillator(ez.Collection):
334
402
 
335
403
  class RandomGeneratorSettings(ez.Settings):
336
404
  loc: float = 0.0
405
+ """loc argument for :obj:`numpy.random.normal`"""
406
+
337
407
  scale: float = 1.0
408
+ """scale argument for :obj:`numpy.random.normal`"""
338
409
 
339
410
 
340
411
  class RandomGenerator(ez.Unit):
341
- SETTINGS: RandomGeneratorSettings
412
+ """
413
+ Replaces input data with random data and yields the result.
414
+ """
415
+
416
+ SETTINGS = RandomGeneratorSettings
342
417
 
343
418
  INPUT_SIGNAL = ez.InputStream(AxisArray)
344
419
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -354,12 +429,15 @@ class RandomGenerator(ez.Unit):
354
429
 
355
430
 
356
431
  class NoiseSettings(ez.Settings):
432
+ """
433
+ See :obj:`CounterSettings` and :obj:`RandomGeneratorSettings`.
434
+ """
435
+
357
436
  n_time: int # Number of samples to output per block
358
437
  fs: float # Sampling rate of signal output in Hz
359
438
  n_ch: int = 1 # Number of channels to output
360
- dispatch_rate: Optional[
361
- Union[float, str]
362
- ] = None # (Hz), 'realtime', or 'ext_clock'
439
+ dispatch_rate: Optional[Union[float, str]] = None
440
+ """(Hz), 'realtime', or 'ext_clock'"""
363
441
  loc: float = 0.0 # DC offset
364
442
  scale: float = 1.0 # Scale (in standard deviations)
365
443
 
@@ -368,7 +446,11 @@ WhiteNoiseSettings = NoiseSettings
368
446
 
369
447
 
370
448
  class WhiteNoise(ez.Collection):
371
- SETTINGS: NoiseSettings
449
+ """
450
+ A :obj:`Collection` that chains a :obj:`Counter` and :obj:`RandomGenerator`.
451
+ """
452
+
453
+ SETTINGS = NoiseSettings
372
454
 
373
455
  INPUT_CLOCK = ez.InputStream(ez.Flag)
374
456
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -403,7 +485,11 @@ PinkNoiseSettings = NoiseSettings
403
485
 
404
486
 
405
487
  class PinkNoise(ez.Collection):
406
- SETTINGS: PinkNoiseSettings
488
+ """
489
+ A :obj:`Collection` that chains :obj:`WhiteNoise` and :obj:`ButterworthFilter`.
490
+ """
491
+
492
+ SETTINGS = PinkNoiseSettings
407
493
 
408
494
  INPUT_CLOCK = ez.InputStream(ez.Flag)
409
495
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -415,7 +501,9 @@ class PinkNoise(ez.Collection):
415
501
  self.WHITE_NOISE.apply_settings(self.SETTINGS)
416
502
  self.FILTER.apply_settings(
417
503
  ButterworthFilterSettings(
418
- axis="time", order=1, cutoff=self.SETTINGS.fs * 0.01 # Hz
504
+ axis="time",
505
+ order=1,
506
+ cutoff=self.SETTINGS.fs * 0.01, # Hz
419
507
  )
420
508
  )
421
509
 
@@ -435,7 +523,7 @@ class AddState(ez.State):
435
523
  class Add(ez.Unit):
436
524
  """Add two signals together. Assumes compatible/similar axes/dimensions."""
437
525
 
438
- STATE: AddState
526
+ STATE = AddState
439
527
 
440
528
  INPUT_SIGNAL_A = ez.InputStream(AxisArray)
441
529
  INPUT_SIGNAL_B = ez.InputStream(AxisArray)
@@ -459,6 +547,8 @@ class Add(ez.Unit):
459
547
 
460
548
 
461
549
  class EEGSynthSettings(ez.Settings):
550
+ """See :obj:`OscillatorSettings`."""
551
+
462
552
  fs: float = 500.0 # Hz
463
553
  n_time: int = 100
464
554
  alpha_freq: float = 10.5 # Hz
@@ -466,7 +556,12 @@ class EEGSynthSettings(ez.Settings):
466
556
 
467
557
 
468
558
  class EEGSynth(ez.Collection):
469
- SETTINGS: EEGSynthSettings
559
+ """
560
+ A :obj:`Collection` that chains a :obj:`Clock` to both :obj:`PinkNoise`
561
+ and :obj:`Oscillator`, then :obj:`Add` s the result.
562
+ """
563
+
564
+ SETTINGS = EEGSynthSettings
470
565
 
471
566
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
472
567
 
@@ -0,0 +1,167 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import pywt
7
+ import ezmsg.core as ez
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.generator import consumer
10
+
11
+ from .base import GenAxisArray
12
+ from .filterbank import filterbank, FilterbankMode, MinPhaseMode
13
+
14
+
15
+ @consumer
16
+ def cwt(
17
+ scales: typing.Union[list, tuple, npt.NDArray],
18
+ wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet],
19
+ min_phase: MinPhaseMode = MinPhaseMode.NONE,
20
+ axis: str = "time",
21
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
22
+ """
23
+ Build a generator to perform a continuous wavelet transform on sent AxisArray messages.
24
+ The function is equivalent to the `pywt.cwt` function, but is designed to work with streaming data.
25
+
26
+ Args:
27
+ scales: The wavelet scales to use.
28
+ wavelet: Wavelet object or name of wavelet to use.
29
+ min_phase: See filterbank MinPhaseMode for details.
30
+ axis: The target axis for operation. Note that this will be moved to the -1th dimension
31
+ because fft and matrix multiplication is much faster on the last axis.
32
+
33
+ Returns:
34
+ A Generator object that expects `.send(axis_array)` of continuous data
35
+ """
36
+ msg_out: typing.Optional[AxisArray] = None
37
+
38
+ # Check parameters
39
+ scales = np.array(scales)
40
+ assert np.all(scales > 0), "Scales must be positive."
41
+ assert scales.ndim == 1, "Scales must be a 1D list, tuple, or array."
42
+ if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
43
+ wavelet = pywt.DiscreteContinuousWavelet(wavelet)
44
+ precision = 10
45
+
46
+ # State variables
47
+ neg_rt_scales = -np.sqrt(scales)[:, None]
48
+ int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
49
+ int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
50
+ template: typing.Optional[AxisArray] = None
51
+ fbgen: typing.Optional[typing.Generator[AxisArray, AxisArray, None]] = None
52
+ last_conv_samp: typing.Optional[npt.NDArray] = None
53
+
54
+ # Reset if input changed
55
+ check_input = {
56
+ "kind": None, # Need to recalc kernels at same complexity as input
57
+ "gain": None, # Need to recalc freqs
58
+ "shape": None, # Need to recalc template and buffer
59
+ "key": None, # Buffer obsolete
60
+ }
61
+
62
+ while True:
63
+ msg_in: AxisArray = yield msg_out
64
+ ax_idx = msg_in.get_axis_idx(axis)
65
+ in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1 :]
66
+
67
+ b_reset = msg_in.data.dtype.kind != check_input["kind"]
68
+ b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
69
+ b_reset = b_reset or in_shape != check_input["shape"]
70
+ b_reset = b_reset or msg_in.key != check_input["key"]
71
+ b_reset = b_reset and msg_in.data.size > 0
72
+ if b_reset:
73
+ check_input["kind"] = msg_in.data.dtype.kind
74
+ check_input["gain"] = msg_in.axes[axis].gain
75
+ check_input["shape"] = in_shape
76
+ check_input["key"] = msg_in.key
77
+
78
+ # convert int_psi, wave_xvec to the same precision as the data
79
+ dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
80
+ dt_cplx = np.result_type(dt_data, np.complex64)
81
+ dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
82
+ int_psi = np.asarray(int_psi, dtype=dt_psi)
83
+ # TODO: Currently int_psi cannot be made non-complex once it is complex.
84
+
85
+ # Calculate waves for each scale
86
+ wave_xvec = np.asarray(wave_xvec, dtype=msg_in.data.real.dtype)
87
+ wave_range = wave_xvec[-1] - wave_xvec[0]
88
+ step = wave_xvec[1] - wave_xvec[0]
89
+ int_psi_scales = []
90
+ for scale in scales:
91
+ reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
92
+ if reix[-1] >= int_psi.size:
93
+ reix = np.extract(reix < int_psi.size, reix)
94
+ int_psi_scales.append(int_psi[reix][::-1])
95
+
96
+ # CONV is probably best because we often get huge kernels.
97
+ fbgen = filterbank(
98
+ int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis
99
+ )
100
+
101
+ freqs = (
102
+ pywt.scale2frequency(wavelet, scales, precision)
103
+ / msg_in.axes[axis].gain
104
+ )
105
+ fstep = (freqs[1] - freqs[0]) if len(freqs) > 1 else 1.0
106
+ # Create output template
107
+ dummy_shape = in_shape + (len(scales), 0)
108
+ template = AxisArray(
109
+ np.zeros(
110
+ dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data
111
+ ),
112
+ dims=msg_in.dims[:ax_idx] + msg_in.dims[ax_idx + 1 :] + ["freq", axis],
113
+ axes={
114
+ **msg_in.axes,
115
+ "freq": AxisArray.Axis("Hz", offset=freqs[0], gain=fstep),
116
+ },
117
+ )
118
+ last_conv_samp = np.zeros(
119
+ dummy_shape[:-1] + (1,), dtype=template.data.dtype
120
+ )
121
+
122
+ conv_msg = fbgen.send(msg_in)
123
+
124
+ # Prepend with last_conv_samp before doing diff
125
+ dat = np.concatenate((last_conv_samp, conv_msg.data), axis=-1)
126
+ coef = neg_rt_scales * np.diff(dat, axis=-1)
127
+ # Store last_conv_samp for next iteration.
128
+ last_conv_samp = conv_msg.data[..., -1:]
129
+
130
+ if template.data.dtype.kind != "c":
131
+ coef = coef.real
132
+
133
+ # pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
134
+ # that luxury when streaming.
135
+ # d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
136
+ # coef = coef[..., math.floor(d):-math.ceil(d)]
137
+ msg_out = replace(
138
+ template, data=coef, axes={**template.axes, axis: msg_in.axes[axis]}
139
+ )
140
+
141
+
142
+ class CWTSettings(ez.Settings):
143
+ """
144
+ Settings for :obj:`CWT`
145
+ See :obj:`cwt` for argument details.
146
+ """
147
+
148
+ scales: typing.Union[list, tuple, npt.NDArray]
149
+ wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet]
150
+ min_phase: MinPhaseMode = MinPhaseMode.NONE
151
+ axis: str = "time"
152
+
153
+
154
+ class CWT(GenAxisArray):
155
+ """
156
+ :obj:`Unit` for :obj:`common_rereference`.
157
+ """
158
+
159
+ SETTINGS = CWTSettings
160
+
161
+ def construct_generator(self):
162
+ self.STATE.gen = cwt(
163
+ scales=self.SETTINGS.scales,
164
+ wavelet=self.SETTINGS.wavelet,
165
+ min_phase=self.SETTINGS.min_phase,
166
+ axis=self.SETTINGS.axis,
167
+ )