ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
  4. ezmsg/sigproc/affinetransform.py +16 -42
  5. ezmsg/sigproc/aggregate.py +17 -34
  6. ezmsg/sigproc/bandpower.py +12 -20
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +7 -16
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/coordinatespaces.py +142 -0
  13. ezmsg/sigproc/decimate.py +3 -7
  14. ezmsg/sigproc/denormalize.py +6 -11
  15. ezmsg/sigproc/detrend.py +3 -4
  16. ezmsg/sigproc/diff.py +8 -17
  17. ezmsg/sigproc/downsample.py +11 -20
  18. ezmsg/sigproc/ewma.py +11 -28
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +3 -4
  21. ezmsg/sigproc/fbcca.py +34 -59
  22. ezmsg/sigproc/filter.py +19 -45
  23. ezmsg/sigproc/filterbank.py +37 -74
  24. ezmsg/sigproc/filterbankdesign.py +7 -14
  25. ezmsg/sigproc/fir_hilbert.py +13 -30
  26. ezmsg/sigproc/fir_pmc.py +5 -10
  27. ezmsg/sigproc/firfilter.py +12 -14
  28. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  29. ezmsg/sigproc/kaiser.py +11 -15
  30. ezmsg/sigproc/math/abs.py +4 -3
  31. ezmsg/sigproc/math/add.py +121 -0
  32. ezmsg/sigproc/math/clip.py +4 -1
  33. ezmsg/sigproc/math/difference.py +100 -36
  34. ezmsg/sigproc/math/invert.py +3 -3
  35. ezmsg/sigproc/math/log.py +5 -6
  36. ezmsg/sigproc/math/scale.py +2 -0
  37. ezmsg/sigproc/messages.py +1 -2
  38. ezmsg/sigproc/quantize.py +3 -6
  39. ezmsg/sigproc/resample.py +17 -38
  40. ezmsg/sigproc/rollingscaler.py +12 -37
  41. ezmsg/sigproc/sampler.py +19 -37
  42. ezmsg/sigproc/scaler.py +11 -22
  43. ezmsg/sigproc/signalinjector.py +7 -18
  44. ezmsg/sigproc/slicer.py +14 -34
  45. ezmsg/sigproc/spectral.py +3 -3
  46. ezmsg/sigproc/spectrogram.py +12 -19
  47. ezmsg/sigproc/spectrum.py +17 -38
  48. ezmsg/sigproc/transpose.py +12 -24
  49. ezmsg/sigproc/util/asio.py +25 -156
  50. ezmsg/sigproc/util/axisarray_buffer.py +12 -26
  51. ezmsg/sigproc/util/buffer.py +22 -43
  52. ezmsg/sigproc/util/message.py +17 -31
  53. ezmsg/sigproc/util/profile.py +23 -174
  54. ezmsg/sigproc/util/sparse.py +7 -15
  55. ezmsg/sigproc/util/typeresolution.py +17 -83
  56. ezmsg/sigproc/wavelets.py +10 -19
  57. ezmsg/sigproc/window.py +29 -83
  58. ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
  59. ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
  60. ezmsg/sigproc/synth.py +0 -774
  61. ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
  62. ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
  63. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
  64. /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
@@ -5,11 +5,11 @@ import scipy.signal
5
5
  from scipy.signal import normalize
6
6
 
7
7
  from .filter import (
8
- FilterBaseSettings,
9
8
  BACoeffs,
10
- SOSCoeffs,
11
- FilterByDesignTransformer,
12
9
  BaseFilterByDesignTransformerUnit,
10
+ FilterBaseSettings,
11
+ FilterByDesignTransformer,
12
+ SOSCoeffs,
13
13
  )
14
14
 
15
15
 
@@ -27,14 +27,14 @@ class ButterworthFilterSettings(FilterBaseSettings):
27
27
  """
28
28
  Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
29
29
  if this is lower than `cutoff` then this is the beginning of the bandpass
30
- or if this is greater than `cutoff` then this is the end of the bandstop.
30
+ or if this is greater than `cutoff` then this is the end of the bandstop.
31
31
  """
32
32
 
33
33
  cutoff: float | None = None
34
34
  """
35
35
  Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
36
36
  if this is greater than `cuton` then this is the end of the bandpass,
37
- or if this is less than `cuton` then this is the beginning of the bandstop.
37
+ or if this is less than `cuton` then this is the beginning of the bandstop.
38
38
  """
39
39
 
40
40
  wn_hz: bool = True
@@ -96,9 +96,7 @@ def butter_design_fun(
96
96
  """
97
97
  coefs = None
98
98
  if order > 0:
99
- btype, cutoffs = ButterworthFilterSettings(
100
- order=order, cuton=cuton, cutoff=cutoff
101
- ).filter_specs()
99
+ btype, cutoffs = ButterworthFilterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs()
102
100
  coefs = scipy.signal.butter(
103
101
  order,
104
102
  Wn=cutoffs,
@@ -111,9 +109,7 @@ def butter_design_fun(
111
109
  return coefs
112
110
 
113
111
 
114
- class ButterworthFilterTransformer(
115
- FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]
116
- ):
112
+ class ButterworthFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]):
117
113
  def get_design_function(
118
114
  self,
119
115
  ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
@@ -127,11 +123,7 @@ class ButterworthFilterTransformer(
127
123
  )
128
124
 
129
125
 
130
- class ButterworthFilter(
131
- BaseFilterByDesignTransformerUnit[
132
- ButterworthFilterSettings, ButterworthFilterTransformer
133
- ]
134
- ):
126
+ class ButterworthFilter(BaseFilterByDesignTransformerUnit[ButterworthFilterSettings, ButterworthFilterTransformer]):
135
127
  SETTINGS = ButterworthFilterSettings
136
128
 
137
129
 
@@ -4,6 +4,9 @@ import typing
4
4
  import ezmsg.core as ez
5
5
  import numpy as np
6
6
  import scipy.signal
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.util import replace
9
+
7
10
  from ezmsg.sigproc.base import SettingsType
8
11
  from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
9
12
  from ezmsg.sigproc.filter import (
@@ -12,8 +15,6 @@ from ezmsg.sigproc.filter import (
12
15
  FilterByDesignTransformer,
13
16
  SOSCoeffs,
14
17
  )
15
- from ezmsg.util.messages.axisarray import AxisArray
16
- from ezmsg.util.messages.util import replace
17
18
 
18
19
 
19
20
  class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
@@ -34,9 +35,7 @@ class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
34
35
  """
35
36
 
36
37
 
37
- class ButterworthZeroPhaseTransformer(
38
- FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]
39
- ):
38
+ class ButterworthZeroPhaseTransformer(FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]):
40
39
  """Zero-phase (filtfilt) Butterworth using your design function."""
41
40
 
42
41
  def get_design_function(
@@ -51,9 +50,7 @@ class ButterworthZeroPhaseTransformer(
51
50
  wn_hz=self.settings.wn_hz,
52
51
  )
53
52
 
54
- def update_settings(
55
- self, new_settings: typing.Optional[SettingsType] = None, **kwargs
56
- ) -> None:
53
+ def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
57
54
  """
58
55
  Update settings and mark that filter coefficients need to be recalculated.
59
56
 
@@ -91,11 +88,7 @@ class ButterworthZeroPhaseTransformer(
91
88
  self._fs_cache = fs
92
89
  self.state.needs_redesign = False
93
90
 
94
- if (
95
- self._coefs_cache is None
96
- or self.settings.order <= 0
97
- or message.data.size <= 0
98
- ):
91
+ if self._coefs_cache is None or self.settings.order <= 0 or message.data.size <= 0:
99
92
  return message
100
93
 
101
94
  x = message.data
@@ -125,8 +118,6 @@ class ButterworthZeroPhaseTransformer(
125
118
 
126
119
 
127
120
  class ButterworthZeroPhase(
128
- BaseFilterByDesignTransformerUnit[
129
- ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer
130
- ]
121
+ BaseFilterByDesignTransformerUnit[ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer]
131
122
  ):
132
123
  SETTINGS = ButterworthZeroPhaseSettings
ezmsg/sigproc/cheby.py CHANGED
@@ -5,11 +5,11 @@ import scipy.signal
5
5
  from scipy.signal import normalize
6
6
 
7
7
  from .filter import (
8
+ BACoeffs,
9
+ BaseFilterByDesignTransformerUnit,
8
10
  FilterBaseSettings,
9
11
  FilterByDesignTransformer,
10
- BACoeffs,
11
12
  SOSCoeffs,
12
- BaseFilterByDesignTransformerUnit,
13
13
  )
14
14
 
15
15
 
@@ -104,9 +104,7 @@ def cheby_design_fun(
104
104
  return coefs
105
105
 
106
106
 
107
- class ChebyshevFilterTransformer(
108
- FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]
109
- ):
107
+ class ChebyshevFilterTransformer(FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]):
110
108
  def get_design_function(
111
109
  self,
112
110
  ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
@@ -123,9 +121,5 @@ class ChebyshevFilterTransformer(
123
121
  )
124
122
 
125
123
 
126
- class ChebyshevFilter(
127
- BaseFilterByDesignTransformerUnit[
128
- ChebyshevFilterSettings, ChebyshevFilterTransformer
129
- ]
130
- ):
124
+ class ChebyshevFilter(BaseFilterByDesignTransformerUnit[ChebyshevFilterSettings, ChebyshevFilterTransformer]):
131
125
  SETTINGS = ChebyshevFilterSettings
@@ -1,15 +1,16 @@
1
1
  import functools
2
2
  import typing
3
+
3
4
  import numpy as np
4
5
  import scipy.signal
5
6
  from scipy.signal import normalize
6
7
 
7
8
  from .filter import (
9
+ BACoeffs,
10
+ BaseFilterByDesignTransformerUnit,
8
11
  FilterBaseSettings,
9
12
  FilterByDesignTransformer,
10
- BACoeffs,
11
13
  SOSCoeffs,
12
- BaseFilterByDesignTransformerUnit,
13
14
  )
14
15
 
15
16
 
@@ -103,9 +104,7 @@ def comb_design_fun(
103
104
  return combined_sos
104
105
 
105
106
 
106
- class CombFilterTransformer(
107
- FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]
108
- ):
107
+ class CombFilterTransformer(FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]):
109
108
  def get_design_function(
110
109
  self,
111
110
  ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
@@ -120,9 +119,7 @@ class CombFilterTransformer(
120
119
  )
121
120
 
122
121
 
123
- class CombFilterUnit(
124
- BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]
125
- ):
122
+ class CombFilterUnit(BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]):
126
123
  SETTINGS = CombFilterSettings
127
124
 
128
125
 
@@ -0,0 +1,142 @@
1
+ """
2
+ Coordinate space transformations for streaming data.
3
+
4
+ This module provides utilities and ezmsg nodes for transforming between
5
+ Cartesian (x, y) and polar (r, theta) coordinate systems.
6
+ """
7
+
8
+ from enum import Enum
9
+ from typing import Tuple
10
+
11
+ import ezmsg.core as ez
12
+ import numpy as np
13
+ import numpy.typing as npt
14
+ from ezmsg.baseproc import (
15
+ BaseTransformer,
16
+ BaseTransformerUnit,
17
+ )
18
+ from ezmsg.util.messages.axisarray import AxisArray, replace
19
+
20
+ # -- Utility functions for coordinate transformations --
21
+
22
+
23
+ def polar2z(r: npt.ArrayLike, theta: npt.ArrayLike) -> npt.ArrayLike:
24
+ """Convert polar coordinates to complex number representation."""
25
+ return r * np.exp(1j * theta)
26
+
27
+
28
+ def z2polar(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
29
+ """Convert complex number to polar coordinates (r, theta)."""
30
+ return np.abs(z), np.angle(z)
31
+
32
+
33
+ def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:
34
+ """Convert Cartesian coordinates to complex number representation."""
35
+ return x + 1j * y
36
+
37
+
38
+ def z2cart(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
39
+ """Convert complex number to Cartesian coordinates (x, y)."""
40
+ return np.real(z), np.imag(z)
41
+
42
+
43
+ def cart2pol(x: npt.ArrayLike, y: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
44
+ """Convert Cartesian coordinates (x, y) to polar coordinates (r, theta)."""
45
+ return z2polar(cart2z(x, y))
46
+
47
+
48
+ def pol2cart(r: npt.ArrayLike, theta: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
49
+ """Convert polar coordinates (r, theta) to Cartesian coordinates (x, y)."""
50
+ return z2cart(polar2z(r, theta))
51
+
52
+
53
+ # -- ezmsg transformer classes --
54
+
55
+
56
+ class CoordinateMode(str, Enum):
57
+ """Transformation mode for coordinate conversion."""
58
+
59
+ CART2POL = "cart2pol"
60
+ """Convert Cartesian (x, y) to polar (r, theta)."""
61
+
62
+ POL2CART = "pol2cart"
63
+ """Convert polar (r, theta) to Cartesian (x, y)."""
64
+
65
+
66
+ class CoordinateSpacesSettings(ez.Settings):
67
+ """
68
+ Settings for :obj:`CoordinateSpaces`.
69
+
70
+ See :obj:`coordinate_spaces` for argument details.
71
+ """
72
+
73
+ mode: CoordinateMode = CoordinateMode.CART2POL
74
+ """The transformation mode: 'cart2pol' or 'pol2cart'."""
75
+
76
+ axis: str | None = None
77
+ """
78
+ The name of the axis containing the coordinate components.
79
+ Defaults to the last axis. Must have exactly 2 elements (x,y or r,theta).
80
+ """
81
+
82
+
83
+ class CoordinateSpacesTransformer(BaseTransformer[CoordinateSpacesSettings, AxisArray, AxisArray]):
84
+ """
85
+ Transform between Cartesian and polar coordinate systems.
86
+
87
+ The input must have exactly 2 elements along the specified axis:
88
+ - For cart2pol: expects (x, y), outputs (r, theta)
89
+ - For pol2cart: expects (r, theta), outputs (x, y)
90
+ """
91
+
92
+ def _process(self, message: AxisArray) -> AxisArray:
93
+ axis = self.settings.axis or message.dims[-1]
94
+ axis_idx = message.get_axis_idx(axis)
95
+
96
+ if message.data.shape[axis_idx] != 2:
97
+ raise ValueError(
98
+ f"Coordinate transformation requires exactly 2 elements along axis '{axis}', "
99
+ f"got {message.data.shape[axis_idx]}."
100
+ )
101
+
102
+ # Extract components along the specified axis
103
+ slices_a = [slice(None)] * message.data.ndim
104
+ slices_b = [slice(None)] * message.data.ndim
105
+ slices_a[axis_idx] = 0
106
+ slices_b[axis_idx] = 1
107
+
108
+ component_a = message.data[tuple(slices_a)]
109
+ component_b = message.data[tuple(slices_b)]
110
+
111
+ if self.settings.mode == CoordinateMode.CART2POL:
112
+ # Input: x, y -> Output: r, theta
113
+ out_a, out_b = cart2pol(component_a, component_b)
114
+ else:
115
+ # Input: r, theta -> Output: x, y
116
+ out_a, out_b = pol2cart(component_a, component_b)
117
+
118
+ # Stack results back along the same axis
119
+ result = np.stack([out_a, out_b], axis=axis_idx)
120
+
121
+ # Update axis labels if present
122
+ axes = message.axes
123
+ if axis in axes and hasattr(axes[axis], "data"):
124
+ if self.settings.mode == CoordinateMode.CART2POL:
125
+ new_labels = np.array(["r", "theta"])
126
+ else:
127
+ new_labels = np.array(["x", "y"])
128
+ axes = {**axes, axis: replace(axes[axis], data=new_labels)}
129
+
130
+ return replace(message, data=result, axes=axes)
131
+
132
+
133
+ class CoordinateSpaces(
134
+ BaseTransformerUnit[CoordinateSpacesSettings, AxisArray, AxisArray, CoordinateSpacesTransformer]
135
+ ):
136
+ """
137
+ Unit for transforming between Cartesian and polar coordinate systems.
138
+
139
+ See :obj:`CoordinateSpacesSettings` for configuration options.
140
+ """
141
+
142
+ SETTINGS = CoordinateSpacesSettings
ezmsg/sigproc/decimate.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import typing
2
2
 
3
3
  import ezmsg.core as ez
4
+ from ezmsg.baseproc import BaseTransformerUnit
4
5
  from ezmsg.util.messages.axisarray import AxisArray
5
6
 
6
- from .base import BaseTransformerUnit
7
- from .cheby import ChebyshevFilterTransformer, ChebyshevFilterSettings
7
+ from .cheby import ChebyshevFilterSettings, ChebyshevFilterTransformer
8
8
  from .downsample import Downsample, DownsampleSettings
9
9
  from .filter import BACoeffs, SOSCoeffs
10
10
 
@@ -30,11 +30,7 @@ class ChebyForDecimateTransformer(ChebyshevFilterTransformer[BACoeffs | SOSCoeff
30
30
  return cheby_opt_design_fun
31
31
 
32
32
 
33
- class ChebyForDecimate(
34
- BaseTransformerUnit[
35
- ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer
36
- ]
37
- ):
33
+ class ChebyForDecimate(BaseTransformerUnit[ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer]):
38
34
  SETTINGS = ChebyshevFilterSettings
39
35
 
40
36
 
@@ -1,13 +1,14 @@
1
1
  import ezmsg.core as ez
2
2
  import numpy as np
3
3
  import numpy.typing as npt
4
+ from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.messages.util import replace
6
+
4
7
  from ezmsg.sigproc.base import (
5
- BaseTransformerUnit,
6
8
  BaseStatefulTransformer,
9
+ BaseTransformerUnit,
7
10
  processor_state,
8
11
  )
9
- from ezmsg.util.messages.axisarray import AxisArray
10
- from ezmsg.util.messages.util import replace
11
12
 
12
13
 
13
14
  class DenormalizeSettings(ez.Settings):
@@ -27,9 +28,7 @@ class DenormalizeState:
27
28
  offsets: npt.NDArray | None = None
28
29
 
29
30
 
30
- class DenormalizeTransformer(
31
- BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]
32
- ):
31
+ class DenormalizeTransformer(BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]):
33
32
  """
34
33
  Scales data from a normalized distribution (mean=0, std=1) to a denormalized
35
34
  distribution using random per-channel offsets and gains designed to keep the
@@ -76,9 +75,5 @@ class DenormalizeTransformer(
76
75
  )
77
76
 
78
77
 
79
- class DenormalizeUnit(
80
- BaseTransformerUnit[
81
- DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer
82
- ]
83
- ):
78
+ class DenormalizeUnit(BaseTransformerUnit[DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer]):
84
79
  SETTINGS = DenormalizeSettings
ezmsg/sigproc/detrend.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import scipy.signal as sps
2
2
  from ezmsg.util.messages.axisarray import AxisArray, replace
3
- from ezmsg.sigproc.ewma import EWMATransformer, EWMASettings
3
+
4
4
  from ezmsg.sigproc.base import BaseTransformerUnit
5
+ from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer
5
6
 
6
7
 
7
8
  class DetrendTransformer(EWMATransformer):
@@ -23,7 +24,5 @@ class DetrendTransformer(EWMATransformer):
23
24
  return replace(message, data=message.data - means)
24
25
 
25
26
 
26
- class DetrendUnit(
27
- BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]
28
- ):
27
+ class DetrendUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]):
29
28
  SETTINGS = EWMASettings
ezmsg/sigproc/diff.py CHANGED
@@ -1,13 +1,14 @@
1
1
  import ezmsg.core as ez
2
2
  import numpy as np
3
3
  import numpy.typing as npt
4
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
5
+ from ezmsg.util.messages.util import replace
6
+
4
7
  from ezmsg.sigproc.base import (
8
+ BaseStatefulTransformer,
5
9
  BaseTransformerUnit,
6
10
  processor_state,
7
- BaseStatefulTransformer,
8
11
  )
9
- from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
10
- from ezmsg.util.messages.util import replace
11
12
 
12
13
 
13
14
  class DiffSettings(ez.Settings):
@@ -21,9 +22,7 @@ class DiffState:
21
22
  last_time: float | None = None
22
23
 
23
24
 
24
- class DiffTransformer(
25
- BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]
26
- ):
25
+ class DiffTransformer(BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]):
27
26
  def _hash_message(self, message: AxisArray) -> int:
28
27
  ax_idx = message.get_axis_idx(self.settings.axis)
29
28
  sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
@@ -49,20 +48,14 @@ class DiffTransformer(
49
48
  axis=ax_idx,
50
49
  )
51
50
  # Prepare last_dat for next iteration
52
- self.state.last_dat = slice_along_axis(
53
- message.data, slice(-1, None), axis=ax_idx
54
- )
51
+ self.state.last_dat = slice_along_axis(message.data, slice(-1, None), axis=ax_idx)
55
52
  # Scale by fs if requested. This convers the diff to a derivative. e.g., diff of position becomes velocity.
56
53
  if self.settings.scale_by_fs:
57
54
  ax_info = message.get_axis(axis)
58
55
  if hasattr(ax_info, "data"):
59
56
  dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
60
57
  # Expand dt dims to match diffs
61
- exp_sl = (
62
- (None,) * ax_idx
63
- + (Ellipsis,)
64
- + (None,) * (message.data.ndim - ax_idx - 1)
65
- )
58
+ exp_sl = (None,) * ax_idx + (Ellipsis,) + (None,) * (message.data.ndim - ax_idx - 1)
66
59
  diffs /= dt[exp_sl]
67
60
  self.state.last_time = ax_info.data[-1] # For next iteration
68
61
  else:
@@ -71,9 +64,7 @@ class DiffTransformer(
71
64
  return replace(message, data=diffs)
72
65
 
73
66
 
74
- class DiffUnit(
75
- BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]
76
- ):
67
+ class DiffUnit(BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]):
77
68
  SETTINGS = DiffSettings
78
69
 
79
70
 
@@ -1,16 +1,15 @@
1
- import numpy as np
2
- from ezmsg.util.messages.axisarray import (
3
- AxisArray,
4
- slice_along_axis,
5
- replace,
6
- )
7
1
  import ezmsg.core as ez
8
-
9
- from .base import (
2
+ import numpy as np
3
+ from ezmsg.baseproc import (
10
4
  BaseStatefulTransformer,
11
5
  BaseTransformerUnit,
12
6
  processor_state,
13
7
  )
8
+ from ezmsg.util.messages.axisarray import (
9
+ AxisArray,
10
+ replace,
11
+ slice_along_axis,
12
+ )
14
13
 
15
14
 
16
15
  class DownsampleSettings(ez.Settings):
@@ -38,9 +37,7 @@ class DownsampleState:
38
37
  """Index of the next msg's first sample into the virtual rotating ds_factor counter."""
39
38
 
40
39
 
41
- class DownsampleTransformer(
42
- BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]
43
- ):
40
+ class DownsampleTransformer(BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]):
44
41
  """
45
42
  Downsampled data simply comprise every `factor`th sample.
46
43
  This should only be used following appropriate lowpass filtering.
@@ -75,9 +72,7 @@ class DownsampleTransformer(
75
72
  axis_idx = message.get_axis_idx(axis)
76
73
 
77
74
  n_samples = message.data.shape[axis_idx]
78
- samples = (
79
- np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
80
- )
75
+ samples = np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
81
76
  if n_samples > 0:
82
77
  # Update state for next iteration.
83
78
  self._state.s_idx = samples[-1] + 1
@@ -104,9 +99,7 @@ class DownsampleTransformer(
104
99
  return msg_out
105
100
 
106
101
 
107
- class Downsample(
108
- BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]
109
- ):
102
+ class Downsample(BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]):
110
103
  SETTINGS = DownsampleSettings
111
104
 
112
105
 
@@ -115,6 +108,4 @@ def downsample(
115
108
  target_rate: float | None = None,
116
109
  factor: int | None = None,
117
110
  ) -> DownsampleTransformer:
118
- return DownsampleTransformer(
119
- DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor)
120
- )
111
+ return DownsampleTransformer(DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor))
ezmsg/sigproc/ewma.py CHANGED
@@ -1,15 +1,14 @@
1
- from dataclasses import field
2
1
  import functools
2
+ from dataclasses import field
3
3
 
4
+ import ezmsg.core as ez
4
5
  import numpy as np
5
6
  import numpy.typing as npt
6
7
  import scipy.signal as sps
7
- import ezmsg.core as ez
8
+ from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
8
9
  from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
9
10
  from ezmsg.util.messages.util import replace
10
11
 
11
- from .base import BaseStatefulTransformer, processor_state, BaseTransformerUnit
12
-
13
12
 
14
13
  def _tau_from_alpha(alpha: float, dt: float) -> float:
15
14
  """
@@ -29,9 +28,7 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
29
28
  return 1 - np.exp(-dt / tau)
30
29
 
31
30
 
32
- def ewma_step(
33
- sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
34
- ):
31
+ def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None):
35
32
  """
36
33
  Do an exponentially weighted moving average step.
37
34
 
@@ -97,9 +94,7 @@ class EWMA_Deprecated:
97
94
  if self.prev is None:
98
95
  self.prev = arr[:1]
99
96
 
100
- out += self.prev * np.expand_dims(
101
- self.weights[1 : n + 1], list(range(1, arr.ndim))
102
- )
97
+ out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
103
98
 
104
99
  self.prev = out[-1:]
105
100
 
@@ -128,9 +123,7 @@ class EWMA_Deprecated:
128
123
  if self.prev is None:
129
124
  self.prev = arr[:1]
130
125
 
131
- result += self.prev * np.expand_dims(
132
- self.weights[1 : n + 1], list(range(1, arr.ndim))
133
- )
126
+ result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
134
127
 
135
128
  # Store the result back into prev
136
129
  self.prev = result[-1]
@@ -155,25 +148,17 @@ class EWMAState:
155
148
  zi: npt.NDArray | None = None
156
149
 
157
150
 
158
- class EWMATransformer(
159
- BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]
160
- ):
151
+ class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]):
161
152
  def _hash_message(self, message: AxisArray) -> int:
162
153
  axis = self.settings.axis or message.dims[0]
163
154
  axis_idx = message.get_axis_idx(axis)
164
- sample_shape = (
165
- message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
166
- )
155
+ sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
167
156
  return hash((sample_shape, message.axes[axis].gain, message.key))
168
157
 
169
158
  def _reset_state(self, message: AxisArray) -> None:
170
159
  axis = self.settings.axis or message.dims[0]
171
- self._state.alpha = _alpha_from_tau(
172
- self.settings.time_constant, message.axes[axis].gain
173
- )
174
- sub_dat = slice_along_axis(
175
- message.data, slice(None, 1, None), axis=message.get_axis_idx(axis)
176
- )
160
+ self._state.alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain)
161
+ sub_dat = slice_along_axis(message.data, slice(None, 1, None), axis=message.get_axis_idx(axis))
177
162
  self._state.zi = (1 - self._state.alpha) * sub_dat
178
163
 
179
164
  def _process(self, message: AxisArray) -> AxisArray:
@@ -191,7 +176,5 @@ class EWMATransformer(
191
176
  return replace(message, data=expected)
192
177
 
193
178
 
194
- class EWMAUnit(
195
- BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]
196
- ):
179
+ class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
197
180
  SETTINGS = EWMASettings
@@ -2,9 +2,9 @@ import asyncio
2
2
  import typing
3
3
 
4
4
  import ezmsg.core as ez
5
+ import numpy as np
5
6
  from ezmsg.util.messages.axisarray import AxisArray
6
7
  from ezmsg.util.messages.util import replace
7
- import numpy as np
8
8
 
9
9
  from .window import Window, WindowSettings
10
10
 
@@ -1,6 +1,7 @@
1
- import numpy as np
2
1
  import ezmsg.core as ez
2
+ import numpy as np
3
3
  from ezmsg.util.messages.axisarray import AxisArray, replace
4
+
4
5
  from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit
5
6
 
6
7
 
@@ -35,7 +36,5 @@ class ExtractAxisData(BaseTransformer[ExtractAxisSettings, AxisArray, AxisArray]
35
36
  )
36
37
 
37
38
 
38
- class ExtractAxisDataUnit(
39
- BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]
40
- ):
39
+ class ExtractAxisDataUnit(BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]):
41
40
  SETTINGS = ExtractAxisSettings