ezmsg-sigproc 2.2.0__py3-none-any.whl → 2.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +16 -3
- ezmsg/sigproc/aggregate.py +69 -0
- ezmsg/sigproc/denormalize.py +86 -0
- ezmsg/sigproc/fbcca.py +332 -0
- ezmsg/sigproc/filter.py +16 -0
- ezmsg/sigproc/filterbankdesign.py +136 -0
- ezmsg/sigproc/firfilter.py +119 -0
- ezmsg/sigproc/kaiser.py +110 -0
- ezmsg/sigproc/resample.py +186 -185
- ezmsg/sigproc/sampler.py +71 -83
- ezmsg/sigproc/util/axisarray_buffer.py +379 -0
- ezmsg/sigproc/util/buffer.py +470 -0
- ezmsg/sigproc/window.py +12 -10
- {ezmsg_sigproc-2.2.0.dist-info → ezmsg_sigproc-2.4.0.dist-info}/METADATA +1 -1
- {ezmsg_sigproc-2.2.0.dist-info → ezmsg_sigproc-2.4.0.dist-info}/RECORD +17 -10
- {ezmsg_sigproc-2.2.0.dist-info → ezmsg_sigproc-2.4.0.dist-info}/WHEEL +1 -1
- {ezmsg_sigproc-2.2.0.dist-info → ezmsg_sigproc-2.4.0.dist-info}/licenses/LICENSE.txt +0 -0
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from collections import deque
|
|
3
|
+
import copy
|
|
3
4
|
import traceback
|
|
4
5
|
import typing
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
|
-
import numpy.typing as npt
|
|
8
8
|
import ezmsg.core as ez
|
|
9
9
|
from ezmsg.util.messages.axisarray import (
|
|
10
10
|
AxisArray,
|
|
11
|
-
slice_along_axis,
|
|
12
11
|
)
|
|
13
12
|
from ezmsg.util.messages.util import replace
|
|
14
13
|
|
|
15
14
|
from .util.profile import profile_subpub
|
|
15
|
+
from .util.axisarray_buffer import HybridAxisArrayBuffer
|
|
16
|
+
from .util.buffer import UpdateStrategy
|
|
16
17
|
from .util.message import SampleMessage, SampleTriggerMessage
|
|
17
18
|
from .base import (
|
|
18
19
|
BaseStatefulTransformer,
|
|
@@ -43,6 +44,7 @@ class SamplerSettings(ez.Settings):
|
|
|
43
44
|
None (default) will choose the first axis in the first input.
|
|
44
45
|
Note: (for now) the axis must exist in the msg .axes and be of type AxisArray.LinearAxis
|
|
45
46
|
"""
|
|
47
|
+
|
|
46
48
|
period: tuple[float, float] | None = None
|
|
47
49
|
"""Optional default period (in seconds) if unspecified in SampleTriggerMessage."""
|
|
48
50
|
|
|
@@ -51,20 +53,25 @@ class SamplerSettings(ez.Settings):
|
|
|
51
53
|
|
|
52
54
|
estimate_alignment: bool = True
|
|
53
55
|
"""
|
|
54
|
-
If true, use message timestamp fields and reported sampling rate to estimate
|
|
56
|
+
If true, use message timestamp fields and reported sampling rate to estimate
|
|
57
|
+
sample-accurate alignment for samples.
|
|
55
58
|
If false, sampling will be limited to incoming message rate -- "Block timing"
|
|
56
59
|
NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
|
|
57
60
|
"realtime" operation for estimate_alignment to operate correctly.
|
|
58
61
|
"""
|
|
59
62
|
|
|
63
|
+
buffer_update_strategy: UpdateStrategy = "immediate"
|
|
64
|
+
"""
|
|
65
|
+
The buffer update strategy. See :obj:`ezmsg.sigproc.util.buffer.UpdateStrategy`.
|
|
66
|
+
If you expect to push data much more frequently than triggers, then "on_demand"
|
|
67
|
+
might be more efficient. For most other scenarios, "immediate" is best.
|
|
68
|
+
"""
|
|
69
|
+
|
|
60
70
|
|
|
61
71
|
@processor_state
|
|
62
72
|
class SamplerState:
|
|
63
|
-
|
|
64
|
-
offset: float | None = None
|
|
65
|
-
buffer: npt.NDArray | None = None
|
|
73
|
+
buffer: HybridAxisArrayBuffer | None = None
|
|
66
74
|
triggers: deque[SampleTriggerMessage] | None = None
|
|
67
|
-
n_samples: int = 0
|
|
68
75
|
|
|
69
76
|
|
|
70
77
|
class SamplerTransformer(
|
|
@@ -73,6 +80,16 @@ class SamplerTransformer(
|
|
|
73
80
|
def __call__(
|
|
74
81
|
self, message: AxisArray | SampleTriggerMessage
|
|
75
82
|
) -> list[SampleMessage]:
|
|
83
|
+
# TODO: Currently we have a single entry point that accepts both
|
|
84
|
+
# data and trigger messages and we choose a code path based on
|
|
85
|
+
# the message type. However, in the future we will likely replace
|
|
86
|
+
# SampleTriggerMessage with an agumented form of AxisArray,
|
|
87
|
+
# leveraging its attrs field, which makes this a bit harder.
|
|
88
|
+
# We should probably force callers of this object to explicitly
|
|
89
|
+
# call `push_trigger` for trigger messages. This will also
|
|
90
|
+
# simplify typing somewhat because `push_trigger` should not
|
|
91
|
+
# return anything yet we currently have it returning an empty
|
|
92
|
+
# list just to be compatible with __call__.
|
|
76
93
|
if isinstance(message, AxisArray):
|
|
77
94
|
return super().__call__(message)
|
|
78
95
|
else:
|
|
@@ -82,102 +99,75 @@ class SamplerTransformer(
|
|
|
82
99
|
# Compute hash based on message properties that require state reset
|
|
83
100
|
axis = self.settings.axis or message.dims[0]
|
|
84
101
|
axis_idx = message.get_axis_idx(axis)
|
|
85
|
-
fs = 1.0 / message.get_axis(axis).gain
|
|
86
102
|
sample_shape = (
|
|
87
103
|
message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
|
|
88
104
|
)
|
|
89
|
-
return hash((
|
|
105
|
+
return hash((sample_shape, message.key))
|
|
90
106
|
|
|
91
107
|
def _reset_state(self, message: AxisArray) -> None:
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
108
|
+
self._state.buffer = HybridAxisArrayBuffer(
|
|
109
|
+
duration=self.settings.buffer_dur,
|
|
110
|
+
axis=self.settings.axis or message.dims[0],
|
|
111
|
+
update_strategy=self.settings.buffer_update_strategy,
|
|
112
|
+
overflow_strategy="warn-overwrite", # True circular buffer
|
|
113
|
+
)
|
|
97
114
|
if self._state.triggers is None:
|
|
98
115
|
self._state.triggers = deque()
|
|
99
116
|
self._state.triggers.clear()
|
|
100
|
-
self._state.n_samples = message.data.shape[axis_idx]
|
|
101
117
|
|
|
102
118
|
def _process(self, message: AxisArray) -> list[SampleMessage]:
|
|
103
|
-
|
|
104
|
-
axis_idx = message.get_axis_idx(axis)
|
|
105
|
-
axis_info = message.get_axis(axis)
|
|
106
|
-
self._state.offset = axis_info.offset
|
|
107
|
-
|
|
108
|
-
# Update buffer
|
|
109
|
-
self._state.buffer = (
|
|
110
|
-
message.data
|
|
111
|
-
if self._state.buffer is None
|
|
112
|
-
else np.concatenate((self._state.buffer, message.data), axis=axis_idx)
|
|
113
|
-
)
|
|
119
|
+
self._state.buffer.write(message)
|
|
114
120
|
|
|
115
|
-
#
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
121
|
+
# How much data in the buffer?
|
|
122
|
+
buff_t_range = (
|
|
123
|
+
self._state.buffer.axis_first_value,
|
|
124
|
+
self._state.buffer.axis_final_value,
|
|
125
|
+
)
|
|
120
126
|
|
|
121
|
-
#
|
|
122
|
-
|
|
123
|
-
for
|
|
127
|
+
# Process in reverse order so that we can remove triggers safely as we iterate.
|
|
128
|
+
msgs_out: list[SampleMessage] = []
|
|
129
|
+
for trig_ix in range(len(self._state.triggers) - 1, -1, -1):
|
|
130
|
+
trig = self._state.triggers[trig_ix]
|
|
124
131
|
if trig.period is None:
|
|
125
|
-
|
|
126
|
-
self._state.triggers
|
|
132
|
+
ez.logger.warning("Sampling failed: trigger period not specified")
|
|
133
|
+
del self._state.triggers[trig_ix]
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
trig_range = trig.timestamp + np.array(trig.period)
|
|
127
137
|
|
|
128
138
|
# If the previous iteration had insufficient data for the trigger timestamp + period,
|
|
129
139
|
# and buffer-management removed data required for the trigger, then we will never be able
|
|
130
140
|
# to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
|
|
131
|
-
if
|
|
141
|
+
if trig_range[0] < buff_t_range[0]:
|
|
132
142
|
ez.logger.warning(
|
|
133
|
-
f"Sampling failed: Buffer span {
|
|
134
|
-
f"requested sample period start: {
|
|
143
|
+
f"Sampling failed: Buffer span {buff_t_range} begins beyond the "
|
|
144
|
+
f"requested sample period start: {trig_range[0]}"
|
|
135
145
|
)
|
|
136
|
-
self._state.triggers
|
|
146
|
+
del self._state.triggers[trig_ix]
|
|
147
|
+
continue
|
|
137
148
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
self._state.buffer, slice(start, stop), axis_idx
|
|
153
|
-
),
|
|
154
|
-
axes={
|
|
155
|
-
**message.axes,
|
|
156
|
-
axis: replace(
|
|
157
|
-
axis_info, offset=buffer_offset[start]
|
|
158
|
-
),
|
|
159
|
-
},
|
|
160
|
-
),
|
|
161
|
-
)
|
|
162
|
-
)
|
|
163
|
-
self._state.triggers.remove(trig)
|
|
164
|
-
|
|
165
|
-
# Trim buffer
|
|
166
|
-
buf_len = int(self.settings.buffer_dur * self._state.fs)
|
|
167
|
-
self._state.buffer = slice_along_axis(
|
|
168
|
-
self._state.buffer, np.s_[-buf_len:], axis_idx
|
|
169
|
-
)
|
|
149
|
+
if trig_range[1] > buff_t_range[1]:
|
|
150
|
+
# We don't *yet* have enough data to satisfy this trigger.
|
|
151
|
+
continue
|
|
152
|
+
|
|
153
|
+
# We know we have enough data in the buffer to satisfy this trigger.
|
|
154
|
+
buff_idx = self._state.buffer.axis_searchsorted(trig_range, side="right")
|
|
155
|
+
self._state.buffer.seek(buff_idx[0]) # FFWD to starting position.
|
|
156
|
+
buff_axarr = self._state.buffer.peek(buff_idx[1] - buff_idx[0])
|
|
157
|
+
self._state.buffer.seek(-buff_idx[0]) # Rewind it back.
|
|
158
|
+
# Note: buffer will trim itself as needed based on buffer_dur.
|
|
159
|
+
|
|
160
|
+
# Prepare output and drop trigger
|
|
161
|
+
msgs_out.append(SampleMessage(trigger=copy.copy(trig), sample=buff_axarr))
|
|
162
|
+
del self._state.triggers[trig_ix]
|
|
170
163
|
|
|
171
|
-
|
|
164
|
+
msgs_out.reverse() # in-place
|
|
165
|
+
return msgs_out
|
|
172
166
|
|
|
173
167
|
def push_trigger(self, message: SampleTriggerMessage) -> list[SampleMessage]:
|
|
174
168
|
# Input is a trigger message that we will use to sample the buffer.
|
|
175
169
|
|
|
176
|
-
if
|
|
177
|
-
self._state.buffer is None
|
|
178
|
-
or not self._state.fs
|
|
179
|
-
or self._state.offset is None
|
|
180
|
-
):
|
|
170
|
+
if self._state.buffer is None:
|
|
181
171
|
# We've yet to see any data; drop the trigger.
|
|
182
172
|
return []
|
|
183
173
|
|
|
@@ -194,11 +184,9 @@ class SamplerTransformer(
|
|
|
194
184
|
return []
|
|
195
185
|
|
|
196
186
|
# Check that period is compatible with buffer duration.
|
|
197
|
-
|
|
198
|
-
req_buf_len = int(np.round((_period[1] - _period[0]) * self._state.fs))
|
|
199
|
-
if req_buf_len >= max_buf_len:
|
|
187
|
+
if (_period[1] - _period[0]) > self.settings.buffer_dur:
|
|
200
188
|
ez.logger.warning(
|
|
201
|
-
f"Sampling failed: {_period=} >= {self.settings.buffer_dur=}"
|
|
189
|
+
f"Sampling failed: trigger period {_period=} >= buffer capacity {self.settings.buffer_dur=}"
|
|
202
190
|
)
|
|
203
191
|
return []
|
|
204
192
|
|
|
@@ -206,7 +194,7 @@ class SamplerTransformer(
|
|
|
206
194
|
if not self.settings.estimate_alignment:
|
|
207
195
|
# Override the trigger timestamp with the next sample's likely timestamp.
|
|
208
196
|
trigger_ts = (
|
|
209
|
-
self._state.
|
|
197
|
+
self._state.buffer.axis_final_value + self._state.buffer.axis_gain
|
|
210
198
|
)
|
|
211
199
|
|
|
212
200
|
new_trig_msg = replace(
|
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
from array_api_compat import get_namespace
|
|
5
|
+
import numpy as np
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, CoordinateAxis
|
|
7
|
+
from ezmsg.util.messages.util import replace
|
|
8
|
+
|
|
9
|
+
from .buffer import HybridBuffer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
Array = typing.TypeVar("Array")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HybridAxisBuffer:
|
|
16
|
+
"""
|
|
17
|
+
A buffer that intelligently handles ezmsg.util.messages.AxisArray _axes_ objects.
|
|
18
|
+
LinearAxis is maintained internally by tracking its offset, gain, and the number
|
|
19
|
+
of samples that have passed through.
|
|
20
|
+
CoordinateAxis has its data values maintained in a `HybridBuffer`.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
duration: The desired duration of the buffer in seconds. This is non-limiting
|
|
24
|
+
when managing a LinearAxis.
|
|
25
|
+
**kwargs: Additional keyword arguments to pass to the underlying HybridBuffer
|
|
26
|
+
(e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`).
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
_coords_buffer: HybridBuffer | None
|
|
30
|
+
_coords_template: CoordinateAxis | None
|
|
31
|
+
_coords_gain_estimate: float | None = None
|
|
32
|
+
_linear_axis: LinearAxis | None
|
|
33
|
+
_linear_n_available: int
|
|
34
|
+
|
|
35
|
+
def __init__(self, duration: float, **kwargs):
|
|
36
|
+
self.duration = duration
|
|
37
|
+
self.buffer_kwargs = kwargs
|
|
38
|
+
# Delay initialization until the first message arrives
|
|
39
|
+
self._coords_buffer = None
|
|
40
|
+
self._coords_template = None
|
|
41
|
+
self._linear_axis = None
|
|
42
|
+
self._linear_n_available = 0
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def capacity(self) -> int:
|
|
46
|
+
"""The maximum number of samples that can be stored in the buffer."""
|
|
47
|
+
if self._coords_buffer is not None:
|
|
48
|
+
return self._coords_buffer.capacity
|
|
49
|
+
elif self._linear_axis is not None:
|
|
50
|
+
return int(math.ceil(self.duration / self._linear_axis.gain))
|
|
51
|
+
else:
|
|
52
|
+
return 0
|
|
53
|
+
|
|
54
|
+
def available(self) -> int:
|
|
55
|
+
if self._coords_buffer is None:
|
|
56
|
+
return self._linear_n_available
|
|
57
|
+
return self._coords_buffer.available()
|
|
58
|
+
|
|
59
|
+
def is_empty(self) -> bool:
|
|
60
|
+
return self.available() == 0
|
|
61
|
+
|
|
62
|
+
def is_full(self) -> bool:
|
|
63
|
+
if self._coords_buffer is not None:
|
|
64
|
+
return self._coords_buffer.is_full()
|
|
65
|
+
return 0 < self.capacity == self.available()
|
|
66
|
+
|
|
67
|
+
def _initialize(self, first_axis: LinearAxis | CoordinateAxis) -> None:
|
|
68
|
+
if hasattr(first_axis, "data"):
|
|
69
|
+
# Initialize a CoordinateAxis buffer
|
|
70
|
+
if len(first_axis.data) > 1:
|
|
71
|
+
_axis_gain = (first_axis.data[-1] - first_axis.data[0]) / (
|
|
72
|
+
len(first_axis.data) - 1
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
_axis_gain = 1.0
|
|
76
|
+
self._coords_gain_estimate = _axis_gain
|
|
77
|
+
capacity = int(self.duration / _axis_gain)
|
|
78
|
+
self._coords_buffer = HybridBuffer(
|
|
79
|
+
get_namespace(first_axis.data),
|
|
80
|
+
capacity,
|
|
81
|
+
other_shape=(),
|
|
82
|
+
dtype=first_axis.data.dtype,
|
|
83
|
+
**self.buffer_kwargs,
|
|
84
|
+
)
|
|
85
|
+
self._coords_template = replace(first_axis, data=first_axis.data[:0].copy())
|
|
86
|
+
else:
|
|
87
|
+
# Initialize a LinearAxis buffer
|
|
88
|
+
self._linear_axis = replace(first_axis, offset=first_axis.offset)
|
|
89
|
+
self._linear_n_available = 0
|
|
90
|
+
|
|
91
|
+
def write(self, axis: LinearAxis | CoordinateAxis, n_samples: int) -> None:
|
|
92
|
+
if self._linear_axis is None and self._coords_buffer is None:
|
|
93
|
+
self._initialize(axis)
|
|
94
|
+
|
|
95
|
+
if self._coords_buffer is not None:
|
|
96
|
+
if axis.__class__ is not self._coords_template.__class__:
|
|
97
|
+
raise TypeError(
|
|
98
|
+
f"Buffer initialized with {self._coords_template.__class__.__name__}, "
|
|
99
|
+
f"but received {axis.__class__.__name__}."
|
|
100
|
+
)
|
|
101
|
+
self._coords_buffer.write(axis.data)
|
|
102
|
+
else:
|
|
103
|
+
if axis.__class__ is not self._linear_axis.__class__:
|
|
104
|
+
raise TypeError(
|
|
105
|
+
f"Buffer initialized with {self._linear_axis.__class__.__name__}, "
|
|
106
|
+
f"but received {axis.__class__.__name__}."
|
|
107
|
+
)
|
|
108
|
+
if axis.gain != self._linear_axis.gain:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"Buffer initialized with gain={self._linear_axis.gain}, "
|
|
111
|
+
f"but received gain={axis.gain}."
|
|
112
|
+
)
|
|
113
|
+
if self._linear_n_available + n_samples > self.capacity:
|
|
114
|
+
# Simulate overflow by advancing the offset and decreasing
|
|
115
|
+
# the number of available samples.
|
|
116
|
+
n_to_discard = self._linear_n_available + n_samples - self.capacity
|
|
117
|
+
self.seek(n_to_discard)
|
|
118
|
+
# Update the offset corresponding to the oldest sample in the buffer
|
|
119
|
+
# by anchoring on the new offset and accounting for the samples already available.
|
|
120
|
+
self._linear_axis.offset = (
|
|
121
|
+
axis.offset - self._linear_n_available * axis.gain
|
|
122
|
+
)
|
|
123
|
+
self._linear_n_available += n_samples
|
|
124
|
+
|
|
125
|
+
def peek(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis:
|
|
126
|
+
if self._coords_buffer is not None:
|
|
127
|
+
return replace(
|
|
128
|
+
self._coords_template, data=self._coords_buffer.peek(n_samples)
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
# Return a shallow copy.
|
|
132
|
+
return replace(self._linear_axis, offset=self._linear_axis.offset)
|
|
133
|
+
|
|
134
|
+
def seek(self, n_samples: int) -> int:
|
|
135
|
+
if self._coords_buffer is not None:
|
|
136
|
+
return self._coords_buffer.seek(n_samples)
|
|
137
|
+
else:
|
|
138
|
+
n_to_seek = min(n_samples, self._linear_n_available)
|
|
139
|
+
self._linear_n_available -= n_to_seek
|
|
140
|
+
self._linear_axis.offset += n_to_seek * self._linear_axis.gain
|
|
141
|
+
return n_to_seek
|
|
142
|
+
|
|
143
|
+
def prune(self, n_samples: int) -> int:
|
|
144
|
+
"""Discards all but the last n_samples from the buffer."""
|
|
145
|
+
n_to_discard = self.available() - n_samples
|
|
146
|
+
if n_to_discard <= 0:
|
|
147
|
+
return 0
|
|
148
|
+
return self.seek(n_to_discard)
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def final_value(self) -> float | None:
|
|
152
|
+
"""
|
|
153
|
+
The axis-value (timestamp, typically) of the last sample in the buffer.
|
|
154
|
+
This does not advance the read head.
|
|
155
|
+
"""
|
|
156
|
+
if self._coords_buffer is not None:
|
|
157
|
+
return self._coords_buffer.peek_last()[0]
|
|
158
|
+
elif self._linear_axis is not None:
|
|
159
|
+
return self._linear_axis.value(self._linear_n_available - 1)
|
|
160
|
+
else:
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def first_value(self) -> float | None:
|
|
165
|
+
"""
|
|
166
|
+
The axis-value (timestamp, typically) of the first sample in the buffer.
|
|
167
|
+
This does not advance the read head.
|
|
168
|
+
"""
|
|
169
|
+
if self.available() == 0:
|
|
170
|
+
return None
|
|
171
|
+
if self._coords_buffer is not None:
|
|
172
|
+
return self._coords_buffer.peek_at(0)[0]
|
|
173
|
+
elif self._linear_axis is not None:
|
|
174
|
+
return self._linear_axis.value(0)
|
|
175
|
+
else:
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def gain(self) -> float | None:
|
|
180
|
+
if self._coords_buffer is not None:
|
|
181
|
+
return self._coords_gain_estimate
|
|
182
|
+
elif self._linear_axis is not None:
|
|
183
|
+
return self._linear_axis.gain
|
|
184
|
+
else:
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
def searchsorted(
|
|
188
|
+
self, values: typing.Union[float, Array], side: str = "left"
|
|
189
|
+
) -> typing.Union[int, Array]:
|
|
190
|
+
if self._coords_buffer is not None:
|
|
191
|
+
return self._coords_buffer.xp.searchsorted(
|
|
192
|
+
self._coords_buffer.peek(self.available()), values, side=side
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
if self.available() == 0:
|
|
196
|
+
if isinstance(values, float):
|
|
197
|
+
return 0
|
|
198
|
+
else:
|
|
199
|
+
_xp = get_namespace(values)
|
|
200
|
+
return _xp.zeros_like(values, dtype=int)
|
|
201
|
+
|
|
202
|
+
f_inds = (values - self._linear_axis.offset) / self._linear_axis.gain
|
|
203
|
+
res = np.ceil(f_inds)
|
|
204
|
+
if side == "right":
|
|
205
|
+
res[np.isclose(f_inds, res)] += 1
|
|
206
|
+
return res.astype(int)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class HybridAxisArrayBuffer:
|
|
210
|
+
"""A buffer that intelligently handles ezmsg.util.messages.AxisArray objects.
|
|
211
|
+
|
|
212
|
+
This buffer defers its own initialization until the first message arrives,
|
|
213
|
+
allowing it to automatically configure its size, shape, dtype, and array backend
|
|
214
|
+
(e.g., NumPy, CuPy) based on the message content and a desired buffer duration.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
duration: The desired duration of the buffer in seconds.
|
|
218
|
+
axis: The name of the axis to buffer along.
|
|
219
|
+
**kwargs: Additional keyword arguments to pass to the underlying HybridBuffer
|
|
220
|
+
(e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`).
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
_data_buffer: HybridBuffer | None
|
|
224
|
+
_axis_buffer: HybridAxisBuffer
|
|
225
|
+
_template_msg: AxisArray | None
|
|
226
|
+
|
|
227
|
+
def __init__(self, duration: float, axis: str = "time", **kwargs):
|
|
228
|
+
self.duration = duration
|
|
229
|
+
self._axis = axis
|
|
230
|
+
self.buffer_kwargs = kwargs
|
|
231
|
+
self._axis_buffer = HybridAxisBuffer(duration=duration, **kwargs)
|
|
232
|
+
# Delay initialization until the first message arrives
|
|
233
|
+
self._data_buffer = None
|
|
234
|
+
self._template_msg = None
|
|
235
|
+
|
|
236
|
+
def available(self) -> int:
|
|
237
|
+
"""The total number of unread samples currently available in the buffer."""
|
|
238
|
+
if self._data_buffer is None:
|
|
239
|
+
return 0
|
|
240
|
+
return self._data_buffer.available()
|
|
241
|
+
|
|
242
|
+
def is_empty(self) -> bool:
|
|
243
|
+
return self.available() == 0
|
|
244
|
+
|
|
245
|
+
def is_full(self) -> bool:
|
|
246
|
+
return 0 < self._data_buffer.capacity == self.available()
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def axis_first_value(self) -> float | None:
|
|
250
|
+
"""The axis-value (timestamp, typically) of the first sample in the buffer."""
|
|
251
|
+
return self._axis_buffer.first_value
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def axis_final_value(self) -> float | None:
|
|
255
|
+
"""The axis-value (timestamp, typically) of the last sample in the buffer."""
|
|
256
|
+
return self._axis_buffer.final_value
|
|
257
|
+
|
|
258
|
+
def _initialize(self, first_msg: AxisArray) -> None:
|
|
259
|
+
# Create a template message that has everything except the data are length 0
|
|
260
|
+
# and the target axis is missing.
|
|
261
|
+
self._template_msg = replace(
|
|
262
|
+
first_msg,
|
|
263
|
+
data=first_msg.data[:0],
|
|
264
|
+
axes={k: v for k, v in first_msg.axes.items() if k != self._axis},
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
in_axis = first_msg.axes[self._axis]
|
|
268
|
+
self._axis_buffer._initialize(in_axis)
|
|
269
|
+
|
|
270
|
+
capacity = int(self.duration / self._axis_buffer.gain)
|
|
271
|
+
self._data_buffer = HybridBuffer(
|
|
272
|
+
get_namespace(first_msg.data),
|
|
273
|
+
capacity,
|
|
274
|
+
other_shape=first_msg.data.shape[1:],
|
|
275
|
+
dtype=first_msg.data.dtype,
|
|
276
|
+
**self.buffer_kwargs,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def write(self, msg: AxisArray) -> None:
|
|
280
|
+
"""Adds an AxisArray message to the buffer, initializing on the first call."""
|
|
281
|
+
in_axis_idx = msg.get_axis_idx(self._axis)
|
|
282
|
+
if in_axis_idx > 0:
|
|
283
|
+
# This class assumes that the target axis is the first axis.
|
|
284
|
+
# If it is not, we move it to the front.
|
|
285
|
+
dims = list(msg.dims)
|
|
286
|
+
dims.insert(0, dims.pop(in_axis_idx))
|
|
287
|
+
_xp = get_namespace(msg.data)
|
|
288
|
+
msg = replace(msg, data=_xp.moveaxis(msg.data, in_axis_idx, 0), dims=dims)
|
|
289
|
+
|
|
290
|
+
if self._data_buffer is None:
|
|
291
|
+
self._initialize(msg)
|
|
292
|
+
|
|
293
|
+
self._data_buffer.write(msg.data)
|
|
294
|
+
self._axis_buffer.write(msg.axes[self._axis], msg.shape[0])
|
|
295
|
+
|
|
296
|
+
def peek(self, n_samples: int | None = None) -> AxisArray | None:
|
|
297
|
+
"""Retrieves the oldest unread data as a new AxisArray without advancing the read head."""
|
|
298
|
+
|
|
299
|
+
if self._data_buffer is None:
|
|
300
|
+
return None
|
|
301
|
+
|
|
302
|
+
data_array = self._data_buffer.peek(n_samples)
|
|
303
|
+
|
|
304
|
+
if data_array is None:
|
|
305
|
+
return None
|
|
306
|
+
|
|
307
|
+
out_axis = self._axis_buffer.peek(n_samples)
|
|
308
|
+
|
|
309
|
+
return replace(
|
|
310
|
+
self._template_msg,
|
|
311
|
+
data=data_array,
|
|
312
|
+
axes={**self._template_msg.axes, self._axis: out_axis},
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def peek_axis(
|
|
316
|
+
self, n_samples: int | None = None
|
|
317
|
+
) -> LinearAxis | CoordinateAxis | None:
|
|
318
|
+
"""Retrieves the axis data without advancing the read head."""
|
|
319
|
+
if self._data_buffer is None:
|
|
320
|
+
return None
|
|
321
|
+
|
|
322
|
+
out_axis = self._axis_buffer.peek(n_samples)
|
|
323
|
+
|
|
324
|
+
if out_axis is None:
|
|
325
|
+
return None
|
|
326
|
+
|
|
327
|
+
return out_axis
|
|
328
|
+
|
|
329
|
+
def seek(self, n_samples: int) -> int:
|
|
330
|
+
"""Advances the read pointer by n_samples."""
|
|
331
|
+
if self._data_buffer is None:
|
|
332
|
+
return 0
|
|
333
|
+
|
|
334
|
+
skipped_data_count = self._data_buffer.seek(n_samples)
|
|
335
|
+
axis_skipped = self._axis_buffer.seek(skipped_data_count)
|
|
336
|
+
assert (
|
|
337
|
+
axis_skipped == skipped_data_count
|
|
338
|
+
), f"Axis buffer skipped {axis_skipped} samples, but data buffer skipped {skipped_data_count}."
|
|
339
|
+
|
|
340
|
+
return skipped_data_count
|
|
341
|
+
|
|
342
|
+
def read(self, n_samples: int | None = None) -> AxisArray | None:
|
|
343
|
+
"""Retrieves the oldest unread data as a new AxisArray and advances the read head."""
|
|
344
|
+
retrieved_axis_array = self.peek(n_samples)
|
|
345
|
+
|
|
346
|
+
if retrieved_axis_array is None or retrieved_axis_array.shape[0] == 0:
|
|
347
|
+
return None
|
|
348
|
+
|
|
349
|
+
self.seek(retrieved_axis_array.shape[0])
|
|
350
|
+
|
|
351
|
+
return retrieved_axis_array
|
|
352
|
+
|
|
353
|
+
def prune(self, n_samples: int) -> int:
|
|
354
|
+
"""Discards all but the last n_samples from the buffer."""
|
|
355
|
+
if self._data_buffer is None:
|
|
356
|
+
return 0
|
|
357
|
+
|
|
358
|
+
n_to_discard = self.available() - n_samples
|
|
359
|
+
if n_to_discard <= 0:
|
|
360
|
+
return 0
|
|
361
|
+
|
|
362
|
+
return self.seek(n_to_discard)
|
|
363
|
+
|
|
364
|
+
@property
|
|
365
|
+
def axis_gain(self) -> float | None:
|
|
366
|
+
"""
|
|
367
|
+
The gain of the target axis, which is the time step between samples.
|
|
368
|
+
This is typically the sampling rate (e.g., 1 / fs).
|
|
369
|
+
"""
|
|
370
|
+
return self._axis_buffer.gain
|
|
371
|
+
|
|
372
|
+
def axis_searchsorted(
|
|
373
|
+
self, values: typing.Union[float, Array], side: str = "left"
|
|
374
|
+
) -> typing.Union[int, Array]:
|
|
375
|
+
"""
|
|
376
|
+
Find the indices into which the given values would be inserted
|
|
377
|
+
into the target axis data to maintain order.
|
|
378
|
+
"""
|
|
379
|
+
return self._axis_buffer.searchsorted(values, side=side)
|