ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
  4. ezmsg/sigproc/affinetransform.py +16 -42
  5. ezmsg/sigproc/aggregate.py +17 -34
  6. ezmsg/sigproc/bandpower.py +12 -20
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +7 -16
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/coordinatespaces.py +142 -0
  13. ezmsg/sigproc/decimate.py +3 -7
  14. ezmsg/sigproc/denormalize.py +6 -11
  15. ezmsg/sigproc/detrend.py +3 -4
  16. ezmsg/sigproc/diff.py +8 -17
  17. ezmsg/sigproc/downsample.py +11 -20
  18. ezmsg/sigproc/ewma.py +11 -28
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +3 -4
  21. ezmsg/sigproc/fbcca.py +34 -59
  22. ezmsg/sigproc/filter.py +19 -45
  23. ezmsg/sigproc/filterbank.py +37 -74
  24. ezmsg/sigproc/filterbankdesign.py +7 -14
  25. ezmsg/sigproc/fir_hilbert.py +13 -30
  26. ezmsg/sigproc/fir_pmc.py +5 -10
  27. ezmsg/sigproc/firfilter.py +12 -14
  28. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  29. ezmsg/sigproc/kaiser.py +11 -15
  30. ezmsg/sigproc/math/abs.py +4 -3
  31. ezmsg/sigproc/math/add.py +121 -0
  32. ezmsg/sigproc/math/clip.py +4 -1
  33. ezmsg/sigproc/math/difference.py +100 -36
  34. ezmsg/sigproc/math/invert.py +3 -3
  35. ezmsg/sigproc/math/log.py +5 -6
  36. ezmsg/sigproc/math/scale.py +2 -0
  37. ezmsg/sigproc/messages.py +1 -2
  38. ezmsg/sigproc/quantize.py +3 -6
  39. ezmsg/sigproc/resample.py +17 -38
  40. ezmsg/sigproc/rollingscaler.py +12 -37
  41. ezmsg/sigproc/sampler.py +19 -37
  42. ezmsg/sigproc/scaler.py +11 -22
  43. ezmsg/sigproc/signalinjector.py +7 -18
  44. ezmsg/sigproc/slicer.py +14 -34
  45. ezmsg/sigproc/spectral.py +3 -3
  46. ezmsg/sigproc/spectrogram.py +12 -19
  47. ezmsg/sigproc/spectrum.py +17 -38
  48. ezmsg/sigproc/transpose.py +12 -24
  49. ezmsg/sigproc/util/asio.py +25 -156
  50. ezmsg/sigproc/util/axisarray_buffer.py +12 -26
  51. ezmsg/sigproc/util/buffer.py +22 -43
  52. ezmsg/sigproc/util/message.py +17 -31
  53. ezmsg/sigproc/util/profile.py +23 -174
  54. ezmsg/sigproc/util/sparse.py +7 -15
  55. ezmsg/sigproc/util/typeresolution.py +17 -83
  56. ezmsg/sigproc/wavelets.py +10 -19
  57. ezmsg/sigproc/window.py +29 -83
  58. ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
  59. ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
  60. ezmsg/sigproc/synth.py +0 -774
  61. ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
  62. ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
  63. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
  64. /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
@@ -4,6 +4,9 @@ import typing
4
4
  import ezmsg.core as ez
5
5
  import numpy as np
6
6
  import scipy.signal as sps
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.util import replace
9
+
7
10
  from ezmsg.sigproc.base import BaseStatefulTransformer, processor_state
8
11
  from ezmsg.sigproc.filter import (
9
12
  BACoeffs,
@@ -12,8 +15,6 @@ from ezmsg.sigproc.filter import (
12
15
  FilterBaseSettings,
13
16
  FilterByDesignTransformer,
14
17
  )
15
- from ezmsg.util.messages.axisarray import AxisArray
16
- from ezmsg.util.messages.util import replace
17
18
 
18
19
 
19
20
  class FIRHilbertFilterSettings(FilterBaseSettings):
@@ -60,7 +61,7 @@ class FIRHilbertFilterSettings(FilterBaseSettings):
60
61
 
61
62
  weight_pass: float = 1.0
62
63
  """
63
- Weight for Hilbert pass region.
64
+ Weight for Hilbert pass region.
64
65
  """
65
66
 
66
67
  weight_stop_lo: float = 1.0
@@ -74,7 +75,7 @@ class FIRHilbertFilterSettings(FilterBaseSettings):
74
75
  """
75
76
 
76
77
  norm_band: tuple[float, float] | None = None
77
- """
78
+ """
78
79
  Optional normalization band (f_lo, f_hi) in Hz for gain normalization.
79
80
  If None, no normalization is applied.
80
81
  """
@@ -128,9 +129,7 @@ def fir_hilbert_design_fun(
128
129
  if bands[i] <= bands[i - 1]:
129
130
  bands[i] = np.nextafter(bands[i - 1], np.inf)
130
131
  if bands[-2] >= nyq:
131
- ez.logger.warning(
132
- "Hilbert upper stopband collapsed; using 2-band (stop/pass) design."
133
- )
132
+ ez.logger.warning("Hilbert upper stopband collapsed; using 2-band (stop/pass) design.")
134
133
  bands = bands[:-3] + [nyq]
135
134
  desired = desired[:-1]
136
135
  weight = weight[:-1]
@@ -139,9 +138,7 @@ def fir_hilbert_design_fun(
139
138
  g = None
140
139
  if norm_freq is not None:
141
140
  if norm_freq < f1 or norm_freq > f2:
142
- ez.logger.warning(
143
- "Invalid normalization frequency specifications. Skipping normalization."
144
- )
141
+ ez.logger.warning("Invalid normalization frequency specifications. Skipping normalization.")
145
142
  else:
146
143
  f0 = float(norm_freq)
147
144
  w = 2.0 * np.pi * (np.asarray([f0], dtype=np.float64) / fs)
@@ -152,13 +149,9 @@ def fir_hilbert_design_fun(
152
149
  if lo < f1 or hi > f2:
153
150
  lo = max(lo, f1)
154
151
  hi = min(hi, f2)
155
- ez.logger.warning(
156
- "Normalization band outside passband. Clipping to passband for normalization."
157
- )
152
+ ez.logger.warning("Normalization band outside passband. Clipping to passband for normalization.")
158
153
  if lo >= hi:
159
- ez.logger.warning(
160
- "Invalid normalization band specifications. Skipping normalization."
161
- )
154
+ ez.logger.warning("Invalid normalization band specifications. Skipping normalization.")
162
155
  else:
163
156
  freqs = np.linspace(lo, hi, 2048, dtype=np.float64)
164
157
  w = 2.0 * np.pi * (np.asarray(freqs, dtype=np.float64) / fs)
@@ -169,9 +162,7 @@ def fir_hilbert_design_fun(
169
162
  return (b, a)
170
163
 
171
164
 
172
- class FIRHilbertFilterTransformer(
173
- FilterByDesignTransformer[FIRHilbertFilterSettings, BACoeffs]
174
- ):
165
+ class FIRHilbertFilterTransformer(FilterByDesignTransformer[FIRHilbertFilterSettings, BACoeffs]):
175
166
  def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]:
176
167
  if self.settings.coef_type != "ba":
177
168
  ez.logger.error("FIRHilbert only supports coef_type='ba'.")
@@ -198,11 +189,7 @@ class FIRHilbertFilterTransformer(
198
189
  return b.size if b is not None else None
199
190
 
200
191
 
201
- class FIRHilbertFilterUnit(
202
- BaseFilterByDesignTransformerUnit[
203
- FIRHilbertFilterSettings, FIRHilbertFilterTransformer
204
- ]
205
- ):
192
+ class FIRHilbertFilterUnit(BaseFilterByDesignTransformerUnit[FIRHilbertFilterSettings, FIRHilbertFilterTransformer]):
206
193
  SETTINGS = FIRHilbertFilterSettings
207
194
 
208
195
 
@@ -214,9 +201,7 @@ class FIRHilbertEnvelopeState:
214
201
 
215
202
 
216
203
  class FIRHilbertEnvelopeTransformer(
217
- BaseStatefulTransformer[
218
- FIRHilbertFilterSettings, AxisArray, AxisArray, FIRHilbertEnvelopeState
219
- ]
204
+ BaseStatefulTransformer[FIRHilbertFilterSettings, AxisArray, AxisArray, FIRHilbertEnvelopeState]
220
205
  ):
221
206
  """
222
207
  Processor for computing the envelope of a signal using the Hilbert transform.
@@ -304,9 +289,7 @@ class FIRHilbertEnvelopeTransformer(
304
289
 
305
290
  if self._state.delay_buf is None:
306
291
  lead_shape = x.shape[:-1]
307
- self._state.delay_buf = np.zeros(
308
- lead_shape + (self._state.dly,), dtype=x.dtype
309
- )
292
+ self._state.delay_buf = np.zeros(lead_shape + (self._state.dly,), dtype=x.dtype)
310
293
 
311
294
  x_cat = np.concatenate([self._state.delay_buf, x], axis=-1)
312
295
  x_delayed_full = x_cat[..., : -self._state.dly]
ezmsg/sigproc/fir_pmc.py CHANGED
@@ -4,6 +4,7 @@ import typing
4
4
  import ezmsg.core as ez
5
5
  import numpy as np
6
6
  import scipy.signal
7
+
7
8
  from ezmsg.sigproc.filter import (
8
9
  BACoeffs,
9
10
  BaseFilterByDesignTransformerUnit,
@@ -33,14 +34,14 @@ class ParksMcClellanFIRSettings(FilterBaseSettings):
33
34
  """
34
35
  Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
35
36
  if this is lower than `cutoff` then this is the beginning of the bandpass
36
- or if this is greater than `cutoff` then this is the end of the bandstop.
37
+ or if this is greater than `cutoff` then this is the end of the bandstop.
37
38
  """
38
39
 
39
40
  cutoff: float | None = None
40
41
  """
41
42
  Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
42
43
  if this is greater than `cuton` then this is the end of the bandpass,
43
- or if this is less than `cuton` then this is the beginning of the bandstop.
44
+ or if this is less than `cuton` then this is the beginning of the bandstop.
44
45
  """
45
46
 
46
47
  transition: float = 10.0
@@ -187,9 +188,7 @@ def parks_mcclellan_design_fun(
187
188
  return (b, np.array([1.0]))
188
189
 
189
190
 
190
- class ParksMcClellanFIRTransformer(
191
- FilterByDesignTransformer[ParksMcClellanFIRSettings, BACoeffs]
192
- ):
191
+ class ParksMcClellanFIRTransformer(FilterByDesignTransformer[ParksMcClellanFIRSettings, BACoeffs]):
193
192
  def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]:
194
193
  if self.settings.coef_type != "ba":
195
194
  ez.logger.error("ParksMcClellanFIR only supports coef_type='ba'.")
@@ -206,9 +205,5 @@ class ParksMcClellanFIRTransformer(
206
205
  )
207
206
 
208
207
 
209
- class ParksMcClellanFIR(
210
- BaseFilterByDesignTransformerUnit[
211
- ParksMcClellanFIRSettings, ParksMcClellanFIRTransformer
212
- ]
213
- ):
208
+ class ParksMcClellanFIR(BaseFilterByDesignTransformerUnit[ParksMcClellanFIRSettings, ParksMcClellanFIRTransformer]):
214
209
  SETTINGS = ParksMcClellanFIRSettings
@@ -6,10 +6,10 @@ import numpy.typing as npt
6
6
  import scipy.signal
7
7
 
8
8
  from .filter import (
9
- FilterBaseSettings,
10
- FilterByDesignTransformer,
11
9
  BACoeffs,
12
10
  BaseFilterByDesignTransformerUnit,
11
+ FilterBaseSettings,
12
+ FilterByDesignTransformer,
13
13
  )
14
14
 
15
15
 
@@ -25,16 +25,16 @@ class FIRFilterSettings(FilterBaseSettings):
25
25
 
26
26
  cutoff: float | npt.ArrayLike | None = None
27
27
  """
28
- Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
29
- (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
30
- the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
31
- cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
28
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
29
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
30
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
31
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
32
32
  not be included in cutoff.
33
33
  """
34
34
 
35
35
  width: float | None = None
36
36
  """
37
- If width is not None, then assume it is the approximate width of the transition region (expressed in
37
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
38
38
  the same units as fs) for use in Kaiser FIR filter design. In this case, the window argument is ignored.
39
39
  """
40
40
 
@@ -45,18 +45,18 @@ class FIRFilterSettings(FilterBaseSettings):
45
45
 
46
46
  pass_zero: bool | str = True
47
47
  """
48
- If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
48
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
49
49
  be a string argument for the desired filter type (equivalent to btype in IIR design functions).
50
50
  {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
51
51
  """
52
52
 
53
53
  scale: bool = True
54
54
  """
55
- Set to True to scale the coefficients so that the frequency response is exactly unity at a certain
55
+ Set to True to scale the coefficients so that the frequency response is exactly unity at a certain
56
56
  frequency. That frequency is either:
57
57
  * 0 (DC) if the first passband starts at 0 (i.e. pass_zero is True)
58
- * fs/2 (the Nyquist frequency) if the first passband ends at fs/2
59
- (i.e the filter is a single band highpass filter);
58
+ * fs/2 (the Nyquist frequency) if the first passband ends at fs/2
59
+ (i.e the filter is a single band highpass filter);
60
60
  center of first passband otherwise
61
61
  """
62
62
 
@@ -113,7 +113,5 @@ class FIRFilterTransformer(FilterByDesignTransformer[FIRFilterSettings, BACoeffs
113
113
  )
114
114
 
115
115
 
116
- class FIRFilter(
117
- BaseFilterByDesignTransformerUnit[FIRFilterSettings, FIRFilterTransformer]
118
- ):
116
+ class FIRFilter(BaseFilterByDesignTransformerUnit[FIRFilterSettings, FIRFilterTransformer]):
119
117
  SETTINGS = FIRFilterSettings
@@ -1,13 +1,13 @@
1
- from typing import Callable
2
1
  import warnings
2
+ from typing import Callable
3
3
 
4
4
  import numpy as np
5
5
 
6
6
  from .filter import (
7
- FilterBaseSettings,
8
7
  BACoeffs,
9
- FilterByDesignTransformer,
10
8
  BaseFilterByDesignTransformerUnit,
9
+ FilterBaseSettings,
10
+ FilterByDesignTransformer,
11
11
  )
12
12
 
13
13
 
@@ -68,9 +68,7 @@ def gaussian_smoothing_filter_design(
68
68
  return b, a
69
69
 
70
70
 
71
- class GaussianSmoothingFilterTransformer(
72
- FilterByDesignTransformer[GaussianSmoothingSettings, BACoeffs]
73
- ):
71
+ class GaussianSmoothingFilterTransformer(FilterByDesignTransformer[GaussianSmoothingSettings, BACoeffs]):
74
72
  def get_design_function(
75
73
  self,
76
74
  ) -> Callable[[float], BACoeffs]:
@@ -86,8 +84,6 @@ class GaussianSmoothingFilterTransformer(
86
84
 
87
85
 
88
86
  class GaussianSmoothingFilter(
89
- BaseFilterByDesignTransformerUnit[
90
- GaussianSmoothingSettings, GaussianSmoothingFilterTransformer
91
- ]
87
+ BaseFilterByDesignTransformerUnit[GaussianSmoothingSettings, GaussianSmoothingFilterTransformer]
92
88
  ):
93
89
  SETTINGS = GaussianSmoothingSettings
ezmsg/sigproc/kaiser.py CHANGED
@@ -6,10 +6,10 @@ import numpy.typing as npt
6
6
  import scipy.signal
7
7
 
8
8
  from .filter import (
9
- FilterBaseSettings,
10
- FilterByDesignTransformer,
11
9
  BACoeffs,
12
10
  BaseFilterByDesignTransformerUnit,
11
+ FilterBaseSettings,
12
+ FilterByDesignTransformer,
13
13
  )
14
14
 
15
15
 
@@ -20,30 +20,30 @@ class KaiserFilterSettings(FilterBaseSettings):
20
20
 
21
21
  cutoff: float | npt.ArrayLike | None = None
22
22
  """
23
- Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
24
- (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
25
- the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
26
- cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
23
+ Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies
24
+ (that is, band edges). In the former case, as a float, the cutoff frequency should correspond with
25
+ the half-amplitude point, where the attenuation will be -6dB. In the latter case, the frequencies in
26
+ cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must
27
27
  not be included in cutoff.
28
28
  """
29
29
 
30
30
  ripple: float | None = None
31
31
  """
32
- Upper bound for the deviation (in dB) of the magnitude of the filter's frequency response from that of
32
+ Upper bound for the deviation (in dB) of the magnitude of the filter's frequency response from that of
33
33
  the desired filter (not including frequencies in any transition intervals).
34
34
  See scipy.signal.kaiserord for more information.
35
35
  """
36
36
 
37
37
  width: float | None = None
38
38
  """
39
- If width is not None, then assume it is the approximate width of the transition region (expressed in
39
+ If width is not None, then assume it is the approximate width of the transition region (expressed in
40
40
  the same units as fs) for use in Kaiser FIR filter design.
41
41
  See scipy.signal.kaiserord for more information.
42
42
  """
43
43
 
44
44
  pass_zero: bool | str = True
45
45
  """
46
- If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
46
+ If True, the gain at the frequency 0 (i.e., the “DC gain”) is 1. If False, the DC gain is 0. Can also
47
47
  be a string argument for the desired filter type (equivalent to btype in IIR design functions).
48
48
  {‘lowpass’, ‘highpass’, ‘bandpass’, ‘bandstop’}
49
49
  """
@@ -88,9 +88,7 @@ def kaiser_design_fun(
88
88
  return (taps, np.array([1.0]))
89
89
 
90
90
 
91
- class KaiserFilterTransformer(
92
- FilterByDesignTransformer[KaiserFilterSettings, BACoeffs]
93
- ):
91
+ class KaiserFilterTransformer(FilterByDesignTransformer[KaiserFilterSettings, BACoeffs]):
94
92
  def get_design_function(
95
93
  self,
96
94
  ) -> typing.Callable[[float], BACoeffs | None]:
@@ -104,7 +102,5 @@ class KaiserFilterTransformer(
104
102
  )
105
103
 
106
104
 
107
- class KaiserFilter(
108
- BaseFilterByDesignTransformerUnit[KaiserFilterSettings, KaiserFilterTransformer]
109
- ):
105
+ class KaiserFilter(BaseFilterByDesignTransformerUnit[KaiserFilterSettings, KaiserFilterTransformer]):
110
106
  SETTINGS = KaiserFilterSettings
ezmsg/sigproc/math/abs.py CHANGED
@@ -1,3 +1,6 @@
1
+ """Take the absolute value of the data."""
2
+ # TODO: Array API
3
+
1
4
  import numpy as np
2
5
  from ezmsg.util.messages.axisarray import AxisArray
3
6
  from ezmsg.util.messages.util import replace
@@ -14,9 +17,7 @@ class AbsTransformer(BaseTransformer[None, AxisArray, AxisArray]):
14
17
  return replace(message, data=np.abs(message.data))
15
18
 
16
19
 
17
- class Abs(
18
- BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]
19
- ): ... # SETTINGS = None
20
+ class Abs(BaseTransformerUnit[None, AxisArray, AxisArray, AbsTransformer]): ... # SETTINGS = None
20
21
 
21
22
 
22
23
  def abs() -> AbsTransformer:
@@ -0,0 +1,121 @@
1
+ """Add 2 signals or add a constant to a signal."""
2
+
3
+ import asyncio
4
+ import typing
5
+ from dataclasses import dataclass, field
6
+
7
+ import ezmsg.core as ez
8
+ from ezmsg.baseproc.util.asio import run_coroutine_sync
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
11
+
12
+ from ..base import BaseTransformer, BaseTransformerUnit
13
+
14
+ # --- Constant Addition (single input) ---
15
+
16
+
17
+ class ConstAddSettings(ez.Settings):
18
+ value: float = 0.0
19
+ """Number to add to the input data."""
20
+
21
+
22
+ class ConstAddTransformer(BaseTransformer[ConstAddSettings, AxisArray, AxisArray]):
23
+ """Add a constant value to input data."""
24
+
25
+ def _process(self, message: AxisArray) -> AxisArray:
26
+ return replace(message, data=message.data + self.settings.value)
27
+
28
+
29
+ class ConstAdd(BaseTransformerUnit[ConstAddSettings, AxisArray, AxisArray, ConstAddTransformer]):
30
+ """Unit wrapper for ConstAddTransformer."""
31
+
32
+ SETTINGS = ConstAddSettings
33
+
34
+
35
+ # --- Two-input Addition ---
36
+
37
+
38
+ @dataclass
39
+ class AddState:
40
+ """State for Add processor with two input queues."""
41
+
42
+ queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
43
+ queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
44
+
45
+
46
+ class AddProcessor:
47
+ """Processor that adds two AxisArray signals together.
48
+
49
+ This processor maintains separate queues for two input streams and
50
+ adds corresponding messages element-wise. It assumes both inputs
51
+ have compatible shapes and aligned time spans.
52
+ """
53
+
54
+ def __init__(self):
55
+ self._state = AddState()
56
+
57
+ @property
58
+ def state(self) -> AddState:
59
+ return self._state
60
+
61
+ @state.setter
62
+ def state(self, state: AddState | bytes | None) -> None:
63
+ if state is not None:
64
+ # TODO: Support hydrating state from bytes
65
+ # if isinstance(state, bytes):
66
+ # self._state = pickle.loads(state)
67
+ # else:
68
+ self._state = state
69
+
70
+ def push_a(self, msg: AxisArray) -> None:
71
+ """Push a message to queue A."""
72
+ self._state.queue_a.put_nowait(msg)
73
+
74
+ def push_b(self, msg: AxisArray) -> None:
75
+ """Push a message to queue B."""
76
+ self._state.queue_b.put_nowait(msg)
77
+
78
+ async def __acall__(self) -> AxisArray:
79
+ """Await and add the next messages from both queues."""
80
+ a = await self._state.queue_a.get()
81
+ b = await self._state.queue_b.get()
82
+ return replace(a, data=a.data + b.data)
83
+
84
+ def __call__(self) -> AxisArray:
85
+ """Synchronously get and add the next messages from both queues."""
86
+ return run_coroutine_sync(self.__acall__())
87
+
88
+ # Aliases for legacy interface
89
+ async def __anext__(self) -> AxisArray:
90
+ return await self.__acall__()
91
+
92
+ def __next__(self) -> AxisArray:
93
+ return self.__call__()
94
+
95
+
96
+ class Add(ez.Unit):
97
+ """Add two signals together.
98
+
99
+ Assumes compatible/similar axes/dimensions and aligned time spans.
100
+ Messages are paired by arrival order (oldest from each queue).
101
+ """
102
+
103
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
104
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
105
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
106
+
107
+ async def initialize(self) -> None:
108
+ self.processor = AddProcessor()
109
+
110
+ @ez.subscriber(INPUT_SIGNAL_A)
111
+ async def on_a(self, msg: AxisArray) -> None:
112
+ self.processor.push_a(msg)
113
+
114
+ @ez.subscriber(INPUT_SIGNAL_B)
115
+ async def on_b(self, msg: AxisArray) -> None:
116
+ self.processor.push_b(msg)
117
+
118
+ @ez.publisher(OUTPUT_SIGNAL)
119
+ async def output(self) -> typing.AsyncGenerator:
120
+ while True:
121
+ yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
@@ -1,5 +1,8 @@
1
- import numpy as np
1
+ """Clips the data to be within the specified range."""
2
+ # TODO: Array API
3
+
2
4
  import ezmsg.core as ez
5
+ import numpy as np
3
6
  from ezmsg.util.messages.axisarray import AxisArray
4
7
  from ezmsg.util.messages.util import replace
5
8
 
@@ -1,4 +1,11 @@
1
+ """Take the difference between 2 signals or between a signal and a constant value."""
2
+
3
+ import asyncio
4
+ import typing
5
+ from dataclasses import dataclass, field
6
+
1
7
  import ezmsg.core as ez
8
+ from ezmsg.baseproc.util.asio import run_coroutine_sync
2
9
  from ezmsg.util.messages.axisarray import AxisArray
3
10
  from ezmsg.util.messages.util import replace
4
11
 
@@ -10,12 +17,11 @@ class ConstDifferenceSettings(ez.Settings):
10
17
  """number to subtract or be subtracted from the input data"""
11
18
 
12
19
  subtrahend: bool = True
13
- """If True (default) then value is subtracted from the input data. If False, the input data is subtracted from value."""
20
+ """If True (default) then value is subtracted from the input data. If False, the input data
21
+ is subtracted from value."""
14
22
 
15
23
 
16
- class ConstDifferenceTransformer(
17
- BaseTransformer[ConstDifferenceSettings, AxisArray, AxisArray]
18
- ):
24
+ class ConstDifferenceTransformer(BaseTransformer[ConstDifferenceSettings, AxisArray, AxisArray]):
19
25
  def _process(self, message: AxisArray) -> AxisArray:
20
26
  return replace(
21
27
  message,
@@ -25,17 +31,11 @@ class ConstDifferenceTransformer(
25
31
  )
26
32
 
27
33
 
28
- class ConstDifference(
29
- BaseTransformerUnit[
30
- ConstDifferenceSettings, AxisArray, AxisArray, ConstDifferenceTransformer
31
- ]
32
- ):
34
+ class ConstDifference(BaseTransformerUnit[ConstDifferenceSettings, AxisArray, AxisArray, ConstDifferenceTransformer]):
33
35
  SETTINGS = ConstDifferenceSettings
34
36
 
35
37
 
36
- def const_difference(
37
- value: float = 0.0, subtrahend: bool = True
38
- ) -> ConstDifferenceTransformer:
38
+ def const_difference(value: float = 0.0, subtrahend: bool = True) -> ConstDifferenceTransformer:
39
39
  """
40
40
  result = (in_data - value) if subtrahend else (value - in_data)
41
41
  https://en.wikipedia.org/wiki/Template:Arithmetic_operations
@@ -47,27 +47,91 @@ def const_difference(
47
47
 
48
48
  Returns: :obj:`ConstDifferenceTransformer`.
49
49
  """
50
- return ConstDifferenceTransformer(
51
- ConstDifferenceSettings(value=value, subtrahend=subtrahend)
52
- )
53
-
54
-
55
- # class DifferenceSettings(ez.Settings):
56
- # pass
57
- #
58
- #
59
- # class Difference(ez.Unit):
60
- # SETTINGS = DifferenceSettings
61
- #
62
- # INPUT_SIGNAL_1 = ez.InputStream(AxisArray)
63
- # INPUT_SIGNAL_2 = ez.InputStream(AxisArray)
64
- # OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
65
- #
66
- # @ez.subscriber(INPUT_SIGNAL_2, zero_copy=True)
67
- # @ez.publisher(OUTPUT_SIGNAL)
68
- # async def on_input_2(self, message: AxisArray) -> typing.AsyncGenerator:
69
- # # TODO: buffer_2
70
- # # TODO: take buffer_1 - buffer_2 for ranges that align
71
- # # TODO: Drop samples from buffer_1 and buffer_2
72
- # if ret is not None:
73
- # yield self.OUTPUT_SIGNAL, ret
50
+ return ConstDifferenceTransformer(ConstDifferenceSettings(value=value, subtrahend=subtrahend))
51
+
52
+
53
+ # --- Two-input Difference ---
54
+
55
+
56
+ @dataclass
57
+ class DifferenceState:
58
+ """State for Difference processor with two input queues."""
59
+
60
+ queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
61
+ queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue)
62
+
63
+
64
+ class DifferenceProcessor:
65
+ """Processor that subtracts two AxisArray signals (A - B).
66
+
67
+ This processor maintains separate queues for two input streams and
68
+ subtracts corresponding messages element-wise. It assumes both inputs
69
+ have compatible shapes and aligned time spans.
70
+ """
71
+
72
+ def __init__(self):
73
+ self._state = DifferenceState()
74
+
75
+ @property
76
+ def state(self) -> DifferenceState:
77
+ return self._state
78
+
79
+ @state.setter
80
+ def state(self, state: DifferenceState | bytes | None) -> None:
81
+ if state is not None:
82
+ self._state = state
83
+
84
+ def push_a(self, msg: AxisArray) -> None:
85
+ """Push a message to queue A (minuend)."""
86
+ self._state.queue_a.put_nowait(msg)
87
+
88
+ def push_b(self, msg: AxisArray) -> None:
89
+ """Push a message to queue B (subtrahend)."""
90
+ self._state.queue_b.put_nowait(msg)
91
+
92
+ async def __acall__(self) -> AxisArray:
93
+ """Await and subtract the next messages (A - B)."""
94
+ a = await self._state.queue_a.get()
95
+ b = await self._state.queue_b.get()
96
+ return replace(a, data=a.data - b.data)
97
+
98
+ def __call__(self) -> AxisArray:
99
+ """Synchronously get and subtract the next messages."""
100
+ return run_coroutine_sync(self.__acall__())
101
+
102
+ # Aliases for legacy interface
103
+ async def __anext__(self) -> AxisArray:
104
+ return await self.__acall__()
105
+
106
+ def __next__(self) -> AxisArray:
107
+ return self.__call__()
108
+
109
+
110
+ class Difference(ez.Unit):
111
+ """Subtract two signals (A - B).
112
+
113
+ Assumes compatible/similar axes/dimensions and aligned time spans.
114
+ Messages are paired by arrival order (oldest from each queue).
115
+
116
+ OUTPUT = INPUT_SIGNAL_A - INPUT_SIGNAL_B
117
+ """
118
+
119
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
120
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
121
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
122
+
123
+ async def initialize(self) -> None:
124
+ self.processor = DifferenceProcessor()
125
+
126
+ @ez.subscriber(INPUT_SIGNAL_A)
127
+ async def on_a(self, msg: AxisArray) -> None:
128
+ self.processor.push_a(msg)
129
+
130
+ @ez.subscriber(INPUT_SIGNAL_B)
131
+ async def on_b(self, msg: AxisArray) -> None:
132
+ self.processor.push_b(msg)
133
+
134
+ @ez.publisher(OUTPUT_SIGNAL)
135
+ async def output(self) -> typing.AsyncGenerator:
136
+ while True:
137
+ yield self.OUTPUT_SIGNAL, await self.processor.__acall__()
@@ -1,3 +1,5 @@
1
+ """1/data transformer."""
2
+
1
3
  from ezmsg.util.messages.axisarray import AxisArray
2
4
  from ezmsg.util.messages.util import replace
3
5
 
@@ -9,9 +11,7 @@ class InvertTransformer(BaseTransformer[None, AxisArray, AxisArray]):
9
11
  return replace(message, data=1 / message.data)
10
12
 
11
13
 
12
- class Invert(
13
- BaseTransformerUnit[None, AxisArray, AxisArray, InvertTransformer]
14
- ): ... # SETTINGS = None
14
+ class Invert(BaseTransformerUnit[None, AxisArray, AxisArray, InvertTransformer]): ... # SETTINGS = None
15
15
 
16
16
 
17
17
  def invert() -> InvertTransformer: