ezmsg-sigproc 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/synth.py CHANGED
@@ -1,31 +1,62 @@
1
1
  import asyncio
2
- from dataclasses import replace
2
+ import time
3
+ from dataclasses import dataclass, replace, field
3
4
 
4
5
  import ezmsg.core as ez
5
6
  import numpy as np
6
7
 
7
- from .messages import TSMessage
8
- from .butterworthfilter import (
9
- ButterworthFilter,
10
- ButterworthFilterSettings,
11
- ButterworthFilterDesign
12
- )
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+
10
+ from .butterworthfilter import ButterworthFilter, ButterworthFilterSettings
13
11
 
14
12
  from typing import Optional, AsyncGenerator, Union
15
13
 
16
14
 
15
+ class ClockSettings(ez.Settings):
16
+ # Message dispatch rate (Hz), or None (fast as possible)
17
+ dispatch_rate: Optional[float]
18
+
19
+
20
+ class ClockState(ez.State):
21
+ cur_settings: ClockSettings
22
+
23
+
24
+ class Clock(ez.Unit):
25
+ SETTINGS: ClockSettings
26
+ STATE: ClockState
27
+
28
+ INPUT_SETTINGS = ez.InputStream(ClockSettings)
29
+ OUTPUT_CLOCK = ez.OutputStream(ez.Flag)
30
+
31
+ def initialize(self) -> None:
32
+ self.STATE.cur_settings = self.SETTINGS
33
+
34
+ @ez.subscriber(INPUT_SETTINGS)
35
+ async def on_settings(self, msg: ClockSettings) -> None:
36
+ self.STATE.cur_settings = msg
37
+
38
+ @ez.publisher(OUTPUT_CLOCK)
39
+ async def generate(self) -> AsyncGenerator:
40
+ while True:
41
+ if self.STATE.cur_settings.dispatch_rate is not None:
42
+ await asyncio.sleep(1.0 / self.STATE.cur_settings.dispatch_rate)
43
+ yield self.OUTPUT_CLOCK, ez.Flag
44
+
45
+
17
46
  class CounterSettings(ez.Settings):
18
47
  """
48
+ TODO: Adapt this to use ezmsg.util.rate?
19
49
  NOTE: This module uses asyncio.sleep to delay appropriately in realtime mode.
20
50
  This method of sleeping/yielding execution priority has quirky behavior with
21
51
  sub-millisecond sleep periods which may result in unexpected behavior (e.g.
22
52
  fs = 2000, n_time = 1, realtime = True -- may result in ~1400 msgs/sec)
23
53
  """
54
+
24
55
  n_time: int # Number of samples to output per block
25
56
  fs: float # Sampling rate of signal output in Hz
26
- n_ch: int = 1 # Number of channels to synthesize
57
+ n_ch: int = 1 # Number of channels to synthesize
27
58
 
28
- # Message dispatch rate (Hz), 'realtime', or None (fast as possible)
59
+ # Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
29
60
  dispatch_rate: Optional[Union[float, str]] = None
30
61
 
31
62
  # If set to an integer, counter will rollover
@@ -33,77 +64,117 @@ class CounterSettings(ez.Settings):
33
64
 
34
65
 
35
66
  class CounterState(ez.State):
67
+ cur_settings: CounterSettings
36
68
  samp: int = 0 # current sample counter
69
+ clock_event: asyncio.Event
37
70
 
38
71
 
39
72
  class Counter(ez.Unit):
40
- """ Generates monotonically increasing counter """
73
+ """Generates monotonically increasing counter"""
41
74
 
42
75
  SETTINGS: CounterSettings
43
76
  STATE: CounterState
44
77
 
45
- OUTPUT_SIGNAL = ez.OutputStream(TSMessage)
78
+ INPUT_CLOCK = ez.InputStream(ez.Flag)
79
+ INPUT_SETTINGS = ez.InputStream(CounterSettings)
80
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
81
+
82
+ def initialize(self) -> None:
83
+ self.STATE.clock_event = asyncio.Event()
84
+ self.STATE.clock_event.clear()
85
+ self.validate_settings(self.SETTINGS)
86
+
87
+ @ez.subscriber(INPUT_SETTINGS)
88
+ async def on_settings(self, msg: CounterSettings) -> None:
89
+ self.validate_settings(msg)
90
+
91
+ def validate_settings(self, settings: CounterSettings) -> None:
92
+ if isinstance(
93
+ settings.dispatch_rate, str
94
+ ) and self.SETTINGS.dispatch_rate not in ["realtime", "ext_clock"]:
95
+ raise ValueError(f"Unknown dispatch_rate: {self.SETTINGS.dispatch_rate}")
96
+
97
+ self.STATE.cur_settings = settings
98
+
99
+ @ez.subscriber(INPUT_CLOCK)
100
+ async def on_clock(self, _: ez.Flag):
101
+ self.STATE.clock_event.set()
46
102
 
47
103
  @ez.publisher(OUTPUT_SIGNAL)
48
104
  async def publish(self) -> AsyncGenerator:
105
+ while True:
106
+ block_dur = self.STATE.cur_settings.n_time / self.STATE.cur_settings.fs
49
107
 
50
- block_samp = np.arange(self.SETTINGS.n_time)[:, np.newaxis]
51
- block_dur = self.SETTINGS.n_time / self.SETTINGS.fs
108
+ dispatch_rate = self.STATE.cur_settings.dispatch_rate
109
+ if dispatch_rate is not None:
110
+ if isinstance(dispatch_rate, str):
111
+ if dispatch_rate == "realtime":
112
+ await asyncio.sleep(block_dur)
113
+ elif dispatch_rate == "ext_clock":
114
+ await self.STATE.clock_event.wait()
115
+ self.STATE.clock_event.clear()
116
+ else:
117
+ await asyncio.sleep(1.0 / dispatch_rate)
52
118
 
53
- while True:
119
+ block_samp = np.arange(self.STATE.cur_settings.n_time)[:, np.newaxis]
54
120
 
55
121
  t_samp = block_samp + self.STATE.samp
56
122
  self.STATE.samp = t_samp[-1] + 1
57
123
 
58
- if self.SETTINGS.mod is not None:
59
- t_samp %= self.SETTINGS.mod
60
- self.STATE.samp %= self.SETTINGS.mod
124
+ if self.STATE.cur_settings.mod is not None:
125
+ t_samp %= self.STATE.cur_settings.mod
126
+ self.STATE.samp %= self.STATE.cur_settings.mod
127
+
128
+ t_samp = np.tile(t_samp, (1, self.STATE.cur_settings.n_ch))
61
129
 
62
- t_samp = np.tile( t_samp, ( 1, self.SETTINGS.n_ch ) )
130
+ offset_adj = self.STATE.cur_settings.n_time / self.STATE.cur_settings.fs
63
131
 
64
- yield (
65
- self.OUTPUT_SIGNAL,
66
- TSMessage(
67
- t_samp,
68
- fs=self.SETTINGS.fs
69
- )
132
+ out = AxisArray(
133
+ t_samp,
134
+ dims=["time", "ch"],
135
+ axes=dict(
136
+ time=AxisArray.Axis.TimeAxis(
137
+ fs=self.STATE.cur_settings.fs, offset=time.time() - offset_adj
138
+ )
139
+ ),
70
140
  )
71
141
 
72
- if self.SETTINGS.dispatch_rate is not None:
73
- await asyncio.sleep(
74
- block_dur if self.SETTINGS.dispatch_rate == 'realtime'
75
- else (1.0 / self.SETTINGS.dispatch_rate)
76
- )
142
+ yield self.OUTPUT_SIGNAL, out
77
143
 
78
144
 
79
145
  class SinGeneratorSettings(ez.Settings):
146
+ time_axis: Optional[str] = "time"
80
147
  freq: float = 1.0 # Oscillation frequency in Hz
81
148
  amp: float = 1.0 # Amplitude
82
149
  phase: float = 0.0 # Phase offset (in radians)
83
150
 
84
151
 
85
152
  class SinGeneratorState(ez.State):
86
- ang_freq: Optional[float] = None # pre-calculated angular frequency in radians
153
+ ang_freq: float # pre-calculated angular frequency in radians
87
154
 
88
155
 
89
156
  class SinGenerator(ez.Unit):
90
157
  SETTINGS: SinGeneratorSettings
91
158
  STATE: SinGeneratorState
92
159
 
93
- INPUT_SIGNAL = ez.InputStream(TSMessage)
94
- OUTPUT_SIGNAL = ez.OutputStream(TSMessage)
160
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
161
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
95
162
 
96
163
  def initialize(self) -> None:
97
164
  self.STATE.ang_freq = 2.0 * np.pi * self.SETTINGS.freq
98
165
 
99
166
  @ez.subscriber(INPUT_SIGNAL)
100
167
  @ez.publisher(OUTPUT_SIGNAL)
101
- async def generate(self, msg: TSMessage) -> AsyncGenerator:
168
+ async def generate(self, msg: AxisArray) -> AsyncGenerator:
102
169
  """
103
170
  msg is assumed to be a monotonically increasing counter ..
104
171
  .. or at least a counter with an intelligently chosen modulus
105
172
  """
106
- t_sec = msg.data / msg.fs
173
+ axis_name = self.SETTINGS.time_axis
174
+ if axis_name is None:
175
+ axis_name = msg.dims[0]
176
+ fs = 1.0 / msg.get_axis(axis_name).gain
177
+ t_sec = msg.data / fs
107
178
  w = self.STATE.ang_freq * t_sec
108
179
  out_data = self.SETTINGS.amp * np.sin(w + self.SETTINGS.phase)
109
180
  yield (self.OUTPUT_SIGNAL, replace(msg, data=out_data))
@@ -112,8 +183,8 @@ class SinGenerator(ez.Unit):
112
183
  class OscillatorSettings(ez.Settings):
113
184
  n_time: int # Number of samples to output per block
114
185
  fs: float # Sampling rate of signal output in Hz
115
- n_ch: int = 1 # Number of channels to output per block
116
- dispatch_rate: Optional[Union[float, str]] = None # (Hz)
186
+ n_ch: int = 1 # Number of channels to output per block
187
+ dispatch_rate: Optional[Union[float, str]] = None # (Hz) | 'realtime' | 'ext_clock'
117
188
  freq: float = 1.0 # Oscillation frequency in Hz
118
189
  amp: float = 1.0 # Amplitude
119
190
  phase: float = 0.0 # Phase offset (in radians)
@@ -121,16 +192,15 @@ class OscillatorSettings(ez.Settings):
121
192
 
122
193
 
123
194
  class Oscillator(ez.Collection):
124
-
125
195
  SETTINGS: OscillatorSettings
126
196
 
127
- OUTPUT_SIGNAL = ez.OutputStream(TSMessage)
197
+ INPUT_CLOCK = ez.InputStream(ez.Flag)
198
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
128
199
 
129
200
  COUNTER = Counter()
130
201
  SIN = SinGenerator()
131
202
 
132
203
  def configure(self) -> None:
133
-
134
204
  # Calculate synchronous settings if necessary
135
205
  freq = self.SETTINGS.freq
136
206
  mod = None
@@ -145,54 +215,72 @@ class Oscillator(ez.Collection):
145
215
  fs=self.SETTINGS.fs,
146
216
  n_ch=self.SETTINGS.n_ch,
147
217
  dispatch_rate=self.SETTINGS.dispatch_rate,
148
- mod=mod
218
+ mod=mod,
149
219
  )
150
220
  )
151
221
 
152
222
  self.SIN.apply_settings(
153
223
  SinGeneratorSettings(
154
- freq=freq,
155
- amp=self.SETTINGS.amp,
156
- phase=self.SETTINGS.phase
224
+ freq=freq, amp=self.SETTINGS.amp, phase=self.SETTINGS.phase
157
225
  )
158
226
  )
159
227
 
160
228
  def network(self) -> ez.NetworkDefinition:
161
229
  return (
230
+ (self.INPUT_CLOCK, self.COUNTER.INPUT_CLOCK),
162
231
  (self.COUNTER.OUTPUT_SIGNAL, self.SIN.INPUT_SIGNAL),
163
- (self.SIN.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)
232
+ (self.SIN.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
164
233
  )
165
234
 
166
235
 
236
+ class RandomGeneratorSettings(ez.Settings):
237
+ loc: float = 0.0
238
+ scale: float = 1.0
239
+
240
+
167
241
  class RandomGenerator(ez.Unit):
168
- INPUT_SIGNAL = ez.InputStream(TSMessage)
169
- OUTPUT_SIGNAL = ez.OutputStream(TSMessage)
242
+ SETTINGS: RandomGeneratorSettings
243
+
244
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
245
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
170
246
 
171
247
  @ez.subscriber(INPUT_SIGNAL)
172
248
  @ez.publisher(OUTPUT_SIGNAL)
173
- async def generate(self, msg: TSMessage) -> AsyncGenerator:
174
- random_data = np.random.normal(size=msg.shape)
175
- yield (self.OUTPUT_SIGNAL, replace(msg, data=random_data))
249
+ async def generate(self, msg: AxisArray) -> AsyncGenerator:
250
+ random_data = np.random.normal(
251
+ size=msg.shape, loc=self.SETTINGS.loc, scale=self.SETTINGS.scale
252
+ )
253
+
254
+ yield self.OUTPUT_SIGNAL, replace(msg, data=random_data)
176
255
 
177
256
 
178
257
  class NoiseSettings(ez.Settings):
179
258
  n_time: int # Number of samples to output per block
180
259
  fs: float # Sampling rate of signal output in Hz
181
- n_ch: int = 1 # Number of channels to output
182
- dispatch_rate: Optional[Union[float, str]] = None # (Hz)
260
+ n_ch: int = 1 # Number of channels to output
261
+ dispatch_rate: Optional[
262
+ Union[float, str]
263
+ ] = None # (Hz), 'realtime', or 'ext_clock'
264
+ loc: float = 0.0 # DC offset
265
+ scale: float = 1.0 # Scale (in standard deviations)
266
+
183
267
 
184
268
  WhiteNoiseSettings = NoiseSettings
185
269
 
186
- class WhiteNoise(ez.Collection):
187
270
 
271
+ class WhiteNoise(ez.Collection):
188
272
  SETTINGS: NoiseSettings
189
273
 
190
- OUTPUT_SIGNAL = ez.OutputStream(TSMessage)
274
+ INPUT_CLOCK = ez.InputStream(ez.Flag)
275
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
191
276
 
192
277
  COUNTER = Counter()
193
278
  RANDOM = RandomGenerator()
194
279
 
195
280
  def configure(self) -> None:
281
+ self.RANDOM.apply_settings(
282
+ RandomGeneratorSettings(loc=self.SETTINGS.loc, scale=self.SETTINGS.scale)
283
+ )
196
284
 
197
285
  self.COUNTER.apply_settings(
198
286
  CounterSettings(
@@ -200,37 +288,124 @@ class WhiteNoise(ez.Collection):
200
288
  fs=self.SETTINGS.fs,
201
289
  n_ch=self.SETTINGS.n_ch,
202
290
  dispatch_rate=self.SETTINGS.dispatch_rate,
203
- mod=None
291
+ mod=None,
204
292
  )
205
293
  )
206
294
 
207
295
  def network(self) -> ez.NetworkDefinition:
208
296
  return (
297
+ (self.INPUT_CLOCK, self.COUNTER.INPUT_CLOCK),
209
298
  (self.COUNTER.OUTPUT_SIGNAL, self.RANDOM.INPUT_SIGNAL),
210
- (self.RANDOM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)
299
+ (self.RANDOM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
211
300
  )
212
301
 
213
- class PinkNoise(ez.Collection):
214
302
 
215
- SETTINGS: WhiteNoiseSettings
303
+ PinkNoiseSettings = NoiseSettings
216
304
 
217
- OUTPUT_SIGNAL = ez.OutputStream(TSMessage)
305
+
306
+ class PinkNoise(ez.Collection):
307
+ SETTINGS: PinkNoiseSettings
308
+
309
+ INPUT_CLOCK = ez.InputStream(ez.Flag)
310
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
218
311
 
219
312
  WHITE_NOISE = WhiteNoise()
220
313
  FILTER = ButterworthFilter()
221
314
 
222
315
  def configure(self) -> None:
223
-
224
- self.WHITE_NOISE.apply_settings( self.SETTINGS )
316
+ self.WHITE_NOISE.apply_settings(self.SETTINGS)
225
317
  self.FILTER.apply_settings(
226
318
  ButterworthFilterSettings(
227
- order = 1,
228
- cutoff = self.SETTINGS.fs * 0.01 # Hz
319
+ axis="time", order=1, cutoff=self.SETTINGS.fs * 0.01 # Hz
229
320
  )
230
321
  )
231
322
 
232
323
  def network(self) -> ez.NetworkDefinition:
233
324
  return (
325
+ (self.INPUT_CLOCK, self.WHITE_NOISE.INPUT_CLOCK),
234
326
  (self.WHITE_NOISE.OUTPUT_SIGNAL, self.FILTER.INPUT_SIGNAL),
235
- (self.FILTER.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)
236
- )
327
+ (self.FILTER.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
328
+ )
329
+
330
+
331
+ class AddState(ez.State):
332
+ queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
333
+ queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
334
+
335
+
336
+ class Add(ez.Unit):
337
+ """Add two signals together. Assumes compatible/similar axes/dimensions."""
338
+
339
+ STATE: AddState
340
+
341
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
342
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
343
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
344
+
345
+ @ez.subscriber(INPUT_SIGNAL_A)
346
+ async def on_a(self, msg: AxisArray) -> None:
347
+ self.STATE.queue_a.put_nowait(msg)
348
+
349
+ @ez.subscriber(INPUT_SIGNAL_B)
350
+ async def on_b(self, msg: AxisArray) -> None:
351
+ self.STATE.queue_b.put_nowait(msg)
352
+
353
+ @ez.publisher(OUTPUT_SIGNAL)
354
+ async def output(self) -> AsyncGenerator:
355
+ while True:
356
+ a = await self.STATE.queue_a.get()
357
+ b = await self.STATE.queue_b.get()
358
+
359
+ yield (self.OUTPUT_SIGNAL, replace(a, data=a.data + b.data))
360
+
361
+
362
+ class EEGSynthSettings(ez.Settings):
363
+ fs: float = 500.0 # Hz
364
+ n_time: int = 100
365
+ alpha_freq: float = 10.5 # Hz
366
+ n_ch: int = 8
367
+
368
+
369
+ class EEGSynth(ez.Collection):
370
+ SETTINGS: EEGSynthSettings
371
+
372
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
373
+
374
+ CLOCK = Clock()
375
+ NOISE = PinkNoise()
376
+ OSC = Oscillator()
377
+ ADD = Add()
378
+
379
+ def configure(self) -> None:
380
+ self.CLOCK.apply_settings(
381
+ ClockSettings(dispatch_rate=self.SETTINGS.fs / self.SETTINGS.n_time)
382
+ )
383
+
384
+ self.OSC.apply_settings(
385
+ OscillatorSettings(
386
+ n_time=self.SETTINGS.n_time,
387
+ fs=self.SETTINGS.fs,
388
+ n_ch=self.SETTINGS.n_ch,
389
+ dispatch_rate="ext_clock",
390
+ freq=self.SETTINGS.alpha_freq,
391
+ )
392
+ )
393
+
394
+ self.NOISE.apply_settings(
395
+ PinkNoiseSettings(
396
+ n_time=self.SETTINGS.n_time,
397
+ fs=self.SETTINGS.fs,
398
+ n_ch=self.SETTINGS.n_ch,
399
+ dispatch_rate="ext_clock",
400
+ scale=5.0,
401
+ )
402
+ )
403
+
404
+ def network(self) -> ez.NetworkDefinition:
405
+ return (
406
+ (self.CLOCK.OUTPUT_CLOCK, self.OSC.INPUT_CLOCK),
407
+ (self.CLOCK.OUTPUT_CLOCK, self.NOISE.INPUT_CLOCK),
408
+ (self.OSC.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A),
409
+ (self.NOISE.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B),
410
+ (self.ADD.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
411
+ )
ezmsg/sigproc/window.py CHANGED
@@ -1,88 +1,88 @@
1
- from dataclasses import dataclass, replace
1
+ from dataclasses import replace
2
2
 
3
3
  import ezmsg.core as ez
4
4
  import numpy as np
5
+ import numpy.typing as npt
5
6
 
6
- from ezmsg.util.messages import view2d, shape2d
7
- from ezmsg.sigproc.messages import TSMessage
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
8
 
9
- from typing import (
10
- AsyncGenerator,
11
- Optional,
12
- Tuple
13
- )
14
-
15
- @dataclass( frozen = True )
16
- class WindowSettingsMessage:
17
- window_dur: Optional[float] = None # Sec. If "None" -- passthrough; window_shift is ignored.
18
- window_shift: Optional[float] = None # Sec. If "None", activate "1:1 mode"
9
+ from typing import AsyncGenerator, Optional, Tuple, List
19
10
 
20
11
 
21
- class WindowSettings(WindowSettingsMessage, ez.Settings):
22
- ...
12
+ class WindowSettings(ez.Settings):
13
+ axis: Optional[str] = None
14
+ newaxis: Optional[
15
+ str
16
+ ] = None # Optional new axis for output. If "None" - no new axes on output
17
+ window_dur: Optional[
18
+ float
19
+ ] = None # Sec. If "None" -- passthrough; window_shift is ignored.
20
+ window_shift: Optional[float] = None # Sec. If "None", activate "1:1 mode"
23
21
 
24
22
 
25
23
  class WindowState(ez.State):
26
- cur_settings: Optional[ WindowSettingsMessage ] = None
24
+ cur_settings: WindowSettings
27
25
 
28
26
  samp_shape: Optional[Tuple[int, ...]] = None # Shape of individual sample
29
27
  out_fs: Optional[float] = None
30
- buffer: Optional[np.ndarray] = None
28
+ buffer: Optional[npt.NDArray] = None
31
29
  window_samples: Optional[int] = None
32
30
  window_shift_samples: Optional[int] = None
33
31
 
34
32
 
35
33
  class Window(ez.Unit):
36
-
37
34
  STATE: WindowState
38
35
  SETTINGS: WindowSettings
39
36
 
40
- INPUT_SIGNAL = ez.InputStream(TSMessage)
41
- OUTPUT_SIGNAL = ez.OutputStream(TSMessage)
42
- INPUT_SETTINGS = ez.InputStream(WindowSettingsMessage)
37
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
38
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
39
+ INPUT_SETTINGS = ez.InputStream(WindowSettings)
43
40
 
44
41
  def initialize(self) -> None:
45
42
  self.STATE.cur_settings = self.SETTINGS
46
43
 
47
44
  @ez.subscriber(INPUT_SETTINGS)
48
- async def on_settings(self, msg: WindowSettingsMessage) -> None:
45
+ async def on_settings(self, msg: WindowSettings) -> None:
49
46
  self.STATE.cur_settings = msg
50
- self.STATE.out_fs = None # This should trigger a reallocation
47
+ self.STATE.out_fs = None # This should trigger a reallocation
51
48
 
52
49
  @ez.subscriber(INPUT_SIGNAL)
53
50
  @ez.publisher(OUTPUT_SIGNAL)
54
- async def on_signal(self, msg: TSMessage) -> AsyncGenerator:
55
-
51
+ async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
56
52
  if self.STATE.cur_settings.window_dur is None:
57
53
  yield self.OUTPUT_SIGNAL, msg
58
54
  return
59
55
 
56
+ axis_name = self.STATE.cur_settings.axis
57
+ if axis_name is None:
58
+ axis_name = msg.dims[0]
59
+ axis_idx = msg.get_axis_idx(axis_name)
60
+ axis = msg.get_axis(axis_name)
61
+ fs = 1.0 / axis.gain
62
+
60
63
  # Create a view of data with time axis at dim 0
61
- time_view = np.moveaxis(msg.data, msg.time_dim, 0)
64
+ time_view = np.moveaxis(msg.data, axis_idx, 0)
62
65
  samp_shape = time_view.shape[1:]
63
66
 
64
- if (self.STATE.samp_shape != samp_shape) or (self.STATE.out_fs != msg.fs):
65
- # Pre(re?)allocate window data
66
- self.STATE.samp_shape = samp_shape
67
- self.STATE.out_fs = msg.fs
68
- self.STATE.window_samples = int(
69
- self.STATE.cur_settings.window_dur * self.STATE.out_fs
67
+ # Pre(re?)allocate buffer
68
+ window_samples = int(self.STATE.cur_settings.window_dur * fs)
69
+ if (
70
+ (self.STATE.samp_shape != samp_shape)
71
+ or (self.STATE.out_fs != fs)
72
+ or self.STATE.buffer is None
73
+ ):
74
+ self.STATE.buffer = np.zeros(tuple([window_samples] + list(samp_shape)))
75
+
76
+ self.STATE.window_samples = window_samples
77
+ self.STATE.samp_shape = samp_shape
78
+ self.STATE.out_fs = fs
79
+
80
+ self.STATE.window_shift_samples = None
81
+ if self.STATE.cur_settings.window_shift is not None:
82
+ self.STATE.window_shift_samples = int(
83
+ fs * self.STATE.cur_settings.window_shift
70
84
  )
71
85
 
72
- if self.STATE.cur_settings.window_shift is not None:
73
- self.STATE.window_shift_samples = int(
74
- self.STATE.out_fs * self.STATE.cur_settings.window_shift
75
- )
76
-
77
- if self.STATE.buffer is None:
78
- self.STATE.buffer = np.zeros(tuple([self.STATE.window_samples] + list(self.STATE.samp_shape)))
79
- elif self.STATE.buffer.shape[0] > self.STATE.window_samples:
80
- self.STATE.buffer = self.STATE.buffer[:self.STATE.window_samples, ...]
81
- elif self.STATE.buffer.shape[0] < self.STATE.window_samples:
82
- extra_samples = self.STATE.window_samples - self.STATE.buffer.shape[0]
83
- extra_samples = np.zeros(tuple([extra_samples] + list(self.STATE.samp_shape)))
84
- self.STATE.buffer = np.concatenate((extra_samples, self.STATE.buffer), axis=0)
85
-
86
86
  # Currently we just concatenate the new time samples and clip the output
87
87
  # np.roll actually returns a copy, and there's no way to construct a
88
88
  # rolling view of the data. In current numpy implementations, np.concatenate
@@ -90,23 +90,55 @@ class Window(ez.Unit):
90
90
  # be a performance bottleneck for large memory arrays.
91
91
  self.STATE.buffer = np.concatenate((self.STATE.buffer, time_view), axis=0)
92
92
 
93
- if self.STATE.window_shift_samples is None: # one-to-one mode
94
-
95
- self.STATE.buffer = self.STATE.buffer[-self.STATE.window_samples:, ...]
93
+ buffer_offset = np.arange(self.STATE.buffer.shape[0] + time_view.shape[0])
94
+ buffer_offset -= self.STATE.buffer.shape[0] + 1
95
+ buffer_offset = (buffer_offset * axis.gain) + axis.offset
96
96
 
97
- # Finally, move time axis back into location before yielding
98
- out_view = np.moveaxis(self.STATE.buffer, 0, msg.time_dim)
99
- yield (self.OUTPUT_SIGNAL, replace(msg, data=out_view))
97
+ outputs: List[Tuple[npt.NDArray, float]] = []
100
98
 
101
- else: # slightly more complicated window shifting
99
+ if self.STATE.window_shift_samples is None: # one-to-one mode
100
+ self.STATE.buffer = self.STATE.buffer[-self.STATE.window_samples :, ...]
101
+ buffer_offset = buffer_offset[-self.STATE.window_samples :]
102
+ outputs.append((self.STATE.buffer, buffer_offset[0]))
102
103
 
104
+ else:
103
105
  yieldable_size = self.STATE.window_samples + self.STATE.window_shift_samples
104
106
  while self.STATE.buffer.shape[0] >= yieldable_size:
107
+ outputs.append(
108
+ (
109
+ self.STATE.buffer[: self.STATE.window_samples, ...],
110
+ buffer_offset[0],
111
+ )
112
+ )
113
+ self.STATE.buffer = self.STATE.buffer[
114
+ self.STATE.window_shift_samples :, ...
115
+ ]
116
+ buffer_offset = buffer_offset[self.STATE.window_shift_samples :]
117
+
118
+ for out_view, offset in outputs:
119
+ out_view = np.moveaxis(out_view, 0, axis_idx)
120
+
121
+ if (
122
+ self.STATE.cur_settings.newaxis is not None
123
+ and self.STATE.cur_settings.newaxis != self.STATE.cur_settings.axis
124
+ ):
125
+ new_gain = 0.0
126
+ if self.STATE.window_shift_samples is not None:
127
+ new_gain = axis.gain * self.STATE.window_shift_samples
128
+
129
+ out_axis = replace(axis, unit=axis.unit, gain=new_gain, offset=offset)
130
+ out_axes = {**msg.axes, **{self.STATE.cur_settings.newaxis: out_axis}}
131
+ out_dims = [self.STATE.cur_settings.newaxis] + msg.dims
132
+ out_view = out_view[np.newaxis, ...]
133
+
134
+ yield self.OUTPUT_SIGNAL, replace(
135
+ msg, data=out_view, dims=out_dims, axes=out_axes
136
+ )
105
137
 
106
- # Yield if possible
107
- out_view = self.STATE.buffer[:self.STATE.window_samples, ...]
108
- out_view = np.moveaxis(out_view, 0, msg.time_dim)
109
- yield (self.OUTPUT_SIGNAL, replace(msg, data=out_view))
110
-
111
- # Shift window
112
- self.STATE.buffer = self.STATE.buffer[self.STATE.window_shift_samples:, ...]
138
+ else:
139
+ if axis_name in msg.axes:
140
+ out_axes = msg.axes
141
+ out_axes[axis_name] = replace(axis, offset=offset)
142
+ yield self.OUTPUT_SIGNAL, replace(msg, data=out_view, axes=out_axes)
143
+ else:
144
+ yield self.OUTPUT_SIGNAL, replace(msg, data=out_view)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ezmsg-sigproc
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Home-page: https://github.com/iscoe/ezmsg
6
6
  Author: Griffin Milsap
@@ -10,7 +10,7 @@ Classifier: Operating System :: OS Independent
10
10
  Requires-Python: >=3.8
11
11
  Description-Content-Type: text/markdown
12
12
  License-File: LICENSE.txt
13
- Requires-Dist: ezmsg (>=2.1.0)
13
+ Requires-Dist: ezmsg (>=3.3.0)
14
14
  Requires-Dist: numpy (>=1.19.5)
15
15
  Requires-Dist: scipy (>=1.6.3)
16
16
  Provides-Extra: test