ezmsg-sigproc 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,98 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
7
+ from ezmsg.util.generator import consumer, GenAxisArray
8
+
9
+
10
+ """
11
+ Slicer:Select a subset of data along a particular axis.
12
+ """
13
+
14
+
15
+ def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
16
+ """
17
+ Parses a string representation of a slice and returns a tuple of slice objects.
18
+ * "" -> slice(None, None, None) (take all)
19
+ * ":" -> slice(None, None, None)
20
+ * '"none"` (case-insensitive) -> slice(None, None, None)
21
+ * "{start}:{stop}" or {start}:{stop}:{step} -> slice(start, stop, step)
22
+ * "5" (or any integer) -> (5,). Take only that item.
23
+ applying this to a ndarray or AxisArray will drop the dimension.
24
+ * A comma-separated list of the above -> a tuple of slices | ints
25
+
26
+ Args:
27
+ s (str): The string representation of the slice.
28
+
29
+ Returns:
30
+ tuple[slice | int, ...]: A tuple of slice objects and/or ints.
31
+ """
32
+ if s.lower() in ["", ":", "none"]:
33
+ return (slice(None),)
34
+ if "," not in s:
35
+ parts = [part.strip() for part in s.split(":")]
36
+ if len(parts) == 1:
37
+ return (int(parts[0]),)
38
+ return (slice(*(int(part.strip()) if part else None for part in parts)),)
39
+ l = [parse_slice(_) for _ in s.split(",")]
40
+ return tuple([item for sublist in l for item in sublist])
41
+
42
+
43
+ @consumer
44
+ def slicer(
45
+ selection: str = "", axis: typing.Optional[str] = None
46
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
47
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
48
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
49
+ _slice = None
50
+ b_change_dims = False
51
+
52
+ while True:
53
+ axis_arr_in = yield axis_arr_out
54
+
55
+ if axis is None:
56
+ axis = axis_arr_in.dims[-1]
57
+ axis_idx = axis_arr_in.get_axis_idx(axis)
58
+
59
+ if _slice is None:
60
+ _slices = parse_slice(selection)
61
+ if len(_slices) == 1:
62
+ _slice = _slices[0]
63
+ b_change_dims = isinstance(_slice, int)
64
+ else:
65
+ # Multiple slices, but this cannot be done in a single step, so we convert the slices
66
+ # to a discontinuous set of integer indexes.
67
+ indices = np.arange(axis_arr_in.data.shape[axis_idx])
68
+ indices = np.hstack([indices[_] for _ in _slices])
69
+ _slice = np.s_[indices]
70
+
71
+ if b_change_dims:
72
+ out_dims = [_ for dim_ix, _ in enumerate(axis_arr_in.dims) if dim_ix != axis_idx]
73
+ out_axes = axis_arr_in.axes.copy()
74
+ out_axes.pop(axis, None)
75
+ else:
76
+ out_dims = axis_arr_in.dims
77
+ out_axes = axis_arr_in.axes
78
+
79
+ axis_arr_out = replace(
80
+ axis_arr_in,
81
+ dims=out_dims,
82
+ axes=out_axes,
83
+ data=slice_along_axis(axis_arr_in.data, _slice, axis_idx),
84
+ )
85
+
86
+
87
+ class SlicerSettings(ez.Settings):
88
+ selection: str = ""
89
+ axis: typing.Optional[str] = None
90
+
91
+
92
+ class Slicer(GenAxisArray):
93
+ SETTINGS: SlicerSettings
94
+
95
+ def construct_generator(self):
96
+ self.STATE.gen = slicer(
97
+ selection=self.SETTINGS.selection, axis=self.SETTINGS.axis
98
+ )
ezmsg/sigproc/spectral.py CHANGED
@@ -1,132 +1,9 @@
1
- import enum
2
-
3
- from dataclasses import replace
4
-
5
- import numpy as np
6
- import ezmsg.core as ez
7
-
8
- from ezmsg.util.messages.axisarray import AxisArray
9
-
10
- from typing import Optional, AsyncGenerator
11
-
12
-
13
- class OptionsEnum(enum.Enum):
14
- @classmethod
15
- def options(cls):
16
- return list(map(lambda c: c.value, cls))
17
-
18
-
19
- class WindowFunction(OptionsEnum):
20
- NONE = "None (Rectangular)"
21
- HAMMING = "Hamming"
22
- HANNING = "Hanning"
23
- BARTLETT = "Bartlett"
24
- BLACKMAN = "Blackman"
25
-
26
-
27
- WINDOWS = {
28
- WindowFunction.NONE: np.ones,
29
- WindowFunction.HAMMING: np.hamming,
30
- WindowFunction.HANNING: np.hanning,
31
- WindowFunction.BARTLETT: np.bartlett,
32
- WindowFunction.BLACKMAN: np.blackman,
33
- }
34
-
35
-
36
- class SpectralTransform(OptionsEnum):
37
- RAW_COMPLEX = "Complex FFT Output"
38
- REAL = "Real Component of FFT"
39
- IMAG = "Imaginary Component of FFT"
40
- REL_POWER = "Relative Power"
41
- REL_DB = "Log Power (Relative dB)"
42
-
43
-
44
- class SpectralOutput(OptionsEnum):
45
- FULL = "Full Spectrum"
46
- POSITIVE = "Positive Frequencies"
47
- NEGATIVE = "Negative Frequencies"
48
-
49
-
50
- class SpectrumSettings(ez.Settings):
51
- axis: Optional[str] = None
52
- # n: Optional[int] = None # n parameter for fft
53
- out_axis: Optional[str] = "freq" # If none; don't change dim name
54
- window: WindowFunction = WindowFunction.HAMMING
55
- transform: SpectralTransform = SpectralTransform.REL_DB
56
- output: SpectralOutput = SpectralOutput.POSITIVE
57
-
58
-
59
- class SpectrumState(ez.State):
60
- cur_settings: SpectrumSettings
61
-
62
-
63
- class Spectrum(ez.Unit):
64
- SETTINGS: SpectrumSettings
65
- STATE: SpectrumState
66
-
67
- INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
68
- INPUT_SIGNAL = ez.InputStream(AxisArray)
69
- OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
70
-
71
- def initialize(self) -> None:
72
- self.STATE.cur_settings = self.SETTINGS
73
-
74
- @ez.subscriber(INPUT_SETTINGS)
75
- async def on_settings(self, msg: SpectrumSettings):
76
- self.STATE.cur_settings = msg
77
-
78
- @ez.subscriber(INPUT_SIGNAL)
79
- @ez.publisher(OUTPUT_SIGNAL)
80
- async def on_data(self, message: AxisArray) -> AsyncGenerator:
81
- axis_name = self.STATE.cur_settings.axis
82
- if axis_name is None:
83
- axis_name = message.dims[0]
84
- axis_idx = message.get_axis_idx(axis_name)
85
- axis = message.get_axis(axis_name)
86
-
87
- spectrum = np.moveaxis(message.data, axis_idx, -1)
88
-
89
- n_time = message.data.shape[axis_idx]
90
- window = WINDOWS[self.STATE.cur_settings.window](n_time)
91
-
92
- spectrum = np.fft.fft(spectrum * window) / n_time
93
- spectrum = np.fft.fftshift(spectrum, axes=-1)
94
- freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=axis.gain), axes=-1)
95
-
96
- if self.STATE.cur_settings.transform != SpectralTransform.RAW_COMPLEX:
97
- if self.STATE.cur_settings.transform == SpectralTransform.REAL:
98
- spectrum = spectrum.real
99
- elif self.STATE.cur_settings.transform == SpectralTransform.IMAG:
100
- spectrum = spectrum.imag
101
- else:
102
- scale = np.sum(window**2.0) * axis.gain
103
- spectrum = (2.0 * (np.abs(spectrum) ** 2.0)) / scale
104
-
105
- if self.STATE.cur_settings.transform == SpectralTransform.REL_DB:
106
- spectrum = 10 * np.log10(spectrum)
107
-
108
- axis_offset = freqs[0]
109
- if self.STATE.cur_settings.output == SpectralOutput.POSITIVE:
110
- axis_offset = freqs[n_time // 2]
111
- spectrum = spectrum[..., n_time // 2 :]
112
- elif self.STATE.cur_settings.output == SpectralOutput.NEGATIVE:
113
- spectrum = spectrum[..., : n_time // 2]
114
-
115
- spectrum = np.moveaxis(spectrum, axis_idx, -1)
116
-
117
- out_axis = self.SETTINGS.out_axis
118
- if out_axis is None:
119
- out_axis = axis_name
120
-
121
- freq_axis = AxisArray.Axis(
122
- unit="Hz", gain=1.0 / (axis.gain * n_time), offset=axis_offset
123
- )
124
- new_axes = {**message.axes, **{out_axis: freq_axis}}
125
-
126
- new_dims = [d for d in message.dims]
127
- if self.SETTINGS.out_axis is not None:
128
- new_dims[axis_idx] = self.SETTINGS.out_axis
129
-
130
- out_msg = replace(message, data=spectrum, dims=new_dims, axes=new_axes)
131
-
132
- yield self.OUTPUT_SIGNAL, out_msg
1
+ from .spectrum import (
2
+ OptionsEnum,
3
+ WindowFunction,
4
+ SpectralTransform,
5
+ SpectralOutput,
6
+ SpectrumSettings,
7
+ SpectrumState,
8
+ Spectrum
9
+ )
@@ -0,0 +1,68 @@
1
+ import typing
2
+
3
+ import numpy as np
4
+
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+ from ezmsg.util.generator import consumer, GenAxisArray # , compose
8
+ from ezmsg.util.messages.modify import modify_axis
9
+ from ezmsg.sigproc.window import windowing
10
+ from ezmsg.sigproc.spectrum import (
11
+ spectrum,
12
+ WindowFunction, SpectralTransform, SpectralOutput
13
+ )
14
+
15
+
16
+ @consumer
17
+ def spectrogram(
18
+ window_dur: typing.Optional[float] = None,
19
+ window_shift: typing.Optional[float] = None,
20
+ window: WindowFunction = WindowFunction.HANNING,
21
+ transform: SpectralTransform = SpectralTransform.REL_DB,
22
+ output: SpectralOutput = SpectralOutput.POSITIVE
23
+ ) -> typing.Generator[typing.Optional[AxisArray], AxisArray, None]:
24
+
25
+ # We cannot use `compose` because `windowing` returns a list of axisarray objects,
26
+ # even though the length is always exactly 1 for the settings used here.
27
+ # pipeline = compose(
28
+ f_win = windowing(axis="time", newaxis="step", window_dur=window_dur, window_shift=window_shift)
29
+ f_spec = spectrum(axis="time", window=window, transform=transform, output=output)
30
+ f_modify = modify_axis(name_map={"step": "time"})
31
+ # )
32
+
33
+ # State variables
34
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
35
+ axis_arr_out: typing.Optional[AxisArray] = None
36
+
37
+ while True:
38
+ axis_arr_in = yield axis_arr_out
39
+
40
+ # axis_arr_out = pipeline(axis_arr_in)
41
+ axis_arr_out = None
42
+ wins = f_win.send(axis_arr_in)
43
+ if len(wins):
44
+ specs = f_spec.send(wins[0])
45
+ if specs is not None:
46
+ axis_arr_out = f_modify.send(specs)
47
+
48
+
49
+ class SpectrogramSettings(ez.Settings):
50
+ window_dur: typing.Optional[float] = None # window duration in seconds
51
+ window_shift: typing.Optional[float] = None # window step in seconds. If None, window_shift == window_dur
52
+ # See SpectrumSettings for details of following settings:
53
+ window: WindowFunction = WindowFunction.HAMMING
54
+ transform: SpectralTransform = SpectralTransform.REL_DB
55
+ output: SpectralOutput = SpectralOutput.POSITIVE
56
+
57
+
58
+ class Spectrogram(GenAxisArray):
59
+ SETTINGS: SpectrogramSettings
60
+
61
+ def construct_generator(self):
62
+ self.STATE.gen = spectrogram(
63
+ window_dur=self.SETTINGS.window_dur,
64
+ window_shift=self.SETTINGS.window_shift,
65
+ window=self.SETTINGS.window,
66
+ transform=self.SETTINGS.transform,
67
+ output=self.SETTINGS.output
68
+ )
@@ -0,0 +1,158 @@
1
+ from dataclasses import replace
2
+ import enum
3
+ from typing import Optional, Generator, AsyncGenerator
4
+
5
+ import numpy as np
6
+ import ezmsg.core as ez
7
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
8
+ from ezmsg.util.generator import consumer, GenAxisArray
9
+
10
+
11
+ class OptionsEnum(enum.Enum):
12
+ @classmethod
13
+ def options(cls):
14
+ return list(map(lambda c: c.value, cls))
15
+
16
+
17
+ class WindowFunction(OptionsEnum):
18
+ NONE = "None (Rectangular)"
19
+ HAMMING = "Hamming"
20
+ HANNING = "Hanning"
21
+ BARTLETT = "Bartlett"
22
+ BLACKMAN = "Blackman"
23
+
24
+
25
+ WINDOWS = {
26
+ WindowFunction.NONE: np.ones,
27
+ WindowFunction.HAMMING: np.hamming,
28
+ WindowFunction.HANNING: np.hanning,
29
+ WindowFunction.BARTLETT: np.bartlett,
30
+ WindowFunction.BLACKMAN: np.blackman,
31
+ }
32
+
33
+
34
+ class SpectralTransform(OptionsEnum):
35
+ RAW_COMPLEX = "Complex FFT Output"
36
+ REAL = "Real Component of FFT"
37
+ IMAG = "Imaginary Component of FFT"
38
+ REL_POWER = "Relative Power"
39
+ REL_DB = "Log Power (Relative dB)"
40
+
41
+
42
+ class SpectralOutput(OptionsEnum):
43
+ FULL = "Full Spectrum"
44
+ POSITIVE = "Positive Frequencies"
45
+ NEGATIVE = "Negative Frequencies"
46
+
47
+
48
+ @consumer
49
+ def spectrum(
50
+ axis: Optional[str] = None,
51
+ out_axis: Optional[str] = "freq",
52
+ window: WindowFunction = WindowFunction.HANNING,
53
+ transform: SpectralTransform = SpectralTransform.REL_DB,
54
+ output: SpectralOutput = SpectralOutput.POSITIVE
55
+ ) -> Generator[AxisArray, AxisArray, None]:
56
+
57
+ # State variables
58
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
59
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
60
+
61
+ axis_name = axis
62
+ axis_idx = None
63
+ n_time = None
64
+
65
+ while True:
66
+ axis_arr_in = yield axis_arr_out
67
+
68
+ if axis_name is None:
69
+ axis_name = axis_arr_in.dims[0]
70
+
71
+ # Initial setup
72
+ if n_time is None or axis_idx is None or axis_arr_in.data.shape[axis_idx] != n_time:
73
+ axis_idx = axis_arr_in.get_axis_idx(axis_name)
74
+ _axis = axis_arr_in.get_axis(axis_name)
75
+ n_time = axis_arr_in.data.shape[axis_idx]
76
+ freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=_axis.gain), axes=-1)
77
+ window = WINDOWS[window](n_time)
78
+ window = window.reshape([1] * axis_idx + [len(window),] + [1] * (axis_arr_in.data.ndim-2))
79
+ if (transform != SpectralTransform.RAW_COMPLEX and
80
+ not (transform == SpectralTransform.REAL or transform == SpectralTransform.IMAG)):
81
+ scale = np.sum(window ** 2.0) * _axis.gain
82
+ axis_offset = freqs[0]
83
+ if output == SpectralOutput.POSITIVE:
84
+ axis_offset = freqs[n_time // 2]
85
+ freq_axis = AxisArray.Axis(
86
+ unit="Hz", gain=1.0 / (_axis.gain * n_time), offset=axis_offset
87
+ )
88
+ if out_axis is None:
89
+ out_axis = axis_name
90
+ new_dims = axis_arr_in.dims[:axis_idx] + [out_axis, ] + axis_arr_in.dims[axis_idx + 1:]
91
+
92
+ f_transform = lambda x: x
93
+ if transform != SpectralTransform.RAW_COMPLEX:
94
+ if transform == SpectralTransform.REAL:
95
+ f_transform = lambda x: x.real
96
+ elif transform == SpectralTransform.IMAG:
97
+ f_transform = lambda x: x.imag
98
+ else:
99
+ f1 = lambda x: (2.0 * (np.abs(x) ** 2.0)) / scale
100
+ if transform == SpectralTransform.REL_DB:
101
+ f_transform = lambda x: 10 * np.log10(f1(x))
102
+ else:
103
+ f_transform = f1
104
+
105
+ new_axes = {**axis_arr_in.axes, **{out_axis: freq_axis}}
106
+ if out_axis != axis_name:
107
+ new_axes.pop(axis_name, None)
108
+
109
+ spec = np.fft.fft(axis_arr_in.data * window, axis=axis_idx) / n_time
110
+ spec = np.fft.fftshift(spec, axes=axis_idx)
111
+ spec = f_transform(spec)
112
+
113
+ if output == SpectralOutput.POSITIVE:
114
+ spec = slice_along_axis(spec, slice(n_time // 2, None), axis_idx)
115
+
116
+ elif output == SpectralOutput.NEGATIVE:
117
+ spec = slice_along_axis(spec, slice(None, n_time // 2), axis_idx)
118
+
119
+ axis_arr_out = replace(axis_arr_in, data=spec, dims=new_dims, axes=new_axes)
120
+
121
+
122
+ class SpectrumSettings(ez.Settings):
123
+ axis: Optional[str] = None
124
+ # n: Optional[int] = None # n parameter for fft
125
+ out_axis: Optional[str] = "freq" # If none; don't change dim name
126
+ window: WindowFunction = WindowFunction.HAMMING
127
+ transform: SpectralTransform = SpectralTransform.REL_DB
128
+ output: SpectralOutput = SpectralOutput.POSITIVE
129
+
130
+
131
+ class SpectrumState(ez.State):
132
+ gen: Generator
133
+ cur_settings: SpectrumSettings
134
+
135
+
136
+ class Spectrum(GenAxisArray):
137
+ SETTINGS: SpectrumSettings
138
+ STATE: SpectrumState
139
+
140
+ INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
141
+
142
+ def initialize(self) -> None:
143
+ self.STATE.cur_settings = self.SETTINGS
144
+ super().initialize()
145
+
146
+ @ez.subscriber(INPUT_SETTINGS)
147
+ async def on_settings(self, msg: SpectrumSettings):
148
+ self.STATE.cur_settings = msg
149
+ self.construct_generator()
150
+
151
+ def construct_generator(self):
152
+ self.STATE.gen = spectrum(
153
+ axis=self.STATE.cur_settings.axis,
154
+ out_axis=self.STATE.cur_settings.out_axis,
155
+ window=self.STATE.cur_settings.window,
156
+ transform=self.STATE.cur_settings.transform,
157
+ output=self.STATE.cur_settings.output
158
+ )