ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Transpose or permute array dimensions.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
Memory layout optimization (C/F order) only applies to NumPy arrays.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from types import EllipsisType
|
|
11
|
+
|
|
12
|
+
import ezmsg.core as ez
|
|
13
|
+
import numpy as np
|
|
14
|
+
from array_api_compat import get_namespace, is_numpy_array
|
|
15
|
+
from ezmsg.baseproc import (
|
|
16
|
+
BaseStatefulTransformer,
|
|
17
|
+
BaseTransformerUnit,
|
|
18
|
+
processor_state,
|
|
19
|
+
)
|
|
20
|
+
from ezmsg.util.messages.axisarray import (
|
|
21
|
+
AxisArray,
|
|
22
|
+
replace,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TransposeSettings(ez.Settings):
|
|
27
|
+
"""
|
|
28
|
+
Settings for :obj:`Transpose` node.
|
|
29
|
+
|
|
30
|
+
Fields:
|
|
31
|
+
axes:
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
axes: tuple[int | str | EllipsisType, ...] | None = None
|
|
35
|
+
order: str | None = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@processor_state
|
|
39
|
+
class TransposeState:
|
|
40
|
+
axes_ints: tuple[int, ...] | None = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TransposeTransformer(BaseStatefulTransformer[TransposeSettings, AxisArray, AxisArray, TransposeState]):
|
|
44
|
+
"""
|
|
45
|
+
Downsampled data simply comprise every `factor`th sample.
|
|
46
|
+
This should only be used following appropriate lowpass filtering.
|
|
47
|
+
If your pipeline does not already have lowpass filtering then consider
|
|
48
|
+
using the :obj:`Decimate` collection instead.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
52
|
+
return hash(tuple(message.dims))
|
|
53
|
+
|
|
54
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
55
|
+
if self.settings.axes is None:
|
|
56
|
+
self._state.axes_ints = None
|
|
57
|
+
else:
|
|
58
|
+
ell_ix = [ix for ix, ax in enumerate(self.settings.axes) if ax is Ellipsis]
|
|
59
|
+
if len(ell_ix) > 1:
|
|
60
|
+
raise ValueError("Only one Ellipsis is allowed in axes.")
|
|
61
|
+
ell_ix = ell_ix[0] if len(ell_ix) == 1 else len(message.dims)
|
|
62
|
+
prefix = []
|
|
63
|
+
for ax in self.settings.axes[:ell_ix]:
|
|
64
|
+
if isinstance(ax, int):
|
|
65
|
+
prefix.append(ax)
|
|
66
|
+
else:
|
|
67
|
+
if ax not in message.dims:
|
|
68
|
+
raise ValueError(f"Axis {ax} not found in message dims.")
|
|
69
|
+
prefix.append(message.dims.index(ax))
|
|
70
|
+
suffix = []
|
|
71
|
+
for ax in self.settings.axes[ell_ix + 1 :]:
|
|
72
|
+
if isinstance(ax, int):
|
|
73
|
+
suffix.append(ax)
|
|
74
|
+
else:
|
|
75
|
+
if ax not in message.dims:
|
|
76
|
+
raise ValueError(f"Axis {ax} not found in message dims.")
|
|
77
|
+
suffix.append(message.dims.index(ax))
|
|
78
|
+
ells = [_ for _ in range(message.data.ndim) if _ not in prefix and _ not in suffix]
|
|
79
|
+
re_ix = tuple(prefix + ells + suffix)
|
|
80
|
+
if re_ix == tuple(range(message.data.ndim)):
|
|
81
|
+
self._state.axes_ints = None
|
|
82
|
+
else:
|
|
83
|
+
self._state.axes_ints = re_ix
|
|
84
|
+
if self.settings.order is not None and self.settings.order.upper()[0] not in [
|
|
85
|
+
"C",
|
|
86
|
+
"F",
|
|
87
|
+
]:
|
|
88
|
+
raise ValueError("order must be 'C' or 'F'.")
|
|
89
|
+
|
|
90
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
91
|
+
if self.settings.axes is None and self.settings.order is None:
|
|
92
|
+
# Passthrough
|
|
93
|
+
return message
|
|
94
|
+
return super().__call__(message)
|
|
95
|
+
|
|
96
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
97
|
+
xp = get_namespace(message.data)
|
|
98
|
+
if self.state.axes_ints is None:
|
|
99
|
+
# No transpose required
|
|
100
|
+
if self.settings.order is None:
|
|
101
|
+
# No memory relayout required
|
|
102
|
+
# Note: We should not be able to reach here because it should be shortcutted at passthrough.
|
|
103
|
+
msg_out = message
|
|
104
|
+
else:
|
|
105
|
+
# Memory layout optimization only applies to numpy arrays
|
|
106
|
+
if is_numpy_array(message.data):
|
|
107
|
+
msg_out = replace(
|
|
108
|
+
message,
|
|
109
|
+
data=np.require(message.data, requirements=self.settings.order.upper()[0]),
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
msg_out = message
|
|
113
|
+
else:
|
|
114
|
+
dims_out = [message.dims[ix] for ix in self.state.axes_ints]
|
|
115
|
+
data_out = xp.permute_dims(message.data, axes=self.state.axes_ints)
|
|
116
|
+
if self.settings.order is not None and is_numpy_array(data_out):
|
|
117
|
+
# Memory layout optimization only applies to numpy arrays
|
|
118
|
+
data_out = np.require(data_out, requirements=self.settings.order.upper()[0])
|
|
119
|
+
msg_out = replace(
|
|
120
|
+
message,
|
|
121
|
+
data=data_out,
|
|
122
|
+
dims=dims_out,
|
|
123
|
+
)
|
|
124
|
+
return msg_out
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class Transpose(BaseTransformerUnit[TransposeSettings, AxisArray, AxisArray, TransposeTransformer]):
|
|
128
|
+
SETTINGS = TransposeSettings
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def transpose(
|
|
132
|
+
axes: tuple[int | str | EllipsisType, ...] | None = None, order: str | None = None
|
|
133
|
+
) -> TransposeTransformer:
|
|
134
|
+
return TransposeTransformer(TransposeSettings(axes=axes, order=order))
|
|
File without changes
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.util.asio.
|
|
3
|
+
|
|
4
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
warnings.warn(
|
|
10
|
+
"Importing from 'ezmsg.sigproc.util.asio' is deprecated. Please import from 'ezmsg.baseproc.util.asio' instead.",
|
|
11
|
+
DeprecationWarning,
|
|
12
|
+
stacklevel=2,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from ezmsg.baseproc.util.asio import ( # noqa: E402
|
|
16
|
+
CoroutineExecutionError,
|
|
17
|
+
SyncToAsyncGeneratorWrapper,
|
|
18
|
+
run_coroutine_sync,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"CoroutineExecutionError",
|
|
23
|
+
"SyncToAsyncGeneratorWrapper",
|
|
24
|
+
"run_coroutine_sync",
|
|
25
|
+
]
|
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
"""AxisArray support for .buffer.HybridBuffer."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import typing
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from array_api_compat import get_namespace
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis, LinearAxis
|
|
9
|
+
from ezmsg.util.messages.util import replace
|
|
10
|
+
|
|
11
|
+
from .buffer import HybridBuffer
|
|
12
|
+
|
|
13
|
+
Array = typing.TypeVar("Array")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HybridAxisBuffer:
|
|
17
|
+
"""
|
|
18
|
+
A buffer that intelligently handles ezmsg.util.messages.AxisArray _axes_ objects.
|
|
19
|
+
LinearAxis is maintained internally by tracking its offset, gain, and the number
|
|
20
|
+
of samples that have passed through.
|
|
21
|
+
CoordinateAxis has its data values maintained in a `HybridBuffer`.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
duration: The desired duration of the buffer in seconds. This is non-limiting
|
|
25
|
+
when managing a LinearAxis.
|
|
26
|
+
**kwargs: Additional keyword arguments to pass to the underlying HybridBuffer
|
|
27
|
+
(e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
_coords_buffer: HybridBuffer | None
|
|
31
|
+
_coords_template: CoordinateAxis | None
|
|
32
|
+
_coords_gain_estimate: float | None = None
|
|
33
|
+
_linear_axis: LinearAxis | None
|
|
34
|
+
_linear_n_available: int
|
|
35
|
+
|
|
36
|
+
def __init__(self, duration: float, **kwargs):
|
|
37
|
+
self.duration = duration
|
|
38
|
+
self.buffer_kwargs = kwargs
|
|
39
|
+
# Delay initialization until the first message arrives
|
|
40
|
+
self._coords_buffer = None
|
|
41
|
+
self._coords_template = None
|
|
42
|
+
self._linear_axis = None
|
|
43
|
+
self._linear_n_available = 0
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def capacity(self) -> int:
|
|
47
|
+
"""The maximum number of samples that can be stored in the buffer."""
|
|
48
|
+
if self._coords_buffer is not None:
|
|
49
|
+
return self._coords_buffer.capacity
|
|
50
|
+
elif self._linear_axis is not None:
|
|
51
|
+
return int(math.ceil(self.duration / self._linear_axis.gain))
|
|
52
|
+
else:
|
|
53
|
+
return 0
|
|
54
|
+
|
|
55
|
+
def available(self) -> int:
|
|
56
|
+
if self._coords_buffer is None:
|
|
57
|
+
return self._linear_n_available
|
|
58
|
+
return self._coords_buffer.available()
|
|
59
|
+
|
|
60
|
+
def is_empty(self) -> bool:
|
|
61
|
+
return self.available() == 0
|
|
62
|
+
|
|
63
|
+
def is_full(self) -> bool:
|
|
64
|
+
if self._coords_buffer is not None:
|
|
65
|
+
return self._coords_buffer.is_full()
|
|
66
|
+
return 0 < self.capacity == self.available()
|
|
67
|
+
|
|
68
|
+
def _initialize(self, first_axis: LinearAxis | CoordinateAxis) -> None:
|
|
69
|
+
if hasattr(first_axis, "data"):
|
|
70
|
+
# Initialize a CoordinateAxis buffer
|
|
71
|
+
if len(first_axis.data) > 1:
|
|
72
|
+
_axis_gain = (first_axis.data[-1] - first_axis.data[0]) / (len(first_axis.data) - 1)
|
|
73
|
+
else:
|
|
74
|
+
_axis_gain = 1.0
|
|
75
|
+
self._coords_gain_estimate = _axis_gain
|
|
76
|
+
capacity = int(self.duration / _axis_gain)
|
|
77
|
+
self._coords_buffer = HybridBuffer(
|
|
78
|
+
get_namespace(first_axis.data),
|
|
79
|
+
capacity,
|
|
80
|
+
other_shape=(),
|
|
81
|
+
dtype=first_axis.data.dtype,
|
|
82
|
+
**self.buffer_kwargs,
|
|
83
|
+
)
|
|
84
|
+
self._coords_template = replace(first_axis, data=first_axis.data[:0].copy())
|
|
85
|
+
else:
|
|
86
|
+
# Initialize a LinearAxis buffer
|
|
87
|
+
self._linear_axis = replace(first_axis, offset=first_axis.offset)
|
|
88
|
+
self._linear_n_available = 0
|
|
89
|
+
|
|
90
|
+
def write(self, axis: LinearAxis | CoordinateAxis, n_samples: int) -> None:
|
|
91
|
+
if self._linear_axis is None and self._coords_buffer is None:
|
|
92
|
+
self._initialize(axis)
|
|
93
|
+
|
|
94
|
+
if self._coords_buffer is not None:
|
|
95
|
+
if axis.__class__ is not self._coords_template.__class__:
|
|
96
|
+
raise TypeError(
|
|
97
|
+
f"Buffer initialized with {self._coords_template.__class__.__name__}, "
|
|
98
|
+
f"but received {axis.__class__.__name__}."
|
|
99
|
+
)
|
|
100
|
+
self._coords_buffer.write(axis.data)
|
|
101
|
+
else:
|
|
102
|
+
if axis.__class__ is not self._linear_axis.__class__:
|
|
103
|
+
raise TypeError(
|
|
104
|
+
f"Buffer initialized with {self._linear_axis.__class__.__name__}, "
|
|
105
|
+
f"but received {axis.__class__.__name__}."
|
|
106
|
+
)
|
|
107
|
+
if axis.gain != self._linear_axis.gain:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Buffer initialized with gain={self._linear_axis.gain}, but received gain={axis.gain}."
|
|
110
|
+
)
|
|
111
|
+
if self._linear_n_available + n_samples > self.capacity:
|
|
112
|
+
# Simulate overflow by advancing the offset and decreasing
|
|
113
|
+
# the number of available samples.
|
|
114
|
+
n_to_discard = self._linear_n_available + n_samples - self.capacity
|
|
115
|
+
self.seek(n_to_discard)
|
|
116
|
+
# Update the offset corresponding to the oldest sample in the buffer
|
|
117
|
+
# by anchoring on the new offset and accounting for the samples already available.
|
|
118
|
+
self._linear_axis.offset = axis.offset - self._linear_n_available * axis.gain
|
|
119
|
+
self._linear_n_available += n_samples
|
|
120
|
+
|
|
121
|
+
def peek(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis:
|
|
122
|
+
if self._coords_buffer is not None:
|
|
123
|
+
return replace(self._coords_template, data=self._coords_buffer.peek(n_samples))
|
|
124
|
+
else:
|
|
125
|
+
# Return a shallow copy.
|
|
126
|
+
return replace(self._linear_axis, offset=self._linear_axis.offset)
|
|
127
|
+
|
|
128
|
+
def seek(self, n_samples: int) -> int:
|
|
129
|
+
if self._coords_buffer is not None:
|
|
130
|
+
return self._coords_buffer.seek(n_samples)
|
|
131
|
+
else:
|
|
132
|
+
n_to_seek = min(n_samples, self._linear_n_available)
|
|
133
|
+
self._linear_n_available -= n_to_seek
|
|
134
|
+
self._linear_axis.offset += n_to_seek * self._linear_axis.gain
|
|
135
|
+
return n_to_seek
|
|
136
|
+
|
|
137
|
+
def prune(self, n_samples: int) -> int:
|
|
138
|
+
"""Discards all but the last n_samples from the buffer."""
|
|
139
|
+
n_to_discard = self.available() - n_samples
|
|
140
|
+
if n_to_discard <= 0:
|
|
141
|
+
return 0
|
|
142
|
+
return self.seek(n_to_discard)
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def final_value(self) -> float | None:
|
|
146
|
+
"""
|
|
147
|
+
The axis-value (timestamp, typically) of the last sample in the buffer.
|
|
148
|
+
This does not advance the read head.
|
|
149
|
+
"""
|
|
150
|
+
if self._coords_buffer is not None:
|
|
151
|
+
return self._coords_buffer.peek_last()[0]
|
|
152
|
+
elif self._linear_axis is not None:
|
|
153
|
+
return self._linear_axis.value(self._linear_n_available - 1)
|
|
154
|
+
else:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def first_value(self) -> float | None:
|
|
159
|
+
"""
|
|
160
|
+
The axis-value (timestamp, typically) of the first sample in the buffer.
|
|
161
|
+
This does not advance the read head.
|
|
162
|
+
"""
|
|
163
|
+
if self.available() == 0:
|
|
164
|
+
return None
|
|
165
|
+
if self._coords_buffer is not None:
|
|
166
|
+
return self._coords_buffer.peek_at(0)[0]
|
|
167
|
+
elif self._linear_axis is not None:
|
|
168
|
+
return self._linear_axis.value(0)
|
|
169
|
+
else:
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def gain(self) -> float | None:
|
|
174
|
+
if self._coords_buffer is not None:
|
|
175
|
+
return self._coords_gain_estimate
|
|
176
|
+
elif self._linear_axis is not None:
|
|
177
|
+
return self._linear_axis.gain
|
|
178
|
+
else:
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
def searchsorted(self, values: typing.Union[float, Array], side: str = "left") -> typing.Union[int, Array]:
|
|
182
|
+
if self._coords_buffer is not None:
|
|
183
|
+
return self._coords_buffer.xp.searchsorted(self._coords_buffer.peek(self.available()), values, side=side)
|
|
184
|
+
else:
|
|
185
|
+
if self.available() == 0:
|
|
186
|
+
if isinstance(values, float):
|
|
187
|
+
return 0
|
|
188
|
+
else:
|
|
189
|
+
_xp = get_namespace(values)
|
|
190
|
+
return _xp.zeros_like(values, dtype=int)
|
|
191
|
+
|
|
192
|
+
f_inds = (values - self._linear_axis.offset) / self._linear_axis.gain
|
|
193
|
+
res = np.ceil(f_inds)
|
|
194
|
+
if side == "right":
|
|
195
|
+
res[np.isclose(f_inds, res)] += 1
|
|
196
|
+
return res.astype(int)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class HybridAxisArrayBuffer:
|
|
200
|
+
"""A buffer that intelligently handles ezmsg.util.messages.AxisArray objects.
|
|
201
|
+
|
|
202
|
+
This buffer defers its own initialization until the first message arrives,
|
|
203
|
+
allowing it to automatically configure its size, shape, dtype, and array backend
|
|
204
|
+
(e.g., NumPy, CuPy) based on the message content and a desired buffer duration.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
duration: The desired duration of the buffer in seconds.
|
|
208
|
+
axis: The name of the axis to buffer along.
|
|
209
|
+
**kwargs: Additional keyword arguments to pass to the underlying HybridBuffer
|
|
210
|
+
(e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`).
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
_data_buffer: HybridBuffer | None
|
|
214
|
+
_axis_buffer: HybridAxisBuffer
|
|
215
|
+
_template_msg: AxisArray | None
|
|
216
|
+
|
|
217
|
+
def __init__(self, duration: float, axis: str = "time", **kwargs):
|
|
218
|
+
self.duration = duration
|
|
219
|
+
self._axis = axis
|
|
220
|
+
self.buffer_kwargs = kwargs
|
|
221
|
+
self._axis_buffer = HybridAxisBuffer(duration=duration, **kwargs)
|
|
222
|
+
# Delay initialization until the first message arrives
|
|
223
|
+
self._data_buffer = None
|
|
224
|
+
self._template_msg = None
|
|
225
|
+
|
|
226
|
+
def available(self) -> int:
|
|
227
|
+
"""The total number of unread samples currently available in the buffer."""
|
|
228
|
+
if self._data_buffer is None:
|
|
229
|
+
return 0
|
|
230
|
+
return self._data_buffer.available()
|
|
231
|
+
|
|
232
|
+
def is_empty(self) -> bool:
|
|
233
|
+
return self.available() == 0
|
|
234
|
+
|
|
235
|
+
def is_full(self) -> bool:
|
|
236
|
+
return 0 < self._data_buffer.capacity == self.available()
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def axis_first_value(self) -> float | None:
|
|
240
|
+
"""The axis-value (timestamp, typically) of the first sample in the buffer."""
|
|
241
|
+
return self._axis_buffer.first_value
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def axis_final_value(self) -> float | None:
|
|
245
|
+
"""The axis-value (timestamp, typically) of the last sample in the buffer."""
|
|
246
|
+
return self._axis_buffer.final_value
|
|
247
|
+
|
|
248
|
+
def _initialize(self, first_msg: AxisArray) -> None:
|
|
249
|
+
# Create a template message that has everything except the data are length 0
|
|
250
|
+
# and the target axis is missing.
|
|
251
|
+
self._template_msg = replace(
|
|
252
|
+
first_msg,
|
|
253
|
+
data=first_msg.data[:0],
|
|
254
|
+
axes={k: v for k, v in first_msg.axes.items() if k != self._axis},
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
in_axis = first_msg.axes[self._axis]
|
|
258
|
+
self._axis_buffer._initialize(in_axis)
|
|
259
|
+
|
|
260
|
+
capacity = int(self.duration / self._axis_buffer.gain)
|
|
261
|
+
self._data_buffer = HybridBuffer(
|
|
262
|
+
get_namespace(first_msg.data),
|
|
263
|
+
capacity,
|
|
264
|
+
other_shape=first_msg.data.shape[1:],
|
|
265
|
+
dtype=first_msg.data.dtype,
|
|
266
|
+
**self.buffer_kwargs,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def write(self, msg: AxisArray) -> None:
|
|
270
|
+
"""Adds an AxisArray message to the buffer, initializing on the first call."""
|
|
271
|
+
in_axis_idx = msg.get_axis_idx(self._axis)
|
|
272
|
+
if in_axis_idx > 0:
|
|
273
|
+
# This class assumes that the target axis is the first axis.
|
|
274
|
+
# If it is not, we move it to the front.
|
|
275
|
+
dims = list(msg.dims)
|
|
276
|
+
dims.insert(0, dims.pop(in_axis_idx))
|
|
277
|
+
_xp = get_namespace(msg.data)
|
|
278
|
+
msg = replace(msg, data=_xp.moveaxis(msg.data, in_axis_idx, 0), dims=dims)
|
|
279
|
+
|
|
280
|
+
if self._data_buffer is None:
|
|
281
|
+
self._initialize(msg)
|
|
282
|
+
|
|
283
|
+
self._data_buffer.write(msg.data)
|
|
284
|
+
self._axis_buffer.write(msg.axes[self._axis], msg.shape[0])
|
|
285
|
+
|
|
286
|
+
def peek(self, n_samples: int | None = None) -> AxisArray | None:
|
|
287
|
+
"""Retrieves the oldest unread data as a new AxisArray without advancing the read head."""
|
|
288
|
+
|
|
289
|
+
if self._data_buffer is None:
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
data_array = self._data_buffer.peek(n_samples)
|
|
293
|
+
|
|
294
|
+
if data_array is None:
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
out_axis = self._axis_buffer.peek(n_samples)
|
|
298
|
+
|
|
299
|
+
return replace(
|
|
300
|
+
self._template_msg,
|
|
301
|
+
data=data_array,
|
|
302
|
+
axes={**self._template_msg.axes, self._axis: out_axis},
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def peek_axis(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis | None:
|
|
306
|
+
"""Retrieves the axis data without advancing the read head."""
|
|
307
|
+
if self._data_buffer is None:
|
|
308
|
+
return None
|
|
309
|
+
|
|
310
|
+
out_axis = self._axis_buffer.peek(n_samples)
|
|
311
|
+
|
|
312
|
+
if out_axis is None:
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
return out_axis
|
|
316
|
+
|
|
317
|
+
def seek(self, n_samples: int) -> int:
|
|
318
|
+
"""Advances the read pointer by n_samples."""
|
|
319
|
+
if self._data_buffer is None:
|
|
320
|
+
return 0
|
|
321
|
+
|
|
322
|
+
skipped_data_count = self._data_buffer.seek(n_samples)
|
|
323
|
+
axis_skipped = self._axis_buffer.seek(skipped_data_count)
|
|
324
|
+
assert (
|
|
325
|
+
axis_skipped == skipped_data_count
|
|
326
|
+
), f"Axis buffer skipped {axis_skipped} samples, but data buffer skipped {skipped_data_count}."
|
|
327
|
+
|
|
328
|
+
return skipped_data_count
|
|
329
|
+
|
|
330
|
+
def read(self, n_samples: int | None = None) -> AxisArray | None:
|
|
331
|
+
"""Retrieves the oldest unread data as a new AxisArray and advances the read head."""
|
|
332
|
+
retrieved_axis_array = self.peek(n_samples)
|
|
333
|
+
|
|
334
|
+
if retrieved_axis_array is None or retrieved_axis_array.shape[0] == 0:
|
|
335
|
+
return None
|
|
336
|
+
|
|
337
|
+
self.seek(retrieved_axis_array.shape[0])
|
|
338
|
+
|
|
339
|
+
return retrieved_axis_array
|
|
340
|
+
|
|
341
|
+
def prune(self, n_samples: int) -> int:
|
|
342
|
+
"""Discards all but the last n_samples from the buffer."""
|
|
343
|
+
if self._data_buffer is None:
|
|
344
|
+
return 0
|
|
345
|
+
|
|
346
|
+
n_to_discard = self.available() - n_samples
|
|
347
|
+
if n_to_discard <= 0:
|
|
348
|
+
return 0
|
|
349
|
+
|
|
350
|
+
return self.seek(n_to_discard)
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
def axis_gain(self) -> float | None:
|
|
354
|
+
"""
|
|
355
|
+
The gain of the target axis, which is the time step between samples.
|
|
356
|
+
This is typically the sampling rate (e.g., 1 / fs).
|
|
357
|
+
"""
|
|
358
|
+
return self._axis_buffer.gain
|
|
359
|
+
|
|
360
|
+
def axis_searchsorted(self, values: typing.Union[float, Array], side: str = "left") -> typing.Union[int, Array]:
|
|
361
|
+
"""
|
|
362
|
+
Find the indices into which the given values would be inserted
|
|
363
|
+
into the target axis data to maintain order.
|
|
364
|
+
"""
|
|
365
|
+
return self._axis_buffer.searchsorted(values, side=side)
|