ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +22 -4
- ezmsg/sigproc/activation.py +31 -40
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +171 -169
- ezmsg/sigproc/aggregate.py +190 -97
- ezmsg/sigproc/bandpower.py +60 -55
- ezmsg/sigproc/base.py +143 -33
- ezmsg/sigproc/butterworthfilter.py +34 -38
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +23 -17
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +15 -10
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +72 -81
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +254 -148
- ezmsg/sigproc/filterbank.py +226 -214
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/abs.py +23 -22
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +33 -25
- ezmsg/sigproc/math/difference.py +117 -43
- ezmsg/sigproc/math/invert.py +18 -25
- ezmsg/sigproc/math/log.py +38 -33
- ezmsg/sigproc/math/scale.py +24 -25
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +209 -254
- ezmsg/sigproc/scaler.py +93 -218
- ezmsg/sigproc/signalinjector.py +44 -43
- ezmsg/sigproc/slicer.py +74 -102
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +70 -70
- ezmsg/sigproc/spectrum.py +187 -173
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +147 -154
- ezmsg/sigproc/window.py +248 -210
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/synth.py +0 -621
- ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
- ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
- /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from .filter import (
|
|
7
|
+
BACoeffs,
|
|
8
|
+
BaseFilterByDesignTransformerUnit,
|
|
9
|
+
FilterBaseSettings,
|
|
10
|
+
FilterByDesignTransformer,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GaussianSmoothingSettings(FilterBaseSettings):
|
|
15
|
+
sigma: float | None = 1.0
|
|
16
|
+
"""
|
|
17
|
+
sigma : float
|
|
18
|
+
Standard deviation of the Gaussian kernel.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
width: int | None = 4
|
|
22
|
+
"""
|
|
23
|
+
width : int
|
|
24
|
+
Number of standard deviations covered by the kernel window if kernel_size is not provided.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
kernel_size: int | None = None
|
|
28
|
+
"""
|
|
29
|
+
kernel_size : int | None
|
|
30
|
+
Length of the kernel in samples. If provided, overrides automatic calculation.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def gaussian_smoothing_filter_design(
|
|
35
|
+
sigma: float = 1.0,
|
|
36
|
+
width: int = 4,
|
|
37
|
+
kernel_size: int | None = None,
|
|
38
|
+
) -> BACoeffs | None:
|
|
39
|
+
# Parameter checks
|
|
40
|
+
if sigma <= 0:
|
|
41
|
+
raise ValueError(f"sigma must be positive. Received: {sigma}")
|
|
42
|
+
|
|
43
|
+
if width <= 0:
|
|
44
|
+
raise ValueError(f"width must be positive. Received: {width}")
|
|
45
|
+
|
|
46
|
+
if kernel_size is not None:
|
|
47
|
+
if kernel_size < 1:
|
|
48
|
+
raise ValueError(f"kernel_size must be >= 1. Received: {kernel_size}")
|
|
49
|
+
else:
|
|
50
|
+
kernel_size = int(2 * width * sigma + 1)
|
|
51
|
+
|
|
52
|
+
# Warn if kernel_size is smaller than recommended but don't fail
|
|
53
|
+
expected_kernel_size = int(2 * width * sigma + 1)
|
|
54
|
+
if kernel_size < expected_kernel_size:
|
|
55
|
+
## TODO: Either add a warning or determine appropriate kernel size and raise an error
|
|
56
|
+
warnings.warn(
|
|
57
|
+
f"Provided kernel_size {kernel_size} is smaller than recommended "
|
|
58
|
+
f"size {expected_kernel_size} for sigma={sigma} and width={width}. "
|
|
59
|
+
"The kernel may be truncated."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
from scipy.signal.windows import gaussian
|
|
63
|
+
|
|
64
|
+
b = gaussian(kernel_size, std=sigma)
|
|
65
|
+
b /= np.sum(b) # Ensure normalization
|
|
66
|
+
a = np.array([1.0])
|
|
67
|
+
|
|
68
|
+
return b, a
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class GaussianSmoothingFilterTransformer(FilterByDesignTransformer[GaussianSmoothingSettings, BACoeffs]):
|
|
72
|
+
def get_design_function(
|
|
73
|
+
self,
|
|
74
|
+
) -> Callable[[float], BACoeffs]:
|
|
75
|
+
# Create a wrapper function that ignores fs parameter since gaussian smoothing doesn't need it
|
|
76
|
+
def design_wrapper(fs: float) -> BACoeffs:
|
|
77
|
+
return gaussian_smoothing_filter_design(
|
|
78
|
+
sigma=self.settings.sigma,
|
|
79
|
+
width=self.settings.width,
|
|
80
|
+
kernel_size=self.settings.kernel_size,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return design_wrapper
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class GaussianSmoothingFilter(
|
|
87
|
+
BaseFilterByDesignTransformerUnit[GaussianSmoothingSettings, GaussianSmoothingFilterTransformer]
|
|
88
|
+
):
|
|
89
|
+
SETTINGS = GaussianSmoothingSettings
|
ezmsg/sigproc/kaiser.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import scipy.signal
|
|
7
|
+
|
|
8
|
+
from .filter import (
|
|
9
|
+
BACoeffs,
|
|
10
|
+
BaseFilterByDesignTransformerUnit,
|
|
11
|
+
FilterBaseSettings,
|
|
12
|
+
FilterByDesignTransformer,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class KaiserFilterSettings(FilterBaseSettings):
|
|
17
|
+
"""Settings for :obj:`KaiserFilter`"""
|
|
18
|
+
|
|
19
|
+
# axis and coef_type are inherited from FilterBaseSettings
|
|
20
|
+
|
|
21
|
+
cutoff: float | npt.ArrayLike | None = None
|
|
22
|
+
"""
|
|
23
|
+
Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
|
|
24
|
+
(that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
|
|
25
|
+
the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
|
|
26
|
+
cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
|
|
27
|
+
not be included in cutoff.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
ripple: float | None = None
|
|
31
|
+
"""
|
|
32
|
+
Upper bound for the deviation (in dB) of the magnitude of the filter's frequency response from that of
|
|
33
|
+
the desired filter (not including frequencies in any transition intervals).
|
|
34
|
+
See scipy.signal.kaiserord for more information.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
width: float | None = None
|
|
38
|
+
"""
|
|
39
|
+
If width is not None, then assume it is the approximate width of the transition region (expressed in
|
|
40
|
+
the same units as fs) for use in Kaiser FIR filter design.
|
|
41
|
+
See scipy.signal.kaiserord for more information.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
pass_zero: bool | str = True
|
|
45
|
+
"""
|
|
46
|
+
If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
|
|
47
|
+
be a string argument for the desired filter type (equivalent to btype in IIR design functions).
|
|
48
|
+
{‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
wn_hz: bool = True
|
|
52
|
+
"""
|
|
53
|
+
Set False if cutoff and width are normalized from 0 to 1, where 1 is the Nyquist frequency
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def kaiser_design_fun(
|
|
58
|
+
fs: float,
|
|
59
|
+
cutoff: float | npt.ArrayLike | None = None,
|
|
60
|
+
ripple: float | None = None,
|
|
61
|
+
width: float | None = None,
|
|
62
|
+
pass_zero: bool | str = True,
|
|
63
|
+
wn_hz: bool = True,
|
|
64
|
+
) -> BACoeffs | None:
|
|
65
|
+
"""
|
|
66
|
+
Design an `order`th-order FIR Kaiser filter and return the filter coefficients.
|
|
67
|
+
See :obj:`FIRFilterSettings` for argument description.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The filter taps as designed by firwin
|
|
71
|
+
"""
|
|
72
|
+
if ripple is None or width is None or cutoff is None:
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
width = width / (0.5 * fs) if wn_hz else width
|
|
76
|
+
n_taps, beta = scipy.signal.kaiserord(ripple, width)
|
|
77
|
+
if n_taps % 2 == 0:
|
|
78
|
+
n_taps += 1
|
|
79
|
+
taps = scipy.signal.firwin(
|
|
80
|
+
numtaps=n_taps,
|
|
81
|
+
cutoff=cutoff,
|
|
82
|
+
window=("kaiser", beta), # type: ignore
|
|
83
|
+
pass_zero=pass_zero, # type: ignore
|
|
84
|
+
scale=False,
|
|
85
|
+
fs=fs if wn_hz else None,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return (taps, np.array([1.0]))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class KaiserFilterTransformer(FilterByDesignTransformer[KaiserFilterSettings, BACoeffs]):
|
|
92
|
+
def get_design_function(
|
|
93
|
+
self,
|
|
94
|
+
) -> typing.Callable[[float], BACoeffs | None]:
|
|
95
|
+
return functools.partial(
|
|
96
|
+
kaiser_design_fun,
|
|
97
|
+
cutoff=self.settings.cutoff,
|
|
98
|
+
ripple=self.settings.ripple,
|
|
99
|
+
width=self.settings.width,
|
|
100
|
+
pass_zero=self.settings.pass_zero,
|
|
101
|
+
wn_hz=self.settings.wn_hz,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KaiserFilter(BaseFilterByDesignTransformerUnit[KaiserFilterSettings, KaiserFilterTransformer]):
|
|
106
|
+
SETTINGS = KaiserFilterSettings
|
ezmsg/sigproc/linear.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Apply a linear transformation: output = scale * input + offset.
|
|
3
|
+
|
|
4
|
+
Supports per-element scale and offset along a specified axis.
|
|
5
|
+
For full matrix transformations, use :obj:`AffineTransformTransformer` instead.
|
|
6
|
+
|
|
7
|
+
.. note::
|
|
8
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
9
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import ezmsg.core as ez
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
15
|
+
from array_api_compat import get_namespace
|
|
16
|
+
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
|
|
17
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
18
|
+
from ezmsg.util.messages.util import replace
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LinearTransformSettings(ez.Settings):
|
|
22
|
+
scale: float | list[float] | npt.ArrayLike = 1.0
|
|
23
|
+
"""Scale factor(s). Can be a scalar (applied to all elements) or an array
|
|
24
|
+
matching the size of the specified axis for per-element scaling."""
|
|
25
|
+
|
|
26
|
+
offset: float | list[float] | npt.ArrayLike = 0.0
|
|
27
|
+
"""Offset value(s). Can be a scalar (applied to all elements) or an array
|
|
28
|
+
matching the size of the specified axis for per-element offset."""
|
|
29
|
+
|
|
30
|
+
axis: str | None = None
|
|
31
|
+
"""Axis along which to apply per-element scale/offset. If None, scalar
|
|
32
|
+
scale/offset are broadcast to all elements."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@processor_state
|
|
36
|
+
class LinearTransformState:
|
|
37
|
+
scale: npt.NDArray = None
|
|
38
|
+
"""Prepared scale array for broadcasting."""
|
|
39
|
+
|
|
40
|
+
offset: npt.NDArray = None
|
|
41
|
+
"""Prepared offset array for broadcasting."""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LinearTransformTransformer(
|
|
45
|
+
BaseStatefulTransformer[LinearTransformSettings, AxisArray, AxisArray, LinearTransformState]
|
|
46
|
+
):
|
|
47
|
+
"""Apply linear transformation: output = scale * input + offset.
|
|
48
|
+
|
|
49
|
+
This transformer is optimized for element-wise linear operations with
|
|
50
|
+
optional per-channel (or per-axis) coefficients. For full matrix
|
|
51
|
+
transformations, use :obj:`AffineTransformTransformer` instead.
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
# Uniform scaling and offset
|
|
55
|
+
>>> transformer = LinearTransformTransformer(LinearTransformSettings(scale=2.0, offset=1.0))
|
|
56
|
+
|
|
57
|
+
# Per-channel scaling (e.g., for 3-channel data along "ch" axis)
|
|
58
|
+
>>> transformer = LinearTransformTransformer(LinearTransformSettings(
|
|
59
|
+
... scale=[0.5, 1.0, 2.0],
|
|
60
|
+
... offset=[0.0, 0.1, 0.2],
|
|
61
|
+
... axis="ch"
|
|
62
|
+
... ))
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
66
|
+
"""Hash based on shape and axis to detect when broadcast shapes need recalculation."""
|
|
67
|
+
axis = self.settings.axis
|
|
68
|
+
if axis is not None:
|
|
69
|
+
axis_idx = message.get_axis_idx(axis)
|
|
70
|
+
return hash((message.data.ndim, axis_idx, message.data.shape[axis_idx]))
|
|
71
|
+
return hash(message.data.ndim)
|
|
72
|
+
|
|
73
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
74
|
+
"""Prepare scale/offset arrays with proper broadcast shapes."""
|
|
75
|
+
xp = get_namespace(message.data)
|
|
76
|
+
ndim = message.data.ndim
|
|
77
|
+
|
|
78
|
+
scale = self.settings.scale
|
|
79
|
+
offset = self.settings.offset
|
|
80
|
+
|
|
81
|
+
# Convert settings to arrays
|
|
82
|
+
if isinstance(scale, (list, np.ndarray)):
|
|
83
|
+
scale = xp.asarray(scale, dtype=xp.float64)
|
|
84
|
+
else:
|
|
85
|
+
# Scalar: create a 0-d array
|
|
86
|
+
scale = xp.asarray(float(scale), dtype=xp.float64)
|
|
87
|
+
|
|
88
|
+
if isinstance(offset, (list, np.ndarray)):
|
|
89
|
+
offset = xp.asarray(offset, dtype=xp.float64)
|
|
90
|
+
else:
|
|
91
|
+
# Scalar: create a 0-d array
|
|
92
|
+
offset = xp.asarray(float(offset), dtype=xp.float64)
|
|
93
|
+
|
|
94
|
+
# If axis is specified and we have 1-d arrays, reshape for proper broadcasting
|
|
95
|
+
if self.settings.axis is not None and ndim > 0:
|
|
96
|
+
axis_idx = message.get_axis_idx(self.settings.axis)
|
|
97
|
+
|
|
98
|
+
if scale.ndim == 1:
|
|
99
|
+
# Create shape for broadcasting: all 1s except at axis_idx
|
|
100
|
+
broadcast_shape = [1] * ndim
|
|
101
|
+
broadcast_shape[axis_idx] = scale.shape[0]
|
|
102
|
+
scale = xp.reshape(scale, broadcast_shape)
|
|
103
|
+
|
|
104
|
+
if offset.ndim == 1:
|
|
105
|
+
broadcast_shape = [1] * ndim
|
|
106
|
+
broadcast_shape[axis_idx] = offset.shape[0]
|
|
107
|
+
offset = xp.reshape(offset, broadcast_shape)
|
|
108
|
+
|
|
109
|
+
self._state.scale = scale
|
|
110
|
+
self._state.offset = offset
|
|
111
|
+
|
|
112
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
113
|
+
result = message.data * self._state.scale + self._state.offset
|
|
114
|
+
return replace(message, data=result)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class LinearTransform(BaseTransformerUnit[LinearTransformSettings, AxisArray, AxisArray, LinearTransformTransformer]):
|
|
118
|
+
"""Unit wrapper for LinearTransformTransformer."""
|
|
119
|
+
|
|
120
|
+
SETTINGS = LinearTransformSettings
|
ezmsg/sigproc/math/abs.py
CHANGED
|
@@ -1,34 +1,35 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Take the absolute value of the data.
|
|
2
3
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from array_api_compat import get_namespace
|
|
10
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
6
11
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
12
|
from ezmsg.util.messages.util import replace
|
|
8
13
|
|
|
9
|
-
from ..base import GenAxisArray
|
|
10
14
|
|
|
15
|
+
class AbsSettings:
|
|
16
|
+
pass
|
|
11
17
|
|
|
12
|
-
@consumer
|
|
13
|
-
def abs() -> typing.Generator[AxisArray, AxisArray, None]:
|
|
14
|
-
"""
|
|
15
|
-
Take the absolute value of the data. See :obj:`np.abs` for more details.
|
|
16
18
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
while True:
|
|
22
|
-
msg_in: AxisArray = yield msg_out
|
|
23
|
-
msg_out = replace(msg_in, data=np.abs(msg_in.data))
|
|
19
|
+
class AbsTransformer(BaseTransformer[None, AxisArray, AxisArray]):
|
|
20
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
21
|
+
xp = get_namespace(message.data)
|
|
22
|
+
return replace(message, data=xp.abs(message.data))
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
class
|
|
27
|
-
pass
|
|
25
|
+
class Abs(BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]): ... # SETTINGS = None
|
|
28
26
|
|
|
29
27
|
|
|
30
|
-
|
|
31
|
-
|
|
28
|
+
def abs() -> AbsTransformer:
|
|
29
|
+
"""
|
|
30
|
+
Take the absolute value of the data. See :obj:`np.abs` for more details.
|
|
32
31
|
|
|
33
|
-
|
|
34
|
-
|
|
32
|
+
Returns: :obj:`AbsTransformer`.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
return AbsTransformer()
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Add 2 signals or add a constant to a signal."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import typing
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
import ezmsg.core as ez
|
|
8
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
9
|
+
from ezmsg.baseproc.util.asio import run_coroutine_sync
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
from ezmsg.util.messages.util import replace
|
|
12
|
+
|
|
13
|
+
# --- Constant Addition (single input) ---
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConstAddSettings(ez.Settings):
|
|
17
|
+
value: float = 0.0
|
|
18
|
+
"""Number to add to the input data."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ConstAddTransformer(BaseTransformer[ConstAddSettings, AxisArray, AxisArray]):
|
|
22
|
+
"""Add a constant value to input data."""
|
|
23
|
+
|
|
24
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
25
|
+
return replace(message, data=message.data + self.settings.value)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ConstAdd(BaseTransformerUnit[ConstAddSettings, AxisArray, AxisArray, ConstAddTransformer]):
|
|
29
|
+
"""Unit wrapper for ConstAddTransformer."""
|
|
30
|
+
|
|
31
|
+
SETTINGS = ConstAddSettings
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# --- Two-input Addition ---
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class AddState:
|
|
39
|
+
"""State for Add processor with two input queues."""
|
|
40
|
+
|
|
41
|
+
queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
|
|
42
|
+
queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class AddProcessor:
|
|
46
|
+
"""Processor that adds two AxisArray signals together.
|
|
47
|
+
|
|
48
|
+
This processor maintains separate queues for two input streams and
|
|
49
|
+
adds corresponding messages element-wise. It assumes both inputs
|
|
50
|
+
have compatible shapes and aligned time spans.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self):
|
|
54
|
+
self._state = AddState()
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def state(self) -> AddState:
|
|
58
|
+
return self._state
|
|
59
|
+
|
|
60
|
+
@state.setter
|
|
61
|
+
def state(self, state: AddState | bytes | None) -> None:
|
|
62
|
+
if state is not None:
|
|
63
|
+
# TODO: Support hydrating state from bytes
|
|
64
|
+
# if isinstance(state, bytes):
|
|
65
|
+
# self._state = pickle.loads(state)
|
|
66
|
+
# else:
|
|
67
|
+
self._state = state
|
|
68
|
+
|
|
69
|
+
def push_a(self, msg: AxisArray) -> None:
|
|
70
|
+
"""Push a message to queue A."""
|
|
71
|
+
self._state.queue_a.put_nowait(msg)
|
|
72
|
+
|
|
73
|
+
def push_b(self, msg: AxisArray) -> None:
|
|
74
|
+
"""Push a message to queue B."""
|
|
75
|
+
self._state.queue_b.put_nowait(msg)
|
|
76
|
+
|
|
77
|
+
async def __acall__(self) -> AxisArray:
|
|
78
|
+
"""Await and add the next messages from both queues."""
|
|
79
|
+
a = await self._state.queue_a.get()
|
|
80
|
+
b = await self._state.queue_b.get()
|
|
81
|
+
return replace(a, data=a.data + b.data)
|
|
82
|
+
|
|
83
|
+
def __call__(self) -> AxisArray:
|
|
84
|
+
"""Synchronously get and add the next messages from both queues."""
|
|
85
|
+
return run_coroutine_sync(self.__acall__())
|
|
86
|
+
|
|
87
|
+
# Aliases for legacy interface
|
|
88
|
+
async def __anext__(self) -> AxisArray:
|
|
89
|
+
return await self.__acall__()
|
|
90
|
+
|
|
91
|
+
def __next__(self) -> AxisArray:
|
|
92
|
+
return self.__call__()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class Add(ez.Unit):
|
|
96
|
+
"""Add two signals together.
|
|
97
|
+
|
|
98
|
+
Assumes compatible/similar axes/dimensions and aligned time spans.
|
|
99
|
+
Messages are paired by arrival order (oldest from each queue).
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
INPUT_SIGNAL_A = ez.InputStream(AxisArray)
|
|
103
|
+
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
|
|
104
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
105
|
+
|
|
106
|
+
async def initialize(self) -> None:
|
|
107
|
+
self.processor = AddProcessor()
|
|
108
|
+
|
|
109
|
+
@ez.subscriber(INPUT_SIGNAL_A)
|
|
110
|
+
async def on_a(self, msg: AxisArray) -> None:
|
|
111
|
+
self.processor.push_a(msg)
|
|
112
|
+
|
|
113
|
+
@ez.subscriber(INPUT_SIGNAL_B)
|
|
114
|
+
async def on_b(self, msg: AxisArray) -> None:
|
|
115
|
+
self.processor.push_b(msg)
|
|
116
|
+
|
|
117
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
118
|
+
async def output(self) -> typing.AsyncGenerator:
|
|
119
|
+
while True:
|
|
120
|
+
yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
|
ezmsg/sigproc/math/clip.py
CHANGED
|
@@ -1,40 +1,48 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Clips the data to be within the specified range.
|
|
3
|
+
|
|
4
|
+
.. note::
|
|
5
|
+
This module supports the :doc:`Array API standard </guides/explanations/array_api>`,
|
|
6
|
+
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
|
|
7
|
+
"""
|
|
2
8
|
|
|
3
|
-
import numpy as np
|
|
4
9
|
import ezmsg.core as ez
|
|
5
|
-
from
|
|
10
|
+
from array_api_compat import get_namespace
|
|
11
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
6
12
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
13
|
from ezmsg.util.messages.util import replace
|
|
8
14
|
|
|
9
|
-
from ..base import GenAxisArray
|
|
10
15
|
|
|
16
|
+
class ClipSettings(ez.Settings):
|
|
17
|
+
min: float | None = None
|
|
18
|
+
"""Lower clip bound. If None, no lower clipping is applied."""
|
|
11
19
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
"""
|
|
15
|
-
Clips the data to be within the specified range. See :obj:`np.clip` for more details.
|
|
20
|
+
max: float | None = None
|
|
21
|
+
"""Upper clip bound. If None, no upper clipping is applied."""
|
|
16
22
|
|
|
17
|
-
Args:
|
|
18
|
-
a_min: Lower clip bound
|
|
19
|
-
a_max: Upper clip bound
|
|
20
23
|
|
|
21
|
-
|
|
22
|
-
|
|
24
|
+
class ClipTransformer(BaseTransformer[ClipSettings, AxisArray, AxisArray]):
|
|
25
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
26
|
+
xp = get_namespace(message.data)
|
|
27
|
+
return replace(
|
|
28
|
+
message,
|
|
29
|
+
data=xp.clip(message.data, self.settings.min, self.settings.max),
|
|
30
|
+
)
|
|
23
31
|
|
|
24
|
-
"""
|
|
25
|
-
msg_out = AxisArray(np.array([]), dims=[""])
|
|
26
|
-
while True:
|
|
27
|
-
msg_in: AxisArray = yield msg_out
|
|
28
|
-
msg_out = replace(msg_in, data=np.clip(msg_in.data, a_min, a_max))
|
|
29
32
|
|
|
33
|
+
class Clip(BaseTransformerUnit[ClipSettings, AxisArray, AxisArray, ClipTransformer]):
|
|
34
|
+
SETTINGS = ClipSettings
|
|
30
35
|
|
|
31
|
-
class ClipSettings(ez.Settings):
|
|
32
|
-
a_min: float
|
|
33
|
-
a_max: float
|
|
34
36
|
|
|
37
|
+
def clip(min: float | None = None, max: float | None = None) -> ClipTransformer:
|
|
38
|
+
"""
|
|
39
|
+
Clips the data to be within the specified range.
|
|
35
40
|
|
|
36
|
-
|
|
37
|
-
|
|
41
|
+
Args:
|
|
42
|
+
min: Lower clip bound. If None, no lower clipping is applied.
|
|
43
|
+
max: Upper clip bound. If None, no upper clipping is applied.
|
|
38
44
|
|
|
39
|
-
|
|
40
|
-
|
|
45
|
+
Returns:
|
|
46
|
+
:obj:`ClipTransformer`.
|
|
47
|
+
"""
|
|
48
|
+
return ClipTransformer(ClipSettings(min=min, max=max))
|