ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -4
- ezmsg/sigproc/__version__.py +16 -0
- ezmsg/sigproc/activation.py +75 -0
- ezmsg/sigproc/affinetransform.py +149 -39
- ezmsg/sigproc/aggregate.py +84 -29
- ezmsg/sigproc/bandpower.py +36 -15
- ezmsg/sigproc/base.py +38 -0
- ezmsg/sigproc/butterworthfilter.py +76 -20
- ezmsg/sigproc/decimate.py +7 -4
- ezmsg/sigproc/downsample.py +79 -61
- ezmsg/sigproc/ewmfilter.py +28 -14
- ezmsg/sigproc/filter.py +51 -31
- ezmsg/sigproc/filterbank.py +278 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +28 -0
- ezmsg/sigproc/math/clip.py +30 -0
- ezmsg/sigproc/math/difference.py +60 -0
- ezmsg/sigproc/math/invert.py +29 -0
- ezmsg/sigproc/math/log.py +32 -0
- ezmsg/sigproc/math/scale.py +31 -0
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +152 -90
- ezmsg/sigproc/scaler.py +88 -42
- ezmsg/sigproc/signalinjector.py +7 -10
- ezmsg/sigproc/slicer.py +71 -36
- ezmsg/sigproc/spectral.py +6 -9
- ezmsg/sigproc/spectrogram.py +48 -30
- ezmsg/sigproc/spectrum.py +177 -76
- ezmsg/sigproc/synth.py +162 -67
- ezmsg/sigproc/wavelets.py +167 -0
- ezmsg/sigproc/window.py +193 -157
- ezmsg_sigproc-1.3.2.dist-info/METADATA +59 -0
- ezmsg_sigproc-1.3.2.dist-info/RECORD +35 -0
- {ezmsg_sigproc-1.2.3.dist-info → ezmsg_sigproc-1.3.2.dist-info}/WHEEL +1 -1
- ezmsg_sigproc-1.2.3.dist-info/METADATA +0 -38
- ezmsg_sigproc-1.2.3.dist-info/RECORD +0 -23
- {ezmsg_sigproc-1.2.3.dist-info → ezmsg_sigproc-1.3.2.dist-info/licenses}/LICENSE.txt +0 -0
ezmsg/sigproc/synth.py
CHANGED
|
@@ -1,21 +1,27 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from
|
|
3
|
-
from dataclasses import dataclass, replace, field
|
|
2
|
+
from dataclasses import replace, field
|
|
4
3
|
import time
|
|
5
4
|
from typing import Optional, Generator, AsyncGenerator, Union
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
8
7
|
import ezmsg.core as ez
|
|
9
|
-
from ezmsg.util.generator import consumer
|
|
8
|
+
from ezmsg.util.generator import consumer
|
|
10
9
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
10
|
|
|
12
11
|
from .butterworthfilter import ButterworthFilter, ButterworthFilterSettings
|
|
12
|
+
from .base import GenAxisArray
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
15
|
+
def clock(dispatch_rate: Optional[float]) -> Generator[ez.Flag, None, None]:
|
|
16
|
+
"""
|
|
17
|
+
Construct a generator that yields events at a specified rate.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
dispatch_rate: event rate in seconds.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
A generator object that yields :obj:`ez.Flag` events at a specified rate.
|
|
24
|
+
"""
|
|
19
25
|
n_dispatch = -1
|
|
20
26
|
t_0 = time.time()
|
|
21
27
|
while True:
|
|
@@ -26,9 +32,13 @@ def clock(
|
|
|
26
32
|
yield ez.Flag()
|
|
27
33
|
|
|
28
34
|
|
|
29
|
-
async def aclock(
|
|
30
|
-
|
|
31
|
-
|
|
35
|
+
async def aclock(dispatch_rate: Optional[float]) -> AsyncGenerator[ez.Flag, None]:
|
|
36
|
+
"""
|
|
37
|
+
``asyncio`` version of :obj:`clock`.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
asynchronous generator object. Must use `anext` or `async for`.
|
|
41
|
+
"""
|
|
32
42
|
t_0 = time.time()
|
|
33
43
|
n_dispatch = -1
|
|
34
44
|
while True:
|
|
@@ -40,6 +50,8 @@ async def aclock(
|
|
|
40
50
|
|
|
41
51
|
|
|
42
52
|
class ClockSettings(ez.Settings):
|
|
53
|
+
"""Settings for :obj:`Clock`. See :obj:`clock` for parameter description."""
|
|
54
|
+
|
|
43
55
|
# Message dispatch rate (Hz), or None (fast as possible)
|
|
44
56
|
dispatch_rate: Optional[float]
|
|
45
57
|
|
|
@@ -50,13 +62,15 @@ class ClockState(ez.State):
|
|
|
50
62
|
|
|
51
63
|
|
|
52
64
|
class Clock(ez.Unit):
|
|
53
|
-
|
|
54
|
-
|
|
65
|
+
"""Unit for :obj:`clock`."""
|
|
66
|
+
|
|
67
|
+
SETTINGS = ClockSettings
|
|
68
|
+
STATE = ClockState
|
|
55
69
|
|
|
56
70
|
INPUT_SETTINGS = ez.InputStream(ClockSettings)
|
|
57
71
|
OUTPUT_CLOCK = ez.OutputStream(ez.Flag)
|
|
58
72
|
|
|
59
|
-
def initialize(self) -> None:
|
|
73
|
+
async def initialize(self) -> None:
|
|
60
74
|
self.STATE.cur_settings = self.SETTINGS
|
|
61
75
|
self.construct_generator()
|
|
62
76
|
|
|
@@ -78,18 +92,33 @@ class Clock(ez.Unit):
|
|
|
78
92
|
|
|
79
93
|
# COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. #
|
|
80
94
|
async def acounter(
|
|
81
|
-
n_time: int,
|
|
82
|
-
fs: Optional[float],
|
|
83
|
-
n_ch: int = 1,
|
|
84
|
-
|
|
85
|
-
# Message dispatch rate (Hz), 'realtime' or None (fast as possible)
|
|
86
|
-
# Note: if dispatch_rate is a float then time offsets will be synthetic and the
|
|
87
|
-
# system will run faster or slower than wall clock time.
|
|
95
|
+
n_time: int,
|
|
96
|
+
fs: Optional[float],
|
|
97
|
+
n_ch: int = 1,
|
|
88
98
|
dispatch_rate: Optional[Union[float, str]] = None,
|
|
89
|
-
|
|
90
|
-
# If set to an integer, counter will rollover at this number.
|
|
91
99
|
mod: Optional[int] = None,
|
|
92
100
|
) -> AsyncGenerator[AxisArray, None]:
|
|
101
|
+
"""
|
|
102
|
+
Construct an asynchronous generator to generate AxisArray objects at a specified rate
|
|
103
|
+
and with the specified sampling rate.
|
|
104
|
+
|
|
105
|
+
NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode.
|
|
106
|
+
This method of sleeping/yielding execution priority has quirky behavior with
|
|
107
|
+
sub-millisecond sleep periods which may result in unexpected behavior (e.g.
|
|
108
|
+
fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
n_time: Number of samples to output per block.
|
|
112
|
+
fs: Sampling rate of signal output in Hz.
|
|
113
|
+
n_ch: Number of channels to synthesize
|
|
114
|
+
dispatch_rate: Message dispatch rate (Hz), 'realtime' or None (fast as possible)
|
|
115
|
+
Note: if dispatch_rate is a float then time offsets will be synthetic and the
|
|
116
|
+
system will run faster or slower than wall clock time.
|
|
117
|
+
mod: If set to an integer, counter will rollover at this number.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
An asynchronous generator.
|
|
121
|
+
"""
|
|
93
122
|
|
|
94
123
|
# TODO: Adapt this to use ezmsg.util.rate?
|
|
95
124
|
|
|
@@ -150,12 +179,10 @@ async def acounter(
|
|
|
150
179
|
|
|
151
180
|
|
|
152
181
|
class CounterSettings(ez.Settings):
|
|
182
|
+
# TODO: Adapt this to use ezmsg.util.rate?
|
|
153
183
|
"""
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
This method of sleeping/yielding execution priority has quirky behavior with
|
|
157
|
-
sub-millisecond sleep periods which may result in unexpected behavior (e.g.
|
|
158
|
-
fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
|
|
184
|
+
Settings for :obj:`Counter`.
|
|
185
|
+
See :obj:`acounter` for a description of the parameters.
|
|
159
186
|
"""
|
|
160
187
|
|
|
161
188
|
n_time: int # Number of samples to output per block
|
|
@@ -178,10 +205,10 @@ class CounterState(ez.State):
|
|
|
178
205
|
|
|
179
206
|
|
|
180
207
|
class Counter(ez.Unit):
|
|
181
|
-
"""Generates monotonically increasing counter"""
|
|
208
|
+
"""Generates monotonically increasing counter. Unit for :obj:`acounter`."""
|
|
182
209
|
|
|
183
|
-
SETTINGS
|
|
184
|
-
STATE
|
|
210
|
+
SETTINGS = CounterSettings
|
|
211
|
+
STATE = CounterState
|
|
185
212
|
|
|
186
213
|
INPUT_CLOCK = ez.InputStream(ez.Flag)
|
|
187
214
|
INPUT_SETTINGS = ez.InputStream(CounterSettings)
|
|
@@ -209,27 +236,26 @@ class Counter(ez.Unit):
|
|
|
209
236
|
self.STATE.cur_settings.fs,
|
|
210
237
|
n_ch=self.STATE.cur_settings.n_ch,
|
|
211
238
|
dispatch_rate=self.STATE.cur_settings.dispatch_rate,
|
|
212
|
-
mod=self.STATE.cur_settings.mod
|
|
239
|
+
mod=self.STATE.cur_settings.mod,
|
|
213
240
|
)
|
|
214
241
|
self.STATE.new_generator.set()
|
|
215
|
-
|
|
242
|
+
|
|
216
243
|
@ez.subscriber(INPUT_CLOCK)
|
|
217
244
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
218
245
|
async def on_clock(self, clock: ez.Flag):
|
|
219
|
-
if self.STATE.cur_settings.dispatch_rate ==
|
|
246
|
+
if self.STATE.cur_settings.dispatch_rate == "ext_clock":
|
|
220
247
|
out = await self.STATE.gen.__anext__()
|
|
221
248
|
yield self.OUTPUT_SIGNAL, out
|
|
222
249
|
|
|
223
250
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
224
251
|
async def run_generator(self) -> AsyncGenerator:
|
|
225
252
|
while True:
|
|
226
|
-
|
|
227
253
|
await self.STATE.new_generator.wait()
|
|
228
254
|
self.STATE.new_generator.clear()
|
|
229
|
-
|
|
230
|
-
if self.STATE.cur_settings.dispatch_rate ==
|
|
255
|
+
|
|
256
|
+
if self.STATE.cur_settings.dispatch_rate == "ext_clock":
|
|
231
257
|
continue
|
|
232
|
-
|
|
258
|
+
|
|
233
259
|
while not self.STATE.new_generator.is_set():
|
|
234
260
|
out = await self.STATE.gen.__anext__()
|
|
235
261
|
yield self.OUTPUT_SIGNAL, out
|
|
@@ -238,29 +264,46 @@ class Counter(ez.Unit):
|
|
|
238
264
|
@consumer
|
|
239
265
|
def sin(
|
|
240
266
|
axis: Optional[str] = "time",
|
|
241
|
-
freq: float = 1.0,
|
|
242
|
-
amp: float = 1.0,
|
|
243
|
-
phase: float = 0.0,
|
|
267
|
+
freq: float = 1.0,
|
|
268
|
+
amp: float = 1.0,
|
|
269
|
+
phase: float = 0.0,
|
|
244
270
|
) -> Generator[AxisArray, AxisArray, None]:
|
|
245
|
-
|
|
246
|
-
|
|
271
|
+
"""
|
|
272
|
+
Construct a generator of sinusoidal waveforms in AxisArray objects.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
axis: The name of the axis over which the sinusoid passes.
|
|
276
|
+
freq: The frequency of the sinusoid, in Hz.
|
|
277
|
+
amp: The amplitude of the sinusoid.
|
|
278
|
+
phase: The initial phase of the sinusoid, in radians.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
A primed generator that expects .send(axis_array) of sample counts
|
|
282
|
+
and yields an AxisArray of sinusoids.
|
|
283
|
+
"""
|
|
284
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
247
285
|
|
|
248
286
|
ang_freq = 2.0 * np.pi * freq
|
|
249
287
|
|
|
250
288
|
while True:
|
|
251
|
-
|
|
252
|
-
#
|
|
289
|
+
msg_in: AxisArray = yield msg_out
|
|
290
|
+
# msg_in is expected to be sample counts
|
|
253
291
|
|
|
254
292
|
axis_name = axis
|
|
255
293
|
if axis_name is None:
|
|
256
|
-
axis_name =
|
|
294
|
+
axis_name = msg_in.dims[0]
|
|
257
295
|
|
|
258
|
-
w = (ang_freq *
|
|
296
|
+
w = (ang_freq * msg_in.get_axis(axis_name).gain) * msg_in.data
|
|
259
297
|
out_data = amp * np.sin(w + phase)
|
|
260
|
-
|
|
298
|
+
msg_out = replace(msg_in, data=out_data)
|
|
261
299
|
|
|
262
300
|
|
|
263
301
|
class SinGeneratorSettings(ez.Settings):
|
|
302
|
+
"""
|
|
303
|
+
Settings for :obj:`SinGenerator`.
|
|
304
|
+
See :obj:`sin` for parameter descriptions.
|
|
305
|
+
"""
|
|
306
|
+
|
|
264
307
|
time_axis: Optional[str] = "time"
|
|
265
308
|
freq: float = 1.0 # Oscillation frequency in Hz
|
|
266
309
|
amp: float = 1.0 # Amplitude
|
|
@@ -268,30 +311,55 @@ class SinGeneratorSettings(ez.Settings):
|
|
|
268
311
|
|
|
269
312
|
|
|
270
313
|
class SinGenerator(GenAxisArray):
|
|
271
|
-
|
|
314
|
+
"""
|
|
315
|
+
Unit for :obj:`sin`.
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
SETTINGS = SinGeneratorSettings
|
|
272
319
|
|
|
273
320
|
def construct_generator(self):
|
|
274
321
|
self.STATE.gen = sin(
|
|
275
322
|
axis=self.SETTINGS.time_axis,
|
|
276
323
|
freq=self.SETTINGS.freq,
|
|
277
324
|
amp=self.SETTINGS.amp,
|
|
278
|
-
phase=self.SETTINGS.phase
|
|
325
|
+
phase=self.SETTINGS.phase,
|
|
279
326
|
)
|
|
280
327
|
|
|
281
328
|
|
|
282
329
|
class OscillatorSettings(ez.Settings):
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
330
|
+
"""Settings for :obj:`Oscillator`"""
|
|
331
|
+
|
|
332
|
+
n_time: int
|
|
333
|
+
"""Number of samples to output per block."""
|
|
334
|
+
|
|
335
|
+
fs: float
|
|
336
|
+
"""Sampling rate of signal output in Hz"""
|
|
337
|
+
|
|
338
|
+
n_ch: int = 1
|
|
339
|
+
"""Number of channels to output per block"""
|
|
340
|
+
|
|
341
|
+
dispatch_rate: Optional[Union[float, str]] = None
|
|
342
|
+
"""(Hz) | 'realtime' | 'ext_clock'"""
|
|
343
|
+
|
|
344
|
+
freq: float = 1.0
|
|
345
|
+
"""Oscillation frequency in Hz"""
|
|
346
|
+
|
|
347
|
+
amp: float = 1.0
|
|
348
|
+
"""Amplitude"""
|
|
349
|
+
|
|
350
|
+
phase: float = 0.0
|
|
351
|
+
"""Phase offset (in radians)"""
|
|
352
|
+
|
|
353
|
+
sync: bool = False
|
|
354
|
+
"""Adjust `freq` to sync with sampling rate"""
|
|
291
355
|
|
|
292
356
|
|
|
293
357
|
class Oscillator(ez.Collection):
|
|
294
|
-
|
|
358
|
+
"""
|
|
359
|
+
:obj:`Collection that chains :obj:`Counter` and :obj:`SinGenerator`.
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
SETTINGS = OscillatorSettings
|
|
295
363
|
|
|
296
364
|
INPUT_CLOCK = ez.InputStream(ez.Flag)
|
|
297
365
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -334,11 +402,18 @@ class Oscillator(ez.Collection):
|
|
|
334
402
|
|
|
335
403
|
class RandomGeneratorSettings(ez.Settings):
|
|
336
404
|
loc: float = 0.0
|
|
405
|
+
"""loc argument for :obj:`numpy.random.normal`"""
|
|
406
|
+
|
|
337
407
|
scale: float = 1.0
|
|
408
|
+
"""scale argument for :obj:`numpy.random.normal`"""
|
|
338
409
|
|
|
339
410
|
|
|
340
411
|
class RandomGenerator(ez.Unit):
|
|
341
|
-
|
|
412
|
+
"""
|
|
413
|
+
Replaces input data with random data and yields the result.
|
|
414
|
+
"""
|
|
415
|
+
|
|
416
|
+
SETTINGS = RandomGeneratorSettings
|
|
342
417
|
|
|
343
418
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
344
419
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -354,12 +429,15 @@ class RandomGenerator(ez.Unit):
|
|
|
354
429
|
|
|
355
430
|
|
|
356
431
|
class NoiseSettings(ez.Settings):
|
|
432
|
+
"""
|
|
433
|
+
See :obj:`CounterSettings` and :obj:`RandomGeneratorSettings`.
|
|
434
|
+
"""
|
|
435
|
+
|
|
357
436
|
n_time: int # Number of samples to output per block
|
|
358
437
|
fs: float # Sampling rate of signal output in Hz
|
|
359
438
|
n_ch: int = 1 # Number of channels to output
|
|
360
|
-
dispatch_rate: Optional[
|
|
361
|
-
|
|
362
|
-
] = None # (Hz), 'realtime', or 'ext_clock'
|
|
439
|
+
dispatch_rate: Optional[Union[float, str]] = None
|
|
440
|
+
"""(Hz), 'realtime', or 'ext_clock'"""
|
|
363
441
|
loc: float = 0.0 # DC offset
|
|
364
442
|
scale: float = 1.0 # Scale (in standard deviations)
|
|
365
443
|
|
|
@@ -368,7 +446,11 @@ WhiteNoiseSettings = NoiseSettings
|
|
|
368
446
|
|
|
369
447
|
|
|
370
448
|
class WhiteNoise(ez.Collection):
|
|
371
|
-
|
|
449
|
+
"""
|
|
450
|
+
A :obj:`Collection` that chains a :obj:`Counter` and :obj:`RandomGenerator`.
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
SETTINGS = NoiseSettings
|
|
372
454
|
|
|
373
455
|
INPUT_CLOCK = ez.InputStream(ez.Flag)
|
|
374
456
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -403,7 +485,11 @@ PinkNoiseSettings = NoiseSettings
|
|
|
403
485
|
|
|
404
486
|
|
|
405
487
|
class PinkNoise(ez.Collection):
|
|
406
|
-
|
|
488
|
+
"""
|
|
489
|
+
A :obj:`Collection` that chains :obj:`WhiteNoise` and :obj:`ButterworthFilter`.
|
|
490
|
+
"""
|
|
491
|
+
|
|
492
|
+
SETTINGS = PinkNoiseSettings
|
|
407
493
|
|
|
408
494
|
INPUT_CLOCK = ez.InputStream(ez.Flag)
|
|
409
495
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -415,7 +501,9 @@ class PinkNoise(ez.Collection):
|
|
|
415
501
|
self.WHITE_NOISE.apply_settings(self.SETTINGS)
|
|
416
502
|
self.FILTER.apply_settings(
|
|
417
503
|
ButterworthFilterSettings(
|
|
418
|
-
axis="time",
|
|
504
|
+
axis="time",
|
|
505
|
+
order=1,
|
|
506
|
+
cutoff=self.SETTINGS.fs * 0.01, # Hz
|
|
419
507
|
)
|
|
420
508
|
)
|
|
421
509
|
|
|
@@ -435,7 +523,7 @@ class AddState(ez.State):
|
|
|
435
523
|
class Add(ez.Unit):
|
|
436
524
|
"""Add two signals together. Assumes compatible/similar axes/dimensions."""
|
|
437
525
|
|
|
438
|
-
STATE
|
|
526
|
+
STATE = AddState
|
|
439
527
|
|
|
440
528
|
INPUT_SIGNAL_A = ez.InputStream(AxisArray)
|
|
441
529
|
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
|
|
@@ -459,6 +547,8 @@ class Add(ez.Unit):
|
|
|
459
547
|
|
|
460
548
|
|
|
461
549
|
class EEGSynthSettings(ez.Settings):
|
|
550
|
+
"""See :obj:`OscillatorSettings`."""
|
|
551
|
+
|
|
462
552
|
fs: float = 500.0 # Hz
|
|
463
553
|
n_time: int = 100
|
|
464
554
|
alpha_freq: float = 10.5 # Hz
|
|
@@ -466,7 +556,12 @@ class EEGSynthSettings(ez.Settings):
|
|
|
466
556
|
|
|
467
557
|
|
|
468
558
|
class EEGSynth(ez.Collection):
|
|
469
|
-
|
|
559
|
+
"""
|
|
560
|
+
A :obj:`Collection` that chains a :obj:`Clock` to both :obj:`PinkNoise`
|
|
561
|
+
and :obj:`Oscillator`, then :obj:`Add` s the result.
|
|
562
|
+
"""
|
|
563
|
+
|
|
564
|
+
SETTINGS = EEGSynthSettings
|
|
470
565
|
|
|
471
566
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
472
567
|
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import pywt
|
|
7
|
+
import ezmsg.core as ez
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
from ezmsg.util.generator import consumer
|
|
10
|
+
|
|
11
|
+
from .base import GenAxisArray
|
|
12
|
+
from .filterbank import filterbank, FilterbankMode, MinPhaseMode
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@consumer
|
|
16
|
+
def cwt(
|
|
17
|
+
scales: typing.Union[list, tuple, npt.NDArray],
|
|
18
|
+
wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet],
|
|
19
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE,
|
|
20
|
+
axis: str = "time",
|
|
21
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
22
|
+
"""
|
|
23
|
+
Build a generator to perform a continuous wavelet transform on sent AxisArray messages.
|
|
24
|
+
The function is equivalent to the `pywt.cwt` function, but is designed to work with streaming data.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
scales: The wavelet scales to use.
|
|
28
|
+
wavelet: Wavelet object or name of wavelet to use.
|
|
29
|
+
min_phase: See filterbank MinPhaseMode for details.
|
|
30
|
+
axis: The target axis for operation. Note that this will be moved to the -1th dimension
|
|
31
|
+
because fft and matrix multiplication is much faster on the last axis.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
A Generator object that expects `.send(axis_array)` of continuous data
|
|
35
|
+
"""
|
|
36
|
+
msg_out: typing.Optional[AxisArray] = None
|
|
37
|
+
|
|
38
|
+
# Check parameters
|
|
39
|
+
scales = np.array(scales)
|
|
40
|
+
assert np.all(scales > 0), "Scales must be positive."
|
|
41
|
+
assert scales.ndim == 1, "Scales must be a 1D list, tuple, or array."
|
|
42
|
+
if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)):
|
|
43
|
+
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
|
|
44
|
+
precision = 10
|
|
45
|
+
|
|
46
|
+
# State variables
|
|
47
|
+
neg_rt_scales = -np.sqrt(scales)[:, None]
|
|
48
|
+
int_psi, wave_xvec = pywt.integrate_wavelet(wavelet, precision=precision)
|
|
49
|
+
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
|
50
|
+
template: typing.Optional[AxisArray] = None
|
|
51
|
+
fbgen: typing.Optional[typing.Generator[AxisArray, AxisArray, None]] = None
|
|
52
|
+
last_conv_samp: typing.Optional[npt.NDArray] = None
|
|
53
|
+
|
|
54
|
+
# Reset if input changed
|
|
55
|
+
check_input = {
|
|
56
|
+
"kind": None, # Need to recalc kernels at same complexity as input
|
|
57
|
+
"gain": None, # Need to recalc freqs
|
|
58
|
+
"shape": None, # Need to recalc template and buffer
|
|
59
|
+
"key": None, # Buffer obsolete
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
while True:
|
|
63
|
+
msg_in: AxisArray = yield msg_out
|
|
64
|
+
ax_idx = msg_in.get_axis_idx(axis)
|
|
65
|
+
in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1 :]
|
|
66
|
+
|
|
67
|
+
b_reset = msg_in.data.dtype.kind != check_input["kind"]
|
|
68
|
+
b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"]
|
|
69
|
+
b_reset = b_reset or in_shape != check_input["shape"]
|
|
70
|
+
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
71
|
+
b_reset = b_reset and msg_in.data.size > 0
|
|
72
|
+
if b_reset:
|
|
73
|
+
check_input["kind"] = msg_in.data.dtype.kind
|
|
74
|
+
check_input["gain"] = msg_in.axes[axis].gain
|
|
75
|
+
check_input["shape"] = in_shape
|
|
76
|
+
check_input["key"] = msg_in.key
|
|
77
|
+
|
|
78
|
+
# convert int_psi, wave_xvec to the same precision as the data
|
|
79
|
+
dt_data = msg_in.data.dtype # _check_dtype(msg_in.data)
|
|
80
|
+
dt_cplx = np.result_type(dt_data, np.complex64)
|
|
81
|
+
dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data
|
|
82
|
+
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
|
83
|
+
# TODO: Currently int_psi cannot be made non-complex once it is complex.
|
|
84
|
+
|
|
85
|
+
# Calculate waves for each scale
|
|
86
|
+
wave_xvec = np.asarray(wave_xvec, dtype=msg_in.data.real.dtype)
|
|
87
|
+
wave_range = wave_xvec[-1] - wave_xvec[0]
|
|
88
|
+
step = wave_xvec[1] - wave_xvec[0]
|
|
89
|
+
int_psi_scales = []
|
|
90
|
+
for scale in scales:
|
|
91
|
+
reix = (np.arange(scale * wave_range + 1) / (scale * step)).astype(int)
|
|
92
|
+
if reix[-1] >= int_psi.size:
|
|
93
|
+
reix = np.extract(reix < int_psi.size, reix)
|
|
94
|
+
int_psi_scales.append(int_psi[reix][::-1])
|
|
95
|
+
|
|
96
|
+
# CONV is probably best because we often get huge kernels.
|
|
97
|
+
fbgen = filterbank(
|
|
98
|
+
int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
freqs = (
|
|
102
|
+
pywt.scale2frequency(wavelet, scales, precision)
|
|
103
|
+
/ msg_in.axes[axis].gain
|
|
104
|
+
)
|
|
105
|
+
fstep = (freqs[1] - freqs[0]) if len(freqs) > 1 else 1.0
|
|
106
|
+
# Create output template
|
|
107
|
+
dummy_shape = in_shape + (len(scales), 0)
|
|
108
|
+
template = AxisArray(
|
|
109
|
+
np.zeros(
|
|
110
|
+
dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data
|
|
111
|
+
),
|
|
112
|
+
dims=msg_in.dims[:ax_idx] + msg_in.dims[ax_idx + 1 :] + ["freq", axis],
|
|
113
|
+
axes={
|
|
114
|
+
**msg_in.axes,
|
|
115
|
+
"freq": AxisArray.Axis("Hz", offset=freqs[0], gain=fstep),
|
|
116
|
+
},
|
|
117
|
+
)
|
|
118
|
+
last_conv_samp = np.zeros(
|
|
119
|
+
dummy_shape[:-1] + (1,), dtype=template.data.dtype
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
conv_msg = fbgen.send(msg_in)
|
|
123
|
+
|
|
124
|
+
# Prepend with last_conv_samp before doing diff
|
|
125
|
+
dat = np.concatenate((last_conv_samp, conv_msg.data), axis=-1)
|
|
126
|
+
coef = neg_rt_scales * np.diff(dat, axis=-1)
|
|
127
|
+
# Store last_conv_samp for next iteration.
|
|
128
|
+
last_conv_samp = conv_msg.data[..., -1:]
|
|
129
|
+
|
|
130
|
+
if template.data.dtype.kind != "c":
|
|
131
|
+
coef = coef.real
|
|
132
|
+
|
|
133
|
+
# pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have
|
|
134
|
+
# that luxury when streaming.
|
|
135
|
+
# d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2.
|
|
136
|
+
# coef = coef[..., math.floor(d):-math.ceil(d)]
|
|
137
|
+
msg_out = replace(
|
|
138
|
+
template, data=coef, axes={**template.axes, axis: msg_in.axes[axis]}
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class CWTSettings(ez.Settings):
|
|
143
|
+
"""
|
|
144
|
+
Settings for :obj:`CWT`
|
|
145
|
+
See :obj:`cwt` for argument details.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
scales: typing.Union[list, tuple, npt.NDArray]
|
|
149
|
+
wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet]
|
|
150
|
+
min_phase: MinPhaseMode = MinPhaseMode.NONE
|
|
151
|
+
axis: str = "time"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class CWT(GenAxisArray):
|
|
155
|
+
"""
|
|
156
|
+
:obj:`Unit` for :obj:`common_rereference`.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
SETTINGS = CWTSettings
|
|
160
|
+
|
|
161
|
+
def construct_generator(self):
|
|
162
|
+
self.STATE.gen = cwt(
|
|
163
|
+
scales=self.SETTINGS.scales,
|
|
164
|
+
wavelet=self.SETTINGS.wavelet,
|
|
165
|
+
min_phase=self.SETTINGS.min_phase,
|
|
166
|
+
axis=self.SETTINGS.axis,
|
|
167
|
+
)
|