ezmsg-sigproc 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__init__.py +1 -1
- ezmsg/sigproc/__version__.py +1 -1
- ezmsg/sigproc/butterworthfilter.py +17 -27
- ezmsg/sigproc/decimate.py +7 -10
- ezmsg/sigproc/downsample.py +27 -33
- ezmsg/sigproc/ewmfilter.py +60 -54
- ezmsg/sigproc/filter.py +40 -24
- ezmsg/sigproc/messages.py +24 -44
- ezmsg/sigproc/sampler.py +173 -137
- ezmsg/sigproc/spectral.py +132 -0
- ezmsg/sigproc/synth.py +239 -64
- ezmsg/sigproc/window.py +92 -60
- {ezmsg_sigproc-1.1.1.dist-info → ezmsg_sigproc-1.2.0.dist-info}/METADATA +2 -2
- ezmsg_sigproc-1.2.0.dist-info/RECORD +17 -0
- {ezmsg_sigproc-1.1.1.dist-info → ezmsg_sigproc-1.2.0.dist-info}/WHEEL +1 -1
- ezmsg/sigproc/timeseriesmessage.py +0 -1
- ezmsg_sigproc-1.1.1.dist-info/RECORD +0 -17
- {ezmsg_sigproc-1.1.1.dist-info → ezmsg_sigproc-1.2.0.dist-info}/LICENSE.txt +0 -0
- {ezmsg_sigproc-1.1.1.dist-info → ezmsg_sigproc-1.2.0.dist-info}/top_level.txt +0 -0
ezmsg/sigproc/sampler.py
CHANGED
|
@@ -1,147 +1,204 @@
|
|
|
1
1
|
from dataclasses import dataclass, replace, field
|
|
2
|
-
import logging
|
|
3
2
|
import time
|
|
4
3
|
|
|
5
4
|
import ezmsg.core as ez
|
|
6
5
|
import numpy as np
|
|
7
6
|
|
|
8
|
-
from .messages import
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
8
|
|
|
10
|
-
from typing import Optional, Any, Tuple, List, Dict
|
|
9
|
+
from typing import Optional, Any, Tuple, List, Dict, AsyncGenerator
|
|
11
10
|
|
|
12
|
-
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
13
|
class SampleTriggerMessage:
|
|
14
|
-
timestamp: float = field(
|
|
15
|
-
period: Optional[
|
|
14
|
+
timestamp: float = field(default_factory=time.time)
|
|
15
|
+
period: Optional[Tuple[float, float]] = None
|
|
16
16
|
value: Any = None
|
|
17
17
|
|
|
18
|
+
|
|
18
19
|
@dataclass
|
|
19
20
|
class SampleMessage:
|
|
20
21
|
trigger: SampleTriggerMessage
|
|
21
|
-
sample:
|
|
22
|
+
sample: AxisArray
|
|
22
23
|
|
|
23
|
-
|
|
24
|
+
|
|
25
|
+
class SamplerSettings(ez.Settings):
|
|
24
26
|
buffer_dur: float
|
|
25
|
-
|
|
26
|
-
|
|
27
|
+
axis: Optional[str] = None
|
|
28
|
+
period: Optional[
|
|
29
|
+
Tuple[float, float]
|
|
30
|
+
] = None # Optional default period if unspecified in SampleTriggerMessage
|
|
31
|
+
value: Any = None # Optional default value if unspecified in SampleTriggerMessage
|
|
27
32
|
|
|
28
33
|
estimate_alignment: bool = True
|
|
29
|
-
# If true, use message timestamp fields and reported sampling rate to estimate
|
|
34
|
+
# If true, use message timestamp fields and reported sampling rate to estimate
|
|
30
35
|
# sample-accurate alignment for samples.
|
|
31
36
|
# If false, sampling will be limited to incoming message rate -- "Block timing"
|
|
32
37
|
# NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
|
|
33
38
|
# "realtime" operation for estimate_alignment to operate correctly.
|
|
34
39
|
|
|
35
|
-
class SamplerState( ez.State ):
|
|
36
|
-
triggers: Dict[ SampleTriggerMessage, int ] = field( default_factory = dict )
|
|
37
|
-
last_msg: Optional[ TSMessage ] = None
|
|
38
|
-
buffer: Optional[ np.ndarray ] = None
|
|
39
40
|
|
|
40
|
-
class
|
|
41
|
+
class SamplerState(ez.State):
|
|
42
|
+
cur_settings: SamplerSettings
|
|
43
|
+
triggers: Dict[SampleTriggerMessage, int] = field(default_factory=dict)
|
|
44
|
+
last_msg: Optional[AxisArray] = None
|
|
45
|
+
buffer: Optional[np.ndarray] = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Sampler(ez.Unit):
|
|
41
49
|
SETTINGS: SamplerSettings
|
|
42
50
|
STATE: SamplerState
|
|
43
51
|
|
|
44
|
-
INPUT_TRIGGER = ez.InputStream(
|
|
45
|
-
|
|
46
|
-
|
|
52
|
+
INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
|
|
53
|
+
INPUT_SETTINGS = ez.InputStream(SamplerSettings)
|
|
54
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
55
|
+
OUTPUT_SAMPLE = ez.OutputStream(SampleMessage)
|
|
56
|
+
|
|
57
|
+
def initialize(self) -> None:
|
|
58
|
+
self.STATE.cur_settings = self.SETTINGS
|
|
59
|
+
|
|
60
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
61
|
+
async def on_settings(self, msg: SamplerSettings) -> None:
|
|
62
|
+
self.STATE.cur_settings = msg
|
|
47
63
|
|
|
48
|
-
@ez.subscriber(
|
|
49
|
-
async def on_trigger(
|
|
64
|
+
@ez.subscriber(INPUT_TRIGGER)
|
|
65
|
+
async def on_trigger(self, msg: SampleTriggerMessage) -> None:
|
|
50
66
|
if self.STATE.last_msg is not None:
|
|
51
|
-
|
|
67
|
+
axis_name = self.STATE.cur_settings.axis
|
|
68
|
+
if axis_name is None:
|
|
69
|
+
axis_name = self.STATE.last_msg.dims[0]
|
|
70
|
+
axis = self.STATE.last_msg.get_axis(axis_name)
|
|
71
|
+
axis_idx = self.STATE.last_msg.get_axis_idx(axis_name)
|
|
72
|
+
|
|
73
|
+
fs = 1.0 / axis.gain
|
|
74
|
+
last_msg_timestamp = axis.offset + (
|
|
75
|
+
self.STATE.last_msg.shape[axis_idx] / fs
|
|
76
|
+
)
|
|
52
77
|
|
|
53
|
-
period =
|
|
54
|
-
|
|
78
|
+
period = (
|
|
79
|
+
msg.period if msg.period is not None else self.STATE.cur_settings.period
|
|
80
|
+
)
|
|
81
|
+
value = (
|
|
82
|
+
msg.value if msg.value is not None else self.STATE.cur_settings.value
|
|
83
|
+
)
|
|
55
84
|
|
|
56
85
|
if period is None:
|
|
57
|
-
ez.logger.warning(
|
|
86
|
+
ez.logger.warning(f"Sampling failed: period not specified")
|
|
58
87
|
return
|
|
59
88
|
|
|
60
89
|
# Check that period is valid
|
|
61
|
-
start_offset = int(
|
|
62
|
-
stop_offset = int(
|
|
63
|
-
if (
|
|
64
|
-
ez.logger.warning(
|
|
90
|
+
start_offset = int(period[0] * fs)
|
|
91
|
+
stop_offset = int(period[1] * fs)
|
|
92
|
+
if (stop_offset - start_offset) <= 0:
|
|
93
|
+
ez.logger.warning(f"Sampling failed: invalid period requested")
|
|
65
94
|
return
|
|
66
95
|
|
|
67
96
|
# Check that period is compatible with buffer duration
|
|
68
|
-
max_buf_len = int(
|
|
69
|
-
req_buf_len = int(
|
|
97
|
+
max_buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
|
|
98
|
+
req_buf_len = int((period[1] - period[0]) * fs)
|
|
70
99
|
if req_buf_len >= max_buf_len:
|
|
71
|
-
ez.logger.warning(
|
|
100
|
+
ez.logger.warning(
|
|
101
|
+
f"Sampling failed: {period=} >= {self.STATE.cur_settings.buffer_dur=}"
|
|
102
|
+
)
|
|
72
103
|
return
|
|
73
104
|
|
|
74
105
|
offset: int = 0
|
|
75
|
-
if self.
|
|
106
|
+
if self.STATE.cur_settings.estimate_alignment:
|
|
76
107
|
# Do what we can with the wall clock to determine sample alignment
|
|
77
|
-
wall_delta = msg.timestamp -
|
|
78
|
-
offset = int(
|
|
108
|
+
wall_delta = msg.timestamp - last_msg_timestamp
|
|
109
|
+
offset = int(wall_delta * fs)
|
|
79
110
|
|
|
80
111
|
# Check that current buffer accumulation allows for offset - period start
|
|
81
|
-
if
|
|
82
|
-
|
|
112
|
+
if (
|
|
113
|
+
self.STATE.buffer is None
|
|
114
|
+
or -min(offset + start_offset, 0) >= self.STATE.buffer.shape[0]
|
|
115
|
+
):
|
|
116
|
+
ez.logger.warning(
|
|
117
|
+
"Sampling failed: insufficient buffer accumulation for requested sample period"
|
|
118
|
+
)
|
|
83
119
|
return
|
|
84
120
|
|
|
85
|
-
|
|
86
|
-
self.STATE.triggers[ trigger ] = offset
|
|
121
|
+
self.STATE.triggers[replace(msg, period=period, value=value)] = offset
|
|
87
122
|
|
|
88
|
-
else:
|
|
123
|
+
else:
|
|
124
|
+
ez.logger.warning("Sampling failed: no signal to sample yet")
|
|
89
125
|
|
|
90
|
-
@ez.subscriber(
|
|
91
|
-
@ez.publisher(
|
|
92
|
-
async def on_signal(
|
|
126
|
+
@ez.subscriber(INPUT_SIGNAL)
|
|
127
|
+
@ez.publisher(OUTPUT_SAMPLE)
|
|
128
|
+
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
|
|
129
|
+
axis_name = self.STATE.cur_settings.axis
|
|
130
|
+
if axis_name is None:
|
|
131
|
+
axis_name = msg.dims[0]
|
|
132
|
+
axis = msg.get_axis(axis_name)
|
|
133
|
+
|
|
134
|
+
fs = 1.0 / axis.gain
|
|
93
135
|
|
|
94
136
|
if self.STATE.last_msg is None:
|
|
95
137
|
self.STATE.last_msg = msg
|
|
96
138
|
|
|
97
139
|
# Easier to deal with timeseries on axis 0
|
|
98
140
|
last_msg = self.STATE.last_msg
|
|
99
|
-
msg_data = np.
|
|
100
|
-
last_msg_data = np.
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
):
|
|
141
|
+
msg_data = np.moveaxis(msg.data, msg.get_axis_idx(axis_name), 0)
|
|
142
|
+
last_msg_data = np.moveaxis(last_msg.data, last_msg.get_axis_idx(axis_name), 0)
|
|
143
|
+
last_msg_axis = last_msg.get_axis(axis_name)
|
|
144
|
+
last_msg_fs = 1.0 / last_msg_axis.gain
|
|
145
|
+
|
|
146
|
+
# Check if signal properties have changed in a breaking way
|
|
147
|
+
if fs != last_msg_fs or msg_data.shape[1:] != last_msg_data.shape[1:]:
|
|
107
148
|
# Data stream changed meaningfully -- flush buffer, stop sampling
|
|
108
|
-
if len(
|
|
109
|
-
ez.logger.warning(
|
|
110
|
-
ez.logger.warning(
|
|
149
|
+
if len(self.STATE.triggers) > 0:
|
|
150
|
+
ez.logger.warning("Sampling failed: Discarding all triggers")
|
|
151
|
+
ez.logger.warning("Flushing buffer: signal properties changed")
|
|
111
152
|
self.STATE.buffer = None
|
|
112
153
|
self.STATE.triggers = dict()
|
|
113
154
|
|
|
114
155
|
# Accumulate buffer ( time dim => dim 0 )
|
|
115
|
-
self.STATE.buffer =
|
|
116
|
-
|
|
156
|
+
self.STATE.buffer = (
|
|
157
|
+
msg_data
|
|
158
|
+
if self.STATE.buffer is None
|
|
159
|
+
else np.concatenate((self.STATE.buffer, msg_data), axis=0)
|
|
160
|
+
)
|
|
117
161
|
|
|
118
|
-
|
|
119
|
-
|
|
162
|
+
buffer_offset = np.arange(self.STATE.buffer.shape[0] + msg_data.shape[0])
|
|
163
|
+
buffer_offset -= self.STATE.buffer.shape[0] + 1
|
|
164
|
+
buffer_offset = (buffer_offset * axis.gain) + axis.offset
|
|
165
|
+
|
|
166
|
+
pub_samples: List[SampleMessage] = []
|
|
167
|
+
remaining_triggers: Dict[SampleTriggerMessage, int] = dict()
|
|
120
168
|
for trigger, offset in self.STATE.triggers.items():
|
|
169
|
+
if trigger.period is None:
|
|
170
|
+
continue
|
|
121
171
|
|
|
122
172
|
# trigger_offset points to t = 0 within buffer
|
|
123
|
-
offset
|
|
124
|
-
start = offset + int(
|
|
125
|
-
stop = offset + int(
|
|
173
|
+
offset -= msg_data.shape[0]
|
|
174
|
+
start = offset + int(trigger.period[0] * fs)
|
|
175
|
+
stop = offset + int(trigger.period[1] * fs)
|
|
176
|
+
|
|
177
|
+
if stop < 0: # We should be able to dispatch a sample
|
|
178
|
+
sample_data = self.STATE.buffer[start:stop, ...]
|
|
179
|
+
sample_data = np.moveaxis(sample_data, msg.get_axis_idx(axis_name), 0)
|
|
180
|
+
|
|
181
|
+
sample_offset = buffer_offset[start]
|
|
182
|
+
sample_axis = replace(axis, offset=sample_offset)
|
|
183
|
+
sample_axes = {**msg.axes, **{axis_name: sample_axis}}
|
|
126
184
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
sample = replace( msg, data = sample_data )
|
|
134
|
-
) )
|
|
185
|
+
pub_samples.append(
|
|
186
|
+
SampleMessage(
|
|
187
|
+
trigger=trigger,
|
|
188
|
+
sample=replace(msg, data=sample_data, axes=sample_axes),
|
|
189
|
+
)
|
|
190
|
+
)
|
|
135
191
|
|
|
136
|
-
else:
|
|
192
|
+
else:
|
|
193
|
+
remaining_triggers[trigger] = offset
|
|
137
194
|
|
|
138
195
|
for sample in pub_samples:
|
|
139
|
-
yield self.OUTPUT_SAMPLE, sample
|
|
140
|
-
|
|
196
|
+
yield self.OUTPUT_SAMPLE, sample
|
|
197
|
+
|
|
141
198
|
self.STATE.triggers = remaining_triggers
|
|
142
199
|
|
|
143
|
-
buf_len = int(
|
|
144
|
-
self.STATE.buffer = self.STATE.buffer[
|
|
200
|
+
buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
|
|
201
|
+
self.STATE.buffer = self.STATE.buffer[-buf_len:, ...]
|
|
145
202
|
self.STATE.last_msg = msg
|
|
146
203
|
|
|
147
204
|
|
|
@@ -153,99 +210,78 @@ from ezmsg.sigproc.synth import Oscillator, OscillatorSettings
|
|
|
153
210
|
|
|
154
211
|
from typing import AsyncGenerator
|
|
155
212
|
|
|
156
|
-
class SampleFormatter( ez.Unit ):
|
|
157
|
-
|
|
158
|
-
INPUT = ez.InputStream( SampleMessage )
|
|
159
|
-
OUTPUT = ez.OutputStream( str )
|
|
160
|
-
|
|
161
|
-
@ez.subscriber( INPUT )
|
|
162
|
-
@ez.publisher( OUTPUT )
|
|
163
|
-
async def format( self, msg: SampleMessage ) -> AsyncGenerator:
|
|
164
|
-
str_msg = f'Trigger: {msg.trigger.value}, '
|
|
165
|
-
str_msg += f'{msg.sample.n_time} samples @ {msg.sample.fs} Hz, '
|
|
166
|
-
|
|
167
|
-
time_axis = np.arange(msg.sample.n_time) / msg.sample.fs
|
|
168
|
-
time_axis = time_axis + msg.trigger.period[0]
|
|
169
|
-
str_msg += f'[{time_axis[0]},{time_axis[-1]})'
|
|
170
|
-
|
|
171
|
-
yield self.OUTPUT, str_msg
|
|
172
213
|
|
|
173
|
-
class TriggerGeneratorSettings(
|
|
174
|
-
period: Tuple[
|
|
175
|
-
prewait: float = 0.5
|
|
176
|
-
publish_period: float = 5.0
|
|
214
|
+
class TriggerGeneratorSettings(ez.Settings):
|
|
215
|
+
period: Tuple[float, float] # sec
|
|
216
|
+
prewait: float = 0.5 # sec
|
|
217
|
+
publish_period: float = 5.0 # sec
|
|
177
218
|
|
|
178
|
-
class TriggerGenerator( ez.Unit ):
|
|
179
219
|
|
|
220
|
+
class TriggerGenerator(ez.Unit):
|
|
180
221
|
SETTINGS: TriggerGeneratorSettings
|
|
181
222
|
|
|
182
|
-
OUTPUT_TRIGGER = ez.OutputStream(
|
|
223
|
+
OUTPUT_TRIGGER = ez.OutputStream(SampleTriggerMessage)
|
|
183
224
|
|
|
184
|
-
@ez.publisher(
|
|
185
|
-
async def generate(
|
|
186
|
-
await asyncio.sleep(
|
|
225
|
+
@ez.publisher(OUTPUT_TRIGGER)
|
|
226
|
+
async def generate(self) -> AsyncGenerator:
|
|
227
|
+
await asyncio.sleep(self.SETTINGS.prewait)
|
|
187
228
|
|
|
188
229
|
output = 0
|
|
189
230
|
while True:
|
|
190
231
|
yield self.OUTPUT_TRIGGER, SampleTriggerMessage(
|
|
191
|
-
period
|
|
192
|
-
value = output
|
|
232
|
+
period=self.SETTINGS.period, value=output
|
|
193
233
|
)
|
|
194
234
|
|
|
195
|
-
await asyncio.sleep(
|
|
235
|
+
await asyncio.sleep(self.SETTINGS.publish_period)
|
|
196
236
|
output += 1
|
|
197
237
|
|
|
198
|
-
|
|
238
|
+
|
|
239
|
+
class SamplerTestSystemSettings(ez.Settings):
|
|
199
240
|
sampler_settings: SamplerSettings
|
|
200
241
|
trigger_settings: TriggerGeneratorSettings
|
|
201
242
|
|
|
202
|
-
class SamplerTestSystem( ez.System ):
|
|
203
243
|
|
|
244
|
+
class SamplerTestSystem(ez.System):
|
|
204
245
|
SETTINGS: SamplerTestSystemSettings
|
|
205
246
|
|
|
206
247
|
OSC = Oscillator()
|
|
207
248
|
SAMPLER = Sampler()
|
|
208
249
|
TRIGGER = TriggerGenerator()
|
|
209
|
-
FORMATTER = SampleFormatter()
|
|
210
250
|
DEBUG = DebugLog()
|
|
211
251
|
|
|
212
|
-
def configure(
|
|
213
|
-
self.SAMPLER.apply_settings(
|
|
214
|
-
self.TRIGGER.apply_settings(
|
|
215
|
-
|
|
216
|
-
self.OSC.apply_settings(
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
return (
|
|
228
|
-
( self.OSC.OUTPUT_SIGNAL, self.SAMPLER.INPUT_SIGNAL ),
|
|
229
|
-
( self.TRIGGER.OUTPUT_TRIGGER, self.SAMPLER.INPUT_TRIGGER ),
|
|
252
|
+
def configure(self) -> None:
|
|
253
|
+
self.SAMPLER.apply_settings(self.SETTINGS.sampler_settings)
|
|
254
|
+
self.TRIGGER.apply_settings(self.SETTINGS.trigger_settings)
|
|
255
|
+
|
|
256
|
+
self.OSC.apply_settings(
|
|
257
|
+
OscillatorSettings(
|
|
258
|
+
n_time=2, # Number of samples to output per block
|
|
259
|
+
fs=10, # Sampling rate of signal output in Hz
|
|
260
|
+
dispatch_rate="realtime",
|
|
261
|
+
freq=2.0, # Oscillation frequency in Hz
|
|
262
|
+
amp=1.0, # Amplitude
|
|
263
|
+
phase=0.0, # Phase offset (in radians)
|
|
264
|
+
sync=True, # Adjust `freq` to sync with sampling rate
|
|
265
|
+
)
|
|
266
|
+
)
|
|
230
267
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
(
|
|
268
|
+
def network(self) -> ez.NetworkDefinition:
|
|
269
|
+
return (
|
|
270
|
+
(self.OSC.OUTPUT_SIGNAL, self.SAMPLER.INPUT_SIGNAL),
|
|
271
|
+
(self.TRIGGER.OUTPUT_TRIGGER, self.SAMPLER.INPUT_TRIGGER),
|
|
272
|
+
(self.TRIGGER.OUTPUT_TRIGGER, self.DEBUG.INPUT),
|
|
273
|
+
(self.SAMPLER.OUTPUT_SAMPLE, self.DEBUG.INPUT),
|
|
234
274
|
)
|
|
235
275
|
|
|
236
|
-
if __name__ == '__main__':
|
|
237
276
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
277
|
+
if __name__ == "__main__":
|
|
278
|
+
settings = SamplerTestSystemSettings(
|
|
279
|
+
sampler_settings=SamplerSettings(buffer_dur=5.0),
|
|
280
|
+
trigger_settings=TriggerGeneratorSettings(
|
|
281
|
+
period=(1.0, 2.0), prewait=0.5, publish_period=5.0
|
|
241
282
|
),
|
|
242
|
-
trigger_settings = TriggerGeneratorSettings(
|
|
243
|
-
period = ( 1.0, 2.0 ),
|
|
244
|
-
prewait = 0.5,
|
|
245
|
-
publish_period = 5.0
|
|
246
|
-
)
|
|
247
283
|
)
|
|
248
284
|
|
|
249
|
-
system = SamplerTestSystem(
|
|
285
|
+
system = SamplerTestSystem(settings)
|
|
250
286
|
|
|
251
|
-
ez.run_system(
|
|
287
|
+
ez.run_system(system)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
|
|
3
|
+
from dataclasses import replace
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import ezmsg.core as ez
|
|
7
|
+
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
|
|
10
|
+
from typing import Optional, AsyncGenerator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OptionsEnum(enum.Enum):
|
|
14
|
+
@classmethod
|
|
15
|
+
def options(cls):
|
|
16
|
+
return list(map(lambda c: c.value, cls))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class WindowFunction(OptionsEnum):
|
|
20
|
+
NONE = "None (Rectangular)"
|
|
21
|
+
HAMMING = "Hamming"
|
|
22
|
+
HANNING = "Hanning"
|
|
23
|
+
BARTLETT = "Bartlett"
|
|
24
|
+
BLACKMAN = "Blackman"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
WINDOWS = {
|
|
28
|
+
WindowFunction.NONE: np.ones,
|
|
29
|
+
WindowFunction.HAMMING: np.hamming,
|
|
30
|
+
WindowFunction.HANNING: np.hanning,
|
|
31
|
+
WindowFunction.BARTLETT: np.bartlett,
|
|
32
|
+
WindowFunction.BLACKMAN: np.blackman,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SpectralTransform(OptionsEnum):
|
|
37
|
+
RAW_COMPLEX = "Complex FFT Output"
|
|
38
|
+
REAL = "Real Component of FFT"
|
|
39
|
+
IMAG = "Imaginary Component of FFT"
|
|
40
|
+
REL_POWER = "Relative Power"
|
|
41
|
+
REL_DB = "Log Power (Relative dB)"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class SpectralOutput(OptionsEnum):
|
|
45
|
+
FULL = "Full Spectrum"
|
|
46
|
+
POSITIVE = "Positive Frequencies"
|
|
47
|
+
NEGATIVE = "Negative Frequencies"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class SpectrumSettings(ez.Settings):
|
|
51
|
+
axis: Optional[str] = None
|
|
52
|
+
# n: Optional[int] = None # n parameter for fft
|
|
53
|
+
out_axis: Optional[str] = "freq" # If none; don't change dim name
|
|
54
|
+
window: WindowFunction = WindowFunction.HAMMING
|
|
55
|
+
transform: SpectralTransform = SpectralTransform.REL_DB
|
|
56
|
+
output: SpectralOutput = SpectralOutput.POSITIVE
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SpectrumState(ez.State):
|
|
60
|
+
cur_settings: SpectrumSettings
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Spectrum(ez.Unit):
|
|
64
|
+
SETTINGS: SpectrumSettings
|
|
65
|
+
STATE: SpectrumState
|
|
66
|
+
|
|
67
|
+
INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
|
|
68
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
69
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
70
|
+
|
|
71
|
+
def initialize(self) -> None:
|
|
72
|
+
self.STATE.cur_settings = self.SETTINGS
|
|
73
|
+
|
|
74
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
75
|
+
async def on_settings(self, msg: SpectrumSettings):
|
|
76
|
+
self.STATE.cur_settings = msg
|
|
77
|
+
|
|
78
|
+
@ez.subscriber(INPUT_SIGNAL)
|
|
79
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
80
|
+
async def on_data(self, message: AxisArray) -> AsyncGenerator:
|
|
81
|
+
axis_name = self.STATE.cur_settings.axis
|
|
82
|
+
if axis_name is None:
|
|
83
|
+
axis_name = message.dims[0]
|
|
84
|
+
axis_idx = message.get_axis_idx(axis_name)
|
|
85
|
+
axis = message.get_axis(axis_name)
|
|
86
|
+
|
|
87
|
+
spectrum = np.moveaxis(message.data, axis_idx, -1)
|
|
88
|
+
|
|
89
|
+
n_time = message.data.shape[axis_idx]
|
|
90
|
+
window = WINDOWS[self.STATE.cur_settings.window](n_time)
|
|
91
|
+
|
|
92
|
+
spectrum = np.fft.fft(spectrum * window) / n_time
|
|
93
|
+
spectrum = np.fft.fftshift(spectrum, axes=-1)
|
|
94
|
+
freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=axis.gain), axes=-1)
|
|
95
|
+
|
|
96
|
+
if self.STATE.cur_settings.transform != SpectralTransform.RAW_COMPLEX:
|
|
97
|
+
if self.STATE.cur_settings.transform == SpectralTransform.REAL:
|
|
98
|
+
spectrum = spectrum.real
|
|
99
|
+
elif self.STATE.cur_settings.transform == SpectralTransform.IMAG:
|
|
100
|
+
spectrum = spectrum.imag
|
|
101
|
+
else:
|
|
102
|
+
scale = np.sum(window**2.0) * axis.gain
|
|
103
|
+
spectrum = (2.0 * (np.abs(spectrum) ** 2.0)) / scale
|
|
104
|
+
|
|
105
|
+
if self.STATE.cur_settings.transform == SpectralTransform.REL_DB:
|
|
106
|
+
spectrum = 10 * np.log10(spectrum)
|
|
107
|
+
|
|
108
|
+
axis_offset = freqs[0]
|
|
109
|
+
if self.STATE.cur_settings.output == SpectralOutput.POSITIVE:
|
|
110
|
+
axis_offset = freqs[n_time // 2]
|
|
111
|
+
spectrum = spectrum[..., n_time // 2 :]
|
|
112
|
+
elif self.STATE.cur_settings.output == SpectralOutput.NEGATIVE:
|
|
113
|
+
spectrum = spectrum[..., : n_time // 2]
|
|
114
|
+
|
|
115
|
+
spectrum = np.moveaxis(spectrum, axis_idx, -1)
|
|
116
|
+
|
|
117
|
+
out_axis = self.SETTINGS.out_axis
|
|
118
|
+
if out_axis is None:
|
|
119
|
+
out_axis = axis_name
|
|
120
|
+
|
|
121
|
+
freq_axis = AxisArray.Axis(
|
|
122
|
+
unit="Hz", gain=1.0 / (axis.gain * n_time), offset=axis_offset
|
|
123
|
+
)
|
|
124
|
+
new_axes = {**message.axes, **{out_axis: freq_axis}}
|
|
125
|
+
|
|
126
|
+
new_dims = [d for d in message.dims]
|
|
127
|
+
if self.SETTINGS.out_axis is not None:
|
|
128
|
+
new_dims[axis_idx] = self.SETTINGS.out_axis
|
|
129
|
+
|
|
130
|
+
out_msg = replace(message, data=spectrum, dims=new_dims, axes=new_axes)
|
|
131
|
+
|
|
132
|
+
yield self.OUTPUT_SIGNAL, out_msg
|