ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
  4. ezmsg/sigproc/affinetransform.py +13 -38
  5. ezmsg/sigproc/aggregate.py +13 -30
  6. ezmsg/sigproc/bandpower.py +7 -15
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +123 -0
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/decimate.py +2 -6
  13. ezmsg/sigproc/denormalize.py +6 -11
  14. ezmsg/sigproc/detrend.py +3 -4
  15. ezmsg/sigproc/diff.py +8 -17
  16. ezmsg/sigproc/downsample.py +6 -14
  17. ezmsg/sigproc/ewma.py +11 -27
  18. ezmsg/sigproc/ewmfilter.py +1 -1
  19. ezmsg/sigproc/extract_axis.py +3 -4
  20. ezmsg/sigproc/fbcca.py +31 -56
  21. ezmsg/sigproc/filter.py +19 -45
  22. ezmsg/sigproc/filterbank.py +33 -70
  23. ezmsg/sigproc/filterbankdesign.py +5 -12
  24. ezmsg/sigproc/fir_hilbert.py +336 -0
  25. ezmsg/sigproc/fir_pmc.py +209 -0
  26. ezmsg/sigproc/firfilter.py +12 -14
  27. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  28. ezmsg/sigproc/kaiser.py +11 -15
  29. ezmsg/sigproc/math/abs.py +1 -3
  30. ezmsg/sigproc/math/add.py +121 -0
  31. ezmsg/sigproc/math/clip.py +1 -1
  32. ezmsg/sigproc/math/difference.py +98 -36
  33. ezmsg/sigproc/math/invert.py +1 -3
  34. ezmsg/sigproc/math/log.py +2 -6
  35. ezmsg/sigproc/messages.py +1 -2
  36. ezmsg/sigproc/quantize.py +2 -4
  37. ezmsg/sigproc/resample.py +13 -34
  38. ezmsg/sigproc/rollingscaler.py +232 -0
  39. ezmsg/sigproc/sampler.py +17 -35
  40. ezmsg/sigproc/scaler.py +8 -18
  41. ezmsg/sigproc/signalinjector.py +6 -16
  42. ezmsg/sigproc/slicer.py +9 -28
  43. ezmsg/sigproc/spectral.py +3 -3
  44. ezmsg/sigproc/spectrogram.py +12 -19
  45. ezmsg/sigproc/spectrum.py +12 -32
  46. ezmsg/sigproc/transpose.py +7 -18
  47. ezmsg/sigproc/util/asio.py +25 -156
  48. ezmsg/sigproc/util/axisarray_buffer.py +10 -26
  49. ezmsg/sigproc/util/buffer.py +18 -43
  50. ezmsg/sigproc/util/message.py +17 -31
  51. ezmsg/sigproc/util/profile.py +23 -174
  52. ezmsg/sigproc/util/sparse.py +5 -15
  53. ezmsg/sigproc/util/typeresolution.py +17 -83
  54. ezmsg/sigproc/wavelets.py +6 -15
  55. ezmsg/sigproc/window.py +24 -78
  56. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
  57. ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
  58. ezmsg/sigproc/synth.py +0 -774
  59. ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
  60. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
  61. /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/spectrum.py CHANGED
@@ -1,14 +1,14 @@
1
1
  import enum
2
- from functools import partial
3
2
  import typing
3
+ from functools import partial
4
4
 
5
+ import ezmsg.core as ez
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- import ezmsg.core as ez
8
8
  from ezmsg.util.messages.axisarray import (
9
9
  AxisArray,
10
- slice_along_axis,
11
10
  replace,
11
+ slice_along_axis,
12
12
  )
13
13
 
14
14
  from .base import (
@@ -127,17 +127,13 @@ class SpectrumState:
127
127
  window: npt.NDArray | None = None
128
128
 
129
129
 
130
- class SpectrumTransformer(
131
- BaseStatefulTransformer[SpectrumSettings, AxisArray, AxisArray, SpectrumState]
132
- ):
130
+ class SpectrumTransformer(BaseStatefulTransformer[SpectrumSettings, AxisArray, AxisArray, SpectrumState]):
133
131
  def _hash_message(self, message: AxisArray) -> int:
134
132
  axis = self.settings.axis or message.dims[0]
135
133
  ax_idx = message.get_axis_idx(axis)
136
134
  ax_info = message.axes[axis]
137
135
  targ_len = message.data.shape[ax_idx]
138
- return hash(
139
- (targ_len, message.data.ndim, message.data.dtype.kind, ax_idx, ax_info.gain)
140
- )
136
+ return hash((targ_len, message.data.ndim, message.data.dtype.kind, ax_idx, ax_info.gain))
141
137
 
142
138
  def _reset_state(self, message: AxisArray) -> None:
143
139
  axis = self.settings.axis or message.dims[0]
@@ -156,8 +152,7 @@ class SpectrumTransformer(
156
152
  + [1] * (message.data.ndim - 1 - ax_idx)
157
153
  )
158
154
  if self.settings.transform != SpectralTransform.RAW_COMPLEX and not (
159
- self.settings.transform == SpectralTransform.REAL
160
- or self.settings.transform == SpectralTransform.IMAG
155
+ self.settings.transform == SpectralTransform.REAL or self.settings.transform == SpectralTransform.IMAG
161
156
  ):
162
157
  scale = np.sum(window**2.0) * ax_info.gain
163
158
 
@@ -170,30 +165,21 @@ class SpectrumTransformer(
170
165
  if (not b_complex) and self.settings.output == SpectralOutput.POSITIVE:
171
166
  # If input is not complex and desired output is SpectralOutput.POSITIVE, we can save some computation
172
167
  # by using rfft and rfftfreq.
173
- self.state.fftfun = partial(
174
- np.fft.rfft, n=nfft, axis=ax_idx, norm=self.settings.norm
175
- )
168
+ self.state.fftfun = partial(np.fft.rfft, n=nfft, axis=ax_idx, norm=self.settings.norm)
176
169
  freqs = np.fft.rfftfreq(nfft, d=ax_info.gain * targ_len / nfft)
177
170
  else:
178
- self.state.fftfun = partial(
179
- np.fft.fft, n=nfft, axis=ax_idx, norm=self.settings.norm
180
- )
171
+ self.state.fftfun = partial(np.fft.fft, n=nfft, axis=ax_idx, norm=self.settings.norm)
181
172
  freqs = np.fft.fftfreq(nfft, d=ax_info.gain * targ_len / nfft)
182
173
  if self.settings.output == SpectralOutput.POSITIVE:
183
174
  self.state.f_sl = slice(None, nfft // 2 + 1 - (nfft % 2))
184
175
  elif self.settings.output == SpectralOutput.NEGATIVE:
185
176
  freqs = np.fft.fftshift(freqs, axes=-1)
186
177
  self.state.f_sl = slice(None, nfft // 2 + 1)
187
- elif (
188
- self.settings.do_fftshift
189
- and self.settings.output == SpectralOutput.FULL
190
- ):
178
+ elif self.settings.do_fftshift and self.settings.output == SpectralOutput.FULL:
191
179
  freqs = np.fft.fftshift(freqs, axes=-1)
192
180
  freqs = freqs[self.state.f_sl]
193
181
  freqs = freqs.tolist() # To please type checking
194
- self.state.freq_axis = AxisArray.LinearAxis(
195
- unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0]
196
- )
182
+ self.state.freq_axis = AxisArray.LinearAxis(unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0])
197
183
  self.state.new_dims = (
198
184
  message.dims[:ax_idx]
199
185
  + [
@@ -232,11 +218,7 @@ class SpectrumTransformer(
232
218
  ax_idx = message.get_axis_idx(axis)
233
219
  targ_len = message.data.shape[ax_idx]
234
220
 
235
- new_axes = {
236
- k: v
237
- for k, v in message.axes.items()
238
- if k not in [self.settings.out_axis, axis]
239
- }
221
+ new_axes = {k: v for k, v in message.axes.items() if k not in [self.settings.out_axis, axis]}
240
222
  new_axes[self.settings.out_axis or axis] = self.state.freq_axis
241
223
 
242
224
  if self.state.window is not None:
@@ -261,9 +243,7 @@ class SpectrumTransformer(
261
243
  return msg_out
262
244
 
263
245
 
264
- class Spectrum(
265
- BaseTransformerUnit[SpectrumSettings, AxisArray, AxisArray, SpectrumTransformer]
266
- ):
246
+ class Spectrum(BaseTransformerUnit[SpectrumSettings, AxisArray, AxisArray, SpectrumTransformer]):
267
247
  SETTINGS = SpectrumSettings
268
248
 
269
249
 
@@ -1,6 +1,7 @@
1
1
  from types import EllipsisType
2
- import numpy as np
2
+
3
3
  import ezmsg.core as ez
4
+ import numpy as np
4
5
  from ezmsg.util.messages.axisarray import (
5
6
  AxisArray,
6
7
  replace,
@@ -30,9 +31,7 @@ class TransposeState:
30
31
  axes_ints: tuple[int, ...] | None = None
31
32
 
32
33
 
33
- class TransposeTransformer(
34
- BaseStatefulTransformer[TransposeSettings, AxisArray, AxisArray, TransposeState]
35
- ):
34
+ class TransposeTransformer(BaseStatefulTransformer[TransposeSettings, AxisArray, AxisArray, TransposeState]):
36
35
  """
37
36
  Downsampled data simply comprise every `factor`th sample.
38
37
  This should only be used following appropriate lowpass filtering.
@@ -67,11 +66,7 @@ class TransposeTransformer(
67
66
  if ax not in message.dims:
68
67
  raise ValueError(f"Axis {ax} not found in message dims.")
69
68
  suffix.append(message.dims.index(ax))
70
- ells = [
71
- _
72
- for _ in range(message.data.ndim)
73
- if _ not in prefix and _ not in suffix
74
- ]
69
+ ells = [_ for _ in range(message.data.ndim) if _ not in prefix and _ not in suffix]
75
70
  re_ix = tuple(prefix + ells + suffix)
76
71
  if re_ix == tuple(range(message.data.ndim)):
77
72
  self._state.axes_ints = None
@@ -100,17 +95,13 @@ class TransposeTransformer(
100
95
  # If the memory is already contiguous in the correct order, np.require won't do anything.
101
96
  msg_out = replace(
102
97
  message,
103
- data=np.require(
104
- message.data, requirements=self.settings.order.upper()[0]
105
- ),
98
+ data=np.require(message.data, requirements=self.settings.order.upper()[0]),
106
99
  )
107
100
  else:
108
101
  dims_out = [message.dims[ix] for ix in self.state.axes_ints]
109
102
  data_out = np.transpose(message.data, axes=self.state.axes_ints)
110
103
  if self.settings.order is not None:
111
- data_out = np.require(
112
- data_out, requirements=self.settings.order.upper()[0]
113
- )
104
+ data_out = np.require(data_out, requirements=self.settings.order.upper()[0])
114
105
  msg_out = replace(
115
106
  message,
116
107
  data=data_out,
@@ -119,9 +110,7 @@ class TransposeTransformer(
119
110
  return msg_out
120
111
 
121
112
 
122
- class Transpose(
123
- BaseTransformerUnit[TransposeSettings, AxisArray, AxisArray, TransposeTransformer]
124
- ):
113
+ class Transpose(BaseTransformerUnit[TransposeSettings, AxisArray, AxisArray, TransposeTransformer]):
125
114
  SETTINGS = TransposeSettings
126
115
 
127
116
 
@@ -1,156 +1,25 @@
1
- import asyncio
2
- from concurrent.futures import ThreadPoolExecutor
3
- import contextlib
4
- import inspect
5
- import threading
6
- from typing import Any, Coroutine, TypeVar
7
-
8
- T = TypeVar("T")
9
-
10
-
11
- class CoroutineExecutionError(Exception):
12
- """Custom exception for coroutine execution failures"""
13
-
14
- pass
15
-
16
-
17
- def run_coroutine_sync(coroutine: Coroutine[Any, Any, T], timeout: float = 30) -> T:
18
- """
19
- Executes an asyncio coroutine synchronously, with enhanced error handling.
20
-
21
- Args:
22
- coroutine: The asyncio coroutine to execute
23
- timeout: Maximum time in seconds to wait for coroutine completion (default: 30)
24
-
25
- Returns:
26
- The result of the coroutine execution
27
-
28
- Raises:
29
- CoroutineExecutionError: If execution fails due to threading or event loop issues
30
- TimeoutError: If execution exceeds the timeout period
31
- Exception: Any exception raised by the coroutine
32
- """
33
-
34
- def run_in_new_loop() -> T:
35
- """
36
- Creates and runs a new event loop in the current thread.
37
- Ensures proper cleanup of the loop.
38
- """
39
- new_loop = asyncio.new_event_loop()
40
- asyncio.set_event_loop(new_loop)
41
- try:
42
- return new_loop.run_until_complete(
43
- asyncio.wait_for(coroutine, timeout=timeout)
44
- )
45
- finally:
46
- with contextlib.suppress(Exception):
47
- # Clean up any pending tasks
48
- pending = asyncio.all_tasks(new_loop)
49
- for task in pending:
50
- task.cancel()
51
- new_loop.run_until_complete(
52
- asyncio.gather(*pending, return_exceptions=True)
53
- )
54
- new_loop.close()
55
-
56
- try:
57
- loop = asyncio.get_running_loop()
58
- except RuntimeError:
59
- try:
60
- return asyncio.run(asyncio.wait_for(coroutine, timeout=timeout))
61
- except Exception as e:
62
- raise CoroutineExecutionError(
63
- f"Failed to execute coroutine: {str(e)}"
64
- ) from e
65
-
66
- if threading.current_thread() is threading.main_thread():
67
- if not loop.is_running():
68
- try:
69
- return loop.run_until_complete(
70
- asyncio.wait_for(coroutine, timeout=timeout)
71
- )
72
- except Exception as e:
73
- raise CoroutineExecutionError(
74
- f"Failed to execute coroutine in main loop: {str(e)}"
75
- ) from e
76
- else:
77
- with ThreadPoolExecutor() as pool:
78
- try:
79
- future = pool.submit(run_in_new_loop)
80
- return future.result(timeout=timeout)
81
- except Exception as e:
82
- raise CoroutineExecutionError(
83
- f"Failed to execute coroutine in thread: {str(e)}"
84
- ) from e
85
- else:
86
- try:
87
- future = asyncio.run_coroutine_threadsafe(coroutine, loop)
88
- return future.result(timeout=timeout)
89
- except Exception as e:
90
- raise CoroutineExecutionError(
91
- f"Failed to execute coroutine threadsafe: {str(e)}"
92
- ) from e
93
-
94
-
95
- class SyncToAsyncGeneratorWrapper:
96
- """
97
- A wrapper for synchronous generators to be used in an async context.
98
- """
99
-
100
- def __init__(self, gen):
101
- self._gen = gen
102
- self._closed = False
103
- # Prime the generator to ready for first send/next call
104
- try:
105
- is_not_primed = inspect.getgeneratorstate(self._gen) is inspect.GEN_CREATED
106
- except AttributeError as e:
107
- raise TypeError(
108
- "The provided generator is not a valid generator object"
109
- ) from e
110
- if is_not_primed:
111
- try:
112
- next(self._gen)
113
- except StopIteration:
114
- self._closed = True
115
- except Exception as e:
116
- raise RuntimeError(f"Failed to prime generator: {e}") from e
117
-
118
- async def asend(self, value):
119
- if self._closed:
120
- raise StopAsyncIteration("Generator is closed")
121
- try:
122
- return await asyncio.to_thread(self._gen.send, value)
123
- except StopIteration as e:
124
- self._closed = True
125
- raise StopAsyncIteration("Generator is closed") from e
126
- except Exception as e:
127
- raise RuntimeError(f"Error while sending value to generator: {e}") from e
128
-
129
- async def __anext__(self):
130
- if self._closed:
131
- raise StopAsyncIteration("Generator is closed")
132
- try:
133
- return await asyncio.to_thread(self._gen.__next__)
134
- except StopIteration as e:
135
- self._closed = True
136
- raise StopAsyncIteration("Generator is closed") from e
137
- except Exception as e:
138
- raise RuntimeError(
139
- f"Error while getting next value from generator: {e}"
140
- ) from e
141
-
142
- async def aclose(self):
143
- if self._closed:
144
- return
145
- try:
146
- await asyncio.to_thread(self._gen.close)
147
- except Exception as e:
148
- raise RuntimeError(f"Error while closing generator: {e}") from e
149
- finally:
150
- self._closed = True
151
-
152
- def __aiter__(self):
153
- return self
154
-
155
- def __getattr__(self, name):
156
- return getattr(self._gen, name)
1
+ """
2
+ Backwards-compatible re-exports from ezmsg.baseproc.util.asio.
3
+
4
+ New code should import directly from ezmsg.baseproc instead.
5
+ """
6
+
7
+ import warnings
8
+
9
+ warnings.warn(
10
+ "Importing from 'ezmsg.sigproc.util.asio' is deprecated. Please import from 'ezmsg.baseproc.util.asio' instead.",
11
+ DeprecationWarning,
12
+ stacklevel=2,
13
+ )
14
+
15
+ from ezmsg.baseproc.util.asio import ( # noqa: E402
16
+ CoroutineExecutionError,
17
+ SyncToAsyncGeneratorWrapper,
18
+ run_coroutine_sync,
19
+ )
20
+
21
+ __all__ = [
22
+ "CoroutineExecutionError",
23
+ "SyncToAsyncGeneratorWrapper",
24
+ "run_coroutine_sync",
25
+ ]
@@ -1,14 +1,13 @@
1
1
  import math
2
2
  import typing
3
3
 
4
- from array_api_compat import get_namespace
5
4
  import numpy as np
6
- from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, CoordinateAxis
5
+ from array_api_compat import get_namespace
6
+ from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis, LinearAxis
7
7
  from ezmsg.util.messages.util import replace
8
8
 
9
9
  from .buffer import HybridBuffer
10
10
 
11
-
12
11
  Array = typing.TypeVar("Array")
13
12
 
14
13
 
@@ -68,9 +67,7 @@ class HybridAxisBuffer:
68
67
  if hasattr(first_axis, "data"):
69
68
  # Initialize a CoordinateAxis buffer
70
69
  if len(first_axis.data) > 1:
71
- _axis_gain = (first_axis.data[-1] - first_axis.data[0]) / (
72
- len(first_axis.data) - 1
73
- )
70
+ _axis_gain = (first_axis.data[-1] - first_axis.data[0]) / (len(first_axis.data) - 1)
74
71
  else:
75
72
  _axis_gain = 1.0
76
73
  self._coords_gain_estimate = _axis_gain
@@ -107,8 +104,7 @@ class HybridAxisBuffer:
107
104
  )
108
105
  if axis.gain != self._linear_axis.gain:
109
106
  raise ValueError(
110
- f"Buffer initialized with gain={self._linear_axis.gain}, "
111
- f"but received gain={axis.gain}."
107
+ f"Buffer initialized with gain={self._linear_axis.gain}, but received gain={axis.gain}."
112
108
  )
113
109
  if self._linear_n_available + n_samples > self.capacity:
114
110
  # Simulate overflow by advancing the offset and decreasing
@@ -117,16 +113,12 @@ class HybridAxisBuffer:
117
113
  self.seek(n_to_discard)
118
114
  # Update the offset corresponding to the oldest sample in the buffer
119
115
  # by anchoring on the new offset and accounting for the samples already available.
120
- self._linear_axis.offset = (
121
- axis.offset - self._linear_n_available * axis.gain
122
- )
116
+ self._linear_axis.offset = axis.offset - self._linear_n_available * axis.gain
123
117
  self._linear_n_available += n_samples
124
118
 
125
119
  def peek(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis:
126
120
  if self._coords_buffer is not None:
127
- return replace(
128
- self._coords_template, data=self._coords_buffer.peek(n_samples)
129
- )
121
+ return replace(self._coords_template, data=self._coords_buffer.peek(n_samples))
130
122
  else:
131
123
  # Return a shallow copy.
132
124
  return replace(self._linear_axis, offset=self._linear_axis.offset)
@@ -184,13 +176,9 @@ class HybridAxisBuffer:
184
176
  else:
185
177
  return None
186
178
 
187
- def searchsorted(
188
- self, values: typing.Union[float, Array], side: str = "left"
189
- ) -> typing.Union[int, Array]:
179
+ def searchsorted(self, values: typing.Union[float, Array], side: str = "left") -> typing.Union[int, Array]:
190
180
  if self._coords_buffer is not None:
191
- return self._coords_buffer.xp.searchsorted(
192
- self._coords_buffer.peek(self.available()), values, side=side
193
- )
181
+ return self._coords_buffer.xp.searchsorted(self._coords_buffer.peek(self.available()), values, side=side)
194
182
  else:
195
183
  if self.available() == 0:
196
184
  if isinstance(values, float):
@@ -312,9 +300,7 @@ class HybridAxisArrayBuffer:
312
300
  axes={**self._template_msg.axes, self._axis: out_axis},
313
301
  )
314
302
 
315
- def peek_axis(
316
- self, n_samples: int | None = None
317
- ) -> LinearAxis | CoordinateAxis | None:
303
+ def peek_axis(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis | None:
318
304
  """Retrieves the axis data without advancing the read head."""
319
305
  if self._data_buffer is None:
320
306
  return None
@@ -369,9 +355,7 @@ class HybridAxisArrayBuffer:
369
355
  """
370
356
  return self._axis_buffer.gain
371
357
 
372
- def axis_searchsorted(
373
- self, values: typing.Union[float, Array], side: str = "left"
374
- ) -> typing.Union[int, Array]:
358
+ def axis_searchsorted(self, values: typing.Union[float, Array], side: str = "left") -> typing.Union[int, Array]:
375
359
  """
376
360
  Find the indices into which the given values would be inserted
377
361
  into the target axis data to maintain order.
@@ -63,9 +63,7 @@ class HybridBuffer:
63
63
  self._buff_unread = 0 # Number of unread samples in the circular buffer
64
64
  self._buff_read = 0 # Tracks samples read and still in buffer
65
65
  self._deque_len = 0 # Number of unread samples in the deque
66
- self._last_overflow = (
67
- 0 # Tracks the last overflow count, overwritten or skipped
68
- )
66
+ self._last_overflow = 0 # Tracks the last overflow count, overwritten or skipped
69
67
  self._warned = False # Tracks if we've warned already (for warn_once)
70
68
 
71
69
  @property
@@ -96,9 +94,7 @@ class HybridBuffer:
96
94
  block = block[:, self.xp.newaxis]
97
95
 
98
96
  if block.shape[1:] != other_shape:
99
- raise ValueError(
100
- f"Block shape {block.shape[1:]} does not match buffer's other_shape {other_shape}"
101
- )
97
+ raise ValueError(f"Block shape {block.shape[1:]} does not match buffer's other_shape {other_shape}")
102
98
 
103
99
  # Most overflow strategies are handled during flush, but there are a couple
104
100
  # scenarios that can be evaluated on write to give immediate feedback.
@@ -117,8 +113,7 @@ class HybridBuffer:
117
113
  self._deque_len += block.shape[0]
118
114
 
119
115
  if self._update_strategy == "immediate" or (
120
- self._update_strategy == "threshold"
121
- and (0 < self._threshold <= self._deque_len)
116
+ self._update_strategy == "threshold" and (0 < self._threshold <= self._deque_len)
122
117
  ):
123
118
  self.flush()
124
119
 
@@ -128,9 +123,7 @@ class HybridBuffer:
128
123
  from the buffer.
129
124
  """
130
125
  if n_samples > self.available():
131
- raise ValueError(
132
- f"Requested {n_samples} samples, but only {self.available()} are available."
133
- )
126
+ raise ValueError(f"Requested {n_samples} samples, but only {self.available()} are available.")
134
127
  n_overflow = 0
135
128
  if self._deque and (n_samples > self._buff_unread):
136
129
  # We would cause a flush, but would that cause an overflow?
@@ -161,14 +154,10 @@ class HybridBuffer:
161
154
  n_overflow = self._estimate_overflow(n_samples)
162
155
  if n_overflow > 0:
163
156
  first_read = self._buff_unread
164
- if (n_overflow - first_read) < self.capacity or (
165
- self._overflow_strategy == "drop"
166
- ):
157
+ if (n_overflow - first_read) < self.capacity or (self._overflow_strategy == "drop"):
167
158
  # We can prevent the overflow (or at least *some* if using "drop"
168
159
  # strategy) by reading the samples in the buffer first to make room.
169
- data = self.xp.empty(
170
- (n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
171
- )
160
+ data = self.xp.empty((n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype)
172
161
  self.peek(first_read, out=data[:first_read])
173
162
  offset += first_read
174
163
  self.seek(first_read)
@@ -204,13 +193,9 @@ class HybridBuffer:
204
193
  if n_samples is None:
205
194
  n_samples = self.available()
206
195
  elif n_samples > self.available():
207
- raise ValueError(
208
- f"Requested to peek {n_samples} samples, but only {self.available()} are available."
209
- )
196
+ raise ValueError(f"Requested to peek {n_samples} samples, but only {self.available()} are available.")
210
197
  if out is not None and out.shape[0] < n_samples:
211
- raise ValueError(
212
- f"Output array shape {out.shape} is smaller than requested {n_samples} samples."
213
- )
198
+ raise ValueError(f"Output array shape {out.shape} is smaller than requested {n_samples} samples.")
214
199
 
215
200
  if n_samples == 0:
216
201
  return self._buffer[:0]
@@ -224,9 +209,7 @@ class HybridBuffer:
224
209
  out = (
225
210
  out
226
211
  if out is not None
227
- else self.xp.empty(
228
- (n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
229
- )
212
+ else self.xp.empty((n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype)
230
213
  )
231
214
  out[:part1_len] = self._buffer[self._tail :]
232
215
  out[part1_len:] = self._buffer[:part2_len]
@@ -258,9 +241,7 @@ class HybridBuffer:
258
241
  if not allow_flush and idx >= self._buff_unread:
259
242
  # The requested sample is in the deque.
260
243
  idx -= self._buff_unread
261
- deq_splits = self.xp.cumsum(
262
- [0] + [_.shape[0] for _ in self._deque], dtype=int
263
- )
244
+ deq_splits = self.xp.cumsum([0] + [_.shape[0] for _ in self._deque], dtype=int)
264
245
  arr_idx = self.xp.searchsorted(deq_splits, idx, side="right") - 1
265
246
  idx -= deq_splits[arr_idx]
266
247
  return self._deque[arr_idx][idx : idx + 1]
@@ -334,7 +315,8 @@ class HybridBuffer:
334
315
  if n_overflow > 0 and (not self._warn_once or not self._warned):
335
316
  self._warned = True
336
317
  warnings.warn(
337
- f"Buffer overflow: {n_new} samples received, but only {self._capacity - self._buff_unread} available. "
318
+ f"Buffer overflow: {n_new} samples received, "
319
+ f"but only {self._capacity - self._buff_unread} available. "
338
320
  f"Overwriting {n_overflow} previous samples.",
339
321
  RuntimeWarning,
340
322
  )
@@ -347,10 +329,9 @@ class HybridBuffer:
347
329
  break
348
330
  n_to_copy = min(block.shape[0], samples_to_copy - copied_samples)
349
331
  start_idx = block.shape[0] - n_to_copy
350
- self._buffer[
351
- samples_to_copy - copied_samples - n_to_copy : samples_to_copy
352
- - copied_samples
353
- ] = block[start_idx:]
332
+ self._buffer[samples_to_copy - copied_samples - n_to_copy : samples_to_copy - copied_samples] = block[
333
+ start_idx:
334
+ ]
354
335
  copied_samples += n_to_copy
355
336
 
356
337
  self._head = 0
@@ -362,9 +343,7 @@ class HybridBuffer:
362
343
  else:
363
344
  if n_overflow > 0:
364
345
  if self._overflow_strategy == "raise":
365
- raise OverflowError(
366
- f"Buffer overflow: {n_new} samples received, but only {n_free} available."
367
- )
346
+ raise OverflowError(f"Buffer overflow: {n_new} samples received, but only {n_free} available.")
368
347
  elif self._overflow_strategy == "warn-overwrite":
369
348
  if not self._warn_once or not self._warned:
370
349
  self._warned = True
@@ -430,9 +409,7 @@ class HybridBuffer:
430
409
  return
431
410
 
432
411
  other_shape = self._buffer.shape[1:]
433
- max_capacity = self._max_size / (
434
- self._buffer.dtype.itemsize * math.prod(other_shape)
435
- )
412
+ max_capacity = self._max_size / (self._buffer.dtype.itemsize * math.prod(other_shape))
436
413
  if min_capacity > max_capacity:
437
414
  raise OverflowError(
438
415
  f"Cannot grow buffer to {min_capacity} samples, "
@@ -440,9 +417,7 @@ class HybridBuffer:
440
417
  )
441
418
 
442
419
  new_capacity = min(max_capacity, max(self._capacity * 2, min_capacity))
443
- new_buffer = self.xp.empty(
444
- (new_capacity, *other_shape), dtype=self._buffer.dtype
445
- )
420
+ new_buffer = self.xp.empty((new_capacity, *other_shape), dtype=self._buffer.dtype)
446
421
 
447
422
  # Copy existing data to new buffer
448
423
  total_samples = self._buff_read + self._buff_unread
@@ -1,31 +1,17 @@
1
- import time
2
- import typing
3
- from dataclasses import dataclass, field
4
-
5
- from ezmsg.util.messages.axisarray import AxisArray
6
-
7
-
8
- @dataclass(unsafe_hash=True)
9
- class SampleTriggerMessage:
10
- timestamp: float = field(default_factory=time.time)
11
- """Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
12
-
13
- period: tuple[float, float] | None = None
14
- """The period around the timestamp, in seconds"""
15
-
16
- value: typing.Any = None
17
- """A value or 'label' associated with the trigger."""
18
-
19
-
20
- @dataclass
21
- class SampleMessage:
22
- trigger: SampleTriggerMessage
23
- """The time, window, and value (if any) associated with the trigger."""
24
-
25
- sample: AxisArray
26
- """The data sampled around the trigger."""
27
-
28
-
29
- def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
30
- """Check if the message is a SampleMessage."""
31
- return hasattr(message, "trigger")
1
+ """
2
+ Backwards-compatible re-exports from ezmsg.baseproc.util.message.
3
+
4
+ New code should import directly from ezmsg.baseproc instead.
5
+ """
6
+
7
+ from ezmsg.baseproc.util.message import (
8
+ SampleMessage,
9
+ SampleTriggerMessage,
10
+ is_sample_message,
11
+ )
12
+
13
+ __all__ = [
14
+ "SampleMessage",
15
+ "SampleTriggerMessage",
16
+ "is_sample_message",
17
+ ]