ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,159 @@
1
+ """
2
+ Coordinate space transformations for streaming data.
3
+
4
+ This module provides utilities and ezmsg nodes for transforming between
5
+ Cartesian (x, y) and polar (r, theta) coordinate systems.
6
+
7
+ .. note::
8
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
9
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
10
+ """
11
+
12
+ from enum import Enum
13
+ from typing import Tuple
14
+
15
+ import ezmsg.core as ez
16
+ import numpy as np
17
+ import numpy.typing as npt
18
+ from array_api_compat import get_namespace, is_array_api_obj
19
+ from ezmsg.baseproc import (
20
+ BaseTransformer,
21
+ BaseTransformerUnit,
22
+ )
23
+ from ezmsg.util.messages.axisarray import AxisArray, replace
24
+
25
+ # -- Utility functions for coordinate transformations --
26
+
27
+
28
+ def _get_namespace_or_numpy(*args: npt.ArrayLike):
29
+ """Get array namespace if any arg is an array, otherwise return numpy."""
30
+ for arg in args:
31
+ if is_array_api_obj(arg):
32
+ return get_namespace(arg)
33
+ return np
34
+
35
+
36
+ def polar2z(r: npt.ArrayLike, theta: npt.ArrayLike) -> npt.ArrayLike:
37
+ """Convert polar coordinates to complex number representation."""
38
+ xp = _get_namespace_or_numpy(r, theta)
39
+ return r * xp.exp(1j * theta)
40
+
41
+
42
+ def z2polar(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
43
+ """Convert complex number to polar coordinates (r, theta)."""
44
+ xp = _get_namespace_or_numpy(z)
45
+ return xp.abs(z), xp.atan2(xp.imag(z), xp.real(z))
46
+
47
+
48
+ def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:
49
+ """Convert Cartesian coordinates to complex number representation."""
50
+ return x + 1j * y
51
+
52
+
53
+ def z2cart(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
54
+ """Convert complex number to Cartesian coordinates (x, y)."""
55
+ xp = _get_namespace_or_numpy(z)
56
+ return xp.real(z), xp.imag(z)
57
+
58
+
59
+ def cart2pol(x: npt.ArrayLike, y: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
60
+ """Convert Cartesian coordinates (x, y) to polar coordinates (r, theta)."""
61
+ return z2polar(cart2z(x, y))
62
+
63
+
64
+ def pol2cart(r: npt.ArrayLike, theta: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
65
+ """Convert polar coordinates (r, theta) to Cartesian coordinates (x, y)."""
66
+ return z2cart(polar2z(r, theta))
67
+
68
+
69
+ # -- ezmsg transformer classes --
70
+
71
+
72
+ class CoordinateMode(str, Enum):
73
+ """Transformation mode for coordinate conversion."""
74
+
75
+ CART2POL = "cart2pol"
76
+ """Convert Cartesian (x, y) to polar (r, theta)."""
77
+
78
+ POL2CART = "pol2cart"
79
+ """Convert polar (r, theta) to Cartesian (x, y)."""
80
+
81
+
82
+ class CoordinateSpacesSettings(ez.Settings):
83
+ """
84
+ Settings for :obj:`CoordinateSpaces`.
85
+
86
+ See :obj:`coordinate_spaces` for argument details.
87
+ """
88
+
89
+ mode: CoordinateMode = CoordinateMode.CART2POL
90
+ """The transformation mode: 'cart2pol' or 'pol2cart'."""
91
+
92
+ axis: str | None = None
93
+ """
94
+ The name of the axis containing the coordinate components.
95
+ Defaults to the last axis. Must have exactly 2 elements (x,y or r,theta).
96
+ """
97
+
98
+
99
+ class CoordinateSpacesTransformer(BaseTransformer[CoordinateSpacesSettings, AxisArray, AxisArray]):
100
+ """
101
+ Transform between Cartesian and polar coordinate systems.
102
+
103
+ The input must have exactly 2 elements along the specified axis:
104
+ - For cart2pol: expects (x, y), outputs (r, theta)
105
+ - For pol2cart: expects (r, theta), outputs (x, y)
106
+ """
107
+
108
+ def _process(self, message: AxisArray) -> AxisArray:
109
+ xp = get_namespace(message.data)
110
+ axis = self.settings.axis or message.dims[-1]
111
+ axis_idx = message.get_axis_idx(axis)
112
+
113
+ if message.data.shape[axis_idx] != 2:
114
+ raise ValueError(
115
+ f"Coordinate transformation requires exactly 2 elements along axis '{axis}', "
116
+ f"got {message.data.shape[axis_idx]}."
117
+ )
118
+
119
+ # Extract components along the specified axis
120
+ slices_a = [slice(None)] * message.data.ndim
121
+ slices_b = [slice(None)] * message.data.ndim
122
+ slices_a[axis_idx] = 0
123
+ slices_b[axis_idx] = 1
124
+
125
+ component_a = message.data[tuple(slices_a)]
126
+ component_b = message.data[tuple(slices_b)]
127
+
128
+ if self.settings.mode == CoordinateMode.CART2POL:
129
+ # Input: x, y -> Output: r, theta
130
+ out_a, out_b = cart2pol(component_a, component_b)
131
+ else:
132
+ # Input: r, theta -> Output: x, y
133
+ out_a, out_b = pol2cart(component_a, component_b)
134
+
135
+ # Stack results back along the same axis
136
+ result = xp.stack([out_a, out_b], axis=axis_idx)
137
+
138
+ # Update axis labels if present (use numpy for string labels)
139
+ axes = message.axes
140
+ if axis in axes and hasattr(axes[axis], "data"):
141
+ if self.settings.mode == CoordinateMode.CART2POL:
142
+ new_labels = np.array(["r", "theta"])
143
+ else:
144
+ new_labels = np.array(["x", "y"])
145
+ axes = {**axes, axis: replace(axes[axis], data=new_labels)}
146
+
147
+ return replace(message, data=result, axes=axes)
148
+
149
+
150
+ class CoordinateSpaces(
151
+ BaseTransformerUnit[CoordinateSpacesSettings, AxisArray, AxisArray, CoordinateSpacesTransformer]
152
+ ):
153
+ """
154
+ Unit for transforming between Cartesian and polar coordinate systems.
155
+
156
+ See :obj:`CoordinateSpacesSettings` for configuration options.
157
+ """
158
+
159
+ SETTINGS = CoordinateSpacesSettings
ezmsg/sigproc/decimate.py CHANGED
@@ -1,33 +1,39 @@
1
1
  import typing
2
2
 
3
3
  import ezmsg.core as ez
4
+ from ezmsg.baseproc import BaseTransformerUnit
4
5
  from ezmsg.util.messages.axisarray import AxisArray
5
6
 
6
- from .cheby import ChebyshevFilter, ChebyshevFilterSettings
7
+ from .cheby import ChebyshevFilterSettings, ChebyshevFilterTransformer
7
8
  from .downsample import Downsample, DownsampleSettings
8
- from .filter import FilterCoefsMultiType
9
+ from .filter import BACoeffs, SOSCoeffs
9
10
 
10
11
 
11
- class ChebyForDecimate(ChebyshevFilter):
12
+ class ChebyForDecimateTransformer(ChebyshevFilterTransformer[BACoeffs | SOSCoeffs]):
12
13
  """
13
- A :obj:`ChebyshevFilter` node with a design filter method that additionally accepts a target sampling rate,
14
+ A :obj:`ChebyshevFilterTransformer` with a design filter method that additionally accepts a target sampling rate,
14
15
  and if the target rate cannot be achieved it returns None, else it returns the filter coefficients.
15
16
  """
16
17
 
17
- def design_filter(
18
+ def get_design_function(
18
19
  self,
19
- ) -> typing.Callable[[float], FilterCoefsMultiType | None]:
20
- def cheby_opt_design_fun(fs: float) -> FilterCoefsMultiType | None:
20
+ ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
21
+ def cheby_opt_design_fun(fs: float) -> BACoeffs | SOSCoeffs | None:
21
22
  if fs is None:
22
23
  return None
23
- ds_factor = int(fs / (2.5 * self.SETTINGS.Wn))
24
+ ds_factor = int(fs / (2.5 * self.settings.Wn))
24
25
  if ds_factor < 2:
25
26
  return None
26
- partial_fun = super(ChebyForDecimate, self).design_filter()
27
+ partial_fun = super(ChebyForDecimateTransformer, self).get_design_function()
27
28
  return partial_fun(fs)
29
+
28
30
  return cheby_opt_design_fun
29
31
 
30
32
 
33
+ class ChebyForDecimate(BaseTransformerUnit[ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer]):
34
+ SETTINGS = ChebyshevFilterSettings
35
+
36
+
31
37
  class Decimate(ez.Collection):
32
38
  """
33
39
  A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
@@ -43,7 +49,6 @@ class Decimate(ez.Collection):
43
49
  DOWNSAMPLE = Downsample()
44
50
 
45
51
  def configure(self) -> None:
46
-
47
52
  cheby_settings = ChebyshevFilterSettings(
48
53
  order=8,
49
54
  ripple_tol=0.05,
@@ -0,0 +1,78 @@
1
+ import ezmsg.core as ez
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ from ezmsg.baseproc import (
5
+ BaseStatefulTransformer,
6
+ BaseTransformerUnit,
7
+ processor_state,
8
+ )
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
11
+
12
+
13
+ class DenormalizeSettings(ez.Settings):
14
+ low_rate: float = 2.0
15
+ """Low end of probable rate after denormalization (Hz)."""
16
+
17
+ high_rate: float = 40.0
18
+ """High end of probable rate after denormalization (Hz)."""
19
+
20
+ distribution: str = "uniform"
21
+ """Distribution to sample rates from. Options are 'uniform', 'normal', or 'constant'."""
22
+
23
+
24
+ @processor_state
25
+ class DenormalizeState:
26
+ gains: npt.NDArray | None = None
27
+ offsets: npt.NDArray | None = None
28
+
29
+
30
+ class DenormalizeTransformer(BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]):
31
+ """
32
+ Scales data from a normalized distribution (mean=0, std=1) to a denormalized
33
+ distribution using random per-channel offsets and gains designed to keep the
34
+ 99.9% CIs between 0 and 2x the offset.
35
+
36
+ This is useful for simulating realistic firing rates from normalized data.
37
+ """
38
+
39
+ def _reset_state(self, message: AxisArray) -> None:
40
+ ax_ix = message.get_axis_idx("ch")
41
+ nch = message.data.shape[ax_ix]
42
+ arr_size = (nch, 1) if ax_ix == 0 else (1, nch)
43
+ if self.settings.distribution == "uniform":
44
+ self.state.offsets = np.random.uniform(2.0, 40.0, size=arr_size)
45
+ elif self.settings.distribution == "normal":
46
+ self.state.offsets = np.random.normal(
47
+ loc=(self.settings.low_rate + self.settings.high_rate) / 2.0,
48
+ scale=(self.settings.high_rate - self.settings.low_rate) / 6.0,
49
+ size=arr_size,
50
+ )
51
+ self.state.offsets = np.clip(
52
+ self.state.offsets,
53
+ a_min=self.settings.low_rate,
54
+ a_max=self.settings.high_rate,
55
+ )
56
+ elif self.settings.distribution == "constant":
57
+ self.state.offsets = np.full(
58
+ shape=arr_size,
59
+ fill_value=(self.settings.low_rate + self.settings.high_rate) / 2.0,
60
+ )
61
+ else:
62
+ raise ValueError(f"Invalid distribution: {self.settings.distribution}")
63
+ # Input has std == 1
64
+ # Desired output has range from 0 to 2*self.state.offsets within 99.9% confidence interval
65
+ # For a standard normal distribution, 99.9% of data is within +/- 3.29 std devs.
66
+ # So, gain = offset / 3.29 to scale the std dev appropriately.
67
+ self.state.gains = self.state.offsets / 3.29
68
+
69
+ def _process(self, message: AxisArray) -> AxisArray:
70
+ denorm = message.data * self.state.gains + self.state.offsets
71
+ return replace(
72
+ message,
73
+ data=np.clip(denorm, a_min=0.0, a_max=None),
74
+ )
75
+
76
+
77
+ class DenormalizeUnit(BaseTransformerUnit[DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer]):
78
+ SETTINGS = DenormalizeSettings
@@ -0,0 +1,28 @@
1
+ import scipy.signal as sps
2
+ from ezmsg.baseproc import BaseTransformerUnit
3
+ from ezmsg.util.messages.axisarray import AxisArray, replace
4
+
5
+ from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer
6
+
7
+
8
+ class DetrendTransformer(EWMATransformer):
9
+ """
10
+ Detrend the data using an exponentially weighted moving average (EWMA)
11
+ estimate of the mean.
12
+ """
13
+
14
+ def _process(self, message):
15
+ axis = self.settings.axis or message.dims[0]
16
+ axis_idx = message.get_axis_idx(axis)
17
+ means, self._state.zi = sps.lfilter(
18
+ [self._state.alpha],
19
+ [1.0, self._state.alpha - 1.0],
20
+ message.data,
21
+ axis=axis_idx,
22
+ zi=self._state.zi,
23
+ )
24
+ return replace(message, data=message.data - means)
25
+
26
+
27
+ class DetrendUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]):
28
+ SETTINGS = EWMASettings
ezmsg/sigproc/diff.py ADDED
@@ -0,0 +1,82 @@
1
+ """
2
+ Compute differences along an axis.
3
+
4
+ .. note::
5
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ """
8
+
9
+ import ezmsg.core as ez
10
+ import numpy as np
11
+ import numpy.typing as npt
12
+ from array_api_compat import get_namespace
13
+ from ezmsg.baseproc import (
14
+ BaseStatefulTransformer,
15
+ BaseTransformerUnit,
16
+ processor_state,
17
+ )
18
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
19
+ from ezmsg.util.messages.util import replace
20
+
21
+
22
+ class DiffSettings(ez.Settings):
23
+ axis: str | None = None
24
+ scale_by_fs: bool = False
25
+
26
+
27
+ @processor_state
28
+ class DiffState:
29
+ last_dat: npt.NDArray | None = None
30
+ last_time: float | None = None
31
+
32
+
33
+ class DiffTransformer(BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]):
34
+ def _hash_message(self, message: AxisArray) -> int:
35
+ ax_idx = message.get_axis_idx(self.settings.axis)
36
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
37
+ return hash((sample_shape, message.key))
38
+
39
+ def _reset_state(self, message) -> None:
40
+ ax_idx = message.get_axis_idx(self.settings.axis)
41
+ self.state.last_dat = slice_along_axis(message.data, slice(0, 1), axis=ax_idx)
42
+ if self.settings.scale_by_fs:
43
+ ax_info = message.get_axis(self.settings.axis)
44
+ if hasattr(ax_info, "data"):
45
+ if len(ax_info.data) > 1:
46
+ self.state.last_time = 2 * ax_info.data[0] - ax_info.data[1]
47
+ else:
48
+ self.state.last_time = ax_info.data[0] - 0.001
49
+
50
+ def _process(self, message: AxisArray) -> AxisArray:
51
+ xp = get_namespace(message.data)
52
+ axis = self.settings.axis or message.dims[0]
53
+ ax_idx = message.get_axis_idx(axis)
54
+
55
+ diffs = xp.diff(
56
+ xp.concat((self.state.last_dat, message.data), axis=ax_idx),
57
+ axis=ax_idx,
58
+ )
59
+ # Prepare last_dat for next iteration
60
+ self.state.last_dat = slice_along_axis(message.data, slice(-1, None), axis=ax_idx)
61
+ # Scale by fs if requested. This converts the diff to a derivative. e.g., diff of position becomes velocity.
62
+ if self.settings.scale_by_fs:
63
+ ax_info = message.get_axis(axis)
64
+ if hasattr(ax_info, "data"):
65
+ # ax_info.data is typically numpy for metadata, so use np.diff here
66
+ dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
67
+ # Expand dt dims to match diffs
68
+ exp_sl = (None,) * ax_idx + (Ellipsis,) + (None,) * (message.data.ndim - ax_idx - 1)
69
+ diffs /= xp.asarray(dt[exp_sl])
70
+ self.state.last_time = ax_info.data[-1] # For next iteration
71
+ else:
72
+ diffs /= ax_info.gain
73
+
74
+ return replace(message, data=diffs)
75
+
76
+
77
+ class DiffUnit(BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]):
78
+ SETTINGS = DiffSettings
79
+
80
+
81
+ def diff(axis: str = "time", scale_by_fs: bool = False) -> DiffTransformer:
82
+ return DiffTransformer(DiffSettings(axis=axis, scale_by_fs=scale_by_fs))
@@ -1,82 +1,81 @@
1
- import typing
2
-
1
+ import ezmsg.core as ez
3
2
  import numpy as np
3
+ from ezmsg.baseproc import (
4
+ BaseStatefulTransformer,
5
+ BaseTransformerUnit,
6
+ processor_state,
7
+ )
4
8
  from ezmsg.util.messages.axisarray import (
5
9
  AxisArray,
6
- slice_along_axis,
7
10
  replace,
11
+ slice_along_axis,
8
12
  )
9
- from ezmsg.util.generator import consumer
10
- import ezmsg.core as ez
11
13
 
12
- from .base import GenAxisArray
13
14
 
15
+ class DownsampleSettings(ez.Settings):
16
+ """
17
+ Settings for :obj:`Downsample` node.
18
+ """
14
19
 
15
- @consumer
16
- def downsample(
17
- axis: str | None = None, target_rate: float | None = None
18
- ) -> typing.Generator[AxisArray, AxisArray, None]:
20
+ axis: str = "time"
21
+ """The name of the axis along which to downsample."""
22
+
23
+ target_rate: float | None = None
24
+ """Desired rate after downsampling. The actual rate will be the nearest integer factor of the
25
+ input rate that is the same or higher than the target rate."""
26
+
27
+ factor: int | None = None
28
+ """Explicitly specify downsample factor. If specified, target_rate is ignored."""
29
+
30
+
31
+ @processor_state
32
+ class DownsampleState:
33
+ q: int = 0
34
+ """The integer downsampling factor. It will be determined based on the target rate."""
35
+
36
+ s_idx: int = 0
37
+ """Index of the next msg's first sample into the virtual rotating ds_factor counter."""
38
+
39
+
40
+ class DownsampleTransformer(BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]):
19
41
  """
20
- Construct a generator that yields a downsampled version of the data .send() to it.
21
42
  Downsampled data simply comprise every `factor`th sample.
22
43
  This should only be used following appropriate lowpass filtering.
23
44
  If your pipeline does not already have lowpass filtering then consider
24
45
  using the :obj:`Decimate` collection instead.
25
-
26
- Args:
27
- axis: The name of the axis along which to downsample.
28
- Note: The axis must exist in the message .axes and be of type AxisArray.LinearAxis.
29
- target_rate: Desired rate after downsampling. The actual rate will be the nearest integer factor of the
30
- input rate that is the same or higher than the target rate.
31
-
32
- Returns:
33
- A primed generator object ready to receive an :obj:`AxisArray` via `.send(axis_array)`
34
- and yields an :obj:`AxisArray` with its data downsampled.
35
- Note that if a send chunk does not have sufficient samples to reach the
36
- next downsample interval then an :obj:`AxisArray` with size-zero data is yielded.
37
-
38
46
  """
39
- msg_out = AxisArray(np.array([]), dims=[""])
40
-
41
- # state variables
42
- factor: int = 0 # The integer downsampling factor. It will be determined based on the target rate.
43
- s_idx: int = 0 # Index of the next msg's first sample into the virtual rotating ds_factor counter.
44
-
45
- check_input = {"gain": None, "key": None}
46
47
 
47
- while True:
48
- msg_in: AxisArray = yield msg_out
48
+ def _hash_message(self, message: AxisArray) -> int:
49
+ return hash((message.axes[self.settings.axis].gain, message.key))
49
50
 
50
- if axis is None:
51
- axis = msg_in.dims[0]
52
- axis_info = msg_in.get_axis(axis)
53
- axis_idx = msg_in.get_axis_idx(axis)
51
+ def _reset_state(self, message: AxisArray) -> None:
52
+ axis_info = message.get_axis(self.settings.axis)
54
53
 
55
- b_reset = (
56
- msg_in.axes[axis].gain != check_input["gain"]
57
- or msg_in.key != check_input["key"]
58
- )
59
- if b_reset:
60
- check_input["gain"] = axis_info.gain
61
- check_input["key"] = msg_in.key
62
- # Reset state variables
63
- s_idx = 0
64
- if target_rate is None:
65
- factor = 1
66
- else:
67
- factor = int(1 / (axis_info.gain * target_rate))
68
- if factor < 1:
69
- ez.logger.warning(
70
- f"Target rate {target_rate} cannot be achieved with input rate of {1/axis_info.gain}."
71
- "Setting factor to 1."
72
- )
73
- factor = 1
74
-
75
- n_samples = msg_in.data.shape[axis_idx]
76
- samples = np.arange(s_idx, s_idx + n_samples) % factor
54
+ if self.settings.factor is not None:
55
+ q = self.settings.factor
56
+ elif self.settings.target_rate is None:
57
+ q = 1
58
+ else:
59
+ q = int(1 / (axis_info.gain * self.settings.target_rate))
60
+ if q < 1:
61
+ ez.logger.warning(
62
+ f"Target rate {self.settings.target_rate} cannot be achieved with input rate of {1 / axis_info.gain}."
63
+ "Setting factor to 1."
64
+ )
65
+ q = 1
66
+ self._state.q = q
67
+ self._state.s_idx = 0
68
+
69
+ def _process(self, message: AxisArray) -> AxisArray:
70
+ axis = self.settings.axis
71
+ axis_info = message.get_axis(axis)
72
+ axis_idx = message.get_axis_idx(axis)
73
+
74
+ n_samples = message.data.shape[axis_idx]
75
+ samples = np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
77
76
  if n_samples > 0:
78
77
  # Update state for next iteration.
79
- s_idx = samples[-1] + 1
78
+ self._state.s_idx = samples[-1] + 1
80
79
 
81
80
  pub_samples = np.where(samples == 0)[0]
82
81
  if len(pub_samples) > 0:
@@ -86,35 +85,27 @@ def downsample(
86
85
  n_step = 0
87
86
  data_slice = slice(None, 0, None)
88
87
  msg_out = replace(
89
- msg_in,
90
- data=slice_along_axis(msg_in.data, data_slice, axis=axis_idx),
88
+ message,
89
+ data=slice_along_axis(message.data, data_slice, axis=axis_idx),
91
90
  axes={
92
- **msg_in.axes,
91
+ **message.axes,
93
92
  axis: replace(
94
93
  axis_info,
95
- gain=axis_info.gain * factor,
94
+ gain=axis_info.gain * self._state.q,
96
95
  offset=axis_info.offset + axis_info.gain * n_step,
97
96
  ),
98
97
  },
99
98
  )
99
+ return msg_out
100
100
 
101
101
 
102
- class DownsampleSettings(ez.Settings):
103
- """
104
- Settings for :obj:`Downsample` node.
105
- See :obj:`downsample` documentation for a description of the parameters.
106
- """
107
-
108
- axis: str | None = None
109
- target_rate: float | None = None
110
-
111
-
112
- class Downsample(GenAxisArray):
113
- """:obj:`Unit` for :obj:`bandpower`."""
114
-
102
+ class Downsample(BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]):
115
103
  SETTINGS = DownsampleSettings
116
104
 
117
- def construct_generator(self):
118
- self.STATE.gen = downsample(
119
- axis=self.SETTINGS.axis, target_rate=self.SETTINGS.target_rate
120
- )
105
+
106
+ def downsample(
107
+ axis: str = "time",
108
+ target_rate: float | None = None,
109
+ factor: int | None = None,
110
+ ) -> DownsampleTransformer:
111
+ return DownsampleTransformer(DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor))