ezmsg-sigproc 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +4 -1
- ezmsg/sigproc/affinetransform.py +124 -0
- ezmsg/sigproc/aggregate.py +103 -0
- ezmsg/sigproc/bandpower.py +53 -0
- ezmsg/sigproc/butterworthfilter.py +44 -6
- ezmsg/sigproc/downsample.py +52 -26
- ezmsg/sigproc/ewmfilter.py +11 -3
- ezmsg/sigproc/filter.py +82 -14
- ezmsg/sigproc/sampler.py +173 -200
- ezmsg/sigproc/scaler.py +127 -0
- ezmsg/sigproc/signalinjector.py +67 -0
- ezmsg/sigproc/slicer.py +98 -0
- ezmsg/sigproc/spectral.py +9 -132
- ezmsg/sigproc/spectrogram.py +68 -0
- ezmsg/sigproc/spectrum.py +158 -0
- ezmsg/sigproc/synth.py +179 -80
- ezmsg/sigproc/window.py +212 -110
- {ezmsg_sigproc-1.2.1.dist-info → ezmsg_sigproc-1.2.3.dist-info}/METADATA +15 -13
- ezmsg_sigproc-1.2.3.dist-info/RECORD +23 -0
- {ezmsg_sigproc-1.2.1.dist-info → ezmsg_sigproc-1.2.3.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/__version__.py +0 -1
- ezmsg_sigproc-1.2.1.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.1.dist-info/top_level.txt +0 -1
- {ezmsg_sigproc-1.2.1.dist-info → ezmsg_sigproc-1.2.3.dist-info}/LICENSE.txt +0 -0
ezmsg/sigproc/synth.py
CHANGED
|
@@ -1,15 +1,42 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import
|
|
2
|
+
from collections import deque
|
|
3
3
|
from dataclasses import dataclass, replace, field
|
|
4
|
+
import time
|
|
5
|
+
from typing import Optional, Generator, AsyncGenerator, Union
|
|
4
6
|
|
|
5
|
-
import ezmsg.core as ez
|
|
6
7
|
import numpy as np
|
|
7
|
-
|
|
8
|
+
import ezmsg.core as ez
|
|
9
|
+
from ezmsg.util.generator import consumer, GenAxisArray
|
|
8
10
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
11
|
|
|
10
12
|
from .butterworthfilter import ButterworthFilter, ButterworthFilterSettings
|
|
11
13
|
|
|
12
|
-
|
|
14
|
+
|
|
15
|
+
# CLOCK -- generate events at a specified rate #
|
|
16
|
+
def clock(
|
|
17
|
+
dispatch_rate: Optional[float]
|
|
18
|
+
) -> Generator[ez.Flag, None, None]:
|
|
19
|
+
n_dispatch = -1
|
|
20
|
+
t_0 = time.time()
|
|
21
|
+
while True:
|
|
22
|
+
if dispatch_rate is not None:
|
|
23
|
+
n_dispatch += 1
|
|
24
|
+
t_next = t_0 + n_dispatch / dispatch_rate
|
|
25
|
+
time.sleep(max(0, t_next - time.time()))
|
|
26
|
+
yield ez.Flag()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
async def aclock(
|
|
30
|
+
dispatch_rate: Optional[float]
|
|
31
|
+
) -> AsyncGenerator[ez.Flag, None]:
|
|
32
|
+
t_0 = time.time()
|
|
33
|
+
n_dispatch = -1
|
|
34
|
+
while True:
|
|
35
|
+
if dispatch_rate is not None:
|
|
36
|
+
n_dispatch += 1
|
|
37
|
+
t_next = t_0 + n_dispatch / dispatch_rate
|
|
38
|
+
await asyncio.sleep(t_next - time.time())
|
|
39
|
+
yield ez.Flag()
|
|
13
40
|
|
|
14
41
|
|
|
15
42
|
class ClockSettings(ez.Settings):
|
|
@@ -19,6 +46,7 @@ class ClockSettings(ez.Settings):
|
|
|
19
46
|
|
|
20
47
|
class ClockState(ez.State):
|
|
21
48
|
cur_settings: ClockSettings
|
|
49
|
+
gen: AsyncGenerator
|
|
22
50
|
|
|
23
51
|
|
|
24
52
|
class Clock(ez.Unit):
|
|
@@ -30,17 +58,95 @@ class Clock(ez.Unit):
|
|
|
30
58
|
|
|
31
59
|
def initialize(self) -> None:
|
|
32
60
|
self.STATE.cur_settings = self.SETTINGS
|
|
61
|
+
self.construct_generator()
|
|
62
|
+
|
|
63
|
+
def construct_generator(self):
|
|
64
|
+
self.STATE.gen = aclock(self.STATE.cur_settings.dispatch_rate)
|
|
33
65
|
|
|
34
66
|
@ez.subscriber(INPUT_SETTINGS)
|
|
35
67
|
async def on_settings(self, msg: ClockSettings) -> None:
|
|
36
68
|
self.STATE.cur_settings = msg
|
|
69
|
+
self.construct_generator()
|
|
37
70
|
|
|
38
71
|
@ez.publisher(OUTPUT_CLOCK)
|
|
39
72
|
async def generate(self) -> AsyncGenerator:
|
|
40
73
|
while True:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
74
|
+
out = await self.STATE.gen.__anext__()
|
|
75
|
+
if out:
|
|
76
|
+
yield self.OUTPUT_CLOCK, out
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# COUNTER - Generate incrementing integer. fs and dispatch_rate parameters combine to give many options. #
|
|
80
|
+
async def acounter(
|
|
81
|
+
n_time: int, # Number of samples to output per block
|
|
82
|
+
fs: Optional[float], # Sampling rate of signal output in Hz
|
|
83
|
+
n_ch: int = 1, # Number of channels to synthesize
|
|
84
|
+
|
|
85
|
+
# Message dispatch rate (Hz), 'realtime' or None (fast as possible)
|
|
86
|
+
# Note: if dispatch_rate is a float then time offsets will be synthetic and the
|
|
87
|
+
# system will run faster or slower than wall clock time.
|
|
88
|
+
dispatch_rate: Optional[Union[float, str]] = None,
|
|
89
|
+
|
|
90
|
+
# If set to an integer, counter will rollover at this number.
|
|
91
|
+
mod: Optional[int] = None,
|
|
92
|
+
) -> AsyncGenerator[AxisArray, None]:
|
|
93
|
+
|
|
94
|
+
# TODO: Adapt this to use ezmsg.util.rate?
|
|
95
|
+
|
|
96
|
+
counter_start: int = 0 # next sample's first value
|
|
97
|
+
|
|
98
|
+
b_realtime = False
|
|
99
|
+
b_manual_dispatch = False
|
|
100
|
+
b_ext_clock = False
|
|
101
|
+
if dispatch_rate is not None:
|
|
102
|
+
if isinstance(dispatch_rate, str):
|
|
103
|
+
if dispatch_rate.lower() == "realtime":
|
|
104
|
+
b_realtime = True
|
|
105
|
+
elif dispatch_rate.lower() == "ext_clock":
|
|
106
|
+
b_ext_clock = True
|
|
107
|
+
else:
|
|
108
|
+
b_manual_dispatch = True
|
|
109
|
+
|
|
110
|
+
n_sent: int = 0 # It is convenient to know how many samples we have sent.
|
|
111
|
+
clock_zero: float = time.time() # time associated with first sample
|
|
112
|
+
|
|
113
|
+
while True:
|
|
114
|
+
# 1. Sleep, if necessary, until we are at the end of the current block
|
|
115
|
+
if b_realtime:
|
|
116
|
+
n_next = n_sent + n_time
|
|
117
|
+
t_next = clock_zero + n_next / fs
|
|
118
|
+
await asyncio.sleep(t_next - time.time())
|
|
119
|
+
elif b_manual_dispatch:
|
|
120
|
+
n_disp_next = 1 + n_sent / n_time
|
|
121
|
+
t_disp_next = clock_zero + n_disp_next / dispatch_rate
|
|
122
|
+
await asyncio.sleep(t_disp_next - time.time())
|
|
123
|
+
|
|
124
|
+
# 2. Prepare counter data.
|
|
125
|
+
block_samp = np.arange(counter_start, counter_start + n_time)[:, np.newaxis]
|
|
126
|
+
if mod is not None:
|
|
127
|
+
block_samp %= mod
|
|
128
|
+
block_samp = np.tile(block_samp, (1, n_ch))
|
|
129
|
+
|
|
130
|
+
# 3. Prepare offset - the time associated with block_samp[0]
|
|
131
|
+
if b_realtime:
|
|
132
|
+
offset = t_next - n_time / fs
|
|
133
|
+
elif b_ext_clock:
|
|
134
|
+
offset = time.time()
|
|
135
|
+
else:
|
|
136
|
+
# Purely synthetic.
|
|
137
|
+
offset = n_sent / fs
|
|
138
|
+
# offset += clock_zero # ??
|
|
139
|
+
|
|
140
|
+
# 4. yield output
|
|
141
|
+
yield AxisArray(
|
|
142
|
+
block_samp,
|
|
143
|
+
dims=["time", "ch"],
|
|
144
|
+
axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=offset)},
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# 5. Update state for next iteration (after next yield)
|
|
148
|
+
counter_start = block_samp[-1, 0] + 1 # do not % mod
|
|
149
|
+
n_sent += n_time
|
|
44
150
|
|
|
45
151
|
|
|
46
152
|
class CounterSettings(ez.Settings):
|
|
@@ -57,6 +163,8 @@ class CounterSettings(ez.Settings):
|
|
|
57
163
|
n_ch: int = 1 # Number of channels to synthesize
|
|
58
164
|
|
|
59
165
|
# Message dispatch rate (Hz), 'realtime', 'ext_clock', or None (fast as possible)
|
|
166
|
+
# Note: if dispatch_rate is a float then time offsets will be synthetic and the
|
|
167
|
+
# system will run faster or slower than wall clock time.
|
|
60
168
|
dispatch_rate: Optional[Union[float, str]] = None
|
|
61
169
|
|
|
62
170
|
# If set to an integer, counter will rollover
|
|
@@ -64,9 +172,9 @@ class CounterSettings(ez.Settings):
|
|
|
64
172
|
|
|
65
173
|
|
|
66
174
|
class CounterState(ez.State):
|
|
175
|
+
gen: AsyncGenerator[AxisArray, Optional[ez.Flag]]
|
|
67
176
|
cur_settings: CounterSettings
|
|
68
|
-
|
|
69
|
-
clock_event: asyncio.Event
|
|
177
|
+
new_generator: asyncio.Event
|
|
70
178
|
|
|
71
179
|
|
|
72
180
|
class Counter(ez.Unit):
|
|
@@ -79,9 +187,8 @@ class Counter(ez.Unit):
|
|
|
79
187
|
INPUT_SETTINGS = ez.InputStream(CounterSettings)
|
|
80
188
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
81
189
|
|
|
82
|
-
def initialize(self) -> None:
|
|
83
|
-
self.STATE.
|
|
84
|
-
self.STATE.clock_event.clear()
|
|
190
|
+
async def initialize(self) -> None:
|
|
191
|
+
self.STATE.new_generator = asyncio.Event()
|
|
85
192
|
self.validate_settings(self.SETTINGS)
|
|
86
193
|
|
|
87
194
|
@ez.subscriber(INPUT_SETTINGS)
|
|
@@ -93,53 +200,64 @@ class Counter(ez.Unit):
|
|
|
93
200
|
settings.dispatch_rate, str
|
|
94
201
|
) and self.SETTINGS.dispatch_rate not in ["realtime", "ext_clock"]:
|
|
95
202
|
raise ValueError(f"Unknown dispatch_rate: {self.SETTINGS.dispatch_rate}")
|
|
96
|
-
|
|
97
203
|
self.STATE.cur_settings = settings
|
|
98
|
-
|
|
204
|
+
self.construct_generator()
|
|
205
|
+
|
|
206
|
+
def construct_generator(self):
|
|
207
|
+
self.STATE.gen = acounter(
|
|
208
|
+
self.STATE.cur_settings.n_time,
|
|
209
|
+
self.STATE.cur_settings.fs,
|
|
210
|
+
n_ch=self.STATE.cur_settings.n_ch,
|
|
211
|
+
dispatch_rate=self.STATE.cur_settings.dispatch_rate,
|
|
212
|
+
mod=self.STATE.cur_settings.mod
|
|
213
|
+
)
|
|
214
|
+
self.STATE.new_generator.set()
|
|
215
|
+
|
|
99
216
|
@ez.subscriber(INPUT_CLOCK)
|
|
100
|
-
|
|
101
|
-
|
|
217
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
218
|
+
async def on_clock(self, clock: ez.Flag):
|
|
219
|
+
if self.STATE.cur_settings.dispatch_rate == 'ext_clock':
|
|
220
|
+
out = await self.STATE.gen.__anext__()
|
|
221
|
+
yield self.OUTPUT_SIGNAL, out
|
|
102
222
|
|
|
103
223
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
104
|
-
async def
|
|
224
|
+
async def run_generator(self) -> AsyncGenerator:
|
|
105
225
|
while True:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
fs=self.STATE.cur_settings.fs, offset=time.time() - offset_adj
|
|
138
|
-
)
|
|
139
|
-
),
|
|
140
|
-
)
|
|
226
|
+
|
|
227
|
+
await self.STATE.new_generator.wait()
|
|
228
|
+
self.STATE.new_generator.clear()
|
|
229
|
+
|
|
230
|
+
if self.STATE.cur_settings.dispatch_rate == 'ext_clock':
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
while not self.STATE.new_generator.is_set():
|
|
234
|
+
out = await self.STATE.gen.__anext__()
|
|
235
|
+
yield self.OUTPUT_SIGNAL, out
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@consumer
|
|
239
|
+
def sin(
|
|
240
|
+
axis: Optional[str] = "time",
|
|
241
|
+
freq: float = 1.0, # Oscillation frequency in Hz
|
|
242
|
+
amp: float = 1.0, # Amplitude
|
|
243
|
+
phase: float = 0.0, # Phase offset (in radians)
|
|
244
|
+
) -> Generator[AxisArray, AxisArray, None]:
|
|
245
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
246
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
247
|
+
|
|
248
|
+
ang_freq = 2.0 * np.pi * freq
|
|
249
|
+
|
|
250
|
+
while True:
|
|
251
|
+
axis_arr_in = yield axis_arr_out
|
|
252
|
+
# axis_arr_in is expected to be sample counts
|
|
253
|
+
|
|
254
|
+
axis_name = axis
|
|
255
|
+
if axis_name is None:
|
|
256
|
+
axis_name = axis_arr_in.dims[0]
|
|
141
257
|
|
|
142
|
-
|
|
258
|
+
w = (ang_freq * axis_arr_in.get_axis(axis_name).gain) * axis_arr_in.data
|
|
259
|
+
out_data = amp * np.sin(w + phase)
|
|
260
|
+
axis_arr_out = replace(axis_arr_in, data=out_data)
|
|
143
261
|
|
|
144
262
|
|
|
145
263
|
class SinGeneratorSettings(ez.Settings):
|
|
@@ -149,35 +267,16 @@ class SinGeneratorSettings(ez.Settings):
|
|
|
149
267
|
phase: float = 0.0 # Phase offset (in radians)
|
|
150
268
|
|
|
151
269
|
|
|
152
|
-
class
|
|
153
|
-
ang_freq: float # pre-calculated angular frequency in radians
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
class SinGenerator(ez.Unit):
|
|
270
|
+
class SinGenerator(GenAxisArray):
|
|
157
271
|
SETTINGS: SinGeneratorSettings
|
|
158
|
-
STATE: SinGeneratorState
|
|
159
|
-
|
|
160
|
-
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
161
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
162
|
-
|
|
163
|
-
def initialize(self) -> None:
|
|
164
|
-
self.STATE.ang_freq = 2.0 * np.pi * self.SETTINGS.freq
|
|
165
272
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
axis_name = self.SETTINGS.time_axis
|
|
174
|
-
if axis_name is None:
|
|
175
|
-
axis_name = msg.dims[0]
|
|
176
|
-
fs = 1.0 / msg.get_axis(axis_name).gain
|
|
177
|
-
t_sec = msg.data / fs
|
|
178
|
-
w = self.STATE.ang_freq * t_sec
|
|
179
|
-
out_data = self.SETTINGS.amp * np.sin(w + self.SETTINGS.phase)
|
|
180
|
-
yield (self.OUTPUT_SIGNAL, replace(msg, data=out_data))
|
|
273
|
+
def construct_generator(self):
|
|
274
|
+
self.STATE.gen = sin(
|
|
275
|
+
axis=self.SETTINGS.time_axis,
|
|
276
|
+
freq=self.SETTINGS.freq,
|
|
277
|
+
amp=self.SETTINGS.amp,
|
|
278
|
+
phase=self.SETTINGS.phase
|
|
279
|
+
)
|
|
181
280
|
|
|
182
281
|
|
|
183
282
|
class OscillatorSettings(ez.Settings):
|
ezmsg/sigproc/window.py
CHANGED
|
@@ -1,33 +1,207 @@
|
|
|
1
1
|
from dataclasses import replace
|
|
2
|
+
import traceback
|
|
3
|
+
from typing import AsyncGenerator, Optional, Tuple, List, Generator
|
|
2
4
|
|
|
3
5
|
import ezmsg.core as ez
|
|
4
6
|
import numpy as np
|
|
5
7
|
import numpy.typing as npt
|
|
6
8
|
|
|
7
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
-
|
|
9
|
-
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis, sliding_win_oneaxis
|
|
10
|
+
from ezmsg.util.generator import consumer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@consumer
|
|
14
|
+
def windowing(
|
|
15
|
+
axis: Optional[str] = None,
|
|
16
|
+
newaxis: Optional[str] = None,
|
|
17
|
+
window_dur: Optional[float] = None,
|
|
18
|
+
window_shift: Optional[float] = None,
|
|
19
|
+
zero_pad_until: str = "input"
|
|
20
|
+
) -> Generator[AxisArray, List[AxisArray], None]:
|
|
21
|
+
"""
|
|
22
|
+
Window function that generates windows of data from an input `AxisArray`.
|
|
23
|
+
:param axis: The axis along which to segment windows.
|
|
24
|
+
If None, defaults to the first dimension of the first seen AxisArray.
|
|
25
|
+
:param newaxis: Optional new axis for the output. If None, no new axes will be added.
|
|
26
|
+
If a string, windows will be stacked in a new axis with key `newaxis`, immediately preceding the windowed axis.
|
|
27
|
+
:param window_dur: The duration of the window in seconds.
|
|
28
|
+
If None, the function acts as a passthrough and all other parameters are ignored.
|
|
29
|
+
:param window_shift: The shift of the window in seconds.
|
|
30
|
+
If None (default), windowing operates in "1:1 mode", where each input yields exactly one most-recent window.
|
|
31
|
+
:param zero_pad_until: Determines how the function initializes the buffer.
|
|
32
|
+
Can be one of "input" (default), "full", "shift", or "none". If `window_shift` is None then this field is
|
|
33
|
+
ignored and "input" is always used.
|
|
34
|
+
"input" (default) initializes the buffer with the input then prepends with zeros to the window size.
|
|
35
|
+
The first input will always yield at least one output.
|
|
36
|
+
"shift" fills the buffer until `window_shift`.
|
|
37
|
+
No outputs will be yielded until at least `window_shift` data has been seen.
|
|
38
|
+
"none" does not pad the buffer. No outputs will be yielded until at least `window_dur` data has been seen.
|
|
39
|
+
:return:
|
|
40
|
+
A (primed) generator that accepts .send(an AxisArray object) and yields a list of windowed
|
|
41
|
+
AxisArray objects. The list will always be length-1 if `newaxis` is not None or `window_shift` is None.
|
|
42
|
+
"""
|
|
43
|
+
# TODO: The return should be an AxisArray. i.e., always add a new axis. The Unit can do a multi-yield-per-pub
|
|
44
|
+
# if the parameterization does not expect a newaxis.
|
|
45
|
+
|
|
46
|
+
if window_shift is None and zero_pad_until != "input":
|
|
47
|
+
ez.logger.warning("`zero_pad_until` must be 'input' if `window_shift` is None. "
|
|
48
|
+
f"Ignoring received argument value: {zero_pad_until}")
|
|
49
|
+
zero_pad_until = "input"
|
|
50
|
+
elif window_shift is not None and zero_pad_until == "input":
|
|
51
|
+
ez.logger.warning("windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size "
|
|
52
|
+
"of the first input. We recommend using 'shift' when `window_shift` is float-valued.")
|
|
53
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
54
|
+
axis_arr_out = [AxisArray(np.array([]), dims=[""])]
|
|
55
|
+
|
|
56
|
+
# State variables
|
|
57
|
+
prev_samp_shape: Optional[Tuple[int, ...]] = None
|
|
58
|
+
prev_fs: Optional[float] = None
|
|
59
|
+
buffer: Optional[npt.NDArray] = None
|
|
60
|
+
window_samples: Optional[int] = None
|
|
61
|
+
window_shift_samples: Optional[int] = None
|
|
62
|
+
shift_deficit: int = 0 # Number of incoming samples to ignore. Only relevant when shift > window.
|
|
63
|
+
newaxis_warn_flag: bool = False
|
|
64
|
+
mod_ax: Optional[str] = None # The key of the modified axis in the output's .axes
|
|
65
|
+
out_template: Optional[AxisArray] = None # Template for building return values.
|
|
66
|
+
|
|
67
|
+
while True:
|
|
68
|
+
axis_arr_in = yield axis_arr_out
|
|
69
|
+
|
|
70
|
+
if window_dur is None:
|
|
71
|
+
axis_arr_out = [axis_arr_in]
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
if axis is None:
|
|
75
|
+
axis = axis_arr_in.dims[0]
|
|
76
|
+
axis_idx = axis_arr_in.get_axis_idx(axis)
|
|
77
|
+
axis_info = axis_arr_in.get_axis(axis)
|
|
78
|
+
fs = 1.0 / axis_info.gain
|
|
79
|
+
|
|
80
|
+
if (not newaxis_warn_flag) and newaxis is not None and newaxis in axis_arr_in.dims:
|
|
81
|
+
ez.logger.warning(f"newaxis {newaxis} present in input dims and will be ignored.")
|
|
82
|
+
newaxis_warn_flag = True
|
|
83
|
+
b_newaxis = newaxis is not None and newaxis not in axis_arr_in.dims
|
|
84
|
+
|
|
85
|
+
samp_shape = axis_arr_in.data.shape[:axis_idx] + axis_arr_in.data.shape[axis_idx + 1:]
|
|
86
|
+
window_samples = int(window_dur * fs)
|
|
87
|
+
b_1to1 = window_shift is None
|
|
88
|
+
if not b_1to1:
|
|
89
|
+
window_shift_samples = int(window_shift * fs)
|
|
90
|
+
|
|
91
|
+
# If buffer unset or input stats changed, create a new buffer
|
|
92
|
+
if buffer is None or samp_shape != prev_samp_shape or fs != prev_fs:
|
|
93
|
+
if zero_pad_until == "none":
|
|
94
|
+
req_samples = window_samples
|
|
95
|
+
elif zero_pad_until == "shift" and not b_1to1:
|
|
96
|
+
req_samples = window_shift_samples
|
|
97
|
+
else: # i.e. zero_pad_until == "input"
|
|
98
|
+
req_samples = axis_arr_in.data.shape[axis_idx]
|
|
99
|
+
n_zero = max(0, window_samples - req_samples)
|
|
100
|
+
buffer_shape = axis_arr_in.data.shape[:axis_idx] + (n_zero,) + axis_arr_in.data.shape[axis_idx + 1:]
|
|
101
|
+
buffer = np.zeros(buffer_shape)
|
|
102
|
+
prev_samp_shape = samp_shape
|
|
103
|
+
prev_fs = fs
|
|
104
|
+
|
|
105
|
+
# Add new data to buffer.
|
|
106
|
+
# Currently we just concatenate the new time samples and clip the output
|
|
107
|
+
# np.roll actually returns a copy, and there's no way to construct a
|
|
108
|
+
# rolling view of the data. In current numpy implementations, np.concatenate
|
|
109
|
+
# is generally faster than np.roll and slicing anyway, but this could still
|
|
110
|
+
# be a performance bottleneck for large memory arrays.
|
|
111
|
+
buffer = np.concatenate((buffer, axis_arr_in.data), axis=axis_idx)
|
|
112
|
+
# Note: if we ever move to using a circular buffer without copies then we need to create copies somewhere,
|
|
113
|
+
# because currently the outputs are merely views into the buffer.
|
|
114
|
+
|
|
115
|
+
# Create a vector of buffer timestamps to track axis `offset` in output(s)
|
|
116
|
+
buffer_offset = np.arange(buffer.shape[axis_idx]).astype(float)
|
|
117
|
+
# Adjust so first _new_ sample at index 0
|
|
118
|
+
buffer_offset -= buffer_offset[-axis_arr_in.data.shape[axis_idx]]
|
|
119
|
+
# Convert form indices to 'units' (probably seconds).
|
|
120
|
+
buffer_offset *= axis_info.gain
|
|
121
|
+
buffer_offset += axis_info.offset
|
|
122
|
+
|
|
123
|
+
if not b_1to1 and shift_deficit > 0:
|
|
124
|
+
n_skip = min(buffer.shape[axis_idx], shift_deficit)
|
|
125
|
+
if n_skip > 0:
|
|
126
|
+
buffer = slice_along_axis(buffer, np.s_[n_skip:], axis_idx)
|
|
127
|
+
buffer_offset = buffer_offset[n_skip:]
|
|
128
|
+
shift_deficit -= n_skip
|
|
129
|
+
|
|
130
|
+
# Prepare reusable parts of output
|
|
131
|
+
if out_template is None:
|
|
132
|
+
out_dims = axis_arr_in.dims
|
|
133
|
+
if newaxis is None:
|
|
134
|
+
out_axes = {
|
|
135
|
+
**axis_arr_in.axes,
|
|
136
|
+
axis: replace(axis_info, offset=0.0) # offset modified below.
|
|
137
|
+
}
|
|
138
|
+
mod_ax = axis
|
|
139
|
+
else:
|
|
140
|
+
out_dims = out_dims[:axis_idx] + [newaxis] + out_dims[axis_idx:]
|
|
141
|
+
out_axes = {
|
|
142
|
+
**axis_arr_in.axes,
|
|
143
|
+
newaxis: AxisArray.Axis(
|
|
144
|
+
unit=axis_info.unit,
|
|
145
|
+
gain=0.0 if b_1to1 else axis_info.gain * window_shift_samples,
|
|
146
|
+
offset=0.0 # offset modified below
|
|
147
|
+
)
|
|
148
|
+
}
|
|
149
|
+
mod_ax = newaxis
|
|
150
|
+
out_template = replace(axis_arr_in, data=np.zeros([0 for _ in out_dims]), dims=out_dims)
|
|
151
|
+
|
|
152
|
+
# Generate outputs.
|
|
153
|
+
axis_arr_out: List[AxisArray] = []
|
|
154
|
+
if b_1to1:
|
|
155
|
+
# one-to-one mode -- Each send yields exactly one window containing only the most recent samples.
|
|
156
|
+
buffer = slice_along_axis(buffer, np.s_[-window_samples:], axis_idx)
|
|
157
|
+
axis_arr_out.append(replace(
|
|
158
|
+
out_template,
|
|
159
|
+
data=np.expand_dims(buffer, axis=axis_idx) if b_newaxis else buffer,
|
|
160
|
+
axes={
|
|
161
|
+
**out_axes,
|
|
162
|
+
mod_ax: replace(out_axes[mod_ax], offset=buffer_offset[-window_samples])
|
|
163
|
+
}
|
|
164
|
+
))
|
|
165
|
+
elif buffer.shape[axis_idx] >= window_samples:
|
|
166
|
+
# Deterministic window shifts.
|
|
167
|
+
win_view = sliding_win_oneaxis(buffer, window_samples, axis_idx)
|
|
168
|
+
win_view = slice_along_axis(win_view, np.s_[::window_shift_samples], axis_idx)
|
|
169
|
+
offset_view = sliding_win_oneaxis(buffer_offset, window_samples, 0)[::window_shift_samples]
|
|
170
|
+
# Place in output
|
|
171
|
+
if b_newaxis:
|
|
172
|
+
axis_arr_out.append(replace(
|
|
173
|
+
out_template,
|
|
174
|
+
data=win_view,
|
|
175
|
+
axes={**out_axes, mod_ax: replace(out_axes[mod_ax], offset=offset_view[0, 0])}
|
|
176
|
+
))
|
|
177
|
+
else:
|
|
178
|
+
for win_ix in range(win_view.shape[axis_idx]):
|
|
179
|
+
axis_arr_out.append(replace(
|
|
180
|
+
out_template,
|
|
181
|
+
data=slice_along_axis(win_view, win_ix, axis_idx),
|
|
182
|
+
axes={
|
|
183
|
+
**out_axes,
|
|
184
|
+
mod_ax: replace(out_axes[mod_ax], offset=offset_view[win_ix, 0])
|
|
185
|
+
}
|
|
186
|
+
))
|
|
187
|
+
|
|
188
|
+
# Drop expired beginning of buffer and update shift_deficit
|
|
189
|
+
multi_shift = window_shift_samples * win_view.shape[axis_idx]
|
|
190
|
+
shift_deficit = max(0, multi_shift - buffer.shape[axis_idx])
|
|
191
|
+
buffer = slice_along_axis(buffer, np.s_[multi_shift:], axis_idx)
|
|
10
192
|
|
|
11
193
|
|
|
12
194
|
class WindowSettings(ez.Settings):
|
|
13
195
|
axis: Optional[str] = None
|
|
14
|
-
newaxis: Optional[
|
|
15
|
-
|
|
16
|
-
] = None #
|
|
17
|
-
|
|
18
|
-
float
|
|
19
|
-
] = None # Sec. If "None" -- passthrough; window_shift is ignored.
|
|
20
|
-
window_shift: Optional[float] = None # Sec. If "None", activate "1:1 mode"
|
|
196
|
+
newaxis: Optional[str] = None # new axis for output. No new axes if None
|
|
197
|
+
window_dur: Optional[float] = None # Sec. passthrough if None
|
|
198
|
+
window_shift: Optional[float] = None # Sec. Use "1:1 mode" if None
|
|
199
|
+
zero_pad_until: str = "full" # "full", "shift", "input", "none"
|
|
21
200
|
|
|
22
201
|
|
|
23
202
|
class WindowState(ez.State):
|
|
24
203
|
cur_settings: WindowSettings
|
|
25
|
-
|
|
26
|
-
samp_shape: Optional[Tuple[int, ...]] = None # Shape of individual sample
|
|
27
|
-
out_fs: Optional[float] = None
|
|
28
|
-
buffer: Optional[npt.NDArray] = None
|
|
29
|
-
window_samples: Optional[int] = None
|
|
30
|
-
window_shift_samples: Optional[int] = None
|
|
204
|
+
gen: Generator
|
|
31
205
|
|
|
32
206
|
|
|
33
207
|
class Window(ez.Unit):
|
|
@@ -40,105 +214,33 @@ class Window(ez.Unit):
|
|
|
40
214
|
|
|
41
215
|
def initialize(self) -> None:
|
|
42
216
|
self.STATE.cur_settings = self.SETTINGS
|
|
217
|
+
self.construct_generator()
|
|
43
218
|
|
|
44
219
|
@ez.subscriber(INPUT_SETTINGS)
|
|
45
220
|
async def on_settings(self, msg: WindowSettings) -> None:
|
|
46
221
|
self.STATE.cur_settings = msg
|
|
47
|
-
self.
|
|
222
|
+
self.construct_generator()
|
|
223
|
+
|
|
224
|
+
def construct_generator(self):
|
|
225
|
+
self.STATE.gen = windowing(
|
|
226
|
+
axis=self.STATE.cur_settings.axis,
|
|
227
|
+
newaxis=self.STATE.cur_settings.newaxis,
|
|
228
|
+
window_dur=self.STATE.cur_settings.window_dur,
|
|
229
|
+
window_shift=self.STATE.cur_settings.window_shift,
|
|
230
|
+
zero_pad_until=self.STATE.cur_settings.zero_pad_until
|
|
231
|
+
)
|
|
48
232
|
|
|
49
233
|
@ez.subscriber(INPUT_SIGNAL)
|
|
50
234
|
@ez.publisher(OUTPUT_SIGNAL)
|
|
51
235
|
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
# Create a view of data with time axis at dim 0
|
|
64
|
-
time_view = np.moveaxis(msg.data, axis_idx, 0)
|
|
65
|
-
samp_shape = time_view.shape[1:]
|
|
66
|
-
|
|
67
|
-
# Pre(re?)allocate buffer
|
|
68
|
-
window_samples = int(self.STATE.cur_settings.window_dur * fs)
|
|
69
|
-
if (
|
|
70
|
-
(self.STATE.samp_shape != samp_shape)
|
|
71
|
-
or (self.STATE.out_fs != fs)
|
|
72
|
-
or self.STATE.buffer is None
|
|
73
|
-
):
|
|
74
|
-
self.STATE.buffer = np.zeros(tuple([window_samples] + list(samp_shape)))
|
|
75
|
-
|
|
76
|
-
self.STATE.window_samples = window_samples
|
|
77
|
-
self.STATE.samp_shape = samp_shape
|
|
78
|
-
self.STATE.out_fs = fs
|
|
79
|
-
|
|
80
|
-
self.STATE.window_shift_samples = None
|
|
81
|
-
if self.STATE.cur_settings.window_shift is not None:
|
|
82
|
-
self.STATE.window_shift_samples = int(
|
|
83
|
-
fs * self.STATE.cur_settings.window_shift
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
# Currently we just concatenate the new time samples and clip the output
|
|
87
|
-
# np.roll actually returns a copy, and there's no way to construct a
|
|
88
|
-
# rolling view of the data. In current numpy implementations, np.concatenate
|
|
89
|
-
# is generally faster than np.roll and slicing anyway, but this could still
|
|
90
|
-
# be a performance bottleneck for large memory arrays.
|
|
91
|
-
self.STATE.buffer = np.concatenate((self.STATE.buffer, time_view), axis=0)
|
|
92
|
-
|
|
93
|
-
buffer_offset = np.arange(self.STATE.buffer.shape[0] + time_view.shape[0])
|
|
94
|
-
buffer_offset -= self.STATE.buffer.shape[0] + 1
|
|
95
|
-
buffer_offset = (buffer_offset * axis.gain) + axis.offset
|
|
96
|
-
|
|
97
|
-
outputs: List[Tuple[npt.NDArray, float]] = []
|
|
98
|
-
|
|
99
|
-
if self.STATE.window_shift_samples is None: # one-to-one mode
|
|
100
|
-
self.STATE.buffer = self.STATE.buffer[-self.STATE.window_samples :, ...]
|
|
101
|
-
buffer_offset = buffer_offset[-self.STATE.window_samples :]
|
|
102
|
-
outputs.append((self.STATE.buffer, buffer_offset[0]))
|
|
103
|
-
|
|
104
|
-
else:
|
|
105
|
-
yieldable_size = self.STATE.window_samples + self.STATE.window_shift_samples
|
|
106
|
-
while self.STATE.buffer.shape[0] >= yieldable_size:
|
|
107
|
-
outputs.append(
|
|
108
|
-
(
|
|
109
|
-
self.STATE.buffer[: self.STATE.window_samples, ...],
|
|
110
|
-
buffer_offset[0],
|
|
111
|
-
)
|
|
112
|
-
)
|
|
113
|
-
self.STATE.buffer = self.STATE.buffer[
|
|
114
|
-
self.STATE.window_shift_samples :, ...
|
|
115
|
-
]
|
|
116
|
-
buffer_offset = buffer_offset[self.STATE.window_shift_samples :]
|
|
117
|
-
|
|
118
|
-
for out_view, offset in outputs:
|
|
119
|
-
out_view = np.moveaxis(out_view, 0, axis_idx)
|
|
120
|
-
|
|
121
|
-
if (
|
|
122
|
-
self.STATE.cur_settings.newaxis is not None
|
|
123
|
-
and self.STATE.cur_settings.newaxis != self.STATE.cur_settings.axis
|
|
124
|
-
):
|
|
125
|
-
new_gain = 0.0
|
|
126
|
-
if self.STATE.window_shift_samples is not None:
|
|
127
|
-
new_gain = axis.gain * self.STATE.window_shift_samples
|
|
128
|
-
|
|
129
|
-
out_axis = replace(axis, unit=axis.unit, gain=new_gain, offset=offset)
|
|
130
|
-
out_axes = {**msg.axes, **{self.STATE.cur_settings.newaxis: out_axis}}
|
|
131
|
-
out_dims = [self.STATE.cur_settings.newaxis] + msg.dims
|
|
132
|
-
out_view = out_view[np.newaxis, ...]
|
|
133
|
-
|
|
134
|
-
yield self.OUTPUT_SIGNAL, replace(
|
|
135
|
-
msg, data=out_view, dims=out_dims, axes=out_axes
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
else:
|
|
139
|
-
if axis_name in msg.axes:
|
|
140
|
-
out_axes = msg.axes
|
|
141
|
-
out_axes[axis_name] = replace(axis, offset=offset)
|
|
142
|
-
yield self.OUTPUT_SIGNAL, replace(msg, data=out_view, axes=out_axes)
|
|
143
|
-
else:
|
|
144
|
-
yield self.OUTPUT_SIGNAL, replace(msg, data=out_view)
|
|
236
|
+
try:
|
|
237
|
+
# TODO: Refactor window generator so it always returns an axis array.
|
|
238
|
+
# Then, if the configuration is such that a new "win" axis is not expected,
|
|
239
|
+
# then iterate over the "win" axis -- dropping the "win" axis in the process.
|
|
240
|
+
out_msgs = self.STATE.gen.send(msg)
|
|
241
|
+
for out_msg in out_msgs:
|
|
242
|
+
yield self.OUTPUT_SIGNAL, out_msg
|
|
243
|
+
except (StopIteration, GeneratorExit):
|
|
244
|
+
ez.logger.debug(f"Window closed in {self.address}")
|
|
245
|
+
except Exception:
|
|
246
|
+
ez.logger.info(traceback.format_exc())
|