ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +119 -104
  6. ezmsg/sigproc/bandpower.py +58 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -84
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,18 +1,41 @@
1
- import typing
2
-
3
- import numpy as np
4
1
  import ezmsg.core as ez
5
- from ezmsg.util.generator import consumer
6
2
  from ezmsg.util.messages.axisarray import AxisArray
7
3
  from ezmsg.util.messages.util import replace
8
4
 
9
- from ..base import GenAxisArray
5
+ from ..base import BaseTransformer, BaseTransformerUnit
6
+
7
+
8
+ class ConstDifferenceSettings(ez.Settings):
9
+ value: float = 0.0
10
+ """number to subtract or be subtracted from the input data"""
11
+
12
+ subtrahend: bool = True
13
+ """If True (default) then value is subtracted from the input data. If False, the input data is subtracted from value."""
14
+
15
+
16
+ class ConstDifferenceTransformer(
17
+ BaseTransformer[ConstDifferenceSettings, AxisArray, AxisArray]
18
+ ):
19
+ def _process(self, message: AxisArray) -> AxisArray:
20
+ return replace(
21
+ message,
22
+ data=(message.data - self.settings.value)
23
+ if self.settings.subtrahend
24
+ else (self.settings.value - message.data),
25
+ )
26
+
27
+
28
+ class ConstDifference(
29
+ BaseTransformerUnit[
30
+ ConstDifferenceSettings, AxisArray, AxisArray, ConstDifferenceTransformer
31
+ ]
32
+ ):
33
+ SETTINGS = ConstDifferenceSettings
10
34
 
11
35
 
12
- @consumer
13
36
  def const_difference(
14
37
  value: float = 0.0, subtrahend: bool = True
15
- ) -> typing.Generator[AxisArray, AxisArray, None]:
38
+ ) -> ConstDifferenceTransformer:
16
39
  """
17
40
  result = (in_data - value) if subtrahend else (value - in_data)
18
41
  https://en.wikipedia.org/wiki/Template:Arithmetic_operations
@@ -22,30 +45,11 @@ def const_difference(
22
45
  subtrahend: If True (default) then value is subtracted from the input data.
23
46
  If False, the input data is subtracted from value.
24
47
 
25
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
26
- with the data payload containing the difference between the input :obj:`AxisArray` data and the value.
27
-
48
+ Returns: :obj:`ConstDifferenceTransformer`.
28
49
  """
29
- msg_out = AxisArray(np.array([]), dims=[""])
30
- while True:
31
- msg_in: AxisArray = yield msg_out
32
- msg_out = replace(
33
- msg_in, data=(msg_in.data - value) if subtrahend else (value - msg_in.data)
34
- )
35
-
36
-
37
- class ConstDifferenceSettings(ez.Settings):
38
- value: float = 0.0
39
- subtrahend: bool = True
40
-
41
-
42
- class ConstDifference(GenAxisArray):
43
- SETTINGS = ConstDifferenceSettings
44
-
45
- def construct_generator(self):
46
- self.STATE.gen = const_difference(
47
- value=self.SETTINGS.value, subtrahend=self.SETTINGS.subtrahend
48
- )
50
+ return ConstDifferenceTransformer(
51
+ ConstDifferenceSettings(value=value, subtrahend=subtrahend)
52
+ )
49
53
 
50
54
 
51
55
  # class DifferenceSettings(ez.Settings):
@@ -1,35 +1,23 @@
1
- import typing
2
-
3
- import numpy as np
4
- import ezmsg.core as ez
5
- from ezmsg.util.generator import consumer
6
1
  from ezmsg.util.messages.axisarray import AxisArray
7
2
  from ezmsg.util.messages.util import replace
8
3
 
9
- from ..base import GenAxisArray
4
+ from ..base import BaseTransformer, BaseTransformerUnit
10
5
 
11
6
 
12
- @consumer
13
- def invert() -> typing.Generator[AxisArray, AxisArray, None]:
14
- """
15
- Take the inverse of the data.
7
+ class InvertTransformer(BaseTransformer[None, AxisArray, AxisArray]):
8
+ def _process(self, message: AxisArray) -> AxisArray:
9
+ return replace(message, data=1 / message.data)
16
10
 
17
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
18
- with the data payload containing the inversion of the input :obj:`AxisArray` data.
19
-
20
- """
21
- msg_out = AxisArray(np.array([]), dims=[""])
22
- while True:
23
- msg_in: AxisArray = yield msg_out
24
- msg_out = replace(msg_in, data=1 / msg_in.data)
25
11
 
12
+ class Invert(
13
+ BaseTransformerUnit[None, AxisArray, AxisArray, InvertTransformer]
14
+ ): ... # SETTINGS = None
26
15
 
27
- class InvertSettings(ez.Settings):
28
- pass
29
16
 
17
+ def invert() -> InvertTransformer:
18
+ """
19
+ Take the inverse of the data.
30
20
 
31
- class Invert(GenAxisArray):
32
- SETTINGS = InvertSettings
33
-
34
- def construct_generator(self):
35
- self.STATE.gen = invert()
21
+ Returns: :obj:`InvertTransformer`.
22
+ """
23
+ return InvertTransformer()
ezmsg/sigproc/math/log.py CHANGED
@@ -1,19 +1,39 @@
1
- import typing
2
-
3
1
  import numpy as np
4
2
  import ezmsg.core as ez
5
- from ezmsg.util.generator import consumer
6
3
  from ezmsg.util.messages.axisarray import AxisArray
7
4
  from ezmsg.util.messages.util import replace
8
5
 
9
- from ..base import GenAxisArray
6
+ from ..base import BaseTransformer, BaseTransformerUnit
7
+
8
+
9
+ class LogSettings(ez.Settings):
10
+ base: float = 10.0
11
+ """The base of the logarithm. Default is 10."""
12
+
13
+ clip_zero: bool = False
14
+ """If True, clip the data to the minimum positive value of the data type before taking the log."""
15
+
16
+
17
+ class LogTransformer(BaseTransformer[LogSettings, AxisArray, AxisArray]):
18
+ def _process(self, message: AxisArray) -> AxisArray:
19
+ data = message.data
20
+ if (
21
+ self.settings.clip_zero
22
+ and np.any(data <= 0)
23
+ and np.issubdtype(data.dtype, np.floating)
24
+ ):
25
+ data = np.clip(data, a_min=np.finfo(data.dtype).tiny, a_max=None)
26
+ return replace(message, data=np.log(data) / np.log(self.settings.base))
27
+
28
+
29
+ class Log(BaseTransformerUnit[LogSettings, AxisArray, AxisArray, LogTransformer]):
30
+ SETTINGS = LogSettings
10
31
 
11
32
 
12
- @consumer
13
33
  def log(
14
34
  base: float = 10.0,
15
35
  clip_zero: bool = False,
16
- ) -> typing.Generator[AxisArray, AxisArray, None]:
36
+ ) -> LogTransformer:
17
37
  """
18
38
  Take the logarithm of the data. See :obj:`np.log` for more details.
19
39
 
@@ -21,32 +41,7 @@ def log(
21
41
  base: The base of the logarithm. Default is 10.
22
42
  clip_zero: If True, clip the data to the minimum positive value of the data type before taking the log.
23
43
 
24
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
25
- with the data payload containing the logarithm of the input :obj:`AxisArray` data.
44
+ Returns: :obj:`LogTransformer`.
26
45
 
27
46
  """
28
- msg_out = AxisArray(np.array([]), dims=[""])
29
- log_base = np.log(base)
30
- while True:
31
- msg_in: AxisArray = yield msg_out
32
- if (
33
- clip_zero
34
- and np.any(msg_in.data <= 0)
35
- and np.issubdtype(msg_in.data.dtype, np.floating)
36
- ):
37
- msg_in.data = np.clip(
38
- msg_in.data, a_min=np.finfo(msg_in.data.dtype).tiny, a_max=None
39
- )
40
- msg_out = replace(msg_in, data=np.log(msg_in.data) / log_base)
41
-
42
-
43
- class LogSettings(ez.Settings):
44
- base: float = 10.0
45
- clip_zero: bool = False
46
-
47
-
48
- class Log(GenAxisArray):
49
- SETTINGS = LogSettings
50
-
51
- def construct_generator(self):
52
- self.STATE.gen = log(base=self.SETTINGS.base, clip_zero=self.SETTINGS.clip_zero)
47
+ return LogTransformer(LogSettings(base=base, clip_zero=clip_zero))
@@ -1,40 +1,32 @@
1
- import typing
2
-
3
- import numpy as np
4
1
  import ezmsg.core as ez
5
- from ezmsg.util.generator import consumer
6
2
  from ezmsg.util.messages.axisarray import AxisArray
7
3
  from ezmsg.util.messages.util import replace
8
4
 
9
- from ..base import GenAxisArray
5
+ from ..base import BaseTransformer, BaseTransformerUnit
10
6
 
11
7
 
12
- @consumer
13
- def scale(scale: float = 1.0) -> typing.Generator[AxisArray, AxisArray, None]:
14
- """
15
- Scale the data by a constant factor.
8
+ class ScaleSettings(ez.Settings):
9
+ scale: float = 1.0
10
+ """Factor by which to scale the data magnitude."""
16
11
 
17
- Args:
18
- scale: Factor by which to scale the data magnitude.
19
12
 
20
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
21
- with the data payload containing the input :obj:`AxisArray` data scaled by a constant factor.
13
+ class ScaleTransformer(BaseTransformer[ScaleSettings, AxisArray, AxisArray]):
14
+ def _process(self, message: AxisArray) -> AxisArray:
15
+ return replace(message, data=self.settings.scale * message.data)
22
16
 
23
- """
24
- msg_out = AxisArray(np.array([]), dims=[""])
25
- while True:
26
- msg_in: AxisArray = yield msg_out
27
- msg_out = replace(msg_in, data=scale * msg_in.data)
17
+
18
+ class Scale(BaseTransformerUnit[ScaleSettings, AxisArray, AxisArray, ScaleTransformer]):
19
+ SETTINGS = ScaleSettings
28
20
 
29
21
 
30
- class ScaleSettings(ez.Settings):
31
- scale: float = 1.0
22
+ def scale(scale: float = 1.0) -> ScaleTransformer:
23
+ """
24
+ Scale the data by a constant factor.
32
25
 
26
+ Args:
27
+ scale: Factor by which to scale the data magnitude.
33
28
 
34
- class Scale(GenAxisArray):
35
- SETTINGS = ScaleSettings
29
+ Returns: :obj:`ScaleTransformer`
36
30
 
37
- def construct_generator(self):
38
- self.STATE.gen = scale(
39
- scale=self.SETTINGS.scale,
40
- )
31
+ """
32
+ return ScaleTransformer(ScaleSettings(scale=scale))
@@ -0,0 +1,71 @@
1
+ import numpy as np
2
+ import ezmsg.core as ez
3
+ from ezmsg.util.messages.axisarray import AxisArray, replace
4
+
5
+ from .base import BaseTransformer, BaseTransformerUnit
6
+
7
+
8
+ class QuantizeSettings(ez.Settings):
9
+ """
10
+ Settings for the Quantizer.
11
+ """
12
+
13
+ max_val: float
14
+ """
15
+ Clip the data to this maximum value before quantization and map the [min_val max_val] range to the quantized range.
16
+ """
17
+
18
+ min_val: float = 0.0
19
+ """
20
+ Clip the data to this minimum value before quantization and map the [min_val max_val] range to the quantized range.
21
+ Default: 0
22
+ """
23
+
24
+ bits: int = 8
25
+ """
26
+ Number of bits for quantization.
27
+ Note: The data type will be integer of the next power of 2 greater than or equal to this value.
28
+ Default: 8
29
+ """
30
+
31
+
32
+ class QuantizeTransformer(BaseTransformer[QuantizeSettings, AxisArray, AxisArray]):
33
+ def _process(
34
+ self,
35
+ message: AxisArray,
36
+ ) -> AxisArray:
37
+ expected_range = self.settings.max_val - self.settings.min_val
38
+ scale_factor = 2**self.settings.bits - 1
39
+ clip_max = self.settings.max_val
40
+
41
+ # Determine appropriate integer type based on bits
42
+ if self.settings.bits <= 1:
43
+ dtype = bool
44
+ elif self.settings.bits <= 8:
45
+ dtype = np.uint8
46
+ elif self.settings.bits <= 16:
47
+ dtype = np.uint16
48
+ elif self.settings.bits <= 32:
49
+ dtype = np.uint32
50
+ else:
51
+ dtype = np.uint64
52
+ if self.settings.bits == 64:
53
+ # The practical upper bound before converting to int is: 2**64 - 1025
54
+ # Anything larger will wrap around to 0.
55
+ #
56
+ clip_max *= 1 - 2e-16
57
+
58
+ data = message.data.clip(self.settings.min_val, clip_max)
59
+ data = (data - self.settings.min_val) / expected_range
60
+
61
+ # Scale to the quantized range [0, 2^bits - 1]
62
+ data = np.rint(scale_factor * data).astype(dtype)
63
+
64
+ # Create a new AxisArray with the quantized data
65
+ return replace(message, data=data)
66
+
67
+
68
+ class QuantizerUnit(
69
+ BaseTransformerUnit[QuantizeSettings, AxisArray, AxisArray, QuantizeTransformer]
70
+ ):
71
+ SETTINGS = QuantizeSettings
@@ -0,0 +1,298 @@
1
+ import asyncio
2
+ import dataclasses
3
+ import time
4
+ import typing
5
+
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ import scipy.interpolate
9
+ import ezmsg.core as ez
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
+ from ezmsg.util.messages.util import replace
12
+
13
+ from .base import (
14
+ BaseStatefulProcessor,
15
+ BaseConsumerUnit,
16
+ processor_state,
17
+ )
18
+
19
+
20
+ class ResampleSettings(ez.Settings):
21
+ axis: str = "time"
22
+
23
+ resample_rate: float | None = None
24
+ """target resample rate in Hz. If None, the resample rate will be determined by the reference signal."""
25
+
26
+ max_chunk_delay: float = 0.0
27
+ """Maximum delay between outputs in seconds. If the delay exceeds this value, the transformer will extrapolate."""
28
+
29
+ fill_value: str = "extrapolate"
30
+ """
31
+ Value to use for out-of-bounds samples.
32
+ If 'extrapolate', the transformer will extrapolate.
33
+ If 'last', the transformer will use the last sample.
34
+ See scipy.interpolate.interp1d for more options.
35
+ """
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class ResampleBuffer:
40
+ data: npt.NDArray
41
+ tvec: npt.NDArray
42
+ template: AxisArray
43
+ last_update: float
44
+
45
+
46
+ @processor_state
47
+ class ResampleState:
48
+ signal_buffer: ResampleBuffer | None = None
49
+ ref_axis: tuple[typing.Union[AxisArray.TimeAxis, AxisArray.CoordinateAxis], int] = (
50
+ AxisArray.TimeAxis(fs=1.0),
51
+ 0,
52
+ )
53
+ last_t_out: float | None = None
54
+
55
+
56
+ class ResampleProcessor(
57
+ BaseStatefulProcessor[ResampleSettings, AxisArray, AxisArray, ResampleState]
58
+ ):
59
+ def _hash_message(self, message: AxisArray) -> int:
60
+ ax_idx: int = message.get_axis_idx(self.settings.axis)
61
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
62
+ ax = message.axes[self.settings.axis]
63
+ in_fs = (1 / ax.gain) if hasattr(ax, "gain") else None
64
+ return hash((message.key, in_fs) + sample_shape)
65
+
66
+ def _reset_state(self, message: AxisArray) -> None:
67
+ """
68
+ Reset the internal state based on the incoming message.
69
+ If resample_rate is None, the output is driven by the reference signal.
70
+ The input will still determine the template (except the primary axis) and the buffer.
71
+ """
72
+ ax_idx: int = message.get_axis_idx(self.settings.axis)
73
+ ax = message.axes[self.settings.axis]
74
+ in_dat = message.data
75
+ in_tvec = (
76
+ ax.data
77
+ if hasattr(ax, "data")
78
+ else ax.value(np.arange(in_dat.shape[ax_idx]))
79
+ )
80
+ if ax_idx != 0:
81
+ in_dat = np.moveaxis(in_dat, ax_idx, 0)
82
+
83
+ if self.settings.resample_rate is None:
84
+ # Output is driven by input.
85
+ # We cannot include the resampled axis until we see reference data.
86
+ out_axes = {
87
+ k: v for k, v in message.axes.items() if k != self.settings.axis
88
+ }
89
+ # last_t_out also driven by reference data.
90
+ # self.state.last_t_out = None
91
+ else:
92
+ out_axes = {
93
+ **message.axes,
94
+ self.settings.axis: AxisArray.TimeAxis(
95
+ fs=self.settings.resample_rate, offset=in_tvec[0]
96
+ ),
97
+ }
98
+ self.state.last_t_out = in_tvec[0] - 1 / self.settings.resample_rate
99
+ template = replace(message, data=in_dat[:0], axes=out_axes)
100
+ self.state.signal_buffer = ResampleBuffer(
101
+ data=in_dat[:0],
102
+ tvec=in_tvec[:0],
103
+ template=template,
104
+ last_update=time.time(),
105
+ )
106
+
107
+ def _process(self, message: AxisArray) -> None:
108
+ # The incoming message will be added to the buffer.
109
+ buf = self.state.signal_buffer
110
+
111
+ # If our outputs are driven by reference signal, create the template's output axis if not already created.
112
+ if (
113
+ self.settings.resample_rate is None
114
+ and self.settings.axis not in self.state.signal_buffer.template.axes
115
+ ):
116
+ buf = self.state.signal_buffer
117
+ buf.template.axes[self.settings.axis] = self.state.ref_axis[0]
118
+ if hasattr(buf.template.axes[self.settings.axis], "gain"):
119
+ buf.template = replace(
120
+ buf.template,
121
+ axes={
122
+ **buf.template.axes,
123
+ self.settings.axis: replace(
124
+ buf.template.axes[self.settings.axis],
125
+ offset=self.state.last_t_out,
126
+ ),
127
+ },
128
+ )
129
+ # Note: last_t_out was set on the first call to push_reference.
130
+
131
+ # Append the new data to the buffer
132
+ ax_idx: int = message.get_axis_idx(self.settings.axis)
133
+ in_dat: npt.NDArray = message.data
134
+ if ax_idx != 0:
135
+ in_dat = np.moveaxis(in_dat, ax_idx, 0)
136
+ ax = message.axes[self.settings.axis]
137
+ in_tvec = (
138
+ ax.data if hasattr(ax, "data") else ax.value(np.arange(in_dat.shape[0]))
139
+ )
140
+ buf.data = np.concatenate((buf.data, in_dat), axis=0)
141
+ buf.tvec = np.hstack((buf.tvec, in_tvec))
142
+ buf.last_update = time.time()
143
+
144
+ def push_reference(self, message: AxisArray) -> None:
145
+ ax = message.axes[self.settings.axis]
146
+ ax_idx = message.get_axis_idx(self.settings.axis)
147
+ n_new = message.data.shape[ax_idx]
148
+ if self.state.ref_axis[1] == 0:
149
+ self.state.ref_axis = (ax, n_new)
150
+ else:
151
+ if hasattr(ax, "gain"):
152
+ # Rate and offset don't need to change; we simply increment our sample counter.
153
+ self.state.ref_axis = (
154
+ self.state.ref_axis[0],
155
+ self.state.ref_axis[1] + n_new,
156
+ )
157
+ else:
158
+ # Extend our time axis with the new data.
159
+ new_tvec = np.concatenate(
160
+ (self.state.ref_axis[0].data, ax.data), axis=0
161
+ )
162
+ self.state.ref_axis = (
163
+ replace(self.state.ref_axis[0], data=new_tvec),
164
+ self.state.ref_axis[1] + n_new,
165
+ )
166
+
167
+ if self.settings.resample_rate is None and self.state.last_t_out is None:
168
+ # This reference axis will become THE output axis.
169
+ # If last_t_out has not previously been set, we set it to the sample before this reference data.
170
+ if hasattr(self.state.ref_axis[0], "gain"):
171
+ ref_tvec = self.state.ref_axis[0].value(np.arange(2))
172
+ else:
173
+ ref_tvec = self.state.ref_axis[0].data[:2]
174
+ self.state.last_t_out = 2 * ref_tvec[0] - ref_tvec[1]
175
+
176
+ def __next__(self) -> AxisArray:
177
+ buf = self.state.signal_buffer
178
+
179
+ if buf is None:
180
+ return AxisArray(data=np.array([]), dims=[""], axes={}, key="null")
181
+
182
+ # buffer is empty or ref-driven && empty-reference; return the empty template
183
+ if (buf.tvec.size == 0) or (
184
+ self.settings.resample_rate is None and self.state.ref_axis[1] < 3
185
+ ):
186
+ # Note: empty template's primary axis' offset might be meaningless.
187
+ return buf.template
188
+
189
+ # Identify the output timestamps at which we will resample the buffer
190
+ b_project = False
191
+ if self.settings.resample_rate is None:
192
+ # Rely on reference signal to determine output timestamps
193
+ if hasattr(self.state.ref_axis[0], "data"):
194
+ ref_tvec = self.state.ref_axis[0].data
195
+ else:
196
+ n_avail = self.state.ref_axis[1]
197
+ ref_tvec = self.state.ref_axis[0].value(np.arange(n_avail))
198
+ else:
199
+ # Get output timestamps from resample_rate and what we've collected so far
200
+ t_begin = self.state.last_t_out + 1 / self.settings.resample_rate
201
+ t_end = buf.tvec[-1]
202
+ if self.settings.max_chunk_delay > 0 and time.time() > (
203
+ buf.last_update + self.settings.max_chunk_delay
204
+ ):
205
+ # We've waiting too long between pushes. We will have to extrapolate.
206
+ b_project = True
207
+ t_end += self.settings.max_chunk_delay
208
+ ref_tvec = np.arange(t_begin, t_end, 1 / self.settings.resample_rate)
209
+
210
+ # Which samples can we resample?
211
+ b_ref = ref_tvec > self.state.last_t_out
212
+ if not b_project:
213
+ b_ref = np.logical_and(b_ref, ref_tvec <= buf.tvec[-1])
214
+ ref_idx = np.where(b_ref)[0]
215
+
216
+ if len(ref_idx) < 2:
217
+ # Not enough data to resample; return the empty template.
218
+ return buf.template
219
+
220
+ tnew = ref_tvec[ref_idx]
221
+ # Slice buf to minimal range around tnew with some padding for better interpolation.
222
+ buf_start_ix = max(0, np.searchsorted(buf.tvec, tnew[0]) - 2)
223
+ buf_stop_ix = np.searchsorted(buf.tvec, tnew[-1], side="right") + 2
224
+ x = buf.tvec[buf_start_ix:buf_stop_ix]
225
+ y = buf.data[buf_start_ix:buf_stop_ix]
226
+ if (
227
+ isinstance(self.settings.fill_value, str)
228
+ and self.settings.fill_value == "last"
229
+ ):
230
+ fill_value = (y[0], y[-1])
231
+ else:
232
+ fill_value = self.settings.fill_value
233
+ f = scipy.interpolate.interp1d(
234
+ x,
235
+ y,
236
+ kind="linear",
237
+ axis=0,
238
+ copy=False,
239
+ bounds_error=False,
240
+ fill_value=fill_value,
241
+ assume_sorted=True,
242
+ )
243
+ resampled_data = f(tnew)
244
+ if hasattr(buf.template.axes[self.settings.axis], "data"):
245
+ repl_axis = replace(buf.template.axes[self.settings.axis], data=tnew)
246
+ else:
247
+ repl_axis = replace(buf.template.axes[self.settings.axis], offset=tnew[0])
248
+ result = replace(
249
+ buf.template,
250
+ data=resampled_data,
251
+ axes={
252
+ **buf.template.axes,
253
+ self.settings.axis: repl_axis,
254
+ },
255
+ )
256
+
257
+ # Update state to move past samples that are no longer be needed
258
+ self.state.last_t_out = tnew[-1]
259
+ buf.data = buf.data[max(0, buf_stop_ix - 3) :]
260
+ buf.tvec = buf.tvec[max(0, buf_stop_ix - 3) :]
261
+ buf.last_update = time.time()
262
+
263
+ if self.settings.resample_rate is None:
264
+ # Update self.state.ref_axis to remove samples that have been used in the output
265
+ if hasattr(self.state.ref_axis[0], "data"):
266
+ new_ref_ax = replace(
267
+ self.state.ref_axis[0],
268
+ data=self.state.ref_axis[0].data[ref_idx[-1] + 1 :],
269
+ )
270
+ else:
271
+ next_offset = self.state.ref_axis[0].value(ref_idx[-1] + 1)
272
+ new_ref_ax = replace(self.state.ref_axis[0], offset=next_offset)
273
+ self.state.ref_axis = (new_ref_ax, self.state.ref_axis[1] - len(ref_idx))
274
+
275
+ return result
276
+
277
+ def send(self, message: AxisArray) -> AxisArray:
278
+ self(message)
279
+ return next(self)
280
+
281
+
282
+ class ResampleUnit(BaseConsumerUnit[ResampleSettings, AxisArray, ResampleProcessor]):
283
+ SETTINGS = ResampleSettings
284
+ INPUT_REFERENCE = ez.InputStream(AxisArray)
285
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
286
+
287
+ @ez.subscriber(INPUT_REFERENCE, zero_copy=True)
288
+ async def on_reference(self, message: AxisArray):
289
+ self.processor.push_reference(message)
290
+
291
+ @ez.publisher(OUTPUT_SIGNAL)
292
+ async def gen_resampled(self):
293
+ while True:
294
+ result: AxisArray = next(self.processor)
295
+ if np.prod(result.data.shape) > 0:
296
+ yield self.OUTPUT_SIGNAL, result
297
+ else:
298
+ await asyncio.sleep(0.001)