ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,159 @@
1
+ """
2
+ Coordinate space transformations for streaming data.
3
+
4
+ This module provides utilities and ezmsg nodes for transforming between
5
+ Cartesian (x, y) and polar (r, theta) coordinate systems.
6
+
7
+ .. note::
8
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
9
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
10
+ """
11
+
12
+ from enum import Enum
13
+ from typing import Tuple
14
+
15
+ import ezmsg.core as ez
16
+ import numpy as np
17
+ import numpy.typing as npt
18
+ from array_api_compat import get_namespace, is_array_api_obj
19
+ from ezmsg.baseproc import (
20
+ BaseTransformer,
21
+ BaseTransformerUnit,
22
+ )
23
+ from ezmsg.util.messages.axisarray import AxisArray, replace
24
+
25
+ # -- Utility functions for coordinate transformations --
26
+
27
+
28
+ def _get_namespace_or_numpy(*args: npt.ArrayLike):
29
+ """Get array namespace if any arg is an array, otherwise return numpy."""
30
+ for arg in args:
31
+ if is_array_api_obj(arg):
32
+ return get_namespace(arg)
33
+ return np
34
+
35
+
36
+ def polar2z(r: npt.ArrayLike, theta: npt.ArrayLike) -> npt.ArrayLike:
37
+ """Convert polar coordinates to complex number representation."""
38
+ xp = _get_namespace_or_numpy(r, theta)
39
+ return r * xp.exp(1j * theta)
40
+
41
+
42
+ def z2polar(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
43
+ """Convert complex number to polar coordinates (r, theta)."""
44
+ xp = _get_namespace_or_numpy(z)
45
+ return xp.abs(z), xp.atan2(xp.imag(z), xp.real(z))
46
+
47
+
48
+ def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:
49
+ """Convert Cartesian coordinates to complex number representation."""
50
+ return x + 1j * y
51
+
52
+
53
+ def z2cart(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
54
+ """Convert complex number to Cartesian coordinates (x, y)."""
55
+ xp = _get_namespace_or_numpy(z)
56
+ return xp.real(z), xp.imag(z)
57
+
58
+
59
+ def cart2pol(x: npt.ArrayLike, y: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
60
+ """Convert Cartesian coordinates (x, y) to polar coordinates (r, theta)."""
61
+ return z2polar(cart2z(x, y))
62
+
63
+
64
+ def pol2cart(r: npt.ArrayLike, theta: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
65
+ """Convert polar coordinates (r, theta) to Cartesian coordinates (x, y)."""
66
+ return z2cart(polar2z(r, theta))
67
+
68
+
69
+ # -- ezmsg transformer classes --
70
+
71
+
72
+ class CoordinateMode(str, Enum):
73
+ """Transformation mode for coordinate conversion."""
74
+
75
+ CART2POL = "cart2pol"
76
+ """Convert Cartesian (x, y) to polar (r, theta)."""
77
+
78
+ POL2CART = "pol2cart"
79
+ """Convert polar (r, theta) to Cartesian (x, y)."""
80
+
81
+
82
+ class CoordinateSpacesSettings(ez.Settings):
83
+ """
84
+ Settings for :obj:`CoordinateSpaces`.
85
+
86
+ See :obj:`coordinate_spaces` for argument details.
87
+ """
88
+
89
+ mode: CoordinateMode = CoordinateMode.CART2POL
90
+ """The transformation mode: 'cart2pol' or 'pol2cart'."""
91
+
92
+ axis: str | None = None
93
+ """
94
+ The name of the axis containing the coordinate components.
95
+ Defaults to the last axis. Must have exactly 2 elements (x,y or r,theta).
96
+ """
97
+
98
+
99
+ class CoordinateSpacesTransformer(BaseTransformer[CoordinateSpacesSettings, AxisArray, AxisArray]):
100
+ """
101
+ Transform between Cartesian and polar coordinate systems.
102
+
103
+ The input must have exactly 2 elements along the specified axis:
104
+ - For cart2pol: expects (x, y), outputs (r, theta)
105
+ - For pol2cart: expects (r, theta), outputs (x, y)
106
+ """
107
+
108
+ def _process(self, message: AxisArray) -> AxisArray:
109
+ xp = get_namespace(message.data)
110
+ axis = self.settings.axis or message.dims[-1]
111
+ axis_idx = message.get_axis_idx(axis)
112
+
113
+ if message.data.shape[axis_idx] != 2:
114
+ raise ValueError(
115
+ f"Coordinate transformation requires exactly 2 elements along axis '{axis}', "
116
+ f"got {message.data.shape[axis_idx]}."
117
+ )
118
+
119
+ # Extract components along the specified axis
120
+ slices_a = [slice(None)] * message.data.ndim
121
+ slices_b = [slice(None)] * message.data.ndim
122
+ slices_a[axis_idx] = 0
123
+ slices_b[axis_idx] = 1
124
+
125
+ component_a = message.data[tuple(slices_a)]
126
+ component_b = message.data[tuple(slices_b)]
127
+
128
+ if self.settings.mode == CoordinateMode.CART2POL:
129
+ # Input: x, y -> Output: r, theta
130
+ out_a, out_b = cart2pol(component_a, component_b)
131
+ else:
132
+ # Input: r, theta -> Output: x, y
133
+ out_a, out_b = pol2cart(component_a, component_b)
134
+
135
+ # Stack results back along the same axis
136
+ result = xp.stack([out_a, out_b], axis=axis_idx)
137
+
138
+ # Update axis labels if present (use numpy for string labels)
139
+ axes = message.axes
140
+ if axis in axes and hasattr(axes[axis], "data"):
141
+ if self.settings.mode == CoordinateMode.CART2POL:
142
+ new_labels = np.array(["r", "theta"])
143
+ else:
144
+ new_labels = np.array(["x", "y"])
145
+ axes = {**axes, axis: replace(axes[axis], data=new_labels)}
146
+
147
+ return replace(message, data=result, axes=axes)
148
+
149
+
150
+ class CoordinateSpaces(
151
+ BaseTransformerUnit[CoordinateSpacesSettings, AxisArray, AxisArray, CoordinateSpacesTransformer]
152
+ ):
153
+ """
154
+ Unit for transforming between Cartesian and polar coordinate systems.
155
+
156
+ See :obj:`CoordinateSpacesSettings` for configuration options.
157
+ """
158
+
159
+ SETTINGS = CoordinateSpacesSettings
ezmsg/sigproc/decimate.py CHANGED
@@ -1,37 +1,65 @@
1
- import ezmsg.core as ez
2
-
3
- import scipy.signal
1
+ import typing
4
2
 
3
+ import ezmsg.core as ez
4
+ from ezmsg.baseproc import BaseTransformerUnit
5
5
  from ezmsg.util.messages.axisarray import AxisArray
6
6
 
7
+ from .cheby import ChebyshevFilterSettings, ChebyshevFilterTransformer
7
8
  from .downsample import Downsample, DownsampleSettings
8
- from .filter import Filter, FilterCoefficients, FilterSettings
9
+ from .filter import BACoeffs, SOSCoeffs
10
+
11
+
12
+ class ChebyForDecimateTransformer(ChebyshevFilterTransformer[BACoeffs | SOSCoeffs]):
13
+ """
14
+ A :obj:`ChebyshevFilterTransformer` with a design filter method that additionally accepts a target sampling rate,
15
+ and if the target rate cannot be achieved it returns None, else it returns the filter coefficients.
16
+ """
17
+
18
+ def get_design_function(
19
+ self,
20
+ ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
21
+ def cheby_opt_design_fun(fs: float) -> BACoeffs | SOSCoeffs | None:
22
+ if fs is None:
23
+ return None
24
+ ds_factor = int(fs / (2.5 * self.settings.Wn))
25
+ if ds_factor < 2:
26
+ return None
27
+ partial_fun = super(ChebyForDecimateTransformer, self).get_design_function()
28
+ return partial_fun(fs)
29
+
30
+ return cheby_opt_design_fun
31
+
32
+
33
+ class ChebyForDecimate(BaseTransformerUnit[ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer]):
34
+ SETTINGS = ChebyshevFilterSettings
9
35
 
10
36
 
11
37
  class Decimate(ez.Collection):
12
- SETTINGS: DownsampleSettings
38
+ """
39
+ A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
40
+ and a :obj:`Downsample` node.
41
+ """
42
+
43
+ SETTINGS = DownsampleSettings
13
44
 
14
45
  INPUT_SIGNAL = ez.InputStream(AxisArray)
15
46
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
16
47
 
17
- FILTER = Filter()
48
+ FILTER = ChebyForDecimate()
18
49
  DOWNSAMPLE = Downsample()
19
50
 
20
51
  def configure(self) -> None:
52
+ cheby_settings = ChebyshevFilterSettings(
53
+ order=8,
54
+ ripple_tol=0.05,
55
+ Wn=0.4 * self.SETTINGS.target_rate,
56
+ btype="lowpass",
57
+ axis=self.SETTINGS.axis,
58
+ wn_hz=True,
59
+ )
60
+ self.FILTER.apply_settings(cheby_settings)
21
61
  self.DOWNSAMPLE.apply_settings(self.SETTINGS)
22
62
 
23
- if self.SETTINGS.factor < 1:
24
- raise ValueError("Decimation factor must be >= 1 (no decimation")
25
- elif self.SETTINGS.factor == 1:
26
- filt = FilterCoefficients()
27
- else:
28
- # See scipy.signal.decimate for IIR Filter Condition
29
- b, a = scipy.signal.cheby1(8, 0.05, 0.8 / self.SETTINGS.factor)
30
- system = scipy.signal.dlti(b, a)
31
- filt = FilterCoefficients(b=system.num, a=system.den) # type: ignore
32
-
33
- self.FILTER.apply_settings(FilterSettings(filt=filt))
34
-
35
63
  def network(self) -> ez.NetworkDefinition:
36
64
  return (
37
65
  (self.INPUT_SIGNAL, self.FILTER.INPUT_SIGNAL),
@@ -0,0 +1,78 @@
1
+ import ezmsg.core as ez
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ from ezmsg.baseproc import (
5
+ BaseStatefulTransformer,
6
+ BaseTransformerUnit,
7
+ processor_state,
8
+ )
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
11
+
12
+
13
+ class DenormalizeSettings(ez.Settings):
14
+ low_rate: float = 2.0
15
+ """Low end of probable rate after denormalization (Hz)."""
16
+
17
+ high_rate: float = 40.0
18
+ """High end of probable rate after denormalization (Hz)."""
19
+
20
+ distribution: str = "uniform"
21
+ """Distribution to sample rates from. Options are 'uniform', 'normal', or 'constant'."""
22
+
23
+
24
+ @processor_state
25
+ class DenormalizeState:
26
+ gains: npt.NDArray | None = None
27
+ offsets: npt.NDArray | None = None
28
+
29
+
30
+ class DenormalizeTransformer(BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]):
31
+ """
32
+ Scales data from a normalized distribution (mean=0, std=1) to a denormalized
33
+ distribution using random per-channel offsets and gains designed to keep the
34
+ 99.9% CIs between 0 and 2x the offset.
35
+
36
+ This is useful for simulating realistic firing rates from normalized data.
37
+ """
38
+
39
+ def _reset_state(self, message: AxisArray) -> None:
40
+ ax_ix = message.get_axis_idx("ch")
41
+ nch = message.data.shape[ax_ix]
42
+ arr_size = (nch, 1) if ax_ix == 0 else (1, nch)
43
+ if self.settings.distribution == "uniform":
44
+ self.state.offsets = np.random.uniform(2.0, 40.0, size=arr_size)
45
+ elif self.settings.distribution == "normal":
46
+ self.state.offsets = np.random.normal(
47
+ loc=(self.settings.low_rate + self.settings.high_rate) / 2.0,
48
+ scale=(self.settings.high_rate - self.settings.low_rate) / 6.0,
49
+ size=arr_size,
50
+ )
51
+ self.state.offsets = np.clip(
52
+ self.state.offsets,
53
+ a_min=self.settings.low_rate,
54
+ a_max=self.settings.high_rate,
55
+ )
56
+ elif self.settings.distribution == "constant":
57
+ self.state.offsets = np.full(
58
+ shape=arr_size,
59
+ fill_value=(self.settings.low_rate + self.settings.high_rate) / 2.0,
60
+ )
61
+ else:
62
+ raise ValueError(f"Invalid distribution: {self.settings.distribution}")
63
+ # Input has std == 1
64
+ # Desired output has range from 0 to 2*self.state.offsets within 99.9% confidence interval
65
+ # For a standard normal distribution, 99.9% of data is within +/- 3.29 std devs.
66
+ # So, gain = offset / 3.29 to scale the std dev appropriately.
67
+ self.state.gains = self.state.offsets / 3.29
68
+
69
+ def _process(self, message: AxisArray) -> AxisArray:
70
+ denorm = message.data * self.state.gains + self.state.offsets
71
+ return replace(
72
+ message,
73
+ data=np.clip(denorm, a_min=0.0, a_max=None),
74
+ )
75
+
76
+
77
+ class DenormalizeUnit(BaseTransformerUnit[DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer]):
78
+ SETTINGS = DenormalizeSettings
@@ -0,0 +1,28 @@
1
+ import scipy.signal as sps
2
+ from ezmsg.baseproc import BaseTransformerUnit
3
+ from ezmsg.util.messages.axisarray import AxisArray, replace
4
+
5
+ from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer
6
+
7
+
8
+ class DetrendTransformer(EWMATransformer):
9
+ """
10
+ Detrend the data using an exponentially weighted moving average (EWMA)
11
+ estimate of the mean.
12
+ """
13
+
14
+ def _process(self, message):
15
+ axis = self.settings.axis or message.dims[0]
16
+ axis_idx = message.get_axis_idx(axis)
17
+ means, self._state.zi = sps.lfilter(
18
+ [self._state.alpha],
19
+ [1.0, self._state.alpha - 1.0],
20
+ message.data,
21
+ axis=axis_idx,
22
+ zi=self._state.zi,
23
+ )
24
+ return replace(message, data=message.data - means)
25
+
26
+
27
+ class DetrendUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]):
28
+ SETTINGS = EWMASettings
ezmsg/sigproc/diff.py ADDED
@@ -0,0 +1,82 @@
1
+ """
2
+ Compute differences along an axis.
3
+
4
+ .. note::
5
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ """
8
+
9
+ import ezmsg.core as ez
10
+ import numpy as np
11
+ import numpy.typing as npt
12
+ from array_api_compat import get_namespace
13
+ from ezmsg.baseproc import (
14
+ BaseStatefulTransformer,
15
+ BaseTransformerUnit,
16
+ processor_state,
17
+ )
18
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
19
+ from ezmsg.util.messages.util import replace
20
+
21
+
22
+ class DiffSettings(ez.Settings):
23
+ axis: str | None = None
24
+ scale_by_fs: bool = False
25
+
26
+
27
+ @processor_state
28
+ class DiffState:
29
+ last_dat: npt.NDArray | None = None
30
+ last_time: float | None = None
31
+
32
+
33
+ class DiffTransformer(BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]):
34
+ def _hash_message(self, message: AxisArray) -> int:
35
+ ax_idx = message.get_axis_idx(self.settings.axis)
36
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
37
+ return hash((sample_shape, message.key))
38
+
39
+ def _reset_state(self, message) -> None:
40
+ ax_idx = message.get_axis_idx(self.settings.axis)
41
+ self.state.last_dat = slice_along_axis(message.data, slice(0, 1), axis=ax_idx)
42
+ if self.settings.scale_by_fs:
43
+ ax_info = message.get_axis(self.settings.axis)
44
+ if hasattr(ax_info, "data"):
45
+ if len(ax_info.data) > 1:
46
+ self.state.last_time = 2 * ax_info.data[0] - ax_info.data[1]
47
+ else:
48
+ self.state.last_time = ax_info.data[0] - 0.001
49
+
50
+ def _process(self, message: AxisArray) -> AxisArray:
51
+ xp = get_namespace(message.data)
52
+ axis = self.settings.axis or message.dims[0]
53
+ ax_idx = message.get_axis_idx(axis)
54
+
55
+ diffs = xp.diff(
56
+ xp.concat((self.state.last_dat, message.data), axis=ax_idx),
57
+ axis=ax_idx,
58
+ )
59
+ # Prepare last_dat for next iteration
60
+ self.state.last_dat = slice_along_axis(message.data, slice(-1, None), axis=ax_idx)
61
+ # Scale by fs if requested. This converts the diff to a derivative. e.g., diff of position becomes velocity.
62
+ if self.settings.scale_by_fs:
63
+ ax_info = message.get_axis(axis)
64
+ if hasattr(ax_info, "data"):
65
+ # ax_info.data is typically numpy for metadata, so use np.diff here
66
+ dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
67
+ # Expand dt dims to match diffs
68
+ exp_sl = (None,) * ax_idx + (Ellipsis,) + (None,) * (message.data.ndim - ax_idx - 1)
69
+ diffs /= xp.asarray(dt[exp_sl])
70
+ self.state.last_time = ax_info.data[-1] # For next iteration
71
+ else:
72
+ diffs /= ax_info.gain
73
+
74
+ return replace(message, data=diffs)
75
+
76
+
77
+ class DiffUnit(BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]):
78
+ SETTINGS = DiffSettings
79
+
80
+
81
+ def diff(axis: str = "time", scale_by_fs: bool = False) -> DiffTransformer:
82
+ return DiffTransformer(DiffSettings(axis=axis, scale_by_fs=scale_by_fs))
@@ -1,63 +1,111 @@
1
- from dataclasses import replace
2
-
3
- from ezmsg.util.messages.axisarray import AxisArray
4
-
5
1
  import ezmsg.core as ez
6
2
  import numpy as np
7
-
8
- from typing import (
9
- AsyncGenerator,
10
- Optional,
3
+ from ezmsg.baseproc import (
4
+ BaseStatefulTransformer,
5
+ BaseTransformerUnit,
6
+ processor_state,
7
+ )
8
+ from ezmsg.util.messages.axisarray import (
9
+ AxisArray,
10
+ replace,
11
+ slice_along_axis,
11
12
  )
12
13
 
13
14
 
14
15
  class DownsampleSettings(ez.Settings):
15
- axis: Optional[str] = None
16
- factor: int = 1
17
-
16
+ """
17
+ Settings for :obj:`Downsample` node.
18
+ """
18
19
 
19
- class DownsampleState(ez.State):
20
- cur_settings: DownsampleSettings
21
- s_idx: int = 0
22
-
23
-
24
- class Downsample(ez.Unit):
25
- SETTINGS: DownsampleSettings
26
- STATE: DownsampleState
20
+ axis: str = "time"
21
+ """The name of the axis along which to downsample."""
27
22
 
28
- INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
29
- INPUT_SIGNAL = ez.InputStream(AxisArray)
30
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
23
+ target_rate: float | None = None
24
+ """Desired rate after downsampling. The actual rate will be the nearest integer factor of the
25
+ input rate that is the same or higher than the target rate."""
31
26
 
32
- def initialize(self) -> None:
33
- self.STATE.cur_settings = self.SETTINGS
27
+ factor: int | None = None
28
+ """Explicitly specify downsample factor. If specified, target_rate is ignored."""
34
29
 
35
- @ez.subscriber(INPUT_SETTINGS)
36
- async def on_settings(self, msg: DownsampleSettings) -> None:
37
- self.STATE.cur_settings = msg
38
30
 
39
- @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
40
- @ez.publisher(OUTPUT_SIGNAL)
41
- async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
42
- if self.STATE.cur_settings.factor < 1:
43
- raise ValueError("Downsample factor must be at least 1 (no downsampling)")
31
+ @processor_state
32
+ class DownsampleState:
33
+ q: int = 0
34
+ """The integer downsampling factor. It will be determined based on the target rate."""
44
35
 
45
- axis_name = self.STATE.cur_settings.axis
46
- if axis_name is None:
47
- axis_name = msg.dims[0]
48
- axis = msg.get_axis(axis_name)
49
- axis_idx = msg.get_axis_idx(axis_name)
50
-
51
- samples = np.arange(msg.data.shape[axis_idx]) + self.STATE.s_idx
52
- samples = samples % self.STATE.cur_settings.factor
53
- self.STATE.s_idx = samples[-1] + 1
36
+ s_idx: int = 0
37
+ """Index of the next msg's first sample into the virtual rotating ds_factor counter."""
38
+
39
+
40
+ class DownsampleTransformer(BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]):
41
+ """
42
+ Downsampled data simply comprise every `factor`th sample.
43
+ This should only be used following appropriate lowpass filtering.
44
+ If your pipeline does not already have lowpass filtering then consider
45
+ using the :obj:`Decimate` collection instead.
46
+ """
47
+
48
+ def _hash_message(self, message: AxisArray) -> int:
49
+ return hash((message.axes[self.settings.axis].gain, message.key))
50
+
51
+ def _reset_state(self, message: AxisArray) -> None:
52
+ axis_info = message.get_axis(self.settings.axis)
53
+
54
+ if self.settings.factor is not None:
55
+ q = self.settings.factor
56
+ elif self.settings.target_rate is None:
57
+ q = 1
58
+ else:
59
+ q = int(1 / (axis_info.gain * self.settings.target_rate))
60
+ if q < 1:
61
+ ez.logger.warning(
62
+ f"Target rate {self.settings.target_rate} cannot be achieved with input rate of {1 / axis_info.gain}."
63
+ "Setting factor to 1."
64
+ )
65
+ q = 1
66
+ self._state.q = q
67
+ self._state.s_idx = 0
68
+
69
+ def _process(self, message: AxisArray) -> AxisArray:
70
+ axis = self.settings.axis
71
+ axis_info = message.get_axis(axis)
72
+ axis_idx = message.get_axis_idx(axis)
73
+
74
+ n_samples = message.data.shape[axis_idx]
75
+ samples = np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
76
+ if n_samples > 0:
77
+ # Update state for next iteration.
78
+ self._state.s_idx = samples[-1] + 1
54
79
 
55
80
  pub_samples = np.where(samples == 0)[0]
56
- if len(pub_samples) != 0:
57
- new_axes = {ax_name: msg.get_axis(ax_name) for ax_name in msg.dims}
58
- new_offset = axis.offset + (axis.gain * pub_samples[0].item())
59
- new_gain = axis.gain * self.STATE.cur_settings.factor
60
- new_axes[axis_name] = replace(axis, gain=new_gain, offset=new_offset)
61
- down_data = np.take(msg.data, pub_samples, axis_idx)
62
- out_msg = replace(msg, data=down_data, dims=msg.dims, axes=new_axes)
63
- yield self.OUTPUT_SIGNAL, out_msg
81
+ if len(pub_samples) > 0:
82
+ n_step = pub_samples[0].item()
83
+ data_slice = pub_samples
84
+ else:
85
+ n_step = 0
86
+ data_slice = slice(None, 0, None)
87
+ msg_out = replace(
88
+ message,
89
+ data=slice_along_axis(message.data, data_slice, axis=axis_idx),
90
+ axes={
91
+ **message.axes,
92
+ axis: replace(
93
+ axis_info,
94
+ gain=axis_info.gain * self._state.q,
95
+ offset=axis_info.offset + axis_info.gain * n_step,
96
+ ),
97
+ },
98
+ )
99
+ return msg_out
100
+
101
+
102
+ class Downsample(BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]):
103
+ SETTINGS = DownsampleSettings
104
+
105
+
106
+ def downsample(
107
+ axis: str = "time",
108
+ target_rate: float | None = None,
109
+ factor: int | None = None,
110
+ ) -> DownsampleTransformer:
111
+ return DownsampleTransformer(DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor))