ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,89 @@
1
+ import warnings
2
+ from typing import Callable
3
+
4
+ import numpy as np
5
+
6
+ from .filter import (
7
+ BACoeffs,
8
+ BaseFilterByDesignTransformerUnit,
9
+ FilterBaseSettings,
10
+ FilterByDesignTransformer,
11
+ )
12
+
13
+
14
+ class GaussianSmoothingSettings(FilterBaseSettings):
15
+ sigma: float | None = 1.0
16
+ """
17
+ sigma : float
18
+ Standard deviation of the Gaussian kernel.
19
+ """
20
+
21
+ width: int | None = 4
22
+ """
23
+ width : int
24
+ Number of standard deviations covered by the kernel window if kernel_size is not provided.
25
+ """
26
+
27
+ kernel_size: int | None = None
28
+ """
29
+ kernel_size : int | None
30
+ Length of the kernel in samples. If provided, overrides automatic calculation.
31
+ """
32
+
33
+
34
+ def gaussian_smoothing_filter_design(
35
+ sigma: float = 1.0,
36
+ width: int = 4,
37
+ kernel_size: int | None = None,
38
+ ) -> BACoeffs | None:
39
+ # Parameter checks
40
+ if sigma <= 0:
41
+ raise ValueError(f"sigma must be positive. Received: {sigma}")
42
+
43
+ if width <= 0:
44
+ raise ValueError(f"width must be positive. Received: {width}")
45
+
46
+ if kernel_size is not None:
47
+ if kernel_size < 1:
48
+ raise ValueError(f"kernel_size must be >= 1. Received: {kernel_size}")
49
+ else:
50
+ kernel_size = int(2 * width * sigma + 1)
51
+
52
+ # Warn if kernel_size is smaller than recommended but don't fail
53
+ expected_kernel_size = int(2 * width * sigma + 1)
54
+ if kernel_size < expected_kernel_size:
55
+ ## TODO: Either add a warning or determine appropriate kernel size and raise an error
56
+ warnings.warn(
57
+ f"Provided kernel_size {kernel_size} is smaller than recommended "
58
+ f"size {expected_kernel_size} for sigma={sigma} and width={width}. "
59
+ "The kernel may be truncated."
60
+ )
61
+
62
+ from scipy.signal.windows import gaussian
63
+
64
+ b = gaussian(kernel_size, std=sigma)
65
+ b /= np.sum(b) # Ensure normalization
66
+ a = np.array([1.0])
67
+
68
+ return b, a
69
+
70
+
71
+ class GaussianSmoothingFilterTransformer(FilterByDesignTransformer[GaussianSmoothingSettings, BACoeffs]):
72
+ def get_design_function(
73
+ self,
74
+ ) -> Callable[[float], BACoeffs]:
75
+ # Create a wrapper function that ignores fs parameter since gaussian smoothing doesn't need it
76
+ def design_wrapper(fs: float) -> BACoeffs:
77
+ return gaussian_smoothing_filter_design(
78
+ sigma=self.settings.sigma,
79
+ width=self.settings.width,
80
+ kernel_size=self.settings.kernel_size,
81
+ )
82
+
83
+ return design_wrapper
84
+
85
+
86
+ class GaussianSmoothingFilter(
87
+ BaseFilterByDesignTransformerUnit[GaussianSmoothingSettings, GaussianSmoothingFilterTransformer]
88
+ ):
89
+ SETTINGS = GaussianSmoothingSettings
@@ -0,0 +1,106 @@
1
+ import functools
2
+ import typing
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import scipy.signal
7
+
8
+ from .filter import (
9
+ BACoeffs,
10
+ BaseFilterByDesignTransformerUnit,
11
+ FilterBaseSettings,
12
+ FilterByDesignTransformer,
13
+ )
14
+
15
+
16
+ class KaiserFilterSettings(FilterBaseSettings):
17
+ """Settings for :obj:`KaiserFilter`"""
18
+
19
+ # axis and coef_type are inherited from FilterBaseSettings
20
+
21
+ cutoff: float | npt.ArrayLike | None = None
22
+ """
23
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
24
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
25
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
26
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
27
+ not be included in cutoff.
28
+ """
29
+
30
+ ripple: float | None = None
31
+ """
32
+ Upper bound for the deviation (in dB) of the magnitude of the filter's frequency response from that of
33
+ the desired filter (not including frequencies in any transition intervals).
34
+ See scipy.signal.kaiserord for more information.
35
+ """
36
+
37
+ width: float | None = None
38
+ """
39
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
40
+ the same units as fs) for use in Kaiser FIR filter design.
41
+ See scipy.signal.kaiserord for more information.
42
+ """
43
+
44
+ pass_zero: bool | str = True
45
+ """
46
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
47
+ be a string argument for the desired filter type (equivalent to btype in IIR design functions).
48
+ {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
49
+ """
50
+
51
+ wn_hz: bool = True
52
+ """
53
+ Set False if cutoff and width are normalized from 0 to 1, where 1 is the Nyquist frequency
54
+ """
55
+
56
+
57
+ def kaiser_design_fun(
58
+ fs: float,
59
+ cutoff: float | npt.ArrayLike | None = None,
60
+ ripple: float | None = None,
61
+ width: float | None = None,
62
+ pass_zero: bool | str = True,
63
+ wn_hz: bool = True,
64
+ ) -> BACoeffs | None:
65
+ """
66
+ Design an `order`th-order FIR Kaiser filter and return the filter coefficients.
67
+ See :obj:`FIRFilterSettings` for argument description.
68
+
69
+ Returns:
70
+ The filter taps as designed by firwin
71
+ """
72
+ if ripple is None or width is None or cutoff is None:
73
+ return None
74
+
75
+ width = width / (0.5 * fs) if wn_hz else width
76
+ n_taps, beta = scipy.signal.kaiserord(ripple, width)
77
+ if n_taps % 2 == 0:
78
+ n_taps += 1
79
+ taps = scipy.signal.firwin(
80
+ numtaps=n_taps,
81
+ cutoff=cutoff,
82
+ window=("kaiser", beta), # type: ignore
83
+ pass_zero=pass_zero, # type: ignore
84
+ scale=False,
85
+ fs=fs if wn_hz else None,
86
+ )
87
+
88
+ return (taps, np.array([1.0]))
89
+
90
+
91
+ class KaiserFilterTransformer(FilterByDesignTransformer[KaiserFilterSettings, BACoeffs]):
92
+ def get_design_function(
93
+ self,
94
+ ) -> typing.Callable[[float], BACoeffs | None]:
95
+ return functools.partial(
96
+ kaiser_design_fun,
97
+ cutoff=self.settings.cutoff,
98
+ ripple=self.settings.ripple,
99
+ width=self.settings.width,
100
+ pass_zero=self.settings.pass_zero,
101
+ wn_hz=self.settings.wn_hz,
102
+ )
103
+
104
+
105
+ class KaiserFilter(BaseFilterByDesignTransformerUnit[KaiserFilterSettings, KaiserFilterTransformer]):
106
+ SETTINGS = KaiserFilterSettings
@@ -0,0 +1,120 @@
1
+ """
2
+ Apply a linear transformation: output = scale * input + offset.
3
+
4
+ Supports per-element scale and offset along a specified axis.
5
+ For full matrix transformations, use :obj:`AffineTransformTransformer` instead.
6
+
7
+ .. note::
8
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
9
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
10
+ """
11
+
12
+ import ezmsg.core as ez
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ from array_api_compat import get_namespace
16
+ from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
17
+ from ezmsg.util.messages.axisarray import AxisArray
18
+ from ezmsg.util.messages.util import replace
19
+
20
+
21
+ class LinearTransformSettings(ez.Settings):
22
+ scale: float | list[float] | npt.ArrayLike = 1.0
23
+ """Scale factor(s). Can be a scalar (applied to all elements) or an array
24
+ matching the size of the specified axis for per-element scaling."""
25
+
26
+ offset: float | list[float] | npt.ArrayLike = 0.0
27
+ """Offset value(s). Can be a scalar (applied to all elements) or an array
28
+ matching the size of the specified axis for per-element offset."""
29
+
30
+ axis: str | None = None
31
+ """Axis along which to apply per-element scale/offset. If None, scalar
32
+ scale/offset are broadcast to all elements."""
33
+
34
+
35
+ @processor_state
36
+ class LinearTransformState:
37
+ scale: npt.NDArray = None
38
+ """Prepared scale array for broadcasting."""
39
+
40
+ offset: npt.NDArray = None
41
+ """Prepared offset array for broadcasting."""
42
+
43
+
44
+ class LinearTransformTransformer(
45
+ BaseStatefulTransformer[LinearTransformSettings, AxisArray, AxisArray, LinearTransformState]
46
+ ):
47
+ """Apply linear transformation: output = scale * input + offset.
48
+
49
+ This transformer is optimized for element-wise linear operations with
50
+ optional per-channel (or per-axis) coefficients. For full matrix
51
+ transformations, use :obj:`AffineTransformTransformer` instead.
52
+
53
+ Examples:
54
+ # Uniform scaling and offset
55
+ >>> transformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0))
56
+
57
+ # Per-channel scaling (e.g., for 3-channel data along "ch" axis)
58
+ >>> transformer = LinearTransformTransformer(LinearTransformSettings(
59
+ ... scale=[0.5, 1.0, 2.0],
60
+ ... offset=[0.0, 0.1, 0.2],
61
+ ... axis="ch"
62
+ ... ))
63
+ """
64
+
65
+ def _hash_message(self, message: AxisArray) -> int:
66
+ """Hash based on shape and axis to detect when broadcast shapes need recalculation."""
67
+ axis = self.settings.axis
68
+ if axis is not None:
69
+ axis_idx = message.get_axis_idx(axis)
70
+ return hash((message.data.ndim, axis_idx, message.data.shape[axis_idx]))
71
+ return hash(message.data.ndim)
72
+
73
+ def _reset_state(self, message: AxisArray) -> None:
74
+ """Prepare scale/offset arrays with proper broadcast shapes."""
75
+ xp = get_namespace(message.data)
76
+ ndim = message.data.ndim
77
+
78
+ scale = self.settings.scale
79
+ offset = self.settings.offset
80
+
81
+ # Convert settings to arrays
82
+ if isinstance(scale, (list, np.ndarray)):
83
+ scale = xp.asarray(scale, dtype=xp.float64)
84
+ else:
85
+ # Scalar: create a 0-d array
86
+ scale = xp.asarray(float(scale), dtype=xp.float64)
87
+
88
+ if isinstance(offset, (list, np.ndarray)):
89
+ offset = xp.asarray(offset, dtype=xp.float64)
90
+ else:
91
+ # Scalar: create a 0-d array
92
+ offset = xp.asarray(float(offset), dtype=xp.float64)
93
+
94
+ # If axis is specified and we have 1-d arrays, reshape for proper broadcasting
95
+ if self.settings.axis is not None and ndim > 0:
96
+ axis_idx = message.get_axis_idx(self.settings.axis)
97
+
98
+ if scale.ndim == 1:
99
+ # Create shape for broadcasting: all 1s except at axis_idx
100
+ broadcast_shape = [1] * ndim
101
+ broadcast_shape[axis_idx] = scale.shape[0]
102
+ scale = xp.reshape(scale, broadcast_shape)
103
+
104
+ if offset.ndim == 1:
105
+ broadcast_shape = [1] * ndim
106
+ broadcast_shape[axis_idx] = offset.shape[0]
107
+ offset = xp.reshape(offset, broadcast_shape)
108
+
109
+ self._state.scale = scale
110
+ self._state.offset = offset
111
+
112
+ def _process(self, message: AxisArray) -> AxisArray:
113
+ result = message.data * self._state.scale + self._state.offset
114
+ return replace(message, data=result)
115
+
116
+
117
+ class LinearTransform(BaseTransformerUnit[LinearTransformSettings, AxisArray, AxisArray, LinearTransformTransformer]):
118
+ """Unit wrapper for LinearTransformTransformer."""
119
+
120
+ SETTINGS = LinearTransformSettings
ezmsg/sigproc/math/abs.py CHANGED
@@ -1,34 +1,35 @@
1
- import typing
1
+ """
2
+ Take the absolute value of the data.
2
3
 
3
- import numpy as np
4
- import ezmsg.core as ez
5
- from ezmsg.util.generator import consumer
4
+ .. note::
5
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ """
8
+
9
+ from array_api_compat import get_namespace
10
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
6
11
  from ezmsg.util.messages.axisarray import AxisArray
7
12
  from ezmsg.util.messages.util import replace
8
13
 
9
- from ..base import GenAxisArray
10
14
 
15
+ class AbsSettings:
16
+ pass
11
17
 
12
- @consumer
13
- def abs() -> typing.Generator[AxisArray, AxisArray, None]:
14
- """
15
- Take the absolute value of the data. See :obj:`np.abs` for more details.
16
18
 
17
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
18
- with the data payload containing the absolute value of the input :obj:`AxisArray` data.
19
- """
20
- msg_out = AxisArray(np.array([]), dims=[""])
21
- while True:
22
- msg_in: AxisArray = yield msg_out
23
- msg_out = replace(msg_in, data=np.abs(msg_in.data))
19
+ class AbsTransformer(BaseTransformer[None, AxisArray, AxisArray]):
20
+ def _process(self, message: AxisArray) -> AxisArray:
21
+ xp = get_namespace(message.data)
22
+ return replace(message, data=xp.abs(message.data))
24
23
 
25
24
 
26
- class AbsSettings(ez.Settings):
27
- pass
25
+ class Abs(BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]): ... # SETTINGS = None
28
26
 
29
27
 
30
- class Abs(GenAxisArray):
31
- SETTINGS = AbsSettings
28
+ def abs() -> AbsTransformer:
29
+ """
30
+ Take the absolute value of the data. See :obj:`np.abs` for more details.
32
31
 
33
- def construct_generator(self):
34
- self.STATE.gen = abs()
32
+ Returns: :obj:`AbsTransformer`.
33
+
34
+ """
35
+ return AbsTransformer()
@@ -0,0 +1,120 @@
1
+ """Add 2 signals or add a constant to a signal."""
2
+
3
+ import asyncio
4
+ import typing
5
+ from dataclasses import dataclass, field
6
+
7
+ import ezmsg.core as ez
8
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
9
+ from ezmsg.baseproc.util.asio import run_coroutine_sync
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
+ from ezmsg.util.messages.util import replace
12
+
13
+ # --- Constant Addition (single input) ---
14
+
15
+
16
+ class ConstAddSettings(ez.Settings):
17
+ value: float = 0.0
18
+ """Number to add to the input data."""
19
+
20
+
21
+ class ConstAddTransformer(BaseTransformer[ConstAddSettings, AxisArray, AxisArray]):
22
+ """Add a constant value to input data."""
23
+
24
+ def _process(self, message: AxisArray) -> AxisArray:
25
+ return replace(message, data=message.data + self.settings.value)
26
+
27
+
28
+ class ConstAdd(BaseTransformerUnit[ConstAddSettings, AxisArray, AxisArray, ConstAddTransformer]):
29
+ """Unit wrapper for ConstAddTransformer."""
30
+
31
+ SETTINGS = ConstAddSettings
32
+
33
+
34
+ # --- Two-input Addition ---
35
+
36
+
37
+ @dataclass
38
+ class AddState:
39
+ """State for Add processor with two input queues."""
40
+
41
+ queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
42
+ queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
43
+
44
+
45
+ class AddProcessor:
46
+ """Processor that adds two AxisArray signals together.
47
+
48
+ This processor maintains separate queues for two input streams and
49
+ adds corresponding messages element-wise. It assumes both inputs
50
+ have compatible shapes and aligned time spans.
51
+ """
52
+
53
+ def __init__(self):
54
+ self._state = AddState()
55
+
56
+ @property
57
+ def state(self) -> AddState:
58
+ return self._state
59
+
60
+ @state.setter
61
+ def state(self, state: AddState | bytes | None) -> None:
62
+ if state is not None:
63
+ # TODO: Support hydrating state from bytes
64
+ # if isinstance(state, bytes):
65
+ # self._state = pickle.loads(state)
66
+ # else:
67
+ self._state = state
68
+
69
+ def push_a(self, msg: AxisArray) -> None:
70
+ """Push a message to queue A."""
71
+ self._state.queue_a.put_nowait(msg)
72
+
73
+ def push_b(self, msg: AxisArray) -> None:
74
+ """Push a message to queue B."""
75
+ self._state.queue_b.put_nowait(msg)
76
+
77
+ async def __acall__(self) -> AxisArray:
78
+ """Await and add the next messages from both queues."""
79
+ a = await self._state.queue_a.get()
80
+ b = await self._state.queue_b.get()
81
+ return replace(a, data=a.data + b.data)
82
+
83
+ def __call__(self) -> AxisArray:
84
+ """Synchronously get and add the next messages from both queues."""
85
+ return run_coroutine_sync(self.__acall__())
86
+
87
+ # Aliases for legacy interface
88
+ async def __anext__(self) -> AxisArray:
89
+ return await self.__acall__()
90
+
91
+ def __next__(self) -> AxisArray:
92
+ return self.__call__()
93
+
94
+
95
+ class Add(ez.Unit):
96
+ """Add two signals together.
97
+
98
+ Assumes compatible/similar axes/dimensions and aligned time spans.
99
+ Messages are paired by arrival order (oldest from each queue).
100
+ """
101
+
102
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
103
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
104
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
105
+
106
+ async def initialize(self) -> None:
107
+ self.processor = AddProcessor()
108
+
109
+ @ez.subscriber(INPUT_SIGNAL_A)
110
+ async def on_a(self, msg: AxisArray) -> None:
111
+ self.processor.push_a(msg)
112
+
113
+ @ez.subscriber(INPUT_SIGNAL_B)
114
+ async def on_b(self, msg: AxisArray) -> None:
115
+ self.processor.push_b(msg)
116
+
117
+ @ez.publisher(OUTPUT_SIGNAL)
118
+ async def output(self) -> typing.AsyncGenerator:
119
+ while True:
120
+ yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
@@ -1,40 +1,48 @@
1
- import typing
1
+ """
2
+ Clips the data to be within the specified range.
3
+
4
+ .. note::
5
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ """
2
8
 
3
- import numpy as np
4
9
  import ezmsg.core as ez
5
- from ezmsg.util.generator import consumer
10
+ from array_api_compat import get_namespace
11
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
6
12
  from ezmsg.util.messages.axisarray import AxisArray
7
13
  from ezmsg.util.messages.util import replace
8
14
 
9
- from ..base import GenAxisArray
10
15
 
16
+ class ClipSettings(ez.Settings):
17
+ min: float | None = None
18
+ """Lower clip bound. If None, no lower clipping is applied."""
11
19
 
12
- @consumer
13
- def clip(a_min: float, a_max: float) -> typing.Generator[AxisArray, AxisArray, None]:
14
- """
15
- Clips the data to be within the specified range. See :obj:`np.clip` for more details.
20
+ max: float | None = None
21
+ """Upper clip bound. If None, no upper clipping is applied."""
16
22
 
17
- Args:
18
- a_min: Lower clip bound
19
- a_max: Upper clip bound
20
23
 
21
- Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
22
- with the data payload containing the clipped version of the input :obj:`AxisArray` data.
24
+ class ClipTransformer(BaseTransformer[ClipSettings, AxisArray, AxisArray]):
25
+ def _process(self, message: AxisArray) -> AxisArray:
26
+ xp = get_namespace(message.data)
27
+ return replace(
28
+ message,
29
+ data=xp.clip(message.data, self.settings.min, self.settings.max),
30
+ )
23
31
 
24
- """
25
- msg_out = AxisArray(np.array([]), dims=[""])
26
- while True:
27
- msg_in: AxisArray = yield msg_out
28
- msg_out = replace(msg_in, data=np.clip(msg_in.data, a_min, a_max))
29
32
 
33
+ class Clip(BaseTransformerUnit[ClipSettings, AxisArray, AxisArray, ClipTransformer]):
34
+ SETTINGS = ClipSettings
30
35
 
31
- class ClipSettings(ez.Settings):
32
- a_min: float
33
- a_max: float
34
36
 
37
+ def clip(min: float | None = None, max: float | None = None) -> ClipTransformer:
38
+ """
39
+ Clips the data to be within the specified range.
35
40
 
36
- class Clip(GenAxisArray):
37
- SETTINGS = ClipSettings
41
+ Args:
42
+ min: Lower clip bound. If None, no lower clipping is applied.
43
+ max: Upper clip bound. If None, no upper clipping is applied.
38
44
 
39
- def construct_generator(self):
40
- self.STATE.gen = clip(a_min=self.SETTINGS.a_min, a_max=self.SETTINGS.a_max)
45
+ Returns:
46
+ :obj:`ClipTransformer`.
47
+ """
48
+ return ClipTransformer(ClipSettings(min=min, max=max))