ezmsg-sigproc 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/sampler.py CHANGED
@@ -1,15 +1,19 @@
1
+ from collections import deque
1
2
  from dataclasses import dataclass, replace, field
2
3
  import time
4
+ from typing import Optional, Any, Tuple, List, Union, AsyncGenerator, Generator
3
5
 
4
6
  import ezmsg.core as ez
5
7
  import numpy as np
6
8
 
7
- from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
10
+ from ezmsg.util.generator import consumer
8
11
 
9
- from typing import Optional, Any, Tuple, List, Dict, AsyncGenerator
12
+ # Dev/test apparatus
13
+ import asyncio
10
14
 
11
15
 
12
- @dataclass(frozen=True)
16
+ @dataclass(unsafe_hash=True)
13
17
  class SampleTriggerMessage:
14
18
  timestamp: float = field(default_factory=time.time)
15
19
  period: Optional[Tuple[float, float]] = None
@@ -22,6 +26,156 @@ class SampleMessage:
22
26
  sample: AxisArray
23
27
 
24
28
 
29
+ @consumer
30
+ def sampler(
31
+ buffer_dur: float,
32
+ axis: Optional[str] = None,
33
+ period: Optional[Tuple[float, float]] = None,
34
+ value: Any = None,
35
+ estimate_alignment: bool = True
36
+ ) -> Generator[Union[AxisArray, SampleTriggerMessage], List[SampleMessage], None]:
37
+ """
38
+ A generator function that samples data into a buffer, accepts triggers, and returns slices of sampled
39
+ data around the trigger time.
40
+
41
+ Parameters:
42
+ - buffer_dur (float): The duration of the buffer in seconds. The buffer must be long enough to store the oldest
43
+ sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
44
+ need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
45
+ - axis (Optional[str]): The axis along which to sample the data.
46
+ None (default) will choose the first axis in the first input.
47
+ - period (Optional[Tuple[float, float]]): The period in seconds during which to sample the data.
48
+ Defaults to None. Only used if not None and the trigger message does not define its own period.
49
+ - value (Any): The value to sample. Defaults to None.
50
+ - estimate_alignment (bool): Whether to estimate the sample alignment. Defaults to True.
51
+ If True, the trigger timestamp field is used to slice the buffer.
52
+ If False, the trigger timestamp is ignored and the next signal's .offset is used.
53
+ NOTE: For faster-than-realtime playback -- Signals and triggers must share the same (fast) clock for
54
+ estimate_alignment to operate correctly.
55
+
56
+ Sends:
57
+ - AxisArray containing streaming data messages
58
+ - SampleTriggerMessage containing a trigger
59
+ Yields:
60
+ - list[SampleMessage]: The list of sample messages.
61
+ """
62
+ msg_in = None
63
+ msg_out: Optional[list[SampleMessage]] = None
64
+
65
+ # State variables (most shared between trigger- and data-processing.
66
+ triggers: deque[SampleTriggerMessage] = deque()
67
+ last_msg_stats = None
68
+ buffer = None
69
+
70
+ while True:
71
+ msg_in = yield msg_out
72
+ msg_out = []
73
+ if isinstance(msg_in, SampleTriggerMessage):
74
+ if last_msg_stats is None or buffer is None:
75
+ # We've yet to see any data; drop the trigger.
76
+ continue
77
+ fs = last_msg_stats["fs"]
78
+ axis_idx = last_msg_stats["axis_idx"]
79
+
80
+ _period = msg_in.period if msg_in.period is not None else period
81
+ _value = msg_in.value if msg_in.value is not None else value
82
+
83
+ if _period is None:
84
+ ez.logger.warning("Sampling failed: period not specified")
85
+ continue
86
+
87
+ # Check that period is valid
88
+ if _period[0] >= _period[1]:
89
+ ez.logger.warning(f"Sampling failed: invalid period requested ({_period})")
90
+ continue
91
+
92
+ # Check that period is compatible with buffer duration.
93
+ max_buf_len = int(buffer_dur * fs)
94
+ req_buf_len = int((_period[1] - _period[0]) * fs)
95
+ if req_buf_len >= max_buf_len:
96
+ ez.logger.warning(
97
+ f"Sampling failed: {period=} >= {buffer_dur=}"
98
+ )
99
+ continue
100
+
101
+ trigger_ts: float = msg_in.timestamp
102
+ if not estimate_alignment:
103
+ # Override the trigger timestamp with the next sample's likely timestamp.
104
+ trigger_ts = last_msg_stats["offset"] + (last_msg_stats["n_samples"] + 1) / fs
105
+
106
+ new_trig_msg = replace(msg_in, timestamp=trigger_ts, period=_period, value=_value)
107
+ triggers.append(new_trig_msg)
108
+
109
+ elif isinstance(msg_in, AxisArray):
110
+ if axis is None:
111
+ axis = msg_in.dims[0]
112
+ axis_idx = msg_in.get_axis_idx(axis)
113
+ axis_info = msg_in.get_axis(axis)
114
+ fs = 1.0 / axis_info.gain
115
+ sample_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1:]
116
+
117
+ # If the signal properties have changed in a breaking way then reset buffer and triggers.
118
+ if last_msg_stats is None or fs != last_msg_stats["fs"] or sample_shape != last_msg_stats["sample_shape"]:
119
+ last_msg_stats = {
120
+ "fs": fs,
121
+ "sample_shape": sample_shape,
122
+ "axis_idx": axis_idx,
123
+ "n_samples": msg_in.data.shape[axis_idx]
124
+ }
125
+ buffer = None
126
+ if len(triggers) > 0:
127
+ ez.logger.warning("Data stream changed: Discarding all triggers")
128
+ triggers.clear()
129
+ last_msg_stats["offset"] = axis_info.offset # Should be updated on every message.
130
+
131
+ # Update buffer
132
+ buffer = msg_in.data if buffer is None else np.concatenate((buffer, msg_in.data), axis=axis_idx)
133
+
134
+ # Calculate timestamps associated with buffer.
135
+ buffer_offset = np.arange(buffer.shape[axis_idx], dtype=float)
136
+ buffer_offset -= buffer_offset[-msg_in.data.shape[axis_idx]]
137
+ buffer_offset *= axis_info.gain
138
+ buffer_offset += axis_info.offset
139
+
140
+ # ... for each trigger, collect the message (if possible) and append to msg_out
141
+ for trig in list(triggers):
142
+ if trig.period is None:
143
+ # This trigger was malformed; drop it.
144
+ triggers.remove(trig)
145
+
146
+ # If the previous iteration had insufficient data for the trigger timestamp + period,
147
+ # and buffer-management removed data required for the trigger, then we will never be able
148
+ # to accommodate this trigger. Discard it. An increase in buffer_dur is recommended.
149
+ if (trig.timestamp + trig.period[0]) < buffer_offset[0]:
150
+ ez.logger.warning(
151
+ f"Sampling failed: Buffer span {buffer_offset[0]} is beyond the "
152
+ f"requested sample period start: {trig.timestamp + trig.period[0]}"
153
+ )
154
+ triggers.remove(trig)
155
+
156
+ # TODO: Speed up with searchsorted?
157
+ t_start = trig.timestamp + trig.period[0]
158
+ if t_start >= buffer_offset[0]:
159
+ start = np.searchsorted(buffer_offset, t_start)
160
+ stop = start + int(fs * (trig.period[1] - trig.period[0]))
161
+ if buffer.shape[axis_idx] > stop:
162
+ # Trigger period fully enclosed in buffer.
163
+ msg_out.append(
164
+ SampleMessage(
165
+ trigger=trig,
166
+ sample=replace(
167
+ msg_in,
168
+ data=slice_along_axis(buffer, slice(start, stop), axis_idx),
169
+ axes={**msg_in.axes, axis: replace(axis_info, offset=buffer_offset[start])}
170
+ )
171
+ )
172
+ )
173
+ triggers.remove(trig)
174
+
175
+ buf_len = int(buffer_dur * fs)
176
+ buffer = slice_along_axis(buffer, np.s_[-buf_len:], axis_idx)
177
+
178
+
25
179
  class SamplerSettings(ez.Settings):
26
180
  buffer_dur: float
27
181
  axis: Optional[str] = None
@@ -40,9 +194,7 @@ class SamplerSettings(ez.Settings):
40
194
 
41
195
  class SamplerState(ez.State):
42
196
  cur_settings: SamplerSettings
43
- triggers: Dict[SampleTriggerMessage, int] = field(default_factory=dict)
44
- last_msg: Optional[AxisArray] = None
45
- buffer: Optional[np.ndarray] = None
197
+ gen: Generator[Union[AxisArray, SampleTriggerMessage], List[SampleMessage], None]
46
198
 
47
199
 
48
200
  class Sampler(ez.Unit):
@@ -54,162 +206,35 @@ class Sampler(ez.Unit):
54
206
  INPUT_SIGNAL = ez.InputStream(AxisArray)
55
207
  OUTPUT_SAMPLE = ez.OutputStream(SampleMessage)
56
208
 
209
+ def construct_generator(self):
210
+ self.STATE.gen = sampler(
211
+ buffer_dur=self.STATE.cur_settings.buffer_dur,
212
+ axis=self.STATE.cur_settings.axis,
213
+ period=self.STATE.cur_settings.period,
214
+ value=self.STATE.cur_settings.value,
215
+ estimate_alignment=self.STATE.cur_settings.estimate_alignment
216
+ )
217
+
57
218
  def initialize(self) -> None:
58
219
  self.STATE.cur_settings = self.SETTINGS
220
+ self.construct_generator()
59
221
 
60
222
  @ez.subscriber(INPUT_SETTINGS)
61
223
  async def on_settings(self, msg: SamplerSettings) -> None:
62
224
  self.STATE.cur_settings = msg
225
+ self.construct_generator()
63
226
 
64
227
  @ez.subscriber(INPUT_TRIGGER)
65
228
  async def on_trigger(self, msg: SampleTriggerMessage) -> None:
66
- if self.STATE.last_msg is not None:
67
- axis_name = self.STATE.cur_settings.axis
68
- if axis_name is None:
69
- axis_name = self.STATE.last_msg.dims[0]
70
- axis = self.STATE.last_msg.get_axis(axis_name)
71
- axis_idx = self.STATE.last_msg.get_axis_idx(axis_name)
72
-
73
- fs = 1.0 / axis.gain
74
- last_msg_timestamp = axis.offset + (
75
- self.STATE.last_msg.shape[axis_idx] / fs
76
- )
77
-
78
- period = (
79
- msg.period if msg.period is not None else self.STATE.cur_settings.period
80
- )
81
- value = (
82
- msg.value if msg.value is not None else self.STATE.cur_settings.value
83
- )
84
-
85
- if period is None:
86
- ez.logger.warning(f"Sampling failed: period not specified")
87
- return
88
-
89
- # Check that period is valid
90
- start_offset = int(period[0] * fs)
91
- stop_offset = int(period[1] * fs)
92
- if (stop_offset - start_offset) <= 0:
93
- ez.logger.warning(f"Sampling failed: invalid period requested")
94
- return
95
-
96
- # Check that period is compatible with buffer duration
97
- max_buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
98
- req_buf_len = int((period[1] - period[0]) * fs)
99
- if req_buf_len >= max_buf_len:
100
- ez.logger.warning(
101
- f"Sampling failed: {period=} >= {self.STATE.cur_settings.buffer_dur=}"
102
- )
103
- return
104
-
105
- offset: int = 0
106
- if self.STATE.cur_settings.estimate_alignment:
107
- # Do what we can with the wall clock to determine sample alignment
108
- wall_delta = msg.timestamp - last_msg_timestamp
109
- offset = int(wall_delta * fs)
110
-
111
- # Check that current buffer accumulation allows for offset - period start
112
- if (
113
- self.STATE.buffer is None
114
- or -min(offset + start_offset, 0) >= self.STATE.buffer.shape[0]
115
- ):
116
- ez.logger.warning(
117
- "Sampling failed: insufficient buffer accumulation for requested sample period"
118
- )
119
- return
120
-
121
- self.STATE.triggers[replace(msg, period=period, value=value)] = offset
122
-
123
- else:
124
- ez.logger.warning("Sampling failed: no signal to sample yet")
229
+ _ = self.STATE.gen.send(msg)
125
230
 
126
231
  @ez.subscriber(INPUT_SIGNAL)
127
232
  @ez.publisher(OUTPUT_SAMPLE)
128
233
  async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
129
- axis_name = self.STATE.cur_settings.axis
130
- if axis_name is None:
131
- axis_name = msg.dims[0]
132
- axis = msg.get_axis(axis_name)
133
-
134
- fs = 1.0 / axis.gain
135
-
136
- if self.STATE.last_msg is None:
137
- self.STATE.last_msg = msg
138
-
139
- # Easier to deal with timeseries on axis 0
140
- last_msg = self.STATE.last_msg
141
- msg_data = np.moveaxis(msg.data, msg.get_axis_idx(axis_name), 0)
142
- last_msg_data = np.moveaxis(last_msg.data, last_msg.get_axis_idx(axis_name), 0)
143
- last_msg_axis = last_msg.get_axis(axis_name)
144
- last_msg_fs = 1.0 / last_msg_axis.gain
145
-
146
- # Check if signal properties have changed in a breaking way
147
- if fs != last_msg_fs or msg_data.shape[1:] != last_msg_data.shape[1:]:
148
- # Data stream changed meaningfully -- flush buffer, stop sampling
149
- if len(self.STATE.triggers) > 0:
150
- ez.logger.warning("Sampling failed: Discarding all triggers")
151
- ez.logger.warning("Flushing buffer: signal properties changed")
152
- self.STATE.buffer = None
153
- self.STATE.triggers = dict()
154
-
155
- # Accumulate buffer ( time dim => dim 0 )
156
- self.STATE.buffer = (
157
- msg_data
158
- if self.STATE.buffer is None
159
- else np.concatenate((self.STATE.buffer, msg_data), axis=0)
160
- )
161
-
162
- buffer_offset = np.arange(self.STATE.buffer.shape[0] + msg_data.shape[0])
163
- buffer_offset -= self.STATE.buffer.shape[0] + 1
164
- buffer_offset = (buffer_offset * axis.gain) + axis.offset
165
-
166
- pub_samples: List[SampleMessage] = []
167
- remaining_triggers: Dict[SampleTriggerMessage, int] = dict()
168
- for trigger, offset in self.STATE.triggers.items():
169
- if trigger.period is None:
170
- continue
171
-
172
- # trigger_offset points to t = 0 within buffer
173
- offset -= msg_data.shape[0]
174
- start = offset + int(trigger.period[0] * fs)
175
- stop = offset + int(trigger.period[1] * fs)
176
-
177
- if stop < 0: # We should be able to dispatch a sample
178
- sample_data = self.STATE.buffer[start:stop, ...]
179
- sample_data = np.moveaxis(sample_data, msg.get_axis_idx(axis_name), 0)
180
-
181
- sample_offset = buffer_offset[start]
182
- sample_axis = replace(axis, offset=sample_offset)
183
- sample_axes = {**msg.axes, **{axis_name: sample_axis}}
184
-
185
- pub_samples.append(
186
- SampleMessage(
187
- trigger=trigger,
188
- sample=replace(msg, data=sample_data, axes=sample_axes),
189
- )
190
- )
191
-
192
- else:
193
- remaining_triggers[trigger] = offset
194
-
234
+ pub_samples = self.STATE.gen.send(msg)
195
235
  for sample in pub_samples:
196
236
  yield self.OUTPUT_SAMPLE, sample
197
237
 
198
- self.STATE.triggers = remaining_triggers
199
-
200
- buf_len = int(self.STATE.cur_settings.buffer_dur * fs)
201
- self.STATE.buffer = self.STATE.buffer[-buf_len:, ...]
202
- self.STATE.last_msg = msg
203
-
204
-
205
- ## Dev/test apparatus
206
- import asyncio
207
-
208
- from ezmsg.testing.debuglog import DebugLog
209
- from ezmsg.sigproc.synth import Oscillator, OscillatorSettings
210
-
211
- from typing import AsyncGenerator
212
-
213
238
 
214
239
  class TriggerGeneratorSettings(ez.Settings):
215
240
  period: Tuple[float, float] # sec
@@ -228,60 +253,8 @@ class TriggerGenerator(ez.Unit):
228
253
 
229
254
  output = 0
230
255
  while True:
231
- yield self.OUTPUT_TRIGGER, SampleTriggerMessage(
232
- period=self.SETTINGS.period, value=output
233
- )
256
+ out_msg = SampleTriggerMessage(period=self.SETTINGS.period, value=output)
257
+ yield self.OUTPUT_TRIGGER, out_msg
234
258
 
235
259
  await asyncio.sleep(self.SETTINGS.publish_period)
236
260
  output += 1
237
-
238
-
239
- class SamplerTestSystemSettings(ez.Settings):
240
- sampler_settings: SamplerSettings
241
- trigger_settings: TriggerGeneratorSettings
242
-
243
-
244
- class SamplerTestSystem(ez.Collection):
245
- SETTINGS: SamplerTestSystemSettings
246
-
247
- OSC = Oscillator()
248
- SAMPLER = Sampler()
249
- TRIGGER = TriggerGenerator()
250
- DEBUG = DebugLog()
251
-
252
- def configure(self) -> None:
253
- self.SAMPLER.apply_settings(self.SETTINGS.sampler_settings)
254
- self.TRIGGER.apply_settings(self.SETTINGS.trigger_settings)
255
-
256
- self.OSC.apply_settings(
257
- OscillatorSettings(
258
- n_time=2, # Number of samples to output per block
259
- fs=10, # Sampling rate of signal output in Hz
260
- dispatch_rate="realtime",
261
- freq=2.0, # Oscillation frequency in Hz
262
- amp=1.0, # Amplitude
263
- phase=0.0, # Phase offset (in radians)
264
- sync=True, # Adjust `freq` to sync with sampling rate
265
- )
266
- )
267
-
268
- def network(self) -> ez.NetworkDefinition:
269
- return (
270
- (self.OSC.OUTPUT_SIGNAL, self.SAMPLER.INPUT_SIGNAL),
271
- (self.TRIGGER.OUTPUT_TRIGGER, self.SAMPLER.INPUT_TRIGGER),
272
- (self.TRIGGER.OUTPUT_TRIGGER, self.DEBUG.INPUT),
273
- (self.SAMPLER.OUTPUT_SAMPLE, self.DEBUG.INPUT),
274
- )
275
-
276
-
277
- if __name__ == "__main__":
278
- settings = SamplerTestSystemSettings(
279
- sampler_settings=SamplerSettings(buffer_dur=5.0),
280
- trigger_settings=TriggerGeneratorSettings(
281
- period=(1.0, 2.0), prewait=0.5, publish_period=5.0
282
- ),
283
- )
284
-
285
- system = SamplerTestSystem(settings)
286
-
287
- ez.run(SYSTEM = system)
@@ -0,0 +1,127 @@
1
+ from dataclasses import replace
2
+ from typing import Generator, Optional
3
+
4
+ import numpy as np
5
+
6
+ import ezmsg.core as ez
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.generator import consumer, GenAxisArray
9
+
10
+
11
+ def _tau_from_alpha(alpha: float, dt: float) -> float:
12
+ """
13
+ Inverse of _alpha_from_tau. See that function for explanation.
14
+ """
15
+ return -dt / np.log(1 - alpha)
16
+
17
+
18
+ def _alpha_from_tau(tau: float, dt: float) -> float:
19
+ """
20
+ # https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
21
+ :param tau: The amount of time for the smoothed response of a unit step function to reach
22
+ 1 - 1/e approx-eq 63.2%.
23
+ :param dt: sampling period, or 1 / sampling_rate.
24
+ :return: alpha, the "fading factor" in exponential smoothing.
25
+ """
26
+ return 1 - np.exp(-dt / tau)
27
+
28
+
29
+ @consumer
30
+ def scaler(time_constant: float = 1.0, axis: Optional[str] = None) -> Generator[AxisArray, AxisArray, None]:
31
+ from river import preprocessing
32
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
33
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
34
+ _scaler = None
35
+ while True:
36
+ axis_arr_in = yield axis_arr_out
37
+ data = axis_arr_in.data
38
+ if axis is None:
39
+ axis = axis_arr_in.dims[0]
40
+ axis_idx = 0
41
+ else:
42
+ axis_idx = axis_arr_in.get_axis_idx(axis)
43
+ if axis_idx != 0:
44
+ data = np.moveaxis(data, axis_idx, 0)
45
+
46
+ if _scaler is None:
47
+ alpha = _alpha_from_tau(time_constant, axis_arr_in.axes[axis].gain)
48
+ _scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
49
+
50
+ result = []
51
+ for sample in data:
52
+ x = {k: v for k, v in enumerate(sample.flatten().tolist())}
53
+ _scaler.learn_one(x)
54
+ y = _scaler.transform_one(x)
55
+ k = sorted(y.keys())
56
+ result.append(np.array([y[_] for _ in k]).reshape(sample.shape))
57
+
58
+ result = np.stack(result)
59
+ result = np.moveaxis(result, 0, axis_idx)
60
+ axis_arr_out = replace(axis_arr_in, data=result)
61
+
62
+
63
+ @consumer
64
+ def scaler_np(time_constant: float = 1.0, axis: Optional[str] = None) -> Generator[AxisArray, AxisArray, None]:
65
+ # The only dependency is numpy.
66
+ # This is faster for multi-channel data but slower for single-channel data.
67
+ axis_arr_in = AxisArray(np.array([]), dims=[""])
68
+ axis_arr_out = AxisArray(np.array([]), dims=[""])
69
+ means = vars_means = vars_sq_means = None
70
+ alpha = None
71
+
72
+ def _ew_update(arr, prev, _alpha):
73
+ if np.all(prev == 0):
74
+ return arr
75
+ # return _alpha * arr + (1 - _alpha) * prev
76
+ # Micro-optimization: sub, mult, add (below) is faster than sub, mult, mult, add (above)
77
+ return prev + _alpha * (arr - prev)
78
+
79
+ while True:
80
+ axis_arr_in = yield axis_arr_out
81
+
82
+ data = axis_arr_in.data
83
+ if axis is None:
84
+ axis = axis_arr_in.dims[0]
85
+ axis_idx = 0
86
+ else:
87
+ axis_idx = axis_arr_in.get_axis_idx(axis)
88
+ data = np.moveaxis(data, axis_idx, 0)
89
+
90
+ if alpha is None:
91
+ alpha = _alpha_from_tau(time_constant, axis_arr_in.axes[axis].gain)
92
+
93
+ if means is None or means.shape != data[0].shape:
94
+ vars_sq_means = np.zeros_like(data[0], dtype=float)
95
+ vars_means = np.zeros_like(data[0], dtype=float)
96
+ means = np.zeros_like(data[0], dtype=float)
97
+
98
+ result = []
99
+ for sample in data:
100
+ # Update step
101
+ vars_means = _ew_update(sample, vars_means, alpha)
102
+ vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
103
+ means = _ew_update(sample, means, alpha)
104
+ # Get step
105
+ varis = vars_sq_means - vars_means ** 2
106
+ y = ((sample - means) / (varis**0.5))
107
+ y[np.isnan(y)] = 0.0
108
+ result.append(y)
109
+
110
+ result = np.stack(result, axis=0)
111
+ result = np.moveaxis(result, 0, axis_idx)
112
+ axis_arr_out = replace(axis_arr_in, data=result)
113
+
114
+
115
+ class AdaptiveStandardScalerSettings(ez.Settings):
116
+ time_constant: float = 1.0
117
+ axis: Optional[str] = None
118
+
119
+
120
+ class AdaptiveStandardScaler(GenAxisArray):
121
+ SETTINGS: AdaptiveStandardScalerSettings
122
+
123
+ def construct_generator(self):
124
+ self.STATE.gen = scaler_np(
125
+ time_constant=self.SETTINGS.time_constant,
126
+ axis=self.SETTINGS.axis
127
+ )
@@ -0,0 +1,67 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+
5
+ from dataclasses import replace
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+
11
+
12
+ class SignalInjectorSettings(ez.Settings):
13
+ time_dim: str = 'time' # Input signal needs a time dimension with units in sec.
14
+ frequency: typing.Optional[float] = None # Hz
15
+ amplitude: float = 1.0
16
+ mixing_seed: typing.Optional[int] = None
17
+
18
+
19
+ class SignalInjectorState(ez.State):
20
+ cur_shape: typing.Optional[typing.Tuple[int, ...]] = None
21
+ cur_frequency: typing.Optional[float] = None
22
+ cur_amplitude: float
23
+ mixing: npt.NDArray
24
+
25
+
26
+ class SignalInjector(ez.Unit):
27
+ SETTINGS: SignalInjectorSettings
28
+ STATE: SignalInjectorState
29
+
30
+ INPUT_FREQUENCY = ez.InputStream(typing.Optional[float])
31
+ INPUT_AMPLITUDE = ez.InputStream(float)
32
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
33
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
34
+
35
+ async def initialize(self) -> None:
36
+ self.STATE.cur_frequency = self.SETTINGS.frequency
37
+ self.STATE.cur_amplitude = self.SETTINGS.amplitude
38
+ self.STATE.mixing = np.array([])
39
+
40
+ @ez.subscriber(INPUT_FREQUENCY)
41
+ async def on_frequency(self, msg: typing.Optional[float]) -> None:
42
+ self.STATE.cur_frequency = msg
43
+
44
+ @ez.subscriber(INPUT_AMPLITUDE)
45
+ async def on_amplitude(self, msg: float) -> None:
46
+ self.STATE.cur_amplitude = msg
47
+
48
+ @ez.subscriber(INPUT_SIGNAL)
49
+ @ez.publisher(OUTPUT_SIGNAL)
50
+ async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
51
+
52
+ if self.STATE.cur_shape != msg.shape:
53
+ self.STATE.cur_shape = msg.shape
54
+ rng = np.random.default_rng(self.SETTINGS.mixing_seed)
55
+ self.STATE.mixing = rng.random((1, msg.shape2d(self.SETTINGS.time_dim)[1]))
56
+ self.STATE.mixing = (self.STATE.mixing * 2.0) - 1.0
57
+
58
+ if self.STATE.cur_frequency is None:
59
+ yield self.OUTPUT_SIGNAL, msg
60
+ else:
61
+ out_msg = replace(msg, data = msg.data.copy())
62
+ t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
63
+ signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
64
+ mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
65
+ with out_msg.view2d(self.SETTINGS.time_dim) as view:
66
+ view[...] = view + mixed_signal.astype(view.dtype)
67
+ yield self.OUTPUT_SIGNAL, out_msg