ezmsg-sigproc 2.6.0__py3-none-any.whl → 2.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +1 -1
- ezmsg/sigproc/adaptive_lattice_notch.py +1 -2
- ezmsg/sigproc/affinetransform.py +26 -5
- ezmsg/sigproc/aggregate.py +6 -6
- ezmsg/sigproc/bandpower.py +6 -6
- ezmsg/sigproc/butterworthzerophase.py +1 -1
- ezmsg/sigproc/coordinatespaces.py +142 -0
- ezmsg/sigproc/decimate.py +1 -1
- ezmsg/sigproc/denormalize.py +3 -4
- ezmsg/sigproc/detrend.py +1 -1
- ezmsg/sigproc/diff.py +3 -4
- ezmsg/sigproc/downsample.py +5 -6
- ezmsg/sigproc/ewma.py +45 -9
- ezmsg/sigproc/extract_axis.py +1 -2
- ezmsg/sigproc/fbcca.py +4 -4
- ezmsg/sigproc/filter.py +3 -4
- ezmsg/sigproc/filterbank.py +5 -5
- ezmsg/sigproc/filterbankdesign.py +4 -4
- ezmsg/sigproc/fir_hilbert.py +1 -1
- ezmsg/sigproc/linear.py +118 -0
- ezmsg/sigproc/math/abs.py +3 -0
- ezmsg/sigproc/math/add.py +1 -1
- ezmsg/sigproc/math/clip.py +3 -0
- ezmsg/sigproc/math/difference.py +2 -0
- ezmsg/sigproc/math/invert.py +2 -0
- ezmsg/sigproc/math/log.py +3 -0
- ezmsg/sigproc/math/scale.py +2 -0
- ezmsg/sigproc/quantize.py +1 -2
- ezmsg/sigproc/resample.py +4 -4
- ezmsg/sigproc/rollingscaler.py +4 -4
- ezmsg/sigproc/sampler.py +6 -6
- ezmsg/sigproc/scaler.py +56 -8
- ezmsg/sigproc/signalinjector.py +3 -4
- ezmsg/sigproc/slicer.py +5 -6
- ezmsg/sigproc/spectrogram.py +4 -4
- ezmsg/sigproc/spectrum.py +5 -6
- ezmsg/sigproc/transpose.py +5 -6
- ezmsg/sigproc/util/axisarray_buffer.py +2 -0
- ezmsg/sigproc/util/buffer.py +4 -0
- ezmsg/sigproc/util/sparse.py +2 -0
- ezmsg/sigproc/wavelets.py +4 -4
- ezmsg/sigproc/window.py +5 -5
- ezmsg_sigproc-2.8.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.8.0.dist-info/RECORD +65 -0
- ezmsg_sigproc-2.6.0.dist-info/METADATA +0 -73
- ezmsg_sigproc-2.6.0.dist-info/RECORD +0 -63
- {ezmsg_sigproc-2.6.0.dist-info → ezmsg_sigproc-2.8.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-2.6.0.dist-info → ezmsg_sigproc-2.8.0.dist-info}/licenses/LICENSE +0 -0
ezmsg/sigproc/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.
|
|
32
|
-
__version_tuple__ = version_tuple = (2,
|
|
31
|
+
__version__ = version = '2.8.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 8, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
ezmsg/sigproc/activation.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
2
|
import scipy.special
|
|
3
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
3
4
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
5
|
from ezmsg.util.messages.util import replace
|
|
5
6
|
|
|
6
|
-
from .base import BaseTransformer, BaseTransformerUnit
|
|
7
7
|
from .spectral import OptionsEnum
|
|
8
8
|
|
|
9
9
|
|
|
@@ -2,11 +2,10 @@ import ezmsg.core as ez
|
|
|
2
2
|
import numpy as np
|
|
3
3
|
import numpy.typing as npt
|
|
4
4
|
import scipy.signal
|
|
5
|
+
from ezmsg.baseproc import BaseStatefulTransformer, processor_state
|
|
5
6
|
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
|
|
6
7
|
from ezmsg.util.messages.util import replace
|
|
7
8
|
|
|
8
|
-
from .base import BaseStatefulTransformer, processor_state
|
|
9
|
-
|
|
10
9
|
|
|
11
10
|
class AdaptiveLatticeNotchFilterSettings(ez.Settings):
|
|
12
11
|
"""Settings for the Adaptive Lattice Notch Filter."""
|
ezmsg/sigproc/affinetransform.py
CHANGED
|
@@ -1,24 +1,32 @@
|
|
|
1
|
+
"""Affine transformations via matrix multiplication: y = Ax or y = Ax + B.
|
|
2
|
+
|
|
3
|
+
For full matrix transformations where channels are mixed (off-diagonal weights),
|
|
4
|
+
use :obj:`AffineTransformTransformer` or the `AffineTransform` unit.
|
|
5
|
+
|
|
6
|
+
For simple per-channel scaling and offset (diagonal weights only), use
|
|
7
|
+
:obj:`LinearTransformTransformer` from :mod:`ezmsg.sigproc.linear` instead,
|
|
8
|
+
which is more efficient as it avoids matrix multiplication.
|
|
9
|
+
"""
|
|
10
|
+
|
|
1
11
|
import os
|
|
2
12
|
from pathlib import Path
|
|
3
13
|
|
|
4
14
|
import ezmsg.core as ez
|
|
5
15
|
import numpy as np
|
|
6
16
|
import numpy.typing as npt
|
|
7
|
-
from ezmsg.
|
|
8
|
-
from ezmsg.util.messages.util import replace
|
|
9
|
-
|
|
10
|
-
from .base import (
|
|
17
|
+
from ezmsg.baseproc import (
|
|
11
18
|
BaseStatefulTransformer,
|
|
12
19
|
BaseTransformer,
|
|
13
20
|
BaseTransformerUnit,
|
|
14
21
|
processor_state,
|
|
15
22
|
)
|
|
23
|
+
from ezmsg.util.messages.axisarray import AxisArray, AxisBase
|
|
24
|
+
from ezmsg.util.messages.util import replace
|
|
16
25
|
|
|
17
26
|
|
|
18
27
|
class AffineTransformSettings(ez.Settings):
|
|
19
28
|
"""
|
|
20
29
|
Settings for :obj:`AffineTransform`.
|
|
21
|
-
See :obj:`affine_transform` for argument details.
|
|
22
30
|
"""
|
|
23
31
|
|
|
24
32
|
weights: np.ndarray | str | Path
|
|
@@ -40,6 +48,19 @@ class AffineTransformState:
|
|
|
40
48
|
class AffineTransformTransformer(
|
|
41
49
|
BaseStatefulTransformer[AffineTransformSettings, AxisArray, AxisArray, AffineTransformState]
|
|
42
50
|
):
|
|
51
|
+
"""Apply affine transformation via matrix multiplication: y = Ax or y = Ax + B.
|
|
52
|
+
|
|
53
|
+
Use this transformer when you need full matrix transformations that mix
|
|
54
|
+
channels (off-diagonal weights), such as spatial filters or projections.
|
|
55
|
+
|
|
56
|
+
For simple per-channel scaling and offset where each output channel depends
|
|
57
|
+
only on its corresponding input channel (diagonal weight matrix), use
|
|
58
|
+
:obj:`LinearTransformTransformer` instead, which is more efficient.
|
|
59
|
+
|
|
60
|
+
The weights matrix can include an offset row (stacked as [A|B]) where the
|
|
61
|
+
input is automatically augmented with a column of ones to compute y = Ax + B.
|
|
62
|
+
"""
|
|
63
|
+
|
|
43
64
|
def __call__(self, message: AxisArray) -> AxisArray:
|
|
44
65
|
# Override __call__ so we can shortcut if weights are None.
|
|
45
66
|
if self.settings.weights is None or (
|
ezmsg/sigproc/aggregate.py
CHANGED
|
@@ -4,6 +4,12 @@ import ezmsg.core as ez
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
6
|
from array_api_compat import get_namespace
|
|
7
|
+
from ezmsg.baseproc import (
|
|
8
|
+
BaseStatefulTransformer,
|
|
9
|
+
BaseTransformer,
|
|
10
|
+
BaseTransformerUnit,
|
|
11
|
+
processor_state,
|
|
12
|
+
)
|
|
7
13
|
from ezmsg.util.messages.axisarray import (
|
|
8
14
|
AxisArray,
|
|
9
15
|
AxisBase,
|
|
@@ -11,12 +17,6 @@ from ezmsg.util.messages.axisarray import (
|
|
|
11
17
|
slice_along_axis,
|
|
12
18
|
)
|
|
13
19
|
|
|
14
|
-
from .base import (
|
|
15
|
-
BaseStatefulTransformer,
|
|
16
|
-
BaseTransformer,
|
|
17
|
-
BaseTransformerUnit,
|
|
18
|
-
processor_state,
|
|
19
|
-
)
|
|
20
20
|
from .spectral import OptionsEnum
|
|
21
21
|
|
|
22
22
|
|
ezmsg/sigproc/bandpower.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
from dataclasses import field
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseProcessor,
|
|
6
|
+
BaseStatefulProcessor,
|
|
7
|
+
BaseTransformerUnit,
|
|
8
|
+
CompositeProcessor,
|
|
9
|
+
)
|
|
4
10
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
11
|
|
|
6
12
|
from .aggregate import (
|
|
@@ -8,12 +14,6 @@ from .aggregate import (
|
|
|
8
14
|
RangedAggregateSettings,
|
|
9
15
|
RangedAggregateTransformer,
|
|
10
16
|
)
|
|
11
|
-
from .base import (
|
|
12
|
-
BaseProcessor,
|
|
13
|
-
BaseStatefulProcessor,
|
|
14
|
-
BaseTransformerUnit,
|
|
15
|
-
CompositeProcessor,
|
|
16
|
-
)
|
|
17
17
|
from .spectrogram import SpectrogramSettings, SpectrogramTransformer
|
|
18
18
|
|
|
19
19
|
|
|
@@ -4,10 +4,10 @@ import typing
|
|
|
4
4
|
import ezmsg.core as ez
|
|
5
5
|
import numpy as np
|
|
6
6
|
import scipy.signal
|
|
7
|
+
from ezmsg.baseproc import SettingsType
|
|
7
8
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
9
|
from ezmsg.util.messages.util import replace
|
|
9
10
|
|
|
10
|
-
from ezmsg.sigproc.base import SettingsType
|
|
11
11
|
from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings, butter_design_fun
|
|
12
12
|
from ezmsg.sigproc.filter import (
|
|
13
13
|
BACoeffs,
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Coordinate space transformations for streaming data.
|
|
3
|
+
|
|
4
|
+
This module provides utilities and ezmsg nodes for transforming between
|
|
5
|
+
Cartesian (x, y) and polar (r, theta) coordinate systems.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
import ezmsg.core as ez
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
from ezmsg.baseproc import (
|
|
15
|
+
BaseTransformer,
|
|
16
|
+
BaseTransformerUnit,
|
|
17
|
+
)
|
|
18
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
19
|
+
|
|
20
|
+
# -- Utility functions for coordinate transformations --
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def polar2z(r: npt.ArrayLike, theta: npt.ArrayLike) -> npt.ArrayLike:
|
|
24
|
+
"""Convert polar coordinates to complex number representation."""
|
|
25
|
+
return r * np.exp(1j * theta)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def z2polar(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
29
|
+
"""Convert complex number to polar coordinates (r, theta)."""
|
|
30
|
+
return np.abs(z), np.angle(z)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def cart2z(x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike:
|
|
34
|
+
"""Convert Cartesian coordinates to complex number representation."""
|
|
35
|
+
return x + 1j * y
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def z2cart(z: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
39
|
+
"""Convert complex number to Cartesian coordinates (x, y)."""
|
|
40
|
+
return np.real(z), np.imag(z)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def cart2pol(x: npt.ArrayLike, y: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
44
|
+
"""Convert Cartesian coordinates (x, y) to polar coordinates (r, theta)."""
|
|
45
|
+
return z2polar(cart2z(x, y))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def pol2cart(r: npt.ArrayLike, theta: npt.ArrayLike) -> Tuple[npt.ArrayLike, npt.ArrayLike]:
|
|
49
|
+
"""Convert polar coordinates (r, theta) to Cartesian coordinates (x, y)."""
|
|
50
|
+
return z2cart(polar2z(r, theta))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# -- ezmsg transformer classes --
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class CoordinateMode(str, Enum):
|
|
57
|
+
"""Transformation mode for coordinate conversion."""
|
|
58
|
+
|
|
59
|
+
CART2POL = "cart2pol"
|
|
60
|
+
"""Convert Cartesian (x, y) to polar (r, theta)."""
|
|
61
|
+
|
|
62
|
+
POL2CART = "pol2cart"
|
|
63
|
+
"""Convert polar (r, theta) to Cartesian (x, y)."""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class CoordinateSpacesSettings(ez.Settings):
|
|
67
|
+
"""
|
|
68
|
+
Settings for :obj:`CoordinateSpaces`.
|
|
69
|
+
|
|
70
|
+
See :obj:`coordinate_spaces` for argument details.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
mode: CoordinateMode = CoordinateMode.CART2POL
|
|
74
|
+
"""The transformation mode: 'cart2pol' or 'pol2cart'."""
|
|
75
|
+
|
|
76
|
+
axis: str | None = None
|
|
77
|
+
"""
|
|
78
|
+
The name of the axis containing the coordinate components.
|
|
79
|
+
Defaults to the last axis. Must have exactly 2 elements (x,y or r,theta).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class CoordinateSpacesTransformer(BaseTransformer[CoordinateSpacesSettings, AxisArray, AxisArray]):
|
|
84
|
+
"""
|
|
85
|
+
Transform between Cartesian and polar coordinate systems.
|
|
86
|
+
|
|
87
|
+
The input must have exactly 2 elements along the specified axis:
|
|
88
|
+
- For cart2pol: expects (x, y), outputs (r, theta)
|
|
89
|
+
- For pol2cart: expects (r, theta), outputs (x, y)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
93
|
+
axis = self.settings.axis or message.dims[-1]
|
|
94
|
+
axis_idx = message.get_axis_idx(axis)
|
|
95
|
+
|
|
96
|
+
if message.data.shape[axis_idx] != 2:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Coordinate transformation requires exactly 2 elements along axis '{axis}', "
|
|
99
|
+
f"got {message.data.shape[axis_idx]}."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Extract components along the specified axis
|
|
103
|
+
slices_a = [slice(None)] * message.data.ndim
|
|
104
|
+
slices_b = [slice(None)] * message.data.ndim
|
|
105
|
+
slices_a[axis_idx] = 0
|
|
106
|
+
slices_b[axis_idx] = 1
|
|
107
|
+
|
|
108
|
+
component_a = message.data[tuple(slices_a)]
|
|
109
|
+
component_b = message.data[tuple(slices_b)]
|
|
110
|
+
|
|
111
|
+
if self.settings.mode == CoordinateMode.CART2POL:
|
|
112
|
+
# Input: x, y -> Output: r, theta
|
|
113
|
+
out_a, out_b = cart2pol(component_a, component_b)
|
|
114
|
+
else:
|
|
115
|
+
# Input: r, theta -> Output: x, y
|
|
116
|
+
out_a, out_b = pol2cart(component_a, component_b)
|
|
117
|
+
|
|
118
|
+
# Stack results back along the same axis
|
|
119
|
+
result = np.stack([out_a, out_b], axis=axis_idx)
|
|
120
|
+
|
|
121
|
+
# Update axis labels if present
|
|
122
|
+
axes = message.axes
|
|
123
|
+
if axis in axes and hasattr(axes[axis], "data"):
|
|
124
|
+
if self.settings.mode == CoordinateMode.CART2POL:
|
|
125
|
+
new_labels = np.array(["r", "theta"])
|
|
126
|
+
else:
|
|
127
|
+
new_labels = np.array(["x", "y"])
|
|
128
|
+
axes = {**axes, axis: replace(axes[axis], data=new_labels)}
|
|
129
|
+
|
|
130
|
+
return replace(message, data=result, axes=axes)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class CoordinateSpaces(
|
|
134
|
+
BaseTransformerUnit[CoordinateSpacesSettings, AxisArray, AxisArray, CoordinateSpacesTransformer]
|
|
135
|
+
):
|
|
136
|
+
"""
|
|
137
|
+
Unit for transforming between Cartesian and polar coordinate systems.
|
|
138
|
+
|
|
139
|
+
See :obj:`CoordinateSpacesSettings` for configuration options.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
SETTINGS = CoordinateSpacesSettings
|
ezmsg/sigproc/decimate.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.baseproc import BaseTransformerUnit
|
|
4
5
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
6
|
|
|
6
|
-
from .base import BaseTransformerUnit
|
|
7
7
|
from .cheby import ChebyshevFilterSettings, ChebyshevFilterTransformer
|
|
8
8
|
from .downsample import Downsample, DownsampleSettings
|
|
9
9
|
from .filter import BACoeffs, SOSCoeffs
|
ezmsg/sigproc/denormalize.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
2
|
import numpy as np
|
|
3
3
|
import numpy.typing as npt
|
|
4
|
-
from ezmsg.
|
|
5
|
-
from ezmsg.util.messages.util import replace
|
|
6
|
-
|
|
7
|
-
from ezmsg.sigproc.base import (
|
|
4
|
+
from ezmsg.baseproc import (
|
|
8
5
|
BaseStatefulTransformer,
|
|
9
6
|
BaseTransformerUnit,
|
|
10
7
|
processor_state,
|
|
11
8
|
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
class DenormalizeSettings(ez.Settings):
|
ezmsg/sigproc/detrend.py
CHANGED
ezmsg/sigproc/diff.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
2
|
import numpy as np
|
|
3
3
|
import numpy.typing as npt
|
|
4
|
-
from ezmsg.
|
|
5
|
-
from ezmsg.util.messages.util import replace
|
|
6
|
-
|
|
7
|
-
from ezmsg.sigproc.base import (
|
|
4
|
+
from ezmsg.baseproc import (
|
|
8
5
|
BaseStatefulTransformer,
|
|
9
6
|
BaseTransformerUnit,
|
|
10
7
|
processor_state,
|
|
11
8
|
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
class DiffSettings(ez.Settings):
|
ezmsg/sigproc/downsample.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
2
|
import numpy as np
|
|
3
|
+
from ezmsg.baseproc import (
|
|
4
|
+
BaseStatefulTransformer,
|
|
5
|
+
BaseTransformerUnit,
|
|
6
|
+
processor_state,
|
|
7
|
+
)
|
|
3
8
|
from ezmsg.util.messages.axisarray import (
|
|
4
9
|
AxisArray,
|
|
5
10
|
replace,
|
|
6
11
|
slice_along_axis,
|
|
7
12
|
)
|
|
8
13
|
|
|
9
|
-
from .base import (
|
|
10
|
-
BaseStatefulTransformer,
|
|
11
|
-
BaseTransformerUnit,
|
|
12
|
-
processor_state,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
14
|
|
|
16
15
|
class DownsampleSettings(ez.Settings):
|
|
17
16
|
"""
|
ezmsg/sigproc/ewma.py
CHANGED
|
@@ -5,11 +5,10 @@ import ezmsg.core as ez
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import numpy.typing as npt
|
|
7
7
|
import scipy.signal as sps
|
|
8
|
+
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
|
|
8
9
|
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
9
10
|
from ezmsg.util.messages.util import replace
|
|
10
11
|
|
|
11
|
-
from .base import BaseStatefulTransformer, BaseTransformerUnit, processor_state
|
|
12
|
-
|
|
13
12
|
|
|
14
13
|
def _tau_from_alpha(alpha: float, dt: float) -> float:
|
|
15
14
|
"""
|
|
@@ -140,8 +139,15 @@ class EWMA_Deprecated:
|
|
|
140
139
|
|
|
141
140
|
class EWMASettings(ez.Settings):
|
|
142
141
|
time_constant: float = 1.0
|
|
142
|
+
"""The amount of time for the smoothed response of a unit step function to reach 1 - 1/e approx-eq 63.2%."""
|
|
143
|
+
|
|
143
144
|
axis: str | None = None
|
|
144
145
|
|
|
146
|
+
accumulate: bool = True
|
|
147
|
+
"""If True, update the EWMA state with each sample. If False, only apply
|
|
148
|
+
the current EWMA estimate without updating state (useful for inference
|
|
149
|
+
periods where you don't want to adapt statistics)."""
|
|
150
|
+
|
|
145
151
|
|
|
146
152
|
@processor_state
|
|
147
153
|
class EWMAState:
|
|
@@ -167,15 +173,45 @@ class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray
|
|
|
167
173
|
return message
|
|
168
174
|
axis = self.settings.axis or message.dims[0]
|
|
169
175
|
axis_idx = message.get_axis_idx(axis)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
176
|
+
if self.settings.accumulate:
|
|
177
|
+
# Normal behavior: update state with new samples
|
|
178
|
+
expected, self._state.zi = sps.lfilter(
|
|
179
|
+
[self._state.alpha],
|
|
180
|
+
[1.0, self._state.alpha - 1.0],
|
|
181
|
+
message.data,
|
|
182
|
+
axis=axis_idx,
|
|
183
|
+
zi=self._state.zi,
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
# Process-only: compute output without updating state
|
|
187
|
+
expected, _ = sps.lfilter(
|
|
188
|
+
[self._state.alpha],
|
|
189
|
+
[1.0, self._state.alpha - 1.0],
|
|
190
|
+
message.data,
|
|
191
|
+
axis=axis_idx,
|
|
192
|
+
zi=self._state.zi,
|
|
193
|
+
)
|
|
177
194
|
return replace(message, data=expected)
|
|
178
195
|
|
|
179
196
|
|
|
180
197
|
class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
|
|
181
198
|
SETTINGS = EWMASettings
|
|
199
|
+
|
|
200
|
+
@ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
|
|
201
|
+
async def on_settings(self, msg: EWMASettings) -> None:
|
|
202
|
+
"""
|
|
203
|
+
Handle settings updates with smart reset behavior.
|
|
204
|
+
|
|
205
|
+
Only resets state if `axis` changes (structural change).
|
|
206
|
+
Changes to `time_constant` or `accumulate` are applied without
|
|
207
|
+
resetting accumulated state.
|
|
208
|
+
"""
|
|
209
|
+
old_axis = self.SETTINGS.axis
|
|
210
|
+
self.apply_settings(msg)
|
|
211
|
+
|
|
212
|
+
if msg.axis != old_axis:
|
|
213
|
+
# Axis changed - need full reset
|
|
214
|
+
self.create_processor()
|
|
215
|
+
else:
|
|
216
|
+
# Only accumulate or time_constant changed - keep state
|
|
217
|
+
self.processor.settings = msg
|
ezmsg/sigproc/extract_axis.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
2
|
import numpy as np
|
|
3
|
+
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
3
4
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
4
5
|
|
|
5
|
-
from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit
|
|
6
|
-
|
|
7
6
|
|
|
8
7
|
class ExtractAxisSettings(ez.Settings):
|
|
9
8
|
axis: str = "freq"
|
ezmsg/sigproc/fbcca.py
CHANGED
|
@@ -4,16 +4,16 @@ from dataclasses import field
|
|
|
4
4
|
|
|
5
5
|
import ezmsg.core as ez
|
|
6
6
|
import numpy as np
|
|
7
|
-
from ezmsg.
|
|
8
|
-
from ezmsg.util.messages.util import replace
|
|
9
|
-
|
|
10
|
-
from .base import (
|
|
7
|
+
from ezmsg.baseproc import (
|
|
11
8
|
BaseProcessor,
|
|
12
9
|
BaseStatefulProcessor,
|
|
13
10
|
BaseTransformer,
|
|
14
11
|
BaseTransformerUnit,
|
|
15
12
|
CompositeProcessor,
|
|
16
13
|
)
|
|
14
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
15
|
+
from ezmsg.util.messages.util import replace
|
|
16
|
+
|
|
17
17
|
from .filterbankdesign import (
|
|
18
18
|
FilterbankDesignSettings,
|
|
19
19
|
FilterbankDesignTransformer,
|
ezmsg/sigproc/filter.py
CHANGED
|
@@ -6,10 +6,7 @@ import ezmsg.core as ez
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import numpy.typing as npt
|
|
8
8
|
import scipy.signal
|
|
9
|
-
from ezmsg.
|
|
10
|
-
from ezmsg.util.messages.util import replace
|
|
11
|
-
|
|
12
|
-
from ezmsg.sigproc.base import (
|
|
9
|
+
from ezmsg.baseproc import (
|
|
13
10
|
BaseConsumerUnit,
|
|
14
11
|
BaseStatefulTransformer,
|
|
15
12
|
BaseTransformerUnit,
|
|
@@ -17,6 +14,8 @@ from ezmsg.sigproc.base import (
|
|
|
17
14
|
TransformerType,
|
|
18
15
|
processor_state,
|
|
19
16
|
)
|
|
17
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
18
|
+
from ezmsg.util.messages.util import replace
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
@dataclass
|
ezmsg/sigproc/filterbank.py
CHANGED
|
@@ -7,15 +7,15 @@ import numpy as np
|
|
|
7
7
|
import numpy.typing as npt
|
|
8
8
|
import scipy.fft as sp_fft
|
|
9
9
|
import scipy.signal as sps
|
|
10
|
-
from ezmsg.
|
|
11
|
-
from ezmsg.util.messages.util import replace
|
|
12
|
-
from scipy.special import lambertw
|
|
13
|
-
|
|
14
|
-
from .base import (
|
|
10
|
+
from ezmsg.baseproc import (
|
|
15
11
|
BaseStatefulTransformer,
|
|
16
12
|
BaseTransformerUnit,
|
|
17
13
|
processor_state,
|
|
18
14
|
)
|
|
15
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
16
|
+
from ezmsg.util.messages.util import replace
|
|
17
|
+
from scipy.special import lambertw
|
|
18
|
+
|
|
19
19
|
from .spectrum import OptionsEnum
|
|
20
20
|
from .window import WindowTransformer
|
|
21
21
|
|
|
@@ -3,13 +3,13 @@ import typing
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
-
from ezmsg.
|
|
7
|
-
from ezmsg.util.messages.util import replace
|
|
8
|
-
|
|
9
|
-
from .base import (
|
|
6
|
+
from ezmsg.baseproc import (
|
|
10
7
|
BaseStatefulTransformer,
|
|
11
8
|
processor_state,
|
|
12
9
|
)
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
from ezmsg.util.messages.util import replace
|
|
12
|
+
|
|
13
13
|
from .filterbank import (
|
|
14
14
|
FilterbankMode,
|
|
15
15
|
FilterbankSettings,
|
ezmsg/sigproc/fir_hilbert.py
CHANGED
|
@@ -4,10 +4,10 @@ import typing
|
|
|
4
4
|
import ezmsg.core as ez
|
|
5
5
|
import numpy as np
|
|
6
6
|
import scipy.signal as sps
|
|
7
|
+
from ezmsg.baseproc import BaseStatefulTransformer, processor_state
|
|
7
8
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
9
|
from ezmsg.util.messages.util import replace
|
|
9
10
|
|
|
10
|
-
from ezmsg.sigproc.base import BaseStatefulTransformer, processor_state
|
|
11
11
|
from ezmsg.sigproc.filter import (
|
|
12
12
|
BACoeffs,
|
|
13
13
|
BaseFilterByDesignTransformerUnit,
|