ezmsg-sigproc 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +1 -1
- ezmsg/sigproc/butterworthfilter.py +17 -27
- ezmsg/sigproc/decimate.py +7 -10
- ezmsg/sigproc/downsample.py +27 -33
- ezmsg/sigproc/ewmfilter.py +60 -54
- ezmsg/sigproc/filter.py +40 -24
- ezmsg/sigproc/messages.py +24 -44
- ezmsg/sigproc/sampler.py +173 -137
- ezmsg/sigproc/spectral.py +132 -0
- ezmsg/sigproc/synth.py +239 -64
- ezmsg/sigproc/window.py +92 -60
- {ezmsg_sigproc-1.1.1.dist-info → ezmsg_sigproc-1.2.0.dist-info}/METADATA +2 -2
- ezmsg_sigproc-1.2.0.dist-info/RECORD +17 -0
- {ezmsg_sigproc-1.1.1.dist-info → ezmsg_sigproc-1.2.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/timeseriesmessage.py +0 -1
- ezmsg_sigproc-1.1.1.dist-info/RECORD +0 -17
- {ezmsg_sigproc-1.1.1.dist-info → ezmsg_sigproc-1.2.0.dist-info}/LICENSE.txt +0 -0
- {ezmsg_sigproc-1.1.1.dist-info → ezmsg_sigproc-1.2.0.dist-info}/top_level.txt +0 -0
ezmsg/sigproc/synth.py
CHANGED
|
@@ -1,31 +1,62 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
|
|
2
|
+
import time
|
|
3
|
+
from dataclasses import dataclass, replace, field
|
|
3
4
|
|
|
4
5
|
import ezmsg.core as ez
|
|
5
6
|
import numpy as np
|
|
6
7
|
|
|
7
|
-
from .messages import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
ButterworthFilterSettings,
|
|
11
|
-
ButterworthFilterDesign
|
|
12
|
-
)
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
|
|
10
|
+
from .butterworthfilter import ButterworthFilter, ButterworthFilterSettings
|
|
13
11
|
|
|
14
12
|
from typing import Optional, AsyncGenerator, Union
|
|
15
13
|
|
|
16
14
|
|
|
15
|
+
class ClockSettings(ez.Settings):
|
|
16
|
+
# Message dispatch rate (Hz), or None (fast as possible)
|
|
17
|
+
dispatch_rate: Optional[float]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ClockState(ez.State):
|
|
21
|
+
cur_settings: ClockSettings
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Clock(ez.Unit):
|
|
25
|
+
SETTINGS: ClockSettings
|
|
26
|
+
STATE: ClockState
|
|
27
|
+
|
|
28
|
+
INPUT_SETTINGS = ez.InputStream(ClockSettings)
|
|
29
|
+
OUTPUT_CLOCK = ez.OutputStream(ez.Flag)
|
|
30
|
+
|
|
31
|
+
def initialize(self) -> None:
|
|
32
|
+
self.STATE.cur_settings = self.SETTINGS
|
|
33
|
+
|
|
34
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
35
|
+
async def on_settings(self, msg: ClockSettings) -> None:
|
|
36
|
+
self.STATE.cur_settings = msg
|
|
37
|
+
|
|
38
|
+
@ez.publisher(OUTPUT_CLOCK)
|
|
39
|
+
async def generate(self) -> AsyncGenerator:
|
|
40
|
+
while True:
|
|
41
|
+
if self.STATE.cur_settings.dispatch_rate is not None:
|
|
42
|
+
await asyncio.sleep(1.0 / self.STATE.cur_settings.dispatch_rate)
|
|
43
|
+
yield self.OUTPUT_CLOCK, ez.Flag
|
|
44
|
+
|
|
45
|
+
|
|
17
46
|
class CounterSettings(ez.Settings):
|
|
18
47
|
"""
|
|
48
|
+
TODO: Adapt this to use ezmsg.util.rate?
|
|
19
49
|
NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode.
|
|
20
50
|
This method of sleeping/yielding execution priority has quirky behavior with
|
|
21
51
|
sub-millisecond sleep periods which may result in unexpected behavior (e.g.
|
|
22
52
|
fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
|
|
23
53
|
"""
|
|
54
|
+
|
|
24
55
|
n_time: int # Number of samples to output per block
|
|
25
56
|
fs: float # Sampling rate of signal output in Hz
|
|
26
|
-
n_ch: int = 1
|
|
57
|
+
n_ch: int = 1 # Number of channels to synthesize
|
|
27
58
|
|
|
28
|
-
# Message dispatch rate (Hz), 'realtime', or None (fast as possible)
|
|
59
|
+
# Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
|
|
29
60
|
dispatch_rate: Optional[Union[float, str]] = None
|
|
30
61
|
|
|
31
62
|
# If set to an integer, counter will rollover
|
|
@@ -33,77 +64,117 @@ class CounterSettings(ez.Settings):
|
|
|
33
64
|
|
|
34
65
|
|
|
35
66
|
class CounterState(ez.State):
|
|
67
|
+
cur_settings: CounterSettings
|
|
36
68
|
samp: int = 0 # current sample counter
|
|
69
|
+
clock_event: asyncio.Event
|
|
37
70
|
|
|
38
71
|
|
|
39
72
|
class Counter(ez.Unit):
|
|
40
|
-
"""
|
|
73
|
+
"""Generates monotonically increasing counter"""
|
|
41
74
|
|
|
42
75
|
SETTINGS: CounterSettings
|
|
43
76
|
STATE: CounterState
|
|
44
77
|
|
|
45
|
-
|
|
78
|
+
INPUT_CLOCK = ez.InputStream(ez.Flag)
|
|
79
|
+
INPUT_SETTINGS = ez.InputStream(CounterSettings)
|
|
80
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
81
|
+
|
|
82
|
+
def initialize(self) -> None:
|
|
83
|
+
self.STATE.clock_event = asyncio.Event()
|
|
84
|
+
self.STATE.clock_event.clear()
|
|
85
|
+
self.validate_settings(self.SETTINGS)
|
|
86
|
+
|
|
87
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
88
|
+
async def on_settings(self, msg: CounterSettings) -> None:
|
|
89
|
+
self.validate_settings(msg)
|
|
90
|
+
|
|
91
|
+
def validate_settings(self, settings: CounterSettings) -> None:
|
|
92
|
+
if isinstance(
|
|
93
|
+
settings.dispatch_rate, str
|
|
94
|
+
) and self.SETTINGS.dispatch_rate not in ["realtime", "ext_clock"]:
|
|
95
|
+
raise ValueError(f"Unknown dispatch_rate: {self.SETTINGS.dispatch_rate}")
|
|
96
|
+
|
|
97
|
+
self.STATE.cur_settings = settings
|
|
98
|
+
|
|
99
|
+
@ez.subscriber(INPUT_CLOCK)
|
|
100
|
+
async def on_clock(self, _: ez.Flag):
|
|
101
|
+
self.STATE.clock_event.set()
|
|
46
102
|
|
|
47
103
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
48
104
|
async def publish(self) -> AsyncGenerator:
|
|
105
|
+
while True:
|
|
106
|
+
block_dur = self.STATE.cur_settings.n_time / self.STATE.cur_settings.fs
|
|
49
107
|
|
|
50
|
-
|
|
51
|
-
|
|
108
|
+
dispatch_rate = self.STATE.cur_settings.dispatch_rate
|
|
109
|
+
if dispatch_rate is not None:
|
|
110
|
+
if isinstance(dispatch_rate, str):
|
|
111
|
+
if dispatch_rate == "realtime":
|
|
112
|
+
await asyncio.sleep(block_dur)
|
|
113
|
+
elif dispatch_rate == "ext_clock":
|
|
114
|
+
await self.STATE.clock_event.wait()
|
|
115
|
+
self.STATE.clock_event.clear()
|
|
116
|
+
else:
|
|
117
|
+
await asyncio.sleep(1.0 / dispatch_rate)
|
|
52
118
|
|
|
53
|
-
|
|
119
|
+
block_samp = np.arange(self.STATE.cur_settings.n_time)[:, np.newaxis]
|
|
54
120
|
|
|
55
121
|
t_samp = block_samp + self.STATE.samp
|
|
56
122
|
self.STATE.samp = t_samp[-1] + 1
|
|
57
123
|
|
|
58
|
-
if self.
|
|
59
|
-
t_samp %= self.
|
|
60
|
-
self.STATE.samp %= self.
|
|
124
|
+
if self.STATE.cur_settings.mod is not None:
|
|
125
|
+
t_samp %= self.STATE.cur_settings.mod
|
|
126
|
+
self.STATE.samp %= self.STATE.cur_settings.mod
|
|
127
|
+
|
|
128
|
+
t_samp = np.tile(t_samp, (1, self.STATE.cur_settings.n_ch))
|
|
61
129
|
|
|
62
|
-
|
|
130
|
+
offset_adj = self.STATE.cur_settings.n_time / self.STATE.cur_settings.fs
|
|
63
131
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
132
|
+
out = AxisArray(
|
|
133
|
+
t_samp,
|
|
134
|
+
dims=["time", "ch"],
|
|
135
|
+
axes=dict(
|
|
136
|
+
time=AxisArray.Axis.TimeAxis(
|
|
137
|
+
fs=self.STATE.cur_settings.fs, offset=time.time() - offset_adj
|
|
138
|
+
)
|
|
139
|
+
),
|
|
70
140
|
)
|
|
71
141
|
|
|
72
|
-
|
|
73
|
-
await asyncio.sleep(
|
|
74
|
-
block_dur if self.SETTINGS.dispatch_rate == 'realtime'
|
|
75
|
-
else (1.0 / self.SETTINGS.dispatch_rate)
|
|
76
|
-
)
|
|
142
|
+
yield self.OUTPUT_SIGNAL, out
|
|
77
143
|
|
|
78
144
|
|
|
79
145
|
class SinGeneratorSettings(ez.Settings):
|
|
146
|
+
time_axis: Optional[str] = "time"
|
|
80
147
|
freq: float = 1.0 # Oscillation frequency in Hz
|
|
81
148
|
amp: float = 1.0 # Amplitude
|
|
82
149
|
phase: float = 0.0 # Phase offset (in radians)
|
|
83
150
|
|
|
84
151
|
|
|
85
152
|
class SinGeneratorState(ez.State):
|
|
86
|
-
ang_freq:
|
|
153
|
+
ang_freq: float # pre-calculated angular frequency in radians
|
|
87
154
|
|
|
88
155
|
|
|
89
156
|
class SinGenerator(ez.Unit):
|
|
90
157
|
SETTINGS: SinGeneratorSettings
|
|
91
158
|
STATE: SinGeneratorState
|
|
92
159
|
|
|
93
|
-
INPUT_SIGNAL = ez.InputStream(
|
|
94
|
-
OUTPUT_SIGNAL = ez.OutputStream(
|
|
160
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
161
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
95
162
|
|
|
96
163
|
def initialize(self) -> None:
|
|
97
164
|
self.STATE.ang_freq = 2.0 * np.pi * self.SETTINGS.freq
|
|
98
165
|
|
|
99
166
|
@ez.subscriber(INPUT_SIGNAL)
|
|
100
167
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
101
|
-
async def generate(self, msg:
|
|
168
|
+
async def generate(self, msg: AxisArray) -> AsyncGenerator:
|
|
102
169
|
"""
|
|
103
170
|
msg is assumed to be a monotonically increasing counter ..
|
|
104
171
|
.. or at least a counter with an intelligently chosen modulus
|
|
105
172
|
"""
|
|
106
|
-
|
|
173
|
+
axis_name = self.SETTINGS.time_axis
|
|
174
|
+
if axis_name is None:
|
|
175
|
+
axis_name = msg.dims[0]
|
|
176
|
+
fs = 1.0 / msg.get_axis(axis_name).gain
|
|
177
|
+
t_sec = msg.data / fs
|
|
107
178
|
w = self.STATE.ang_freq * t_sec
|
|
108
179
|
out_data = self.SETTINGS.amp * np.sin(w + self.SETTINGS.phase)
|
|
109
180
|
yield (self.OUTPUT_SIGNAL, replace(msg, data=out_data))
|
|
@@ -112,8 +183,8 @@ class SinGenerator(ez.Unit):
|
|
|
112
183
|
class OscillatorSettings(ez.Settings):
|
|
113
184
|
n_time: int # Number of samples to output per block
|
|
114
185
|
fs: float # Sampling rate of signal output in Hz
|
|
115
|
-
n_ch: int = 1
|
|
116
|
-
dispatch_rate: Optional[Union[float, str]] = None # (Hz)
|
|
186
|
+
n_ch: int = 1 # Number of channels to output per block
|
|
187
|
+
dispatch_rate: Optional[Union[float, str]] = None # (Hz) | 'realtime' | 'ext_clock'
|
|
117
188
|
freq: float = 1.0 # Oscillation frequency in Hz
|
|
118
189
|
amp: float = 1.0 # Amplitude
|
|
119
190
|
phase: float = 0.0 # Phase offset (in radians)
|
|
@@ -121,16 +192,15 @@ class OscillatorSettings(ez.Settings):
|
|
|
121
192
|
|
|
122
193
|
|
|
123
194
|
class Oscillator(ez.Collection):
|
|
124
|
-
|
|
125
195
|
SETTINGS: OscillatorSettings
|
|
126
196
|
|
|
127
|
-
|
|
197
|
+
INPUT_CLOCK = ez.InputStream(ez.Flag)
|
|
198
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
128
199
|
|
|
129
200
|
COUNTER = Counter()
|
|
130
201
|
SIN = SinGenerator()
|
|
131
202
|
|
|
132
203
|
def configure(self) -> None:
|
|
133
|
-
|
|
134
204
|
# Calculate synchronous settings if necessary
|
|
135
205
|
freq = self.SETTINGS.freq
|
|
136
206
|
mod = None
|
|
@@ -145,54 +215,72 @@ class Oscillator(ez.Collection):
|
|
|
145
215
|
fs=self.SETTINGS.fs,
|
|
146
216
|
n_ch=self.SETTINGS.n_ch,
|
|
147
217
|
dispatch_rate=self.SETTINGS.dispatch_rate,
|
|
148
|
-
mod=mod
|
|
218
|
+
mod=mod,
|
|
149
219
|
)
|
|
150
220
|
)
|
|
151
221
|
|
|
152
222
|
self.SIN.apply_settings(
|
|
153
223
|
SinGeneratorSettings(
|
|
154
|
-
freq=freq,
|
|
155
|
-
amp=self.SETTINGS.amp,
|
|
156
|
-
phase=self.SETTINGS.phase
|
|
224
|
+
freq=freq, amp=self.SETTINGS.amp, phase=self.SETTINGS.phase
|
|
157
225
|
)
|
|
158
226
|
)
|
|
159
227
|
|
|
160
228
|
def network(self) -> ez.NetworkDefinition:
|
|
161
229
|
return (
|
|
230
|
+
(self.INPUT_CLOCK, self.COUNTER.INPUT_CLOCK),
|
|
162
231
|
(self.COUNTER.OUTPUT_SIGNAL, self.SIN.INPUT_SIGNAL),
|
|
163
|
-
(self.SIN.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)
|
|
232
|
+
(self.SIN.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
|
|
164
233
|
)
|
|
165
234
|
|
|
166
235
|
|
|
236
|
+
class RandomGeneratorSettings(ez.Settings):
|
|
237
|
+
loc: float = 0.0
|
|
238
|
+
scale: float = 1.0
|
|
239
|
+
|
|
240
|
+
|
|
167
241
|
class RandomGenerator(ez.Unit):
|
|
168
|
-
|
|
169
|
-
|
|
242
|
+
SETTINGS: RandomGeneratorSettings
|
|
243
|
+
|
|
244
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
245
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
170
246
|
|
|
171
247
|
@ez.subscriber(INPUT_SIGNAL)
|
|
172
248
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
173
|
-
async def generate(self, msg:
|
|
174
|
-
random_data = np.random.normal(
|
|
175
|
-
|
|
249
|
+
async def generate(self, msg: AxisArray) -> AsyncGenerator:
|
|
250
|
+
random_data = np.random.normal(
|
|
251
|
+
size=msg.shape, loc=self.SETTINGS.loc, scale=self.SETTINGS.scale
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
yield self.OUTPUT_SIGNAL, replace(msg, data=random_data)
|
|
176
255
|
|
|
177
256
|
|
|
178
257
|
class NoiseSettings(ez.Settings):
|
|
179
258
|
n_time: int # Number of samples to output per block
|
|
180
259
|
fs: float # Sampling rate of signal output in Hz
|
|
181
|
-
n_ch: int = 1
|
|
182
|
-
dispatch_rate: Optional[
|
|
260
|
+
n_ch: int = 1 # Number of channels to output
|
|
261
|
+
dispatch_rate: Optional[
|
|
262
|
+
Union[float, str]
|
|
263
|
+
] = None # (Hz), 'realtime', or 'ext_clock'
|
|
264
|
+
loc: float = 0.0 # DC offset
|
|
265
|
+
scale: float = 1.0 # Scale (in standard deviations)
|
|
266
|
+
|
|
183
267
|
|
|
184
268
|
WhiteNoiseSettings = NoiseSettings
|
|
185
269
|
|
|
186
|
-
class WhiteNoise(ez.Collection):
|
|
187
270
|
|
|
271
|
+
class WhiteNoise(ez.Collection):
|
|
188
272
|
SETTINGS: NoiseSettings
|
|
189
273
|
|
|
190
|
-
|
|
274
|
+
INPUT_CLOCK = ez.InputStream(ez.Flag)
|
|
275
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
191
276
|
|
|
192
277
|
COUNTER = Counter()
|
|
193
278
|
RANDOM = RandomGenerator()
|
|
194
279
|
|
|
195
280
|
def configure(self) -> None:
|
|
281
|
+
self.RANDOM.apply_settings(
|
|
282
|
+
RandomGeneratorSettings(loc=self.SETTINGS.loc, scale=self.SETTINGS.scale)
|
|
283
|
+
)
|
|
196
284
|
|
|
197
285
|
self.COUNTER.apply_settings(
|
|
198
286
|
CounterSettings(
|
|
@@ -200,37 +288,124 @@ class WhiteNoise(ez.Collection):
|
|
|
200
288
|
fs=self.SETTINGS.fs,
|
|
201
289
|
n_ch=self.SETTINGS.n_ch,
|
|
202
290
|
dispatch_rate=self.SETTINGS.dispatch_rate,
|
|
203
|
-
mod=None
|
|
291
|
+
mod=None,
|
|
204
292
|
)
|
|
205
293
|
)
|
|
206
294
|
|
|
207
295
|
def network(self) -> ez.NetworkDefinition:
|
|
208
296
|
return (
|
|
297
|
+
(self.INPUT_CLOCK, self.COUNTER.INPUT_CLOCK),
|
|
209
298
|
(self.COUNTER.OUTPUT_SIGNAL, self.RANDOM.INPUT_SIGNAL),
|
|
210
|
-
(self.RANDOM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)
|
|
299
|
+
(self.RANDOM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
|
|
211
300
|
)
|
|
212
301
|
|
|
213
|
-
class PinkNoise(ez.Collection):
|
|
214
302
|
|
|
215
|
-
|
|
303
|
+
PinkNoiseSettings = NoiseSettings
|
|
216
304
|
|
|
217
|
-
|
|
305
|
+
|
|
306
|
+
class PinkNoise(ez.Collection):
|
|
307
|
+
SETTINGS: PinkNoiseSettings
|
|
308
|
+
|
|
309
|
+
INPUT_CLOCK = ez.InputStream(ez.Flag)
|
|
310
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
218
311
|
|
|
219
312
|
WHITE_NOISE = WhiteNoise()
|
|
220
313
|
FILTER = ButterworthFilter()
|
|
221
314
|
|
|
222
315
|
def configure(self) -> None:
|
|
223
|
-
|
|
224
|
-
self.WHITE_NOISE.apply_settings( self.SETTINGS )
|
|
316
|
+
self.WHITE_NOISE.apply_settings(self.SETTINGS)
|
|
225
317
|
self.FILTER.apply_settings(
|
|
226
318
|
ButterworthFilterSettings(
|
|
227
|
-
order =
|
|
228
|
-
cutoff = self.SETTINGS.fs * 0.01 # Hz
|
|
319
|
+
axis="time", order=1, cutoff=self.SETTINGS.fs * 0.01 # Hz
|
|
229
320
|
)
|
|
230
321
|
)
|
|
231
322
|
|
|
232
323
|
def network(self) -> ez.NetworkDefinition:
|
|
233
324
|
return (
|
|
325
|
+
(self.INPUT_CLOCK, self.WHITE_NOISE.INPUT_CLOCK),
|
|
234
326
|
(self.WHITE_NOISE.OUTPUT_SIGNAL, self.FILTER.INPUT_SIGNAL),
|
|
235
|
-
(self.FILTER.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)
|
|
236
|
-
)
|
|
327
|
+
(self.FILTER.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class AddState(ez.State):
|
|
332
|
+
queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
|
|
333
|
+
queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class Add(ez.Unit):
|
|
337
|
+
"""Add two signals together. Assumes compatible/similar axes/dimensions."""
|
|
338
|
+
|
|
339
|
+
STATE: AddState
|
|
340
|
+
|
|
341
|
+
INPUT_SIGNAL_A = ez.InputStream(AxisArray)
|
|
342
|
+
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
|
|
343
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
344
|
+
|
|
345
|
+
@ez.subscriber(INPUT_SIGNAL_A)
|
|
346
|
+
async def on_a(self, msg: AxisArray) -> None:
|
|
347
|
+
self.STATE.queue_a.put_nowait(msg)
|
|
348
|
+
|
|
349
|
+
@ez.subscriber(INPUT_SIGNAL_B)
|
|
350
|
+
async def on_b(self, msg: AxisArray) -> None:
|
|
351
|
+
self.STATE.queue_b.put_nowait(msg)
|
|
352
|
+
|
|
353
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
354
|
+
async def output(self) -> AsyncGenerator:
|
|
355
|
+
while True:
|
|
356
|
+
a = await self.STATE.queue_a.get()
|
|
357
|
+
b = await self.STATE.queue_b.get()
|
|
358
|
+
|
|
359
|
+
yield (self.OUTPUT_SIGNAL, replace(a, data=a.data + b.data))
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class EEGSynthSettings(ez.Settings):
|
|
363
|
+
fs: float = 500.0 # Hz
|
|
364
|
+
n_time: int = 100
|
|
365
|
+
alpha_freq: float = 10.5 # Hz
|
|
366
|
+
n_ch: int = 8
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class EEGSynth(ez.Collection):
|
|
370
|
+
SETTINGS: EEGSynthSettings
|
|
371
|
+
|
|
372
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
373
|
+
|
|
374
|
+
CLOCK = Clock()
|
|
375
|
+
NOISE = PinkNoise()
|
|
376
|
+
OSC = Oscillator()
|
|
377
|
+
ADD = Add()
|
|
378
|
+
|
|
379
|
+
def configure(self) -> None:
|
|
380
|
+
self.CLOCK.apply_settings(
|
|
381
|
+
ClockSettings(dispatch_rate=self.SETTINGS.fs / self.SETTINGS.n_time)
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
self.OSC.apply_settings(
|
|
385
|
+
OscillatorSettings(
|
|
386
|
+
n_time=self.SETTINGS.n_time,
|
|
387
|
+
fs=self.SETTINGS.fs,
|
|
388
|
+
n_ch=self.SETTINGS.n_ch,
|
|
389
|
+
dispatch_rate="ext_clock",
|
|
390
|
+
freq=self.SETTINGS.alpha_freq,
|
|
391
|
+
)
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
self.NOISE.apply_settings(
|
|
395
|
+
PinkNoiseSettings(
|
|
396
|
+
n_time=self.SETTINGS.n_time,
|
|
397
|
+
fs=self.SETTINGS.fs,
|
|
398
|
+
n_ch=self.SETTINGS.n_ch,
|
|
399
|
+
dispatch_rate="ext_clock",
|
|
400
|
+
scale=5.0,
|
|
401
|
+
)
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
def network(self) -> ez.NetworkDefinition:
|
|
405
|
+
return (
|
|
406
|
+
(self.CLOCK.OUTPUT_CLOCK, self.OSC.INPUT_CLOCK),
|
|
407
|
+
(self.CLOCK.OUTPUT_CLOCK, self.NOISE.INPUT_CLOCK),
|
|
408
|
+
(self.OSC.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A),
|
|
409
|
+
(self.NOISE.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B),
|
|
410
|
+
(self.ADD.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
|
|
411
|
+
)
|
ezmsg/sigproc/window.py
CHANGED
|
@@ -1,88 +1,88 @@
|
|
|
1
|
-
from dataclasses import
|
|
1
|
+
from dataclasses import replace
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
5
6
|
|
|
6
|
-
from ezmsg.util.messages import
|
|
7
|
-
from ezmsg.sigproc.messages import TSMessage
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
8
|
|
|
9
|
-
from typing import
|
|
10
|
-
AsyncGenerator,
|
|
11
|
-
Optional,
|
|
12
|
-
Tuple
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
@dataclass( frozen = True )
|
|
16
|
-
class WindowSettingsMessage:
|
|
17
|
-
window_dur: Optional[float] = None # Sec. If "None" -- passthrough; window_shift is ignored.
|
|
18
|
-
window_shift: Optional[float] = None # Sec. If "None", activate "1:1 mode"
|
|
9
|
+
from typing import AsyncGenerator, Optional, Tuple, List
|
|
19
10
|
|
|
20
11
|
|
|
21
|
-
class WindowSettings(
|
|
22
|
-
|
|
12
|
+
class WindowSettings(ez.Settings):
|
|
13
|
+
axis: Optional[str] = None
|
|
14
|
+
newaxis: Optional[
|
|
15
|
+
str
|
|
16
|
+
] = None # Optional new axis for output. If "None" - no new axes on output
|
|
17
|
+
window_dur: Optional[
|
|
18
|
+
float
|
|
19
|
+
] = None # Sec. If "None" -- passthrough; window_shift is ignored.
|
|
20
|
+
window_shift: Optional[float] = None # Sec. If "None", activate "1:1 mode"
|
|
23
21
|
|
|
24
22
|
|
|
25
23
|
class WindowState(ez.State):
|
|
26
|
-
cur_settings:
|
|
24
|
+
cur_settings: WindowSettings
|
|
27
25
|
|
|
28
26
|
samp_shape: Optional[Tuple[int, ...]] = None # Shape of individual sample
|
|
29
27
|
out_fs: Optional[float] = None
|
|
30
|
-
buffer: Optional[
|
|
28
|
+
buffer: Optional[npt.NDArray] = None
|
|
31
29
|
window_samples: Optional[int] = None
|
|
32
30
|
window_shift_samples: Optional[int] = None
|
|
33
31
|
|
|
34
32
|
|
|
35
33
|
class Window(ez.Unit):
|
|
36
|
-
|
|
37
34
|
STATE: WindowState
|
|
38
35
|
SETTINGS: WindowSettings
|
|
39
36
|
|
|
40
|
-
INPUT_SIGNAL = ez.InputStream(
|
|
41
|
-
OUTPUT_SIGNAL = ez.OutputStream(
|
|
42
|
-
INPUT_SETTINGS = ez.InputStream(
|
|
37
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
38
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
39
|
+
INPUT_SETTINGS = ez.InputStream(WindowSettings)
|
|
43
40
|
|
|
44
41
|
def initialize(self) -> None:
|
|
45
42
|
self.STATE.cur_settings = self.SETTINGS
|
|
46
43
|
|
|
47
44
|
@ez.subscriber(INPUT_SETTINGS)
|
|
48
|
-
async def on_settings(self, msg:
|
|
45
|
+
async def on_settings(self, msg: WindowSettings) -> None:
|
|
49
46
|
self.STATE.cur_settings = msg
|
|
50
|
-
self.STATE.out_fs = None
|
|
47
|
+
self.STATE.out_fs = None # This should trigger a reallocation
|
|
51
48
|
|
|
52
49
|
@ez.subscriber(INPUT_SIGNAL)
|
|
53
50
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
54
|
-
async def on_signal(self, msg:
|
|
55
|
-
|
|
51
|
+
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
56
52
|
if self.STATE.cur_settings.window_dur is None:
|
|
57
53
|
yield self.OUTPUT_SIGNAL, msg
|
|
58
54
|
return
|
|
59
55
|
|
|
56
|
+
axis_name = self.STATE.cur_settings.axis
|
|
57
|
+
if axis_name is None:
|
|
58
|
+
axis_name = msg.dims[0]
|
|
59
|
+
axis_idx = msg.get_axis_idx(axis_name)
|
|
60
|
+
axis = msg.get_axis(axis_name)
|
|
61
|
+
fs = 1.0 / axis.gain
|
|
62
|
+
|
|
60
63
|
# Create a view of data with time axis at dim 0
|
|
61
|
-
time_view = np.moveaxis(msg.data,
|
|
64
|
+
time_view = np.moveaxis(msg.data, axis_idx, 0)
|
|
62
65
|
samp_shape = time_view.shape[1:]
|
|
63
66
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
self.STATE.
|
|
68
|
-
self.STATE.
|
|
69
|
-
|
|
67
|
+
# Pre(re?)allocate buffer
|
|
68
|
+
window_samples = int(self.STATE.cur_settings.window_dur * fs)
|
|
69
|
+
if (
|
|
70
|
+
(self.STATE.samp_shape != samp_shape)
|
|
71
|
+
or (self.STATE.out_fs != fs)
|
|
72
|
+
or self.STATE.buffer is None
|
|
73
|
+
):
|
|
74
|
+
self.STATE.buffer = np.zeros(tuple([window_samples] + list(samp_shape)))
|
|
75
|
+
|
|
76
|
+
self.STATE.window_samples = window_samples
|
|
77
|
+
self.STATE.samp_shape = samp_shape
|
|
78
|
+
self.STATE.out_fs = fs
|
|
79
|
+
|
|
80
|
+
self.STATE.window_shift_samples = None
|
|
81
|
+
if self.STATE.cur_settings.window_shift is not None:
|
|
82
|
+
self.STATE.window_shift_samples = int(
|
|
83
|
+
fs * self.STATE.cur_settings.window_shift
|
|
70
84
|
)
|
|
71
85
|
|
|
72
|
-
if self.STATE.cur_settings.window_shift is not None:
|
|
73
|
-
self.STATE.window_shift_samples = int(
|
|
74
|
-
self.STATE.out_fs * self.STATE.cur_settings.window_shift
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
if self.STATE.buffer is None:
|
|
78
|
-
self.STATE.buffer = np.zeros(tuple([self.STATE.window_samples] + list(self.STATE.samp_shape)))
|
|
79
|
-
elif self.STATE.buffer.shape[0] > self.STATE.window_samples:
|
|
80
|
-
self.STATE.buffer = self.STATE.buffer[:self.STATE.window_samples, ...]
|
|
81
|
-
elif self.STATE.buffer.shape[0] < self.STATE.window_samples:
|
|
82
|
-
extra_samples = self.STATE.window_samples - self.STATE.buffer.shape[0]
|
|
83
|
-
extra_samples = np.zeros(tuple([extra_samples] + list(self.STATE.samp_shape)))
|
|
84
|
-
self.STATE.buffer = np.concatenate((extra_samples, self.STATE.buffer), axis=0)
|
|
85
|
-
|
|
86
86
|
# Currently we just concatenate the new time samples and clip the output
|
|
87
87
|
# np.roll actually returns a copy, and there's no way to construct a
|
|
88
88
|
# rolling view of the data. In current numpy implementations, np.concatenate
|
|
@@ -90,23 +90,55 @@ class Window(ez.Unit):
|
|
|
90
90
|
# be a performance bottleneck for large memory arrays.
|
|
91
91
|
self.STATE.buffer = np.concatenate((self.STATE.buffer, time_view), axis=0)
|
|
92
92
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
93
|
+
buffer_offset = np.arange(self.STATE.buffer.shape[0] + time_view.shape[0])
|
|
94
|
+
buffer_offset -= self.STATE.buffer.shape[0] + 1
|
|
95
|
+
buffer_offset = (buffer_offset * axis.gain) + axis.offset
|
|
96
96
|
|
|
97
|
-
|
|
98
|
-
out_view = np.moveaxis(self.STATE.buffer, 0, msg.time_dim)
|
|
99
|
-
yield (self.OUTPUT_SIGNAL, replace(msg, data=out_view))
|
|
97
|
+
outputs: List[Tuple[npt.NDArray, float]] = []
|
|
100
98
|
|
|
101
|
-
|
|
99
|
+
if self.STATE.window_shift_samples is None: # one-to-one mode
|
|
100
|
+
self.STATE.buffer = self.STATE.buffer[-self.STATE.window_samples :, ...]
|
|
101
|
+
buffer_offset = buffer_offset[-self.STATE.window_samples :]
|
|
102
|
+
outputs.append((self.STATE.buffer, buffer_offset[0]))
|
|
102
103
|
|
|
104
|
+
else:
|
|
103
105
|
yieldable_size = self.STATE.window_samples + self.STATE.window_shift_samples
|
|
104
106
|
while self.STATE.buffer.shape[0] >= yieldable_size:
|
|
107
|
+
outputs.append(
|
|
108
|
+
(
|
|
109
|
+
self.STATE.buffer[: self.STATE.window_samples, ...],
|
|
110
|
+
buffer_offset[0],
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
self.STATE.buffer = self.STATE.buffer[
|
|
114
|
+
self.STATE.window_shift_samples :, ...
|
|
115
|
+
]
|
|
116
|
+
buffer_offset = buffer_offset[self.STATE.window_shift_samples :]
|
|
117
|
+
|
|
118
|
+
for out_view, offset in outputs:
|
|
119
|
+
out_view = np.moveaxis(out_view, 0, axis_idx)
|
|
120
|
+
|
|
121
|
+
if (
|
|
122
|
+
self.STATE.cur_settings.newaxis is not None
|
|
123
|
+
and self.STATE.cur_settings.newaxis != self.STATE.cur_settings.axis
|
|
124
|
+
):
|
|
125
|
+
new_gain = 0.0
|
|
126
|
+
if self.STATE.window_shift_samples is not None:
|
|
127
|
+
new_gain = axis.gain * self.STATE.window_shift_samples
|
|
128
|
+
|
|
129
|
+
out_axis = replace(axis, unit=axis.unit, gain=new_gain, offset=offset)
|
|
130
|
+
out_axes = {**msg.axes, **{self.STATE.cur_settings.newaxis: out_axis}}
|
|
131
|
+
out_dims = [self.STATE.cur_settings.newaxis] + msg.dims
|
|
132
|
+
out_view = out_view[np.newaxis, ...]
|
|
133
|
+
|
|
134
|
+
yield self.OUTPUT_SIGNAL, replace(
|
|
135
|
+
msg, data=out_view, dims=out_dims, axes=out_axes
|
|
136
|
+
)
|
|
105
137
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
138
|
+
else:
|
|
139
|
+
if axis_name in msg.axes:
|
|
140
|
+
out_axes = msg.axes
|
|
141
|
+
out_axes[axis_name] = replace(axis, offset=offset)
|
|
142
|
+
yield self.OUTPUT_SIGNAL, replace(msg, data=out_view, axes=out_axes)
|
|
143
|
+
else:
|
|
144
|
+
yield self.OUTPUT_SIGNAL, replace(msg, data=out_view)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ezmsg-sigproc
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
5
|
Home-page: https://github.com/iscoe/ezmsg
|
|
6
6
|
Author: Griffin Milsap
|
|
@@ -10,7 +10,7 @@ Classifier: Operating System :: OS Independent
|
|
|
10
10
|
Requires-Python: >=3.8
|
|
11
11
|
Description-Content-Type: text/markdown
|
|
12
12
|
License-File: LICENSE.txt
|
|
13
|
-
Requires-Dist: ezmsg (>=
|
|
13
|
+
Requires-Dist: ezmsg (>=3.3.0)
|
|
14
14
|
Requires-Dist: numpy (>=1.19.5)
|
|
15
15
|
Requires-Dist: scipy (>=1.6.3)
|
|
16
16
|
Provides-Extra: test
|