ezmsg-sigproc 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +4 -1
- ezmsg/sigproc/affinetransform.py +124 -0
- ezmsg/sigproc/aggregate.py +103 -0
- ezmsg/sigproc/bandpower.py +53 -0
- ezmsg/sigproc/butterworthfilter.py +41 -6
- ezmsg/sigproc/downsample.py +52 -26
- ezmsg/sigproc/ewmfilter.py +11 -3
- ezmsg/sigproc/filter.py +82 -14
- ezmsg/sigproc/sampler.py +173 -200
- ezmsg/sigproc/scaler.py +127 -0
- ezmsg/sigproc/signalinjector.py +67 -0
- ezmsg/sigproc/slicer.py +98 -0
- ezmsg/sigproc/spectral.py +9 -132
- ezmsg/sigproc/spectrogram.py +68 -0
- ezmsg/sigproc/spectrum.py +158 -0
- ezmsg/sigproc/synth.py +179 -80
- ezmsg/sigproc/window.py +212 -110
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.2.3.dist-info}/METADATA +15 -13
- ezmsg_sigproc-1.2.3.dist-info/RECORD +23 -0
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.2.3.dist-info}/WHEEL +1 -2
- ezmsg/sigproc/__version__.py +0 -1
- ezmsg_sigproc-1.2.2.dist-info/RECORD +0 -17
- ezmsg_sigproc-1.2.2.dist-info/top_level.txt +0 -1
- {ezmsg_sigproc-1.2.2.dist-info → ezmsg_sigproc-1.2.3.dist-info}/LICENSE.txt +0 -0
ezmsg/sigproc/slicer.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
7
|
+
from ezmsg.util.generator import consumer, GenAxisArray
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
Slicer:Select a subset of data along a particular axis.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
|
|
16
|
+
"""
|
|
17
|
+
Parses a string representation of a slice and returns a tuple of slice objects.
|
|
18
|
+
* "" -> slice(None, None, None) (take all)
|
|
19
|
+
* ":" -> slice(None, None, None)
|
|
20
|
+
* '"none"` (case-insensitive) -> slice(None, None, None)
|
|
21
|
+
* "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
|
|
22
|
+
* "5" (or any integer) -> (5,). Take only that item.
|
|
23
|
+
applying this to a ndarray or AxisArray will drop the dimension.
|
|
24
|
+
* A comma-separated list of the above -> a tuple of slices | ints
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
s (str): The string representation of the slice.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
tuple[slice | int, ...]: A tuple of slice objects and/or ints.
|
|
31
|
+
"""
|
|
32
|
+
if s.lower() in ["", ":", "none"]:
|
|
33
|
+
return (slice(None),)
|
|
34
|
+
if "," not in s:
|
|
35
|
+
parts = [part.strip() for part in s.split(":")]
|
|
36
|
+
if len(parts) == 1:
|
|
37
|
+
return (int(parts[0]),)
|
|
38
|
+
return (slice(*(int(part.strip()) if part else None for part in parts)),)
|
|
39
|
+
l = [parse_slice(_) for _ in s.split(",")]
|
|
40
|
+
return tuple([item for sublist in l for item in sublist])
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@consumer
|
|
44
|
+
def slicer(
|
|
45
|
+
selection: str = "", axis: typing.Optional[str] = None
|
|
46
|
+
) -> typing.Generator[AxisArray, AxisArray, None]:
|
|
47
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
48
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
49
|
+
_slice = None
|
|
50
|
+
b_change_dims = False
|
|
51
|
+
|
|
52
|
+
while True:
|
|
53
|
+
axis_arr_in = yield axis_arr_out
|
|
54
|
+
|
|
55
|
+
if axis is None:
|
|
56
|
+
axis = axis_arr_in.dims[-1]
|
|
57
|
+
axis_idx = axis_arr_in.get_axis_idx(axis)
|
|
58
|
+
|
|
59
|
+
if _slice is None:
|
|
60
|
+
_slices = parse_slice(selection)
|
|
61
|
+
if len(_slices) == 1:
|
|
62
|
+
_slice = _slices[0]
|
|
63
|
+
b_change_dims = isinstance(_slice, int)
|
|
64
|
+
else:
|
|
65
|
+
# Multiple slices, but this cannot be done in a single step, so we convert the slices
|
|
66
|
+
# to a discontinuous set of integer indexes.
|
|
67
|
+
indices = np.arange(axis_arr_in.data.shape[axis_idx])
|
|
68
|
+
indices = np.hstack([indices[_] for _ in _slices])
|
|
69
|
+
_slice = np.s_[indices]
|
|
70
|
+
|
|
71
|
+
if b_change_dims:
|
|
72
|
+
out_dims = [_ for dim_ix, _ in enumerate(axis_arr_in.dims) if dim_ix != axis_idx]
|
|
73
|
+
out_axes = axis_arr_in.axes.copy()
|
|
74
|
+
out_axes.pop(axis, None)
|
|
75
|
+
else:
|
|
76
|
+
out_dims = axis_arr_in.dims
|
|
77
|
+
out_axes = axis_arr_in.axes
|
|
78
|
+
|
|
79
|
+
axis_arr_out = replace(
|
|
80
|
+
axis_arr_in,
|
|
81
|
+
dims=out_dims,
|
|
82
|
+
axes=out_axes,
|
|
83
|
+
data=slice_along_axis(axis_arr_in.data, _slice, axis_idx),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class SlicerSettings(ez.Settings):
|
|
88
|
+
selection: str = ""
|
|
89
|
+
axis: typing.Optional[str] = None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class Slicer(GenAxisArray):
|
|
93
|
+
SETTINGS: SlicerSettings
|
|
94
|
+
|
|
95
|
+
def construct_generator(self):
|
|
96
|
+
self.STATE.gen = slicer(
|
|
97
|
+
selection=self.SETTINGS.selection, axis=self.SETTINGS.axis
|
|
98
|
+
)
|
ezmsg/sigproc/spectral.py
CHANGED
|
@@ -1,132 +1,9 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from typing import Optional, AsyncGenerator
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class OptionsEnum(enum.Enum):
|
|
14
|
-
@classmethod
|
|
15
|
-
def options(cls):
|
|
16
|
-
return list(map(lambda c: c.value, cls))
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class WindowFunction(OptionsEnum):
|
|
20
|
-
NONE = "None (Rectangular)"
|
|
21
|
-
HAMMING = "Hamming"
|
|
22
|
-
HANNING = "Hanning"
|
|
23
|
-
BARTLETT = "Bartlett"
|
|
24
|
-
BLACKMAN = "Blackman"
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
WINDOWS = {
|
|
28
|
-
WindowFunction.NONE: np.ones,
|
|
29
|
-
WindowFunction.HAMMING: np.hamming,
|
|
30
|
-
WindowFunction.HANNING: np.hanning,
|
|
31
|
-
WindowFunction.BARTLETT: np.bartlett,
|
|
32
|
-
WindowFunction.BLACKMAN: np.blackman,
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class SpectralTransform(OptionsEnum):
|
|
37
|
-
RAW_COMPLEX = "Complex FFT Output"
|
|
38
|
-
REAL = "Real Component of FFT"
|
|
39
|
-
IMAG = "Imaginary Component of FFT"
|
|
40
|
-
REL_POWER = "Relative Power"
|
|
41
|
-
REL_DB = "Log Power (Relative dB)"
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class SpectralOutput(OptionsEnum):
|
|
45
|
-
FULL = "Full Spectrum"
|
|
46
|
-
POSITIVE = "Positive Frequencies"
|
|
47
|
-
NEGATIVE = "Negative Frequencies"
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class SpectrumSettings(ez.Settings):
|
|
51
|
-
axis: Optional[str] = None
|
|
52
|
-
# n: Optional[int] = None # n parameter for fft
|
|
53
|
-
out_axis: Optional[str] = "freq" # If none; don't change dim name
|
|
54
|
-
window: WindowFunction = WindowFunction.HAMMING
|
|
55
|
-
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
56
|
-
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class SpectrumState(ez.State):
|
|
60
|
-
cur_settings: SpectrumSettings
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class Spectrum(ez.Unit):
|
|
64
|
-
SETTINGS: SpectrumSettings
|
|
65
|
-
STATE: SpectrumState
|
|
66
|
-
|
|
67
|
-
INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
|
|
68
|
-
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
69
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
70
|
-
|
|
71
|
-
def initialize(self) -> None:
|
|
72
|
-
self.STATE.cur_settings = self.SETTINGS
|
|
73
|
-
|
|
74
|
-
@ez.subscriber(INPUT_SETTINGS)
|
|
75
|
-
async def on_settings(self, msg: SpectrumSettings):
|
|
76
|
-
self.STATE.cur_settings = msg
|
|
77
|
-
|
|
78
|
-
@ez.subscriber(INPUT_SIGNAL)
|
|
79
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
80
|
-
async def on_data(self, message: AxisArray) -> AsyncGenerator:
|
|
81
|
-
axis_name = self.STATE.cur_settings.axis
|
|
82
|
-
if axis_name is None:
|
|
83
|
-
axis_name = message.dims[0]
|
|
84
|
-
axis_idx = message.get_axis_idx(axis_name)
|
|
85
|
-
axis = message.get_axis(axis_name)
|
|
86
|
-
|
|
87
|
-
spectrum = np.moveaxis(message.data, axis_idx, -1)
|
|
88
|
-
|
|
89
|
-
n_time = message.data.shape[axis_idx]
|
|
90
|
-
window = WINDOWS[self.STATE.cur_settings.window](n_time)
|
|
91
|
-
|
|
92
|
-
spectrum = np.fft.fft(spectrum * window) / n_time
|
|
93
|
-
spectrum = np.fft.fftshift(spectrum, axes=-1)
|
|
94
|
-
freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=axis.gain), axes=-1)
|
|
95
|
-
|
|
96
|
-
if self.STATE.cur_settings.transform != SpectralTransform.RAW_COMPLEX:
|
|
97
|
-
if self.STATE.cur_settings.transform == SpectralTransform.REAL:
|
|
98
|
-
spectrum = spectrum.real
|
|
99
|
-
elif self.STATE.cur_settings.transform == SpectralTransform.IMAG:
|
|
100
|
-
spectrum = spectrum.imag
|
|
101
|
-
else:
|
|
102
|
-
scale = np.sum(window**2.0) * axis.gain
|
|
103
|
-
spectrum = (2.0 * (np.abs(spectrum) ** 2.0)) / scale
|
|
104
|
-
|
|
105
|
-
if self.STATE.cur_settings.transform == SpectralTransform.REL_DB:
|
|
106
|
-
spectrum = 10 * np.log10(spectrum)
|
|
107
|
-
|
|
108
|
-
axis_offset = freqs[0]
|
|
109
|
-
if self.STATE.cur_settings.output == SpectralOutput.POSITIVE:
|
|
110
|
-
axis_offset = freqs[n_time // 2]
|
|
111
|
-
spectrum = spectrum[..., n_time // 2 :]
|
|
112
|
-
elif self.STATE.cur_settings.output == SpectralOutput.NEGATIVE:
|
|
113
|
-
spectrum = spectrum[..., : n_time // 2]
|
|
114
|
-
|
|
115
|
-
spectrum = np.moveaxis(spectrum, axis_idx, -1)
|
|
116
|
-
|
|
117
|
-
out_axis = self.SETTINGS.out_axis
|
|
118
|
-
if out_axis is None:
|
|
119
|
-
out_axis = axis_name
|
|
120
|
-
|
|
121
|
-
freq_axis = AxisArray.Axis(
|
|
122
|
-
unit="Hz", gain=1.0 / (axis.gain * n_time), offset=axis_offset
|
|
123
|
-
)
|
|
124
|
-
new_axes = {**message.axes, **{out_axis: freq_axis}}
|
|
125
|
-
|
|
126
|
-
new_dims = [d for d in message.dims]
|
|
127
|
-
if self.SETTINGS.out_axis is not None:
|
|
128
|
-
new_dims[axis_idx] = self.SETTINGS.out_axis
|
|
129
|
-
|
|
130
|
-
out_msg = replace(message, data=spectrum, dims=new_dims, axes=new_axes)
|
|
131
|
-
|
|
132
|
-
yield self.OUTPUT_SIGNAL, out_msg
|
|
1
|
+
from .spectrum import (
|
|
2
|
+
OptionsEnum,
|
|
3
|
+
WindowFunction,
|
|
4
|
+
SpectralTransform,
|
|
5
|
+
SpectralOutput,
|
|
6
|
+
SpectrumSettings,
|
|
7
|
+
SpectrumState,
|
|
8
|
+
Spectrum
|
|
9
|
+
)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
7
|
+
from ezmsg.util.generator import consumer, GenAxisArray # , compose
|
|
8
|
+
from ezmsg.util.messages.modify import modify_axis
|
|
9
|
+
from ezmsg.sigproc.window import windowing
|
|
10
|
+
from ezmsg.sigproc.spectrum import (
|
|
11
|
+
spectrum,
|
|
12
|
+
WindowFunction, SpectralTransform, SpectralOutput
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@consumer
|
|
17
|
+
def spectrogram(
|
|
18
|
+
window_dur: typing.Optional[float] = None,
|
|
19
|
+
window_shift: typing.Optional[float] = None,
|
|
20
|
+
window: WindowFunction = WindowFunction.HANNING,
|
|
21
|
+
transform: SpectralTransform = SpectralTransform.REL_DB,
|
|
22
|
+
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
23
|
+
) -> typing.Generator[typing.Optional[AxisArray], AxisArray, None]:
|
|
24
|
+
|
|
25
|
+
# We cannot use `compose` because `windowing` returns a list of axisarray objects,
|
|
26
|
+
# even though the length is always exactly 1 for the settings used here.
|
|
27
|
+
# pipeline = compose(
|
|
28
|
+
f_win = windowing(axis="time", newaxis="step", window_dur=window_dur, window_shift=window_shift)
|
|
29
|
+
f_spec = spectrum(axis="time", window=window, transform=transform, output=output)
|
|
30
|
+
f_modify = modify_axis(name_map={"step": "time"})
|
|
31
|
+
# )
|
|
32
|
+
|
|
33
|
+
# State variables
|
|
34
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
35
|
+
axis_arr_out: typing.Optional[AxisArray] = None
|
|
36
|
+
|
|
37
|
+
while True:
|
|
38
|
+
axis_arr_in = yield axis_arr_out
|
|
39
|
+
|
|
40
|
+
# axis_arr_out = pipeline(axis_arr_in)
|
|
41
|
+
axis_arr_out = None
|
|
42
|
+
wins = f_win.send(axis_arr_in)
|
|
43
|
+
if len(wins):
|
|
44
|
+
specs = f_spec.send(wins[0])
|
|
45
|
+
if specs is not None:
|
|
46
|
+
axis_arr_out = f_modify.send(specs)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SpectrogramSettings(ez.Settings):
|
|
50
|
+
window_dur: typing.Optional[float] = None # window duration in seconds
|
|
51
|
+
window_shift: typing.Optional[float] = None # window step in seconds. If None, window_shift == window_dur
|
|
52
|
+
# See SpectrumSettings for details of following settings:
|
|
53
|
+
window: WindowFunction = WindowFunction.HAMMING
|
|
54
|
+
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
55
|
+
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Spectrogram(GenAxisArray):
|
|
59
|
+
SETTINGS: SpectrogramSettings
|
|
60
|
+
|
|
61
|
+
def construct_generator(self):
|
|
62
|
+
self.STATE.gen = spectrogram(
|
|
63
|
+
window_dur=self.SETTINGS.window_dur,
|
|
64
|
+
window_shift=self.SETTINGS.window_shift,
|
|
65
|
+
window=self.SETTINGS.window,
|
|
66
|
+
transform=self.SETTINGS.transform,
|
|
67
|
+
output=self.SETTINGS.output
|
|
68
|
+
)
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
import enum
|
|
3
|
+
from typing import Optional, Generator, AsyncGenerator
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import ezmsg.core as ez
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
|
|
8
|
+
from ezmsg.util.generator import consumer, GenAxisArray
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OptionsEnum(enum.Enum):
|
|
12
|
+
@classmethod
|
|
13
|
+
def options(cls):
|
|
14
|
+
return list(map(lambda c: c.value, cls))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WindowFunction(OptionsEnum):
|
|
18
|
+
NONE = "None (Rectangular)"
|
|
19
|
+
HAMMING = "Hamming"
|
|
20
|
+
HANNING = "Hanning"
|
|
21
|
+
BARTLETT = "Bartlett"
|
|
22
|
+
BLACKMAN = "Blackman"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
WINDOWS = {
|
|
26
|
+
WindowFunction.NONE: np.ones,
|
|
27
|
+
WindowFunction.HAMMING: np.hamming,
|
|
28
|
+
WindowFunction.HANNING: np.hanning,
|
|
29
|
+
WindowFunction.BARTLETT: np.bartlett,
|
|
30
|
+
WindowFunction.BLACKMAN: np.blackman,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SpectralTransform(OptionsEnum):
|
|
35
|
+
RAW_COMPLEX = "Complex FFT Output"
|
|
36
|
+
REAL = "Real Component of FFT"
|
|
37
|
+
IMAG = "Imaginary Component of FFT"
|
|
38
|
+
REL_POWER = "Relative Power"
|
|
39
|
+
REL_DB = "Log Power (Relative dB)"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SpectralOutput(OptionsEnum):
|
|
43
|
+
FULL = "Full Spectrum"
|
|
44
|
+
POSITIVE = "Positive Frequencies"
|
|
45
|
+
NEGATIVE = "Negative Frequencies"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@consumer
|
|
49
|
+
def spectrum(
|
|
50
|
+
axis: Optional[str] = None,
|
|
51
|
+
out_axis: Optional[str] = "freq",
|
|
52
|
+
window: WindowFunction = WindowFunction.HANNING,
|
|
53
|
+
transform: SpectralTransform = SpectralTransform.REL_DB,
|
|
54
|
+
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
55
|
+
) -> Generator[AxisArray, AxisArray, None]:
|
|
56
|
+
|
|
57
|
+
# State variables
|
|
58
|
+
axis_arr_in = AxisArray(np.array([]), dims=[""])
|
|
59
|
+
axis_arr_out = AxisArray(np.array([]), dims=[""])
|
|
60
|
+
|
|
61
|
+
axis_name = axis
|
|
62
|
+
axis_idx = None
|
|
63
|
+
n_time = None
|
|
64
|
+
|
|
65
|
+
while True:
|
|
66
|
+
axis_arr_in = yield axis_arr_out
|
|
67
|
+
|
|
68
|
+
if axis_name is None:
|
|
69
|
+
axis_name = axis_arr_in.dims[0]
|
|
70
|
+
|
|
71
|
+
# Initial setup
|
|
72
|
+
if n_time is None or axis_idx is None or axis_arr_in.data.shape[axis_idx] != n_time:
|
|
73
|
+
axis_idx = axis_arr_in.get_axis_idx(axis_name)
|
|
74
|
+
_axis = axis_arr_in.get_axis(axis_name)
|
|
75
|
+
n_time = axis_arr_in.data.shape[axis_idx]
|
|
76
|
+
freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=_axis.gain), axes=-1)
|
|
77
|
+
window = WINDOWS[window](n_time)
|
|
78
|
+
window = window.reshape([1] * axis_idx + [len(window),] + [1] * (axis_arr_in.data.ndim-2))
|
|
79
|
+
if (transform != SpectralTransform.RAW_COMPLEX and
|
|
80
|
+
not (transform == SpectralTransform.REAL or transform == SpectralTransform.IMAG)):
|
|
81
|
+
scale = np.sum(window ** 2.0) * _axis.gain
|
|
82
|
+
axis_offset = freqs[0]
|
|
83
|
+
if output == SpectralOutput.POSITIVE:
|
|
84
|
+
axis_offset = freqs[n_time // 2]
|
|
85
|
+
freq_axis = AxisArray.Axis(
|
|
86
|
+
unit="Hz", gain=1.0 / (_axis.gain * n_time), offset=axis_offset
|
|
87
|
+
)
|
|
88
|
+
if out_axis is None:
|
|
89
|
+
out_axis = axis_name
|
|
90
|
+
new_dims = axis_arr_in.dims[:axis_idx] + [out_axis, ] + axis_arr_in.dims[axis_idx + 1:]
|
|
91
|
+
|
|
92
|
+
f_transform = lambda x: x
|
|
93
|
+
if transform != SpectralTransform.RAW_COMPLEX:
|
|
94
|
+
if transform == SpectralTransform.REAL:
|
|
95
|
+
f_transform = lambda x: x.real
|
|
96
|
+
elif transform == SpectralTransform.IMAG:
|
|
97
|
+
f_transform = lambda x: x.imag
|
|
98
|
+
else:
|
|
99
|
+
f1 = lambda x: (2.0 * (np.abs(x) ** 2.0)) / scale
|
|
100
|
+
if transform == SpectralTransform.REL_DB:
|
|
101
|
+
f_transform = lambda x: 10 * np.log10(f1(x))
|
|
102
|
+
else:
|
|
103
|
+
f_transform = f1
|
|
104
|
+
|
|
105
|
+
new_axes = {**axis_arr_in.axes, **{out_axis: freq_axis}}
|
|
106
|
+
if out_axis != axis_name:
|
|
107
|
+
new_axes.pop(axis_name, None)
|
|
108
|
+
|
|
109
|
+
spec = np.fft.fft(axis_arr_in.data * window, axis=axis_idx) / n_time
|
|
110
|
+
spec = np.fft.fftshift(spec, axes=axis_idx)
|
|
111
|
+
spec = f_transform(spec)
|
|
112
|
+
|
|
113
|
+
if output == SpectralOutput.POSITIVE:
|
|
114
|
+
spec = slice_along_axis(spec, slice(n_time // 2, None), axis_idx)
|
|
115
|
+
|
|
116
|
+
elif output == SpectralOutput.NEGATIVE:
|
|
117
|
+
spec = slice_along_axis(spec, slice(None, n_time // 2), axis_idx)
|
|
118
|
+
|
|
119
|
+
axis_arr_out = replace(axis_arr_in, data=spec, dims=new_dims, axes=new_axes)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class SpectrumSettings(ez.Settings):
|
|
123
|
+
axis: Optional[str] = None
|
|
124
|
+
# n: Optional[int] = None # n parameter for fft
|
|
125
|
+
out_axis: Optional[str] = "freq" # If none; don't change dim name
|
|
126
|
+
window: WindowFunction = WindowFunction.HAMMING
|
|
127
|
+
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
128
|
+
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class SpectrumState(ez.State):
|
|
132
|
+
gen: Generator
|
|
133
|
+
cur_settings: SpectrumSettings
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class Spectrum(GenAxisArray):
|
|
137
|
+
SETTINGS: SpectrumSettings
|
|
138
|
+
STATE: SpectrumState
|
|
139
|
+
|
|
140
|
+
INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
|
|
141
|
+
|
|
142
|
+
def initialize(self) -> None:
|
|
143
|
+
self.STATE.cur_settings = self.SETTINGS
|
|
144
|
+
super().initialize()
|
|
145
|
+
|
|
146
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
147
|
+
async def on_settings(self, msg: SpectrumSettings):
|
|
148
|
+
self.STATE.cur_settings = msg
|
|
149
|
+
self.construct_generator()
|
|
150
|
+
|
|
151
|
+
def construct_generator(self):
|
|
152
|
+
self.STATE.gen = spectrum(
|
|
153
|
+
axis=self.STATE.cur_settings.axis,
|
|
154
|
+
out_axis=self.STATE.cur_settings.out_axis,
|
|
155
|
+
window=self.STATE.cur_settings.window,
|
|
156
|
+
transform=self.STATE.cur_settings.transform,
|
|
157
|
+
output=self.STATE.cur_settings.output
|
|
158
|
+
)
|