ezmsg-sigproc 1.2.3__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/sampler.py CHANGED
@@ -1,81 +1,94 @@
1
+ import asyncio # Dev/test apparatus
1
2
  from collections import deque
2
3
  from dataclasses import dataclass, replace, field
3
4
  import time
4
- from typing import Optional, Any, Tuple, List, Union, AsyncGenerator, Generator
5
+ import typing
5
6
 
6
- import ezmsg.core as ez
7
7
  import numpy as np
8
-
8
+ import numpy.typing as npt
9
+ import ezmsg.core as ez
9
10
  from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
10
11
  from ezmsg.util.generator import consumer
11
12
 
12
- # Dev/test apparatus
13
- import asyncio
14
-
15
13
 
16
14
  @dataclass(unsafe_hash=True)
17
15
  class SampleTriggerMessage:
18
16
  timestamp: float = field(default_factory=time.time)
19
- period: Optional[Tuple[float, float]] = None
20
- value: Any = None
17
+ """Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
18
+
19
+ period: typing.Optional[typing.Tuple[float, float]] = None
20
+ """The period around the timestamp, in seconds"""
21
+
22
+ value: typing.Any = None
23
+ """A value or 'label' associated with the trigger."""
21
24
 
22
25
 
23
26
  @dataclass
24
27
  class SampleMessage:
25
28
  trigger: SampleTriggerMessage
29
+ """The time, window, and value (if any) associated with the trigger."""
30
+
26
31
  sample: AxisArray
32
+ """The data sampled around the trigger."""
27
33
 
28
34
 
29
35
  @consumer
30
36
  def sampler(
31
- buffer_dur: float,
32
- axis: Optional[str] = None,
33
- period: Optional[Tuple[float, float]] = None,
34
- value: Any = None,
35
- estimate_alignment: bool = True
36
- ) -> Generator[Union[AxisArray, SampleTriggerMessage], List[SampleMessage], None]:
37
+ buffer_dur: float,
38
+ axis: typing.Optional[str] = None,
39
+ period: typing.Optional[typing.Tuple[float, float]] = None,
40
+ value: typing.Any = None,
41
+ estimate_alignment: bool = True,
42
+ ) -> typing.Generator[
43
+ typing.List[SampleMessage], typing.Union[AxisArray, SampleTriggerMessage], None
44
+ ]:
37
45
  """
38
46
  A generator function that samples data into a buffer, accepts triggers, and returns slices of sampled
39
47
  data around the trigger time.
40
48
 
41
- Parameters:
42
- - buffer_dur (float): The duration of the buffer in seconds. The buffer must be long enough to store the oldest
43
- sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
44
- need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
45
- - axis (Optional[str]): The axis along which to sample the data.
46
- None (default) will choose the first axis in the first input.
47
- - period (Optional[Tuple[float, float]]): The period in seconds during which to sample the data.
48
- Defaults to None. Only used if not None and the trigger message does not define its own period.
49
- - value (Any): The value to sample. Defaults to None.
50
- - estimate_alignment (bool): Whether to estimate the sample alignment. Defaults to True.
51
- If True, the trigger timestamp field is used to slice the buffer.
52
- If False, the trigger timestamp is ignored and the next signal's .offset is used.
53
- NOTE: For faster-than-realtime playback -- Signals and triggers must share the same (fast) clock for
54
- estimate_alignment to operate correctly.
55
-
56
- Sends:
57
- - AxisArray containing streaming data messages
58
- - SampleTriggerMessage containing a trigger
59
- Yields:
60
- - list[SampleMessage]: The list of sample messages.
49
+ Args:
50
+ buffer_dur: The duration of the buffer in seconds. The buffer must be long enough to store the oldest
51
+ sample to be included in a window. e.g., a trigger lagged by 0.5 seconds with a period of (-1.0, +1.5) will
52
+ need a buffer of 0.5 + (1.5 - -1.0) = 3.0 seconds. It is best to at least double your estimate if memory allows.
53
+ axis: The axis along which to sample the data.
54
+ None (default) will choose the first axis in the first input.
55
+ period: The period in seconds during which to sample the data.
56
+ Defaults to None. Only used if not None and the trigger message does not define its own period.
57
+ value: The value to sample. Defaults to None.
58
+ estimate_alignment: Whether to estimate the sample alignment. Defaults to True.
59
+ If True, the trigger timestamp field is used to slice the buffer.
60
+ If False, the trigger timestamp is ignored and the next signal's .offset is used.
61
+ NOTE: For faster-than-realtime playback -- Signals and triggers must share the same (fast) clock for
62
+ estimate_alignment to operate correctly.
63
+
64
+ Returns:
65
+ A generator that expects `.send` either an :obj:`AxisArray` containing streaming data messages,
66
+ or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`SampleMessage` s.
61
67
  """
62
- msg_in = None
63
- msg_out: Optional[list[SampleMessage]] = None
68
+ msg_out: list[SampleMessage] = []
64
69
 
65
70
  # State variables (most shared between trigger- and data-processing.
66
71
  triggers: deque[SampleTriggerMessage] = deque()
67
- last_msg_stats = None
68
- buffer = None
72
+ buffer: typing.Optional[npt.NDArray] = None
73
+ n_samples: int = 0
74
+ offset: float = 0.0
75
+
76
+ check_inputs = {
77
+ "fs": None, # Also a state variable
78
+ "key": None,
79
+ "shape": None,
80
+ }
69
81
 
70
82
  while True:
71
83
  msg_in = yield msg_out
72
84
  msg_out = []
85
+
73
86
  if isinstance(msg_in, SampleTriggerMessage):
74
- if last_msg_stats is None or buffer is None:
87
+ # Input is a trigger message that we will use to sample the buffer.
88
+
89
+ if buffer is None or check_inputs["fs"] is None:
75
90
  # We've yet to see any data; drop the trigger.
76
91
  continue
77
- fs = last_msg_stats["fs"]
78
- axis_idx = last_msg_stats["axis_idx"]
79
92
 
80
93
  _period = msg_in.period if msg_in.period is not None else period
81
94
  _value = msg_in.value if msg_in.value is not None else value
@@ -86,50 +99,73 @@ def sampler(
86
99
 
87
100
  # Check that period is valid
88
101
  if _period[0] >= _period[1]:
89
- ez.logger.warning(f"Sampling failed: invalid period requested ({_period})")
102
+ ez.logger.warning(
103
+ f"Sampling failed: invalid period requested ({_period})"
104
+ )
90
105
  continue
91
106
 
92
107
  # Check that period is compatible with buffer duration.
93
- max_buf_len = int(buffer_dur * fs)
94
- req_buf_len = int((_period[1] - _period[0]) * fs)
108
+ max_buf_len = int(np.round(buffer_dur * check_inputs["fs"]))
109
+ req_buf_len = int(np.round((_period[1] - _period[0]) * check_inputs["fs"]))
95
110
  if req_buf_len >= max_buf_len:
96
- ez.logger.warning(
97
- f"Sampling failed: {period=} >= {buffer_dur=}"
98
- )
111
+ ez.logger.warning(f"Sampling failed: {period=} >= {buffer_dur=}")
99
112
  continue
100
113
 
101
114
  trigger_ts: float = msg_in.timestamp
102
115
  if not estimate_alignment:
103
116
  # Override the trigger timestamp with the next sample's likely timestamp.
104
- trigger_ts = last_msg_stats["offset"] + (last_msg_stats["n_samples"] + 1) / fs
117
+ trigger_ts = offset + (n_samples + 1) / check_inputs["fs"]
105
118
 
106
- new_trig_msg = replace(msg_in, timestamp=trigger_ts, period=_period, value=_value)
119
+ new_trig_msg = replace(
120
+ msg_in, timestamp=trigger_ts, period=_period, value=_value
121
+ )
107
122
  triggers.append(new_trig_msg)
108
123
 
109
124
  elif isinstance(msg_in, AxisArray):
110
- if axis is None:
111
- axis = msg_in.dims[0]
125
+ # Get properties from message
126
+ axis = axis or msg_in.dims[0]
112
127
  axis_idx = msg_in.get_axis_idx(axis)
113
128
  axis_info = msg_in.get_axis(axis)
114
129
  fs = 1.0 / axis_info.gain
115
- sample_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1:]
116
-
117
- # If the signal properties have changed in a breaking way then reset buffer and triggers.
118
- if last_msg_stats is None or fs != last_msg_stats["fs"] or sample_shape != last_msg_stats["sample_shape"]:
119
- last_msg_stats = {
120
- "fs": fs,
121
- "sample_shape": sample_shape,
122
- "axis_idx": axis_idx,
123
- "n_samples": msg_in.data.shape[axis_idx]
124
- }
130
+ sample_shape = (
131
+ msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :]
132
+ )
133
+
134
+ # TODO: We could accommodate change in dim order.
135
+ # if axis_idx != check_inputs["axis_idx"]:
136
+ # msg_in = replace(
137
+ # msg_in,
138
+ # data=np.moveaxis(msg_in.data, axis_idx, check_inputs["axis_idx"]),
139
+ # dims=TODO...
140
+ # )
141
+ # axis_idx = check_inputs["axis_idx"]
142
+
143
+ # If the properties have changed in a breaking way then reset buffer and triggers.
144
+ b_reset = fs != check_inputs["fs"]
145
+ b_reset = b_reset or sample_shape != check_inputs["shape"]
146
+ # TODO: Skip next line if we do np.moveaxis above
147
+ b_reset = b_reset or axis_idx != check_inputs["axis_idx"]
148
+ b_reset = b_reset or msg_in.key != check_inputs["key"]
149
+ if b_reset:
150
+ check_inputs["fs"] = fs
151
+ check_inputs["shape"] = sample_shape
152
+ check_inputs["axis_idx"] = axis_idx
153
+ check_inputs["key"] = msg_in.key
154
+ n_samples = msg_in.data.shape[axis_idx]
125
155
  buffer = None
126
156
  if len(triggers) > 0:
127
157
  ez.logger.warning("Data stream changed: Discarding all triggers")
128
158
  triggers.clear()
129
- last_msg_stats["offset"] = axis_info.offset # Should be updated on every message.
159
+
160
+ # Save some info for trigger processing
161
+ offset = axis_info.offset
130
162
 
131
163
  # Update buffer
132
- buffer = msg_in.data if buffer is None else np.concatenate((buffer, msg_in.data), axis=axis_idx)
164
+ buffer = (
165
+ msg_in.data
166
+ if buffer is None
167
+ else np.concatenate((buffer, msg_in.data), axis=axis_idx)
168
+ )
133
169
 
134
170
  # Calculate timestamps associated with buffer.
135
171
  buffer_offset = np.arange(buffer.shape[axis_idx], dtype=float)
@@ -153,11 +189,10 @@ def sampler(
153
189
  )
154
190
  triggers.remove(trig)
155
191
 
156
- # TODO: Speed up with searchsorted?
157
192
  t_start = trig.timestamp + trig.period[0]
158
193
  if t_start >= buffer_offset[0]:
159
194
  start = np.searchsorted(buffer_offset, t_start)
160
- stop = start + int(fs * (trig.period[1] - trig.period[0]))
195
+ stop = start + int(np.round(fs * (trig.period[1] - trig.period[0])))
161
196
  if buffer.shape[axis_idx] > stop:
162
197
  # Trigger period fully enclosed in buffer.
163
198
  msg_out.append(
@@ -165,9 +200,16 @@ def sampler(
165
200
  trigger=trig,
166
201
  sample=replace(
167
202
  msg_in,
168
- data=slice_along_axis(buffer, slice(start, stop), axis_idx),
169
- axes={**msg_in.axes, axis: replace(axis_info, offset=buffer_offset[start])}
170
- )
203
+ data=slice_along_axis(
204
+ buffer, slice(start, stop), axis_idx
205
+ ),
206
+ axes={
207
+ **msg_in.axes,
208
+ axis: replace(
209
+ axis_info, offset=buffer_offset[start]
210
+ ),
211
+ },
212
+ ),
171
213
  )
172
214
  )
173
215
  triggers.remove(trig)
@@ -177,29 +219,40 @@ def sampler(
177
219
 
178
220
 
179
221
  class SamplerSettings(ez.Settings):
222
+ """
223
+ Settings for :obj:`Sampler`.
224
+ See :obj:`sampler` for a description of the fields.
225
+ """
226
+
180
227
  buffer_dur: float
181
- axis: Optional[str] = None
182
- period: Optional[
183
- Tuple[float, float]
184
- ] = None # Optional default period if unspecified in SampleTriggerMessage
185
- value: Any = None # Optional default value if unspecified in SampleTriggerMessage
228
+ axis: typing.Optional[str] = None
229
+ period: typing.Optional[typing.Tuple[float, float]] = None
230
+ """Optional default period if unspecified in SampleTriggerMessage"""
231
+
232
+ value: typing.Any = None
233
+ """Optional default value if unspecified in SampleTriggerMessage"""
186
234
 
187
235
  estimate_alignment: bool = True
188
- # If true, use message timestamp fields and reported sampling rate to estimate
189
- # sample-accurate alignment for samples.
190
- # If false, sampling will be limited to incoming message rate -- "Block timing"
191
- # NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
192
- # "realtime" operation for estimate_alignment to operate correctly.
236
+ """
237
+ If true, use message timestamp fields and reported sampling rate to estimate sample-accurate alignment for samples.
238
+ If false, sampling will be limited to incoming message rate -- "Block timing"
239
+ NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect
240
+ "realtime" operation for estimate_alignment to operate correctly.
241
+ """
193
242
 
194
243
 
195
244
  class SamplerState(ez.State):
196
245
  cur_settings: SamplerSettings
197
- gen: Generator[Union[AxisArray, SampleTriggerMessage], List[SampleMessage], None]
246
+ gen: typing.Generator[
247
+ typing.Union[AxisArray, SampleTriggerMessage], typing.List[SampleMessage], None
248
+ ]
198
249
 
199
250
 
200
251
  class Sampler(ez.Unit):
201
- SETTINGS: SamplerSettings
202
- STATE: SamplerState
252
+ """An :obj:`Unit` for :obj:`sampler`."""
253
+
254
+ SETTINGS = SamplerSettings
255
+ STATE = SamplerState
203
256
 
204
257
  INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
205
258
  INPUT_SETTINGS = ez.InputStream(SamplerSettings)
@@ -212,10 +265,10 @@ class Sampler(ez.Unit):
212
265
  axis=self.STATE.cur_settings.axis,
213
266
  period=self.STATE.cur_settings.period,
214
267
  value=self.STATE.cur_settings.value,
215
- estimate_alignment=self.STATE.cur_settings.estimate_alignment
268
+ estimate_alignment=self.STATE.cur_settings.estimate_alignment,
216
269
  )
217
270
 
218
- def initialize(self) -> None:
271
+ async def initialize(self) -> None:
219
272
  self.STATE.cur_settings = self.SETTINGS
220
273
  self.construct_generator()
221
274
 
@@ -228,27 +281,36 @@ class Sampler(ez.Unit):
228
281
  async def on_trigger(self, msg: SampleTriggerMessage) -> None:
229
282
  _ = self.STATE.gen.send(msg)
230
283
 
231
- @ez.subscriber(INPUT_SIGNAL)
284
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
232
285
  @ez.publisher(OUTPUT_SAMPLE)
233
- async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
286
+ async def on_signal(self, msg: AxisArray) -> typing.AsyncGenerator:
234
287
  pub_samples = self.STATE.gen.send(msg)
235
288
  for sample in pub_samples:
236
289
  yield self.OUTPUT_SAMPLE, sample
237
290
 
238
291
 
239
292
  class TriggerGeneratorSettings(ez.Settings):
240
- period: Tuple[float, float] # sec
241
- prewait: float = 0.5 # sec
242
- publish_period: float = 5.0 # sec
293
+ period: typing.Tuple[float, float]
294
+ """The period around the trigger event."""
295
+
296
+ prewait: float = 0.5
297
+ """The time before the first trigger (sec)"""
298
+
299
+ publish_period: float = 5.0
300
+ """The period between triggers (sec)"""
243
301
 
244
302
 
245
303
  class TriggerGenerator(ez.Unit):
246
- SETTINGS: TriggerGeneratorSettings
304
+ """
305
+ A unit to generate triggers every `publish_period` interval.
306
+ """
307
+
308
+ SETTINGS = TriggerGeneratorSettings
247
309
 
248
310
  OUTPUT_TRIGGER = ez.OutputStream(SampleTriggerMessage)
249
311
 
250
312
  @ez.publisher(OUTPUT_TRIGGER)
251
- async def generate(self) -> AsyncGenerator:
313
+ async def generate(self) -> typing.AsyncGenerator:
252
314
  await asyncio.sleep(self.SETTINGS.prewait)
253
315
 
254
316
  output = 0
ezmsg/sigproc/scaler.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from dataclasses import replace
2
- from typing import Generator, Optional
2
+ import typing
3
3
 
4
4
  import numpy as np
5
-
5
+ import numpy.typing as npt
6
6
  import ezmsg.core as ez
7
7
  from ezmsg.util.messages.axisarray import AxisArray
8
- from ezmsg.util.generator import consumer, GenAxisArray
8
+ from ezmsg.util.generator import consumer
9
+
10
+ from .base import GenAxisArray
9
11
 
10
12
 
11
13
  def _tau_from_alpha(alpha: float, dt: float) -> float:
@@ -27,24 +29,39 @@ def _alpha_from_tau(tau: float, dt: float) -> float:
27
29
 
28
30
 
29
31
  @consumer
30
- def scaler(time_constant: float = 1.0, axis: Optional[str] = None) -> Generator[AxisArray, AxisArray, None]:
32
+ def scaler(
33
+ time_constant: float = 1.0, axis: typing.Optional[str] = None
34
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
35
+ """
36
+ Create a generator function that applies the
37
+ adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
38
+ This is faster than :obj:`scaler_np` for single-channel data.
39
+
40
+ Args:
41
+ time_constant: Decay constant `tau` in seconds.
42
+ axis: The name of the axis to accumulate statistics over.
43
+
44
+ Returns:
45
+ A primed generator object that expects `.send(axis_array)` and yields a
46
+ standardized, or "Z-scored" version of the input.
47
+ """
31
48
  from river import preprocessing
32
- axis_arr_in = AxisArray(np.array([]), dims=[""])
33
- axis_arr_out = AxisArray(np.array([]), dims=[""])
49
+
50
+ msg_out = AxisArray(np.array([]), dims=[""])
34
51
  _scaler = None
35
52
  while True:
36
- axis_arr_in = yield axis_arr_out
37
- data = axis_arr_in.data
53
+ msg_in: AxisArray = yield msg_out
54
+ data = msg_in.data
38
55
  if axis is None:
39
- axis = axis_arr_in.dims[0]
56
+ axis = msg_in.dims[0]
40
57
  axis_idx = 0
41
58
  else:
42
- axis_idx = axis_arr_in.get_axis_idx(axis)
59
+ axis_idx = msg_in.get_axis_idx(axis)
43
60
  if axis_idx != 0:
44
61
  data = np.moveaxis(data, axis_idx, 0)
45
62
 
46
63
  if _scaler is None:
47
- alpha = _alpha_from_tau(time_constant, axis_arr_in.axes[axis].gain)
64
+ alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
48
65
  _scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha)
49
66
 
50
67
  result = []
@@ -57,17 +74,39 @@ def scaler(time_constant: float = 1.0, axis: Optional[str] = None) -> Generator[
57
74
 
58
75
  result = np.stack(result)
59
76
  result = np.moveaxis(result, 0, axis_idx)
60
- axis_arr_out = replace(axis_arr_in, data=result)
77
+ msg_out = replace(msg_in, data=result)
61
78
 
62
79
 
63
80
  @consumer
64
- def scaler_np(time_constant: float = 1.0, axis: Optional[str] = None) -> Generator[AxisArray, AxisArray, None]:
65
- # The only dependency is numpy.
66
- # This is faster for multi-channel data but slower for single-channel data.
67
- axis_arr_in = AxisArray(np.array([]), dims=[""])
68
- axis_arr_out = AxisArray(np.array([]), dims=[""])
69
- means = vars_means = vars_sq_means = None
70
- alpha = None
81
+ def scaler_np(
82
+ time_constant: float = 1.0, axis: typing.Optional[str] = None
83
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
84
+ """
85
+ Create a generator function that applies an adaptive standard scaler.
86
+ This is faster than :obj:`scaler` for multichannel data.
87
+
88
+ Args:
89
+ time_constant: Decay constant `tau` in seconds.
90
+ axis: The name of the axis to accumulate statistics over.
91
+
92
+ Returns:
93
+ A primed generator object that expects `.send(axis_array)` and yields a
94
+ standardized, or "Z-scored" version of the input.
95
+ """
96
+ msg_out = AxisArray(np.array([]), dims=[""])
97
+
98
+ # State variables
99
+ alpha: float = 0.0
100
+ means: typing.Optional[npt.NDArray] = None
101
+ vars_means: typing.Optional[npt.NDArray] = None
102
+ vars_sq_means: typing.Optional[npt.NDArray] = None
103
+
104
+ # Reset if input changes
105
+ check_input = {
106
+ "gain": None, # Resets alpha
107
+ "shape": None,
108
+ "key": None, # Key change implies buffered means/vars are invalid.
109
+ }
71
110
 
72
111
  def _ew_update(arr, prev, _alpha):
73
112
  if np.all(prev == 0):
@@ -77,51 +116,58 @@ def scaler_np(time_constant: float = 1.0, axis: Optional[str] = None) -> Generat
77
116
  return prev + _alpha * (arr - prev)
78
117
 
79
118
  while True:
80
- axis_arr_in = yield axis_arr_out
119
+ msg_in: AxisArray = yield msg_out
81
120
 
82
- data = axis_arr_in.data
83
- if axis is None:
84
- axis = axis_arr_in.dims[0]
85
- axis_idx = 0
86
- else:
87
- axis_idx = axis_arr_in.get_axis_idx(axis)
88
- data = np.moveaxis(data, axis_idx, 0)
121
+ axis = axis or msg_in.dims[0]
122
+ axis_idx = msg_in.get_axis_idx(axis)
89
123
 
90
- if alpha is None:
91
- alpha = _alpha_from_tau(time_constant, axis_arr_in.axes[axis].gain)
124
+ if msg_in.axes[axis].gain != check_input["gain"]:
125
+ alpha = _alpha_from_tau(time_constant, msg_in.axes[axis].gain)
126
+ check_input["gain"] = msg_in.axes[axis].gain
92
127
 
93
- if means is None or means.shape != data[0].shape:
128
+ data: npt.NDArray = np.moveaxis(msg_in.data, axis_idx, 0)
129
+ b_reset = data.shape[1:] != check_input["shape"]
130
+ b_reset |= msg_in.key != check_input["key"]
131
+ if b_reset:
132
+ check_input["shape"] = data.shape[1:]
133
+ check_input["key"] = msg_in.key
94
134
  vars_sq_means = np.zeros_like(data[0], dtype=float)
95
135
  vars_means = np.zeros_like(data[0], dtype=float)
96
136
  means = np.zeros_like(data[0], dtype=float)
97
137
 
98
- result = []
99
- for sample in data:
138
+ result = np.zeros_like(data)
139
+ for sample_ix in range(data.shape[0]):
140
+ sample = data[sample_ix]
100
141
  # Update step
101
142
  vars_means = _ew_update(sample, vars_means, alpha)
102
143
  vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha)
103
144
  means = _ew_update(sample, means, alpha)
104
145
  # Get step
105
- varis = vars_sq_means - vars_means ** 2
106
- y = ((sample - means) / (varis**0.5))
107
- y[np.isnan(y)] = 0.0
108
- result.append(y)
146
+ varis = vars_sq_means - vars_means**2
147
+ y = (sample - means) / (varis**0.5)
148
+ result[sample_ix] = y
109
149
 
110
- result = np.stack(result, axis=0)
150
+ result[np.isnan(result)] = 0.0
111
151
  result = np.moveaxis(result, 0, axis_idx)
112
- axis_arr_out = replace(axis_arr_in, data=result)
152
+ msg_out = replace(msg_in, data=result)
113
153
 
114
154
 
115
155
  class AdaptiveStandardScalerSettings(ez.Settings):
156
+ """
157
+ Settings for :obj:`AdaptiveStandardScaler`.
158
+ See :obj:`scaler_np` for a description of the parameters.
159
+ """
160
+
116
161
  time_constant: float = 1.0
117
- axis: Optional[str] = None
162
+ axis: typing.Optional[str] = None
118
163
 
119
164
 
120
165
  class AdaptiveStandardScaler(GenAxisArray):
121
- SETTINGS: AdaptiveStandardScalerSettings
166
+ """Unit for :obj:`scaler_np`"""
167
+
168
+ SETTINGS = AdaptiveStandardScalerSettings
122
169
 
123
170
  def construct_generator(self):
124
171
  self.STATE.gen = scaler_np(
125
- time_constant=self.SETTINGS.time_constant,
126
- axis=self.SETTINGS.axis
172
+ time_constant=self.SETTINGS.time_constant, axis=self.SETTINGS.axis
127
173
  )
@@ -1,17 +1,15 @@
1
+ from dataclasses import replace
1
2
  import typing
2
3
 
3
4
  import ezmsg.core as ez
4
-
5
- from dataclasses import replace
6
5
  from ezmsg.util.messages.axisarray import AxisArray
7
-
8
6
  import numpy as np
9
7
  import numpy.typing as npt
10
8
 
11
9
 
12
10
  class SignalInjectorSettings(ez.Settings):
13
- time_dim: str = 'time' # Input signal needs a time dimension with units in sec.
14
- frequency: typing.Optional[float] = None # Hz
11
+ time_dim: str = "time" # Input signal needs a time dimension with units in sec.
12
+ frequency: typing.Optional[float] = None # Hz
15
13
  amplitude: float = 1.0
16
14
  mixing_seed: typing.Optional[int] = None
17
15
 
@@ -24,8 +22,8 @@ class SignalInjectorState(ez.State):
24
22
 
25
23
 
26
24
  class SignalInjector(ez.Unit):
27
- SETTINGS: SignalInjectorSettings
28
- STATE: SignalInjectorState
25
+ SETTINGS = SignalInjectorSettings
26
+ STATE = SignalInjectorState
29
27
 
30
28
  INPUT_FREQUENCY = ez.InputStream(typing.Optional[float])
31
29
  INPUT_AMPLITUDE = ez.InputStream(float)
@@ -48,7 +46,6 @@ class SignalInjector(ez.Unit):
48
46
  @ez.subscriber(INPUT_SIGNAL)
49
47
  @ez.publisher(OUTPUT_SIGNAL)
50
48
  async def inject(self, msg: AxisArray) -> typing.AsyncGenerator:
51
-
52
49
  if self.STATE.cur_shape != msg.shape:
53
50
  self.STATE.cur_shape = msg.shape
54
51
  rng = np.random.default_rng(self.SETTINGS.mixing_seed)
@@ -58,10 +55,10 @@ class SignalInjector(ez.Unit):
58
55
  if self.STATE.cur_frequency is None:
59
56
  yield self.OUTPUT_SIGNAL, msg
60
57
  else:
61
- out_msg = replace(msg, data = msg.data.copy())
58
+ out_msg = replace(msg, data=msg.data.copy())
62
59
  t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis]
63
60
  signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t)
64
61
  mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude
65
62
  with out_msg.view2d(self.SETTINGS.time_dim) as view:
66
63
  view[...] = view + mixed_signal.astype(view.dtype)
67
- yield self.OUTPUT_SIGNAL, out_msg
64
+ yield self.OUTPUT_SIGNAL, out_msg