ezmsg-sigproc 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/sampler.py CHANGED
@@ -1,147 +1,204 @@
1
1
  from dataclasses import dataclass, replace, field
2
- import logging
3
2
  import time
4
3
 
5
4
  import ezmsg.core as ez
6
5
  import numpy as np
7
6
 
8
- from .messages import TSMessage
7
+ from ezmsg.util.messages.axisarray import AxisArray
9
8
 
10
- from typing import Optional, Any, Tuple, List, Dict
9
+ from typing import Optional, Any, Tuple, List, Dict, AsyncGenerator
11
10
 
12
- @dataclass( frozen = True )
11
+
12
+ @dataclass(frozen=True)
13
13
  class SampleTriggerMessage:
14
- timestamp: float = field( default_factory = time.time )
15
- period: Optional[ Tuple[ float, float ] ] = None
14
+ timestamp: float = field(default_factory=time.time)
15
+ period: Optional[Tuple[float, float]] = None
16
16
  value: Any = None
17
17
 
18
+
18
19
  @dataclass
19
20
  class SampleMessage:
20
21
  trigger: SampleTriggerMessage
21
- sample: TSMessage
22
+ sample: AxisArray
22
23
 
23
- class SamplerSettings( ez.Settings ):
24
+
25
+ class SamplerSettings(ez.Settings):
24
26
  buffer_dur: float
25
- period: Optional[ Tuple[ float, float ] ] = None # Optional default period if unspecified in SampleTriggerMessage
26
- value: Any = None # Optional default value if unspecified in SampleTriggerMessage
27
+ axis: Optional[str] = None
28
+ period: Optional[
29
+ Tuple[float, float]
30
+ ] = None # Optional default period if unspecified in SampleTriggerMessage
31
+ value: Any = None # Optional default value if unspecified in SampleTriggerMessage
27
32
 
28
33
  estimate_alignment: bool = True
29
- # If true, use message timestamp fields and reported sampling rate to estimate
34
+ # If true, use message timestamp fields and reported sampling rate to estimate
30
35
  # sample-accurate alignment for samples.
31
36
  # If false, sampling will be limited to incoming message rate -- "Block timing"
32
37
  # NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
33
38
  # "realtime" operation for estimate_alignment to operate correctly.
34
39
 
35
- class SamplerState( ez.State ):
36
- triggers: Dict[ SampleTriggerMessage, int ] = field( default_factory = dict )
37
- last_msg: Optional[ TSMessage ] = None
38
- buffer: Optional[ np.ndarray ] = None
39
40
 
40
- class Sampler( ez.Unit ):
41
+ class SamplerState(ez.State):
42
+ cur_settings: SamplerSettings
43
+ triggers: Dict[SampleTriggerMessage, int] = field(default_factory=dict)
44
+ last_msg: Optional[AxisArray] = None
45
+ buffer: Optional[np.ndarray] = None
46
+
47
+
48
+ class Sampler(ez.Unit):
41
49
  SETTINGS: SamplerSettings
42
50
  STATE: SamplerState
43
51
 
44
- INPUT_TRIGGER = ez.InputStream( SampleTriggerMessage )
45
- INPUT_SIGNAL = ez.InputStream( TSMessage )
46
- OUTPUT_SAMPLE = ez.OutputStream( SampleMessage )
52
+ INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
53
+ INPUT_SETTINGS = ez.InputStream(SamplerSettings)
54
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
55
+ OUTPUT_SAMPLE = ez.OutputStream(SampleMessage)
56
+
57
+ def initialize(self) -> None:
58
+ self.STATE.cur_settings = self.SETTINGS
59
+
60
+ @ez.subscriber(INPUT_SETTINGS)
61
+ async def on_settings(self, msg: SamplerSettings) -> None:
62
+ self.STATE.cur_settings = msg
47
63
 
48
- @ez.subscriber( INPUT_TRIGGER )
49
- async def on_trigger( self, msg: SampleTriggerMessage ) -> None:
64
+ @ez.subscriber(INPUT_TRIGGER)
65
+ async def on_trigger(self, msg: SampleTriggerMessage) -> None:
50
66
  if self.STATE.last_msg is not None:
51
- fs = self.STATE.last_msg.fs
67
+ axis_name = self.STATE.cur_settings.axis
68
+ if axis_name is None:
69
+ axis_name = self.STATE.last_msg.dims[0]
70
+ axis = self.STATE.last_msg.get_axis(axis_name)
71
+ axis_idx = self.STATE.last_msg.get_axis_idx(axis_name)
72
+
73
+ fs = 1.0 / axis.gain
74
+ last_msg_timestamp = axis.offset + (
75
+ self.STATE.last_msg.shape[axis_idx] / fs
76
+ )
52
77
 
53
- period = msg.period if msg.period is not None else self.SETTINGS.period
54
- value = msg.value if msg.value is not None else self.SETTINGS.value
78
+ period = (
79
+ msg.period if msg.period is not None else self.STATE.cur_settings.period
80
+ )
81
+ value = (
82
+ msg.value if msg.value is not None else self.STATE.cur_settings.value
83
+ )
55
84
 
56
85
  if period is None:
57
- ez.logger.warning( f'Sampling failed: period not specified' )
86
+ ez.logger.warning(f"Sampling failed: period not specified")
58
87
  return
59
88
 
60
89
  # Check that period is valid
61
- start_offset = int( period[0] * fs )
62
- stop_offset = int( period[1] * fs )
63
- if ( stop_offset - start_offset ) <= 0:
64
- ez.logger.warning( f'Sampling failed: invalid period requested' )
90
+ start_offset = int(period[0] * fs)
91
+ stop_offset = int(period[1] * fs)
92
+ if (stop_offset - start_offset) <= 0:
93
+ ez.logger.warning(f"Sampling failed: invalid period requested")
65
94
  return
66
95
 
67
96
  # Check that period is compatible with buffer duration
68
- max_buf_len = int( self.SETTINGS.buffer_dur * fs )
69
- req_buf_len = int( ( period[1] - period[0] ) * fs )
97
+ max_buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
98
+ req_buf_len = int((period[1] - period[0]) * fs)
70
99
  if req_buf_len >= max_buf_len:
71
- ez.logger.warning( f'Sampling failed: {period=} >= {self.SETTINGS.buffer_dur=}' )
100
+ ez.logger.warning(
101
+ f"Sampling failed: {period=} >= {self.STATE.cur_settings.buffer_dur=}"
102
+ )
72
103
  return
73
104
 
74
105
  offset: int = 0
75
- if self.SETTINGS.estimate_alignment:
106
+ if self.STATE.cur_settings.estimate_alignment:
76
107
  # Do what we can with the wall clock to determine sample alignment
77
- wall_delta = msg.timestamp - self.STATE.last_msg.timestamp
78
- offset = int( wall_delta * fs )
108
+ wall_delta = msg.timestamp - last_msg_timestamp
109
+ offset = int(wall_delta * fs)
79
110
 
80
111
  # Check that current buffer accumulation allows for offset - period start
81
- if -min( offset + start_offset, 0 ) >= self.STATE.buffer.shape[0]:
82
- ez.logger.warning( 'Sampling failed: insufficient buffer accumulation for requested sample period' )
112
+ if (
113
+ self.STATE.buffer is None
114
+ or -min(offset + start_offset, 0) >= self.STATE.buffer.shape[0]
115
+ ):
116
+ ez.logger.warning(
117
+ "Sampling failed: insufficient buffer accumulation for requested sample period"
118
+ )
83
119
  return
84
120
 
85
- trigger = replace( msg, period = period, value = value )
86
- self.STATE.triggers[ trigger ] = offset
121
+ self.STATE.triggers[replace(msg, period=period, value=value)] = offset
87
122
 
88
- else: ez.logger.warning( 'Sampling failed: no signal to sample yet' )
123
+ else:
124
+ ez.logger.warning("Sampling failed: no signal to sample yet")
89
125
 
90
- @ez.subscriber( INPUT_SIGNAL )
91
- @ez.publisher( OUTPUT_SAMPLE )
92
- async def on_signal( self, msg: TSMessage ) -> None:
126
+ @ez.subscriber(INPUT_SIGNAL)
127
+ @ez.publisher(OUTPUT_SAMPLE)
128
+ async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
129
+ axis_name = self.STATE.cur_settings.axis
130
+ if axis_name is None:
131
+ axis_name = msg.dims[0]
132
+ axis = msg.get_axis(axis_name)
133
+
134
+ fs = 1.0 / axis.gain
93
135
 
94
136
  if self.STATE.last_msg is None:
95
137
  self.STATE.last_msg = msg
96
138
 
97
139
  # Easier to deal with timeseries on axis 0
98
140
  last_msg = self.STATE.last_msg
99
- msg_data = np.swapaxes( msg.data, msg.time_dim, 0 )
100
- last_msg_data = np.swapaxes( last_msg.data, last_msg.time_dim, 0 )
101
-
102
- if ( # Check if signal properties have changed in a breaking way
103
- msg.fs != last_msg.fs or \
104
- msg.time_dim != last_msg.time_dim or \
105
- msg_data.shape[1:] != last_msg_data.shape[1:]
106
- ):
141
+ msg_data = np.moveaxis(msg.data, msg.get_axis_idx(axis_name), 0)
142
+ last_msg_data = np.moveaxis(last_msg.data, last_msg.get_axis_idx(axis_name), 0)
143
+ last_msg_axis = last_msg.get_axis(axis_name)
144
+ last_msg_fs = 1.0 / last_msg_axis.gain
145
+
146
+ # Check if signal properties have changed in a breaking way
147
+ if fs != last_msg_fs or msg_data.shape[1:] != last_msg_data.shape[1:]:
107
148
  # Data stream changed meaningfully -- flush buffer, stop sampling
108
- if len( self.STATE.triggers ) > 0:
109
- ez.logger.warning( 'Sampling failed: Discarding all triggers' )
110
- ez.logger.warning( 'Flushing buffer: signal properties changed' )
149
+ if len(self.STATE.triggers) > 0:
150
+ ez.logger.warning("Sampling failed: Discarding all triggers")
151
+ ez.logger.warning("Flushing buffer: signal properties changed")
111
152
  self.STATE.buffer = None
112
153
  self.STATE.triggers = dict()
113
154
 
114
155
  # Accumulate buffer ( time dim => dim 0 )
115
- self.STATE.buffer = msg_data if self.STATE.buffer is None else \
116
- np.concatenate( ( self.STATE.buffer, msg_data ), axis = 0 )
156
+ self.STATE.buffer = (
157
+ msg_data
158
+ if self.STATE.buffer is None
159
+ else np.concatenate((self.STATE.buffer, msg_data), axis=0)
160
+ )
117
161
 
118
- pub_samples: List[ SampleMessage ] = []
119
- remaining_triggers: Dict[ SampleTriggerMessage, int ] = dict()
162
+ buffer_offset = np.arange(self.STATE.buffer.shape[0] + msg_data.shape[0])
163
+ buffer_offset -= self.STATE.buffer.shape[0] + 1
164
+ buffer_offset = (buffer_offset * axis.gain) + axis.offset
165
+
166
+ pub_samples: List[SampleMessage] = []
167
+ remaining_triggers: Dict[SampleTriggerMessage, int] = dict()
120
168
  for trigger, offset in self.STATE.triggers.items():
169
+ if trigger.period is None:
170
+ continue
121
171
 
122
172
  # trigger_offset points to t = 0 within buffer
123
- offset = offset - msg.n_time
124
- start = offset + int( trigger.period[0] * msg.fs )
125
- stop = offset + int( trigger.period[1] * msg.fs )
173
+ offset -= msg_data.shape[0]
174
+ start = offset + int(trigger.period[0] * fs)
175
+ stop = offset + int(trigger.period[1] * fs)
176
+
177
+ if stop < 0: # We should be able to dispatch a sample
178
+ sample_data = self.STATE.buffer[start:stop, ...]
179
+ sample_data = np.moveaxis(sample_data, msg.get_axis_idx(axis_name), 0)
180
+
181
+ sample_offset = buffer_offset[start]
182
+ sample_axis = replace(axis, offset=sample_offset)
183
+ sample_axes = {**msg.axes, **{axis_name: sample_axis}}
126
184
 
127
- if stop < 0: # We should be able to dispatch a sample
128
- sample_data = self.STATE.buffer[ start : stop, ... ]
129
- sample_data = np.swapaxes( sample_data, msg.time_dim, 0 )
130
-
131
- pub_samples.append( SampleMessage(
132
- trigger = trigger,
133
- sample = replace( msg, data = sample_data )
134
- ) )
185
+ pub_samples.append(
186
+ SampleMessage(
187
+ trigger=trigger,
188
+ sample=replace(msg, data=sample_data, axes=sample_axes),
189
+ )
190
+ )
135
191
 
136
- else: remaining_triggers[ trigger ] = offset
192
+ else:
193
+ remaining_triggers[trigger] = offset
137
194
 
138
195
  for sample in pub_samples:
139
- yield self.OUTPUT_SAMPLE, sample
140
-
196
+ yield self.OUTPUT_SAMPLE, sample
197
+
141
198
  self.STATE.triggers = remaining_triggers
142
199
 
143
- buf_len = int( self.SETTINGS.buffer_dur * msg.fs )
144
- self.STATE.buffer = self.STATE.buffer[ -buf_len:, ... ]
200
+ buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
201
+ self.STATE.buffer = self.STATE.buffer[-buf_len:, ...]
145
202
  self.STATE.last_msg = msg
146
203
 
147
204
 
@@ -153,99 +210,78 @@ from ezmsg.sigproc.synth import Oscillator, OscillatorSettings
153
210
 
154
211
  from typing import AsyncGenerator
155
212
 
156
- class SampleFormatter( ez.Unit ):
157
-
158
- INPUT = ez.InputStream( SampleMessage )
159
- OUTPUT = ez.OutputStream( str )
160
-
161
- @ez.subscriber( INPUT )
162
- @ez.publisher( OUTPUT )
163
- async def format( self, msg: SampleMessage ) -> AsyncGenerator:
164
- str_msg = f'Trigger: {msg.trigger.value}, '
165
- str_msg += f'{msg.sample.n_time} samples @ {msg.sample.fs} Hz, '
166
-
167
- time_axis = np.arange(msg.sample.n_time) / msg.sample.fs
168
- time_axis = time_axis + msg.trigger.period[0]
169
- str_msg += f'[{time_axis[0]},{time_axis[-1]})'
170
-
171
- yield self.OUTPUT, str_msg
172
213
 
173
- class TriggerGeneratorSettings( ez.Settings ):
174
- period: Tuple[ float, float ] # sec
175
- prewait: float = 0.5 # sec
176
- publish_period: float = 5.0 # sec
214
+ class TriggerGeneratorSettings(ez.Settings):
215
+ period: Tuple[float, float] # sec
216
+ prewait: float = 0.5 # sec
217
+ publish_period: float = 5.0 # sec
177
218
 
178
- class TriggerGenerator( ez.Unit ):
179
219
 
220
+ class TriggerGenerator(ez.Unit):
180
221
  SETTINGS: TriggerGeneratorSettings
181
222
 
182
- OUTPUT_TRIGGER = ez.OutputStream( SampleTriggerMessage )
223
+ OUTPUT_TRIGGER = ez.OutputStream(SampleTriggerMessage)
183
224
 
184
- @ez.publisher( OUTPUT_TRIGGER )
185
- async def generate( self ) -> AsyncGenerator:
186
- await asyncio.sleep( self.SETTINGS.prewait )
225
+ @ez.publisher(OUTPUT_TRIGGER)
226
+ async def generate(self) -> AsyncGenerator:
227
+ await asyncio.sleep(self.SETTINGS.prewait)
187
228
 
188
229
  output = 0
189
230
  while True:
190
231
  yield self.OUTPUT_TRIGGER, SampleTriggerMessage(
191
- period = self.SETTINGS.period,
192
- value = output
232
+ period=self.SETTINGS.period, value=output
193
233
  )
194
234
 
195
- await asyncio.sleep( self.SETTINGS.publish_period )
235
+ await asyncio.sleep(self.SETTINGS.publish_period)
196
236
  output += 1
197
237
 
198
- class SamplerTestSystemSettings( ez.Settings ):
238
+
239
+ class SamplerTestSystemSettings(ez.Settings):
199
240
  sampler_settings: SamplerSettings
200
241
  trigger_settings: TriggerGeneratorSettings
201
242
 
202
- class SamplerTestSystem( ez.System ):
203
243
 
244
+ class SamplerTestSystem(ez.System):
204
245
  SETTINGS: SamplerTestSystemSettings
205
246
 
206
247
  OSC = Oscillator()
207
248
  SAMPLER = Sampler()
208
249
  TRIGGER = TriggerGenerator()
209
- FORMATTER = SampleFormatter()
210
250
  DEBUG = DebugLog()
211
251
 
212
- def configure( self ) -> None:
213
- self.SAMPLER.apply_settings( self.SETTINGS.sampler_settings )
214
- self.TRIGGER.apply_settings( self.SETTINGS.trigger_settings )
215
-
216
- self.OSC.apply_settings( OscillatorSettings(
217
- n_time = 2, # Number of samples to output per block
218
- fs = 10, # Sampling rate of signal output in Hz
219
- dispatch_rate = 'realtime',
220
- freq = 2.0, # Oscillation frequency in Hz
221
- amp = 1.0, # Amplitude
222
- phase = 0.0, # Phase offset (in radians)
223
- sync = True, # Adjust `freq` to sync with sampling rate
224
- ) )
225
-
226
- def network( self ) -> ez.NetworkDefinition:
227
- return (
228
- ( self.OSC.OUTPUT_SIGNAL, self.SAMPLER.INPUT_SIGNAL ),
229
- ( self.TRIGGER.OUTPUT_TRIGGER, self.SAMPLER.INPUT_TRIGGER ),
252
+ def configure(self) -> None:
253
+ self.SAMPLER.apply_settings(self.SETTINGS.sampler_settings)
254
+ self.TRIGGER.apply_settings(self.SETTINGS.trigger_settings)
255
+
256
+ self.OSC.apply_settings(
257
+ OscillatorSettings(
258
+ n_time=2, # Number of samples to output per block
259
+ fs=10, # Sampling rate of signal output in Hz
260
+ dispatch_rate="realtime",
261
+ freq=2.0, # Oscillation frequency in Hz
262
+ amp=1.0, # Amplitude
263
+ phase=0.0, # Phase offset (in radians)
264
+ sync=True, # Adjust `freq` to sync with sampling rate
265
+ )
266
+ )
230
267
 
231
- ( self.TRIGGER.OUTPUT_TRIGGER, self.DEBUG.INPUT ),
232
- ( self.SAMPLER.OUTPUT_SAMPLE, self.FORMATTER.INPUT ),
233
- ( self.FORMATTER.OUTPUT, self.DEBUG.INPUT )
268
+ def network(self) -> ez.NetworkDefinition:
269
+ return (
270
+ (self.OSC.OUTPUT_SIGNAL, self.SAMPLER.INPUT_SIGNAL),
271
+ (self.TRIGGER.OUTPUT_TRIGGER, self.SAMPLER.INPUT_TRIGGER),
272
+ (self.TRIGGER.OUTPUT_TRIGGER, self.DEBUG.INPUT),
273
+ (self.SAMPLER.OUTPUT_SAMPLE, self.DEBUG.INPUT),
234
274
  )
235
275
 
236
- if __name__ == '__main__':
237
276
 
238
- settings = SamplerTestSystemSettings(
239
- sampler_settings = SamplerSettings(
240
- buffer_dur = 5.0
277
+ if __name__ == "__main__":
278
+ settings = SamplerTestSystemSettings(
279
+ sampler_settings=SamplerSettings(buffer_dur=5.0),
280
+ trigger_settings=TriggerGeneratorSettings(
281
+ period=(1.0, 2.0), prewait=0.5, publish_period=5.0
241
282
  ),
242
- trigger_settings = TriggerGeneratorSettings(
243
- period = ( 1.0, 2.0 ),
244
- prewait = 0.5,
245
- publish_period = 5.0
246
- )
247
283
  )
248
284
 
249
- system = SamplerTestSystem( settings )
285
+ system = SamplerTestSystem(settings)
250
286
 
251
- ez.run_system( system )
287
+ ez.run_system(system)
@@ -0,0 +1,132 @@
1
+ import enum
2
+
3
+ from dataclasses import replace
4
+
5
+ import numpy as np
6
+ import ezmsg.core as ez
7
+
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+
10
+ from typing import Optional, AsyncGenerator
11
+
12
+
13
+ class OptionsEnum(enum.Enum):
14
+ @classmethod
15
+ def options(cls):
16
+ return list(map(lambda c: c.value, cls))
17
+
18
+
19
+ class WindowFunction(OptionsEnum):
20
+ NONE = "None (Rectangular)"
21
+ HAMMING = "Hamming"
22
+ HANNING = "Hanning"
23
+ BARTLETT = "Bartlett"
24
+ BLACKMAN = "Blackman"
25
+
26
+
27
+ WINDOWS = {
28
+ WindowFunction.NONE: np.ones,
29
+ WindowFunction.HAMMING: np.hamming,
30
+ WindowFunction.HANNING: np.hanning,
31
+ WindowFunction.BARTLETT: np.bartlett,
32
+ WindowFunction.BLACKMAN: np.blackman,
33
+ }
34
+
35
+
36
+ class SpectralTransform(OptionsEnum):
37
+ RAW_COMPLEX = "Complex FFT Output"
38
+ REAL = "Real Component of FFT"
39
+ IMAG = "Imaginary Component of FFT"
40
+ REL_POWER = "Relative Power"
41
+ REL_DB = "Log Power (Relative dB)"
42
+
43
+
44
+ class SpectralOutput(OptionsEnum):
45
+ FULL = "Full Spectrum"
46
+ POSITIVE = "Positive Frequencies"
47
+ NEGATIVE = "Negative Frequencies"
48
+
49
+
50
+ class SpectrumSettings(ez.Settings):
51
+ axis: Optional[str] = None
52
+ # n: Optional[int] = None # n parameter for fft
53
+ out_axis: Optional[str] = "freq" # If none; don't change dim name
54
+ window: WindowFunction = WindowFunction.HAMMING
55
+ transform: SpectralTransform = SpectralTransform.REL_DB
56
+ output: SpectralOutput = SpectralOutput.POSITIVE
57
+
58
+
59
+ class SpectrumState(ez.State):
60
+ cur_settings: SpectrumSettings
61
+
62
+
63
+ class Spectrum(ez.Unit):
64
+ SETTINGS: SpectrumSettings
65
+ STATE: SpectrumState
66
+
67
+ INPUT_SETTINGS = ez.InputStream(SpectrumSettings)
68
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
69
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
70
+
71
+ def initialize(self) -> None:
72
+ self.STATE.cur_settings = self.SETTINGS
73
+
74
+ @ez.subscriber(INPUT_SETTINGS)
75
+ async def on_settings(self, msg: SpectrumSettings):
76
+ self.STATE.cur_settings = msg
77
+
78
+ @ez.subscriber(INPUT_SIGNAL)
79
+ @ez.publisher(OUTPUT_SIGNAL)
80
+ async def on_data(self, message: AxisArray) -> AsyncGenerator:
81
+ axis_name = self.STATE.cur_settings.axis
82
+ if axis_name is None:
83
+ axis_name = message.dims[0]
84
+ axis_idx = message.get_axis_idx(axis_name)
85
+ axis = message.get_axis(axis_name)
86
+
87
+ spectrum = np.moveaxis(message.data, axis_idx, -1)
88
+
89
+ n_time = message.data.shape[axis_idx]
90
+ window = WINDOWS[self.STATE.cur_settings.window](n_time)
91
+
92
+ spectrum = np.fft.fft(spectrum * window) / n_time
93
+ spectrum = np.fft.fftshift(spectrum, axes=-1)
94
+ freqs = np.fft.fftshift(np.fft.fftfreq(n_time, d=axis.gain), axes=-1)
95
+
96
+ if self.STATE.cur_settings.transform != SpectralTransform.RAW_COMPLEX:
97
+ if self.STATE.cur_settings.transform == SpectralTransform.REAL:
98
+ spectrum = spectrum.real
99
+ elif self.STATE.cur_settings.transform == SpectralTransform.IMAG:
100
+ spectrum = spectrum.imag
101
+ else:
102
+ scale = np.sum(window**2.0) * axis.gain
103
+ spectrum = (2.0 * (np.abs(spectrum) ** 2.0)) / scale
104
+
105
+ if self.STATE.cur_settings.transform == SpectralTransform.REL_DB:
106
+ spectrum = 10 * np.log10(spectrum)
107
+
108
+ axis_offset = freqs[0]
109
+ if self.STATE.cur_settings.output == SpectralOutput.POSITIVE:
110
+ axis_offset = freqs[n_time // 2]
111
+ spectrum = spectrum[..., n_time // 2 :]
112
+ elif self.STATE.cur_settings.output == SpectralOutput.NEGATIVE:
113
+ spectrum = spectrum[..., : n_time // 2]
114
+
115
+ spectrum = np.moveaxis(spectrum, axis_idx, -1)
116
+
117
+ out_axis = self.SETTINGS.out_axis
118
+ if out_axis is None:
119
+ out_axis = axis_name
120
+
121
+ freq_axis = AxisArray.Axis(
122
+ unit="Hz", gain=1.0 / (axis.gain * n_time), offset=axis_offset
123
+ )
124
+ new_axes = {**message.axes, **{out_axis: freq_axis}}
125
+
126
+ new_dims = [d for d in message.dims]
127
+ if self.SETTINGS.out_axis is not None:
128
+ new_dims[axis_idx] = self.SETTINGS.out_axis
129
+
130
+ out_msg = replace(message, data=spectrum, dims=new_dims, axes=new_axes)
131
+
132
+ yield self.OUTPUT_SIGNAL, out_msg