ezmsg-sigproc 2.2.0__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,119 @@
1
+ import functools
2
+ import typing
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import scipy.signal
7
+
8
+ from .filter import (
9
+ FilterBaseSettings,
10
+ FilterByDesignTransformer,
11
+ BACoeffs,
12
+ BaseFilterByDesignTransformerUnit,
13
+ )
14
+
15
+
16
+ class FIRFilterSettings(FilterBaseSettings):
17
+ """Settings for :obj:`FIRFilter`. See scipy.signal.firwin for more details"""
18
+
19
+ # axis and coef_type are inherited from FilterBaseSettings
20
+
21
+ order: int = 0
22
+ """
23
+ Filter order/number of taps
24
+ """
25
+
26
+ cutoff: float | npt.ArrayLike | None = None
27
+ """
28
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
29
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
30
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
31
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
32
+ not be included in cutoff.
33
+ """
34
+
35
+ width: float | None = None
36
+ """
37
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
38
+ the same units as fs) for use in Kaiser FIR filter design. In this case, the window argument is ignored.
39
+ """
40
+
41
+ window: str | None = "hamming"
42
+ """
43
+ Desired window to use. See scipy.signal.get_window for a list of windows and required parameters.
44
+ """
45
+
46
+ pass_zero: bool | str = True
47
+ """
48
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
49
+ be a string argument for the desired filter type (equivalent to btype in IIR design functions).
50
+ {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
51
+ """
52
+
53
+ scale: bool = True
54
+ """
55
+ Set to True to scale the coefficients so that the frequency response is exactly unity at a certain
56
+ frequency. That frequency is either:
57
+ * 0 (DC) if the first passband starts at 0 (i.e. pass_zero is True)
58
+ * fs/2 (the Nyquist frequency) if the first passband ends at fs/2
59
+ (i.e the filter is a single band highpass filter);
60
+ center of first passband otherwise
61
+ """
62
+
63
+ wn_hz: bool = True
64
+ """
65
+ Set False if provided Wn are normalized from 0 to 1, where 1 is the Nyquist frequency
66
+ """
67
+
68
+
69
+ def firwin_design_fun(
70
+ fs: float,
71
+ order: int = 0,
72
+ cutoff: float | npt.ArrayLike | None = None,
73
+ width: float | None = None,
74
+ window: str | None = "hamming",
75
+ pass_zero: bool | str = True,
76
+ scale: bool = True,
77
+ wn_hz: bool = True,
78
+ ) -> BACoeffs | None:
79
+ """
80
+ Design an `order`th-order FIR filter and return the filter coefficients.
81
+ See :obj:`FIRFilterSettings` for argument description.
82
+
83
+ Returns:
84
+ The filter taps as designed by firwin
85
+ """
86
+ if order > 0:
87
+ taps = scipy.signal.firwin(
88
+ numtaps=order,
89
+ cutoff=cutoff,
90
+ width=width,
91
+ window=window,
92
+ pass_zero=pass_zero,
93
+ scale=scale,
94
+ fs=fs if wn_hz else None,
95
+ )
96
+ return (taps, np.array([1.0]))
97
+ return None
98
+
99
+
100
+ class FIRFilterTransformer(FilterByDesignTransformer[FIRFilterSettings, BACoeffs]):
101
+ def get_design_function(
102
+ self,
103
+ ) -> typing.Callable[[float], BACoeffs | None]:
104
+ return functools.partial(
105
+ firwin_design_fun,
106
+ order=self.settings.order,
107
+ cutoff=self.settings.cutoff,
108
+ width=self.settings.width,
109
+ window=self.settings.window,
110
+ pass_zero=self.settings.pass_zero,
111
+ scale=self.settings.scale,
112
+ wn_hz=self.settings.wn_hz,
113
+ )
114
+
115
+
116
+ class FIRFilter(
117
+ BaseFilterByDesignTransformerUnit[FIRFilterSettings, FIRFilterTransformer]
118
+ ):
119
+ SETTINGS = FIRFilterSettings
@@ -0,0 +1,110 @@
1
+ import functools
2
+ import typing
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import scipy.signal
7
+
8
+ from .filter import (
9
+ FilterBaseSettings,
10
+ FilterByDesignTransformer,
11
+ BACoeffs,
12
+ BaseFilterByDesignTransformerUnit,
13
+ )
14
+
15
+
16
+ class KaiserFilterSettings(FilterBaseSettings):
17
+ """Settings for :obj:`KaiserFilter`"""
18
+
19
+ # axis and coef_type are inherited from FilterBaseSettings
20
+
21
+ cutoff: float | npt.ArrayLike | None = None
22
+ """
23
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
24
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
25
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
26
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
27
+ not be included in cutoff.
28
+ """
29
+
30
+ ripple: float | None = None
31
+ """
32
+ Upper bound for the deviation (in dB) of the magnitude of the filter's frequency response from that of
33
+ the desired filter (not including frequencies in any transition intervals).
34
+ See scipy.signal.kaiserord for more information.
35
+ """
36
+
37
+ width: float | None = None
38
+ """
39
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
40
+ the same units as fs) for use in Kaiser FIR filter design.
41
+ See scipy.signal.kaiserord for more information.
42
+ """
43
+
44
+ pass_zero: bool | str = True
45
+ """
46
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
47
+ be a string argument for the desired filter type (equivalent to btype in IIR design functions).
48
+ {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
49
+ """
50
+
51
+ wn_hz: bool = True
52
+ """
53
+ Set False if cutoff and width are normalized from 0 to 1, where 1 is the Nyquist frequency
54
+ """
55
+
56
+
57
+ def kaiser_design_fun(
58
+ fs: float,
59
+ cutoff: float | npt.ArrayLike | None = None,
60
+ ripple: float | None = None,
61
+ width: float | None = None,
62
+ pass_zero: bool | str = True,
63
+ wn_hz: bool = True,
64
+ ) -> BACoeffs | None:
65
+ """
66
+ Design an `order`th-order FIR Kaiser filter and return the filter coefficients.
67
+ See :obj:`FIRFilterSettings` for argument description.
68
+
69
+ Returns:
70
+ The filter taps as designed by firwin
71
+ """
72
+ if ripple is None or width is None or cutoff is None:
73
+ return None
74
+
75
+ width = width / (0.5 * fs) if wn_hz else width
76
+ n_taps, beta = scipy.signal.kaiserord(ripple, width)
77
+ if n_taps % 2 == 0:
78
+ n_taps += 1
79
+ taps = scipy.signal.firwin(
80
+ numtaps=n_taps,
81
+ cutoff=cutoff,
82
+ window=("kaiser", beta), # type: ignore
83
+ pass_zero=pass_zero, # type: ignore
84
+ scale=False,
85
+ fs=fs if wn_hz else None,
86
+ )
87
+
88
+ return (taps, np.array([1.0]))
89
+
90
+
91
+ class KaiserFilterTransformer(
92
+ FilterByDesignTransformer[KaiserFilterSettings, BACoeffs]
93
+ ):
94
+ def get_design_function(
95
+ self,
96
+ ) -> typing.Callable[[float], BACoeffs | None]:
97
+ return functools.partial(
98
+ kaiser_design_fun,
99
+ cutoff=self.settings.cutoff,
100
+ ripple=self.settings.ripple,
101
+ width=self.settings.width,
102
+ pass_zero=self.settings.pass_zero,
103
+ wn_hz=self.settings.wn_hz,
104
+ )
105
+
106
+
107
+ class KaiserFilter(
108
+ BaseFilterByDesignTransformerUnit[KaiserFilterSettings, KaiserFilterTransformer]
109
+ ):
110
+ SETTINGS = KaiserFilterSettings
ezmsg/sigproc/resample.py CHANGED
@@ -1,13 +1,11 @@
1
1
  import asyncio
2
- import dataclasses
2
+ import math
3
3
  import time
4
- import typing
5
4
 
6
5
  import numpy as np
7
- import numpy.typing as npt
8
6
  import scipy.interpolate
9
7
  import ezmsg.core as ez
10
- from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
11
9
  from ezmsg.util.messages.util import replace
12
10
 
13
11
  from .base import (
@@ -15,6 +13,8 @@ from .base import (
15
13
  BaseConsumerUnit,
16
14
  processor_state,
17
15
  )
16
+ from .util.axisarray_buffer import HybridAxisArrayBuffer, HybridAxisBuffer
17
+ from .util.buffer import UpdateStrategy
18
18
 
19
19
 
20
20
  class ResampleSettings(ez.Settings):
@@ -23,7 +23,7 @@ class ResampleSettings(ez.Settings):
23
23
  resample_rate: float | None = None
24
24
  """target resample rate in Hz. If None, the resample rate will be determined by the reference signal."""
25
25
 
26
- max_chunk_delay: float = 0.0
26
+ max_chunk_delay: float = np.inf
27
27
  """Maximum delay between outputs in seconds. If the delay exceeds this value, the transformer will extrapolate."""
28
28
 
29
29
  fill_value: str = "extrapolate"
@@ -34,23 +34,49 @@ class ResampleSettings(ez.Settings):
34
34
  See scipy.interpolate.interp1d for more options.
35
35
  """
36
36
 
37
+ buffer_duration: float = 2.0
37
38
 
38
- @dataclasses.dataclass
39
- class ResampleBuffer:
40
- data: npt.NDArray
41
- tvec: npt.NDArray
42
- template: AxisArray
43
- last_update: float
39
+ buffer_update_strategy: UpdateStrategy = "immediate"
40
+ """
41
+ The buffer update strategy. See :obj:`ezmsg.sigproc.util.buffer.UpdateStrategy`.
42
+ If you expect to push data much more frequently than it is resampled, then "on_demand"
43
+ might be more efficient. For most other scenarios, "immediate" is best.
44
+ """
44
45
 
45
46
 
46
47
  @processor_state
47
48
  class ResampleState:
48
- signal_buffer: ResampleBuffer | None = None
49
- ref_axis: tuple[typing.Union[AxisArray.TimeAxis, AxisArray.CoordinateAxis], int] = (
50
- AxisArray.TimeAxis(fs=1.0),
51
- 0,
52
- )
53
- last_t_out: float | None = None
49
+ src_buffer: HybridAxisArrayBuffer | None = None
50
+ """
51
+ Buffer for the incoming signal data. This is the source for training the interpolation function.
52
+ Its contents are rarely empty because we usually hold back some data to allow for accurate
53
+ interpolation and optionally extrapolation.
54
+ """
55
+
56
+ ref_axis_buffer: HybridAxisBuffer | None = None
57
+ """
58
+ The buffer for the reference axis (usually a time axis). The interpolation function
59
+ will be evaluated at the reference axis values.
60
+ When resample_rate is None, this buffer will be filled with the axis from incoming
61
+ _reference_ messages.
62
+ When resample_rate is not None (i.e., prescribed float resample_rate), this buffer
63
+ is filled with a synthetic axis that is generated from the incoming signal messages.
64
+ """
65
+
66
+ last_ref_ax_val: float | None = None
67
+ """
68
+ The last value of the reference axis that was returned. This helps us to know
69
+ what the _next_ returned value should be, and to avoid returning the same value.
70
+ TODO: We can eliminate this variable if we maintain "by convention" that the
71
+ reference axis always has 1 value at its start that we exclude from the resampling.
72
+ """
73
+
74
+ last_write_time: float = -np.inf
75
+ """
76
+ Wall clock time of the last write to the signal buffer.
77
+ This is used to determine if we need to extrapolate the reference axis
78
+ if we have not received an update within max_chunk_delay.
79
+ """
54
80
 
55
81
 
56
82
  class ResampleProcessor(
@@ -60,169 +86,149 @@ class ResampleProcessor(
60
86
  ax_idx: int = message.get_axis_idx(self.settings.axis)
61
87
  sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
62
88
  ax = message.axes[self.settings.axis]
63
- in_fs = (1 / ax.gain) if hasattr(ax, "gain") else None
64
- return hash((message.key, in_fs) + sample_shape)
89
+ gain = ax.gain if hasattr(ax, "gain") else None
90
+ return hash((message.key, gain) + sample_shape)
65
91
 
66
92
  def _reset_state(self, message: AxisArray) -> None:
67
93
  """
68
94
  Reset the internal state based on the incoming message.
69
- If resample_rate is None, the output is driven by the reference signal.
70
- The input will still determine the template (except the primary axis) and the buffer.
71
95
  """
72
- ax_idx: int = message.get_axis_idx(self.settings.axis)
73
- ax = message.axes[self.settings.axis]
74
- in_dat = message.data
75
- in_tvec = (
76
- ax.data
77
- if hasattr(ax, "data")
78
- else ax.value(np.arange(in_dat.shape[ax_idx]))
79
- )
80
- if ax_idx != 0:
81
- in_dat = np.moveaxis(in_dat, ax_idx, 0)
82
-
83
- if self.settings.resample_rate is None:
84
- # Output is driven by input.
85
- # We cannot include the resampled axis until we see reference data.
86
- out_axes = {
87
- k: v for k, v in message.axes.items() if k != self.settings.axis
88
- }
89
- # last_t_out also driven by reference data.
90
- # self.state.last_t_out = None
91
- else:
92
- out_axes = {
93
- **message.axes,
94
- self.settings.axis: AxisArray.TimeAxis(
95
- fs=self.settings.resample_rate, offset=in_tvec[0]
96
- ),
97
- }
98
- self.state.last_t_out = in_tvec[0] - 1 / self.settings.resample_rate
99
- template = replace(message, data=in_dat[:0], axes=out_axes)
100
- self.state.signal_buffer = ResampleBuffer(
101
- data=in_dat[:0],
102
- tvec=in_tvec[:0],
103
- template=template,
104
- last_update=time.time(),
105
- )
106
-
107
- def _process(self, message: AxisArray) -> None:
108
- # The incoming message will be added to the buffer.
109
- buf = self.state.signal_buffer
110
-
111
- # If our outputs are driven by reference signal, create the template's output axis if not already created.
112
- if (
113
- self.settings.resample_rate is None
114
- and self.settings.axis not in self.state.signal_buffer.template.axes
115
- ):
116
- buf = self.state.signal_buffer
117
- buf.template.axes[self.settings.axis] = self.state.ref_axis[0]
118
- if hasattr(buf.template.axes[self.settings.axis], "gain"):
119
- buf.template = replace(
120
- buf.template,
121
- axes={
122
- **buf.template.axes,
123
- self.settings.axis: replace(
124
- buf.template.axes[self.settings.axis],
125
- offset=self.state.last_t_out,
126
- ),
127
- },
128
- )
129
- # Note: last_t_out was set on the first call to push_reference.
130
-
131
- # Append the new data to the buffer
132
- ax_idx: int = message.get_axis_idx(self.settings.axis)
133
- in_dat: npt.NDArray = message.data
134
- if ax_idx != 0:
135
- in_dat = np.moveaxis(in_dat, ax_idx, 0)
136
- ax = message.axes[self.settings.axis]
137
- in_tvec = (
138
- ax.data if hasattr(ax, "data") else ax.value(np.arange(in_dat.shape[0]))
96
+ self.state.src_buffer = HybridAxisArrayBuffer(
97
+ duration=self.settings.buffer_duration,
98
+ axis=self.settings.axis,
99
+ update_strategy=self.settings.buffer_update_strategy,
100
+ overflow_strategy="grow",
139
101
  )
140
- buf.data = np.concatenate((buf.data, in_dat), axis=0)
141
- buf.tvec = np.hstack((buf.tvec, in_tvec))
142
- buf.last_update = time.time()
102
+ if self.settings.resample_rate is not None:
103
+ # If we are resampling at a prescribed rate, then we synthesize a reference axis
104
+ self.state.ref_axis_buffer = HybridAxisBuffer(
105
+ duration=self.settings.buffer_duration,
106
+ )
107
+ in_ax = message.axes[self.settings.axis]
108
+ out_gain = 1 / self.settings.resample_rate
109
+ t0 = in_ax.data[0] if hasattr(in_ax, "data") else in_ax.value(0)
110
+ self.state.last_ref_ax_val = t0 - out_gain
111
+ self.state.last_write_time = -np.inf
143
112
 
144
113
  def push_reference(self, message: AxisArray) -> None:
145
114
  ax = message.axes[self.settings.axis]
146
115
  ax_idx = message.get_axis_idx(self.settings.axis)
147
- n_new = message.data.shape[ax_idx]
148
- if self.state.ref_axis[1] == 0:
149
- self.state.ref_axis = (ax, n_new)
150
- else:
151
- if hasattr(ax, "gain"):
152
- # Rate and offset don't need to change; we simply increment our sample counter.
153
- self.state.ref_axis = (
154
- self.state.ref_axis[0],
155
- self.state.ref_axis[1] + n_new,
156
- )
157
- else:
158
- # Extend our time axis with the new data.
159
- new_tvec = np.concatenate(
160
- (self.state.ref_axis[0].data, ax.data), axis=0
161
- )
162
- self.state.ref_axis = (
163
- replace(self.state.ref_axis[0], data=new_tvec),
164
- self.state.ref_axis[1] + n_new,
165
- )
166
-
167
- if self.settings.resample_rate is None and self.state.last_t_out is None:
168
- # This reference axis will become THE output axis.
169
- # If last_t_out has not previously been set, we set it to the sample before this reference data.
170
- if hasattr(self.state.ref_axis[0], "gain"):
171
- ref_tvec = self.state.ref_axis[0].value(np.arange(2))
172
- else:
173
- ref_tvec = self.state.ref_axis[0].data[:2]
174
- self.state.last_t_out = 2 * ref_tvec[0] - ref_tvec[1]
116
+ if self.state.ref_axis_buffer is None:
117
+ self.state.ref_axis_buffer = HybridAxisBuffer(
118
+ duration=self.settings.buffer_duration,
119
+ update_strategy=self.settings.buffer_update_strategy,
120
+ overflow_strategy="grow",
121
+ )
122
+ t0 = ax.data[0] if hasattr(ax, "data") else ax.value(0)
123
+ self.state.last_ref_ax_val = t0 - ax.gain
124
+ self.state.ref_axis_buffer.write(ax, n_samples=message.data.shape[ax_idx])
175
125
 
176
- def __next__(self) -> AxisArray:
177
- buf = self.state.signal_buffer
126
+ def _process(self, message: AxisArray) -> None:
127
+ """
128
+ Add a new data message to the buffer and update the reference axis if needed.
129
+ """
130
+ # Note: The src_buffer will copy and permute message if ax_idx != 0
131
+ self.state.src_buffer.write(message)
178
132
 
179
- if buf is None:
133
+ # If we are resampling at a prescribed rate (i.e., not by reference msgs),
134
+ # then we use this opportunity to extend our synthetic reference axis.
135
+ ax_idx = message.get_axis_idx(self.settings.axis)
136
+ if self.settings.resample_rate is not None and message.data.shape[ax_idx] > 0:
137
+ in_ax = message.axes[self.settings.axis]
138
+ in_t_end = (
139
+ in_ax.data[-1]
140
+ if hasattr(in_ax, "data")
141
+ else in_ax.value(message.data.shape[ax_idx] - 1)
142
+ )
143
+ out_gain = 1 / self.settings.resample_rate
144
+ prev_t_end = self.state.last_ref_ax_val
145
+ n_synth = math.ceil((in_t_end - prev_t_end) * self.settings.resample_rate)
146
+ synth_ref_axis = LinearAxis(
147
+ unit="s", gain=out_gain, offset=prev_t_end + out_gain
148
+ )
149
+ self.state.ref_axis_buffer.write(synth_ref_axis, n_samples=n_synth)
150
+
151
+ self.state.last_write_time = time.time()
152
+
153
+ def __next__(self) -> AxisArray:
154
+ if self.state.src_buffer is None or self.state.ref_axis_buffer is None:
155
+ # If we have not received any data, or we require reference data
156
+ # that we do not yet have, then return an empty template.
180
157
  return AxisArray(data=np.array([]), dims=[""], axes={}, key="null")
181
158
 
182
- # buffer is empty or ref-driven && empty-reference; return the empty template
183
- if (buf.tvec.size == 0) or (
184
- self.settings.resample_rate is None and self.state.ref_axis[1] < 3
185
- ):
186
- # Note: empty template's primary axis' offset might be meaningless.
187
- return buf.template
188
-
189
- # Identify the output timestamps at which we will resample the buffer
190
- b_project = False
191
- if self.settings.resample_rate is None:
192
- # Rely on reference signal to determine output timestamps
193
- if hasattr(self.state.ref_axis[0], "data"):
194
- ref_tvec = self.state.ref_axis[0].data
195
- else:
196
- n_avail = self.state.ref_axis[1]
197
- ref_tvec = self.state.ref_axis[0].value(np.arange(n_avail))
159
+ src = self.state.src_buffer
160
+ ref = self.state.ref_axis_buffer
161
+
162
+ # If we have no reference or the source is insufficient for interpolation
163
+ # then return the empty template
164
+ if ref.is_empty() or src.available() < 3:
165
+ src_axarr = src.peek(0)
166
+ return replace(
167
+ src_axarr,
168
+ axes={
169
+ **src_axarr.axes,
170
+ self.settings.axis: ref.peek(0),
171
+ },
172
+ )
173
+
174
+ # Build the reference xvec.
175
+ # Note: The reference axis buffer may grow upon `.peek()`
176
+ # as it flushes data from its deque to its buffer.
177
+ ref_ax = ref.peek()
178
+ if hasattr(ref_ax, "data"):
179
+ ref_xvec = ref_ax.data
198
180
  else:
199
- # Get output timestamps from resample_rate and what we've collected so far
200
- t_begin = self.state.last_t_out + 1 / self.settings.resample_rate
201
- t_end = buf.tvec[-1]
202
- if self.settings.max_chunk_delay > 0 and time.time() > (
203
- buf.last_update + self.settings.max_chunk_delay
204
- ):
205
- # We've waiting too long between pushes. We will have to extrapolate.
206
- b_project = True
207
- t_end += self.settings.max_chunk_delay
208
- ref_tvec = np.arange(t_begin, t_end, 1 / self.settings.resample_rate)
209
-
210
- # Which samples can we resample?
211
- b_ref = ref_tvec > self.state.last_t_out
181
+ ref_xvec = ref_ax.value(np.arange(ref.available()))
182
+
183
+ # If we do not rely on an external reference, and we have not received new data in a while,
184
+ # then extrapolate our reference vector out beyond the delay limit.
185
+ b_project = self.settings.resample_rate is not None and time.time() > (
186
+ self.state.last_write_time + self.settings.max_chunk_delay
187
+ )
188
+ if b_project:
189
+ n_append = math.ceil(self.settings.max_chunk_delay / ref_ax.gain)
190
+ xvec_append = ref_xvec[-1] + np.arange(1, n_append + 1) * ref_ax.gain
191
+ ref_xvec = np.hstack((ref_xvec, xvec_append))
192
+
193
+ # Get source to train interpolation
194
+ src_axarr = src.peek()
195
+ src_axis = src_axarr.axes[self.settings.axis]
196
+ x = (
197
+ src_axis.data
198
+ if hasattr(src_axis, "data")
199
+ else src_axis.value(np.arange(src_axarr.data.shape[0]))
200
+ )
201
+
202
+ # Only resample at reference values that have not been interpolated over previously.
203
+ b_ref = ref_xvec > self.state.last_ref_ax_val
212
204
  if not b_project:
213
- b_ref = np.logical_and(b_ref, ref_tvec <= buf.tvec[-1])
205
+ # Not extrapolating -- Do not resample beyond the end of the source buffer.
206
+ b_ref = np.logical_and(b_ref, ref_xvec <= x[-1])
214
207
  ref_idx = np.where(b_ref)[0]
215
208
 
216
- if len(ref_idx) < 2:
217
- # Not enough data to resample; return the empty template.
218
- return buf.template
209
+ if len(ref_idx) == 0:
210
+ # Nothing to interpolate over; return empty data
211
+ null_ref = (
212
+ replace(ref_ax, data=ref_ax.data[:0])
213
+ if hasattr(ref_ax, "data")
214
+ else ref_ax
215
+ )
216
+ return replace(
217
+ src_axarr,
218
+ data=src_axarr.data[:0, ...],
219
+ axes={**src_axarr.axes, self.settings.axis: null_ref},
220
+ )
221
+
222
+ xnew = ref_xvec[ref_idx]
223
+
224
+ # Identify source data indices around ref tvec with some padding for better interpolation.
225
+ src_start_ix = max(
226
+ 0, np.where(x > xnew[0])[0][0] - 2 if np.any(x > xnew[0]) else 0
227
+ )
228
+
229
+ x = x[src_start_ix:]
230
+ y = src_axarr.data[src_start_ix:]
219
231
 
220
- tnew = ref_tvec[ref_idx]
221
- # Slice buf to minimal range around tnew with some padding for better interpolation.
222
- buf_start_ix = max(0, np.searchsorted(buf.tvec, tnew[0]) - 2)
223
- buf_stop_ix = np.searchsorted(buf.tvec, tnew[-1], side="right") + 2
224
- x = buf.tvec[buf_start_ix:buf_stop_ix]
225
- y = buf.data[buf_start_ix:buf_stop_ix]
226
232
  if (
227
233
  isinstance(self.settings.fill_value, str)
228
234
  and self.settings.fill_value == "last"
@@ -240,37 +246,32 @@ class ResampleProcessor(
240
246
  fill_value=fill_value,
241
247
  assume_sorted=True,
242
248
  )
243
- resampled_data = f(tnew)
244
- if hasattr(buf.template.axes[self.settings.axis], "data"):
245
- repl_axis = replace(buf.template.axes[self.settings.axis], data=tnew)
249
+
250
+ # Calculate output
251
+ resampled_data = f(xnew)
252
+
253
+ # Create output message
254
+ if hasattr(ref_ax, "data"):
255
+ out_ax = replace(ref_ax, data=xnew)
246
256
  else:
247
- repl_axis = replace(buf.template.axes[self.settings.axis], offset=tnew[0])
257
+ out_ax = replace(ref_ax, offset=xnew[0])
248
258
  result = replace(
249
- buf.template,
259
+ src_axarr,
250
260
  data=resampled_data,
251
261
  axes={
252
- **buf.template.axes,
253
- self.settings.axis: repl_axis,
262
+ **src_axarr.axes,
263
+ self.settings.axis: out_ax,
254
264
  },
255
265
  )
256
266
 
257
- # Update state to move past samples that are no longer be needed
258
- self.state.last_t_out = tnew[-1]
259
- buf.data = buf.data[max(0, buf_stop_ix - 3) :]
260
- buf.tvec = buf.tvec[max(0, buf_stop_ix - 3) :]
261
- buf.last_update = time.time()
262
-
263
- if self.settings.resample_rate is None:
264
- # Update self.state.ref_axis to remove samples that have been used in the output
265
- if hasattr(self.state.ref_axis[0], "data"):
266
- new_ref_ax = replace(
267
- self.state.ref_axis[0],
268
- data=self.state.ref_axis[0].data[ref_idx[-1] + 1 :],
269
- )
270
- else:
271
- next_offset = self.state.ref_axis[0].value(ref_idx[-1] + 1)
272
- new_ref_ax = replace(self.state.ref_axis[0], offset=next_offset)
273
- self.state.ref_axis = (new_ref_ax, self.state.ref_axis[1] - len(ref_idx))
267
+ # Update the state. For state buffers, seek beyond samples that are no longer needed.
268
+ # src: keep at least 1 sample before the final resampled value
269
+ seek_ix = np.where(x >= xnew[-1])[0]
270
+ if len(seek_ix) > 0:
271
+ self.state.src_buffer.seek(max(0, src_start_ix + seek_ix[0] - 1))
272
+ # ref: remove samples that have been sent to output
273
+ self.state.ref_axis_buffer.seek(ref_idx[-1] + 1)
274
+ self.state.last_ref_ax_val = xnew[-1]
274
275
 
275
276
  return result
276
277