ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +34 -1
- ezmsg/sigproc/activation.py +78 -0
- ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
- ezmsg/sigproc/affinetransform.py +235 -0
- ezmsg/sigproc/aggregate.py +276 -0
- ezmsg/sigproc/bandpower.py +80 -0
- ezmsg/sigproc/base.py +149 -0
- ezmsg/sigproc/butterworthfilter.py +129 -39
- ezmsg/sigproc/butterworthzerophase.py +305 -0
- ezmsg/sigproc/cheby.py +125 -0
- ezmsg/sigproc/combfilter.py +160 -0
- ezmsg/sigproc/coordinatespaces.py +159 -0
- ezmsg/sigproc/decimate.py +46 -18
- ezmsg/sigproc/denormalize.py +78 -0
- ezmsg/sigproc/detrend.py +28 -0
- ezmsg/sigproc/diff.py +82 -0
- ezmsg/sigproc/downsample.py +97 -49
- ezmsg/sigproc/ewma.py +217 -0
- ezmsg/sigproc/ewmfilter.py +45 -19
- ezmsg/sigproc/extract_axis.py +39 -0
- ezmsg/sigproc/fbcca.py +307 -0
- ezmsg/sigproc/filter.py +282 -117
- ezmsg/sigproc/filterbank.py +292 -0
- ezmsg/sigproc/filterbankdesign.py +129 -0
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +117 -0
- ezmsg/sigproc/gaussiansmoothing.py +89 -0
- ezmsg/sigproc/kaiser.py +106 -0
- ezmsg/sigproc/linear.py +120 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +35 -0
- ezmsg/sigproc/math/add.py +120 -0
- ezmsg/sigproc/math/clip.py +48 -0
- ezmsg/sigproc/math/difference.py +143 -0
- ezmsg/sigproc/math/invert.py +28 -0
- ezmsg/sigproc/math/log.py +57 -0
- ezmsg/sigproc/math/scale.py +39 -0
- ezmsg/sigproc/messages.py +3 -6
- ezmsg/sigproc/quantize.py +68 -0
- ezmsg/sigproc/resample.py +278 -0
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +232 -241
- ezmsg/sigproc/scaler.py +165 -0
- ezmsg/sigproc/signalinjector.py +70 -0
- ezmsg/sigproc/slicer.py +138 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +90 -0
- ezmsg/sigproc/spectrum.py +277 -0
- ezmsg/sigproc/transpose.py +134 -0
- ezmsg/sigproc/util/__init__.py +0 -0
- ezmsg/sigproc/util/asio.py +25 -0
- ezmsg/sigproc/util/axisarray_buffer.py +365 -0
- ezmsg/sigproc/util/buffer.py +449 -0
- ezmsg/sigproc/util/message.py +17 -0
- ezmsg/sigproc/util/profile.py +23 -0
- ezmsg/sigproc/util/sparse.py +115 -0
- ezmsg/sigproc/util/typeresolution.py +17 -0
- ezmsg/sigproc/wavelets.py +187 -0
- ezmsg/sigproc/window.py +301 -117
- ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/synth.py +0 -411
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/scaler.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ezmsg.baseproc import (
|
|
6
|
+
BaseStatefulTransformer,
|
|
7
|
+
BaseTransformerUnit,
|
|
8
|
+
processor_state,
|
|
9
|
+
)
|
|
10
|
+
from ezmsg.util.generator import consumer
|
|
11
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
12
|
+
from ezmsg.util.messages.util import replace
|
|
13
|
+
|
|
14
|
+
# Imports for backwards compatibility with previous module location
|
|
15
|
+
from .ewma import EWMA_Deprecated as EWMA_Deprecated
|
|
16
|
+
from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
|
|
17
|
+
from .ewma import _tau_from_alpha as _tau_from_alpha
|
|
18
|
+
from .ewma import ewma_step as ewma_step
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@consumer
|
|
22
|
+
def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
23
|
+
"""
|
|
24
|
+
Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
|
|
25
|
+
This is faster than :obj:`scaler_np` for single-channel data.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
time_constant: Decay constant `tau` in seconds.
|
|
29
|
+
axis: The name of the axis to accumulate statistics over.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
|
|
33
|
+
and yields an :obj:`AxisArray` with its data being a standardized, or "Z-scored" version of the input data.
|
|
34
|
+
"""
|
|
35
|
+
from river import preprocessing
|
|
36
|
+
|
|
37
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
38
|
+
_scaler = None
|
|
39
|
+
while True:
|
|
40
|
+
msg_in: AxisArray = yield msg_out
|
|
41
|
+
data = msg_in.data
|
|
42
|
+
if axis is None:
|
|
43
|
+
axis = msg_in.dims[0]
|
|
44
|
+
axis_idx = 0
|
|
45
|
+
else:
|
|
46
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
47
|
+
if axis_idx != 0:
|
|
48
|
+
data = np.moveaxis(data, axis_idx, 0)
|
|
49
|
+
|
|
50
|
+
if _scaler is None:
|
|
51
|
+
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
52
|
+
_scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
|
|
53
|
+
|
|
54
|
+
result = []
|
|
55
|
+
for sample in data:
|
|
56
|
+
x = {k: v for k, v in enumerate(sample.flatten().tolist())}
|
|
57
|
+
_scaler.learn_one(x)
|
|
58
|
+
y = _scaler.transform_one(x)
|
|
59
|
+
k = sorted(y.keys())
|
|
60
|
+
result.append(np.array([y[_] for _ in k]).reshape(sample.shape))
|
|
61
|
+
|
|
62
|
+
result = np.stack(result)
|
|
63
|
+
result = np.moveaxis(result, 0, axis_idx)
|
|
64
|
+
msg_out = replace(msg_in, data=result)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class AdaptiveStandardScalerSettings(EWMASettings): ...
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@processor_state
|
|
71
|
+
class AdaptiveStandardScalerState:
|
|
72
|
+
samps_ewma: EWMATransformer | None = None
|
|
73
|
+
vars_sq_ewma: EWMATransformer | None = None
|
|
74
|
+
alpha: float | None = None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class AdaptiveStandardScalerTransformer(
|
|
78
|
+
BaseStatefulTransformer[
|
|
79
|
+
AdaptiveStandardScalerSettings,
|
|
80
|
+
AxisArray,
|
|
81
|
+
AxisArray,
|
|
82
|
+
AdaptiveStandardScalerState,
|
|
83
|
+
]
|
|
84
|
+
):
|
|
85
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
86
|
+
self._state.samps_ewma = EWMATransformer(
|
|
87
|
+
time_constant=self.settings.time_constant,
|
|
88
|
+
axis=self.settings.axis,
|
|
89
|
+
accumulate=self.settings.accumulate,
|
|
90
|
+
)
|
|
91
|
+
self._state.vars_sq_ewma = EWMATransformer(
|
|
92
|
+
time_constant=self.settings.time_constant,
|
|
93
|
+
axis=self.settings.axis,
|
|
94
|
+
accumulate=self.settings.accumulate,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def accumulate(self) -> bool:
|
|
99
|
+
"""Whether to accumulate statistics from incoming samples."""
|
|
100
|
+
return self.settings.accumulate
|
|
101
|
+
|
|
102
|
+
@accumulate.setter
|
|
103
|
+
def accumulate(self, value: bool) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Set the accumulate mode and propagate to child EWMA transformers.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
value: If True, update statistics with each sample.
|
|
109
|
+
If False, only apply current statistics without updating.
|
|
110
|
+
"""
|
|
111
|
+
if self._state.samps_ewma is not None:
|
|
112
|
+
self._state.samps_ewma.settings = replace(self._state.samps_ewma.settings, accumulate=value)
|
|
113
|
+
if self._state.vars_sq_ewma is not None:
|
|
114
|
+
self._state.vars_sq_ewma.settings = replace(self._state.vars_sq_ewma.settings, accumulate=value)
|
|
115
|
+
|
|
116
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
117
|
+
# Update step (respects accumulate setting via child EWMAs)
|
|
118
|
+
mean_message = self._state.samps_ewma(message)
|
|
119
|
+
var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
|
|
120
|
+
|
|
121
|
+
# Get step
|
|
122
|
+
varis = var_sq_message.data - mean_message.data**2
|
|
123
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
124
|
+
result = (message.data - mean_message.data) / (varis**0.5)
|
|
125
|
+
result[np.isnan(result)] = 0.0
|
|
126
|
+
return replace(message, data=result)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class AdaptiveStandardScaler(
|
|
130
|
+
BaseTransformerUnit[
|
|
131
|
+
AdaptiveStandardScalerSettings,
|
|
132
|
+
AxisArray,
|
|
133
|
+
AxisArray,
|
|
134
|
+
AdaptiveStandardScalerTransformer,
|
|
135
|
+
]
|
|
136
|
+
):
|
|
137
|
+
SETTINGS = AdaptiveStandardScalerSettings
|
|
138
|
+
|
|
139
|
+
@ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
|
|
140
|
+
async def on_settings(self, msg: AdaptiveStandardScalerSettings) -> None:
|
|
141
|
+
"""
|
|
142
|
+
Handle settings updates with smart reset behavior.
|
|
143
|
+
|
|
144
|
+
Only resets state if `axis` changes (structural change).
|
|
145
|
+
Changes to `time_constant` or `accumulate` are applied without
|
|
146
|
+
resetting accumulated statistics.
|
|
147
|
+
"""
|
|
148
|
+
old_axis = self.SETTINGS.axis
|
|
149
|
+
self.apply_settings(msg)
|
|
150
|
+
|
|
151
|
+
if msg.axis != old_axis:
|
|
152
|
+
# Axis changed - need full reset
|
|
153
|
+
self.create_processor()
|
|
154
|
+
else:
|
|
155
|
+
# Update accumulate on processor (propagates to child EWMAs)
|
|
156
|
+
self.processor.accumulate = msg.accumulate
|
|
157
|
+
# Also update own settings reference
|
|
158
|
+
self.processor.settings = msg
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# Backwards compatibility...
|
|
162
|
+
def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
|
|
163
|
+
return AdaptiveStandardScalerTransformer(
|
|
164
|
+
settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
|
|
165
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseAsyncTransformer,
|
|
6
|
+
BaseTransformerUnit,
|
|
7
|
+
processor_state,
|
|
8
|
+
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SignalInjectorSettings(ez.Settings):
|
|
14
|
+
time_dim: str = "time" # Input signal needs a time dimension with units in sec.
|
|
15
|
+
frequency: float | None = None # Hz
|
|
16
|
+
amplitude: float = 1.0
|
|
17
|
+
mixing_seed: int | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@processor_state
|
|
21
|
+
class SignalInjectorState:
|
|
22
|
+
cur_shape: tuple[int, ...] | None = None
|
|
23
|
+
cur_frequency: float | None = None
|
|
24
|
+
cur_amplitude: float | None = None
|
|
25
|
+
mixing: npt.NDArray | None = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SignalInjectorTransformer(
|
|
29
|
+
BaseAsyncTransformer[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState]
|
|
30
|
+
):
|
|
31
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
32
|
+
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
33
|
+
sample_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
34
|
+
return hash((message.key,) + sample_shape)
|
|
35
|
+
|
|
36
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
37
|
+
if self._state.cur_frequency is None:
|
|
38
|
+
self._state.cur_frequency = self.settings.frequency
|
|
39
|
+
if self._state.cur_amplitude is None:
|
|
40
|
+
self._state.cur_amplitude = self.settings.amplitude
|
|
41
|
+
time_ax_idx = message.get_axis_idx(self.settings.time_dim)
|
|
42
|
+
self._state.cur_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
|
|
43
|
+
rng = np.random.default_rng(self.settings.mixing_seed)
|
|
44
|
+
self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
|
|
45
|
+
self._state.mixing = (self._state.mixing * 2.0) - 1.0
|
|
46
|
+
|
|
47
|
+
async def _aprocess(self, message: AxisArray) -> AxisArray:
|
|
48
|
+
if self._state.cur_frequency is None:
|
|
49
|
+
return message
|
|
50
|
+
out_msg = replace(message, data=message.data.copy())
|
|
51
|
+
t = out_msg.ax(self.settings.time_dim).values[..., np.newaxis]
|
|
52
|
+
signal = np.sin(2 * np.pi * self._state.cur_frequency * t)
|
|
53
|
+
mixed_signal = signal * self._state.mixing * self._state.cur_amplitude
|
|
54
|
+
with out_msg.view2d(self.settings.time_dim) as view:
|
|
55
|
+
view[...] = view + mixed_signal.astype(view.dtype)
|
|
56
|
+
return out_msg
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SignalInjector(BaseTransformerUnit[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer]):
|
|
60
|
+
SETTINGS = SignalInjectorSettings
|
|
61
|
+
INPUT_FREQUENCY = ez.InputStream(float | None)
|
|
62
|
+
INPUT_AMPLITUDE = ez.InputStream(float)
|
|
63
|
+
|
|
64
|
+
@ez.subscriber(INPUT_FREQUENCY)
|
|
65
|
+
async def on_frequency(self, msg: float | None) -> None:
|
|
66
|
+
self.processor.state.cur_frequency = msg
|
|
67
|
+
|
|
68
|
+
@ez.subscriber(INPUT_AMPLITUDE)
|
|
69
|
+
async def on_amplitude(self, msg: float) -> None:
|
|
70
|
+
self.processor.state.cur_amplitude = msg
|
ezmsg/sigproc/slicer.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import ezmsg.core as ez
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseStatefulTransformer,
|
|
6
|
+
BaseTransformerUnit,
|
|
7
|
+
processor_state,
|
|
8
|
+
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import (
|
|
10
|
+
AxisArray,
|
|
11
|
+
AxisBase,
|
|
12
|
+
replace,
|
|
13
|
+
slice_along_axis,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
Slicer:Select a subset of data along a particular axis.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def parse_slice(
|
|
22
|
+
s: str,
|
|
23
|
+
axinfo: AxisArray.CoordinateAxis | None = None,
|
|
24
|
+
) -> tuple[slice | int, ...]:
|
|
25
|
+
"""
|
|
26
|
+
Parses a string representation of a slice and returns a tuple of slice objects.
|
|
27
|
+
|
|
28
|
+
- "" -> slice(None, None, None) (take all)
|
|
29
|
+
- ":" -> slice(None, None, None)
|
|
30
|
+
- '"none"` (case-insensitive) -> slice(None, None, None)
|
|
31
|
+
- "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
|
|
32
|
+
- "5" (or any integer) -> (5,). Take only that item.
|
|
33
|
+
applying this to a ndarray or AxisArray will drop the dimension.
|
|
34
|
+
- A comma-separated list of the above -> a tuple of slices | ints
|
|
35
|
+
- A comma-separated list of values and axinfo is provided and is a CoordinateAxis -> a tuple of ints
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
s: The string representation of the slice.
|
|
39
|
+
axinfo: (Optional) If provided, and of type CoordinateAxis,
|
|
40
|
+
and `s` is a comma-separated list of values, then the values
|
|
41
|
+
in s will be checked against the values in axinfo.data.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
A tuple of slice objects and/or ints.
|
|
45
|
+
"""
|
|
46
|
+
if s.lower() in ["", ":", "none"]:
|
|
47
|
+
return (slice(None),)
|
|
48
|
+
if "," not in s:
|
|
49
|
+
parts = [part.strip() for part in s.split(":")]
|
|
50
|
+
if len(parts) == 1:
|
|
51
|
+
if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
|
|
52
|
+
return tuple(np.where(axinfo.data == parts[0])[0])
|
|
53
|
+
return (int(parts[0]),)
|
|
54
|
+
return (slice(*(int(part.strip()) if part else None for part in parts)),)
|
|
55
|
+
suplist = [parse_slice(_, axinfo=axinfo) for _ in s.split(",")]
|
|
56
|
+
return tuple([item for sublist in suplist for item in sublist])
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SlicerSettings(ez.Settings):
|
|
60
|
+
selection: str = ""
|
|
61
|
+
"""selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details."""
|
|
62
|
+
|
|
63
|
+
axis: str | None = None
|
|
64
|
+
"""The name of the axis to slice along. If None, the last axis is used."""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@processor_state
|
|
68
|
+
class SlicerState:
|
|
69
|
+
slice_: slice | int | npt.NDArray | None = None
|
|
70
|
+
new_axis: AxisBase | None = None
|
|
71
|
+
b_change_dims: bool = False
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class SlicerTransformer(BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]):
|
|
75
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
76
|
+
axis = self.settings.axis or message.dims[-1]
|
|
77
|
+
axis_idx = message.get_axis_idx(axis)
|
|
78
|
+
return hash((message.key, message.data.shape[axis_idx]))
|
|
79
|
+
|
|
80
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
81
|
+
axis = self.settings.axis or message.dims[-1]
|
|
82
|
+
axis_idx = message.get_axis_idx(axis)
|
|
83
|
+
self._state.new_axis = None
|
|
84
|
+
self._state.b_change_dims = False
|
|
85
|
+
|
|
86
|
+
# Calculate the slice
|
|
87
|
+
_slices = parse_slice(self.settings.selection, message.axes.get(axis, None))
|
|
88
|
+
if len(_slices) == 1:
|
|
89
|
+
self._state.slice_ = _slices[0]
|
|
90
|
+
self._state.b_change_dims = isinstance(self._state.slice_, int)
|
|
91
|
+
else:
|
|
92
|
+
indices = np.arange(message.data.shape[axis_idx])
|
|
93
|
+
indices = np.hstack([indices[_] for _ in _slices])
|
|
94
|
+
self._state.slice_ = np.s_[indices]
|
|
95
|
+
|
|
96
|
+
# Create the output axis
|
|
97
|
+
if axis in message.axes and hasattr(message.axes[axis], "data") and len(message.axes[axis].data) > 0:
|
|
98
|
+
in_data = np.array(message.axes[axis].data)
|
|
99
|
+
if self._state.b_change_dims:
|
|
100
|
+
out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
|
|
101
|
+
else:
|
|
102
|
+
out_data = in_data[self._state.slice_]
|
|
103
|
+
self._state.new_axis = replace(message.axes[axis], data=out_data)
|
|
104
|
+
|
|
105
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
106
|
+
axis = self.settings.axis or message.dims[-1]
|
|
107
|
+
axis_idx = message.get_axis_idx(axis)
|
|
108
|
+
|
|
109
|
+
replace_kwargs = {}
|
|
110
|
+
if self._state.b_change_dims:
|
|
111
|
+
replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx]
|
|
112
|
+
replace_kwargs["axes"] = {k: v for k, v in message.axes.items() if k != axis}
|
|
113
|
+
elif self._state.new_axis is not None:
|
|
114
|
+
replace_kwargs["axes"] = {k: (v if k != axis else self._state.new_axis) for k, v in message.axes.items()}
|
|
115
|
+
|
|
116
|
+
return replace(
|
|
117
|
+
message,
|
|
118
|
+
data=slice_along_axis(message.data, self._state.slice_, axis_idx),
|
|
119
|
+
**replace_kwargs,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
|
|
124
|
+
SETTINGS = SlicerSettings
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def slicer(selection: str = "", axis: str | None = None) -> SlicerTransformer:
|
|
128
|
+
"""
|
|
129
|
+
Slice along a particular axis.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
|
|
133
|
+
axis: The name of the axis to slice along. If None, the last axis is used.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
:obj:`SlicerTransformer`
|
|
137
|
+
"""
|
|
138
|
+
return SlicerTransformer(SlicerSettings(selection=selection, axis=axis))
|
ezmsg/sigproc/spectral.py
CHANGED
|
@@ -1,132 +1,6 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
-
|
|
10
|
-
from typing import Optional, AsyncGenerator
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class OptionsEnum(enum.Enum):
|
|
14
|
-
@classmethod
|
|
15
|
-
def options(cls):
|
|
16
|
-
return list(map(lambda c: c.value, cls))
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class WindowFunction(OptionsEnum):
|
|
20
|
-
NONE = "None (Rectangular)"
|
|
21
|
-
HAMMING = "Hamming"
|
|
22
|
-
HANNING = "Hanning"
|
|
23
|
-
BARTLETT = "Bartlett"
|
|
24
|
-
BLACKMAN = "Blackman"
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
WINDOWS = {
|
|
28
|
-
WindowFunction.NONE: np.ones,
|
|
29
|
-
WindowFunction.HAMMING: np.hamming,
|
|
30
|
-
WindowFunction.HANNING: np.hanning,
|
|
31
|
-
WindowFunction.BARTLETT: np.bartlett,
|
|
32
|
-
WindowFunction.BLACKMAN: np.blackman,
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class SpectralTransform(OptionsEnum):
|
|
37
|
-
RAW_COMPLEX = "Complex FFT Output"
|
|
38
|
-
REAL = "Real Component of FFT"
|
|
39
|
-
IMAG = "Imaginary Component of FFT"
|
|
40
|
-
REL_POWER = "Relative Power"
|
|
41
|
-
REL_DB = "Log Power (Relative dB)"
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class SpectralOutput(OptionsEnum):
|
|
45
|
-
FULL = "Full Spectrum"
|
|
46
|
-
POSITIVE = "Positive Frequencies"
|
|
47
|
-
NEGATIVE = "Negative Frequencies"
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class SpectrumSettings(ez.Settings):
|
|
51
|
-
axis: Optional[str] = None
|
|
52
|
-
# n: Optional[int] = None # n parameter for fft
|
|
53
|
-
out_axis: Optional[str] = "freq" # If none; don't change dim name
|
|
54
|
-
window: WindowFunction = WindowFunction.HAMMING
|
|
55
|
-
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
56
|
-
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class SpectrumState(ez.State):
|
|
60
|
-
cur_settings: SpectrumSettings
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class Spectrum(ez.Unit):
|
|
64
|
-
SETTINGS: SpectrumSettings
|
|
65
|
-
STATE: SpectrumState
|
|
66
|
-
|
|
67
|
-
INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
|
|
68
|
-
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
69
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
70
|
-
|
|
71
|
-
def initialize(self) -> None:
|
|
72
|
-
self.STATE.cur_settings = self.SETTINGS
|
|
73
|
-
|
|
74
|
-
@ez.subscriber(INPUT_SETTINGS)
|
|
75
|
-
async def on_settings(self, msg: SpectrumSettings):
|
|
76
|
-
self.STATE.cur_settings = msg
|
|
77
|
-
|
|
78
|
-
@ez.subscriber(INPUT_SIGNAL)
|
|
79
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
80
|
-
async def on_data(self, message: AxisArray) -> AsyncGenerator:
|
|
81
|
-
axis_name = self.STATE.cur_settings.axis
|
|
82
|
-
if axis_name is None:
|
|
83
|
-
axis_name = message.dims[0]
|
|
84
|
-
axis_idx = message.get_axis_idx(axis_name)
|
|
85
|
-
axis = message.get_axis(axis_name)
|
|
86
|
-
|
|
87
|
-
spectrum = np.moveaxis(message.data, axis_idx, -1)
|
|
88
|
-
|
|
89
|
-
n_time = message.data.shape[axis_idx]
|
|
90
|
-
window = WINDOWS[self.STATE.cur_settings.window](n_time)
|
|
91
|
-
|
|
92
|
-
spectrum = np.fft.fft(spectrum * window) / n_time
|
|
93
|
-
spectrum = np.fft.fftshift(spectrum, axes=-1)
|
|
94
|
-
freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=axis.gain), axes=-1)
|
|
95
|
-
|
|
96
|
-
if self.STATE.cur_settings.transform != SpectralTransform.RAW_COMPLEX:
|
|
97
|
-
if self.STATE.cur_settings.transform == SpectralTransform.REAL:
|
|
98
|
-
spectrum = spectrum.real
|
|
99
|
-
elif self.STATE.cur_settings.transform == SpectralTransform.IMAG:
|
|
100
|
-
spectrum = spectrum.imag
|
|
101
|
-
else:
|
|
102
|
-
scale = np.sum(window**2.0) * axis.gain
|
|
103
|
-
spectrum = (2.0 * (np.abs(spectrum) ** 2.0)) / scale
|
|
104
|
-
|
|
105
|
-
if self.STATE.cur_settings.transform == SpectralTransform.REL_DB:
|
|
106
|
-
spectrum = 10 * np.log10(spectrum)
|
|
107
|
-
|
|
108
|
-
axis_offset = freqs[0]
|
|
109
|
-
if self.STATE.cur_settings.output == SpectralOutput.POSITIVE:
|
|
110
|
-
axis_offset = freqs[n_time // 2]
|
|
111
|
-
spectrum = spectrum[..., n_time // 2 :]
|
|
112
|
-
elif self.STATE.cur_settings.output == SpectralOutput.NEGATIVE:
|
|
113
|
-
spectrum = spectrum[..., : n_time // 2]
|
|
114
|
-
|
|
115
|
-
spectrum = np.moveaxis(spectrum, axis_idx, -1)
|
|
116
|
-
|
|
117
|
-
out_axis = self.SETTINGS.out_axis
|
|
118
|
-
if out_axis is None:
|
|
119
|
-
out_axis = axis_name
|
|
120
|
-
|
|
121
|
-
freq_axis = AxisArray.Axis(
|
|
122
|
-
unit="Hz", gain=1.0 / (axis.gain * n_time), offset=axis_offset
|
|
123
|
-
)
|
|
124
|
-
new_axes = {**message.axes, **{out_axis: freq_axis}}
|
|
125
|
-
|
|
126
|
-
new_dims = [d for d in message.dims]
|
|
127
|
-
if self.SETTINGS.out_axis is not None:
|
|
128
|
-
new_dims[axis_idx] = self.SETTINGS.out_axis
|
|
129
|
-
|
|
130
|
-
out_msg = replace(message, data=spectrum, dims=new_dims, axes=new_axes)
|
|
131
|
-
|
|
132
|
-
yield self.OUTPUT_SIGNAL, out_msg
|
|
1
|
+
from .spectrum import OptionsEnum as OptionsEnum
|
|
2
|
+
from .spectrum import SpectralOutput as SpectralOutput
|
|
3
|
+
from .spectrum import SpectralTransform as SpectralTransform
|
|
4
|
+
from .spectrum import Spectrum as Spectrum
|
|
5
|
+
from .spectrum import SpectrumSettings as SpectrumSettings
|
|
6
|
+
from .spectrum import WindowFunction as WindowFunction
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from typing import Generator
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.baseproc import (
|
|
5
|
+
BaseStatefulProcessor,
|
|
6
|
+
BaseTransformerUnit,
|
|
7
|
+
CompositeProcessor,
|
|
8
|
+
)
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.modify import modify_axis
|
|
11
|
+
|
|
12
|
+
from .spectrum import (
|
|
13
|
+
SpectralOutput,
|
|
14
|
+
SpectralTransform,
|
|
15
|
+
SpectrumTransformer,
|
|
16
|
+
WindowFunction,
|
|
17
|
+
)
|
|
18
|
+
from .window import Anchor, WindowTransformer
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SpectrogramSettings(ez.Settings):
|
|
22
|
+
"""
|
|
23
|
+
Settings for :obj:`SpectrogramTransformer`.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
window_dur: float | None = None
|
|
27
|
+
"""window duration in seconds."""
|
|
28
|
+
|
|
29
|
+
window_shift: float | None = None
|
|
30
|
+
""""window step in seconds. If None, window_shift == window_dur"""
|
|
31
|
+
|
|
32
|
+
window_anchor: str | Anchor = Anchor.BEGINNING
|
|
33
|
+
"""See :obj"`WindowTransformer`"""
|
|
34
|
+
|
|
35
|
+
window: WindowFunction = WindowFunction.HAMMING
|
|
36
|
+
"""The :obj:`WindowFunction` to apply to the data slice prior to calculating the spectrum."""
|
|
37
|
+
|
|
38
|
+
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
39
|
+
"""The :obj:`SpectralTransform` to apply to the spectral magnitude."""
|
|
40
|
+
|
|
41
|
+
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
42
|
+
"""The :obj:`SpectralOutput` format."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SpectrogramTransformer(CompositeProcessor[SpectrogramSettings, AxisArray, AxisArray]):
|
|
46
|
+
@staticmethod
|
|
47
|
+
def _initialize_processors(
|
|
48
|
+
settings: SpectrogramSettings,
|
|
49
|
+
) -> dict[str, BaseStatefulProcessor | Generator[AxisArray, AxisArray, None]]:
|
|
50
|
+
return {
|
|
51
|
+
"windowing": WindowTransformer(
|
|
52
|
+
axis="time",
|
|
53
|
+
newaxis="win",
|
|
54
|
+
window_dur=settings.window_dur,
|
|
55
|
+
window_shift=settings.window_shift,
|
|
56
|
+
zero_pad_until="shift" if settings.window_shift is not None else "input",
|
|
57
|
+
anchor=settings.window_anchor,
|
|
58
|
+
),
|
|
59
|
+
"spectrum": SpectrumTransformer(
|
|
60
|
+
axis="time",
|
|
61
|
+
window=settings.window,
|
|
62
|
+
transform=settings.transform,
|
|
63
|
+
output=settings.output,
|
|
64
|
+
),
|
|
65
|
+
"modify_axis": modify_axis(name_map={"win": "time"}),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Spectrogram(BaseTransformerUnit[SpectrogramSettings, AxisArray, AxisArray, SpectrogramTransformer]):
|
|
70
|
+
SETTINGS = SpectrogramSettings
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def spectrogram(
|
|
74
|
+
window_dur: float | None = None,
|
|
75
|
+
window_shift: float | None = None,
|
|
76
|
+
window_anchor: str | Anchor = Anchor.BEGINNING,
|
|
77
|
+
window: WindowFunction = WindowFunction.HAMMING,
|
|
78
|
+
transform: SpectralTransform = SpectralTransform.REL_DB,
|
|
79
|
+
output: SpectralOutput = SpectralOutput.POSITIVE,
|
|
80
|
+
) -> SpectrogramTransformer:
|
|
81
|
+
return SpectrogramTransformer(
|
|
82
|
+
SpectrogramSettings(
|
|
83
|
+
window_dur=window_dur,
|
|
84
|
+
window_shift=window_shift,
|
|
85
|
+
window_anchor=window_anchor,
|
|
86
|
+
window=window,
|
|
87
|
+
transform=transform,
|
|
88
|
+
output=output,
|
|
89
|
+
)
|
|
90
|
+
)
|