ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/synth.py CHANGED
@@ -1,15 +1,42 @@
1
1
  import asyncio
2
- import time
2
+ from collections import deque
3
3
  from dataclasses import dataclass, replace, field
4
+ import time
5
+ from typing import Optional, Generator, AsyncGenerator, Union
4
6
 
5
- import ezmsg.core as ez
6
7
  import numpy as np
7
-
8
+ import ezmsg.core as ez
9
+ from ezmsg.util.generator import consumer, GenAxisArray
8
10
  from ezmsg.util.messages.axisarray import AxisArray
9
11
 
10
12
  from .butterworthfilter import ButterworthFilter, ButterworthFilterSettings
11
13
 
12
- from typing import Optional, AsyncGenerator, Union
14
+
15
+ # CLOCK -- generate events at a specified rate #
16
+ def clock(
17
+ dispatch_rate: Optional[float]
18
+ ) -> Generator[ez.Flag, None, None]:
19
+ n_dispatch = -1
20
+ t_0 = time.time()
21
+ while True:
22
+ if dispatch_rate is not None:
23
+ n_dispatch += 1
24
+ t_next = t_0 + n_dispatch / dispatch_rate
25
+ time.sleep(max(0, t_next - time.time()))
26
+ yield ez.Flag()
27
+
28
+
29
+ async def aclock(
30
+ dispatch_rate: Optional[float]
31
+ ) -> AsyncGenerator[ez.Flag, None]:
32
+ t_0 = time.time()
33
+ n_dispatch = -1
34
+ while True:
35
+ if dispatch_rate is not None:
36
+ n_dispatch += 1
37
+ t_next = t_0 + n_dispatch / dispatch_rate
38
+ await asyncio.sleep(t_next - time.time())
39
+ yield ez.Flag()
13
40
 
14
41
 
15
42
  class ClockSettings(ez.Settings):
@@ -19,6 +46,7 @@ class ClockSettings(ez.Settings):
19
46
 
20
47
  class ClockState(ez.State):
21
48
  cur_settings: ClockSettings
49
+ gen: AsyncGenerator
22
50
 
23
51
 
24
52
  class Clock(ez.Unit):
@@ -30,17 +58,95 @@ class Clock(ez.Unit):
30
58
 
31
59
  def initialize(self) -> None:
32
60
  self.STATE.cur_settings = self.SETTINGS
61
+ self.construct_generator()
62
+
63
+ def construct_generator(self):
64
+ self.STATE.gen = aclock(self.STATE.cur_settings.dispatch_rate)
33
65
 
34
66
  @ez.subscriber(INPUT_SETTINGS)
35
67
  async def on_settings(self, msg: ClockSettings) -> None:
36
68
  self.STATE.cur_settings = msg
69
+ self.construct_generator()
37
70
 
38
71
  @ez.publisher(OUTPUT_CLOCK)
39
72
  async def generate(self) -> AsyncGenerator:
40
73
  while True:
41
- if self.STATE.cur_settings.dispatch_rate is not None:
42
- await asyncio.sleep(1.0 / self.STATE.cur_settings.dispatch_rate)
43
- yield self.OUTPUT_CLOCK, ez.Flag
74
+ out = await self.STATE.gen.__anext__()
75
+ if out:
76
+ yield self.OUTPUT_CLOCK, out
77
+
78
+
79
+ # COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. #
80
+ async def acounter(
81
+ n_time: int, # Number of samples to output per block
82
+ fs: Optional[float], # Sampling rate of signal output in Hz
83
+ n_ch: int = 1, # Number of channels to synthesize
84
+
85
+ # Message dispatch rate (Hz), 'realtime' or None (fast as possible)
86
+ # Note: if dispatch_rate is a float then time offsets will be synthetic and the
87
+ # system will run faster or slower than wall clock time.
88
+ dispatch_rate: Optional[Union[float, str]] = None,
89
+
90
+ # If set to an integer, counter will rollover at this number.
91
+ mod: Optional[int] = None,
92
+ ) -> AsyncGenerator[AxisArray, None]:
93
+
94
+ # TODO: Adapt this to use ezmsg.util.rate?
95
+
96
+ counter_start: int = 0 # next sample's first value
97
+
98
+ b_realtime = False
99
+ b_manual_dispatch = False
100
+ b_ext_clock = False
101
+ if dispatch_rate is not None:
102
+ if isinstance(dispatch_rate, str):
103
+ if dispatch_rate.lower() == "realtime":
104
+ b_realtime = True
105
+ elif dispatch_rate.lower() == "ext_clock":
106
+ b_ext_clock = True
107
+ else:
108
+ b_manual_dispatch = True
109
+
110
+ n_sent: int = 0 # It is convenient to know how many samples we have sent.
111
+ clock_zero: float = time.time() # time associated with first sample
112
+
113
+ while True:
114
+ # 1. Sleep, if necessary, until we are at the end of the current block
115
+ if b_realtime:
116
+ n_next = n_sent + n_time
117
+ t_next = clock_zero + n_next / fs
118
+ await asyncio.sleep(t_next - time.time())
119
+ elif b_manual_dispatch:
120
+ n_disp_next = 1 + n_sent / n_time
121
+ t_disp_next = clock_zero + n_disp_next / dispatch_rate
122
+ await asyncio.sleep(t_disp_next - time.time())
123
+
124
+ # 2. Prepare counter data.
125
+ block_samp = np.arange(counter_start, counter_start + n_time)[:, np.newaxis]
126
+ if mod is not None:
127
+ block_samp %= mod
128
+ block_samp = np.tile(block_samp, (1, n_ch))
129
+
130
+ # 3. Prepare offset - the time associated with block_samp[0]
131
+ if b_realtime:
132
+ offset = t_next - n_time / fs
133
+ elif b_ext_clock:
134
+ offset = time.time()
135
+ else:
136
+ # Purely synthetic.
137
+ offset = n_sent / fs
138
+ # offset += clock_zero # ??
139
+
140
+ # 4. yield output
141
+ yield AxisArray(
142
+ block_samp,
143
+ dims=["time", "ch"],
144
+ axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=offset)},
145
+ )
146
+
147
+ # 5. Update state for next iteration (after next yield)
148
+ counter_start = block_samp[-1, 0] + 1 # do not % mod
149
+ n_sent += n_time
44
150
 
45
151
 
46
152
  class CounterSettings(ez.Settings):
@@ -57,6 +163,8 @@ class CounterSettings(ez.Settings):
57
163
  n_ch: int = 1 # Number of channels to synthesize
58
164
 
59
165
  # Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
166
+ # Note: if dispatch_rate is a float then time offsets will be synthetic and the
167
+ # system will run faster or slower than wall clock time.
60
168
  dispatch_rate: Optional[Union[float, str]] = None
61
169
 
62
170
  # If set to an integer, counter will rollover
@@ -64,9 +172,9 @@ class CounterSettings(ez.Settings):
64
172
 
65
173
 
66
174
  class CounterState(ez.State):
175
+ gen: AsyncGenerator[AxisArray, Optional[ez.Flag]]
67
176
  cur_settings: CounterSettings
68
- samp: int = 0 # current sample counter
69
- clock_event: asyncio.Event
177
+ new_generator: asyncio.Event
70
178
 
71
179
 
72
180
  class Counter(ez.Unit):
@@ -79,9 +187,8 @@ class Counter(ez.Unit):
79
187
  INPUT_SETTINGS = ez.InputStream(CounterSettings)
80
188
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
81
189
 
82
- def initialize(self) -> None:
83
- self.STATE.clock_event = asyncio.Event()
84
- self.STATE.clock_event.clear()
190
+ async def initialize(self) -> None:
191
+ self.STATE.new_generator = asyncio.Event()
85
192
  self.validate_settings(self.SETTINGS)
86
193
 
87
194
  @ez.subscriber(INPUT_SETTINGS)
@@ -93,53 +200,64 @@ class Counter(ez.Unit):
93
200
  settings.dispatch_rate, str
94
201
  ) and self.SETTINGS.dispatch_rate not in ["realtime", "ext_clock"]:
95
202
  raise ValueError(f"Unknown dispatch_rate: {self.SETTINGS.dispatch_rate}")
96
-
97
203
  self.STATE.cur_settings = settings
98
-
204
+ self.construct_generator()
205
+
206
+ def construct_generator(self):
207
+ self.STATE.gen = acounter(
208
+ self.STATE.cur_settings.n_time,
209
+ self.STATE.cur_settings.fs,
210
+ n_ch=self.STATE.cur_settings.n_ch,
211
+ dispatch_rate=self.STATE.cur_settings.dispatch_rate,
212
+ mod=self.STATE.cur_settings.mod
213
+ )
214
+ self.STATE.new_generator.set()
215
+
99
216
  @ez.subscriber(INPUT_CLOCK)
100
- async def on_clock(self, _: ez.Flag):
101
- self.STATE.clock_event.set()
217
+ @ez.publisher(OUTPUT_SIGNAL)
218
+ async def on_clock(self, clock: ez.Flag):
219
+ if self.STATE.cur_settings.dispatch_rate == 'ext_clock':
220
+ out = await self.STATE.gen.__anext__()
221
+ yield self.OUTPUT_SIGNAL, out
102
222
 
103
223
  @ez.publisher(OUTPUT_SIGNAL)
104
- async def publish(self) -> AsyncGenerator:
224
+ async def run_generator(self) -> AsyncGenerator:
105
225
  while True:
106
- block_dur = self.STATE.cur_settings.n_time / self.STATE.cur_settings.fs
107
-
108
- dispatch_rate = self.STATE.cur_settings.dispatch_rate
109
- if dispatch_rate is not None:
110
- if isinstance(dispatch_rate, str):
111
- if dispatch_rate == "realtime":
112
- await asyncio.sleep(block_dur)
113
- elif dispatch_rate == "ext_clock":
114
- await self.STATE.clock_event.wait()
115
- self.STATE.clock_event.clear()
116
- else:
117
- await asyncio.sleep(1.0 / dispatch_rate)
118
-
119
- block_samp = np.arange(self.STATE.cur_settings.n_time)[:, np.newaxis]
120
-
121
- t_samp = block_samp + self.STATE.samp
122
- self.STATE.samp = t_samp[-1] + 1
123
-
124
- if self.STATE.cur_settings.mod is not None:
125
- t_samp %= self.STATE.cur_settings.mod
126
- self.STATE.samp %= self.STATE.cur_settings.mod
127
-
128
- t_samp = np.tile(t_samp, (1, self.STATE.cur_settings.n_ch))
129
-
130
- offset_adj = self.STATE.cur_settings.n_time / self.STATE.cur_settings.fs
131
-
132
- out = AxisArray(
133
- t_samp,
134
- dims=["time", "ch"],
135
- axes=dict(
136
- time=AxisArray.Axis.TimeAxis(
137
- fs=self.STATE.cur_settings.fs, offset=time.time() - offset_adj
138
- )
139
- ),
140
- )
226
+
227
+ await self.STATE.new_generator.wait()
228
+ self.STATE.new_generator.clear()
229
+
230
+ if self.STATE.cur_settings.dispatch_rate == 'ext_clock':
231
+ continue
232
+
233
+ while not self.STATE.new_generator.is_set():
234
+ out = await self.STATE.gen.__anext__()
235
+ yield self.OUTPUT_SIGNAL, out
236
+
237
+
238
+ @consumer
239
+ def sin(
240
+ axis: Optional[str] = "time",
241
+ freq: float = 1.0, # Oscillation frequency in Hz
242
+ amp: float = 1.0, # Amplitude
243
+ phase: float = 0.0, # Phase offset (in radians)
244
+ ) -> Generator[AxisArray, AxisArray, None]:
245
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
246
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
247
+
248
+ ang_freq = 2.0 * np.pi * freq
249
+
250
+ while True:
251
+ axis_arr_in = yield axis_arr_out
252
+ # axis_arr_in is expected to be sample counts
253
+
254
+ axis_name = axis
255
+ if axis_name is None:
256
+ axis_name = axis_arr_in.dims[0]
141
257
 
142
- yield self.OUTPUT_SIGNAL, out
258
+ w = (ang_freq * axis_arr_in.get_axis(axis_name).gain) * axis_arr_in.data
259
+ out_data = amp * np.sin(w + phase)
260
+ axis_arr_out = replace(axis_arr_in, data=out_data)
143
261
 
144
262
 
145
263
  class SinGeneratorSettings(ez.Settings):
@@ -149,35 +267,16 @@ class SinGeneratorSettings(ez.Settings):
149
267
  phase: float = 0.0 # Phase offset (in radians)
150
268
 
151
269
 
152
- class SinGeneratorState(ez.State):
153
- ang_freq: float # pre-calculated angular frequency in radians
154
-
155
-
156
- class SinGenerator(ez.Unit):
270
+ class SinGenerator(GenAxisArray):
157
271
  SETTINGS: SinGeneratorSettings
158
- STATE: SinGeneratorState
159
-
160
- INPUT_SIGNAL = ez.InputStream(AxisArray)
161
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
162
-
163
- def initialize(self) -> None:
164
- self.STATE.ang_freq = 2.0 * np.pi * self.SETTINGS.freq
165
272
 
166
- @ez.subscriber(INPUT_SIGNAL)
167
- @ez.publisher(OUTPUT_SIGNAL)
168
- async def generate(self, msg: AxisArray) -> AsyncGenerator:
169
- """
170
- msg is assumed to be a monotonically increasing counter ..
171
- .. or at least a counter with an intelligently chosen modulus
172
- """
173
- axis_name = self.SETTINGS.time_axis
174
- if axis_name is None:
175
- axis_name = msg.dims[0]
176
- fs = 1.0 / msg.get_axis(axis_name).gain
177
- t_sec = msg.data / fs
178
- w = self.STATE.ang_freq * t_sec
179
- out_data = self.SETTINGS.amp * np.sin(w + self.SETTINGS.phase)
180
- yield (self.OUTPUT_SIGNAL, replace(msg, data=out_data))
273
+ def construct_generator(self):
274
+ self.STATE.gen = sin(
275
+ axis=self.SETTINGS.time_axis,
276
+ freq=self.SETTINGS.freq,
277
+ amp=self.SETTINGS.amp,
278
+ phase=self.SETTINGS.phase
279
+ )
181
280
 
182
281
 
183
282
  class OscillatorSettings(ez.Settings):
ezmsg/sigproc/window.py CHANGED
@@ -1,33 +1,207 @@
1
1
  from dataclasses import replace
2
+ import traceback
3
+ from typing import AsyncGenerator, Optional, Tuple, List, Generator
2
4
 
3
5
  import ezmsg.core as ez
4
6
  import numpy as np
5
7
  import numpy.typing as npt
6
8
 
7
- from ezmsg.util.messages.axisarray import AxisArray
8
-
9
- from typing import AsyncGenerator, Optional, Tuple, List
9
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis, sliding_win_oneaxis
10
+ from ezmsg.util.generator import consumer
11
+
12
+
13
+ @consumer
14
+ def windowing(
15
+ axis: Optional[str] = None,
16
+ newaxis: Optional[str] = None,
17
+ window_dur: Optional[float] = None,
18
+ window_shift: Optional[float] = None,
19
+ zero_pad_until: str = "input"
20
+ ) -> Generator[AxisArray, List[AxisArray], None]:
21
+ """
22
+ Window function that generates windows of data from an input `AxisArray`.
23
+ :param axis: The axis along which to segment windows.
24
+ If None, defaults to the first dimension of the first seen AxisArray.
25
+ :param newaxis: Optional new axis for the output. If None, no new axes will be added.
26
+ If a string, windows will be stacked in a new axis with key `newaxis`, immediately preceding the windowed axis.
27
+ :param window_dur: The duration of the window in seconds.
28
+ If None, the function acts as a passthrough and all other parameters are ignored.
29
+ :param window_shift: The shift of the window in seconds.
30
+ If None (default), windowing operates in "1:1 mode", where each input yields exactly one most-recent window.
31
+ :param zero_pad_until: Determines how the function initializes the buffer.
32
+ Can be one of "input" (default), "full", "shift", or "none". If `window_shift` is None then this field is
33
+ ignored and "input" is always used.
34
+ "input" (default) initializes the buffer with the input then prepends with zeros to the window size.
35
+ The first input will always yield at least one output.
36
+ "shift" fills the buffer until `window_shift`.
37
+ No outputs will be yielded until at least `window_shift` data has been seen.
38
+ "none" does not pad the buffer. No outputs will be yielded until at least `window_dur` data has been seen.
39
+ :return:
40
+ A (primed) generator that accepts .send(an AxisArray object) and yields a list of windowed
41
+ AxisArray objects. The list will always be length-1 if `newaxis` is not None or `window_shift` is None.
42
+ """
43
+ # TODO: The return should be an AxisArray. i.e., always add a new axis. The Unit can do a multi-yield-per-pub
44
+ # if the parameterization does not expect a newaxis.
45
+
46
+ if window_shift is None and zero_pad_until != "input":
47
+ ez.logger.warning("`zero_pad_until` must be 'input' if `window_shift` is None. "
48
+ f"Ignoring received argument value: {zero_pad_until}")
49
+ zero_pad_until = "input"
50
+ elif window_shift is not None and zero_pad_until == "input":
51
+ ez.logger.warning("windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size "
52
+ "of the first input. We recommend using 'shift' when `window_shift` is float-valued.")
53
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
54
+ axis_arr_out = [AxisArray(np.array([]), dims=[""])]
55
+
56
+ # State variables
57
+ prev_samp_shape: Optional[Tuple[int, ...]] = None
58
+ prev_fs: Optional[float] = None
59
+ buffer: Optional[npt.NDArray] = None
60
+ window_samples: Optional[int] = None
61
+ window_shift_samples: Optional[int] = None
62
+ shift_deficit: int = 0 # Number of incoming samples to ignore. Only relevant when shift > window.
63
+ newaxis_warn_flag: bool = False
64
+ mod_ax: Optional[str] = None # The key of the modified axis in the output's .axes
65
+ out_template: Optional[AxisArray] = None # Template for building return values.
66
+
67
+ while True:
68
+ axis_arr_in = yield axis_arr_out
69
+
70
+ if window_dur is None:
71
+ axis_arr_out = [axis_arr_in]
72
+ continue
73
+
74
+ if axis is None:
75
+ axis = axis_arr_in.dims[0]
76
+ axis_idx = axis_arr_in.get_axis_idx(axis)
77
+ axis_info = axis_arr_in.get_axis(axis)
78
+ fs = 1.0 / axis_info.gain
79
+
80
+ if (not newaxis_warn_flag) and newaxis is not None and newaxis in axis_arr_in.dims:
81
+ ez.logger.warning(f"newaxis {newaxis} present in input dims and will be ignored.")
82
+ newaxis_warn_flag = True
83
+ b_newaxis = newaxis is not None and newaxis not in axis_arr_in.dims
84
+
85
+ samp_shape = axis_arr_in.data.shape[:axis_idx] + axis_arr_in.data.shape[axis_idx + 1:]
86
+ window_samples = int(window_dur * fs)
87
+ b_1to1 = window_shift is None
88
+ if not b_1to1:
89
+ window_shift_samples = int(window_shift * fs)
90
+
91
+ # If buffer unset or input stats changed, create a new buffer
92
+ if buffer is None or samp_shape != prev_samp_shape or fs != prev_fs:
93
+ if zero_pad_until == "none":
94
+ req_samples = window_samples
95
+ elif zero_pad_until == "shift" and not b_1to1:
96
+ req_samples = window_shift_samples
97
+ else: # i.e. zero_pad_until == "input"
98
+ req_samples = axis_arr_in.data.shape[axis_idx]
99
+ n_zero = max(0, window_samples - req_samples)
100
+ buffer_shape = axis_arr_in.data.shape[:axis_idx] + (n_zero,) + axis_arr_in.data.shape[axis_idx + 1:]
101
+ buffer = np.zeros(buffer_shape)
102
+ prev_samp_shape = samp_shape
103
+ prev_fs = fs
104
+
105
+ # Add new data to buffer.
106
+ # Currently we just concatenate the new time samples and clip the output
107
+ # np.roll actually returns a copy, and there's no way to construct a
108
+ # rolling view of the data. In current numpy implementations, np.concatenate
109
+ # is generally faster than np.roll and slicing anyway, but this could still
110
+ # be a performance bottleneck for large memory arrays.
111
+ buffer = np.concatenate((buffer, axis_arr_in.data), axis=axis_idx)
112
+ # Note: if we ever move to using a circular buffer without copies then we need to create copies somewhere,
113
+ # because currently the outputs are merely views into the buffer.
114
+
115
+ # Create a vector of buffer timestamps to track axis `offset` in output(s)
116
+ buffer_offset = np.arange(buffer.shape[axis_idx]).astype(float)
117
+ # Adjust so first _new_ sample at index 0
118
+ buffer_offset -= buffer_offset[-axis_arr_in.data.shape[axis_idx]]
119
+ # Convert form indices to 'units' (probably seconds).
120
+ buffer_offset *= axis_info.gain
121
+ buffer_offset += axis_info.offset
122
+
123
+ if not b_1to1 and shift_deficit > 0:
124
+ n_skip = min(buffer.shape[axis_idx], shift_deficit)
125
+ if n_skip > 0:
126
+ buffer = slice_along_axis(buffer, np.s_[n_skip:], axis_idx)
127
+ buffer_offset = buffer_offset[n_skip:]
128
+ shift_deficit -= n_skip
129
+
130
+ # Prepare reusable parts of output
131
+ if out_template is None:
132
+ out_dims = axis_arr_in.dims
133
+ if newaxis is None:
134
+ out_axes = {
135
+ **axis_arr_in.axes,
136
+ axis: replace(axis_info, offset=0.0) # offset modified below.
137
+ }
138
+ mod_ax = axis
139
+ else:
140
+ out_dims = out_dims[:axis_idx] + [newaxis] + out_dims[axis_idx:]
141
+ out_axes = {
142
+ **axis_arr_in.axes,
143
+ newaxis: AxisArray.Axis(
144
+ unit=axis_info.unit,
145
+ gain=0.0 if b_1to1 else axis_info.gain * window_shift_samples,
146
+ offset=0.0 # offset modified below
147
+ )
148
+ }
149
+ mod_ax = newaxis
150
+ out_template = replace(axis_arr_in, data=np.zeros([0 for _ in out_dims]), dims=out_dims)
151
+
152
+ # Generate outputs.
153
+ axis_arr_out: List[AxisArray] = []
154
+ if b_1to1:
155
+ # one-to-one mode -- Each send yields exactly one window containing only the most recent samples.
156
+ buffer = slice_along_axis(buffer, np.s_[-window_samples:], axis_idx)
157
+ axis_arr_out.append(replace(
158
+ out_template,
159
+ data=np.expand_dims(buffer, axis=axis_idx) if b_newaxis else buffer,
160
+ axes={
161
+ **out_axes,
162
+ mod_ax: replace(out_axes[mod_ax], offset=buffer_offset[-window_samples])
163
+ }
164
+ ))
165
+ elif buffer.shape[axis_idx] >= window_samples:
166
+ # Deterministic window shifts.
167
+ win_view = sliding_win_oneaxis(buffer, window_samples, axis_idx)
168
+ win_view = slice_along_axis(win_view, np.s_[::window_shift_samples], axis_idx)
169
+ offset_view = sliding_win_oneaxis(buffer_offset, window_samples, 0)[::window_shift_samples]
170
+ # Place in output
171
+ if b_newaxis:
172
+ axis_arr_out.append(replace(
173
+ out_template,
174
+ data=win_view,
175
+ axes={**out_axes, mod_ax: replace(out_axes[mod_ax], offset=offset_view[0, 0])}
176
+ ))
177
+ else:
178
+ for win_ix in range(win_view.shape[axis_idx]):
179
+ axis_arr_out.append(replace(
180
+ out_template,
181
+ data=slice_along_axis(win_view, win_ix, axis_idx),
182
+ axes={
183
+ **out_axes,
184
+ mod_ax: replace(out_axes[mod_ax], offset=offset_view[win_ix, 0])
185
+ }
186
+ ))
187
+
188
+ # Drop expired beginning of buffer and update shift_deficit
189
+ multi_shift = window_shift_samples * win_view.shape[axis_idx]
190
+ shift_deficit = max(0, multi_shift - buffer.shape[axis_idx])
191
+ buffer = slice_along_axis(buffer, np.s_[multi_shift:], axis_idx)
10
192
 
11
193
 
12
194
  class WindowSettings(ez.Settings):
13
195
  axis: Optional[str] = None
14
- newaxis: Optional[
15
- str
16
- ] = None # Optional new axis for output. If "None" - no new axes on output
17
- window_dur: Optional[
18
- float
19
- ] = None # Sec. If "None" -- passthrough; window_shift is ignored.
20
- window_shift: Optional[float] = None # Sec. If "None", activate "1:1 mode"
196
+ newaxis: Optional[str] = None # new axis for output. No new axes if None
197
+ window_dur: Optional[float] = None # Sec. passthrough if None
198
+ window_shift: Optional[float] = None # Sec. Use "1:1 mode" if None
199
+ zero_pad_until: str = "full" # "full", "shift", "input", "none"
21
200
 
22
201
 
23
202
  class WindowState(ez.State):
24
203
  cur_settings: WindowSettings
25
-
26
- samp_shape: Optional[Tuple[int, ...]] = None # Shape of individual sample
27
- out_fs: Optional[float] = None
28
- buffer: Optional[npt.NDArray] = None
29
- window_samples: Optional[int] = None
30
- window_shift_samples: Optional[int] = None
204
+ gen: Generator
31
205
 
32
206
 
33
207
  class Window(ez.Unit):
@@ -40,105 +214,33 @@ class Window(ez.Unit):
40
214
 
41
215
  def initialize(self) -> None:
42
216
  self.STATE.cur_settings = self.SETTINGS
217
+ self.construct_generator()
43
218
 
44
219
  @ez.subscriber(INPUT_SETTINGS)
45
220
  async def on_settings(self, msg: WindowSettings) -> None:
46
221
  self.STATE.cur_settings = msg
47
- self.STATE.out_fs = None # This should trigger a reallocation
222
+ self.construct_generator()
223
+
224
+ def construct_generator(self):
225
+ self.STATE.gen = windowing(
226
+ axis=self.STATE.cur_settings.axis,
227
+ newaxis=self.STATE.cur_settings.newaxis,
228
+ window_dur=self.STATE.cur_settings.window_dur,
229
+ window_shift=self.STATE.cur_settings.window_shift,
230
+ zero_pad_until=self.STATE.cur_settings.zero_pad_until
231
+ )
48
232
 
49
233
  @ez.subscriber(INPUT_SIGNAL)
50
234
  @ez.publisher(OUTPUT_SIGNAL)
51
235
  async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
52
- if self.STATE.cur_settings.window_dur is None:
53
- yield self.OUTPUT_SIGNAL, msg
54
- return
55
-
56
- axis_name = self.STATE.cur_settings.axis
57
- if axis_name is None:
58
- axis_name = msg.dims[0]
59
- axis_idx = msg.get_axis_idx(axis_name)
60
- axis = msg.get_axis(axis_name)
61
- fs = 1.0 / axis.gain
62
-
63
- # Create a view of data with time axis at dim 0
64
- time_view = np.moveaxis(msg.data, axis_idx, 0)
65
- samp_shape = time_view.shape[1:]
66
-
67
- # Pre(re?)allocate buffer
68
- window_samples = int(self.STATE.cur_settings.window_dur * fs)
69
- if (
70
- (self.STATE.samp_shape != samp_shape)
71
- or (self.STATE.out_fs != fs)
72
- or self.STATE.buffer is None
73
- ):
74
- self.STATE.buffer = np.zeros(tuple([window_samples] + list(samp_shape)))
75
-
76
- self.STATE.window_samples = window_samples
77
- self.STATE.samp_shape = samp_shape
78
- self.STATE.out_fs = fs
79
-
80
- self.STATE.window_shift_samples = None
81
- if self.STATE.cur_settings.window_shift is not None:
82
- self.STATE.window_shift_samples = int(
83
- fs * self.STATE.cur_settings.window_shift
84
- )
85
-
86
- # Currently we just concatenate the new time samples and clip the output
87
- # np.roll actually returns a copy, and there's no way to construct a
88
- # rolling view of the data. In current numpy implementations, np.concatenate
89
- # is generally faster than np.roll and slicing anyway, but this could still
90
- # be a performance bottleneck for large memory arrays.
91
- self.STATE.buffer = np.concatenate((self.STATE.buffer, time_view), axis=0)
92
-
93
- buffer_offset = np.arange(self.STATE.buffer.shape[0] + time_view.shape[0])
94
- buffer_offset -= self.STATE.buffer.shape[0] + 1
95
- buffer_offset = (buffer_offset * axis.gain) + axis.offset
96
-
97
- outputs: List[Tuple[npt.NDArray, float]] = []
98
-
99
- if self.STATE.window_shift_samples is None: # one-to-one mode
100
- self.STATE.buffer = self.STATE.buffer[-self.STATE.window_samples :, ...]
101
- buffer_offset = buffer_offset[-self.STATE.window_samples :]
102
- outputs.append((self.STATE.buffer, buffer_offset[0]))
103
-
104
- else:
105
- yieldable_size = self.STATE.window_samples + self.STATE.window_shift_samples
106
- while self.STATE.buffer.shape[0] >= yieldable_size:
107
- outputs.append(
108
- (
109
- self.STATE.buffer[: self.STATE.window_samples, ...],
110
- buffer_offset[0],
111
- )
112
- )
113
- self.STATE.buffer = self.STATE.buffer[
114
- self.STATE.window_shift_samples :, ...
115
- ]
116
- buffer_offset = buffer_offset[self.STATE.window_shift_samples :]
117
-
118
- for out_view, offset in outputs:
119
- out_view = np.moveaxis(out_view, 0, axis_idx)
120
-
121
- if (
122
- self.STATE.cur_settings.newaxis is not None
123
- and self.STATE.cur_settings.newaxis != self.STATE.cur_settings.axis
124
- ):
125
- new_gain = 0.0
126
- if self.STATE.window_shift_samples is not None:
127
- new_gain = axis.gain * self.STATE.window_shift_samples
128
-
129
- out_axis = replace(axis, unit=axis.unit, gain=new_gain, offset=offset)
130
- out_axes = {**msg.axes, **{self.STATE.cur_settings.newaxis: out_axis}}
131
- out_dims = [self.STATE.cur_settings.newaxis] + msg.dims
132
- out_view = out_view[np.newaxis, ...]
133
-
134
- yield self.OUTPUT_SIGNAL, replace(
135
- msg, data=out_view, dims=out_dims, axes=out_axes
136
- )
137
-
138
- else:
139
- if axis_name in msg.axes:
140
- out_axes = msg.axes
141
- out_axes[axis_name] = replace(axis, offset=offset)
142
- yield self.OUTPUT_SIGNAL, replace(msg, data=out_view, axes=out_axes)
143
- else:
144
- yield self.OUTPUT_SIGNAL, replace(msg, data=out_view)
236
+ try:
237
+ # TODO: Refactor window generator so it always returns an axis array.
238
+ # Then, if the configuration is such that a new "win" axis is not expected,
239
+ # then iterate over the "win" axis -- dropping the "win" axis in the process.
240
+ out_msgs = self.STATE.gen.send(msg)
241
+ for out_msg in out_msgs:
242
+ yield self.OUTPUT_SIGNAL, out_msg
243
+ except (StopIteration, GeneratorExit):
244
+ ez.logger.debug(f"Window closed in {self.address}")
245
+ except Exception:
246
+ ez.logger.info(traceback.format_exc())