ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
  4. ezmsg/sigproc/affinetransform.py +13 -38
  5. ezmsg/sigproc/aggregate.py +13 -30
  6. ezmsg/sigproc/bandpower.py +7 -15
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +7 -16
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/decimate.py +2 -6
  13. ezmsg/sigproc/denormalize.py +6 -11
  14. ezmsg/sigproc/detrend.py +3 -4
  15. ezmsg/sigproc/diff.py +8 -17
  16. ezmsg/sigproc/downsample.py +6 -14
  17. ezmsg/sigproc/ewma.py +11 -27
  18. ezmsg/sigproc/ewmfilter.py +1 -1
  19. ezmsg/sigproc/extract_axis.py +3 -4
  20. ezmsg/sigproc/fbcca.py +31 -56
  21. ezmsg/sigproc/filter.py +19 -45
  22. ezmsg/sigproc/filterbank.py +33 -70
  23. ezmsg/sigproc/filterbankdesign.py +5 -12
  24. ezmsg/sigproc/fir_hilbert.py +13 -30
  25. ezmsg/sigproc/fir_pmc.py +5 -10
  26. ezmsg/sigproc/firfilter.py +12 -14
  27. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  28. ezmsg/sigproc/kaiser.py +11 -15
  29. ezmsg/sigproc/math/abs.py +1 -3
  30. ezmsg/sigproc/math/add.py +121 -0
  31. ezmsg/sigproc/math/clip.py +1 -1
  32. ezmsg/sigproc/math/difference.py +98 -36
  33. ezmsg/sigproc/math/invert.py +1 -3
  34. ezmsg/sigproc/math/log.py +2 -6
  35. ezmsg/sigproc/messages.py +1 -2
  36. ezmsg/sigproc/quantize.py +2 -4
  37. ezmsg/sigproc/resample.py +13 -34
  38. ezmsg/sigproc/rollingscaler.py +12 -37
  39. ezmsg/sigproc/sampler.py +17 -35
  40. ezmsg/sigproc/scaler.py +8 -18
  41. ezmsg/sigproc/signalinjector.py +6 -16
  42. ezmsg/sigproc/slicer.py +9 -28
  43. ezmsg/sigproc/spectral.py +3 -3
  44. ezmsg/sigproc/spectrogram.py +12 -19
  45. ezmsg/sigproc/spectrum.py +12 -32
  46. ezmsg/sigproc/transpose.py +7 -18
  47. ezmsg/sigproc/util/asio.py +25 -156
  48. ezmsg/sigproc/util/axisarray_buffer.py +10 -26
  49. ezmsg/sigproc/util/buffer.py +18 -43
  50. ezmsg/sigproc/util/message.py +17 -31
  51. ezmsg/sigproc/util/profile.py +23 -174
  52. ezmsg/sigproc/util/sparse.py +5 -15
  53. ezmsg/sigproc/util/typeresolution.py +17 -83
  54. ezmsg/sigproc/wavelets.py +6 -15
  55. ezmsg/sigproc/window.py +24 -78
  56. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +3 -2
  57. ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
  58. ezmsg/sigproc/synth.py +0 -774
  59. ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
  60. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
  61. /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
@@ -5,11 +5,11 @@ import scipy.signal
5
5
  from scipy.signal import normalize
6
6
 
7
7
  from .filter import (
8
- FilterBaseSettings,
9
8
  BACoeffs,
10
- SOSCoeffs,
11
- FilterByDesignTransformer,
12
9
  BaseFilterByDesignTransformerUnit,
10
+ FilterBaseSettings,
11
+ FilterByDesignTransformer,
12
+ SOSCoeffs,
13
13
  )
14
14
 
15
15
 
@@ -27,14 +27,14 @@ class ButterworthFilterSettings(FilterBaseSettings):
27
27
  """
28
28
  Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
29
29
  if this is lower than `cutoff` then this is the beginning of the bandpass
30
- or if this is greater than `cutoff` then this is the end of the bandstop.
30
+ or if this is greater than `cutoff` then this is the end of the bandstop.
31
31
  """
32
32
 
33
33
  cutoff: float | None = None
34
34
  """
35
35
  Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
36
36
  if this is greater than `cuton` then this is the end of the bandpass,
37
- or if this is less than `cuton` then this is the beginning of the bandstop.
37
+ or if this is less than `cuton` then this is the beginning of the bandstop.
38
38
  """
39
39
 
40
40
  wn_hz: bool = True
@@ -96,9 +96,7 @@ def butter_design_fun(
96
96
  """
97
97
  coefs = None
98
98
  if order > 0:
99
- btype, cutoffs = ButterworthFilterSettings(
100
- order=order, cuton=cuton, cutoff=cutoff
101
- ).filter_specs()
99
+ btype, cutoffs = ButterworthFilterSettings(order=order, cuton=cuton, cutoff=cutoff).filter_specs()
102
100
  coefs = scipy.signal.butter(
103
101
  order,
104
102
  Wn=cutoffs,
@@ -111,9 +109,7 @@ def butter_design_fun(
111
109
  return coefs
112
110
 
113
111
 
114
- class ButterworthFilterTransformer(
115
- FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]
116
- ):
112
+ class ButterworthFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]):
117
113
  def get_design_function(
118
114
  self,
119
115
  ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
@@ -127,11 +123,7 @@ class ButterworthFilterTransformer(
127
123
  )
128
124
 
129
125
 
130
- class ButterworthFilter(
131
- BaseFilterByDesignTransformerUnit[
132
- ButterworthFilterSettings, ButterworthFilterTransformer
133
- ]
134
- ):
126
+ class ButterworthFilter(BaseFilterByDesignTransformerUnit[ButterworthFilterSettings, ButterworthFilterTransformer]):
135
127
  SETTINGS = ButterworthFilterSettings
136
128
 
137
129
 
@@ -4,6 +4,9 @@ import typing
4
4
  import ezmsg.core as ez
5
5
  import numpy as np
6
6
  import scipy.signal
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.util import replace
9
+
7
10
  from ezmsg.sigproc.base import SettingsType
8
11
  from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
9
12
  from ezmsg.sigproc.filter import (
@@ -12,8 +15,6 @@ from ezmsg.sigproc.filter import (
12
15
  FilterByDesignTransformer,
13
16
  SOSCoeffs,
14
17
  )
15
- from ezmsg.util.messages.axisarray import AxisArray
16
- from ezmsg.util.messages.util import replace
17
18
 
18
19
 
19
20
  class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
@@ -34,9 +35,7 @@ class ButterworthZeroPhaseSettings(ButterworthFilterSettings):
34
35
  """
35
36
 
36
37
 
37
- class ButterworthZeroPhaseTransformer(
38
- FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]
39
- ):
38
+ class ButterworthZeroPhaseTransformer(FilterByDesignTransformer[ButterworthZeroPhaseSettings, BACoeffs | SOSCoeffs]):
40
39
  """Zero-phase (filtfilt) Butterworth using your design function."""
41
40
 
42
41
  def get_design_function(
@@ -51,9 +50,7 @@ class ButterworthZeroPhaseTransformer(
51
50
  wn_hz=self.settings.wn_hz,
52
51
  )
53
52
 
54
- def update_settings(
55
- self, new_settings: typing.Optional[SettingsType] = None, **kwargs
56
- ) -> None:
53
+ def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None:
57
54
  """
58
55
  Update settings and mark that filter coefficients need to be recalculated.
59
56
 
@@ -91,11 +88,7 @@ class ButterworthZeroPhaseTransformer(
91
88
  self._fs_cache = fs
92
89
  self.state.needs_redesign = False
93
90
 
94
- if (
95
- self._coefs_cache is None
96
- or self.settings.order <= 0
97
- or message.data.size <= 0
98
- ):
91
+ if self._coefs_cache is None or self.settings.order <= 0 or message.data.size <= 0:
99
92
  return message
100
93
 
101
94
  x = message.data
@@ -125,8 +118,6 @@ class ButterworthZeroPhaseTransformer(
125
118
 
126
119
 
127
120
  class ButterworthZeroPhase(
128
- BaseFilterByDesignTransformerUnit[
129
- ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer
130
- ]
121
+ BaseFilterByDesignTransformerUnit[ButterworthZeroPhaseSettings, ButterworthZeroPhaseTransformer]
131
122
  ):
132
123
  SETTINGS = ButterworthZeroPhaseSettings
ezmsg/sigproc/cheby.py CHANGED
@@ -5,11 +5,11 @@ import scipy.signal
5
5
  from scipy.signal import normalize
6
6
 
7
7
  from .filter import (
8
+ BACoeffs,
9
+ BaseFilterByDesignTransformerUnit,
8
10
  FilterBaseSettings,
9
11
  FilterByDesignTransformer,
10
- BACoeffs,
11
12
  SOSCoeffs,
12
- BaseFilterByDesignTransformerUnit,
13
13
  )
14
14
 
15
15
 
@@ -104,9 +104,7 @@ def cheby_design_fun(
104
104
  return coefs
105
105
 
106
106
 
107
- class ChebyshevFilterTransformer(
108
- FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]
109
- ):
107
+ class ChebyshevFilterTransformer(FilterByDesignTransformer[ChebyshevFilterSettings, BACoeffs | SOSCoeffs]):
110
108
  def get_design_function(
111
109
  self,
112
110
  ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
@@ -123,9 +121,5 @@ class ChebyshevFilterTransformer(
123
121
  )
124
122
 
125
123
 
126
- class ChebyshevFilter(
127
- BaseFilterByDesignTransformerUnit[
128
- ChebyshevFilterSettings, ChebyshevFilterTransformer
129
- ]
130
- ):
124
+ class ChebyshevFilter(BaseFilterByDesignTransformerUnit[ChebyshevFilterSettings, ChebyshevFilterTransformer]):
131
125
  SETTINGS = ChebyshevFilterSettings
@@ -1,15 +1,16 @@
1
1
  import functools
2
2
  import typing
3
+
3
4
  import numpy as np
4
5
  import scipy.signal
5
6
  from scipy.signal import normalize
6
7
 
7
8
  from .filter import (
9
+ BACoeffs,
10
+ BaseFilterByDesignTransformerUnit,
8
11
  FilterBaseSettings,
9
12
  FilterByDesignTransformer,
10
- BACoeffs,
11
13
  SOSCoeffs,
12
- BaseFilterByDesignTransformerUnit,
13
14
  )
14
15
 
15
16
 
@@ -103,9 +104,7 @@ def comb_design_fun(
103
104
  return combined_sos
104
105
 
105
106
 
106
- class CombFilterTransformer(
107
- FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]
108
- ):
107
+ class CombFilterTransformer(FilterByDesignTransformer[CombFilterSettings, BACoeffs | SOSCoeffs]):
109
108
  def get_design_function(
110
109
  self,
111
110
  ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]:
@@ -120,9 +119,7 @@ class CombFilterTransformer(
120
119
  )
121
120
 
122
121
 
123
- class CombFilterUnit(
124
- BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]
125
- ):
122
+ class CombFilterUnit(BaseFilterByDesignTransformerUnit[CombFilterSettings, CombFilterTransformer]):
126
123
  SETTINGS = CombFilterSettings
127
124
 
128
125
 
ezmsg/sigproc/decimate.py CHANGED
@@ -4,7 +4,7 @@ import ezmsg.core as ez
4
4
  from ezmsg.util.messages.axisarray import AxisArray
5
5
 
6
6
  from .base import BaseTransformerUnit
7
- from .cheby import ChebyshevFilterTransformer, ChebyshevFilterSettings
7
+ from .cheby import ChebyshevFilterSettings, ChebyshevFilterTransformer
8
8
  from .downsample import Downsample, DownsampleSettings
9
9
  from .filter import BACoeffs, SOSCoeffs
10
10
 
@@ -30,11 +30,7 @@ class ChebyForDecimateTransformer(ChebyshevFilterTransformer[BACoeffs | SOSCoeff
30
30
  return cheby_opt_design_fun
31
31
 
32
32
 
33
- class ChebyForDecimate(
34
- BaseTransformerUnit[
35
- ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer
36
- ]
37
- ):
33
+ class ChebyForDecimate(BaseTransformerUnit[ChebyshevFilterSettings, AxisArray, AxisArray, ChebyForDecimateTransformer]):
38
34
  SETTINGS = ChebyshevFilterSettings
39
35
 
40
36
 
@@ -1,13 +1,14 @@
1
1
  import ezmsg.core as ez
2
2
  import numpy as np
3
3
  import numpy.typing as npt
4
+ from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.messages.util import replace
6
+
4
7
  from ezmsg.sigproc.base import (
5
- BaseTransformerUnit,
6
8
  BaseStatefulTransformer,
9
+ BaseTransformerUnit,
7
10
  processor_state,
8
11
  )
9
- from ezmsg.util.messages.axisarray import AxisArray
10
- from ezmsg.util.messages.util import replace
11
12
 
12
13
 
13
14
  class DenormalizeSettings(ez.Settings):
@@ -27,9 +28,7 @@ class DenormalizeState:
27
28
  offsets: npt.NDArray | None = None
28
29
 
29
30
 
30
- class DenormalizeTransformer(
31
- BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]
32
- ):
31
+ class DenormalizeTransformer(BaseStatefulTransformer[DenormalizeSettings, AxisArray, AxisArray, DenormalizeState]):
33
32
  """
34
33
  Scales data from a normalized distribution (mean=0, std=1) to a denormalized
35
34
  distribution using random per-channel offsets and gains designed to keep the
@@ -76,9 +75,5 @@ class DenormalizeTransformer(
76
75
  )
77
76
 
78
77
 
79
- class DenormalizeUnit(
80
- BaseTransformerUnit[
81
- DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer
82
- ]
83
- ):
78
+ class DenormalizeUnit(BaseTransformerUnit[DenormalizeSettings, AxisArray, AxisArray, DenormalizeTransformer]):
84
79
  SETTINGS = DenormalizeSettings
ezmsg/sigproc/detrend.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import scipy.signal as sps
2
2
  from ezmsg.util.messages.axisarray import AxisArray, replace
3
- from ezmsg.sigproc.ewma import EWMATransformer, EWMASettings
3
+
4
4
  from ezmsg.sigproc.base import BaseTransformerUnit
5
+ from ezmsg.sigproc.ewma import EWMASettings, EWMATransformer
5
6
 
6
7
 
7
8
  class DetrendTransformer(EWMATransformer):
@@ -23,7 +24,5 @@ class DetrendTransformer(EWMATransformer):
23
24
  return replace(message, data=message.data - means)
24
25
 
25
26
 
26
- class DetrendUnit(
27
- BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]
28
- ):
27
+ class DetrendUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, DetrendTransformer]):
29
28
  SETTINGS = EWMASettings
ezmsg/sigproc/diff.py CHANGED
@@ -1,13 +1,14 @@
1
1
  import ezmsg.core as ez
2
2
  import numpy as np
3
3
  import numpy.typing as npt
4
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
5
+ from ezmsg.util.messages.util import replace
6
+
4
7
  from ezmsg.sigproc.base import (
8
+ BaseStatefulTransformer,
5
9
  BaseTransformerUnit,
6
10
  processor_state,
7
- BaseStatefulTransformer,
8
11
  )
9
- from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
10
- from ezmsg.util.messages.util import replace
11
12
 
12
13
 
13
14
  class DiffSettings(ez.Settings):
@@ -21,9 +22,7 @@ class DiffState:
21
22
  last_time: float | None = None
22
23
 
23
24
 
24
- class DiffTransformer(
25
- BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]
26
- ):
25
+ class DiffTransformer(BaseStatefulTransformer[DiffSettings, AxisArray, AxisArray, DiffState]):
27
26
  def _hash_message(self, message: AxisArray) -> int:
28
27
  ax_idx = message.get_axis_idx(self.settings.axis)
29
28
  sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
@@ -49,20 +48,14 @@ class DiffTransformer(
49
48
  axis=ax_idx,
50
49
  )
51
50
  # Prepare last_dat for next iteration
52
- self.state.last_dat = slice_along_axis(
53
- message.data, slice(-1, None), axis=ax_idx
54
- )
51
+ self.state.last_dat = slice_along_axis(message.data, slice(-1, None), axis=ax_idx)
55
52
  # Scale by fs if requested. This convers the diff to a derivative. e.g., diff of position becomes velocity.
56
53
  if self.settings.scale_by_fs:
57
54
  ax_info = message.get_axis(axis)
58
55
  if hasattr(ax_info, "data"):
59
56
  dt = np.diff(np.concatenate(([self.state.last_time], ax_info.data)))
60
57
  # Expand dt dims to match diffs
61
- exp_sl = (
62
- (None,) * ax_idx
63
- + (Ellipsis,)
64
- + (None,) * (message.data.ndim - ax_idx - 1)
65
- )
58
+ exp_sl = (None,) * ax_idx + (Ellipsis,) + (None,) * (message.data.ndim - ax_idx - 1)
66
59
  diffs /= dt[exp_sl]
67
60
  self.state.last_time = ax_info.data[-1] # For next iteration
68
61
  else:
@@ -71,9 +64,7 @@ class DiffTransformer(
71
64
  return replace(message, data=diffs)
72
65
 
73
66
 
74
- class DiffUnit(
75
- BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]
76
- ):
67
+ class DiffUnit(BaseTransformerUnit[DiffSettings, AxisArray, AxisArray, DiffTransformer]):
77
68
  SETTINGS = DiffSettings
78
69
 
79
70
 
@@ -1,10 +1,10 @@
1
+ import ezmsg.core as ez
1
2
  import numpy as np
2
3
  from ezmsg.util.messages.axisarray import (
3
4
  AxisArray,
4
- slice_along_axis,
5
5
  replace,
6
+ slice_along_axis,
6
7
  )
7
- import ezmsg.core as ez
8
8
 
9
9
  from .base import (
10
10
  BaseStatefulTransformer,
@@ -38,9 +38,7 @@ class DownsampleState:
38
38
  """Index of the next msg's first sample into the virtual rotating ds_factor counter."""
39
39
 
40
40
 
41
- class DownsampleTransformer(
42
- BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]
43
- ):
41
+ class DownsampleTransformer(BaseStatefulTransformer[DownsampleSettings, AxisArray, AxisArray, DownsampleState]):
44
42
  """
45
43
  Downsampled data simply comprise every `factor`th sample.
46
44
  This should only be used following appropriate lowpass filtering.
@@ -75,9 +73,7 @@ class DownsampleTransformer(
75
73
  axis_idx = message.get_axis_idx(axis)
76
74
 
77
75
  n_samples = message.data.shape[axis_idx]
78
- samples = (
79
- np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
80
- )
76
+ samples = np.arange(self.state.s_idx, self.state.s_idx + n_samples) % self._state.q
81
77
  if n_samples > 0:
82
78
  # Update state for next iteration.
83
79
  self._state.s_idx = samples[-1] + 1
@@ -104,9 +100,7 @@ class DownsampleTransformer(
104
100
  return msg_out
105
101
 
106
102
 
107
- class Downsample(
108
- BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]
109
- ):
103
+ class Downsample(BaseTransformerUnit[DownsampleSettings, AxisArray, AxisArray, DownsampleTransformer]):
110
104
  SETTINGS = DownsampleSettings
111
105
 
112
106
 
@@ -115,6 +109,4 @@ def downsample(
115
109
  target_rate: float | None = None,
116
110
  factor: int | None = None,
117
111
  ) -> DownsampleTransformer:
118
- return DownsampleTransformer(
119
- DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor)
120
- )
112
+ return DownsampleTransformer(DownsampleSettings(axis=axis, target_rate=target_rate, factor=factor))
ezmsg/sigproc/ewma.py CHANGED
@@ -1,14 +1,14 @@
1
- from dataclasses import field
2
1
  import functools
2
+ from dataclasses import field
3
3
 
4
+ import ezmsg.core as ez
4
5
  import numpy as np
5
6
  import numpy.typing as npt
6
7
  import scipy.signal as sps
7
- import ezmsg.core as ez
8
8
  from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
9
9
  from ezmsg.util.messages.util import replace
10
10
 
11
- from .base import BaseStatefulTransformer, processor_state, BaseTransformerUnit
11
+ from .base import BaseStatefulTransformer, BaseTransformerUnit, processor_state
12
12
 
13
13
 
14
14
  def _tau_from_alpha(alpha: float, dt: float) -> float:
@@ -29,9 +29,7 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
29
29
  return 1 - np.exp(-dt / tau)
30
30
 
31
31
 
32
- def ewma_step(
33
- sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None
34
- ):
32
+ def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None):
35
33
  """
36
34
  Do an exponentially weighted moving average step.
37
35
 
@@ -97,9 +95,7 @@ class EWMA_Deprecated:
97
95
  if self.prev is None:
98
96
  self.prev = arr[:1]
99
97
 
100
- out += self.prev * np.expand_dims(
101
- self.weights[1 : n + 1], list(range(1, arr.ndim))
102
- )
98
+ out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
103
99
 
104
100
  self.prev = out[-1:]
105
101
 
@@ -128,9 +124,7 @@ class EWMA_Deprecated:
128
124
  if self.prev is None:
129
125
  self.prev = arr[:1]
130
126
 
131
- result += self.prev * np.expand_dims(
132
- self.weights[1 : n + 1], list(range(1, arr.ndim))
133
- )
127
+ result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
134
128
 
135
129
  # Store the result back into prev
136
130
  self.prev = result[-1]
@@ -155,25 +149,17 @@ class EWMAState:
155
149
  zi: npt.NDArray | None = None
156
150
 
157
151
 
158
- class EWMATransformer(
159
- BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]
160
- ):
152
+ class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]):
161
153
  def _hash_message(self, message: AxisArray) -> int:
162
154
  axis = self.settings.axis or message.dims[0]
163
155
  axis_idx = message.get_axis_idx(axis)
164
- sample_shape = (
165
- message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
166
- )
156
+ sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :]
167
157
  return hash((sample_shape, message.axes[axis].gain, message.key))
168
158
 
169
159
  def _reset_state(self, message: AxisArray) -> None:
170
160
  axis = self.settings.axis or message.dims[0]
171
- self._state.alpha = _alpha_from_tau(
172
- self.settings.time_constant, message.axes[axis].gain
173
- )
174
- sub_dat = slice_along_axis(
175
- message.data, slice(None, 1, None), axis=message.get_axis_idx(axis)
176
- )
161
+ self._state.alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain)
162
+ sub_dat = slice_along_axis(message.data, slice(None, 1, None), axis=message.get_axis_idx(axis))
177
163
  self._state.zi = (1 - self._state.alpha) * sub_dat
178
164
 
179
165
  def _process(self, message: AxisArray) -> AxisArray:
@@ -191,7 +177,5 @@ class EWMATransformer(
191
177
  return replace(message, data=expected)
192
178
 
193
179
 
194
- class EWMAUnit(
195
- BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]
196
- ):
180
+ class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
197
181
  SETTINGS = EWMASettings
@@ -2,9 +2,9 @@ import asyncio
2
2
  import typing
3
3
 
4
4
  import ezmsg.core as ez
5
+ import numpy as np
5
6
  from ezmsg.util.messages.axisarray import AxisArray
6
7
  from ezmsg.util.messages.util import replace
7
- import numpy as np
8
8
 
9
9
  from .window import Window, WindowSettings
10
10
 
@@ -1,6 +1,7 @@
1
- import numpy as np
2
1
  import ezmsg.core as ez
2
+ import numpy as np
3
3
  from ezmsg.util.messages.axisarray import AxisArray, replace
4
+
4
5
  from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit
5
6
 
6
7
 
@@ -35,7 +36,5 @@ class ExtractAxisData(BaseTransformer[ExtractAxisSettings, AxisArray, AxisArray]
35
36
  )
36
37
 
37
38
 
38
- class ExtractAxisDataUnit(
39
- BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]
40
- ):
39
+ class ExtractAxisDataUnit(BaseTransformerUnit[ExtractAxisSettings, AxisArray, AxisArray, ExtractAxisData]):
41
40
  SETTINGS = ExtractAxisSettings