ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +16 -1
- ezmsg/sigproc/activation.py +75 -0
- ezmsg/sigproc/affinetransform.py +234 -0
- ezmsg/sigproc/aggregate.py +158 -0
- ezmsg/sigproc/bandpower.py +74 -0
- ezmsg/sigproc/base.py +38 -0
- ezmsg/sigproc/butterworthfilter.py +102 -11
- ezmsg/sigproc/decimate.py +7 -4
- ezmsg/sigproc/downsample.py +95 -51
- ezmsg/sigproc/ewmfilter.py +38 -16
- ezmsg/sigproc/filter.py +108 -20
- ezmsg/sigproc/filterbank.py +278 -0
- ezmsg/sigproc/math/__init__.py +0 -0
- ezmsg/sigproc/math/abs.py +28 -0
- ezmsg/sigproc/math/clip.py +30 -0
- ezmsg/sigproc/math/difference.py +60 -0
- ezmsg/sigproc/math/invert.py +29 -0
- ezmsg/sigproc/math/log.py +32 -0
- ezmsg/sigproc/math/scale.py +31 -0
- ezmsg/sigproc/messages.py +2 -3
- ezmsg/sigproc/sampler.py +259 -224
- ezmsg/sigproc/scaler.py +173 -0
- ezmsg/sigproc/signalinjector.py +64 -0
- ezmsg/sigproc/slicer.py +133 -0
- ezmsg/sigproc/spectral.py +6 -132
- ezmsg/sigproc/spectrogram.py +86 -0
- ezmsg/sigproc/spectrum.py +259 -0
- ezmsg/sigproc/synth.py +299 -105
- ezmsg/sigproc/wavelets.py +167 -0
- ezmsg/sigproc/window.py +254 -116
- ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
- ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -2
- ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
ezmsg/sigproc/scaler.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import ezmsg.core as ez
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.generator import consumer
|
|
9
|
+
|
|
10
|
+
from .base import GenAxisArray
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _tau_from_alpha(alpha: float, dt: float) -> float:
|
|
14
|
+
"""
|
|
15
|
+
Inverse of _alpha_from_tau. See that function for explanation.
|
|
16
|
+
"""
|
|
17
|
+
return -dt / np.log(1 - alpha)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _alpha_from_tau(tau: float, dt: float) -> float:
|
|
21
|
+
"""
|
|
22
|
+
# https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
|
|
23
|
+
:param tau: The amount of time for the smoothed response of a unit step function to reach
|
|
24
|
+
1 - 1/e approx-eq 63.2%.
|
|
25
|
+
:param dt: sampling period, or 1 / sampling_rate.
|
|
26
|
+
:return: alpha, the "fading factor" in exponential smoothing.
|
|
27
|
+
"""
|
|
28
|
+
return 1 - np.exp(-dt / tau)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@consumer
|
|
32
|
+
def scaler(
|
|
33
|
+
time_constant: float = 1.0, axis: typing.Optional[str] = None
|
|
34
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
35
|
+
"""
|
|
36
|
+
Create a generator function that applies the
|
|
37
|
+
adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
|
|
38
|
+
This is faster than :obj:`scaler_np` for single-channel data.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
time_constant: Decay constant `tau` in seconds.
|
|
42
|
+
axis: The name of the axis to accumulate statistics over.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A primed generator object that expects `.send(axis_array)` and yields a
|
|
46
|
+
standardized, or "Z-scored" version of the input.
|
|
47
|
+
"""
|
|
48
|
+
from river import preprocessing
|
|
49
|
+
|
|
50
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
51
|
+
_scaler = None
|
|
52
|
+
while True:
|
|
53
|
+
msg_in: AxisArray = yield msg_out
|
|
54
|
+
data = msg_in.data
|
|
55
|
+
if axis is None:
|
|
56
|
+
axis = msg_in.dims[0]
|
|
57
|
+
axis_idx = 0
|
|
58
|
+
else:
|
|
59
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
60
|
+
if axis_idx != 0:
|
|
61
|
+
data = np.moveaxis(data, axis_idx, 0)
|
|
62
|
+
|
|
63
|
+
if _scaler is None:
|
|
64
|
+
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
65
|
+
_scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
|
|
66
|
+
|
|
67
|
+
result = []
|
|
68
|
+
for sample in data:
|
|
69
|
+
x = {k: v for k, v in enumerate(sample.flatten().tolist())}
|
|
70
|
+
_scaler.learn_one(x)
|
|
71
|
+
y = _scaler.transform_one(x)
|
|
72
|
+
k = sorted(y.keys())
|
|
73
|
+
result.append(np.array([y[_] for _ in k]).reshape(sample.shape))
|
|
74
|
+
|
|
75
|
+
result = np.stack(result)
|
|
76
|
+
result = np.moveaxis(result, 0, axis_idx)
|
|
77
|
+
msg_out = replace(msg_in, data=result)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@consumer
|
|
81
|
+
def scaler_np(
|
|
82
|
+
time_constant: float = 1.0, axis: typing.Optional[str] = None
|
|
83
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
84
|
+
"""
|
|
85
|
+
Create a generator function that applies an adaptive standard scaler.
|
|
86
|
+
This is faster than :obj:`scaler` for multichannel data.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
time_constant: Decay constant `tau` in seconds.
|
|
90
|
+
axis: The name of the axis to accumulate statistics over.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
A primed generator object that expects `.send(axis_array)` and yields a
|
|
94
|
+
standardized, or "Z-scored" version of the input.
|
|
95
|
+
"""
|
|
96
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
97
|
+
|
|
98
|
+
# State variables
|
|
99
|
+
alpha: float = 0.0
|
|
100
|
+
means: typing.Optional[npt.NDArray] = None
|
|
101
|
+
vars_means: typing.Optional[npt.NDArray] = None
|
|
102
|
+
vars_sq_means: typing.Optional[npt.NDArray] = None
|
|
103
|
+
|
|
104
|
+
# Reset if input changes
|
|
105
|
+
check_input = {
|
|
106
|
+
"gain": None, # Resets alpha
|
|
107
|
+
"shape": None,
|
|
108
|
+
"key": None, # Key change implies buffered means/vars are invalid.
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
def _ew_update(arr, prev, _alpha):
|
|
112
|
+
if np.all(prev == 0):
|
|
113
|
+
return arr
|
|
114
|
+
# return _alpha * arr + (1 - _alpha) * prev
|
|
115
|
+
# Micro-optimization: sub, mult, add (below) is faster than sub, mult, mult, add (above)
|
|
116
|
+
return prev + _alpha * (arr - prev)
|
|
117
|
+
|
|
118
|
+
while True:
|
|
119
|
+
msg_in: AxisArray = yield msg_out
|
|
120
|
+
|
|
121
|
+
axis = axis or msg_in.dims[0]
|
|
122
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
123
|
+
|
|
124
|
+
if msg_in.axes[axis].gain != check_input["gain"]:
|
|
125
|
+
alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
|
|
126
|
+
check_input["gain"] = msg_in.axes[axis].gain
|
|
127
|
+
|
|
128
|
+
data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
|
|
129
|
+
b_reset = data.shape[1:] != check_input["shape"]
|
|
130
|
+
b_reset |= msg_in.key != check_input["key"]
|
|
131
|
+
if b_reset:
|
|
132
|
+
check_input["shape"] = data.shape[1:]
|
|
133
|
+
check_input["key"] = msg_in.key
|
|
134
|
+
vars_sq_means = np.zeros_like(data[0], dtype=float)
|
|
135
|
+
vars_means = np.zeros_like(data[0], dtype=float)
|
|
136
|
+
means = np.zeros_like(data[0], dtype=float)
|
|
137
|
+
|
|
138
|
+
result = np.zeros_like(data)
|
|
139
|
+
for sample_ix in range(data.shape[0]):
|
|
140
|
+
sample = data[sample_ix]
|
|
141
|
+
# Update step
|
|
142
|
+
vars_means = _ew_update(sample, vars_means, alpha)
|
|
143
|
+
vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
|
|
144
|
+
means = _ew_update(sample, means, alpha)
|
|
145
|
+
# Get step
|
|
146
|
+
varis = vars_sq_means - vars_means**2
|
|
147
|
+
y = (sample - means) / (varis**0.5)
|
|
148
|
+
result[sample_ix] = y
|
|
149
|
+
|
|
150
|
+
result[np.isnan(result)] = 0.0
|
|
151
|
+
result = np.moveaxis(result, 0, axis_idx)
|
|
152
|
+
msg_out = replace(msg_in, data=result)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class AdaptiveStandardScalerSettings(ez.Settings):
|
|
156
|
+
"""
|
|
157
|
+
Settings for :obj:`AdaptiveStandardScaler`.
|
|
158
|
+
See :obj:`scaler_np` for a description of the parameters.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
time_constant: float = 1.0
|
|
162
|
+
axis: typing.Optional[str] = None
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class AdaptiveStandardScaler(GenAxisArray):
|
|
166
|
+
"""Unit for :obj:`scaler_np`"""
|
|
167
|
+
|
|
168
|
+
SETTINGS = AdaptiveStandardScalerSettings
|
|
169
|
+
|
|
170
|
+
def construct_generator(self):
|
|
171
|
+
self.STATE.gen = scaler_np(
|
|
172
|
+
time_constant=self.SETTINGS.time_constant, axis=self.SETTINGS.axis
|
|
173
|
+
)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
+
import numpy as np
|
|
7
|
+
import numpy.typing as npt
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SignalInjectorSettings(ez.Settings):
|
|
11
|
+
time_dim: str = "time" # Input signal needs a time dimension with units in sec.
|
|
12
|
+
frequency: typing.Optional[float] = None # Hz
|
|
13
|
+
amplitude: float = 1.0
|
|
14
|
+
mixing_seed: typing.Optional[int] = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SignalInjectorState(ez.State):
|
|
18
|
+
cur_shape: typing.Optional[typing.Tuple[int, ...]] = None
|
|
19
|
+
cur_frequency: typing.Optional[float] = None
|
|
20
|
+
cur_amplitude: float
|
|
21
|
+
mixing: npt.NDArray
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SignalInjector(ez.Unit):
|
|
25
|
+
SETTINGS = SignalInjectorSettings
|
|
26
|
+
STATE = SignalInjectorState
|
|
27
|
+
|
|
28
|
+
INPUT_FREQUENCY = ez.InputStream(typing.Optional[float])
|
|
29
|
+
INPUT_AMPLITUDE = ez.InputStream(float)
|
|
30
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
31
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
32
|
+
|
|
33
|
+
async def initialize(self) -> None:
|
|
34
|
+
self.STATE.cur_frequency = self.SETTINGS.frequency
|
|
35
|
+
self.STATE.cur_amplitude = self.SETTINGS.amplitude
|
|
36
|
+
self.STATE.mixing = np.array([])
|
|
37
|
+
|
|
38
|
+
@ez.subscriber(INPUT_FREQUENCY)
|
|
39
|
+
async def on_frequency(self, msg: typing.Optional[float]) -> None:
|
|
40
|
+
self.STATE.cur_frequency = msg
|
|
41
|
+
|
|
42
|
+
@ez.subscriber(INPUT_AMPLITUDE)
|
|
43
|
+
async def on_amplitude(self, msg: float) -> None:
|
|
44
|
+
self.STATE.cur_amplitude = msg
|
|
45
|
+
|
|
46
|
+
@ez.subscriber(INPUT_SIGNAL)
|
|
47
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
48
|
+
async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
49
|
+
if self.STATE.cur_shape != msg.shape:
|
|
50
|
+
self.STATE.cur_shape = msg.shape
|
|
51
|
+
rng = np.random.default_rng(self.SETTINGS.mixing_seed)
|
|
52
|
+
self.STATE.mixing = rng.random((1, msg.shape2d(self.SETTINGS.time_dim)[1]))
|
|
53
|
+
self.STATE.mixing = (self.STATE.mixing * 2.0) - 1.0
|
|
54
|
+
|
|
55
|
+
if self.STATE.cur_frequency is None:
|
|
56
|
+
yield self.OUTPUT_SIGNAL, msg
|
|
57
|
+
else:
|
|
58
|
+
out_msg = replace(msg, data=msg.data.copy())
|
|
59
|
+
t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
|
|
60
|
+
signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
|
|
61
|
+
mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
|
|
62
|
+
with out_msg.view2d(self.SETTINGS.time_dim) as view:
|
|
63
|
+
view[...] = view + mixed_signal.astype(view.dtype)
|
|
64
|
+
yield self.OUTPUT_SIGNAL, out_msg
|
ezmsg/sigproc/slicer.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import ezmsg.core as ez
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
8
|
+
from ezmsg.util.generator import consumer
|
|
9
|
+
|
|
10
|
+
from .base import GenAxisArray
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
Slicer:Select a subset of data along a particular axis.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
|
|
19
|
+
"""
|
|
20
|
+
Parses a string representation of a slice and returns a tuple of slice objects.
|
|
21
|
+
|
|
22
|
+
- "" -> slice(None, None, None) (take all)
|
|
23
|
+
- ":" -> slice(None, None, None)
|
|
24
|
+
- '"none"` (case-insensitive) -> slice(None, None, None)
|
|
25
|
+
- "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
|
|
26
|
+
- "5" (or any integer) -> (5,). Take only that item.
|
|
27
|
+
applying this to a ndarray or AxisArray will drop the dimension.
|
|
28
|
+
- A comma-separated list of the above -> a tuple of slices | ints
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
s: The string representation of the slice.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
A tuple of slice objects and/or ints.
|
|
35
|
+
"""
|
|
36
|
+
if s.lower() in ["", ":", "none"]:
|
|
37
|
+
return (slice(None),)
|
|
38
|
+
if "," not in s:
|
|
39
|
+
parts = [part.strip() for part in s.split(":")]
|
|
40
|
+
if len(parts) == 1:
|
|
41
|
+
return (int(parts[0]),)
|
|
42
|
+
return (slice(*(int(part.strip()) if part else None for part in parts)),)
|
|
43
|
+
suplist = [parse_slice(_) for _ in s.split(",")]
|
|
44
|
+
return tuple([item for sublist in suplist for item in sublist])
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@consumer
|
|
48
|
+
def slicer(
|
|
49
|
+
selection: str = "", axis: typing.Optional[str] = None
|
|
50
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
51
|
+
msg_out = AxisArray(np.array([]), dims=[""])
|
|
52
|
+
|
|
53
|
+
# State variables
|
|
54
|
+
_slice: typing.Optional[typing.Union[slice, npt.NDArray]] = None
|
|
55
|
+
new_axis: typing.Optional[AxisArray.Axis] = None
|
|
56
|
+
b_change_dims: bool = False # If number of dimensions changes when slicing
|
|
57
|
+
|
|
58
|
+
# Reset if input changes
|
|
59
|
+
check_input = {
|
|
60
|
+
"key": None, # key change used as proxy for label change, which we don't check explicitly
|
|
61
|
+
"len": None,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
while True:
|
|
65
|
+
msg_in: AxisArray = yield msg_out
|
|
66
|
+
|
|
67
|
+
axis = axis or msg_in.dims[-1]
|
|
68
|
+
axis_idx = msg_in.get_axis_idx(axis)
|
|
69
|
+
|
|
70
|
+
b_reset = _slice is None # or new_axis is None
|
|
71
|
+
b_reset = b_reset or msg_in.key != check_input["key"]
|
|
72
|
+
b_reset = b_reset or (
|
|
73
|
+
(msg_in.data.shape[axis_idx] != check_input["len"])
|
|
74
|
+
and (type(_slice) is np.ndarray)
|
|
75
|
+
)
|
|
76
|
+
if b_reset:
|
|
77
|
+
check_input["key"] = msg_in.key
|
|
78
|
+
check_input["len"] = msg_in.data.shape[axis_idx]
|
|
79
|
+
new_axis = None # Will hold updated metadata
|
|
80
|
+
b_change_dims = False
|
|
81
|
+
|
|
82
|
+
# Calculate the slice
|
|
83
|
+
_slices = parse_slice(selection)
|
|
84
|
+
if len(_slices) == 1:
|
|
85
|
+
_slice = _slices[0]
|
|
86
|
+
# Do we drop the sliced dimension?
|
|
87
|
+
b_change_dims = isinstance(_slice, int)
|
|
88
|
+
else:
|
|
89
|
+
# Multiple slices, but this cannot be done in a single step, so we convert the slices
|
|
90
|
+
# to a discontinuous set of integer indexes.
|
|
91
|
+
indices = np.arange(msg_in.data.shape[axis_idx])
|
|
92
|
+
indices = np.hstack([indices[_] for _ in _slices])
|
|
93
|
+
_slice = np.s_[indices] # Integer scalar array
|
|
94
|
+
|
|
95
|
+
# Create the output axis.
|
|
96
|
+
if (
|
|
97
|
+
axis in msg_in.axes
|
|
98
|
+
and hasattr(msg_in.axes[axis], "labels")
|
|
99
|
+
and len(msg_in.axes[axis].labels) > 0
|
|
100
|
+
):
|
|
101
|
+
new_labels = msg_in.axes[axis].labels[_slice]
|
|
102
|
+
new_axis = replace(msg_in.axes[axis], labels=new_labels)
|
|
103
|
+
|
|
104
|
+
replace_kwargs = {}
|
|
105
|
+
if b_change_dims:
|
|
106
|
+
# Dropping the target axis
|
|
107
|
+
replace_kwargs["dims"] = [
|
|
108
|
+
_ for dim_ix, _ in enumerate(msg_in.dims) if dim_ix != axis_idx
|
|
109
|
+
]
|
|
110
|
+
replace_kwargs["axes"] = {k: v for k, v in msg_in.axes.items() if k != axis}
|
|
111
|
+
elif new_axis is not None:
|
|
112
|
+
replace_kwargs["axes"] = {
|
|
113
|
+
k: (v if k != axis else new_axis) for k, v in msg_in.axes.items()
|
|
114
|
+
}
|
|
115
|
+
msg_out = replace(
|
|
116
|
+
msg_in,
|
|
117
|
+
data=slice_along_axis(msg_in.data, _slice, axis_idx),
|
|
118
|
+
**replace_kwargs,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class SlicerSettings(ez.Settings):
|
|
123
|
+
selection: str = ""
|
|
124
|
+
axis: typing.Optional[str] = None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class Slicer(GenAxisArray):
|
|
128
|
+
SETTINGS = SlicerSettings
|
|
129
|
+
|
|
130
|
+
def construct_generator(self):
|
|
131
|
+
self.STATE.gen = slicer(
|
|
132
|
+
selection=self.SETTINGS.selection, axis=self.SETTINGS.axis
|
|
133
|
+
)
|
ezmsg/sigproc/spectral.py
CHANGED
|
@@ -1,132 +1,6 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
-
|
|
10
|
-
from typing import Optional, AsyncGenerator
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class OptionsEnum(enum.Enum):
|
|
14
|
-
@classmethod
|
|
15
|
-
def options(cls):
|
|
16
|
-
return list(map(lambda c: c.value, cls))
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class WindowFunction(OptionsEnum):
|
|
20
|
-
NONE = "None (Rectangular)"
|
|
21
|
-
HAMMING = "Hamming"
|
|
22
|
-
HANNING = "Hanning"
|
|
23
|
-
BARTLETT = "Bartlett"
|
|
24
|
-
BLACKMAN = "Blackman"
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
WINDOWS = {
|
|
28
|
-
WindowFunction.NONE: np.ones,
|
|
29
|
-
WindowFunction.HAMMING: np.hamming,
|
|
30
|
-
WindowFunction.HANNING: np.hanning,
|
|
31
|
-
WindowFunction.BARTLETT: np.bartlett,
|
|
32
|
-
WindowFunction.BLACKMAN: np.blackman,
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class SpectralTransform(OptionsEnum):
|
|
37
|
-
RAW_COMPLEX = "Complex FFT Output"
|
|
38
|
-
REAL = "Real Component of FFT"
|
|
39
|
-
IMAG = "Imaginary Component of FFT"
|
|
40
|
-
REL_POWER = "Relative Power"
|
|
41
|
-
REL_DB = "Log Power (Relative dB)"
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class SpectralOutput(OptionsEnum):
|
|
45
|
-
FULL = "Full Spectrum"
|
|
46
|
-
POSITIVE = "Positive Frequencies"
|
|
47
|
-
NEGATIVE = "Negative Frequencies"
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class SpectrumSettings(ez.Settings):
|
|
51
|
-
axis: Optional[str] = None
|
|
52
|
-
# n: Optional[int] = None # n parameter for fft
|
|
53
|
-
out_axis: Optional[str] = "freq" # If none; don't change dim name
|
|
54
|
-
window: WindowFunction = WindowFunction.HAMMING
|
|
55
|
-
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
56
|
-
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class SpectrumState(ez.State):
|
|
60
|
-
cur_settings: SpectrumSettings
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class Spectrum(ez.Unit):
|
|
64
|
-
SETTINGS: SpectrumSettings
|
|
65
|
-
STATE: SpectrumState
|
|
66
|
-
|
|
67
|
-
INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
|
|
68
|
-
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
69
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
70
|
-
|
|
71
|
-
def initialize(self) -> None:
|
|
72
|
-
self.STATE.cur_settings = self.SETTINGS
|
|
73
|
-
|
|
74
|
-
@ez.subscriber(INPUT_SETTINGS)
|
|
75
|
-
async def on_settings(self, msg: SpectrumSettings):
|
|
76
|
-
self.STATE.cur_settings = msg
|
|
77
|
-
|
|
78
|
-
@ez.subscriber(INPUT_SIGNAL)
|
|
79
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
80
|
-
async def on_data(self, message: AxisArray) -> AsyncGenerator:
|
|
81
|
-
axis_name = self.STATE.cur_settings.axis
|
|
82
|
-
if axis_name is None:
|
|
83
|
-
axis_name = message.dims[0]
|
|
84
|
-
axis_idx = message.get_axis_idx(axis_name)
|
|
85
|
-
axis = message.get_axis(axis_name)
|
|
86
|
-
|
|
87
|
-
spectrum = np.moveaxis(message.data, axis_idx, -1)
|
|
88
|
-
|
|
89
|
-
n_time = message.data.shape[axis_idx]
|
|
90
|
-
window = WINDOWS[self.STATE.cur_settings.window](n_time)
|
|
91
|
-
|
|
92
|
-
spectrum = np.fft.fft(spectrum * window) / n_time
|
|
93
|
-
spectrum = np.fft.fftshift(spectrum, axes=-1)
|
|
94
|
-
freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=axis.gain), axes=-1)
|
|
95
|
-
|
|
96
|
-
if self.STATE.cur_settings.transform != SpectralTransform.RAW_COMPLEX:
|
|
97
|
-
if self.STATE.cur_settings.transform == SpectralTransform.REAL:
|
|
98
|
-
spectrum = spectrum.real
|
|
99
|
-
elif self.STATE.cur_settings.transform == SpectralTransform.IMAG:
|
|
100
|
-
spectrum = spectrum.imag
|
|
101
|
-
else:
|
|
102
|
-
scale = np.sum(window**2.0) * axis.gain
|
|
103
|
-
spectrum = (2.0 * (np.abs(spectrum) ** 2.0)) / scale
|
|
104
|
-
|
|
105
|
-
if self.STATE.cur_settings.transform == SpectralTransform.REL_DB:
|
|
106
|
-
spectrum = 10 * np.log10(spectrum)
|
|
107
|
-
|
|
108
|
-
axis_offset = freqs[0]
|
|
109
|
-
if self.STATE.cur_settings.output == SpectralOutput.POSITIVE:
|
|
110
|
-
axis_offset = freqs[n_time // 2]
|
|
111
|
-
spectrum = spectrum[..., n_time // 2 :]
|
|
112
|
-
elif self.STATE.cur_settings.output == SpectralOutput.NEGATIVE:
|
|
113
|
-
spectrum = spectrum[..., : n_time // 2]
|
|
114
|
-
|
|
115
|
-
spectrum = np.moveaxis(spectrum, axis_idx, -1)
|
|
116
|
-
|
|
117
|
-
out_axis = self.SETTINGS.out_axis
|
|
118
|
-
if out_axis is None:
|
|
119
|
-
out_axis = axis_name
|
|
120
|
-
|
|
121
|
-
freq_axis = AxisArray.Axis(
|
|
122
|
-
unit="Hz", gain=1.0 / (axis.gain * n_time), offset=axis_offset
|
|
123
|
-
)
|
|
124
|
-
new_axes = {**message.axes, **{out_axis: freq_axis}}
|
|
125
|
-
|
|
126
|
-
new_dims = [d for d in message.dims]
|
|
127
|
-
if self.SETTINGS.out_axis is not None:
|
|
128
|
-
new_dims[axis_idx] = self.SETTINGS.out_axis
|
|
129
|
-
|
|
130
|
-
out_msg = replace(message, data=spectrum, dims=new_dims, axes=new_axes)
|
|
131
|
-
|
|
132
|
-
yield self.OUTPUT_SIGNAL, out_msg
|
|
1
|
+
from .spectrum import OptionsEnum as OptionsEnum
|
|
2
|
+
from .spectrum import WindowFunction as WindowFunction
|
|
3
|
+
from .spectrum import SpectralTransform as SpectralTransform
|
|
4
|
+
from .spectrum import SpectralOutput as SpectralOutput
|
|
5
|
+
from .spectrum import SpectrumSettings as SpectrumSettings
|
|
6
|
+
from .spectrum import Spectrum as Spectrum
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
+
from ezmsg.util.generator import consumer, compose
|
|
6
|
+
from ezmsg.util.messages.modify import modify_axis
|
|
7
|
+
|
|
8
|
+
from .window import windowing
|
|
9
|
+
from .spectrum import spectrum, WindowFunction, SpectralTransform, SpectralOutput
|
|
10
|
+
from .base import GenAxisArray
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@consumer
|
|
14
|
+
def spectrogram(
|
|
15
|
+
window_dur: typing.Optional[float] = None,
|
|
16
|
+
window_shift: typing.Optional[float] = None,
|
|
17
|
+
window: WindowFunction = WindowFunction.HANNING,
|
|
18
|
+
transform: SpectralTransform = SpectralTransform.REL_DB,
|
|
19
|
+
output: SpectralOutput = SpectralOutput.POSITIVE,
|
|
20
|
+
) -> typing.Generator[typing.Optional[AxisArray], AxisArray, None]:
|
|
21
|
+
"""
|
|
22
|
+
Calculate a spectrogram on streaming data.
|
|
23
|
+
|
|
24
|
+
Chains :obj:`ezmsg.sigproc.window.windowing` to apply a moving window on the data,
|
|
25
|
+
:obj:`ezmsg.sigproc.spectrum.spectrum` to calculate spectra for each window,
|
|
26
|
+
and finally :obj:`ezmsg.util.messages.modify.modify_axis` to convert the win axis back to time axis.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
window_dur: See :obj:`ezmsg.sigproc.window.windowing`
|
|
30
|
+
window_shift: See :obj:`ezmsg.sigproc.window.windowing`
|
|
31
|
+
window: See :obj:`ezmsg.sigproc.spectrum.spectrum`
|
|
32
|
+
transform: See :obj:`ezmsg.sigproc.spectrum.spectrum`
|
|
33
|
+
output: See :obj:`ezmsg.sigproc.spectrum.spectrum`
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
A primed generator object that expects `.send(axis_array)` of continuous data
|
|
37
|
+
and yields an AxisArray of time-frequency power values.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
pipeline = compose(
|
|
41
|
+
windowing(
|
|
42
|
+
axis="time", newaxis="win", window_dur=window_dur, window_shift=window_shift
|
|
43
|
+
),
|
|
44
|
+
spectrum(axis="time", window=window, transform=transform, output=output),
|
|
45
|
+
modify_axis(name_map={"win": "time"}),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# State variables
|
|
49
|
+
msg_out: typing.Optional[AxisArray] = None
|
|
50
|
+
|
|
51
|
+
while True:
|
|
52
|
+
msg_in: AxisArray = yield msg_out
|
|
53
|
+
msg_out = pipeline(msg_in)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SpectrogramSettings(ez.Settings):
|
|
57
|
+
"""
|
|
58
|
+
Settings for :obj:`Spectrogram`.
|
|
59
|
+
See :obj:`spectrogram` for a description of the parameters.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
window_dur: typing.Optional[float] = None # window duration in seconds
|
|
63
|
+
window_shift: typing.Optional[float] = None
|
|
64
|
+
""""window step in seconds. If None, window_shift == window_dur"""
|
|
65
|
+
|
|
66
|
+
# See SpectrumSettings for details of following settings:
|
|
67
|
+
window: WindowFunction = WindowFunction.HAMMING
|
|
68
|
+
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
69
|
+
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class Spectrogram(GenAxisArray):
|
|
73
|
+
"""
|
|
74
|
+
Unit for :obj:`spectrogram`.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
SETTINGS = SpectrogramSettings
|
|
78
|
+
|
|
79
|
+
def construct_generator(self):
|
|
80
|
+
self.STATE.gen = spectrogram(
|
|
81
|
+
window_dur=self.SETTINGS.window_dur,
|
|
82
|
+
window_shift=self.SETTINGS.window_shift,
|
|
83
|
+
window=self.SETTINGS.window,
|
|
84
|
+
transform=self.SETTINGS.transform,
|
|
85
|
+
output=self.SETTINGS.output,
|
|
86
|
+
)
|