ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
- ezmsg/sigproc/affinetransform.py +13 -38
- ezmsg/sigproc/aggregate.py +13 -30
- ezmsg/sigproc/bandpower.py +7 -15
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +123 -0
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/decimate.py +2 -6
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +6 -14
- ezmsg/sigproc/ewma.py +11 -27
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +31 -56
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +33 -70
- ezmsg/sigproc/filterbankdesign.py +5 -12
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +1 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +1 -1
- ezmsg/sigproc/math/difference.py +98 -36
- ezmsg/sigproc/math/invert.py +1 -3
- ezmsg/sigproc/math/log.py +2 -6
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +2 -4
- ezmsg/sigproc/resample.py +13 -34
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +17 -35
- ezmsg/sigproc/scaler.py +8 -18
- ezmsg/sigproc/signalinjector.py +6 -16
- ezmsg/sigproc/slicer.py +9 -28
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +12 -32
- ezmsg/sigproc/transpose.py +7 -18
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +10 -26
- ezmsg/sigproc/util/buffer.py +18 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +5 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +6 -15
- ezmsg/sigproc/window.py +24 -78
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
- ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/spectrum.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import enum
|
|
2
|
-
from functools import partial
|
|
3
2
|
import typing
|
|
3
|
+
from functools import partial
|
|
4
4
|
|
|
5
|
+
import ezmsg.core as ez
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
|
-
import ezmsg.core as ez
|
|
8
8
|
from ezmsg.util.messages.axisarray import (
|
|
9
9
|
AxisArray,
|
|
10
|
-
slice_along_axis,
|
|
11
10
|
replace,
|
|
11
|
+
slice_along_axis,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
from .base import (
|
|
@@ -127,17 +127,13 @@ class SpectrumState:
|
|
|
127
127
|
window: npt.NDArray | None = None
|
|
128
128
|
|
|
129
129
|
|
|
130
|
-
class SpectrumTransformer(
|
|
131
|
-
BaseStatefulTransformer[SpectrumSettings, AxisArray, AxisArray, SpectrumState]
|
|
132
|
-
):
|
|
130
|
+
class SpectrumTransformer(BaseStatefulTransformer[SpectrumSettings, AxisArray, AxisArray, SpectrumState]):
|
|
133
131
|
def _hash_message(self, message: AxisArray) -> int:
|
|
134
132
|
axis = self.settings.axis or message.dims[0]
|
|
135
133
|
ax_idx = message.get_axis_idx(axis)
|
|
136
134
|
ax_info = message.axes[axis]
|
|
137
135
|
targ_len = message.data.shape[ax_idx]
|
|
138
|
-
return hash(
|
|
139
|
-
(targ_len, message.data.ndim, message.data.dtype.kind, ax_idx, ax_info.gain)
|
|
140
|
-
)
|
|
136
|
+
return hash((targ_len, message.data.ndim, message.data.dtype.kind, ax_idx, ax_info.gain))
|
|
141
137
|
|
|
142
138
|
def _reset_state(self, message: AxisArray) -> None:
|
|
143
139
|
axis = self.settings.axis or message.dims[0]
|
|
@@ -156,8 +152,7 @@ class SpectrumTransformer(
|
|
|
156
152
|
+ [1] * (message.data.ndim - 1 - ax_idx)
|
|
157
153
|
)
|
|
158
154
|
if self.settings.transform != SpectralTransform.RAW_COMPLEX and not (
|
|
159
|
-
self.settings.transform == SpectralTransform.REAL
|
|
160
|
-
or self.settings.transform == SpectralTransform.IMAG
|
|
155
|
+
self.settings.transform == SpectralTransform.REAL or self.settings.transform == SpectralTransform.IMAG
|
|
161
156
|
):
|
|
162
157
|
scale = np.sum(window**2.0) * ax_info.gain
|
|
163
158
|
|
|
@@ -170,30 +165,21 @@ class SpectrumTransformer(
|
|
|
170
165
|
if (not b_complex) and self.settings.output == SpectralOutput.POSITIVE:
|
|
171
166
|
# If input is not complex and desired output is SpectralOutput.POSITIVE, we can save some computation
|
|
172
167
|
# by using rfft and rfftfreq.
|
|
173
|
-
self.state.fftfun = partial(
|
|
174
|
-
np.fft.rfft, n=nfft, axis=ax_idx, norm=self.settings.norm
|
|
175
|
-
)
|
|
168
|
+
self.state.fftfun = partial(np.fft.rfft, n=nfft, axis=ax_idx, norm=self.settings.norm)
|
|
176
169
|
freqs = np.fft.rfftfreq(nfft, d=ax_info.gain * targ_len / nfft)
|
|
177
170
|
else:
|
|
178
|
-
self.state.fftfun = partial(
|
|
179
|
-
np.fft.fft, n=nfft, axis=ax_idx, norm=self.settings.norm
|
|
180
|
-
)
|
|
171
|
+
self.state.fftfun = partial(np.fft.fft, n=nfft, axis=ax_idx, norm=self.settings.norm)
|
|
181
172
|
freqs = np.fft.fftfreq(nfft, d=ax_info.gain * targ_len / nfft)
|
|
182
173
|
if self.settings.output == SpectralOutput.POSITIVE:
|
|
183
174
|
self.state.f_sl = slice(None, nfft // 2 + 1 - (nfft % 2))
|
|
184
175
|
elif self.settings.output == SpectralOutput.NEGATIVE:
|
|
185
176
|
freqs = np.fft.fftshift(freqs, axes=-1)
|
|
186
177
|
self.state.f_sl = slice(None, nfft // 2 + 1)
|
|
187
|
-
elif
|
|
188
|
-
self.settings.do_fftshift
|
|
189
|
-
and self.settings.output == SpectralOutput.FULL
|
|
190
|
-
):
|
|
178
|
+
elif self.settings.do_fftshift and self.settings.output == SpectralOutput.FULL:
|
|
191
179
|
freqs = np.fft.fftshift(freqs, axes=-1)
|
|
192
180
|
freqs = freqs[self.state.f_sl]
|
|
193
181
|
freqs = freqs.tolist() # To please type checking
|
|
194
|
-
self.state.freq_axis = AxisArray.LinearAxis(
|
|
195
|
-
unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0]
|
|
196
|
-
)
|
|
182
|
+
self.state.freq_axis = AxisArray.LinearAxis(unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0])
|
|
197
183
|
self.state.new_dims = (
|
|
198
184
|
message.dims[:ax_idx]
|
|
199
185
|
+ [
|
|
@@ -232,11 +218,7 @@ class SpectrumTransformer(
|
|
|
232
218
|
ax_idx = message.get_axis_idx(axis)
|
|
233
219
|
targ_len = message.data.shape[ax_idx]
|
|
234
220
|
|
|
235
|
-
new_axes = {
|
|
236
|
-
k: v
|
|
237
|
-
for k, v in message.axes.items()
|
|
238
|
-
if k not in [self.settings.out_axis, axis]
|
|
239
|
-
}
|
|
221
|
+
new_axes = {k: v for k, v in message.axes.items() if k not in [self.settings.out_axis, axis]}
|
|
240
222
|
new_axes[self.settings.out_axis or axis] = self.state.freq_axis
|
|
241
223
|
|
|
242
224
|
if self.state.window is not None:
|
|
@@ -261,9 +243,7 @@ class SpectrumTransformer(
|
|
|
261
243
|
return msg_out
|
|
262
244
|
|
|
263
245
|
|
|
264
|
-
class Spectrum(
|
|
265
|
-
BaseTransformerUnit[SpectrumSettings, AxisArray, AxisArray, SpectrumTransformer]
|
|
266
|
-
):
|
|
246
|
+
class Spectrum(BaseTransformerUnit[SpectrumSettings, AxisArray, AxisArray, SpectrumTransformer]):
|
|
267
247
|
SETTINGS = SpectrumSettings
|
|
268
248
|
|
|
269
249
|
|
ezmsg/sigproc/transpose.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from types import EllipsisType
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
4
5
|
from ezmsg.util.messages.axisarray import (
|
|
5
6
|
AxisArray,
|
|
6
7
|
replace,
|
|
@@ -30,9 +31,7 @@ class TransposeState:
|
|
|
30
31
|
axes_ints: tuple[int, ...] | None = None
|
|
31
32
|
|
|
32
33
|
|
|
33
|
-
class TransposeTransformer(
|
|
34
|
-
BaseStatefulTransformer[TransposeSettings, AxisArray, AxisArray, TransposeState]
|
|
35
|
-
):
|
|
34
|
+
class TransposeTransformer(BaseStatefulTransformer[TransposeSettings, AxisArray, AxisArray, TransposeState]):
|
|
36
35
|
"""
|
|
37
36
|
Downsampled data simply comprise every `factor`th sample.
|
|
38
37
|
This should only be used following appropriate lowpass filtering.
|
|
@@ -67,11 +66,7 @@ class TransposeTransformer(
|
|
|
67
66
|
if ax not in message.dims:
|
|
68
67
|
raise ValueError(f"Axis {ax} not found in message dims.")
|
|
69
68
|
suffix.append(message.dims.index(ax))
|
|
70
|
-
ells = [
|
|
71
|
-
_
|
|
72
|
-
for _ in range(message.data.ndim)
|
|
73
|
-
if _ not in prefix and _ not in suffix
|
|
74
|
-
]
|
|
69
|
+
ells = [_ for _ in range(message.data.ndim) if _ not in prefix and _ not in suffix]
|
|
75
70
|
re_ix = tuple(prefix + ells + suffix)
|
|
76
71
|
if re_ix == tuple(range(message.data.ndim)):
|
|
77
72
|
self._state.axes_ints = None
|
|
@@ -100,17 +95,13 @@ class TransposeTransformer(
|
|
|
100
95
|
# If the memory is already contiguous in the correct order, np.require won't do anything.
|
|
101
96
|
msg_out = replace(
|
|
102
97
|
message,
|
|
103
|
-
data=np.require(
|
|
104
|
-
message.data, requirements=self.settings.order.upper()[0]
|
|
105
|
-
),
|
|
98
|
+
data=np.require(message.data, requirements=self.settings.order.upper()[0]),
|
|
106
99
|
)
|
|
107
100
|
else:
|
|
108
101
|
dims_out = [message.dims[ix] for ix in self.state.axes_ints]
|
|
109
102
|
data_out = np.transpose(message.data, axes=self.state.axes_ints)
|
|
110
103
|
if self.settings.order is not None:
|
|
111
|
-
data_out = np.require(
|
|
112
|
-
data_out, requirements=self.settings.order.upper()[0]
|
|
113
|
-
)
|
|
104
|
+
data_out = np.require(data_out, requirements=self.settings.order.upper()[0])
|
|
114
105
|
msg_out = replace(
|
|
115
106
|
message,
|
|
116
107
|
data=data_out,
|
|
@@ -119,9 +110,7 @@ class TransposeTransformer(
|
|
|
119
110
|
return msg_out
|
|
120
111
|
|
|
121
112
|
|
|
122
|
-
class Transpose(
|
|
123
|
-
BaseTransformerUnit[TransposeSettings, AxisArray, AxisArray, TransposeTransformer]
|
|
124
|
-
):
|
|
113
|
+
class Transpose(BaseTransformerUnit[TransposeSettings, AxisArray, AxisArray, TransposeTransformer]):
|
|
125
114
|
SETTINGS = TransposeSettings
|
|
126
115
|
|
|
127
116
|
|
ezmsg/sigproc/util/asio.py
CHANGED
|
@@ -1,156 +1,25 @@
|
|
|
1
|
-
|
|
2
|
-
from
|
|
3
|
-
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
The result of the coroutine execution
|
|
27
|
-
|
|
28
|
-
Raises:
|
|
29
|
-
CoroutineExecutionError: If execution fails due to threading or event loop issues
|
|
30
|
-
TimeoutError: If execution exceeds the timeout period
|
|
31
|
-
Exception: Any exception raised by the coroutine
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
def run_in_new_loop() -> T:
|
|
35
|
-
"""
|
|
36
|
-
Creates and runs a new event loop in the current thread.
|
|
37
|
-
Ensures proper cleanup of the loop.
|
|
38
|
-
"""
|
|
39
|
-
new_loop = asyncio.new_event_loop()
|
|
40
|
-
asyncio.set_event_loop(new_loop)
|
|
41
|
-
try:
|
|
42
|
-
return new_loop.run_until_complete(
|
|
43
|
-
asyncio.wait_for(coroutine, timeout=timeout)
|
|
44
|
-
)
|
|
45
|
-
finally:
|
|
46
|
-
with contextlib.suppress(Exception):
|
|
47
|
-
# Clean up any pending tasks
|
|
48
|
-
pending = asyncio.all_tasks(new_loop)
|
|
49
|
-
for task in pending:
|
|
50
|
-
task.cancel()
|
|
51
|
-
new_loop.run_until_complete(
|
|
52
|
-
asyncio.gather(*pending, return_exceptions=True)
|
|
53
|
-
)
|
|
54
|
-
new_loop.close()
|
|
55
|
-
|
|
56
|
-
try:
|
|
57
|
-
loop = asyncio.get_running_loop()
|
|
58
|
-
except RuntimeError:
|
|
59
|
-
try:
|
|
60
|
-
return asyncio.run(asyncio.wait_for(coroutine, timeout=timeout))
|
|
61
|
-
except Exception as e:
|
|
62
|
-
raise CoroutineExecutionError(
|
|
63
|
-
f"Failed to execute coroutine: {str(e)}"
|
|
64
|
-
) from e
|
|
65
|
-
|
|
66
|
-
if threading.current_thread() is threading.main_thread():
|
|
67
|
-
if not loop.is_running():
|
|
68
|
-
try:
|
|
69
|
-
return loop.run_until_complete(
|
|
70
|
-
asyncio.wait_for(coroutine, timeout=timeout)
|
|
71
|
-
)
|
|
72
|
-
except Exception as e:
|
|
73
|
-
raise CoroutineExecutionError(
|
|
74
|
-
f"Failed to execute coroutine in main loop: {str(e)}"
|
|
75
|
-
) from e
|
|
76
|
-
else:
|
|
77
|
-
with ThreadPoolExecutor() as pool:
|
|
78
|
-
try:
|
|
79
|
-
future = pool.submit(run_in_new_loop)
|
|
80
|
-
return future.result(timeout=timeout)
|
|
81
|
-
except Exception as e:
|
|
82
|
-
raise CoroutineExecutionError(
|
|
83
|
-
f"Failed to execute coroutine in thread: {str(e)}"
|
|
84
|
-
) from e
|
|
85
|
-
else:
|
|
86
|
-
try:
|
|
87
|
-
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
|
88
|
-
return future.result(timeout=timeout)
|
|
89
|
-
except Exception as e:
|
|
90
|
-
raise CoroutineExecutionError(
|
|
91
|
-
f"Failed to execute coroutine threadsafe: {str(e)}"
|
|
92
|
-
) from e
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
class SyncToAsyncGeneratorWrapper:
|
|
96
|
-
"""
|
|
97
|
-
A wrapper for synchronous generators to be used in an async context.
|
|
98
|
-
"""
|
|
99
|
-
|
|
100
|
-
def __init__(self, gen):
|
|
101
|
-
self._gen = gen
|
|
102
|
-
self._closed = False
|
|
103
|
-
# Prime the generator to ready for first send/next call
|
|
104
|
-
try:
|
|
105
|
-
is_not_primed = inspect.getgeneratorstate(self._gen) is inspect.GEN_CREATED
|
|
106
|
-
except AttributeError as e:
|
|
107
|
-
raise TypeError(
|
|
108
|
-
"The provided generator is not a valid generator object"
|
|
109
|
-
) from e
|
|
110
|
-
if is_not_primed:
|
|
111
|
-
try:
|
|
112
|
-
next(self._gen)
|
|
113
|
-
except StopIteration:
|
|
114
|
-
self._closed = True
|
|
115
|
-
except Exception as e:
|
|
116
|
-
raise RuntimeError(f"Failed to prime generator: {e}") from e
|
|
117
|
-
|
|
118
|
-
async def asend(self, value):
|
|
119
|
-
if self._closed:
|
|
120
|
-
raise StopAsyncIteration("Generator is closed")
|
|
121
|
-
try:
|
|
122
|
-
return await asyncio.to_thread(self._gen.send, value)
|
|
123
|
-
except StopIteration as e:
|
|
124
|
-
self._closed = True
|
|
125
|
-
raise StopAsyncIteration("Generator is closed") from e
|
|
126
|
-
except Exception as e:
|
|
127
|
-
raise RuntimeError(f"Error while sending value to generator: {e}") from e
|
|
128
|
-
|
|
129
|
-
async def __anext__(self):
|
|
130
|
-
if self._closed:
|
|
131
|
-
raise StopAsyncIteration("Generator is closed")
|
|
132
|
-
try:
|
|
133
|
-
return await asyncio.to_thread(self._gen.__next__)
|
|
134
|
-
except StopIteration as e:
|
|
135
|
-
self._closed = True
|
|
136
|
-
raise StopAsyncIteration("Generator is closed") from e
|
|
137
|
-
except Exception as e:
|
|
138
|
-
raise RuntimeError(
|
|
139
|
-
f"Error while getting next value from generator: {e}"
|
|
140
|
-
) from e
|
|
141
|
-
|
|
142
|
-
async def aclose(self):
|
|
143
|
-
if self._closed:
|
|
144
|
-
return
|
|
145
|
-
try:
|
|
146
|
-
await asyncio.to_thread(self._gen.close)
|
|
147
|
-
except Exception as e:
|
|
148
|
-
raise RuntimeError(f"Error while closing generator: {e}") from e
|
|
149
|
-
finally:
|
|
150
|
-
self._closed = True
|
|
151
|
-
|
|
152
|
-
def __aiter__(self):
|
|
153
|
-
return self
|
|
154
|
-
|
|
155
|
-
def __getattr__(self, name):
|
|
156
|
-
return getattr(self._gen, name)
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.util.asio.
|
|
3
|
+
|
|
4
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
warnings.warn(
|
|
10
|
+
"Importing from 'ezmsg.sigproc.util.asio' is deprecated. Please import from 'ezmsg.baseproc.util.asio' instead.",
|
|
11
|
+
DeprecationWarning,
|
|
12
|
+
stacklevel=2,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from ezmsg.baseproc.util.asio import ( # noqa: E402
|
|
16
|
+
CoroutineExecutionError,
|
|
17
|
+
SyncToAsyncGeneratorWrapper,
|
|
18
|
+
run_coroutine_sync,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"CoroutineExecutionError",
|
|
23
|
+
"SyncToAsyncGeneratorWrapper",
|
|
24
|
+
"run_coroutine_sync",
|
|
25
|
+
]
|
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import typing
|
|
3
3
|
|
|
4
|
-
from array_api_compat import get_namespace
|
|
5
4
|
import numpy as np
|
|
6
|
-
from
|
|
5
|
+
from array_api_compat import get_namespace
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis, LinearAxis
|
|
7
7
|
from ezmsg.util.messages.util import replace
|
|
8
8
|
|
|
9
9
|
from .buffer import HybridBuffer
|
|
10
10
|
|
|
11
|
-
|
|
12
11
|
Array = typing.TypeVar("Array")
|
|
13
12
|
|
|
14
13
|
|
|
@@ -68,9 +67,7 @@ class HybridAxisBuffer:
|
|
|
68
67
|
if hasattr(first_axis, "data"):
|
|
69
68
|
# Initialize a CoordinateAxis buffer
|
|
70
69
|
if len(first_axis.data) > 1:
|
|
71
|
-
_axis_gain = (first_axis.data[-1] - first_axis.data[0]) / (
|
|
72
|
-
len(first_axis.data) - 1
|
|
73
|
-
)
|
|
70
|
+
_axis_gain = (first_axis.data[-1] - first_axis.data[0]) / (len(first_axis.data) - 1)
|
|
74
71
|
else:
|
|
75
72
|
_axis_gain = 1.0
|
|
76
73
|
self._coords_gain_estimate = _axis_gain
|
|
@@ -107,8 +104,7 @@ class HybridAxisBuffer:
|
|
|
107
104
|
)
|
|
108
105
|
if axis.gain != self._linear_axis.gain:
|
|
109
106
|
raise ValueError(
|
|
110
|
-
f"Buffer initialized with gain={self._linear_axis.gain}, "
|
|
111
|
-
f"but received gain={axis.gain}."
|
|
107
|
+
f"Buffer initialized with gain={self._linear_axis.gain}, but received gain={axis.gain}."
|
|
112
108
|
)
|
|
113
109
|
if self._linear_n_available + n_samples > self.capacity:
|
|
114
110
|
# Simulate overflow by advancing the offset and decreasing
|
|
@@ -117,16 +113,12 @@ class HybridAxisBuffer:
|
|
|
117
113
|
self.seek(n_to_discard)
|
|
118
114
|
# Update the offset corresponding to the oldest sample in the buffer
|
|
119
115
|
# by anchoring on the new offset and accounting for the samples already available.
|
|
120
|
-
self._linear_axis.offset =
|
|
121
|
-
axis.offset - self._linear_n_available * axis.gain
|
|
122
|
-
)
|
|
116
|
+
self._linear_axis.offset = axis.offset - self._linear_n_available * axis.gain
|
|
123
117
|
self._linear_n_available += n_samples
|
|
124
118
|
|
|
125
119
|
def peek(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis:
|
|
126
120
|
if self._coords_buffer is not None:
|
|
127
|
-
return replace(
|
|
128
|
-
self._coords_template, data=self._coords_buffer.peek(n_samples)
|
|
129
|
-
)
|
|
121
|
+
return replace(self._coords_template, data=self._coords_buffer.peek(n_samples))
|
|
130
122
|
else:
|
|
131
123
|
# Return a shallow copy.
|
|
132
124
|
return replace(self._linear_axis, offset=self._linear_axis.offset)
|
|
@@ -184,13 +176,9 @@ class HybridAxisBuffer:
|
|
|
184
176
|
else:
|
|
185
177
|
return None
|
|
186
178
|
|
|
187
|
-
def searchsorted(
|
|
188
|
-
self, values: typing.Union[float, Array], side: str = "left"
|
|
189
|
-
) -> typing.Union[int, Array]:
|
|
179
|
+
def searchsorted(self, values: typing.Union[float, Array], side: str = "left") -> typing.Union[int, Array]:
|
|
190
180
|
if self._coords_buffer is not None:
|
|
191
|
-
return self._coords_buffer.xp.searchsorted(
|
|
192
|
-
self._coords_buffer.peek(self.available()), values, side=side
|
|
193
|
-
)
|
|
181
|
+
return self._coords_buffer.xp.searchsorted(self._coords_buffer.peek(self.available()), values, side=side)
|
|
194
182
|
else:
|
|
195
183
|
if self.available() == 0:
|
|
196
184
|
if isinstance(values, float):
|
|
@@ -312,9 +300,7 @@ class HybridAxisArrayBuffer:
|
|
|
312
300
|
axes={**self._template_msg.axes, self._axis: out_axis},
|
|
313
301
|
)
|
|
314
302
|
|
|
315
|
-
def peek_axis(
|
|
316
|
-
self, n_samples: int | None = None
|
|
317
|
-
) -> LinearAxis | CoordinateAxis | None:
|
|
303
|
+
def peek_axis(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis | None:
|
|
318
304
|
"""Retrieves the axis data without advancing the read head."""
|
|
319
305
|
if self._data_buffer is None:
|
|
320
306
|
return None
|
|
@@ -369,9 +355,7 @@ class HybridAxisArrayBuffer:
|
|
|
369
355
|
"""
|
|
370
356
|
return self._axis_buffer.gain
|
|
371
357
|
|
|
372
|
-
def axis_searchsorted(
|
|
373
|
-
self, values: typing.Union[float, Array], side: str = "left"
|
|
374
|
-
) -> typing.Union[int, Array]:
|
|
358
|
+
def axis_searchsorted(self, values: typing.Union[float, Array], side: str = "left") -> typing.Union[int, Array]:
|
|
375
359
|
"""
|
|
376
360
|
Find the indices into which the given values would be inserted
|
|
377
361
|
into the target axis data to maintain order.
|
ezmsg/sigproc/util/buffer.py
CHANGED
|
@@ -63,9 +63,7 @@ class HybridBuffer:
|
|
|
63
63
|
self._buff_unread = 0 # Number of unread samples in the circular buffer
|
|
64
64
|
self._buff_read = 0 # Tracks samples read and still in buffer
|
|
65
65
|
self._deque_len = 0 # Number of unread samples in the deque
|
|
66
|
-
self._last_overflow =
|
|
67
|
-
0 # Tracks the last overflow count, overwritten or skipped
|
|
68
|
-
)
|
|
66
|
+
self._last_overflow = 0 # Tracks the last overflow count, overwritten or skipped
|
|
69
67
|
self._warned = False # Tracks if we've warned already (for warn_once)
|
|
70
68
|
|
|
71
69
|
@property
|
|
@@ -96,9 +94,7 @@ class HybridBuffer:
|
|
|
96
94
|
block = block[:, self.xp.newaxis]
|
|
97
95
|
|
|
98
96
|
if block.shape[1:] != other_shape:
|
|
99
|
-
raise ValueError(
|
|
100
|
-
f"Block shape {block.shape[1:]} does not match buffer's other_shape {other_shape}"
|
|
101
|
-
)
|
|
97
|
+
raise ValueError(f"Block shape {block.shape[1:]} does not match buffer's other_shape {other_shape}")
|
|
102
98
|
|
|
103
99
|
# Most overflow strategies are handled during flush, but there are a couple
|
|
104
100
|
# scenarios that can be evaluated on write to give immediate feedback.
|
|
@@ -117,8 +113,7 @@ class HybridBuffer:
|
|
|
117
113
|
self._deque_len += block.shape[0]
|
|
118
114
|
|
|
119
115
|
if self._update_strategy == "immediate" or (
|
|
120
|
-
self._update_strategy == "threshold"
|
|
121
|
-
and (0 < self._threshold <= self._deque_len)
|
|
116
|
+
self._update_strategy == "threshold" and (0 < self._threshold <= self._deque_len)
|
|
122
117
|
):
|
|
123
118
|
self.flush()
|
|
124
119
|
|
|
@@ -128,9 +123,7 @@ class HybridBuffer:
|
|
|
128
123
|
from the buffer.
|
|
129
124
|
"""
|
|
130
125
|
if n_samples > self.available():
|
|
131
|
-
raise ValueError(
|
|
132
|
-
f"Requested {n_samples} samples, but only {self.available()} are available."
|
|
133
|
-
)
|
|
126
|
+
raise ValueError(f"Requested {n_samples} samples, but only {self.available()} are available.")
|
|
134
127
|
n_overflow = 0
|
|
135
128
|
if self._deque and (n_samples > self._buff_unread):
|
|
136
129
|
# We would cause a flush, but would that cause an overflow?
|
|
@@ -161,14 +154,10 @@ class HybridBuffer:
|
|
|
161
154
|
n_overflow = self._estimate_overflow(n_samples)
|
|
162
155
|
if n_overflow > 0:
|
|
163
156
|
first_read = self._buff_unread
|
|
164
|
-
if (n_overflow - first_read) < self.capacity or (
|
|
165
|
-
self._overflow_strategy == "drop"
|
|
166
|
-
):
|
|
157
|
+
if (n_overflow - first_read) < self.capacity or (self._overflow_strategy == "drop"):
|
|
167
158
|
# We can prevent the overflow (or at least *some* if using "drop"
|
|
168
159
|
# strategy) by reading the samples in the buffer first to make room.
|
|
169
|
-
data = self.xp.empty(
|
|
170
|
-
(n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
|
|
171
|
-
)
|
|
160
|
+
data = self.xp.empty((n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype)
|
|
172
161
|
self.peek(first_read, out=data[:first_read])
|
|
173
162
|
offset += first_read
|
|
174
163
|
self.seek(first_read)
|
|
@@ -204,13 +193,9 @@ class HybridBuffer:
|
|
|
204
193
|
if n_samples is None:
|
|
205
194
|
n_samples = self.available()
|
|
206
195
|
elif n_samples > self.available():
|
|
207
|
-
raise ValueError(
|
|
208
|
-
f"Requested to peek {n_samples} samples, but only {self.available()} are available."
|
|
209
|
-
)
|
|
196
|
+
raise ValueError(f"Requested to peek {n_samples} samples, but only {self.available()} are available.")
|
|
210
197
|
if out is not None and out.shape[0] < n_samples:
|
|
211
|
-
raise ValueError(
|
|
212
|
-
f"Output array shape {out.shape} is smaller than requested {n_samples} samples."
|
|
213
|
-
)
|
|
198
|
+
raise ValueError(f"Output array shape {out.shape} is smaller than requested {n_samples} samples.")
|
|
214
199
|
|
|
215
200
|
if n_samples == 0:
|
|
216
201
|
return self._buffer[:0]
|
|
@@ -224,9 +209,7 @@ class HybridBuffer:
|
|
|
224
209
|
out = (
|
|
225
210
|
out
|
|
226
211
|
if out is not None
|
|
227
|
-
else self.xp.empty(
|
|
228
|
-
(n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
|
|
229
|
-
)
|
|
212
|
+
else self.xp.empty((n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype)
|
|
230
213
|
)
|
|
231
214
|
out[:part1_len] = self._buffer[self._tail :]
|
|
232
215
|
out[part1_len:] = self._buffer[:part2_len]
|
|
@@ -258,9 +241,7 @@ class HybridBuffer:
|
|
|
258
241
|
if not allow_flush and idx >= self._buff_unread:
|
|
259
242
|
# The requested sample is in the deque.
|
|
260
243
|
idx -= self._buff_unread
|
|
261
|
-
deq_splits = self.xp.cumsum(
|
|
262
|
-
[0] + [_.shape[0] for _ in self._deque], dtype=int
|
|
263
|
-
)
|
|
244
|
+
deq_splits = self.xp.cumsum([0] + [_.shape[0] for _ in self._deque], dtype=int)
|
|
264
245
|
arr_idx = self.xp.searchsorted(deq_splits, idx, side="right") - 1
|
|
265
246
|
idx -= deq_splits[arr_idx]
|
|
266
247
|
return self._deque[arr_idx][idx : idx + 1]
|
|
@@ -334,7 +315,8 @@ class HybridBuffer:
|
|
|
334
315
|
if n_overflow > 0 and (not self._warn_once or not self._warned):
|
|
335
316
|
self._warned = True
|
|
336
317
|
warnings.warn(
|
|
337
|
-
f"Buffer overflow: {n_new} samples received,
|
|
318
|
+
f"Buffer overflow: {n_new} samples received, "
|
|
319
|
+
f"but only {self._capacity - self._buff_unread} available. "
|
|
338
320
|
f"Overwriting {n_overflow} previous samples.",
|
|
339
321
|
RuntimeWarning,
|
|
340
322
|
)
|
|
@@ -347,10 +329,9 @@ class HybridBuffer:
|
|
|
347
329
|
break
|
|
348
330
|
n_to_copy = min(block.shape[0], samples_to_copy - copied_samples)
|
|
349
331
|
start_idx = block.shape[0] - n_to_copy
|
|
350
|
-
self._buffer[
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
] = block[start_idx:]
|
|
332
|
+
self._buffer[samples_to_copy - copied_samples - n_to_copy : samples_to_copy - copied_samples] = block[
|
|
333
|
+
start_idx:
|
|
334
|
+
]
|
|
354
335
|
copied_samples += n_to_copy
|
|
355
336
|
|
|
356
337
|
self._head = 0
|
|
@@ -362,9 +343,7 @@ class HybridBuffer:
|
|
|
362
343
|
else:
|
|
363
344
|
if n_overflow > 0:
|
|
364
345
|
if self._overflow_strategy == "raise":
|
|
365
|
-
raise OverflowError(
|
|
366
|
-
f"Buffer overflow: {n_new} samples received, but only {n_free} available."
|
|
367
|
-
)
|
|
346
|
+
raise OverflowError(f"Buffer overflow: {n_new} samples received, but only {n_free} available.")
|
|
368
347
|
elif self._overflow_strategy == "warn-overwrite":
|
|
369
348
|
if not self._warn_once or not self._warned:
|
|
370
349
|
self._warned = True
|
|
@@ -430,9 +409,7 @@ class HybridBuffer:
|
|
|
430
409
|
return
|
|
431
410
|
|
|
432
411
|
other_shape = self._buffer.shape[1:]
|
|
433
|
-
max_capacity = self._max_size / (
|
|
434
|
-
self._buffer.dtype.itemsize * math.prod(other_shape)
|
|
435
|
-
)
|
|
412
|
+
max_capacity = self._max_size / (self._buffer.dtype.itemsize * math.prod(other_shape))
|
|
436
413
|
if min_capacity > max_capacity:
|
|
437
414
|
raise OverflowError(
|
|
438
415
|
f"Cannot grow buffer to {min_capacity} samples, "
|
|
@@ -440,9 +417,7 @@ class HybridBuffer:
|
|
|
440
417
|
)
|
|
441
418
|
|
|
442
419
|
new_capacity = min(max_capacity, max(self._capacity * 2, min_capacity))
|
|
443
|
-
new_buffer = self.xp.empty(
|
|
444
|
-
(new_capacity, *other_shape), dtype=self._buffer.dtype
|
|
445
|
-
)
|
|
420
|
+
new_buffer = self.xp.empty((new_capacity, *other_shape), dtype=self._buffer.dtype)
|
|
446
421
|
|
|
447
422
|
# Copy existing data to new buffer
|
|
448
423
|
total_samples = self._buff_read + self._buff_unread
|
ezmsg/sigproc/util/message.py
CHANGED
|
@@ -1,31 +1,17 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
""
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@dataclass
|
|
21
|
-
class SampleMessage:
|
|
22
|
-
trigger: SampleTriggerMessage
|
|
23
|
-
"""The time, window, and value (if any) associated with the trigger."""
|
|
24
|
-
|
|
25
|
-
sample: AxisArray
|
|
26
|
-
"""The data sampled around the trigger."""
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
|
|
30
|
-
"""Check if the message is a SampleMessage."""
|
|
31
|
-
return hasattr(message, "trigger")
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.util.message.
|
|
3
|
+
|
|
4
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ezmsg.baseproc.util.message import (
|
|
8
|
+
SampleMessage,
|
|
9
|
+
SampleTriggerMessage,
|
|
10
|
+
is_sample_message,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"SampleMessage",
|
|
15
|
+
"SampleTriggerMessage",
|
|
16
|
+
"is_sample_message",
|
|
17
|
+
]
|