ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +16 -1
  3. ezmsg/sigproc/activation.py +75 -0
  4. ezmsg/sigproc/affinetransform.py +234 -0
  5. ezmsg/sigproc/aggregate.py +158 -0
  6. ezmsg/sigproc/bandpower.py +74 -0
  7. ezmsg/sigproc/base.py +38 -0
  8. ezmsg/sigproc/butterworthfilter.py +102 -11
  9. ezmsg/sigproc/decimate.py +7 -4
  10. ezmsg/sigproc/downsample.py +95 -51
  11. ezmsg/sigproc/ewmfilter.py +38 -16
  12. ezmsg/sigproc/filter.py +108 -20
  13. ezmsg/sigproc/filterbank.py +278 -0
  14. ezmsg/sigproc/math/__init__.py +0 -0
  15. ezmsg/sigproc/math/abs.py +28 -0
  16. ezmsg/sigproc/math/clip.py +30 -0
  17. ezmsg/sigproc/math/difference.py +60 -0
  18. ezmsg/sigproc/math/invert.py +29 -0
  19. ezmsg/sigproc/math/log.py +32 -0
  20. ezmsg/sigproc/math/scale.py +31 -0
  21. ezmsg/sigproc/messages.py +2 -3
  22. ezmsg/sigproc/sampler.py +259 -224
  23. ezmsg/sigproc/scaler.py +173 -0
  24. ezmsg/sigproc/signalinjector.py +64 -0
  25. ezmsg/sigproc/slicer.py +133 -0
  26. ezmsg/sigproc/spectral.py +6 -132
  27. ezmsg/sigproc/spectrogram.py +86 -0
  28. ezmsg/sigproc/spectrum.py +259 -0
  29. ezmsg/sigproc/synth.py +299 -105
  30. ezmsg/sigproc/wavelets.py +167 -0
  31. ezmsg/sigproc/window.py +254 -116
  32. ezmsg_sigproc-1.3.1.dist-info/METADATA +59 -0
  33. ezmsg_sigproc-1.3.1.dist-info/RECORD +35 -0
  34. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info}/WHEEL +1 -2
  35. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  36. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  37. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  38. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.3.1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -0,0 +1,173 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import ezmsg.core as ez
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.generator import consumer
9
+
10
+ from .base import GenAxisArray
11
+
12
+
13
+ def _tau_from_alpha(alpha: float, dt: float) -> float:
14
+ """
15
+ Inverse of _alpha_from_tau. See that function for explanation.
16
+ """
17
+ return -dt / np.log(1 - alpha)
18
+
19
+
20
+ def _alpha_from_tau(tau: float, dt: float) -> float:
21
+ """
22
+ # https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
23
+ :param tau: The amount of time for the smoothed response of a unit step function to reach
24
+ 1 - 1/e approx-eq 63.2%.
25
+ :param dt: sampling period, or 1 / sampling_rate.
26
+ :return: alpha, the "fading factor" in exponential smoothing.
27
+ """
28
+ return 1 - np.exp(-dt / tau)
29
+
30
+
31
+ @consumer
32
+ def scaler(
33
+ time_constant: float = 1.0, axis: typing.Optional[str] = None
34
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
35
+ """
36
+ Create a generator function that applies the
37
+ adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
38
+ This is faster than :obj:`scaler_np` for single-channel data.
39
+
40
+ Args:
41
+ time_constant: Decay constant `tau` in seconds.
42
+ axis: The name of the axis to accumulate statistics over.
43
+
44
+ Returns:
45
+ A primed generator object that expects `.send(axis_array)` and yields a
46
+ standardized, or "Z-scored" version of the input.
47
+ """
48
+ from river import preprocessing
49
+
50
+ msg_out = AxisArray(np.array([]), dims=[""])
51
+ _scaler = None
52
+ while True:
53
+ msg_in: AxisArray = yield msg_out
54
+ data = msg_in.data
55
+ if axis is None:
56
+ axis = msg_in.dims[0]
57
+ axis_idx = 0
58
+ else:
59
+ axis_idx = msg_in.get_axis_idx(axis)
60
+ if axis_idx != 0:
61
+ data = np.moveaxis(data, axis_idx, 0)
62
+
63
+ if _scaler is None:
64
+ alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
65
+ _scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
66
+
67
+ result = []
68
+ for sample in data:
69
+ x = {k: v for k, v in enumerate(sample.flatten().tolist())}
70
+ _scaler.learn_one(x)
71
+ y = _scaler.transform_one(x)
72
+ k = sorted(y.keys())
73
+ result.append(np.array([y[_] for _ in k]).reshape(sample.shape))
74
+
75
+ result = np.stack(result)
76
+ result = np.moveaxis(result, 0, axis_idx)
77
+ msg_out = replace(msg_in, data=result)
78
+
79
+
80
+ @consumer
81
+ def scaler_np(
82
+ time_constant: float = 1.0, axis: typing.Optional[str] = None
83
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
84
+ """
85
+ Create a generator function that applies an adaptive standard scaler.
86
+ This is faster than :obj:`scaler` for multichannel data.
87
+
88
+ Args:
89
+ time_constant: Decay constant `tau` in seconds.
90
+ axis: The name of the axis to accumulate statistics over.
91
+
92
+ Returns:
93
+ A primed generator object that expects `.send(axis_array)` and yields a
94
+ standardized, or "Z-scored" version of the input.
95
+ """
96
+ msg_out = AxisArray(np.array([]), dims=[""])
97
+
98
+ # State variables
99
+ alpha: float = 0.0
100
+ means: typing.Optional[npt.NDArray] = None
101
+ vars_means: typing.Optional[npt.NDArray] = None
102
+ vars_sq_means: typing.Optional[npt.NDArray] = None
103
+
104
+ # Reset if input changes
105
+ check_input = {
106
+ "gain": None, # Resets alpha
107
+ "shape": None,
108
+ "key": None, # Key change implies buffered means/vars are invalid.
109
+ }
110
+
111
+ def _ew_update(arr, prev, _alpha):
112
+ if np.all(prev == 0):
113
+ return arr
114
+ # return _alpha * arr + (1 - _alpha) * prev
115
+ # Micro-optimization: sub, mult, add (below) is faster than sub, mult, mult, add (above)
116
+ return prev + _alpha * (arr - prev)
117
+
118
+ while True:
119
+ msg_in: AxisArray = yield msg_out
120
+
121
+ axis = axis or msg_in.dims[0]
122
+ axis_idx = msg_in.get_axis_idx(axis)
123
+
124
+ if msg_in.axes[axis].gain != check_input["gain"]:
125
+ alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
126
+ check_input["gain"] = msg_in.axes[axis].gain
127
+
128
+ data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
129
+ b_reset = data.shape[1:] != check_input["shape"]
130
+ b_reset |= msg_in.key != check_input["key"]
131
+ if b_reset:
132
+ check_input["shape"] = data.shape[1:]
133
+ check_input["key"] = msg_in.key
134
+ vars_sq_means = np.zeros_like(data[0], dtype=float)
135
+ vars_means = np.zeros_like(data[0], dtype=float)
136
+ means = np.zeros_like(data[0], dtype=float)
137
+
138
+ result = np.zeros_like(data)
139
+ for sample_ix in range(data.shape[0]):
140
+ sample = data[sample_ix]
141
+ # Update step
142
+ vars_means = _ew_update(sample, vars_means, alpha)
143
+ vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
144
+ means = _ew_update(sample, means, alpha)
145
+ # Get step
146
+ varis = vars_sq_means - vars_means**2
147
+ y = (sample - means) / (varis**0.5)
148
+ result[sample_ix] = y
149
+
150
+ result[np.isnan(result)] = 0.0
151
+ result = np.moveaxis(result, 0, axis_idx)
152
+ msg_out = replace(msg_in, data=result)
153
+
154
+
155
+ class AdaptiveStandardScalerSettings(ez.Settings):
156
+ """
157
+ Settings for :obj:`AdaptiveStandardScaler`.
158
+ See :obj:`scaler_np` for a description of the parameters.
159
+ """
160
+
161
+ time_constant: float = 1.0
162
+ axis: typing.Optional[str] = None
163
+
164
+
165
+ class AdaptiveStandardScaler(GenAxisArray):
166
+ """Unit for :obj:`scaler_np`"""
167
+
168
+ SETTINGS = AdaptiveStandardScalerSettings
169
+
170
+ def construct_generator(self):
171
+ self.STATE.gen = scaler_np(
172
+ time_constant=self.SETTINGS.time_constant, axis=self.SETTINGS.axis
173
+ )
@@ -0,0 +1,64 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+
9
+
10
+ class SignalInjectorSettings(ez.Settings):
11
+ time_dim: str = "time" # Input signal needs a time dimension with units in sec.
12
+ frequency: typing.Optional[float] = None # Hz
13
+ amplitude: float = 1.0
14
+ mixing_seed: typing.Optional[int] = None
15
+
16
+
17
+ class SignalInjectorState(ez.State):
18
+ cur_shape: typing.Optional[typing.Tuple[int, ...]] = None
19
+ cur_frequency: typing.Optional[float] = None
20
+ cur_amplitude: float
21
+ mixing: npt.NDArray
22
+
23
+
24
+ class SignalInjector(ez.Unit):
25
+ SETTINGS = SignalInjectorSettings
26
+ STATE = SignalInjectorState
27
+
28
+ INPUT_FREQUENCY = ez.InputStream(typing.Optional[float])
29
+ INPUT_AMPLITUDE = ez.InputStream(float)
30
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
31
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
32
+
33
+ async def initialize(self) -> None:
34
+ self.STATE.cur_frequency = self.SETTINGS.frequency
35
+ self.STATE.cur_amplitude = self.SETTINGS.amplitude
36
+ self.STATE.mixing = np.array([])
37
+
38
+ @ez.subscriber(INPUT_FREQUENCY)
39
+ async def on_frequency(self, msg: typing.Optional[float]) -> None:
40
+ self.STATE.cur_frequency = msg
41
+
42
+ @ez.subscriber(INPUT_AMPLITUDE)
43
+ async def on_amplitude(self, msg: float) -> None:
44
+ self.STATE.cur_amplitude = msg
45
+
46
+ @ez.subscriber(INPUT_SIGNAL)
47
+ @ez.publisher(OUTPUT_SIGNAL)
48
+ async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
49
+ if self.STATE.cur_shape != msg.shape:
50
+ self.STATE.cur_shape = msg.shape
51
+ rng = np.random.default_rng(self.SETTINGS.mixing_seed)
52
+ self.STATE.mixing = rng.random((1, msg.shape2d(self.SETTINGS.time_dim)[1]))
53
+ self.STATE.mixing = (self.STATE.mixing * 2.0) - 1.0
54
+
55
+ if self.STATE.cur_frequency is None:
56
+ yield self.OUTPUT_SIGNAL, msg
57
+ else:
58
+ out_msg = replace(msg, data=msg.data.copy())
59
+ t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
60
+ signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
61
+ mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
62
+ with out_msg.view2d(self.SETTINGS.time_dim) as view:
63
+ view[...] = view + mixed_signal.astype(view.dtype)
64
+ yield self.OUTPUT_SIGNAL, out_msg
@@ -0,0 +1,133 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import ezmsg.core as ez
7
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
8
+ from ezmsg.util.generator import consumer
9
+
10
+ from .base import GenAxisArray
11
+
12
+
13
+ """
14
+ Slicer:Select a subset of data along a particular axis.
15
+ """
16
+
17
+
18
+ def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
19
+ """
20
+ Parses a string representation of a slice and returns a tuple of slice objects.
21
+
22
+ - "" -> slice(None, None, None) (take all)
23
+ - ":" -> slice(None, None, None)
24
+ - '"none"` (case-insensitive) -> slice(None, None, None)
25
+ - "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
26
+ - "5" (or any integer) -> (5,). Take only that item.
27
+ applying this to a ndarray or AxisArray will drop the dimension.
28
+ - A comma-separated list of the above -> a tuple of slices | ints
29
+
30
+ Args:
31
+ s: The string representation of the slice.
32
+
33
+ Returns:
34
+ A tuple of slice objects and/or ints.
35
+ """
36
+ if s.lower() in ["", ":", "none"]:
37
+ return (slice(None),)
38
+ if "," not in s:
39
+ parts = [part.strip() for part in s.split(":")]
40
+ if len(parts) == 1:
41
+ return (int(parts[0]),)
42
+ return (slice(*(int(part.strip()) if part else None for part in parts)),)
43
+ suplist = [parse_slice(_) for _ in s.split(",")]
44
+ return tuple([item for sublist in suplist for item in sublist])
45
+
46
+
47
+ @consumer
48
+ def slicer(
49
+ selection: str = "", axis: typing.Optional[str] = None
50
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
51
+ msg_out = AxisArray(np.array([]), dims=[""])
52
+
53
+ # State variables
54
+ _slice: typing.Optional[typing.Union[slice, npt.NDArray]] = None
55
+ new_axis: typing.Optional[AxisArray.Axis] = None
56
+ b_change_dims: bool = False # If number of dimensions changes when slicing
57
+
58
+ # Reset if input changes
59
+ check_input = {
60
+ "key": None, # key change used as proxy for label change, which we don't check explicitly
61
+ "len": None,
62
+ }
63
+
64
+ while True:
65
+ msg_in: AxisArray = yield msg_out
66
+
67
+ axis = axis or msg_in.dims[-1]
68
+ axis_idx = msg_in.get_axis_idx(axis)
69
+
70
+ b_reset = _slice is None # or new_axis is None
71
+ b_reset = b_reset or msg_in.key != check_input["key"]
72
+ b_reset = b_reset or (
73
+ (msg_in.data.shape[axis_idx] != check_input["len"])
74
+ and (type(_slice) is np.ndarray)
75
+ )
76
+ if b_reset:
77
+ check_input["key"] = msg_in.key
78
+ check_input["len"] = msg_in.data.shape[axis_idx]
79
+ new_axis = None # Will hold updated metadata
80
+ b_change_dims = False
81
+
82
+ # Calculate the slice
83
+ _slices = parse_slice(selection)
84
+ if len(_slices) == 1:
85
+ _slice = _slices[0]
86
+ # Do we drop the sliced dimension?
87
+ b_change_dims = isinstance(_slice, int)
88
+ else:
89
+ # Multiple slices, but this cannot be done in a single step, so we convert the slices
90
+ # to a discontinuous set of integer indexes.
91
+ indices = np.arange(msg_in.data.shape[axis_idx])
92
+ indices = np.hstack([indices[_] for _ in _slices])
93
+ _slice = np.s_[indices] # Integer scalar array
94
+
95
+ # Create the output axis.
96
+ if (
97
+ axis in msg_in.axes
98
+ and hasattr(msg_in.axes[axis], "labels")
99
+ and len(msg_in.axes[axis].labels) > 0
100
+ ):
101
+ new_labels = msg_in.axes[axis].labels[_slice]
102
+ new_axis = replace(msg_in.axes[axis], labels=new_labels)
103
+
104
+ replace_kwargs = {}
105
+ if b_change_dims:
106
+ # Dropping the target axis
107
+ replace_kwargs["dims"] = [
108
+ _ for dim_ix, _ in enumerate(msg_in.dims) if dim_ix != axis_idx
109
+ ]
110
+ replace_kwargs["axes"] = {k: v for k, v in msg_in.axes.items() if k != axis}
111
+ elif new_axis is not None:
112
+ replace_kwargs["axes"] = {
113
+ k: (v if k != axis else new_axis) for k, v in msg_in.axes.items()
114
+ }
115
+ msg_out = replace(
116
+ msg_in,
117
+ data=slice_along_axis(msg_in.data, _slice, axis_idx),
118
+ **replace_kwargs,
119
+ )
120
+
121
+
122
+ class SlicerSettings(ez.Settings):
123
+ selection: str = ""
124
+ axis: typing.Optional[str] = None
125
+
126
+
127
+ class Slicer(GenAxisArray):
128
+ SETTINGS = SlicerSettings
129
+
130
+ def construct_generator(self):
131
+ self.STATE.gen = slicer(
132
+ selection=self.SETTINGS.selection, axis=self.SETTINGS.axis
133
+ )
ezmsg/sigproc/spectral.py CHANGED
@@ -1,132 +1,6 @@
1
- import enum
2
-
3
- from dataclasses import replace
4
-
5
- import numpy as np
6
- import ezmsg.core as ez
7
-
8
- from ezmsg.util.messages.axisarray import AxisArray
9
-
10
- from typing import Optional, AsyncGenerator
11
-
12
-
13
- class OptionsEnum(enum.Enum):
14
- @classmethod
15
- def options(cls):
16
- return list(map(lambda c: c.value, cls))
17
-
18
-
19
- class WindowFunction(OptionsEnum):
20
- NONE = "None (Rectangular)"
21
- HAMMING = "Hamming"
22
- HANNING = "Hanning"
23
- BARTLETT = "Bartlett"
24
- BLACKMAN = "Blackman"
25
-
26
-
27
- WINDOWS = {
28
- WindowFunction.NONE: np.ones,
29
- WindowFunction.HAMMING: np.hamming,
30
- WindowFunction.HANNING: np.hanning,
31
- WindowFunction.BARTLETT: np.bartlett,
32
- WindowFunction.BLACKMAN: np.blackman,
33
- }
34
-
35
-
36
- class SpectralTransform(OptionsEnum):
37
- RAW_COMPLEX = "Complex FFT Output"
38
- REAL = "Real Component of FFT"
39
- IMAG = "Imaginary Component of FFT"
40
- REL_POWER = "Relative Power"
41
- REL_DB = "Log Power (Relative dB)"
42
-
43
-
44
- class SpectralOutput(OptionsEnum):
45
- FULL = "Full Spectrum"
46
- POSITIVE = "Positive Frequencies"
47
- NEGATIVE = "Negative Frequencies"
48
-
49
-
50
- class SpectrumSettings(ez.Settings):
51
- axis: Optional[str] = None
52
- # n: Optional[int] = None # n parameter for fft
53
- out_axis: Optional[str] = "freq" # If none; don't change dim name
54
- window: WindowFunction = WindowFunction.HAMMING
55
- transform: SpectralTransform = SpectralTransform.REL_DB
56
- output: SpectralOutput = SpectralOutput.POSITIVE
57
-
58
-
59
- class SpectrumState(ez.State):
60
- cur_settings: SpectrumSettings
61
-
62
-
63
- class Spectrum(ez.Unit):
64
- SETTINGS: SpectrumSettings
65
- STATE: SpectrumState
66
-
67
- INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
68
- INPUT_SIGNAL = ez.InputStream(AxisArray)
69
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
70
-
71
- def initialize(self) -> None:
72
- self.STATE.cur_settings = self.SETTINGS
73
-
74
- @ez.subscriber(INPUT_SETTINGS)
75
- async def on_settings(self, msg: SpectrumSettings):
76
- self.STATE.cur_settings = msg
77
-
78
- @ez.subscriber(INPUT_SIGNAL)
79
- @ez.publisher(OUTPUT_SIGNAL)
80
- async def on_data(self, message: AxisArray) -> AsyncGenerator:
81
- axis_name = self.STATE.cur_settings.axis
82
- if axis_name is None:
83
- axis_name = message.dims[0]
84
- axis_idx = message.get_axis_idx(axis_name)
85
- axis = message.get_axis(axis_name)
86
-
87
- spectrum = np.moveaxis(message.data, axis_idx, -1)
88
-
89
- n_time = message.data.shape[axis_idx]
90
- window = WINDOWS[self.STATE.cur_settings.window](n_time)
91
-
92
- spectrum = np.fft.fft(spectrum * window) / n_time
93
- spectrum = np.fft.fftshift(spectrum, axes=-1)
94
- freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=axis.gain), axes=-1)
95
-
96
- if self.STATE.cur_settings.transform != SpectralTransform.RAW_COMPLEX:
97
- if self.STATE.cur_settings.transform == SpectralTransform.REAL:
98
- spectrum = spectrum.real
99
- elif self.STATE.cur_settings.transform == SpectralTransform.IMAG:
100
- spectrum = spectrum.imag
101
- else:
102
- scale = np.sum(window**2.0) * axis.gain
103
- spectrum = (2.0 * (np.abs(spectrum) ** 2.0)) / scale
104
-
105
- if self.STATE.cur_settings.transform == SpectralTransform.REL_DB:
106
- spectrum = 10 * np.log10(spectrum)
107
-
108
- axis_offset = freqs[0]
109
- if self.STATE.cur_settings.output == SpectralOutput.POSITIVE:
110
- axis_offset = freqs[n_time // 2]
111
- spectrum = spectrum[..., n_time // 2 :]
112
- elif self.STATE.cur_settings.output == SpectralOutput.NEGATIVE:
113
- spectrum = spectrum[..., : n_time // 2]
114
-
115
- spectrum = np.moveaxis(spectrum, axis_idx, -1)
116
-
117
- out_axis = self.SETTINGS.out_axis
118
- if out_axis is None:
119
- out_axis = axis_name
120
-
121
- freq_axis = AxisArray.Axis(
122
- unit="Hz", gain=1.0 / (axis.gain * n_time), offset=axis_offset
123
- )
124
- new_axes = {**message.axes, **{out_axis: freq_axis}}
125
-
126
- new_dims = [d for d in message.dims]
127
- if self.SETTINGS.out_axis is not None:
128
- new_dims[axis_idx] = self.SETTINGS.out_axis
129
-
130
- out_msg = replace(message, data=spectrum, dims=new_dims, axes=new_axes)
131
-
132
- yield self.OUTPUT_SIGNAL, out_msg
1
+ from .spectrum import OptionsEnum as OptionsEnum
2
+ from .spectrum import WindowFunction as WindowFunction
3
+ from .spectrum import SpectralTransform as SpectralTransform
4
+ from .spectrum import SpectralOutput as SpectralOutput
5
+ from .spectrum import SpectrumSettings as SpectrumSettings
6
+ from .spectrum import Spectrum as Spectrum
@@ -0,0 +1,86 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ from ezmsg.util.messages.axisarray import AxisArray
5
+ from ezmsg.util.generator import consumer, compose
6
+ from ezmsg.util.messages.modify import modify_axis
7
+
8
+ from .window import windowing
9
+ from .spectrum import spectrum, WindowFunction, SpectralTransform, SpectralOutput
10
+ from .base import GenAxisArray
11
+
12
+
13
+ @consumer
14
+ def spectrogram(
15
+ window_dur: typing.Optional[float] = None,
16
+ window_shift: typing.Optional[float] = None,
17
+ window: WindowFunction = WindowFunction.HANNING,
18
+ transform: SpectralTransform = SpectralTransform.REL_DB,
19
+ output: SpectralOutput = SpectralOutput.POSITIVE,
20
+ ) -> typing.Generator[typing.Optional[AxisArray], AxisArray, None]:
21
+ """
22
+ Calculate a spectrogram on streaming data.
23
+
24
+ Chains :obj:`ezmsg.sigproc.window.windowing` to apply a moving window on the data,
25
+ :obj:`ezmsg.sigproc.spectrum.spectrum` to calculate spectra for each window,
26
+ and finally :obj:`ezmsg.util.messages.modify.modify_axis` to convert the win axis back to time axis.
27
+
28
+ Args:
29
+ window_dur: See :obj:`ezmsg.sigproc.window.windowing`
30
+ window_shift: See :obj:`ezmsg.sigproc.window.windowing`
31
+ window: See :obj:`ezmsg.sigproc.spectrum.spectrum`
32
+ transform: See :obj:`ezmsg.sigproc.spectrum.spectrum`
33
+ output: See :obj:`ezmsg.sigproc.spectrum.spectrum`
34
+
35
+ Returns:
36
+ A primed generator object that expects `.send(axis_array)` of continuous data
37
+ and yields an AxisArray of time-frequency power values.
38
+ """
39
+
40
+ pipeline = compose(
41
+ windowing(
42
+ axis="time", newaxis="win", window_dur=window_dur, window_shift=window_shift
43
+ ),
44
+ spectrum(axis="time", window=window, transform=transform, output=output),
45
+ modify_axis(name_map={"win": "time"}),
46
+ )
47
+
48
+ # State variables
49
+ msg_out: typing.Optional[AxisArray] = None
50
+
51
+ while True:
52
+ msg_in: AxisArray = yield msg_out
53
+ msg_out = pipeline(msg_in)
54
+
55
+
56
+ class SpectrogramSettings(ez.Settings):
57
+ """
58
+ Settings for :obj:`Spectrogram`.
59
+ See :obj:`spectrogram` for a description of the parameters.
60
+ """
61
+
62
+ window_dur: typing.Optional[float] = None # window duration in seconds
63
+ window_shift: typing.Optional[float] = None
64
+ """"window step in seconds. If None, window_shift == window_dur"""
65
+
66
+ # See SpectrumSettings for details of following settings:
67
+ window: WindowFunction = WindowFunction.HAMMING
68
+ transform: SpectralTransform = SpectralTransform.REL_DB
69
+ output: SpectralOutput = SpectralOutput.POSITIVE
70
+
71
+
72
+ class Spectrogram(GenAxisArray):
73
+ """
74
+ Unit for :obj:`spectrogram`.
75
+ """
76
+
77
+ SETTINGS = SpectrogramSettings
78
+
79
+ def construct_generator(self):
80
+ self.STATE.gen = spectrogram(
81
+ window_dur=self.SETTINGS.window_dur,
82
+ window_shift=self.SETTINGS.window_shift,
83
+ window=self.SETTINGS.window,
84
+ transform=self.SETTINGS.transform,
85
+ output=self.SETTINGS.output,
86
+ )