ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/spectral.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from .spectrum import OptionsEnum as OptionsEnum
2
- from .spectrum import WindowFunction as WindowFunction
3
- from .spectrum import SpectralTransform as SpectralTransform
4
2
  from .spectrum import SpectralOutput as SpectralOutput
5
- from .spectrum import SpectrumSettings as SpectrumSettings
3
+ from .spectrum import SpectralTransform as SpectralTransform
6
4
  from .spectrum import Spectrum as Spectrum
5
+ from .spectrum import SpectrumSettings as SpectrumSettings
6
+ from .spectrum import WindowFunction as WindowFunction
@@ -1,90 +1,90 @@
1
- import typing
1
+ from typing import Generator
2
2
 
3
3
  import ezmsg.core as ez
4
+ from ezmsg.baseproc import (
5
+ BaseStatefulProcessor,
6
+ BaseTransformerUnit,
7
+ CompositeProcessor,
8
+ )
4
9
  from ezmsg.util.messages.axisarray import AxisArray
5
- from ezmsg.util.generator import consumer, compose
6
10
  from ezmsg.util.messages.modify import modify_axis
7
11
 
8
- from .window import windowing
9
- from .spectrum import spectrum, WindowFunction, SpectralTransform, SpectralOutput
10
- from .base import GenAxisArray
11
-
12
-
13
- @consumer
14
- def spectrogram(
15
- window_dur: float | None = None,
16
- window_shift: float | None = None,
17
- window: WindowFunction = WindowFunction.HANNING,
18
- transform: SpectralTransform = SpectralTransform.REL_DB,
19
- output: SpectralOutput = SpectralOutput.POSITIVE,
20
- ) -> typing.Generator[AxisArray | None, AxisArray, None]:
21
- """
22
- Calculate a spectrogram on streaming data.
23
-
24
- Chains :obj:`ezmsg.sigproc.window.windowing` to apply a moving window on the data,
25
- :obj:`ezmsg.sigproc.spectrum.spectrum` to calculate spectra for each window,
26
- and finally :obj:`ezmsg.util.messages.modify.modify_axis` to convert the win axis back to time axis.
27
-
28
- Args:
29
- window_dur: See :obj:`ezmsg.sigproc.window.windowing`
30
- window_shift: See :obj:`ezmsg.sigproc.window.windowing`
31
- window: See :obj:`ezmsg.sigproc.spectrum.spectrum`
32
- transform: See :obj:`ezmsg.sigproc.spectrum.spectrum`
33
- output: See :obj:`ezmsg.sigproc.spectrum.spectrum`
34
-
35
- Returns:
36
- A primed generator object that expects an :obj:`AxisArray` via `.send(axis_array)`
37
- with continuous data in its .data payload, and yields an :obj:`AxisArray` of time-frequency power values.
38
- """
39
-
40
- pipeline = compose(
41
- windowing(
42
- axis="time",
43
- newaxis="win",
44
- window_dur=window_dur,
45
- window_shift=window_shift,
46
- zero_pad_until="shift" if window_shift is not None else "input",
47
- ),
48
- spectrum(axis="time", window=window, transform=transform, output=output),
49
- modify_axis(name_map={"win": "time"}),
50
- )
51
-
52
- # State variables
53
- msg_out: AxisArray | None = None
54
-
55
- while True:
56
- msg_in: AxisArray = yield msg_out
57
- msg_out = pipeline(msg_in)
12
+ from .spectrum import (
13
+ SpectralOutput,
14
+ SpectralTransform,
15
+ SpectrumTransformer,
16
+ WindowFunction,
17
+ )
18
+ from .window import Anchor, WindowTransformer
58
19
 
59
20
 
60
21
  class SpectrogramSettings(ez.Settings):
61
22
  """
62
- Settings for :obj:`Spectrogram`.
63
- See :obj:`spectrogram` for a description of the parameters.
23
+ Settings for :obj:`SpectrogramTransformer`.
64
24
  """
65
25
 
66
- window_dur: float | None = None # window duration in seconds
26
+ window_dur: float | None = None
27
+ """window duration in seconds."""
28
+
67
29
  window_shift: float | None = None
68
30
  """"window step in seconds. If None, window_shift == window_dur"""
69
31
 
70
- # See SpectrumSettings for details of following settings:
71
- window: WindowFunction = WindowFunction.HAMMING
72
- transform: SpectralTransform = SpectralTransform.REL_DB
73
- output: SpectralOutput = SpectralOutput.POSITIVE
32
+ window_anchor: str | Anchor = Anchor.BEGINNING
33
+ """See :obj"`WindowTransformer`"""
74
34
 
35
+ window: WindowFunction = WindowFunction.HAMMING
36
+ """The :obj:`WindowFunction` to apply to the data slice prior to calculating the spectrum."""
75
37
 
76
- class Spectrogram(GenAxisArray):
77
- """
78
- Unit for :obj:`spectrogram`.
79
- """
38
+ transform: SpectralTransform = SpectralTransform.REL_DB
39
+ """The :obj:`SpectralTransform` to apply to the spectral magnitude."""
80
40
 
41
+ output: SpectralOutput = SpectralOutput.POSITIVE
42
+ """The :obj:`SpectralOutput` format."""
43
+
44
+
45
+ class SpectrogramTransformer(CompositeProcessor[SpectrogramSettings, AxisArray, AxisArray]):
46
+ @staticmethod
47
+ def _initialize_processors(
48
+ settings: SpectrogramSettings,
49
+ ) -> dict[str, BaseStatefulProcessor | Generator[AxisArray, AxisArray, None]]:
50
+ return {
51
+ "windowing": WindowTransformer(
52
+ axis="time",
53
+ newaxis="win",
54
+ window_dur=settings.window_dur,
55
+ window_shift=settings.window_shift,
56
+ zero_pad_until="shift" if settings.window_shift is not None else "input",
57
+ anchor=settings.window_anchor,
58
+ ),
59
+ "spectrum": SpectrumTransformer(
60
+ axis="time",
61
+ window=settings.window,
62
+ transform=settings.transform,
63
+ output=settings.output,
64
+ ),
65
+ "modify_axis": modify_axis(name_map={"win": "time"}),
66
+ }
67
+
68
+
69
+ class Spectrogram(BaseTransformerUnit[SpectrogramSettings, AxisArray, AxisArray, SpectrogramTransformer]):
81
70
  SETTINGS = SpectrogramSettings
82
71
 
83
- def construct_generator(self):
84
- self.STATE.gen = spectrogram(
85
- window_dur=self.SETTINGS.window_dur,
86
- window_shift=self.SETTINGS.window_shift,
87
- window=self.SETTINGS.window,
88
- transform=self.SETTINGS.transform,
89
- output=self.SETTINGS.output,
72
+
73
+ def spectrogram(
74
+ window_dur: float | None = None,
75
+ window_shift: float | None = None,
76
+ window_anchor: str | Anchor = Anchor.BEGINNING,
77
+ window: WindowFunction = WindowFunction.HAMMING,
78
+ transform: SpectralTransform = SpectralTransform.REL_DB,
79
+ output: SpectralOutput = SpectralOutput.POSITIVE,
80
+ ) -> SpectrogramTransformer:
81
+ return SpectrogramTransformer(
82
+ SpectrogramSettings(
83
+ window_dur=window_dur,
84
+ window_shift=window_shift,
85
+ window_anchor=window_anchor,
86
+ window=window,
87
+ transform=transform,
88
+ output=output,
90
89
  )
90
+ )
ezmsg/sigproc/spectrum.py CHANGED
@@ -1,17 +1,20 @@
1
1
  import enum
2
- from functools import partial
3
2
  import typing
3
+ from functools import partial
4
4
 
5
- import numpy as np
6
5
  import ezmsg.core as ez
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ from ezmsg.baseproc import (
9
+ BaseStatefulTransformer,
10
+ BaseTransformerUnit,
11
+ processor_state,
12
+ )
7
13
  from ezmsg.util.messages.axisarray import (
8
14
  AxisArray,
9
- slice_along_axis,
10
15
  replace,
16
+ slice_along_axis,
11
17
  )
12
- from ezmsg.util.generator import consumer
13
-
14
- from .base import GenAxisArray
15
18
 
16
19
 
17
20
  class OptionsEnum(enum.Enum):
@@ -66,198 +69,209 @@ class SpectralOutput(OptionsEnum):
66
69
  NEGATIVE = "Negative Frequencies"
67
70
 
68
71
 
69
- @consumer
70
- def spectrum(
71
- axis: str | None = None,
72
- out_axis: str | None = "freq",
73
- window: WindowFunction = WindowFunction.HANNING,
74
- transform: SpectralTransform = SpectralTransform.REL_DB,
75
- output: SpectralOutput = SpectralOutput.POSITIVE,
76
- norm: str | None = "forward",
77
- do_fftshift: bool = True,
78
- nfft: int | None = None,
79
- ) -> typing.Generator[AxisArray, AxisArray, None]:
72
+ class SpectrumSettings(ez.Settings):
73
+ """
74
+ Settings for :obj:`Spectrum.
75
+ See :obj:`spectrum` for a description of the parameters.
80
76
  """
81
- Calculate a spectrum on a data slice.
82
77
 
83
- Args:
84
- axis: The name of the axis on which to calculate the spectrum.
85
- Note: The axis must have an .axes entry of type LinearAxis, not CoordinateAxis.
86
- out_axis: The name of the new axis. Defaults to "freq".
87
- window: The :obj:`WindowFunction` to apply to the data slice prior to calculating the spectrum.
88
- transform: The :obj:`SpectralTransform` to apply to the spectral magnitude.
89
- output: The :obj:`SpectralOutput` format.
90
- norm: Normalization mode. Default "forward" is best used when the inverse transform is not needed,
91
- for example when the goal is to get spectral power. Use "backward" (equivalent to None) to not
92
- scale the spectrum which is useful when the spectra will be manipulated and possibly inverse-transformed.
93
- See numpy.fft.fft for details.
94
- do_fftshift: Whether to apply fftshift to the output. Default is True. This value is ignored unless
95
- output is SpectralOutput.FULL.
96
- nfft: The number of points to use for the FFT. If None, the length of the input data is used.
78
+ axis: str | None = None
79
+ """
80
+ The name of the axis on which to calculate the spectrum.
81
+ Note: The axis must have an .axes entry of type LinearAxis, not CoordinateAxis.
82
+ """
97
83
 
98
- Returns:
99
- A primed generator object that expects an :obj:`AxisArray` via `.send(axis_array)` containing continuous data
100
- and yields an :obj:`AxisArray` with data of spectral magnitudes or powers.
84
+ # n: int | None = None # n parameter for fft
85
+
86
+ out_axis: str | None = "freq"
87
+ """The name of the new axis. Defaults to "freq". If none; don't change dim name"""
88
+
89
+ window: WindowFunction = WindowFunction.HAMMING
90
+ """The :obj:`WindowFunction` to apply to the data slice prior to calculating the spectrum."""
91
+
92
+ transform: SpectralTransform = SpectralTransform.REL_DB
93
+ """The :obj:`SpectralTransform` to apply to the spectral magnitude."""
94
+
95
+ output: SpectralOutput = SpectralOutput.POSITIVE
96
+ """The :obj:`SpectralOutput` format."""
97
+
98
+ norm: str | None = "forward"
99
+ """
100
+ Normalization mode. Default "forward" is best used when the inverse transform is not needed,
101
+ for example when the goal is to get spectral power. Use "backward" (equivalent to None) to not
102
+ scale the spectrum which is useful when the spectra will be manipulated and possibly inverse-transformed.
103
+ See numpy.fft.fft for details.
104
+ """
105
+
106
+ do_fftshift: bool = True
107
+ """
108
+ Whether to apply fftshift to the output. Default is True.
109
+ This value is ignored unless output is SpectralOutput.FULL.
110
+ """
111
+
112
+ nfft: int | None = None
101
113
  """
102
- msg_out = AxisArray(np.array([]), dims=[""])
114
+ The number of points to use for the FFT. If None, the length of the input data is used.
115
+ """
116
+
103
117
 
104
- # State variables
105
- apply_window = window != WindowFunction.NONE
106
- do_fftshift &= output == SpectralOutput.FULL
107
- f_sl = slice(None)
118
+ @processor_state
119
+ class SpectrumState:
120
+ f_sl: slice | None = None
121
+ # I would prefer `slice(None)` as f_sl default but this fails because it is mutable.
108
122
  freq_axis: AxisArray.LinearAxis | None = None
109
123
  fftfun: typing.Callable | None = None
110
124
  f_transform: typing.Callable | None = None
111
125
  new_dims: list[str] | None = None
126
+ window: npt.NDArray | None = None
127
+
128
+
129
+ class SpectrumTransformer(BaseStatefulTransformer[SpectrumSettings, AxisArray, AxisArray, SpectrumState]):
130
+ def _hash_message(self, message: AxisArray) -> int:
131
+ axis = self.settings.axis or message.dims[0]
132
+ ax_idx = message.get_axis_idx(axis)
133
+ ax_info = message.axes[axis]
134
+ targ_len = message.data.shape[ax_idx]
135
+ return hash((targ_len, message.data.ndim, message.data.dtype.kind, ax_idx, ax_info.gain))
136
+
137
+ def _reset_state(self, message: AxisArray) -> None:
138
+ axis = self.settings.axis or message.dims[0]
139
+ ax_idx = message.get_axis_idx(axis)
140
+ ax_info = message.axes[axis]
141
+ targ_len = message.data.shape[ax_idx]
142
+ nfft = self.settings.nfft or targ_len
143
+
144
+ # Pre-calculate windowing
145
+ window = WINDOWS[self.settings.window](targ_len)
146
+ window = window.reshape(
147
+ [1] * ax_idx
148
+ + [
149
+ len(window),
150
+ ]
151
+ + [1] * (message.data.ndim - 1 - ax_idx)
152
+ )
153
+ if self.settings.transform != SpectralTransform.RAW_COMPLEX and not (
154
+ self.settings.transform == SpectralTransform.REAL or self.settings.transform == SpectralTransform.IMAG
155
+ ):
156
+ scale = np.sum(window**2.0) * ax_info.gain
157
+
158
+ if self.settings.window != WindowFunction.NONE:
159
+ self.state.window = window
160
+
161
+ # Pre-calculate frequencies and select our fft function.
162
+ b_complex = message.data.dtype.kind == "c"
163
+ self.state.f_sl = slice(None)
164
+ if (not b_complex) and self.settings.output == SpectralOutput.POSITIVE:
165
+ # If input is not complex and desired output is SpectralOutput.POSITIVE, we can save some computation
166
+ # by using rfft and rfftfreq.
167
+ self.state.fftfun = partial(np.fft.rfft, n=nfft, axis=ax_idx, norm=self.settings.norm)
168
+ freqs = np.fft.rfftfreq(nfft, d=ax_info.gain * targ_len / nfft)
169
+ else:
170
+ self.state.fftfun = partial(np.fft.fft, n=nfft, axis=ax_idx, norm=self.settings.norm)
171
+ freqs = np.fft.fftfreq(nfft, d=ax_info.gain * targ_len / nfft)
172
+ if self.settings.output == SpectralOutput.POSITIVE:
173
+ self.state.f_sl = slice(None, nfft // 2 + 1 - (nfft % 2))
174
+ elif self.settings.output == SpectralOutput.NEGATIVE:
175
+ freqs = np.fft.fftshift(freqs, axes=-1)
176
+ self.state.f_sl = slice(None, nfft // 2 + 1)
177
+ elif self.settings.do_fftshift and self.settings.output == SpectralOutput.FULL:
178
+ freqs = np.fft.fftshift(freqs, axes=-1)
179
+ freqs = freqs[self.state.f_sl]
180
+ freqs = freqs.tolist() # To please type checking
181
+ self.state.freq_axis = AxisArray.LinearAxis(unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0])
182
+ self.state.new_dims = (
183
+ message.dims[:ax_idx]
184
+ + [
185
+ self.settings.out_axis or axis,
186
+ ]
187
+ + message.dims[ax_idx + 1 :]
188
+ )
189
+
190
+ def f_transform(x):
191
+ return x
112
192
 
113
- # Reset if input changes substantially
114
- check_input = {
115
- "n_time": None, # Need to recalc windows
116
- "ndim": None, # Input ndim changed: Need to recalc windows
117
- "kind": None, # Input dtype changed: Need to re-init fft funcs
118
- "ax_idx": None, # Axis index changed: Need to re-init fft funcs
119
- "gain": None, # Gain changed: Need to re-calc freqs
120
- # "key": None # There's no temporal continuity; we can ignore key changes
121
- }
122
-
123
- while True:
124
- msg_in: AxisArray = yield msg_out
125
-
126
- # Get signal properties
127
- axis = axis or msg_in.dims[0]
128
- ax_idx = msg_in.get_axis_idx(axis)
129
- ax_info = msg_in.axes[axis]
130
- targ_len = msg_in.data.shape[ax_idx]
131
-
132
- # Check signal properties for change
133
- b_reset = targ_len != check_input["n_time"]
134
- b_reset = b_reset or msg_in.data.ndim != check_input["ndim"]
135
- b_reset = b_reset or msg_in.data.dtype.kind != check_input["kind"]
136
- b_reset = b_reset or ax_idx != check_input["ax_idx"]
137
- b_reset = b_reset or ax_info.gain != check_input["gain"]
138
- if b_reset:
139
- check_input["n_time"] = targ_len
140
- check_input["ndim"] = msg_in.data.ndim
141
- check_input["kind"] = msg_in.data.dtype.kind
142
- check_input["ax_idx"] = ax_idx
143
- check_input["gain"] = ax_info.gain
144
-
145
- nfft = nfft or targ_len
146
-
147
- # Pre-calculate windowing
148
- window = WINDOWS[window](targ_len)
149
- window = window.reshape(
150
- [1] * ax_idx
151
- + [
152
- len(window),
153
- ]
154
- + [1] * (msg_in.data.ndim - 1 - ax_idx)
155
- )
156
- if transform != SpectralTransform.RAW_COMPLEX and not (
157
- transform == SpectralTransform.REAL
158
- or transform == SpectralTransform.IMAG
159
- ):
160
- scale = np.sum(window**2.0) * ax_info.gain
161
-
162
- # Pre-calculate frequencies and select our fft function.
163
- b_complex = msg_in.data.dtype.kind == "c"
164
- if (not b_complex) and output == SpectralOutput.POSITIVE:
165
- # If input is not complex and desired output is SpectralOutput.POSITIVE, we can save some computation
166
- # by using rfft and rfftfreq.
167
- fftfun = partial(np.fft.rfft, n=nfft, axis=ax_idx, norm=norm)
168
- freqs = np.fft.rfftfreq(nfft, d=ax_info.gain * targ_len / nfft)
193
+ if self.settings.transform != SpectralTransform.RAW_COMPLEX:
194
+ if self.settings.transform == SpectralTransform.REAL:
195
+
196
+ def f_transform(x):
197
+ return x.real
198
+ elif self.settings.transform == SpectralTransform.IMAG:
199
+
200
+ def f_transform(x):
201
+ return x.imag
169
202
  else:
170
- fftfun = partial(np.fft.fft, n=nfft, axis=ax_idx, norm=norm)
171
- freqs = np.fft.fftfreq(nfft, d=ax_info.gain * targ_len / nfft)
172
- if output == SpectralOutput.POSITIVE:
173
- f_sl = slice(None, nfft // 2 + 1 - (nfft % 2))
174
- elif output == SpectralOutput.NEGATIVE:
175
- freqs = np.fft.fftshift(freqs, axes=-1)
176
- f_sl = slice(None, nfft // 2 + 1)
177
- elif do_fftshift: # and FULL
178
- freqs = np.fft.fftshift(freqs, axes=-1)
179
- freqs = freqs[f_sl]
180
- freqs = freqs.tolist() # To please type checking
181
- freq_axis = AxisArray.LinearAxis(
182
- unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0]
183
- )
184
- if out_axis is None:
185
- out_axis = axis
186
- new_dims = (
187
- msg_in.dims[:ax_idx]
188
- + [
189
- out_axis,
190
- ]
191
- + msg_in.dims[ax_idx + 1 :]
192
- )
193
-
194
- def f_transform(x):
195
- return x
196
-
197
- if transform != SpectralTransform.RAW_COMPLEX:
198
- if transform == SpectralTransform.REAL:
199
203
 
200
- def f_transform(x):
201
- return x.real
202
- elif transform == SpectralTransform.IMAG:
204
+ def f1(x):
205
+ return (np.abs(x) ** 2.0) / scale
206
+
207
+ if self.settings.transform == SpectralTransform.REL_DB:
203
208
 
204
209
  def f_transform(x):
205
- return x.imag
210
+ return 10 * np.log10(f1(x))
206
211
  else:
212
+ f_transform = f1
213
+ self.state.f_transform = f_transform
207
214
 
208
- def f1(x):
209
- return (np.abs(x) ** 2.0) / scale
210
-
211
- if transform == SpectralTransform.REL_DB:
215
+ def _process(self, message: AxisArray) -> AxisArray:
216
+ axis = self.settings.axis or message.dims[0]
217
+ ax_idx = message.get_axis_idx(axis)
218
+ targ_len = message.data.shape[ax_idx]
212
219
 
213
- def f_transform(x):
214
- return 10 * np.log10(f1(x))
215
- else:
216
- f_transform = f1
220
+ new_axes = {k: v for k, v in message.axes.items() if k not in [self.settings.out_axis, axis]}
221
+ new_axes[self.settings.out_axis or axis] = self.state.freq_axis
217
222
 
218
- new_axes = {k: v for k, v in msg_in.axes.items() if k not in [out_axis, axis]}
219
- new_axes[out_axis] = freq_axis
220
-
221
- if apply_window:
222
- win_dat = msg_in.data * window
223
+ if self.state.window is not None:
224
+ win_dat = message.data * self.state.window
223
225
  else:
224
- win_dat = msg_in.data
225
- spec = fftfun(win_dat, n=nfft, axis=ax_idx, norm=norm)
226
+ win_dat = message.data
227
+ spec = self.state.fftfun(
228
+ win_dat,
229
+ n=self.settings.nfft or targ_len,
230
+ axis=ax_idx,
231
+ norm=self.settings.norm,
232
+ )
226
233
  # Note: norm="forward" equivalent to `/ nfft`
227
- if do_fftshift or output == SpectralOutput.NEGATIVE:
234
+ if (
235
+ self.settings.do_fftshift and self.settings.output == SpectralOutput.FULL
236
+ ) or self.settings.output == SpectralOutput.NEGATIVE:
228
237
  spec = np.fft.fftshift(spec, axes=ax_idx)
229
- spec = f_transform(spec)
230
- spec = slice_along_axis(spec, f_sl, ax_idx)
231
-
232
- msg_out = replace(msg_in, data=spec, dims=new_dims, axes=new_axes)
238
+ spec = self.state.f_transform(spec)
239
+ spec = slice_along_axis(spec, self.state.f_sl, ax_idx)
233
240
 
241
+ msg_out = replace(message, data=spec, dims=self.state.new_dims, axes=new_axes)
242
+ return msg_out
234
243
 
235
- class SpectrumSettings(ez.Settings):
236
- """
237
- Settings for :obj:`Spectrum.
238
- See :obj:`spectrum` for a description of the parameters.
239
- """
240
-
241
- axis: str | None = None
242
- # n: int | None = None # n parameter for fft
243
- out_axis: str | None = "freq" # If none; don't change dim name
244
- window: WindowFunction = WindowFunction.HAMMING
245
- transform: SpectralTransform = SpectralTransform.REL_DB
246
- output: SpectralOutput = SpectralOutput.POSITIVE
247
-
248
-
249
- class Spectrum(GenAxisArray):
250
- """Unit for :obj:`spectrum`"""
251
244
 
245
+ class Spectrum(BaseTransformerUnit[SpectrumSettings, AxisArray, AxisArray, SpectrumTransformer]):
252
246
  SETTINGS = SpectrumSettings
253
247
 
254
- INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
255
248
 
256
- def construct_generator(self):
257
- self.STATE.gen = spectrum(
258
- axis=self.SETTINGS.axis,
259
- out_axis=self.SETTINGS.out_axis,
260
- window=self.SETTINGS.window,
261
- transform=self.SETTINGS.transform,
262
- output=self.SETTINGS.output,
249
+ def spectrum(
250
+ axis: str | None = None,
251
+ out_axis: str | None = "freq",
252
+ window: WindowFunction = WindowFunction.HANNING,
253
+ transform: SpectralTransform = SpectralTransform.REL_DB,
254
+ output: SpectralOutput = SpectralOutput.POSITIVE,
255
+ norm: str | None = "forward",
256
+ do_fftshift: bool = True,
257
+ nfft: int | None = None,
258
+ ) -> SpectrumTransformer:
259
+ """
260
+ Calculate a spectrum on a data slice.
261
+
262
+ Returns:
263
+ A :obj:`SpectrumTransformer` object that expects an :obj:`AxisArray` via `.(axis_array)` (__call__)
264
+ containing continuous data and returns an :obj:`AxisArray` with data of spectral magnitudes or powers.
265
+ """
266
+ return SpectrumTransformer(
267
+ SpectrumSettings(
268
+ axis=axis,
269
+ out_axis=out_axis,
270
+ window=window,
271
+ transform=transform,
272
+ output=output,
273
+ norm=norm,
274
+ do_fftshift=do_fftshift,
275
+ nfft=nfft,
263
276
  )
277
+ )