ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/slicer.py CHANGED
@@ -2,9 +2,12 @@ from dataclasses import replace
2
2
  import typing
3
3
 
4
4
  import numpy as np
5
+ import numpy.typing as npt
5
6
  import ezmsg.core as ez
6
7
  from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
7
- from ezmsg.util.generator import consumer, GenAxisArray
8
+ from ezmsg.util.generator import consumer
9
+
10
+ from .base import GenAxisArray
8
11
 
9
12
 
10
13
  """
@@ -15,19 +18,20 @@ Slicer:Select a subset of data along a particular axis.
15
18
  def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
16
19
  """
17
20
  Parses a string representation of a slice and returns a tuple of slice objects.
18
- * "" -> slice(None, None, None) (take all)
19
- * ":" -> slice(None, None, None)
20
- * '"none"` (case-insensitive) -> slice(None, None, None)
21
- * "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
22
- * "5" (or any integer) -> (5,). Take only that item.
21
+
22
+ - "" -> slice(None, None, None) (take all)
23
+ - ":" -> slice(None, None, None)
24
+ - '"none"` (case-insensitive) -> slice(None, None, None)
25
+ - "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
26
+ - "5" (or any integer) -> (5,). Take only that item.
23
27
  applying this to a ndarray or AxisArray will drop the dimension.
24
- * A comma-separated list of the above -> a tuple of slices | ints
28
+ - A comma-separated list of the above -> a tuple of slices | ints
25
29
 
26
30
  Args:
27
- s (str): The string representation of the slice.
31
+ s: The string representation of the slice.
28
32
 
29
33
  Returns:
30
- tuple[slice | int, ...]: A tuple of slice objects and/or ints.
34
+ A tuple of slice objects and/or ints.
31
35
  """
32
36
  if s.lower() in ["", ":", "none"]:
33
37
  return (slice(None),)
@@ -36,51 +40,82 @@ def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
36
40
  if len(parts) == 1:
37
41
  return (int(parts[0]),)
38
42
  return (slice(*(int(part.strip()) if part else None for part in parts)),)
39
- l = [parse_slice(_) for _ in s.split(",")]
40
- return tuple([item for sublist in l for item in sublist])
43
+ suplist = [parse_slice(_) for _ in s.split(",")]
44
+ return tuple([item for sublist in suplist for item in sublist])
41
45
 
42
46
 
43
47
  @consumer
44
48
  def slicer(
45
49
  selection: str = "", axis: typing.Optional[str] = None
46
50
  ) -> typing.Generator[AxisArray, AxisArray, None]:
47
- axis_arr_in = AxisArray(np.array([]), dims=[""])
48
- axis_arr_out = AxisArray(np.array([]), dims=[""])
49
- _slice = None
50
- b_change_dims = False
51
+ msg_out = AxisArray(np.array([]), dims=[""])
52
+
53
+ # State variables
54
+ _slice: typing.Optional[typing.Union[slice, npt.NDArray]] = None
55
+ new_axis: typing.Optional[AxisArray.Axis] = None
56
+ b_change_dims: bool = False # If number of dimensions changes when slicing
57
+
58
+ # Reset if input changes
59
+ check_input = {
60
+ "key": None, # key change used as proxy for label change, which we don't check explicitly
61
+ "len": None,
62
+ }
51
63
 
52
64
  while True:
53
- axis_arr_in = yield axis_arr_out
65
+ msg_in: AxisArray = yield msg_out
54
66
 
55
- if axis is None:
56
- axis = axis_arr_in.dims[-1]
57
- axis_idx = axis_arr_in.get_axis_idx(axis)
67
+ axis = axis or msg_in.dims[-1]
68
+ axis_idx = msg_in.get_axis_idx(axis)
58
69
 
59
- if _slice is None:
70
+ b_reset = _slice is None # or new_axis is None
71
+ b_reset = b_reset or msg_in.key != check_input["key"]
72
+ b_reset = b_reset or (
73
+ (msg_in.data.shape[axis_idx] != check_input["len"])
74
+ and (type(_slice) is np.ndarray)
75
+ )
76
+ if b_reset:
77
+ check_input["key"] = msg_in.key
78
+ check_input["len"] = msg_in.data.shape[axis_idx]
79
+ new_axis = None # Will hold updated metadata
80
+ b_change_dims = False
81
+
82
+ # Calculate the slice
60
83
  _slices = parse_slice(selection)
61
84
  if len(_slices) == 1:
62
85
  _slice = _slices[0]
86
+ # Do we drop the sliced dimension?
63
87
  b_change_dims = isinstance(_slice, int)
64
88
  else:
65
89
  # Multiple slices, but this cannot be done in a single step, so we convert the slices
66
90
  # to a discontinuous set of integer indexes.
67
- indices = np.arange(axis_arr_in.data.shape[axis_idx])
91
+ indices = np.arange(msg_in.data.shape[axis_idx])
68
92
  indices = np.hstack([indices[_] for _ in _slices])
69
- _slice = np.s_[indices]
70
-
93
+ _slice = np.s_[indices] # Integer scalar array
94
+
95
+ # Create the output axis.
96
+ if (
97
+ axis in msg_in.axes
98
+ and hasattr(msg_in.axes[axis], "labels")
99
+ and len(msg_in.axes[axis].labels) > 0
100
+ ):
101
+ new_labels = msg_in.axes[axis].labels[_slice]
102
+ new_axis = replace(msg_in.axes[axis], labels=new_labels)
103
+
104
+ replace_kwargs = {}
71
105
  if b_change_dims:
72
- out_dims = [_ for dim_ix, _ in enumerate(axis_arr_in.dims) if dim_ix != axis_idx]
73
- out_axes = axis_arr_in.axes.copy()
74
- out_axes.pop(axis, None)
75
- else:
76
- out_dims = axis_arr_in.dims
77
- out_axes = axis_arr_in.axes
78
-
79
- axis_arr_out = replace(
80
- axis_arr_in,
81
- dims=out_dims,
82
- axes=out_axes,
83
- data=slice_along_axis(axis_arr_in.data, _slice, axis_idx),
106
+ # Dropping the target axis
107
+ replace_kwargs["dims"] = [
108
+ _ for dim_ix, _ in enumerate(msg_in.dims) if dim_ix != axis_idx
109
+ ]
110
+ replace_kwargs["axes"] = {k: v for k, v in msg_in.axes.items() if k != axis}
111
+ elif new_axis is not None:
112
+ replace_kwargs["axes"] = {
113
+ k: (v if k != axis else new_axis) for k, v in msg_in.axes.items()
114
+ }
115
+ msg_out = replace(
116
+ msg_in,
117
+ data=slice_along_axis(msg_in.data, _slice, axis_idx),
118
+ **replace_kwargs,
84
119
  )
85
120
 
86
121
 
@@ -90,7 +125,7 @@ class SlicerSettings(ez.Settings):
90
125
 
91
126
 
92
127
  class Slicer(GenAxisArray):
93
- SETTINGS: SlicerSettings
128
+ SETTINGS = SlicerSettings
94
129
 
95
130
  def construct_generator(self):
96
131
  self.STATE.gen = slicer(
ezmsg/sigproc/spectral.py CHANGED
@@ -1,9 +1,6 @@
1
- from .spectrum import (
2
- OptionsEnum,
3
- WindowFunction,
4
- SpectralTransform,
5
- SpectralOutput,
6
- SpectrumSettings,
7
- SpectrumState,
8
- Spectrum
9
- )
1
+ from .spectrum import OptionsEnum as OptionsEnum
2
+ from .spectrum import WindowFunction as WindowFunction
3
+ from .spectrum import SpectralTransform as SpectralTransform
4
+ from .spectrum import SpectralOutput as SpectralOutput
5
+ from .spectrum import SpectrumSettings as SpectrumSettings
6
+ from .spectrum import Spectrum as Spectrum
@@ -1,16 +1,13 @@
1
1
  import typing
2
2
 
3
- import numpy as np
4
-
5
3
  import ezmsg.core as ez
6
4
  from ezmsg.util.messages.axisarray import AxisArray
7
- from ezmsg.util.generator import consumer, GenAxisArray # , compose
5
+ from ezmsg.util.generator import consumer, compose
8
6
  from ezmsg.util.messages.modify import modify_axis
9
- from ezmsg.sigproc.window import windowing
10
- from ezmsg.sigproc.spectrum import (
11
- spectrum,
12
- WindowFunction, SpectralTransform, SpectralOutput
13
- )
7
+
8
+ from .window import windowing
9
+ from .spectrum import spectrum, WindowFunction, SpectralTransform, SpectralOutput
10
+ from .base import GenAxisArray
14
11
 
15
12
 
16
13
  @consumer
@@ -19,36 +16,53 @@ def spectrogram(
19
16
  window_shift: typing.Optional[float] = None,
20
17
  window: WindowFunction = WindowFunction.HANNING,
21
18
  transform: SpectralTransform = SpectralTransform.REL_DB,
22
- output: SpectralOutput = SpectralOutput.POSITIVE
19
+ output: SpectralOutput = SpectralOutput.POSITIVE,
23
20
  ) -> typing.Generator[typing.Optional[AxisArray], AxisArray, None]:
21
+ """
22
+ Calculate a spectrogram on streaming data.
23
+
24
+ Chains :obj:`ezmsg.sigproc.window.windowing` to apply a moving window on the data,
25
+ :obj:`ezmsg.sigproc.spectrum.spectrum` to calculate spectra for each window,
26
+ and finally :obj:`ezmsg.util.messages.modify.modify_axis` to convert the win axis back to time axis.
27
+
28
+ Args:
29
+ window_dur: See :obj:`ezmsg.sigproc.window.windowing`
30
+ window_shift: See :obj:`ezmsg.sigproc.window.windowing`
31
+ window: See :obj:`ezmsg.sigproc.spectrum.spectrum`
32
+ transform: See :obj:`ezmsg.sigproc.spectrum.spectrum`
33
+ output: See :obj:`ezmsg.sigproc.spectrum.spectrum`
24
34
 
25
- # We cannot use `compose` because `windowing` returns a list of axisarray objects,
26
- # even though the length is always exactly 1 for the settings used here.
27
- # pipeline = compose(
28
- f_win = windowing(axis="time", newaxis="step", window_dur=window_dur, window_shift=window_shift)
29
- f_spec = spectrum(axis="time", window=window, transform=transform, output=output)
30
- f_modify = modify_axis(name_map={"step": "time"})
31
- # )
35
+ Returns:
36
+ A primed generator object that expects `.send(axis_array)` of continuous data
37
+ and yields an AxisArray of time-frequency power values.
38
+ """
39
+
40
+ pipeline = compose(
41
+ windowing(
42
+ axis="time", newaxis="win", window_dur=window_dur, window_shift=window_shift
43
+ ),
44
+ spectrum(axis="time", window=window, transform=transform, output=output),
45
+ modify_axis(name_map={"win": "time"}),
46
+ )
32
47
 
33
48
  # State variables
34
- axis_arr_in = AxisArray(np.array([]), dims=[""])
35
- axis_arr_out: typing.Optional[AxisArray] = None
49
+ msg_out: typing.Optional[AxisArray] = None
36
50
 
37
51
  while True:
38
- axis_arr_in = yield axis_arr_out
39
-
40
- # axis_arr_out = pipeline(axis_arr_in)
41
- axis_arr_out = None
42
- wins = f_win.send(axis_arr_in)
43
- if len(wins):
44
- specs = f_spec.send(wins[0])
45
- if specs is not None:
46
- axis_arr_out = f_modify.send(specs)
52
+ msg_in: AxisArray = yield msg_out
53
+ msg_out = pipeline(msg_in)
47
54
 
48
55
 
49
56
  class SpectrogramSettings(ez.Settings):
57
+ """
58
+ Settings for :obj:`Spectrogram`.
59
+ See :obj:`spectrogram` for a description of the parameters.
60
+ """
61
+
50
62
  window_dur: typing.Optional[float] = None # window duration in seconds
51
- window_shift: typing.Optional[float] = None # window step in seconds. If None, window_shift == window_dur
63
+ window_shift: typing.Optional[float] = None
64
+ """"window step in seconds. If None, window_shift == window_dur"""
65
+
52
66
  # See SpectrumSettings for details of following settings:
53
67
  window: WindowFunction = WindowFunction.HAMMING
54
68
  transform: SpectralTransform = SpectralTransform.REL_DB
@@ -56,7 +70,11 @@ class SpectrogramSettings(ez.Settings):
56
70
 
57
71
 
58
72
  class Spectrogram(GenAxisArray):
59
- SETTINGS: SpectrogramSettings
73
+ """
74
+ Unit for :obj:`spectrogram`.
75
+ """
76
+
77
+ SETTINGS = SpectrogramSettings
60
78
 
61
79
  def construct_generator(self):
62
80
  self.STATE.gen = spectrogram(
@@ -64,5 +82,5 @@ class Spectrogram(GenAxisArray):
64
82
  window_shift=self.SETTINGS.window_shift,
65
83
  window=self.SETTINGS.window,
66
84
  transform=self.SETTINGS.transform,
67
- output=self.SETTINGS.output
85
+ output=self.SETTINGS.output,
68
86
  )
ezmsg/sigproc/spectrum.py CHANGED
@@ -1,11 +1,14 @@
1
1
  from dataclasses import replace
2
2
  import enum
3
- from typing import Optional, Generator, AsyncGenerator
3
+ from functools import partial
4
+ import typing
4
5
 
5
6
  import numpy as np
6
7
  import ezmsg.core as ez
7
8
  from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
8
- from ezmsg.util.generator import consumer, GenAxisArray
9
+ from ezmsg.util.generator import consumer
10
+
11
+ from .base import GenAxisArray
9
12
 
10
13
 
11
14
  class OptionsEnum(enum.Enum):
@@ -15,11 +18,22 @@ class OptionsEnum(enum.Enum):
15
18
 
16
19
 
17
20
  class WindowFunction(OptionsEnum):
21
+ """Windowing function prior to calculating spectrum."""
22
+
18
23
  NONE = "None (Rectangular)"
24
+ """None."""
25
+
19
26
  HAMMING = "Hamming"
27
+ """:obj:`numpy.hamming`"""
28
+
20
29
  HANNING = "Hanning"
30
+ """:obj:`numpy.hanning`"""
31
+
21
32
  BARTLETT = "Bartlett"
33
+ """:obj:`numpy.bartlett`"""
34
+
22
35
  BLACKMAN = "Blackman"
36
+ """:obj:`numpy.blackman`"""
23
37
 
24
38
 
25
39
  WINDOWS = {
@@ -32,6 +46,8 @@ WINDOWS = {
32
46
 
33
47
 
34
48
  class SpectralTransform(OptionsEnum):
49
+ """Additional transformation functions to apply to the spectral result."""
50
+
35
51
  RAW_COMPLEX = "Complex FFT Output"
36
52
  REAL = "Real Component of FFT"
37
53
  IMAG = "Imaginary Component of FFT"
@@ -40,6 +56,8 @@ class SpectralTransform(OptionsEnum):
40
56
 
41
57
 
42
58
  class SpectralOutput(OptionsEnum):
59
+ """The expected spectral contents."""
60
+
43
61
  FULL = "Full Spectrum"
44
62
  POSITIVE = "Positive Frequencies"
45
63
  NEGATIVE = "Negative Frequencies"
@@ -47,112 +65,195 @@ class SpectralOutput(OptionsEnum):
47
65
 
48
66
  @consumer
49
67
  def spectrum(
50
- axis: Optional[str] = None,
51
- out_axis: Optional[str] = "freq",
68
+ axis: typing.Optional[str] = None,
69
+ out_axis: typing.Optional[str] = "freq",
52
70
  window: WindowFunction = WindowFunction.HANNING,
53
71
  transform: SpectralTransform = SpectralTransform.REL_DB,
54
- output: SpectralOutput = SpectralOutput.POSITIVE
55
- ) -> Generator[AxisArray, AxisArray, None]:
72
+ output: SpectralOutput = SpectralOutput.POSITIVE,
73
+ norm: typing.Optional[str] = "forward",
74
+ do_fftshift: bool = True,
75
+ nfft: typing.Optional[int] = None,
76
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
77
+ """
78
+ Calculate a spectrum on a data slice.
79
+
80
+ Args:
81
+ axis: The name of the axis on which to calculate the spectrum.
82
+ out_axis: The name of the new axis. Defaults to "freq".
83
+ window: The :obj:`WindowFunction` to apply to the data slice prior to calculating the spectrum.
84
+ transform: The :obj:`SpectralTransform` to apply to the spectral magnitude.
85
+ output: The :obj:`SpectralOutput` format.
86
+ norm: Normalization mode. Default "forward" is best used when the inverse transform is not needed,
87
+ for example when the goal is to get spectral power. Use "backward" (equivalent to None) to not
88
+ scale the spectrum which is useful when the spectra will be manipulated and possibly inverse-transformed.
89
+ See numpy.fft.fft for details.
90
+ do_fftshift: Whether to apply fftshift to the output. Default is True. This value is ignored unless
91
+ output is SpectralOutput.FULL.
92
+ nfft: The number of points to use for the FFT. If None, the length of the input data is used.
93
+
94
+ Returns:
95
+ A primed generator object that expects `.send(axis_array)` of continuous data
96
+ and yields an AxisArray of spectral magnitudes or powers.
97
+ """
98
+ msg_out = AxisArray(np.array([]), dims=[""])
56
99
 
57
100
  # State variables
58
- axis_arr_in = AxisArray(np.array([]), dims=[""])
59
- axis_arr_out = AxisArray(np.array([]), dims=[""])
60
-
61
- axis_name = axis
62
- axis_idx = None
63
- n_time = None
101
+ apply_window = window != WindowFunction.NONE
102
+ do_fftshift &= output == SpectralOutput.FULL
103
+ f_sl = slice(None)
104
+ freq_axis: typing.Optional[AxisArray.Axis] = None
105
+ fftfun: typing.Optional[typing.Callable] = None
106
+ f_transform: typing.Optional[typing.Callable] = None
107
+ new_dims: typing.Optional[typing.List[str]] = None
108
+
109
+ # Reset if input changes substantially
110
+ check_input = {
111
+ "n_time": None, # Need to recalc windows
112
+ "ndim": None, # Input ndim changed: Need to recalc windows
113
+ "kind": None, # Input dtype changed: Need to re-init fft funcs
114
+ "ax_idx": None, # Axis index changed: Need to re-init fft funcs
115
+ "gain": None, # Gain changed: Need to re-calc freqs
116
+ # "key": None # There's no temporal continuity; we can ignore key changes
117
+ }
64
118
 
65
119
  while True:
66
- axis_arr_in = yield axis_arr_out
67
-
68
- if axis_name is None:
69
- axis_name = axis_arr_in.dims[0]
70
-
71
- # Initial setup
72
- if n_time is None or axis_idx is None or axis_arr_in.data.shape[axis_idx] != n_time:
73
- axis_idx = axis_arr_in.get_axis_idx(axis_name)
74
- _axis = axis_arr_in.get_axis(axis_name)
75
- n_time = axis_arr_in.data.shape[axis_idx]
76
- freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=_axis.gain), axes=-1)
77
- window = WINDOWS[window](n_time)
78
- window = window.reshape([1] * axis_idx + [len(window),] + [1] * (axis_arr_in.data.ndim-2))
79
- if (transform != SpectralTransform.RAW_COMPLEX and
80
- not (transform == SpectralTransform.REAL or transform == SpectralTransform.IMAG)):
81
- scale = np.sum(window ** 2.0) * _axis.gain
82
- axis_offset = freqs[0]
83
- if output == SpectralOutput.POSITIVE:
84
- axis_offset = freqs[n_time // 2]
120
+ msg_in: AxisArray = yield msg_out
121
+
122
+ # Get signal properties
123
+ axis = axis or msg_in.dims[0]
124
+ ax_idx = msg_in.get_axis_idx(axis)
125
+ ax_info = msg_in.axes[axis]
126
+ targ_len = msg_in.data.shape[ax_idx]
127
+
128
+ # Check signal properties for change
129
+ b_reset = targ_len != check_input["n_time"]
130
+ b_reset = b_reset or msg_in.data.ndim != check_input["ndim"]
131
+ b_reset = b_reset or msg_in.data.dtype.kind != check_input["kind"]
132
+ b_reset = b_reset or ax_idx != check_input["ax_idx"]
133
+ b_reset = b_reset or ax_info.gain != check_input["gain"]
134
+ if b_reset:
135
+ check_input["n_time"] = targ_len
136
+ check_input["ndim"] = msg_in.data.ndim
137
+ check_input["kind"] = msg_in.data.dtype.kind
138
+ check_input["ax_idx"] = ax_idx
139
+ check_input["gain"] = ax_info.gain
140
+
141
+ nfft = nfft or targ_len
142
+
143
+ # Pre-calculate windowing
144
+ window = WINDOWS[window](targ_len)
145
+ window = window.reshape(
146
+ [1] * ax_idx
147
+ + [
148
+ len(window),
149
+ ]
150
+ + [1] * (msg_in.data.ndim - 1 - ax_idx)
151
+ )
152
+ if transform != SpectralTransform.RAW_COMPLEX and not (
153
+ transform == SpectralTransform.REAL
154
+ or transform == SpectralTransform.IMAG
155
+ ):
156
+ scale = np.sum(window**2.0) * ax_info.gain
157
+
158
+ # Pre-calculate frequencies and select our fft function.
159
+ b_complex = msg_in.data.dtype.kind == "c"
160
+ if (not b_complex) and output == SpectralOutput.POSITIVE:
161
+ # If input is not complex and desired output is SpectralOutput.POSITIVE, we can save some computation
162
+ # by using rfft and rfftfreq.
163
+ fftfun = partial(np.fft.rfft, n=nfft, axis=ax_idx, norm=norm)
164
+ freqs = np.fft.rfftfreq(nfft, d=ax_info.gain * targ_len / nfft)
165
+ else:
166
+ fftfun = partial(np.fft.fft, n=nfft, axis=ax_idx, norm=norm)
167
+ freqs = np.fft.fftfreq(nfft, d=ax_info.gain * targ_len / nfft)
168
+ if output == SpectralOutput.POSITIVE:
169
+ f_sl = slice(None, nfft // 2 + 1 - (nfft % 2))
170
+ elif output == SpectralOutput.NEGATIVE:
171
+ freqs = np.fft.fftshift(freqs, axes=-1)
172
+ f_sl = slice(None, nfft // 2 + 1)
173
+ elif do_fftshift: # and FULL
174
+ freqs = np.fft.fftshift(freqs, axes=-1)
175
+ freqs = freqs[f_sl]
176
+ freqs = freqs.tolist() # To please type checking
85
177
  freq_axis = AxisArray.Axis(
86
- unit="Hz", gain=1.0 / (_axis.gain * n_time), offset=axis_offset
178
+ unit="Hz", gain=freqs[1] - freqs[0], offset=freqs[0]
87
179
  )
88
180
  if out_axis is None:
89
- out_axis = axis_name
90
- new_dims = axis_arr_in.dims[:axis_idx] + [out_axis, ] + axis_arr_in.dims[axis_idx + 1:]
181
+ out_axis = axis
182
+ new_dims = (
183
+ msg_in.dims[:ax_idx]
184
+ + [
185
+ out_axis,
186
+ ]
187
+ + msg_in.dims[ax_idx + 1 :]
188
+ )
189
+
190
+ def f_transform(x):
191
+ return x
91
192
 
92
- f_transform = lambda x: x
93
193
  if transform != SpectralTransform.RAW_COMPLEX:
94
194
  if transform == SpectralTransform.REAL:
95
- f_transform = lambda x: x.real
195
+
196
+ def f_transform(x):
197
+ return x.real
96
198
  elif transform == SpectralTransform.IMAG:
97
- f_transform = lambda x: x.imag
199
+
200
+ def f_transform(x):
201
+ return x.imag
98
202
  else:
99
- f1 = lambda x: (2.0 * (np.abs(x) ** 2.0)) / scale
203
+
204
+ def f1(x):
205
+ return (np.abs(x) ** 2.0) / scale
206
+
100
207
  if transform == SpectralTransform.REL_DB:
101
- f_transform = lambda x: 10 * np.log10(f1(x))
208
+
209
+ def f_transform(x):
210
+ return 10 * np.log10(f1(x))
102
211
  else:
103
212
  f_transform = f1
104
213
 
105
- new_axes = {**axis_arr_in.axes, **{out_axis: freq_axis}}
106
- if out_axis != axis_name:
107
- new_axes.pop(axis_name, None)
108
-
109
- spec = np.fft.fft(axis_arr_in.data * window, axis=axis_idx) / n_time
110
- spec = np.fft.fftshift(spec, axes=axis_idx)
214
+ new_axes = {k: v for k, v in msg_in.axes.items() if k not in [out_axis, axis]}
215
+ new_axes[out_axis] = freq_axis
216
+
217
+ if apply_window:
218
+ win_dat = msg_in.data * window
219
+ else:
220
+ win_dat = msg_in.data
221
+ spec = fftfun(win_dat, n=nfft, axis=ax_idx, norm=norm)
222
+ # Note: norm="forward" equivalent to `/ nfft`
223
+ if do_fftshift or output == SpectralOutput.NEGATIVE:
224
+ spec = np.fft.fftshift(spec, axes=ax_idx)
111
225
  spec = f_transform(spec)
226
+ spec = slice_along_axis(spec, f_sl, ax_idx)
112
227
 
113
- if output == SpectralOutput.POSITIVE:
114
- spec = slice_along_axis(spec, slice(n_time // 2, None), axis_idx)
115
-
116
- elif output == SpectralOutput.NEGATIVE:
117
- spec = slice_along_axis(spec, slice(None, n_time // 2), axis_idx)
118
-
119
- axis_arr_out = replace(axis_arr_in, data=spec, dims=new_dims, axes=new_axes)
228
+ msg_out = replace(msg_in, data=spec, dims=new_dims, axes=new_axes)
120
229
 
121
230
 
122
231
  class SpectrumSettings(ez.Settings):
123
- axis: Optional[str] = None
124
- # n: Optional[int] = None # n parameter for fft
125
- out_axis: Optional[str] = "freq" # If none; don't change dim name
232
+ """
233
+ Settings for :obj:`Spectrum.
234
+ See :obj:`spectrum` for a description of the parameters.
235
+ """
236
+
237
+ axis: typing.Optional[str] = None
238
+ # n: typing.Optional[int] = None # n parameter for fft
239
+ out_axis: typing.Optional[str] = "freq" # If none; don't change dim name
126
240
  window: WindowFunction = WindowFunction.HAMMING
127
241
  transform: SpectralTransform = SpectralTransform.REL_DB
128
242
  output: SpectralOutput = SpectralOutput.POSITIVE
129
243
 
130
244
 
131
- class SpectrumState(ez.State):
132
- gen: Generator
133
- cur_settings: SpectrumSettings
134
-
135
-
136
245
  class Spectrum(GenAxisArray):
137
- SETTINGS: SpectrumSettings
138
- STATE: SpectrumState
246
+ """Unit for :obj:`spectrum`"""
139
247
 
140
- INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
141
-
142
- def initialize(self) -> None:
143
- self.STATE.cur_settings = self.SETTINGS
144
- super().initialize()
248
+ SETTINGS = SpectrumSettings
145
249
 
146
- @ez.subscriber(INPUT_SETTINGS)
147
- async def on_settings(self, msg: SpectrumSettings):
148
- self.STATE.cur_settings = msg
149
- self.construct_generator()
250
+ INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
150
251
 
151
252
  def construct_generator(self):
152
253
  self.STATE.gen = spectrum(
153
- axis=self.STATE.cur_settings.axis,
154
- out_axis=self.STATE.cur_settings.out_axis,
155
- window=self.STATE.cur_settings.window,
156
- transform=self.STATE.cur_settings.transform,
157
- output=self.STATE.cur_settings.output
254
+ axis=self.SETTINGS.axis,
255
+ out_axis=self.SETTINGS.out_axis,
256
+ window=self.SETTINGS.window,
257
+ transform=self.SETTINGS.transform,
258
+ output=self.SETTINGS.output,
158
259
  )