ezmsg-sigproc 2.2.0__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/sampler.py CHANGED
@@ -1,18 +1,19 @@
1
1
  import asyncio
2
2
  from collections import deque
3
+ import copy
3
4
  import traceback
4
5
  import typing
5
6
 
6
7
  import numpy as np
7
- import numpy.typing as npt
8
8
  import ezmsg.core as ez
9
9
  from ezmsg.util.messages.axisarray import (
10
10
  AxisArray,
11
- slice_along_axis,
12
11
  )
13
12
  from ezmsg.util.messages.util import replace
14
13
 
15
14
  from .util.profile import profile_subpub
15
+ from .util.axisarray_buffer import HybridAxisArrayBuffer
16
+ from .util.buffer import UpdateStrategy
16
17
  from .util.message import SampleMessage, SampleTriggerMessage
17
18
  from .base import (
18
19
  BaseStatefulTransformer,
@@ -43,6 +44,7 @@ class SamplerSettings(ez.Settings):
43
44
  None (default) will choose the first axis in the first input.
44
45
  Note: (for now) the axis must exist in the msg .axes and be of type AxisArray.LinearAxis
45
46
  """
47
+
46
48
  period: tuple[float, float] | None = None
47
49
  """Optional default period (in seconds) if unspecified in SampleTriggerMessage."""
48
50
 
@@ -51,20 +53,25 @@ class SamplerSettings(ez.Settings):
51
53
 
52
54
  estimate_alignment: bool = True
53
55
  """
54
- If true, use message timestamp fields and reported sampling rate to estimate sample-accurate alignment for samples.
56
+ If true, use message timestamp fields and reported sampling rate to estimate
57
+ sample-accurate alignment for samples.
55
58
  If false, sampling will be limited to incoming message rate -- "Block timing"
56
59
  NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
57
60
  "realtime" operation for estimate_alignment to operate correctly.
58
61
  """
59
62
 
63
+ buffer_update_strategy: UpdateStrategy = "immediate"
64
+ """
65
+ The buffer update strategy. See :obj:`ezmsg.sigproc.util.buffer.UpdateStrategy`.
66
+ If you expect to push data much more frequently than triggers, then "on_demand"
67
+ might be more efficient. For most other scenarios, "immediate" is best.
68
+ """
69
+
60
70
 
61
71
  @processor_state
62
72
  class SamplerState:
63
- fs: float = 0.0
64
- offset: float | None = None
65
- buffer: npt.NDArray | None = None
73
+ buffer: HybridAxisArrayBuffer | None = None
66
74
  triggers: deque[SampleTriggerMessage] | None = None
67
- n_samples: int = 0
68
75
 
69
76
 
70
77
  class SamplerTransformer(
@@ -73,6 +80,16 @@ class SamplerTransformer(
73
80
  def __call__(
74
81
  self, message: AxisArray | SampleTriggerMessage
75
82
  ) -> list[SampleMessage]:
83
+ # TODO: Currently we have a single entry point that accepts both
84
+ # data and trigger messages and we choose a code path based on
85
+ # the message type. However, in the future we will likely replace
86
+ # SampleTriggerMessage with an agumented form of AxisArray,
87
+ # leveraging its attrs field, which makes this a bit harder.
88
+ # We should probably force callers of this object to explicitly
89
+ # call `push_trigger` for trigger messages. This will also
90
+ # simplify typing somewhat because `push_trigger` should not
91
+ # return anything yet we currently have it returning an empty
92
+ # list just to be compatible with __call__.
76
93
  if isinstance(message, AxisArray):
77
94
  return super().__call__(message)
78
95
  else:
@@ -82,102 +99,75 @@ class SamplerTransformer(
82
99
  # Compute hash based on message properties that require state reset
83
100
  axis = self.settings.axis or message.dims[0]
84
101
  axis_idx = message.get_axis_idx(axis)
85
- fs = 1.0 / message.get_axis(axis).gain
86
102
  sample_shape = (
87
103
  message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
88
104
  )
89
- return hash((fs, sample_shape, axis_idx, message.key))
105
+ return hash((sample_shape, message.key))
90
106
 
91
107
  def _reset_state(self, message: AxisArray) -> None:
92
- axis = self.settings.axis or message.dims[0]
93
- axis_idx = message.get_axis_idx(axis)
94
- axis_info = message.get_axis(axis)
95
- self._state.fs = 1.0 / axis_info.gain
96
- self._state.buffer = None
108
+ self._state.buffer = HybridAxisArrayBuffer(
109
+ duration=self.settings.buffer_dur,
110
+ axis=self.settings.axis or message.dims[0],
111
+ update_strategy=self.settings.buffer_update_strategy,
112
+ overflow_strategy="warn-overwrite", # True circular buffer
113
+ )
97
114
  if self._state.triggers is None:
98
115
  self._state.triggers = deque()
99
116
  self._state.triggers.clear()
100
- self._state.n_samples = message.data.shape[axis_idx]
101
117
 
102
118
  def _process(self, message: AxisArray) -> list[SampleMessage]:
103
- axis = self.settings.axis or message.dims[0]
104
- axis_idx = message.get_axis_idx(axis)
105
- axis_info = message.get_axis(axis)
106
- self._state.offset = axis_info.offset
107
-
108
- # Update buffer
109
- self._state.buffer = (
110
- message.data
111
- if self._state.buffer is None
112
- else np.concatenate((self._state.buffer, message.data), axis=axis_idx)
113
- )
119
+ self._state.buffer.write(message)
114
120
 
115
- # Calculate timestamps associated with buffer.
116
- buffer_offset = np.arange(self._state.buffer.shape[axis_idx], dtype=float)
117
- buffer_offset -= buffer_offset[-message.data.shape[axis_idx]]
118
- buffer_offset *= axis_info.gain
119
- buffer_offset += axis_info.offset
121
+ # How much data in the buffer?
122
+ buff_t_range = (
123
+ self._state.buffer.axis_first_value,
124
+ self._state.buffer.axis_final_value,
125
+ )
120
126
 
121
- # ... for each trigger, collect the message (if possible) and append to msg_out
122
- msg_out: list[SampleMessage] = []
123
- for trig in list(self._state.triggers):
127
+ # Process in reverse order so that we can remove triggers safely as we iterate.
128
+ msgs_out: list[SampleMessage] = []
129
+ for trig_ix in range(len(self._state.triggers) - 1, -1, -1):
130
+ trig = self._state.triggers[trig_ix]
124
131
  if trig.period is None:
125
- # This trigger was malformed; drop it.
126
- self._state.triggers.remove(trig)
132
+ ez.logger.warning("Sampling failed: trigger period not specified")
133
+ del self._state.triggers[trig_ix]
134
+ continue
135
+
136
+ trig_range = trig.timestamp + np.array(trig.period)
127
137
 
128
138
  # If the previous iteration had insufficient data for the trigger timestamp + period,
129
139
  # and buffer-management removed data required for the trigger, then we will never be able
130
140
  # to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
131
- if (trig.timestamp + trig.period[0]) < buffer_offset[0]:
141
+ if trig_range[0] < buff_t_range[0]:
132
142
  ez.logger.warning(
133
- f"Sampling failed: Buffer span {buffer_offset[0]} is beyond the "
134
- f"requested sample period start: {trig.timestamp + trig.period[0]}"
143
+ f"Sampling failed: Buffer span {buff_t_range} begins beyond the "
144
+ f"requested sample period start: {trig_range[0]}"
135
145
  )
136
- self._state.triggers.remove(trig)
146
+ del self._state.triggers[trig_ix]
147
+ continue
137
148
 
138
- t_start = trig.timestamp + trig.period[0]
139
- if t_start >= buffer_offset[0]:
140
- start = np.searchsorted(buffer_offset, t_start)
141
- stop = start + int(
142
- np.round(self._state.fs * (trig.period[1] - trig.period[0]))
143
- )
144
- if self._state.buffer.shape[axis_idx] > stop:
145
- # Trigger period fully enclosed in buffer.
146
- msg_out.append(
147
- SampleMessage(
148
- trigger=trig,
149
- sample=replace(
150
- message,
151
- data=slice_along_axis(
152
- self._state.buffer, slice(start, stop), axis_idx
153
- ),
154
- axes={
155
- **message.axes,
156
- axis: replace(
157
- axis_info, offset=buffer_offset[start]
158
- ),
159
- },
160
- ),
161
- )
162
- )
163
- self._state.triggers.remove(trig)
164
-
165
- # Trim buffer
166
- buf_len = int(self.settings.buffer_dur * self._state.fs)
167
- self._state.buffer = slice_along_axis(
168
- self._state.buffer, np.s_[-buf_len:], axis_idx
169
- )
149
+ if trig_range[1] > buff_t_range[1]:
150
+ # We don't *yet* have enough data to satisfy this trigger.
151
+ continue
152
+
153
+ # We know we have enough data in the buffer to satisfy this trigger.
154
+ buff_idx = self._state.buffer.axis_searchsorted(trig_range, side="right")
155
+ self._state.buffer.seek(buff_idx[0]) # FFWD to starting position.
156
+ buff_axarr = self._state.buffer.peek(buff_idx[1] - buff_idx[0])
157
+ self._state.buffer.seek(-buff_idx[0]) # Rewind it back.
158
+ # Note: buffer will trim itself as needed based on buffer_dur.
159
+
160
+ # Prepare output and drop trigger
161
+ msgs_out.append(SampleMessage(trigger=copy.copy(trig), sample=buff_axarr))
162
+ del self._state.triggers[trig_ix]
170
163
 
171
- return msg_out
164
+ msgs_out.reverse() # in-place
165
+ return msgs_out
172
166
 
173
167
  def push_trigger(self, message: SampleTriggerMessage) -> list[SampleMessage]:
174
168
  # Input is a trigger message that we will use to sample the buffer.
175
169
 
176
- if (
177
- self._state.buffer is None
178
- or not self._state.fs
179
- or self._state.offset is None
180
- ):
170
+ if self._state.buffer is None:
181
171
  # We've yet to see any data; drop the trigger.
182
172
  return []
183
173
 
@@ -194,11 +184,9 @@ class SamplerTransformer(
194
184
  return []
195
185
 
196
186
  # Check that period is compatible with buffer duration.
197
- max_buf_len = int(np.round(self.settings.buffer_dur * self._state.fs))
198
- req_buf_len = int(np.round((_period[1] - _period[0]) * self._state.fs))
199
- if req_buf_len >= max_buf_len:
187
+ if (_period[1] - _period[0]) > self.settings.buffer_dur:
200
188
  ez.logger.warning(
201
- f"Sampling failed: {_period=} >= {self.settings.buffer_dur=}"
189
+ f"Sampling failed: trigger period {_period=} >= buffer capacity {self.settings.buffer_dur=}"
202
190
  )
203
191
  return []
204
192
 
@@ -206,7 +194,7 @@ class SamplerTransformer(
206
194
  if not self.settings.estimate_alignment:
207
195
  # Override the trigger timestamp with the next sample's likely timestamp.
208
196
  trigger_ts = (
209
- self._state.offset + (self.state.n_samples + 1) / self._state.fs
197
+ self._state.buffer.axis_final_value + self._state.buffer.axis_gain
210
198
  )
211
199
 
212
200
  new_trig_msg = replace(
@@ -0,0 +1,379 @@
1
+ import math
2
+ import typing
3
+
4
+ from array_api_compat import get_namespace
5
+ import numpy as np
6
+ from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, CoordinateAxis
7
+ from ezmsg.util.messages.util import replace
8
+
9
+ from .buffer import HybridBuffer
10
+
11
+
12
+ Array = typing.TypeVar("Array")
13
+
14
+
15
+ class HybridAxisBuffer:
16
+ """
17
+ A buffer that intelligently handles ezmsg.util.messages.AxisArray _axes_ objects.
18
+ LinearAxis is maintained internally by tracking its offset, gain, and the number
19
+ of samples that have passed through.
20
+ CoordinateAxis has its data values maintained in a `HybridBuffer`.
21
+
22
+ Args:
23
+ duration: The desired duration of the buffer in seconds. This is non-limiting
24
+ when managing a LinearAxis.
25
+ **kwargs: Additional keyword arguments to pass to the underlying HybridBuffer
26
+ (e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`).
27
+ """
28
+
29
+ _coords_buffer: HybridBuffer | None
30
+ _coords_template: CoordinateAxis | None
31
+ _coords_gain_estimate: float | None = None
32
+ _linear_axis: LinearAxis | None
33
+ _linear_n_available: int
34
+
35
+ def __init__(self, duration: float, **kwargs):
36
+ self.duration = duration
37
+ self.buffer_kwargs = kwargs
38
+ # Delay initialization until the first message arrives
39
+ self._coords_buffer = None
40
+ self._coords_template = None
41
+ self._linear_axis = None
42
+ self._linear_n_available = 0
43
+
44
+ @property
45
+ def capacity(self) -> int:
46
+ """The maximum number of samples that can be stored in the buffer."""
47
+ if self._coords_buffer is not None:
48
+ return self._coords_buffer.capacity
49
+ elif self._linear_axis is not None:
50
+ return int(math.ceil(self.duration / self._linear_axis.gain))
51
+ else:
52
+ return 0
53
+
54
+ def available(self) -> int:
55
+ if self._coords_buffer is None:
56
+ return self._linear_n_available
57
+ return self._coords_buffer.available()
58
+
59
+ def is_empty(self) -> bool:
60
+ return self.available() == 0
61
+
62
+ def is_full(self) -> bool:
63
+ if self._coords_buffer is not None:
64
+ return self._coords_buffer.is_full()
65
+ return 0 < self.capacity == self.available()
66
+
67
+ def _initialize(self, first_axis: LinearAxis | CoordinateAxis) -> None:
68
+ if hasattr(first_axis, "data"):
69
+ # Initialize a CoordinateAxis buffer
70
+ if len(first_axis.data) > 1:
71
+ _axis_gain = (first_axis.data[-1] - first_axis.data[0]) / (
72
+ len(first_axis.data) - 1
73
+ )
74
+ else:
75
+ _axis_gain = 1.0
76
+ self._coords_gain_estimate = _axis_gain
77
+ capacity = int(self.duration / _axis_gain)
78
+ self._coords_buffer = HybridBuffer(
79
+ get_namespace(first_axis.data),
80
+ capacity,
81
+ other_shape=(),
82
+ dtype=first_axis.data.dtype,
83
+ **self.buffer_kwargs,
84
+ )
85
+ self._coords_template = replace(first_axis, data=first_axis.data[:0].copy())
86
+ else:
87
+ # Initialize a LinearAxis buffer
88
+ self._linear_axis = replace(first_axis, offset=first_axis.offset)
89
+ self._linear_n_available = 0
90
+
91
+ def write(self, axis: LinearAxis | CoordinateAxis, n_samples: int) -> None:
92
+ if self._linear_axis is None and self._coords_buffer is None:
93
+ self._initialize(axis)
94
+
95
+ if self._coords_buffer is not None:
96
+ if axis.__class__ is not self._coords_template.__class__:
97
+ raise TypeError(
98
+ f"Buffer initialized with {self._coords_template.__class__.__name__}, "
99
+ f"but received {axis.__class__.__name__}."
100
+ )
101
+ self._coords_buffer.write(axis.data)
102
+ else:
103
+ if axis.__class__ is not self._linear_axis.__class__:
104
+ raise TypeError(
105
+ f"Buffer initialized with {self._linear_axis.__class__.__name__}, "
106
+ f"but received {axis.__class__.__name__}."
107
+ )
108
+ if axis.gain != self._linear_axis.gain:
109
+ raise ValueError(
110
+ f"Buffer initialized with gain={self._linear_axis.gain}, "
111
+ f"but received gain={axis.gain}."
112
+ )
113
+ if self._linear_n_available + n_samples > self.capacity:
114
+ # Simulate overflow by advancing the offset and decreasing
115
+ # the number of available samples.
116
+ n_to_discard = self._linear_n_available + n_samples - self.capacity
117
+ self.seek(n_to_discard)
118
+ # Update the offset corresponding to the oldest sample in the buffer
119
+ # by anchoring on the new offset and accounting for the samples already available.
120
+ self._linear_axis.offset = (
121
+ axis.offset - self._linear_n_available * axis.gain
122
+ )
123
+ self._linear_n_available += n_samples
124
+
125
+ def peek(self, n_samples: int | None = None) -> LinearAxis | CoordinateAxis:
126
+ if self._coords_buffer is not None:
127
+ return replace(
128
+ self._coords_template, data=self._coords_buffer.peek(n_samples)
129
+ )
130
+ else:
131
+ # Return a shallow copy.
132
+ return replace(self._linear_axis, offset=self._linear_axis.offset)
133
+
134
+ def seek(self, n_samples: int) -> int:
135
+ if self._coords_buffer is not None:
136
+ return self._coords_buffer.seek(n_samples)
137
+ else:
138
+ n_to_seek = min(n_samples, self._linear_n_available)
139
+ self._linear_n_available -= n_to_seek
140
+ self._linear_axis.offset += n_to_seek * self._linear_axis.gain
141
+ return n_to_seek
142
+
143
+ def prune(self, n_samples: int) -> int:
144
+ """Discards all but the last n_samples from the buffer."""
145
+ n_to_discard = self.available() - n_samples
146
+ if n_to_discard <= 0:
147
+ return 0
148
+ return self.seek(n_to_discard)
149
+
150
+ @property
151
+ def final_value(self) -> float | None:
152
+ """
153
+ The axis-value (timestamp, typically) of the last sample in the buffer.
154
+ This does not advance the read head.
155
+ """
156
+ if self._coords_buffer is not None:
157
+ return self._coords_buffer.peek_last()[0]
158
+ elif self._linear_axis is not None:
159
+ return self._linear_axis.value(self._linear_n_available - 1)
160
+ else:
161
+ return None
162
+
163
+ @property
164
+ def first_value(self) -> float | None:
165
+ """
166
+ The axis-value (timestamp, typically) of the first sample in the buffer.
167
+ This does not advance the read head.
168
+ """
169
+ if self.available() == 0:
170
+ return None
171
+ if self._coords_buffer is not None:
172
+ return self._coords_buffer.peek_at(0)[0]
173
+ elif self._linear_axis is not None:
174
+ return self._linear_axis.value(0)
175
+ else:
176
+ return None
177
+
178
+ @property
179
+ def gain(self) -> float | None:
180
+ if self._coords_buffer is not None:
181
+ return self._coords_gain_estimate
182
+ elif self._linear_axis is not None:
183
+ return self._linear_axis.gain
184
+ else:
185
+ return None
186
+
187
+ def searchsorted(
188
+ self, values: typing.Union[float, Array], side: str = "left"
189
+ ) -> typing.Union[int, Array]:
190
+ if self._coords_buffer is not None:
191
+ return self._coords_buffer.xp.searchsorted(
192
+ self._coords_buffer.peek(self.available()), values, side=side
193
+ )
194
+ else:
195
+ if self.available() == 0:
196
+ if isinstance(values, float):
197
+ return 0
198
+ else:
199
+ _xp = get_namespace(values)
200
+ return _xp.zeros_like(values, dtype=int)
201
+
202
+ f_inds = (values - self._linear_axis.offset) / self._linear_axis.gain
203
+ res = np.ceil(f_inds)
204
+ if side == "right":
205
+ res[np.isclose(f_inds, res)] += 1
206
+ return res.astype(int)
207
+
208
+
209
+ class HybridAxisArrayBuffer:
210
+ """A buffer that intelligently handles ezmsg.util.messages.AxisArray objects.
211
+
212
+ This buffer defers its own initialization until the first message arrives,
213
+ allowing it to automatically configure its size, shape, dtype, and array backend
214
+ (e.g., NumPy, CuPy) based on the message content and a desired buffer duration.
215
+
216
+ Args:
217
+ duration: The desired duration of the buffer in seconds.
218
+ axis: The name of the axis to buffer along.
219
+ **kwargs: Additional keyword arguments to pass to the underlying HybridBuffer
220
+ (e.g., `update_strategy`, `threshold`, `overflow_strategy`, `max_size`).
221
+ """
222
+
223
+ _data_buffer: HybridBuffer | None
224
+ _axis_buffer: HybridAxisBuffer
225
+ _template_msg: AxisArray | None
226
+
227
+ def __init__(self, duration: float, axis: str = "time", **kwargs):
228
+ self.duration = duration
229
+ self._axis = axis
230
+ self.buffer_kwargs = kwargs
231
+ self._axis_buffer = HybridAxisBuffer(duration=duration, **kwargs)
232
+ # Delay initialization until the first message arrives
233
+ self._data_buffer = None
234
+ self._template_msg = None
235
+
236
+ def available(self) -> int:
237
+ """The total number of unread samples currently available in the buffer."""
238
+ if self._data_buffer is None:
239
+ return 0
240
+ return self._data_buffer.available()
241
+
242
+ def is_empty(self) -> bool:
243
+ return self.available() == 0
244
+
245
+ def is_full(self) -> bool:
246
+ return 0 < self._data_buffer.capacity == self.available()
247
+
248
+ @property
249
+ def axis_first_value(self) -> float | None:
250
+ """The axis-value (timestamp, typically) of the first sample in the buffer."""
251
+ return self._axis_buffer.first_value
252
+
253
+ @property
254
+ def axis_final_value(self) -> float | None:
255
+ """The axis-value (timestamp, typically) of the last sample in the buffer."""
256
+ return self._axis_buffer.final_value
257
+
258
+ def _initialize(self, first_msg: AxisArray) -> None:
259
+ # Create a template message that has everything except the data are length 0
260
+ # and the target axis is missing.
261
+ self._template_msg = replace(
262
+ first_msg,
263
+ data=first_msg.data[:0],
264
+ axes={k: v for k, v in first_msg.axes.items() if k != self._axis},
265
+ )
266
+
267
+ in_axis = first_msg.axes[self._axis]
268
+ self._axis_buffer._initialize(in_axis)
269
+
270
+ capacity = int(self.duration / self._axis_buffer.gain)
271
+ self._data_buffer = HybridBuffer(
272
+ get_namespace(first_msg.data),
273
+ capacity,
274
+ other_shape=first_msg.data.shape[1:],
275
+ dtype=first_msg.data.dtype,
276
+ **self.buffer_kwargs,
277
+ )
278
+
279
+ def write(self, msg: AxisArray) -> None:
280
+ """Adds an AxisArray message to the buffer, initializing on the first call."""
281
+ in_axis_idx = msg.get_axis_idx(self._axis)
282
+ if in_axis_idx > 0:
283
+ # This class assumes that the target axis is the first axis.
284
+ # If it is not, we move it to the front.
285
+ dims = list(msg.dims)
286
+ dims.insert(0, dims.pop(in_axis_idx))
287
+ _xp = get_namespace(msg.data)
288
+ msg = replace(msg, data=_xp.moveaxis(msg.data, in_axis_idx, 0), dims=dims)
289
+
290
+ if self._data_buffer is None:
291
+ self._initialize(msg)
292
+
293
+ self._data_buffer.write(msg.data)
294
+ self._axis_buffer.write(msg.axes[self._axis], msg.shape[0])
295
+
296
+ def peek(self, n_samples: int | None = None) -> AxisArray | None:
297
+ """Retrieves the oldest unread data as a new AxisArray without advancing the read head."""
298
+
299
+ if self._data_buffer is None:
300
+ return None
301
+
302
+ data_array = self._data_buffer.peek(n_samples)
303
+
304
+ if data_array is None:
305
+ return None
306
+
307
+ out_axis = self._axis_buffer.peek(n_samples)
308
+
309
+ return replace(
310
+ self._template_msg,
311
+ data=data_array,
312
+ axes={**self._template_msg.axes, self._axis: out_axis},
313
+ )
314
+
315
+ def peek_axis(
316
+ self, n_samples: int | None = None
317
+ ) -> LinearAxis | CoordinateAxis | None:
318
+ """Retrieves the axis data without advancing the read head."""
319
+ if self._data_buffer is None:
320
+ return None
321
+
322
+ out_axis = self._axis_buffer.peek(n_samples)
323
+
324
+ if out_axis is None:
325
+ return None
326
+
327
+ return out_axis
328
+
329
+ def seek(self, n_samples: int) -> int:
330
+ """Advances the read pointer by n_samples."""
331
+ if self._data_buffer is None:
332
+ return 0
333
+
334
+ skipped_data_count = self._data_buffer.seek(n_samples)
335
+ axis_skipped = self._axis_buffer.seek(skipped_data_count)
336
+ assert (
337
+ axis_skipped == skipped_data_count
338
+ ), f"Axis buffer skipped {axis_skipped} samples, but data buffer skipped {skipped_data_count}."
339
+
340
+ return skipped_data_count
341
+
342
+ def read(self, n_samples: int | None = None) -> AxisArray | None:
343
+ """Retrieves the oldest unread data as a new AxisArray and advances the read head."""
344
+ retrieved_axis_array = self.peek(n_samples)
345
+
346
+ if retrieved_axis_array is None or retrieved_axis_array.shape[0] == 0:
347
+ return None
348
+
349
+ self.seek(retrieved_axis_array.shape[0])
350
+
351
+ return retrieved_axis_array
352
+
353
+ def prune(self, n_samples: int) -> int:
354
+ """Discards all but the last n_samples from the buffer."""
355
+ if self._data_buffer is None:
356
+ return 0
357
+
358
+ n_to_discard = self.available() - n_samples
359
+ if n_to_discard <= 0:
360
+ return 0
361
+
362
+ return self.seek(n_to_discard)
363
+
364
+ @property
365
+ def axis_gain(self) -> float | None:
366
+ """
367
+ The gain of the target axis, which is the time step between samples.
368
+ This is typically the sampling rate (e.g., 1 / fs).
369
+ """
370
+ return self._axis_buffer.gain
371
+
372
+ def axis_searchsorted(
373
+ self, values: typing.Union[float, Array], side: str = "left"
374
+ ) -> typing.Union[int, Array]:
375
+ """
376
+ Find the indices into which the given values would be inserted
377
+ into the target axis data to maintain order.
378
+ """
379
+ return self._axis_buffer.searchsorted(values, side=side)