ezmsg-sigproc 1.2.2__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ezmsg/sigproc/__init__.py +1 -1
  2. ezmsg/sigproc/__version__.py +34 -1
  3. ezmsg/sigproc/activation.py +78 -0
  4. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  5. ezmsg/sigproc/affinetransform.py +235 -0
  6. ezmsg/sigproc/aggregate.py +276 -0
  7. ezmsg/sigproc/bandpower.py +80 -0
  8. ezmsg/sigproc/base.py +149 -0
  9. ezmsg/sigproc/butterworthfilter.py +129 -39
  10. ezmsg/sigproc/butterworthzerophase.py +305 -0
  11. ezmsg/sigproc/cheby.py +125 -0
  12. ezmsg/sigproc/combfilter.py +160 -0
  13. ezmsg/sigproc/coordinatespaces.py +159 -0
  14. ezmsg/sigproc/decimate.py +46 -18
  15. ezmsg/sigproc/denormalize.py +78 -0
  16. ezmsg/sigproc/detrend.py +28 -0
  17. ezmsg/sigproc/diff.py +82 -0
  18. ezmsg/sigproc/downsample.py +97 -49
  19. ezmsg/sigproc/ewma.py +217 -0
  20. ezmsg/sigproc/ewmfilter.py +45 -19
  21. ezmsg/sigproc/extract_axis.py +39 -0
  22. ezmsg/sigproc/fbcca.py +307 -0
  23. ezmsg/sigproc/filter.py +282 -117
  24. ezmsg/sigproc/filterbank.py +292 -0
  25. ezmsg/sigproc/filterbankdesign.py +129 -0
  26. ezmsg/sigproc/fir_hilbert.py +336 -0
  27. ezmsg/sigproc/fir_pmc.py +209 -0
  28. ezmsg/sigproc/firfilter.py +117 -0
  29. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  30. ezmsg/sigproc/kaiser.py +106 -0
  31. ezmsg/sigproc/linear.py +120 -0
  32. ezmsg/sigproc/math/__init__.py +0 -0
  33. ezmsg/sigproc/math/abs.py +35 -0
  34. ezmsg/sigproc/math/add.py +120 -0
  35. ezmsg/sigproc/math/clip.py +48 -0
  36. ezmsg/sigproc/math/difference.py +143 -0
  37. ezmsg/sigproc/math/invert.py +28 -0
  38. ezmsg/sigproc/math/log.py +57 -0
  39. ezmsg/sigproc/math/scale.py +39 -0
  40. ezmsg/sigproc/messages.py +3 -6
  41. ezmsg/sigproc/quantize.py +68 -0
  42. ezmsg/sigproc/resample.py +278 -0
  43. ezmsg/sigproc/rollingscaler.py +232 -0
  44. ezmsg/sigproc/sampler.py +232 -241
  45. ezmsg/sigproc/scaler.py +165 -0
  46. ezmsg/sigproc/signalinjector.py +70 -0
  47. ezmsg/sigproc/slicer.py +138 -0
  48. ezmsg/sigproc/spectral.py +6 -132
  49. ezmsg/sigproc/spectrogram.py +90 -0
  50. ezmsg/sigproc/spectrum.py +277 -0
  51. ezmsg/sigproc/transpose.py +134 -0
  52. ezmsg/sigproc/util/__init__.py +0 -0
  53. ezmsg/sigproc/util/asio.py +25 -0
  54. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  55. ezmsg/sigproc/util/buffer.py +449 -0
  56. ezmsg/sigproc/util/message.py +17 -0
  57. ezmsg/sigproc/util/profile.py +23 -0
  58. ezmsg/sigproc/util/sparse.py +115 -0
  59. ezmsg/sigproc/util/typeresolution.py +17 -0
  60. ezmsg/sigproc/wavelets.py +187 -0
  61. ezmsg/sigproc/window.py +301 -117
  62. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  63. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  64. {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -2
  65. ezmsg/sigproc/synth.py +0 -411
  66. ezmsg_sigproc-1.2.2.dist-info/METADATA +0 -36
  67. ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
  68. ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
  69. /ezmsg_sigproc-1.2.2.dist-info/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -0,0 +1,165 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
6
+ BaseStatefulTransformer,
7
+ BaseTransformerUnit,
8
+ processor_state,
9
+ )
10
+ from ezmsg.util.generator import consumer
11
+ from ezmsg.util.messages.axisarray import AxisArray
12
+ from ezmsg.util.messages.util import replace
13
+
14
+ # Imports for backwards compatibility with previous module location
15
+ from .ewma import EWMA_Deprecated as EWMA_Deprecated
16
+ from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
17
+ from .ewma import _tau_from_alpha as _tau_from_alpha
18
+ from .ewma import ewma_step as ewma_step
19
+
20
+
21
+ @consumer
22
+ def scaler(time_constant: float = 1.0, axis: str | None = None) -> typing.Generator[AxisArray, AxisArray, None]:
23
+ """
24
+ Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
25
+ This is faster than :obj:`scaler_np` for single-channel data.
26
+
27
+ Args:
28
+ time_constant: Decay constant `tau` in seconds.
29
+ axis: The name of the axis to accumulate statistics over.
30
+
31
+ Returns:
32
+ A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
33
+ and yields an :obj:`AxisArray` with its data being a standardized, or "Z-scored" version of the input data.
34
+ """
35
+ from river import preprocessing
36
+
37
+ msg_out = AxisArray(np.array([]), dims=[""])
38
+ _scaler = None
39
+ while True:
40
+ msg_in: AxisArray = yield msg_out
41
+ data = msg_in.data
42
+ if axis is None:
43
+ axis = msg_in.dims[0]
44
+ axis_idx = 0
45
+ else:
46
+ axis_idx = msg_in.get_axis_idx(axis)
47
+ if axis_idx != 0:
48
+ data = np.moveaxis(data, axis_idx, 0)
49
+
50
+ if _scaler is None:
51
+ alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
52
+ _scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
53
+
54
+ result = []
55
+ for sample in data:
56
+ x = {k: v for k, v in enumerate(sample.flatten().tolist())}
57
+ _scaler.learn_one(x)
58
+ y = _scaler.transform_one(x)
59
+ k = sorted(y.keys())
60
+ result.append(np.array([y[_] for _ in k]).reshape(sample.shape))
61
+
62
+ result = np.stack(result)
63
+ result = np.moveaxis(result, 0, axis_idx)
64
+ msg_out = replace(msg_in, data=result)
65
+
66
+
67
+ class AdaptiveStandardScalerSettings(EWMASettings): ...
68
+
69
+
70
+ @processor_state
71
+ class AdaptiveStandardScalerState:
72
+ samps_ewma: EWMATransformer | None = None
73
+ vars_sq_ewma: EWMATransformer | None = None
74
+ alpha: float | None = None
75
+
76
+
77
+ class AdaptiveStandardScalerTransformer(
78
+ BaseStatefulTransformer[
79
+ AdaptiveStandardScalerSettings,
80
+ AxisArray,
81
+ AxisArray,
82
+ AdaptiveStandardScalerState,
83
+ ]
84
+ ):
85
+ def _reset_state(self, message: AxisArray) -> None:
86
+ self._state.samps_ewma = EWMATransformer(
87
+ time_constant=self.settings.time_constant,
88
+ axis=self.settings.axis,
89
+ accumulate=self.settings.accumulate,
90
+ )
91
+ self._state.vars_sq_ewma = EWMATransformer(
92
+ time_constant=self.settings.time_constant,
93
+ axis=self.settings.axis,
94
+ accumulate=self.settings.accumulate,
95
+ )
96
+
97
+ @property
98
+ def accumulate(self) -> bool:
99
+ """Whether to accumulate statistics from incoming samples."""
100
+ return self.settings.accumulate
101
+
102
+ @accumulate.setter
103
+ def accumulate(self, value: bool) -> None:
104
+ """
105
+ Set the accumulate mode and propagate to child EWMA transformers.
106
+
107
+ Args:
108
+ value: If True, update statistics with each sample.
109
+ If False, only apply current statistics without updating.
110
+ """
111
+ if self._state.samps_ewma is not None:
112
+ self._state.samps_ewma.settings = replace(self._state.samps_ewma.settings, accumulate=value)
113
+ if self._state.vars_sq_ewma is not None:
114
+ self._state.vars_sq_ewma.settings = replace(self._state.vars_sq_ewma.settings, accumulate=value)
115
+
116
+ def _process(self, message: AxisArray) -> AxisArray:
117
+ # Update step (respects accumulate setting via child EWMAs)
118
+ mean_message = self._state.samps_ewma(message)
119
+ var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
120
+
121
+ # Get step
122
+ varis = var_sq_message.data - mean_message.data**2
123
+ with np.errstate(divide="ignore", invalid="ignore"):
124
+ result = (message.data - mean_message.data) / (varis**0.5)
125
+ result[np.isnan(result)] = 0.0
126
+ return replace(message, data=result)
127
+
128
+
129
+ class AdaptiveStandardScaler(
130
+ BaseTransformerUnit[
131
+ AdaptiveStandardScalerSettings,
132
+ AxisArray,
133
+ AxisArray,
134
+ AdaptiveStandardScalerTransformer,
135
+ ]
136
+ ):
137
+ SETTINGS = AdaptiveStandardScalerSettings
138
+
139
+ @ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
140
+ async def on_settings(self, msg: AdaptiveStandardScalerSettings) -> None:
141
+ """
142
+ Handle settings updates with smart reset behavior.
143
+
144
+ Only resets state if `axis` changes (structural change).
145
+ Changes to `time_constant` or `accumulate` are applied without
146
+ resetting accumulated statistics.
147
+ """
148
+ old_axis = self.SETTINGS.axis
149
+ self.apply_settings(msg)
150
+
151
+ if msg.axis != old_axis:
152
+ # Axis changed - need full reset
153
+ self.create_processor()
154
+ else:
155
+ # Update accumulate on processor (propagates to child EWMAs)
156
+ self.processor.accumulate = msg.accumulate
157
+ # Also update own settings reference
158
+ self.processor.settings = msg
159
+
160
+
161
+ # Backwards compatibility...
162
+ def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer:
163
+ return AdaptiveStandardScalerTransformer(
164
+ settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis)
165
+ )
@@ -0,0 +1,70 @@
1
+ import ezmsg.core as ez
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ from ezmsg.baseproc import (
5
+ BaseAsyncTransformer,
6
+ BaseTransformerUnit,
7
+ processor_state,
8
+ )
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
11
+
12
+
13
+ class SignalInjectorSettings(ez.Settings):
14
+ time_dim: str = "time" # Input signal needs a time dimension with units in sec.
15
+ frequency: float | None = None # Hz
16
+ amplitude: float = 1.0
17
+ mixing_seed: int | None = None
18
+
19
+
20
+ @processor_state
21
+ class SignalInjectorState:
22
+ cur_shape: tuple[int, ...] | None = None
23
+ cur_frequency: float | None = None
24
+ cur_amplitude: float | None = None
25
+ mixing: npt.NDArray | None = None
26
+
27
+
28
+ class SignalInjectorTransformer(
29
+ BaseAsyncTransformer[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorState]
30
+ ):
31
+ def _hash_message(self, message: AxisArray) -> int:
32
+ time_ax_idx = message.get_axis_idx(self.settings.time_dim)
33
+ sample_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
34
+ return hash((message.key,) + sample_shape)
35
+
36
+ def _reset_state(self, message: AxisArray) -> None:
37
+ if self._state.cur_frequency is None:
38
+ self._state.cur_frequency = self.settings.frequency
39
+ if self._state.cur_amplitude is None:
40
+ self._state.cur_amplitude = self.settings.amplitude
41
+ time_ax_idx = message.get_axis_idx(self.settings.time_dim)
42
+ self._state.cur_shape = message.data.shape[:time_ax_idx] + message.data.shape[time_ax_idx + 1 :]
43
+ rng = np.random.default_rng(self.settings.mixing_seed)
44
+ self._state.mixing = rng.random((1, message.shape2d(self.settings.time_dim)[1]))
45
+ self._state.mixing = (self._state.mixing * 2.0) - 1.0
46
+
47
+ async def _aprocess(self, message: AxisArray) -> AxisArray:
48
+ if self._state.cur_frequency is None:
49
+ return message
50
+ out_msg = replace(message, data=message.data.copy())
51
+ t = out_msg.ax(self.settings.time_dim).values[..., np.newaxis]
52
+ signal = np.sin(2 * np.pi * self._state.cur_frequency * t)
53
+ mixed_signal = signal * self._state.mixing * self._state.cur_amplitude
54
+ with out_msg.view2d(self.settings.time_dim) as view:
55
+ view[...] = view + mixed_signal.astype(view.dtype)
56
+ return out_msg
57
+
58
+
59
+ class SignalInjector(BaseTransformerUnit[SignalInjectorSettings, AxisArray, AxisArray, SignalInjectorTransformer]):
60
+ SETTINGS = SignalInjectorSettings
61
+ INPUT_FREQUENCY = ez.InputStream(float | None)
62
+ INPUT_AMPLITUDE = ez.InputStream(float)
63
+
64
+ @ez.subscriber(INPUT_FREQUENCY)
65
+ async def on_frequency(self, msg: float | None) -> None:
66
+ self.processor.state.cur_frequency = msg
67
+
68
+ @ez.subscriber(INPUT_AMPLITUDE)
69
+ async def on_amplitude(self, msg: float) -> None:
70
+ self.processor.state.cur_amplitude = msg
@@ -0,0 +1,138 @@
1
+ import ezmsg.core as ez
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ from ezmsg.baseproc import (
5
+ BaseStatefulTransformer,
6
+ BaseTransformerUnit,
7
+ processor_state,
8
+ )
9
+ from ezmsg.util.messages.axisarray import (
10
+ AxisArray,
11
+ AxisBase,
12
+ replace,
13
+ slice_along_axis,
14
+ )
15
+
16
+ """
17
+ Slicer:Select a subset of data along a particular axis.
18
+ """
19
+
20
+
21
+ def parse_slice(
22
+ s: str,
23
+ axinfo: AxisArray.CoordinateAxis | None = None,
24
+ ) -> tuple[slice | int, ...]:
25
+ """
26
+ Parses a string representation of a slice and returns a tuple of slice objects.
27
+
28
+ - "" -> slice(None, None, None) (take all)
29
+ - ":" -> slice(None, None, None)
30
+ - '"none"` (case-insensitive) -> slice(None, None, None)
31
+ - "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
32
+ - "5" (or any integer) -> (5,). Take only that item.
33
+ applying this to a ndarray or AxisArray will drop the dimension.
34
+ - A comma-separated list of the above -> a tuple of slices | ints
35
+ - A comma-separated list of values and axinfo is provided and is a CoordinateAxis -> a tuple of ints
36
+
37
+ Args:
38
+ s: The string representation of the slice.
39
+ axinfo: (Optional) If provided, and of type CoordinateAxis,
40
+ and `s` is a comma-separated list of values, then the values
41
+ in s will be checked against the values in axinfo.data.
42
+
43
+ Returns:
44
+ A tuple of slice objects and/or ints.
45
+ """
46
+ if s.lower() in ["", ":", "none"]:
47
+ return (slice(None),)
48
+ if "," not in s:
49
+ parts = [part.strip() for part in s.split(":")]
50
+ if len(parts) == 1:
51
+ if axinfo is not None and hasattr(axinfo, "data") and parts[0] in axinfo.data:
52
+ return tuple(np.where(axinfo.data == parts[0])[0])
53
+ return (int(parts[0]),)
54
+ return (slice(*(int(part.strip()) if part else None for part in parts)),)
55
+ suplist = [parse_slice(_, axinfo=axinfo) for _ in s.split(",")]
56
+ return tuple([item for sublist in suplist for item in sublist])
57
+
58
+
59
+ class SlicerSettings(ez.Settings):
60
+ selection: str = ""
61
+ """selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details."""
62
+
63
+ axis: str | None = None
64
+ """The name of the axis to slice along. If None, the last axis is used."""
65
+
66
+
67
+ @processor_state
68
+ class SlicerState:
69
+ slice_: slice | int | npt.NDArray | None = None
70
+ new_axis: AxisBase | None = None
71
+ b_change_dims: bool = False
72
+
73
+
74
+ class SlicerTransformer(BaseStatefulTransformer[SlicerSettings, AxisArray, AxisArray, SlicerState]):
75
+ def _hash_message(self, message: AxisArray) -> int:
76
+ axis = self.settings.axis or message.dims[-1]
77
+ axis_idx = message.get_axis_idx(axis)
78
+ return hash((message.key, message.data.shape[axis_idx]))
79
+
80
+ def _reset_state(self, message: AxisArray) -> None:
81
+ axis = self.settings.axis or message.dims[-1]
82
+ axis_idx = message.get_axis_idx(axis)
83
+ self._state.new_axis = None
84
+ self._state.b_change_dims = False
85
+
86
+ # Calculate the slice
87
+ _slices = parse_slice(self.settings.selection, message.axes.get(axis, None))
88
+ if len(_slices) == 1:
89
+ self._state.slice_ = _slices[0]
90
+ self._state.b_change_dims = isinstance(self._state.slice_, int)
91
+ else:
92
+ indices = np.arange(message.data.shape[axis_idx])
93
+ indices = np.hstack([indices[_] for _ in _slices])
94
+ self._state.slice_ = np.s_[indices]
95
+
96
+ # Create the output axis
97
+ if axis in message.axes and hasattr(message.axes[axis], "data") and len(message.axes[axis].data) > 0:
98
+ in_data = np.array(message.axes[axis].data)
99
+ if self._state.b_change_dims:
100
+ out_data = in_data[self._state.slice_ : self._state.slice_ + 1]
101
+ else:
102
+ out_data = in_data[self._state.slice_]
103
+ self._state.new_axis = replace(message.axes[axis], data=out_data)
104
+
105
+ def _process(self, message: AxisArray) -> AxisArray:
106
+ axis = self.settings.axis or message.dims[-1]
107
+ axis_idx = message.get_axis_idx(axis)
108
+
109
+ replace_kwargs = {}
110
+ if self._state.b_change_dims:
111
+ replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(message.dims) if dim_ix != axis_idx]
112
+ replace_kwargs["axes"] = {k: v for k, v in message.axes.items() if k != axis}
113
+ elif self._state.new_axis is not None:
114
+ replace_kwargs["axes"] = {k: (v if k != axis else self._state.new_axis) for k, v in message.axes.items()}
115
+
116
+ return replace(
117
+ message,
118
+ data=slice_along_axis(message.data, self._state.slice_, axis_idx),
119
+ **replace_kwargs,
120
+ )
121
+
122
+
123
+ class Slicer(BaseTransformerUnit[SlicerSettings, AxisArray, AxisArray, SlicerTransformer]):
124
+ SETTINGS = SlicerSettings
125
+
126
+
127
+ def slicer(selection: str = "", axis: str | None = None) -> SlicerTransformer:
128
+ """
129
+ Slice along a particular axis.
130
+
131
+ Args:
132
+ selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
133
+ axis: The name of the axis to slice along. If None, the last axis is used.
134
+
135
+ Returns:
136
+ :obj:`SlicerTransformer`
137
+ """
138
+ return SlicerTransformer(SlicerSettings(selection=selection, axis=axis))
ezmsg/sigproc/spectral.py CHANGED
@@ -1,132 +1,6 @@
1
- import enum
2
-
3
- from dataclasses import replace
4
-
5
- import numpy as np
6
- import ezmsg.core as ez
7
-
8
- from ezmsg.util.messages.axisarray import AxisArray
9
-
10
- from typing import Optional, AsyncGenerator
11
-
12
-
13
- class OptionsEnum(enum.Enum):
14
- @classmethod
15
- def options(cls):
16
- return list(map(lambda c: c.value, cls))
17
-
18
-
19
- class WindowFunction(OptionsEnum):
20
- NONE = "None (Rectangular)"
21
- HAMMING = "Hamming"
22
- HANNING = "Hanning"
23
- BARTLETT = "Bartlett"
24
- BLACKMAN = "Blackman"
25
-
26
-
27
- WINDOWS = {
28
- WindowFunction.NONE: np.ones,
29
- WindowFunction.HAMMING: np.hamming,
30
- WindowFunction.HANNING: np.hanning,
31
- WindowFunction.BARTLETT: np.bartlett,
32
- WindowFunction.BLACKMAN: np.blackman,
33
- }
34
-
35
-
36
- class SpectralTransform(OptionsEnum):
37
- RAW_COMPLEX = "Complex FFT Output"
38
- REAL = "Real Component of FFT"
39
- IMAG = "Imaginary Component of FFT"
40
- REL_POWER = "Relative Power"
41
- REL_DB = "Log Power (Relative dB)"
42
-
43
-
44
- class SpectralOutput(OptionsEnum):
45
- FULL = "Full Spectrum"
46
- POSITIVE = "Positive Frequencies"
47
- NEGATIVE = "Negative Frequencies"
48
-
49
-
50
- class SpectrumSettings(ez.Settings):
51
- axis: Optional[str] = None
52
- # n: Optional[int] = None # n parameter for fft
53
- out_axis: Optional[str] = "freq" # If none; don't change dim name
54
- window: WindowFunction = WindowFunction.HAMMING
55
- transform: SpectralTransform = SpectralTransform.REL_DB
56
- output: SpectralOutput = SpectralOutput.POSITIVE
57
-
58
-
59
- class SpectrumState(ez.State):
60
- cur_settings: SpectrumSettings
61
-
62
-
63
- class Spectrum(ez.Unit):
64
- SETTINGS: SpectrumSettings
65
- STATE: SpectrumState
66
-
67
- INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
68
- INPUT_SIGNAL = ez.InputStream(AxisArray)
69
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
70
-
71
- def initialize(self) -> None:
72
- self.STATE.cur_settings = self.SETTINGS
73
-
74
- @ez.subscriber(INPUT_SETTINGS)
75
- async def on_settings(self, msg: SpectrumSettings):
76
- self.STATE.cur_settings = msg
77
-
78
- @ez.subscriber(INPUT_SIGNAL)
79
- @ez.publisher(OUTPUT_SIGNAL)
80
- async def on_data(self, message: AxisArray) -> AsyncGenerator:
81
- axis_name = self.STATE.cur_settings.axis
82
- if axis_name is None:
83
- axis_name = message.dims[0]
84
- axis_idx = message.get_axis_idx(axis_name)
85
- axis = message.get_axis(axis_name)
86
-
87
- spectrum = np.moveaxis(message.data, axis_idx, -1)
88
-
89
- n_time = message.data.shape[axis_idx]
90
- window = WINDOWS[self.STATE.cur_settings.window](n_time)
91
-
92
- spectrum = np.fft.fft(spectrum * window) / n_time
93
- spectrum = np.fft.fftshift(spectrum, axes=-1)
94
- freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=axis.gain), axes=-1)
95
-
96
- if self.STATE.cur_settings.transform != SpectralTransform.RAW_COMPLEX:
97
- if self.STATE.cur_settings.transform == SpectralTransform.REAL:
98
- spectrum = spectrum.real
99
- elif self.STATE.cur_settings.transform == SpectralTransform.IMAG:
100
- spectrum = spectrum.imag
101
- else:
102
- scale = np.sum(window**2.0) * axis.gain
103
- spectrum = (2.0 * (np.abs(spectrum) ** 2.0)) / scale
104
-
105
- if self.STATE.cur_settings.transform == SpectralTransform.REL_DB:
106
- spectrum = 10 * np.log10(spectrum)
107
-
108
- axis_offset = freqs[0]
109
- if self.STATE.cur_settings.output == SpectralOutput.POSITIVE:
110
- axis_offset = freqs[n_time // 2]
111
- spectrum = spectrum[..., n_time // 2 :]
112
- elif self.STATE.cur_settings.output == SpectralOutput.NEGATIVE:
113
- spectrum = spectrum[..., : n_time // 2]
114
-
115
- spectrum = np.moveaxis(spectrum, axis_idx, -1)
116
-
117
- out_axis = self.SETTINGS.out_axis
118
- if out_axis is None:
119
- out_axis = axis_name
120
-
121
- freq_axis = AxisArray.Axis(
122
- unit="Hz", gain=1.0 / (axis.gain * n_time), offset=axis_offset
123
- )
124
- new_axes = {**message.axes, **{out_axis: freq_axis}}
125
-
126
- new_dims = [d for d in message.dims]
127
- if self.SETTINGS.out_axis is not None:
128
- new_dims[axis_idx] = self.SETTINGS.out_axis
129
-
130
- out_msg = replace(message, data=spectrum, dims=new_dims, axes=new_axes)
131
-
132
- yield self.OUTPUT_SIGNAL, out_msg
1
+ from .spectrum import OptionsEnum as OptionsEnum
2
+ from .spectrum import SpectralOutput as SpectralOutput
3
+ from .spectrum import SpectralTransform as SpectralTransform
4
+ from .spectrum import Spectrum as Spectrum
5
+ from .spectrum import SpectrumSettings as SpectrumSettings
6
+ from .spectrum import WindowFunction as WindowFunction
@@ -0,0 +1,90 @@
1
+ from typing import Generator
2
+
3
+ import ezmsg.core as ez
4
+ from ezmsg.baseproc import (
5
+ BaseStatefulProcessor,
6
+ BaseTransformerUnit,
7
+ CompositeProcessor,
8
+ )
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.modify import modify_axis
11
+
12
+ from .spectrum import (
13
+ SpectralOutput,
14
+ SpectralTransform,
15
+ SpectrumTransformer,
16
+ WindowFunction,
17
+ )
18
+ from .window import Anchor, WindowTransformer
19
+
20
+
21
+ class SpectrogramSettings(ez.Settings):
22
+ """
23
+ Settings for :obj:`SpectrogramTransformer`.
24
+ """
25
+
26
+ window_dur: float | None = None
27
+ """window duration in seconds."""
28
+
29
+ window_shift: float | None = None
30
+ """"window step in seconds. If None, window_shift == window_dur"""
31
+
32
+ window_anchor: str | Anchor = Anchor.BEGINNING
33
+ """See :obj"`WindowTransformer`"""
34
+
35
+ window: WindowFunction = WindowFunction.HAMMING
36
+ """The :obj:`WindowFunction` to apply to the data slice prior to calculating the spectrum."""
37
+
38
+ transform: SpectralTransform = SpectralTransform.REL_DB
39
+ """The :obj:`SpectralTransform` to apply to the spectral magnitude."""
40
+
41
+ output: SpectralOutput = SpectralOutput.POSITIVE
42
+ """The :obj:`SpectralOutput` format."""
43
+
44
+
45
+ class SpectrogramTransformer(CompositeProcessor[SpectrogramSettings, AxisArray, AxisArray]):
46
+ @staticmethod
47
+ def _initialize_processors(
48
+ settings: SpectrogramSettings,
49
+ ) -> dict[str, BaseStatefulProcessor | Generator[AxisArray, AxisArray, None]]:
50
+ return {
51
+ "windowing": WindowTransformer(
52
+ axis="time",
53
+ newaxis="win",
54
+ window_dur=settings.window_dur,
55
+ window_shift=settings.window_shift,
56
+ zero_pad_until="shift" if settings.window_shift is not None else "input",
57
+ anchor=settings.window_anchor,
58
+ ),
59
+ "spectrum": SpectrumTransformer(
60
+ axis="time",
61
+ window=settings.window,
62
+ transform=settings.transform,
63
+ output=settings.output,
64
+ ),
65
+ "modify_axis": modify_axis(name_map={"win": "time"}),
66
+ }
67
+
68
+
69
+ class Spectrogram(BaseTransformerUnit[SpectrogramSettings, AxisArray, AxisArray, SpectrogramTransformer]):
70
+ SETTINGS = SpectrogramSettings
71
+
72
+
73
+ def spectrogram(
74
+ window_dur: float | None = None,
75
+ window_shift: float | None = None,
76
+ window_anchor: str | Anchor = Anchor.BEGINNING,
77
+ window: WindowFunction = WindowFunction.HAMMING,
78
+ transform: SpectralTransform = SpectralTransform.REL_DB,
79
+ output: SpectralOutput = SpectralOutput.POSITIVE,
80
+ ) -> SpectrogramTransformer:
81
+ return SpectrogramTransformer(
82
+ SpectrogramSettings(
83
+ window_dur=window_dur,
84
+ window_shift=window_shift,
85
+ window_anchor=window_anchor,
86
+ window=window,
87
+ transform=transform,
88
+ output=output,
89
+ )
90
+ )