ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,89 @@
1
+ import warnings
2
+ from typing import Callable
3
+
4
+ import numpy as np
5
+
6
+ from .filter import (
7
+ BACoeffs,
8
+ BaseFilterByDesignTransformerUnit,
9
+ FilterBaseSettings,
10
+ FilterByDesignTransformer,
11
+ )
12
+
13
+
14
+ class GaussianSmoothingSettings(FilterBaseSettings):
15
+ sigma: float | None = 1.0
16
+ """
17
+ sigma : float
18
+ Standard deviation of the Gaussian kernel.
19
+ """
20
+
21
+ width: int | None = 4
22
+ """
23
+ width : int
24
+ Number of standard deviations covered by the kernel window if kernel_size is not provided.
25
+ """
26
+
27
+ kernel_size: int | None = None
28
+ """
29
+ kernel_size : int | None
30
+ Length of the kernel in samples. If provided, overrides automatic calculation.
31
+ """
32
+
33
+
34
+ def gaussian_smoothing_filter_design(
35
+ sigma: float = 1.0,
36
+ width: int = 4,
37
+ kernel_size: int | None = None,
38
+ ) -> BACoeffs | None:
39
+ # Parameter checks
40
+ if sigma <= 0:
41
+ raise ValueError(f"sigma must be positive. Received: {sigma}")
42
+
43
+ if width <= 0:
44
+ raise ValueError(f"width must be positive. Received: {width}")
45
+
46
+ if kernel_size is not None:
47
+ if kernel_size < 1:
48
+ raise ValueError(f"kernel_size must be >= 1. Received: {kernel_size}")
49
+ else:
50
+ kernel_size = int(2 * width * sigma + 1)
51
+
52
+ # Warn if kernel_size is smaller than recommended but don't fail
53
+ expected_kernel_size = int(2 * width * sigma + 1)
54
+ if kernel_size < expected_kernel_size:
55
+ ## TODO: Either add a warning or determine appropriate kernel size and raise an error
56
+ warnings.warn(
57
+ f"Provided kernel_size {kernel_size} is smaller than recommended "
58
+ f"size {expected_kernel_size} for sigma={sigma} and width={width}. "
59
+ "The kernel may be truncated."
60
+ )
61
+
62
+ from scipy.signal.windows import gaussian
63
+
64
+ b = gaussian(kernel_size, std=sigma)
65
+ b /= np.sum(b) # Ensure normalization
66
+ a = np.array([1.0])
67
+
68
+ return b, a
69
+
70
+
71
+ class GaussianSmoothingFilterTransformer(FilterByDesignTransformer[GaussianSmoothingSettings, BACoeffs]):
72
+ def get_design_function(
73
+ self,
74
+ ) -> Callable[[float], BACoeffs]:
75
+ # Create a wrapper function that ignores fs parameter since gaussian smoothing doesn't need it
76
+ def design_wrapper(fs: float) -> BACoeffs:
77
+ return gaussian_smoothing_filter_design(
78
+ sigma=self.settings.sigma,
79
+ width=self.settings.width,
80
+ kernel_size=self.settings.kernel_size,
81
+ )
82
+
83
+ return design_wrapper
84
+
85
+
86
+ class GaussianSmoothingFilter(
87
+ BaseFilterByDesignTransformerUnit[GaussianSmoothingSettings, GaussianSmoothingFilterTransformer]
88
+ ):
89
+ SETTINGS = GaussianSmoothingSettings
@@ -0,0 +1,106 @@
1
+ import functools
2
+ import typing
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import scipy.signal
7
+
8
+ from .filter import (
9
+ BACoeffs,
10
+ BaseFilterByDesignTransformerUnit,
11
+ FilterBaseSettings,
12
+ FilterByDesignTransformer,
13
+ )
14
+
15
+
16
+ class KaiserFilterSettings(FilterBaseSettings):
17
+ """Settings for :obj:`KaiserFilter`"""
18
+
19
+ # axis and coef_type are inherited from FilterBaseSettings
20
+
21
+ cutoff: float | npt.ArrayLike | None = None
22
+ """
23
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
24
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
25
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
26
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
27
+ not be included in cutoff.
28
+ """
29
+
30
+ ripple: float | None = None
31
+ """
32
+ Upper bound for the deviation (in dB) of the magnitude of the filter's frequency response from that of
33
+ the desired filter (not including frequencies in any transition intervals).
34
+ See scipy.signal.kaiserord for more information.
35
+ """
36
+
37
+ width: float | None = None
38
+ """
39
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
40
+ the same units as fs) for use in Kaiser FIR filter design.
41
+ See scipy.signal.kaiserord for more information.
42
+ """
43
+
44
+ pass_zero: bool | str = True
45
+ """
46
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
47
+ be a string argument for the desired filter type (equivalent to btype in IIR design functions).
48
+ {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
49
+ """
50
+
51
+ wn_hz: bool = True
52
+ """
53
+ Set False if cutoff and width are normalized from 0 to 1, where 1 is the Nyquist frequency
54
+ """
55
+
56
+
57
+ def kaiser_design_fun(
58
+ fs: float,
59
+ cutoff: float | npt.ArrayLike | None = None,
60
+ ripple: float | None = None,
61
+ width: float | None = None,
62
+ pass_zero: bool | str = True,
63
+ wn_hz: bool = True,
64
+ ) -> BACoeffs | None:
65
+ """
66
+ Design an `order`th-order FIR Kaiser filter and return the filter coefficients.
67
+ See :obj:`FIRFilterSettings` for argument description.
68
+
69
+ Returns:
70
+ The filter taps as designed by firwin
71
+ """
72
+ if ripple is None or width is None or cutoff is None:
73
+ return None
74
+
75
+ width = width / (0.5 * fs) if wn_hz else width
76
+ n_taps, beta = scipy.signal.kaiserord(ripple, width)
77
+ if n_taps % 2 == 0:
78
+ n_taps += 1
79
+ taps = scipy.signal.firwin(
80
+ numtaps=n_taps,
81
+ cutoff=cutoff,
82
+ window=("kaiser", beta), # type: ignore
83
+ pass_zero=pass_zero, # type: ignore
84
+ scale=False,
85
+ fs=fs if wn_hz else None,
86
+ )
87
+
88
+ return (taps, np.array([1.0]))
89
+
90
+
91
+ class KaiserFilterTransformer(FilterByDesignTransformer[KaiserFilterSettings, BACoeffs]):
92
+ def get_design_function(
93
+ self,
94
+ ) -> typing.Callable[[float], BACoeffs | None]:
95
+ return functools.partial(
96
+ kaiser_design_fun,
97
+ cutoff=self.settings.cutoff,
98
+ ripple=self.settings.ripple,
99
+ width=self.settings.width,
100
+ pass_zero=self.settings.pass_zero,
101
+ wn_hz=self.settings.wn_hz,
102
+ )
103
+
104
+
105
+ class KaiserFilter(BaseFilterByDesignTransformerUnit[KaiserFilterSettings, KaiserFilterTransformer]):
106
+ SETTINGS = KaiserFilterSettings
@@ -0,0 +1,120 @@
1
+ """
2
+ Apply a linear transformation: output = scale * input + offset.
3
+
4
+ Supports per-element scale and offset along a specified axis.
5
+ For full matrix transformations, use :obj:`AffineTransformTransformer` instead.
6
+
7
+ .. note::
8
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
9
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
10
+ """
11
+
12
+ import ezmsg.core as ez
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ from array_api_compat import get_namespace
16
+ from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
17
+ from ezmsg.util.messages.axisarray import AxisArray
18
+ from ezmsg.util.messages.util import replace
19
+
20
+
21
+ class LinearTransformSettings(ez.Settings):
22
+ scale: float | list[float] | npt.ArrayLike = 1.0
23
+ """Scale factor(s). Can be a scalar (applied to all elements) or an array
24
+ matching the size of the specified axis for per-element scaling."""
25
+
26
+ offset: float | list[float] | npt.ArrayLike = 0.0
27
+ """Offset value(s). Can be a scalar (applied to all elements) or an array
28
+ matching the size of the specified axis for per-element offset."""
29
+
30
+ axis: str | None = None
31
+ """Axis along which to apply per-element scale/offset. If None, scalar
32
+ scale/offset are broadcast to all elements."""
33
+
34
+
35
+ @processor_state
36
+ class LinearTransformState:
37
+ scale: npt.NDArray = None
38
+ """Prepared scale array for broadcasting."""
39
+
40
+ offset: npt.NDArray = None
41
+ """Prepared offset array for broadcasting."""
42
+
43
+
44
+ class LinearTransformTransformer(
45
+ BaseStatefulTransformer[LinearTransformSettings, AxisArray, AxisArray, LinearTransformState]
46
+ ):
47
+ """Apply linear transformation: output = scale * input + offset.
48
+
49
+ This transformer is optimized for element-wise linear operations with
50
+ optional per-channel (or per-axis) coefficients. For full matrix
51
+ transformations, use :obj:`AffineTransformTransformer` instead.
52
+
53
+ Examples:
54
+ # Uniform scaling and offset
55
+ >>> transformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0))
56
+
57
+ # Per-channel scaling (e.g., for 3-channel data along "ch" axis)
58
+ >>> transformer = LinearTransformTransformer(LinearTransformSettings(
59
+ ... scale=[0.5, 1.0, 2.0],
60
+ ... offset=[0.0, 0.1, 0.2],
61
+ ... axis="ch"
62
+ ... ))
63
+ """
64
+
65
+ def _hash_message(self, message: AxisArray) -> int:
66
+ """Hash based on shape and axis to detect when broadcast shapes need recalculation."""
67
+ axis = self.settings.axis
68
+ if axis is not None:
69
+ axis_idx = message.get_axis_idx(axis)
70
+ return hash((message.data.ndim, axis_idx, message.data.shape[axis_idx]))
71
+ return hash(message.data.ndim)
72
+
73
+ def _reset_state(self, message: AxisArray) -> None:
74
+ """Prepare scale/offset arrays with proper broadcast shapes."""
75
+ xp = get_namespace(message.data)
76
+ ndim = message.data.ndim
77
+
78
+ scale = self.settings.scale
79
+ offset = self.settings.offset
80
+
81
+ # Convert settings to arrays
82
+ if isinstance(scale, (list, np.ndarray)):
83
+ scale = xp.asarray(scale, dtype=xp.float64)
84
+ else:
85
+ # Scalar: create a 0-d array
86
+ scale = xp.asarray(float(scale), dtype=xp.float64)
87
+
88
+ if isinstance(offset, (list, np.ndarray)):
89
+ offset = xp.asarray(offset, dtype=xp.float64)
90
+ else:
91
+ # Scalar: create a 0-d array
92
+ offset = xp.asarray(float(offset), dtype=xp.float64)
93
+
94
+ # If axis is specified and we have 1-d arrays, reshape for proper broadcasting
95
+ if self.settings.axis is not None and ndim > 0:
96
+ axis_idx = message.get_axis_idx(self.settings.axis)
97
+
98
+ if scale.ndim == 1:
99
+ # Create shape for broadcasting: all 1s except at axis_idx
100
+ broadcast_shape = [1] * ndim
101
+ broadcast_shape[axis_idx] = scale.shape[0]
102
+ scale = xp.reshape(scale, broadcast_shape)
103
+
104
+ if offset.ndim == 1:
105
+ broadcast_shape = [1] * ndim
106
+ broadcast_shape[axis_idx] = offset.shape[0]
107
+ offset = xp.reshape(offset, broadcast_shape)
108
+
109
+ self._state.scale = scale
110
+ self._state.offset = offset
111
+
112
+ def _process(self, message: AxisArray) -> AxisArray:
113
+ result = message.data * self._state.scale + self._state.offset
114
+ return replace(message, data=result)
115
+
116
+
117
+ class LinearTransform(BaseTransformerUnit[LinearTransformSettings, AxisArray, AxisArray, LinearTransformTransformer]):
118
+ """Unit wrapper for LinearTransformTransformer."""
119
+
120
+ SETTINGS = LinearTransformSettings
File without changes
@@ -0,0 +1,35 @@
1
+ """
2
+ Take the absolute value of the data.
3
+
4
+ .. note::
5
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ """
8
+
9
+ from array_api_compat import get_namespace
10
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
11
+ from ezmsg.util.messages.axisarray import AxisArray
12
+ from ezmsg.util.messages.util import replace
13
+
14
+
15
+ class AbsSettings:
16
+ pass
17
+
18
+
19
+ class AbsTransformer(BaseTransformer[None, AxisArray, AxisArray]):
20
+ def _process(self, message: AxisArray) -> AxisArray:
21
+ xp = get_namespace(message.data)
22
+ return replace(message, data=xp.abs(message.data))
23
+
24
+
25
+ class Abs(BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]): ... # SETTINGS = None
26
+
27
+
28
+ def abs() -> AbsTransformer:
29
+ """
30
+ Take the absolute value of the data. See :obj:`np.abs` for more details.
31
+
32
+ Returns: :obj:`AbsTransformer`.
33
+
34
+ """
35
+ return AbsTransformer()
@@ -0,0 +1,120 @@
1
+ """Add 2 signals or add a constant to a signal."""
2
+
3
+ import asyncio
4
+ import typing
5
+ from dataclasses import dataclass, field
6
+
7
+ import ezmsg.core as ez
8
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
9
+ from ezmsg.baseproc.util.asio import run_coroutine_sync
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
+ from ezmsg.util.messages.util import replace
12
+
13
+ # --- Constant Addition (single input) ---
14
+
15
+
16
+ class ConstAddSettings(ez.Settings):
17
+ value: float = 0.0
18
+ """Number to add to the input data."""
19
+
20
+
21
+ class ConstAddTransformer(BaseTransformer[ConstAddSettings, AxisArray, AxisArray]):
22
+ """Add a constant value to input data."""
23
+
24
+ def _process(self, message: AxisArray) -> AxisArray:
25
+ return replace(message, data=message.data + self.settings.value)
26
+
27
+
28
+ class ConstAdd(BaseTransformerUnit[ConstAddSettings, AxisArray, AxisArray, ConstAddTransformer]):
29
+ """Unit wrapper for ConstAddTransformer."""
30
+
31
+ SETTINGS = ConstAddSettings
32
+
33
+
34
+ # --- Two-input Addition ---
35
+
36
+
37
+ @dataclass
38
+ class AddState:
39
+ """State for Add processor with two input queues."""
40
+
41
+ queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
42
+ queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
43
+
44
+
45
+ class AddProcessor:
46
+ """Processor that adds two AxisArray signals together.
47
+
48
+ This processor maintains separate queues for two input streams and
49
+ adds corresponding messages element-wise. It assumes both inputs
50
+ have compatible shapes and aligned time spans.
51
+ """
52
+
53
+ def __init__(self):
54
+ self._state = AddState()
55
+
56
+ @property
57
+ def state(self) -> AddState:
58
+ return self._state
59
+
60
+ @state.setter
61
+ def state(self, state: AddState | bytes | None) -> None:
62
+ if state is not None:
63
+ # TODO: Support hydrating state from bytes
64
+ # if isinstance(state, bytes):
65
+ # self._state = pickle.loads(state)
66
+ # else:
67
+ self._state = state
68
+
69
+ def push_a(self, msg: AxisArray) -> None:
70
+ """Push a message to queue A."""
71
+ self._state.queue_a.put_nowait(msg)
72
+
73
+ def push_b(self, msg: AxisArray) -> None:
74
+ """Push a message to queue B."""
75
+ self._state.queue_b.put_nowait(msg)
76
+
77
+ async def __acall__(self) -> AxisArray:
78
+ """Await and add the next messages from both queues."""
79
+ a = await self._state.queue_a.get()
80
+ b = await self._state.queue_b.get()
81
+ return replace(a, data=a.data + b.data)
82
+
83
+ def __call__(self) -> AxisArray:
84
+ """Synchronously get and add the next messages from both queues."""
85
+ return run_coroutine_sync(self.__acall__())
86
+
87
+ # Aliases for legacy interface
88
+ async def __anext__(self) -> AxisArray:
89
+ return await self.__acall__()
90
+
91
+ def __next__(self) -> AxisArray:
92
+ return self.__call__()
93
+
94
+
95
+ class Add(ez.Unit):
96
+ """Add two signals together.
97
+
98
+ Assumes compatible/similar axes/dimensions and aligned time spans.
99
+ Messages are paired by arrival order (oldest from each queue).
100
+ """
101
+
102
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
103
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
104
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
105
+
106
+ async def initialize(self) -> None:
107
+ self.processor = AddProcessor()
108
+
109
+ @ez.subscriber(INPUT_SIGNAL_A)
110
+ async def on_a(self, msg: AxisArray) -> None:
111
+ self.processor.push_a(msg)
112
+
113
+ @ez.subscriber(INPUT_SIGNAL_B)
114
+ async def on_b(self, msg: AxisArray) -> None:
115
+ self.processor.push_b(msg)
116
+
117
+ @ez.publisher(OUTPUT_SIGNAL)
118
+ async def output(self) -> typing.AsyncGenerator:
119
+ while True:
120
+ yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
@@ -0,0 +1,48 @@
1
+ """
2
+ Clips the data to be within the specified range.
3
+
4
+ .. note::
5
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ """
8
+
9
+ import ezmsg.core as ez
10
+ from array_api_compat import get_namespace
11
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
12
+ from ezmsg.util.messages.axisarray import AxisArray
13
+ from ezmsg.util.messages.util import replace
14
+
15
+
16
+ class ClipSettings(ez.Settings):
17
+ min: float | None = None
18
+ """Lower clip bound. If None, no lower clipping is applied."""
19
+
20
+ max: float | None = None
21
+ """Upper clip bound. If None, no upper clipping is applied."""
22
+
23
+
24
+ class ClipTransformer(BaseTransformer[ClipSettings, AxisArray, AxisArray]):
25
+ def _process(self, message: AxisArray) -> AxisArray:
26
+ xp = get_namespace(message.data)
27
+ return replace(
28
+ message,
29
+ data=xp.clip(message.data, self.settings.min, self.settings.max),
30
+ )
31
+
32
+
33
+ class Clip(BaseTransformerUnit[ClipSettings, AxisArray, AxisArray, ClipTransformer]):
34
+ SETTINGS = ClipSettings
35
+
36
+
37
+ def clip(min: float | None = None, max: float | None = None) -> ClipTransformer:
38
+ """
39
+ Clips the data to be within the specified range.
40
+
41
+ Args:
42
+ min: Lower clip bound. If None, no lower clipping is applied.
43
+ max: Upper clip bound. If None, no upper clipping is applied.
44
+
45
+ Returns:
46
+ :obj:`ClipTransformer`.
47
+ """
48
+ return ClipTransformer(ClipSettings(min=min, max=max))
@@ -0,0 +1,143 @@
1
+ """
2
+ Take the difference between 2 signals or between a signal and a constant value.
3
+
4
+ .. note::
5
+ :obj:`ConstDifferenceTransformer` supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ :obj:`DifferenceProcessor` (two-input difference) currently requires NumPy arrays.
8
+ """
9
+
10
+ import asyncio
11
+ import typing
12
+ from dataclasses import dataclass, field
13
+
14
+ import ezmsg.core as ez
15
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
16
+ from ezmsg.baseproc.util.asio import run_coroutine_sync
17
+ from ezmsg.util.messages.axisarray import AxisArray
18
+ from ezmsg.util.messages.util import replace
19
+
20
+
21
+ class ConstDifferenceSettings(ez.Settings):
22
+ value: float = 0.0
23
+ """number to subtract or be subtracted from the input data"""
24
+
25
+ subtrahend: bool = True
26
+ """If True (default) then value is subtracted from the input data. If False, the input data
27
+ is subtracted from value."""
28
+
29
+
30
+ class ConstDifferenceTransformer(BaseTransformer[ConstDifferenceSettings, AxisArray, AxisArray]):
31
+ def _process(self, message: AxisArray) -> AxisArray:
32
+ return replace(
33
+ message,
34
+ data=(message.data - self.settings.value)
35
+ if self.settings.subtrahend
36
+ else (self.settings.value - message.data),
37
+ )
38
+
39
+
40
+ class ConstDifference(BaseTransformerUnit[ConstDifferenceSettings, AxisArray, AxisArray, ConstDifferenceTransformer]):
41
+ SETTINGS = ConstDifferenceSettings
42
+
43
+
44
+ def const_difference(value: float = 0.0, subtrahend: bool = True) -> ConstDifferenceTransformer:
45
+ """
46
+ result = (in_data - value) if subtrahend else (value - in_data)
47
+ https://en.wikipedia.org/wiki/Template:Arithmetic_operations
48
+
49
+ Args:
50
+ value: number to subtract or be subtracted from the input data
51
+ subtrahend: If True (default) then value is subtracted from the input data.
52
+ If False, the input data is subtracted from value.
53
+
54
+ Returns: :obj:`ConstDifferenceTransformer`.
55
+ """
56
+ return ConstDifferenceTransformer(ConstDifferenceSettings(value=value, subtrahend=subtrahend))
57
+
58
+
59
+ # --- Two-input Difference ---
60
+
61
+
62
+ @dataclass
63
+ class DifferenceState:
64
+ """State for Difference processor with two input queues."""
65
+
66
+ queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
67
+ queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
68
+
69
+
70
+ class DifferenceProcessor:
71
+ """Processor that subtracts two AxisArray signals (A - B).
72
+
73
+ This processor maintains separate queues for two input streams and
74
+ subtracts corresponding messages element-wise. It assumes both inputs
75
+ have compatible shapes and aligned time spans.
76
+ """
77
+
78
+ def __init__(self):
79
+ self._state = DifferenceState()
80
+
81
+ @property
82
+ def state(self) -> DifferenceState:
83
+ return self._state
84
+
85
+ @state.setter
86
+ def state(self, state: DifferenceState | bytes | None) -> None:
87
+ if state is not None:
88
+ self._state = state
89
+
90
+ def push_a(self, msg: AxisArray) -> None:
91
+ """Push a message to queue A (minuend)."""
92
+ self._state.queue_a.put_nowait(msg)
93
+
94
+ def push_b(self, msg: AxisArray) -> None:
95
+ """Push a message to queue B (subtrahend)."""
96
+ self._state.queue_b.put_nowait(msg)
97
+
98
+ async def __acall__(self) -> AxisArray:
99
+ """Await and subtract the next messages (A - B)."""
100
+ a = await self._state.queue_a.get()
101
+ b = await self._state.queue_b.get()
102
+ return replace(a, data=a.data - b.data)
103
+
104
+ def __call__(self) -> AxisArray:
105
+ """Synchronously get and subtract the next messages."""
106
+ return run_coroutine_sync(self.__acall__())
107
+
108
+ # Aliases for legacy interface
109
+ async def __anext__(self) -> AxisArray:
110
+ return await self.__acall__()
111
+
112
+ def __next__(self) -> AxisArray:
113
+ return self.__call__()
114
+
115
+
116
+ class Difference(ez.Unit):
117
+ """Subtract two signals (A - B).
118
+
119
+ Assumes compatible/similar axes/dimensions and aligned time spans.
120
+ Messages are paired by arrival order (oldest from each queue).
121
+
122
+ OUTPUT = INPUT_SIGNAL_A - INPUT_SIGNAL_B
123
+ """
124
+
125
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
126
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
127
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
128
+
129
+ async def initialize(self) -> None:
130
+ self.processor = DifferenceProcessor()
131
+
132
+ @ez.subscriber(INPUT_SIGNAL_A)
133
+ async def on_a(self, msg: AxisArray) -> None:
134
+ self.processor.push_a(msg)
135
+
136
+ @ez.subscriber(INPUT_SIGNAL_B)
137
+ async def on_b(self, msg: AxisArray) -> None:
138
+ self.processor.push_b(msg)
139
+
140
+ @ez.publisher(OUTPUT_SIGNAL)
141
+ async def output(self) -> typing.AsyncGenerator:
142
+ while True:
143
+ yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
@@ -0,0 +1,28 @@
1
+ """
2
+ Compute the multiplicative inverse (1/x) of the data.
3
+
4
+ .. note::
5
+ This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
6
+ enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
7
+ """
8
+
9
+ from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
+ from ezmsg.util.messages.util import replace
12
+
13
+
14
+ class InvertTransformer(BaseTransformer[None, AxisArray, AxisArray]):
15
+ def _process(self, message: AxisArray) -> AxisArray:
16
+ return replace(message, data=1 / message.data)
17
+
18
+
19
+ class Invert(BaseTransformerUnit[None, AxisArray, AxisArray, InvertTransformer]): ... # SETTINGS = None
20
+
21
+
22
+ def invert() -> InvertTransformer:
23
+ """
24
+ Take the inverse of the data.
25
+
26
+ Returns: :obj:`InvertTransformer`.
27
+ """
28
+ return InvertTransformer()