ezmsg-sigproc 2.11.0__py3-none-any.whl → 2.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.11.0'
32
- __version_tuple__ = version_tuple = (2, 11, 0)
31
+ __version__ = version = '2.13.0'
32
+ __version_tuple__ = version_tuple = (2, 13, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -14,9 +14,9 @@ from pathlib import Path
14
14
  import ezmsg.core as ez
15
15
  import numpy as np
16
16
  import numpy.typing as npt
17
+ from array_api_compat import get_namespace
17
18
  from ezmsg.baseproc import (
18
19
  BaseStatefulTransformer,
19
- BaseTransformer,
20
20
  BaseTransformerUnit,
21
21
  processor_state,
22
22
  )
@@ -24,6 +24,117 @@ from ezmsg.util.messages.axisarray import AxisArray, AxisBase
24
24
  from ezmsg.util.messages.util import replace
25
25
 
26
26
 
27
+ def _find_block_diagonal_clusters(weights: np.ndarray) -> list[tuple[np.ndarray, np.ndarray]] | None:
28
+ """Detect block-diagonal structure in a weight matrix.
29
+
30
+ Finds connected components in the bipartite graph of non-zero weights,
31
+ where input channels and output channels are separate node sets.
32
+
33
+ Args:
34
+ weights: 2-D weight matrix of shape (n_in, n_out).
35
+
36
+ Returns:
37
+ List of (input_indices, output_indices) tuples, one per block, or
38
+ None if the matrix is not block-diagonal (single connected component).
39
+ """
40
+ if weights.ndim != 2:
41
+ return None
42
+
43
+ n_in, n_out = weights.shape
44
+ if n_in + n_out <= 2:
45
+ return None
46
+
47
+ from scipy.sparse import coo_matrix
48
+ from scipy.sparse.csgraph import connected_components
49
+
50
+ rows, cols = np.nonzero(weights)
51
+ if len(rows) == 0:
52
+ return None
53
+
54
+ # Bipartite graph: input nodes [0, n_in), output nodes [n_in, n_in + n_out)
55
+ shifted_cols = cols + n_in
56
+ adj_rows = np.concatenate([rows, shifted_cols])
57
+ adj_cols = np.concatenate([shifted_cols, rows])
58
+ adj_data = np.ones(len(adj_rows), dtype=bool)
59
+ n_nodes = n_in + n_out
60
+ adj = coo_matrix((adj_data, (adj_rows, adj_cols)), shape=(n_nodes, n_nodes))
61
+
62
+ n_components, labels = connected_components(adj, directed=False)
63
+
64
+ if n_components <= 1:
65
+ return None
66
+
67
+ clusters = []
68
+ for comp in range(n_components):
69
+ members = np.where(labels == comp)[0]
70
+ in_idx = np.sort(members[members < n_in])
71
+ out_idx = np.sort(members[members >= n_in] - n_in)
72
+ if len(in_idx) > 0 and len(out_idx) > 0:
73
+ clusters.append((in_idx, out_idx))
74
+
75
+ return clusters if len(clusters) > 1 else None
76
+
77
+
78
+ def _max_cross_cluster_weight(weights: np.ndarray, clusters: list[tuple[np.ndarray, np.ndarray]]) -> float:
79
+ """Return the maximum absolute weight between different clusters."""
80
+ mask = np.zeros(weights.shape, dtype=bool)
81
+ for in_idx, out_idx in clusters:
82
+ mask[np.ix_(in_idx, out_idx)] = True
83
+ cross = np.abs(weights[~mask])
84
+ return float(cross.max()) if cross.size > 0 else 0.0
85
+
86
+
87
+ def _merge_small_clusters(
88
+ clusters: list[tuple[np.ndarray, np.ndarray]], min_size: int
89
+ ) -> list[tuple[np.ndarray, np.ndarray]]:
90
+ """Merge clusters smaller than *min_size* into combined groups.
91
+
92
+ Small clusters are greedily concatenated until each merged group has
93
+ at least *min_size* channels (measured as ``max(n_in, n_out)``).
94
+ Any leftover small clusters that don't reach the threshold are
95
+ combined into a final group.
96
+
97
+ The merged group's sub-weight-matrix will contain the original small
98
+ diagonal blocks with zeros between them — a dense matmul on that
99
+ sub-matrix is cheaper than iterating over many tiny matmuls.
100
+ """
101
+ if min_size <= 1:
102
+ return clusters
103
+
104
+ large = []
105
+ small = []
106
+ for cluster in clusters:
107
+ in_idx, out_idx = cluster
108
+ if max(len(in_idx), len(out_idx)) >= min_size:
109
+ large.append(cluster)
110
+ else:
111
+ small.append(cluster)
112
+
113
+ if not small:
114
+ return clusters
115
+
116
+ current_in: list[np.ndarray] = []
117
+ current_out: list[np.ndarray] = []
118
+ current_in_size = 0
119
+ current_out_size = 0
120
+ for in_idx, out_idx in small:
121
+ current_in.append(in_idx)
122
+ current_out.append(out_idx)
123
+ current_in_size += len(in_idx)
124
+ current_out_size += len(out_idx)
125
+ if max(current_in_size, current_out_size) >= min_size:
126
+ large.append((np.sort(np.concatenate(current_in)), np.sort(np.concatenate(current_out))))
127
+ current_in = []
128
+ current_out = []
129
+ current_in_size = 0
130
+ current_out_size = 0
131
+
132
+ if current_in:
133
+ large.append((np.sort(np.concatenate(current_in)), np.sort(np.concatenate(current_out))))
134
+
135
+ return large
136
+
137
+
27
138
  class AffineTransformSettings(ez.Settings):
28
139
  """
29
140
  Settings for :obj:`AffineTransform`.
@@ -38,11 +149,32 @@ class AffineTransformSettings(ez.Settings):
38
149
  right_multiply: bool = True
39
150
  """Set False to transpose the weights before applying."""
40
151
 
152
+ channel_clusters: list[list[int]] | None = None
153
+ """Optional explicit input channel cluster specification for block-diagonal optimization.
154
+
155
+ Each element is a list of input channel indices forming one cluster. The
156
+ corresponding output indices are derived automatically from the non-zero
157
+ columns of the weight matrix for those input rows.
158
+
159
+ When provided, the weight matrix is decomposed into per-cluster sub-matrices
160
+ and multiplied separately, which is faster when cross-cluster weights are zero.
161
+
162
+ If None, block-diagonal structure is auto-detected from the zero pattern
163
+ of the weights."""
164
+
165
+ min_cluster_size: int = 32
166
+ """Minimum number of channels per cluster for the block-diagonal optimization.
167
+ Clusters smaller than this are greedily merged together to avoid excessive
168
+ Python loop overhead. Set to 1 to disable merging."""
169
+
41
170
 
42
171
  @processor_state
43
172
  class AffineTransformState:
44
173
  weights: npt.NDArray | None = None
45
174
  new_axis: AxisBase | None = None
175
+ n_out: int = 0
176
+ clusters: list | None = None
177
+ """list of (in_indices_xp, out_indices_xp, sub_weights_xp) tuples when block-diagonal."""
46
178
 
47
179
 
48
180
  class AffineTransformTransformer(
@@ -85,11 +217,60 @@ class AffineTransformTransformer(
85
217
 
86
218
  self._state.weights = weights
87
219
 
220
+ # Note: If weights were scipy.sparse BSR then maybe we could use automate this next part.
221
+ # However, that would break compatibility with Array API.
222
+
223
+ # --- Block-diagonal cluster detection ---
224
+ # Clusters are a list of (input_indices, output_indices) tuples.
225
+ n_in, n_out = weights.shape
226
+ if self.settings.channel_clusters is not None:
227
+ # Validate input index bounds
228
+ all_in = np.concatenate([np.asarray(group) for group in self.settings.channel_clusters])
229
+ if np.any((all_in < 0) | (all_in >= n_in)):
230
+ raise ValueError(
231
+ "channel_clusters contains out-of-range input indices " f"(valid range: 0..{n_in - 1})"
232
+ )
233
+
234
+ # Derive output indices from non-zero weights for each input cluster
235
+ clusters = []
236
+ for group in self.settings.channel_clusters:
237
+ in_idx = np.asarray(group)
238
+ out_idx = np.where(np.any(weights[in_idx, :] != 0, axis=0))[0]
239
+ clusters.append((in_idx, out_idx))
240
+
241
+ max_cross = _max_cross_cluster_weight(weights, clusters)
242
+ if max_cross > 0:
243
+ ez.logger.warning(
244
+ f"Non-zero cross-cluster weights detected (max abs: {max_cross:.2e}). "
245
+ "These will be ignored in block-diagonal multiplication."
246
+ )
247
+ else:
248
+ clusters = _find_block_diagonal_clusters(weights)
249
+ if clusters is not None:
250
+ ez.logger.info(
251
+ f"Auto-detected {len(clusters)} block-diagonal clusters "
252
+ f"(sizes: {[(len(i), len(o)) for i, o in clusters]})"
253
+ )
254
+
255
+ # Merge small clusters to avoid excessive loop overhead
256
+ if clusters is not None:
257
+ clusters = _merge_small_clusters(clusters, self.settings.min_cluster_size)
258
+
259
+ if clusters is not None and len(clusters) > 1:
260
+ self._state.n_out = n_out
261
+ self._state.clusters = [
262
+ (in_idx, out_idx, np.ascontiguousarray(weights[np.ix_(in_idx, out_idx)]))
263
+ for in_idx, out_idx in clusters
264
+ ]
265
+ self._state.weights = None
266
+ else:
267
+ self._state.clusters = None
268
+
269
+ # --- Axis label handling (for non-square transforms, non-cluster path) ---
88
270
  axis = self.settings.axis or message.dims[-1]
89
- if axis in message.axes and hasattr(message.axes[axis], "data") and weights.shape[0] != weights.shape[1]:
271
+ if axis in message.axes and hasattr(message.axes[axis], "data") and n_in != n_out:
90
272
  in_labels = message.axes[axis].data
91
273
  new_labels = []
92
- n_in, n_out = weights.shape
93
274
  if len(in_labels) != n_in:
94
275
  ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.")
95
276
  else:
@@ -111,23 +292,70 @@ class AffineTransformTransformer(
111
292
 
112
293
  self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
113
294
 
295
+ # Convert to match message.data namespace for efficient operations in _process
296
+ xp = get_namespace(message.data)
297
+ if self._state.weights is not None:
298
+ self._state.weights = xp.asarray(self._state.weights)
299
+ if self._state.clusters is not None:
300
+ self._state.clusters = [
301
+ (xp.asarray(in_idx), xp.asarray(out_idx), xp.asarray(sub_w))
302
+ for in_idx, out_idx, sub_w in self._state.clusters
303
+ ]
304
+
305
+ def _block_diagonal_matmul(self, xp, data, axis_idx):
306
+ """Perform matmul using block-diagonal decomposition.
307
+
308
+ For each cluster, gathers input channels via ``xp.take``, performs a
309
+ matmul with the cluster's sub-weight matrix, and writes the result
310
+ directly into the pre-allocated output at the cluster's output indices.
311
+ Omitted output channels naturally remain zero.
312
+ """
313
+ needs_permute = axis_idx not in [-1, data.ndim - 1]
314
+ if needs_permute:
315
+ dim_perm = list(range(data.ndim))
316
+ dim_perm.append(dim_perm.pop(axis_idx))
317
+ data = xp.permute_dims(data, dim_perm)
318
+
319
+ # Pre-allocate output (omitted channels stay zero)
320
+ out_shape = data.shape[:-1] + (self._state.n_out,)
321
+ result = xp.zeros(out_shape, dtype=data.dtype)
322
+
323
+ for in_idx, out_idx, sub_weights in self._state.clusters:
324
+ chunk = xp.take(data, in_idx, axis=data.ndim - 1)
325
+ result[..., out_idx] = xp.matmul(chunk, sub_weights)
326
+
327
+ if needs_permute:
328
+ inv_dim_perm = list(range(result.ndim))
329
+ inv_dim_perm.insert(axis_idx, inv_dim_perm.pop(-1))
330
+ result = xp.permute_dims(result, inv_dim_perm)
331
+
332
+ return result
333
+
114
334
  def _process(self, message: AxisArray) -> AxisArray:
335
+ xp = get_namespace(message.data)
115
336
  axis = self.settings.axis or message.dims[-1]
116
337
  axis_idx = message.get_axis_idx(axis)
117
338
  data = message.data
118
339
 
119
- if data.shape[axis_idx] == (self._state.weights.shape[0] - 1):
120
- # The weights are stacked A|B where A is the transform and B is a single row
121
- # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
122
- sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
123
- data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
124
-
125
- if axis_idx in [-1, len(message.dims) - 1]:
126
- data = np.matmul(data, self._state.weights)
340
+ if self._state.clusters is not None:
341
+ data = self._block_diagonal_matmul(xp, data, axis_idx)
127
342
  else:
128
- data = np.moveaxis(data, axis_idx, -1)
129
- data = np.matmul(data, self._state.weights)
130
- data = np.moveaxis(data, -1, axis_idx)
343
+ if data.shape[axis_idx] == (self._state.weights.shape[0] - 1):
344
+ # The weights are stacked A|B where A is the transform and B is a single row
345
+ # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
346
+ sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
347
+ data = xp.concat((data, xp.ones(sample_shape, dtype=data.dtype)), axis=axis_idx)
348
+
349
+ if axis_idx in [-1, len(message.dims) - 1]:
350
+ data = xp.matmul(data, self._state.weights)
351
+ else:
352
+ perm = list(range(data.ndim))
353
+ perm.append(perm.pop(axis_idx))
354
+ data = xp.permute_dims(data, perm)
355
+ data = xp.matmul(data, self._state.weights)
356
+ inv_perm = list(range(data.ndim))
357
+ inv_perm.insert(axis_idx, inv_perm.pop(-1))
358
+ data = xp.permute_dims(data, inv_perm)
131
359
 
132
360
  replace_kwargs = {"data": data}
133
361
  if self._state.new_axis is not None:
@@ -144,6 +372,8 @@ def affine_transform(
144
372
  weights: np.ndarray | str | Path,
145
373
  axis: str | None = None,
146
374
  right_multiply: bool = True,
375
+ channel_clusters: list[list[int]] | None = None,
376
+ min_cluster_size: int = 32,
147
377
  ) -> AffineTransformTransformer:
148
378
  """
149
379
  Perform affine transformations on streaming data.
@@ -152,19 +382,25 @@ def affine_transform(
152
382
  weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
153
383
  axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
154
384
  right_multiply: Set False to transpose the weights before applying.
385
+ channel_clusters: Optional explicit channel cluster specification. See
386
+ :attr:`AffineTransformSettings.channel_clusters`.
387
+ min_cluster_size: Minimum channels per cluster; smaller clusters are merged. See
388
+ :attr:`AffineTransformSettings.min_cluster_size`.
155
389
 
156
390
  Returns:
157
391
  :obj:`AffineTransformTransformer`.
158
392
  """
159
393
  return AffineTransformTransformer(
160
- AffineTransformSettings(weights=weights, axis=axis, right_multiply=right_multiply)
394
+ AffineTransformSettings(
395
+ weights=weights,
396
+ axis=axis,
397
+ right_multiply=right_multiply,
398
+ channel_clusters=channel_clusters,
399
+ min_cluster_size=min_cluster_size,
400
+ )
161
401
  )
162
402
 
163
403
 
164
- def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
165
- return np.zeros_like(data)
166
-
167
-
168
404
  class CommonRereferenceSettings(ez.Settings):
169
405
  """
170
406
  Settings for :obj:`CommonRereference`
@@ -179,35 +415,64 @@ class CommonRereferenceSettings(ez.Settings):
179
415
  include_current: bool = True
180
416
  """Set False to exclude each channel from participating in the calculation of its reference."""
181
417
 
418
+ channel_clusters: list[list[int]] | None = None
419
+ """Optional channel clusters for per-cluster rereferencing. Each element is a
420
+ list of channel indices forming one cluster. The common reference is computed
421
+ independently within each cluster. If None, all channels form a single cluster."""
422
+
423
+
424
+ @processor_state
425
+ class CommonRereferenceState:
426
+ clusters: list | None = None
427
+ """list of xp arrays of channel indices, one per cluster."""
428
+
429
+
430
+ class CommonRereferenceTransformer(
431
+ BaseStatefulTransformer[CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceState]
432
+ ):
433
+ def _hash_message(self, message: AxisArray) -> int:
434
+ axis = self.settings.axis or message.dims[-1]
435
+ axis_idx = message.get_axis_idx(axis)
436
+ return hash((message.key, message.data.shape[axis_idx]))
437
+
438
+ def _reset_state(self, message: AxisArray) -> None:
439
+ xp = get_namespace(message.data)
440
+ axis = self.settings.axis or message.dims[-1]
441
+ axis_idx = message.get_axis_idx(axis)
442
+ n_chans = message.data.shape[axis_idx]
443
+
444
+ if self.settings.channel_clusters is not None:
445
+ self._state.clusters = [xp.asarray(group) for group in self.settings.channel_clusters]
446
+ else:
447
+ self._state.clusters = [xp.arange(n_chans)]
182
448
 
183
- class CommonRereferenceTransformer(BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]):
184
449
  def _process(self, message: AxisArray) -> AxisArray:
185
450
  if self.settings.mode == "passthrough":
186
451
  return message
187
452
 
453
+ xp = get_namespace(message.data)
188
454
  axis = self.settings.axis or message.dims[-1]
189
455
  axis_idx = message.get_axis_idx(axis)
456
+ func = {"mean": xp.mean, "median": np.median}[self.settings.mode]
190
457
 
191
- func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[self.settings.mode]
458
+ # Use result_type to match dtype promotion from data - float operations.
459
+ out_dtype = np.result_type(message.data.dtype, np.float64)
460
+ output = xp.zeros(message.data.shape, dtype=out_dtype)
192
461
 
193
- ref_data = func(message.data, axis=axis_idx, keepdims=True)
462
+ for cluster_idx in self._state.clusters:
463
+ cluster_data = xp.take(message.data, cluster_idx, axis=axis_idx)
464
+ ref_data = func(cluster_data, axis=axis_idx, keepdims=True)
194
465
 
195
- if not self.settings.include_current:
196
- # Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
197
- # and is the same for all i, so it is calculated only once in `ref_data`.
198
- # However, if we had excluded the current channel,
199
- # then we would have omitted the contribution of the current channel:
200
- # `CAR[i] = x[0]/(N-1) + x[1]/(N-1) + ... x[i-1]/(N-1) + x[i+1]/(N-1) + ... + x[N-1]/(N-1)`
201
- # The majority of the calculation is the same as when the current channel is included;
202
- # we need only rescale CAR so the divisor is `N-1` instead of `N`, then subtract the contribution
203
- # from the current channel (i.e., `x[i] / (N-1)`)
204
- # i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
205
- # We can use broadcasting subtraction instead of looping over channels.
206
- N = message.data.shape[axis_idx]
207
- ref_data = (N / (N - 1)) * ref_data - message.data / (N - 1)
208
- # Note: I profiled using AffineTransformTransformer; it's ~30x slower than this implementation.
466
+ if not self.settings.include_current:
467
+ N = cluster_data.shape[axis_idx]
468
+ ref_data = (N / (N - 1)) * ref_data - cluster_data / (N - 1)
209
469
 
210
- return replace(message, data=message.data - ref_data)
470
+ # Write per-cluster result into output at the correct axis position
471
+ idx = [slice(None)] * output.ndim
472
+ idx[axis_idx] = cluster_idx
473
+ output[tuple(idx)] = cluster_data - ref_data
474
+
475
+ return replace(message, data=output)
211
476
 
212
477
 
213
478
  class CommonRereference(
@@ -217,19 +482,26 @@ class CommonRereference(
217
482
 
218
483
 
219
484
  def common_rereference(
220
- mode: str = "mean", axis: str | None = None, include_current: bool = True
485
+ mode: str = "mean",
486
+ axis: str | None = None,
487
+ include_current: bool = True,
488
+ channel_clusters: list[list[int]] | None = None,
221
489
  ) -> CommonRereferenceTransformer:
222
490
  """
223
491
  Perform common average referencing (CAR) on streaming data.
224
492
 
225
493
  Args:
226
494
  mode: The statistical mode to apply -- either "mean" or "median"
227
- axis: The name of hte axis to apply the transformation to.
495
+ axis: The name of the axis to apply the transformation to.
228
496
  include_current: Set False to exclude each channel from participating in the calculation of its reference.
497
+ channel_clusters: Optional channel clusters for per-cluster rereferencing. See
498
+ :attr:`CommonRereferenceSettings.channel_clusters`.
229
499
 
230
500
  Returns:
231
501
  :obj:`CommonRereferenceTransformer`
232
502
  """
233
503
  return CommonRereferenceTransformer(
234
- CommonRereferenceSettings(mode=mode, axis=axis, include_current=include_current)
504
+ CommonRereferenceSettings(
505
+ mode=mode, axis=axis, include_current=include_current, channel_clusters=channel_clusters
506
+ )
235
507
  )
ezmsg/sigproc/merge.py ADDED
@@ -0,0 +1,358 @@
1
+ """Time-aligned merge of two AxisArray streams along a non-time axis."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import typing
6
+
7
+ import ezmsg.core as ez
8
+ import numpy as np
9
+ from array_api_compat import get_namespace
10
+ from ezmsg.baseproc.protocols import processor_state
11
+ from ezmsg.baseproc.stateful import BaseStatefulTransformer
12
+ from ezmsg.baseproc.units import BaseProcessorUnit
13
+ from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
14
+ from ezmsg.util.messages.util import replace
15
+
16
+ from .util.axisarray_buffer import HybridAxisArrayBuffer
17
+
18
+
19
+ class MergeSettings(ez.Settings):
20
+ axis: str = "ch"
21
+ """Axis along which to concatenate the two signals."""
22
+
23
+ align_axis: str | None = "time"
24
+ """Axis used for alignment. If None, defaults to the first dimension."""
25
+
26
+ buffer_dur: float = 10.0
27
+ """Buffer duration in seconds for each input stream."""
28
+
29
+ relabel_axis: bool = True
30
+ """Whether to relabel coordinate axis labels to ensure uniqueness."""
31
+
32
+ label_a: str = "_a"
33
+ """Suffix appended to signal A labels when relabel_axis is True."""
34
+
35
+ label_b: str = "_b"
36
+ """Suffix appended to signal B labels when relabel_axis is True."""
37
+
38
+ new_key: str | None = None
39
+ """Output AxisArray key. If None, uses the key from signal A."""
40
+
41
+
42
+ @processor_state
43
+ class MergeState:
44
+ # Common state
45
+ gain: float | None = None
46
+ align_axis: str | None = None
47
+ aligned: bool = False
48
+ merged_concat_axis: CoordinateAxis | None = None
49
+
50
+ # A state
51
+ buf_a: HybridAxisArrayBuffer | None = None
52
+ concat_axis_a: CoordinateAxis | None = None
53
+ a_concat_dim: int | None = None
54
+ a_other_dims: tuple[int, ...] | None = None
55
+
56
+ # B state
57
+ buf_b: HybridAxisArrayBuffer | None = None
58
+ concat_axis_b: CoordinateAxis | None = None
59
+ b_concat_dim: int | None = None
60
+ b_other_dims: tuple[int, ...] | None = None
61
+
62
+
63
+ class MergeProcessor(BaseStatefulTransformer[MergeSettings, AxisArray, AxisArray | None, MergeState]):
64
+ """Processor that time-aligns two AxisArray streams and concatenates them.
65
+
66
+ Input A flows through the standard ``__call__`` / ``_process`` path,
67
+ getting automatic ``_hash_message`` / ``_reset_state`` handling from
68
+ :class:`BaseStatefulTransformer`. Input B flows through :meth:`push_b`,
69
+ which independently tracks its own structure.
70
+
71
+ Invalidation rules:
72
+
73
+ - Gain mismatch (either input vs stored common gain) → full reset.
74
+ - Concat-axis dimensionality change → per-input buffer reset +
75
+ alignment and merged-axis cache invalidation.
76
+ - Non-align/non-concat axis shape change → per-input buffer reset +
77
+ alignment invalidation.
78
+ """
79
+
80
+ # -- Structural extraction helpers ---------------------------------------
81
+
82
+ def _extract_gain(self, message: AxisArray) -> float | None:
83
+ """Extract the align-axis gain from a message."""
84
+ align_name = self.settings.align_axis or message.dims[0]
85
+ ax = message.axes.get(align_name)
86
+ if ax is not None and hasattr(ax, "gain"):
87
+ return ax.gain
88
+ if ax is not None and hasattr(ax, "data") and len(ax.data) > 1:
89
+ return float(ax.data[-1] - ax.data[0]) / (len(ax.data) - 1)
90
+ return None
91
+
92
+ # -- Reset helpers -------------------------------------------------------
93
+
94
+ def _full_reset(self, align_axis: str) -> None:
95
+ """Reset all state — both inputs and common merge state."""
96
+ self._state.align_axis = align_axis
97
+ self._state.buf_a = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis)
98
+ self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis)
99
+ self._state.gain = None
100
+ self._state.aligned = False
101
+ self._state.concat_axis_a = None
102
+ self._state.concat_axis_b = None
103
+ self._state.merged_concat_axis = None
104
+ self._state.a_concat_dim = None
105
+ self._state.a_other_dims = None
106
+ self._state.b_concat_dim = None
107
+ self._state.b_other_dims = None
108
+
109
+ def _reset_a_state(self) -> None:
110
+ """Reset input-A buffer and concat-axis cache."""
111
+ self._state.buf_a = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=self._state.align_axis)
112
+ self._state.concat_axis_a = None
113
+
114
+ def _reset_b_state(self) -> None:
115
+ """Reset input-B buffer and concat-axis cache."""
116
+ self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=self._state.align_axis)
117
+ self._state.concat_axis_b = None
118
+
119
+ # -- BaseStatefulTransformer interface ------------------------------------
120
+
121
+ def _hash_message(self, message: AxisArray) -> int:
122
+ """Hash the align-axis gain only.
123
+
124
+ Gain changes trigger a full reset via ``_reset_state``. Concat-axis
125
+ and non-merge dimension changes are handled as partial resets inside
126
+ ``_process`` and ``push_b``.
127
+ """
128
+ return hash(self._extract_gain(message))
129
+
130
+ def _reset_state(self, message: AxisArray) -> None:
131
+ """Full reset — called by the base class when gain changes."""
132
+ align_axis = self.settings.align_axis or message.dims[0]
133
+ self._full_reset(align_axis)
134
+
135
+ def _process(self, message: AxisArray) -> AxisArray | None:
136
+ """Process input A: detect structural changes, buffer, try merge."""
137
+ # Detect per-input structural changes.
138
+ align_idx = message.dims.index(self._state.align_axis)
139
+ concat_idx = message.dims.index(self.settings.axis) if self.settings.axis in message.dims else None
140
+ concat_dim = message.data.shape[concat_idx] if concat_idx is not None else None
141
+ other_dims = tuple(s for i, s in enumerate(message.data.shape) if i != align_idx and i != concat_idx)
142
+
143
+ if self._state.a_concat_dim is not None and concat_dim != self._state.a_concat_dim:
144
+ self._reset_a_state()
145
+ self._state.aligned = False
146
+ self._state.merged_concat_axis = None
147
+ elif self._state.a_other_dims is not None and other_dims != self._state.a_other_dims:
148
+ self._reset_a_state()
149
+ self._state.aligned = False
150
+
151
+ self._state.a_concat_dim = concat_dim
152
+ self._state.a_other_dims = other_dims
153
+
154
+ self._state.buf_a.write(message)
155
+ if self._state.gain is None:
156
+ self._state.gain = self._state.buf_a.axis_gain
157
+ self._update_concat_axis(message, "a")
158
+ return self._try_merge()
159
+
160
+ # -- Input B entry point ------------------------------------------------
161
+
162
+ def push_b(self, message: AxisArray) -> AxisArray | None:
163
+ """Process input B: check gain, detect structural changes, buffer, try merge."""
164
+ align_axis = self.settings.align_axis or message.dims[0]
165
+
166
+ # Gain compatibility check.
167
+ b_gain = self._extract_gain(message)
168
+ if self._state.gain is not None and b_gain != self._state.gain:
169
+ self._full_reset(align_axis)
170
+ # Set the base-class hash so the next compatible A goes straight
171
+ # to _process instead of triggering another full reset.
172
+ self._hash = self._hash_message(message)
173
+
174
+ # Lazy-create buf_b if B arrives before A.
175
+ if self._state.buf_b is None:
176
+ if self._state.align_axis is None:
177
+ self._state.align_axis = align_axis
178
+ self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis)
179
+
180
+ # Detect per-input structural changes.
181
+ align_idx = message.dims.index(align_axis)
182
+ concat_idx = message.dims.index(self.settings.axis) if self.settings.axis in message.dims else None
183
+ concat_dim = message.data.shape[concat_idx] if concat_idx is not None else None
184
+ other_dims = tuple(s for i, s in enumerate(message.data.shape) if i != align_idx and i != concat_idx)
185
+
186
+ if self._state.b_concat_dim is not None and concat_dim != self._state.b_concat_dim:
187
+ self._reset_b_state()
188
+ self._state.aligned = False
189
+ self._state.merged_concat_axis = None
190
+ elif self._state.b_other_dims is not None and other_dims != self._state.b_other_dims:
191
+ self._reset_b_state()
192
+ self._state.aligned = False
193
+
194
+ self._state.b_concat_dim = concat_dim
195
+ self._state.b_other_dims = other_dims
196
+
197
+ self._state.buf_b.write(message)
198
+ if self._state.gain is None:
199
+ self._state.gain = self._state.buf_b.axis_gain
200
+ self._update_concat_axis(message, "b")
201
+ return self._try_merge()
202
+
203
+ # -- Concat-axis caching ------------------------------------------------
204
+
205
+ def _update_concat_axis(self, message: AxisArray, which: str) -> None:
206
+ """Track each input's concat-axis labels; invalidate cache on change."""
207
+ concat_dim = self.settings.axis
208
+ if concat_dim not in message.axes:
209
+ return
210
+ ax = message.axes[concat_dim]
211
+ if not hasattr(ax, "data"):
212
+ return
213
+
214
+ if which == "a":
215
+ if self._state.concat_axis_a is None or not np.array_equal(self._state.concat_axis_a.data, ax.data):
216
+ self._state.concat_axis_a = ax
217
+ self._state.merged_concat_axis = None
218
+ else:
219
+ if self._state.concat_axis_b is None or not np.array_equal(self._state.concat_axis_b.data, ax.data):
220
+ self._state.concat_axis_b = ax
221
+ self._state.merged_concat_axis = None
222
+
223
+ def _build_merged_concat_axis(self) -> CoordinateAxis | None:
224
+ """Build the merged CoordinateAxis from the two cached per-input axes."""
225
+ if self._state.concat_axis_a is None or self._state.concat_axis_b is None:
226
+ return None
227
+ if self.settings.relabel_axis:
228
+ labels_a = np.array([str(lbl) + self.settings.label_a for lbl in self._state.concat_axis_a.data])
229
+ labels_b = np.array([str(lbl) + self.settings.label_b for lbl in self._state.concat_axis_b.data])
230
+ else:
231
+ labels_a = self._state.concat_axis_a.data
232
+ labels_b = self._state.concat_axis_b.data
233
+ return CoordinateAxis(
234
+ data=np.concatenate([labels_a, labels_b]),
235
+ dims=self._state.concat_axis_a.dims,
236
+ unit=self._state.concat_axis_a.unit,
237
+ )
238
+
239
+ # -- Core merge logic ---------------------------------------------------
240
+
241
+ def _try_merge(self) -> AxisArray | None:
242
+ """Align and read from both buffers, returning the merged result.
243
+
244
+ Initial alignment is performed once. After the first successful
245
+ merge the two streams are assumed to share a common clock and
246
+ never drop samples, so we simply read
247
+ ``min(available_a, available_b)`` on every subsequent call.
248
+ """
249
+ if self._state.buf_a is None or self._state.buf_b is None:
250
+ return None
251
+ if self._state.buf_a.is_empty() or self._state.buf_b.is_empty():
252
+ return None
253
+
254
+ gain = self._state.gain
255
+
256
+ # --- Initial alignment (runs only until the first successful merge) ---
257
+ if not self._state.aligned:
258
+ first_a = self._state.buf_a.axis_first_value
259
+ final_a = self._state.buf_a.axis_final_value
260
+ first_b = self._state.buf_b.axis_first_value
261
+ final_b = self._state.buf_b.axis_final_value
262
+
263
+ overlap_start = max(first_a, first_b)
264
+ overlap_end = min(final_a, final_b)
265
+
266
+ if overlap_end < overlap_start - gain / 2:
267
+ if final_a < first_b:
268
+ self._state.buf_a.seek(self._state.buf_a.available())
269
+ elif final_b < first_a:
270
+ self._state.buf_b.seek(self._state.buf_b.available())
271
+ return None
272
+
273
+ if first_a < overlap_start - gain / 2:
274
+ self._state.buf_a.seek(int(round((overlap_start - first_a) / gain)))
275
+ if first_b < overlap_start - gain / 2:
276
+ self._state.buf_b.seek(int(round((overlap_start - first_b) / gain)))
277
+
278
+ # --- Read aligned samples ---
279
+ n_read = min(self._state.buf_a.available(), self._state.buf_b.available())
280
+ if n_read <= 0:
281
+ return None
282
+
283
+ aa_a = self._state.buf_a.read(n_read)
284
+ aa_b = self._state.buf_b.read(n_read)
285
+ if aa_a is None or aa_b is None:
286
+ return None
287
+
288
+ if not self._state.aligned:
289
+ axis_a = aa_a.axes.get(self._state.align_axis)
290
+ axis_b = aa_b.axes.get(self._state.align_axis)
291
+ if axis_a is not None and axis_b is not None:
292
+ off_a = axis_a.value(0) if hasattr(axis_a, "value") else None
293
+ off_b = axis_b.value(0) if hasattr(axis_b, "value") else None
294
+ if off_a is not None and off_b is not None:
295
+ if not np.isclose(off_a, off_b, atol=abs(gain) * 1e-6):
296
+ raise RuntimeError(
297
+ f"Offset mismatch after alignment: " f"off_a={off_a}, off_b={off_b}, gain={gain}"
298
+ )
299
+ self._state.aligned = True
300
+
301
+ return self._concat(aa_a, aa_b)
302
+
303
+ def _concat(self, a: AxisArray, b: AxisArray) -> AxisArray:
304
+ """Concatenate *a* and *b* along the configured merge axis."""
305
+ merge_dim = self.settings.axis
306
+
307
+ # If the merge dim doesn't exist in an input, add it as a trailing axis.
308
+ if merge_dim not in a.dims:
309
+ xp = get_namespace(a.data)
310
+ a = replace(a, data=xp.expand_dims(a.data, axis=-1), dims=[*a.dims, merge_dim])
311
+ if merge_dim not in b.dims:
312
+ xp = get_namespace(b.data)
313
+ b = replace(b, data=xp.expand_dims(b.data, axis=-1), dims=[*b.dims, merge_dim])
314
+
315
+ # Use the cached merged axis (rebuilt lazily when labels change).
316
+ if self._state.merged_concat_axis is None:
317
+ self._state.merged_concat_axis = self._build_merged_concat_axis()
318
+
319
+ key = self.settings.new_key if self.settings.new_key is not None else a.key
320
+ result = AxisArray.concatenate(a, b, dim=merge_dim, axis=self._state.merged_concat_axis)
321
+ if key != result.key:
322
+ result = replace(result, key=key)
323
+ return result
324
+
325
+
326
+ class Merge(BaseProcessorUnit[MergeSettings]):
327
+ """Merge two AxisArray streams by time-aligning and concatenating along a non-time axis.
328
+
329
+ Input A routes through the processor's ``__acall__`` (triggering
330
+ hash-based reset when the stream structure changes). Input B
331
+ routes through ``push_b`` which independently tracks its own structure.
332
+
333
+ Inherits ``INPUT_SETTINGS`` and ``on_settings`` → ``create_processor``
334
+ from :class:`BaseProcessorUnit`.
335
+ """
336
+
337
+ SETTINGS = MergeSettings
338
+
339
+ INPUT_SIGNAL_A = ez.InputStream(AxisArray)
340
+ INPUT_SIGNAL_B = ez.InputStream(AxisArray)
341
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
342
+
343
+ def create_processor(self) -> None:
344
+ self.processor = MergeProcessor(settings=self.SETTINGS)
345
+
346
+ @ez.subscriber(INPUT_SIGNAL_A, zero_copy=True)
347
+ @ez.publisher(OUTPUT_SIGNAL)
348
+ async def on_a(self, msg: AxisArray) -> typing.AsyncGenerator:
349
+ result = await self.processor.__acall__(msg)
350
+ if result is not None:
351
+ yield self.OUTPUT_SIGNAL, result
352
+
353
+ @ez.subscriber(INPUT_SIGNAL_B, zero_copy=True)
354
+ @ez.publisher(OUTPUT_SIGNAL)
355
+ async def on_b(self, msg: AxisArray) -> typing.AsyncGenerator:
356
+ result = self.processor.push_b(msg)
357
+ if result is not None:
358
+ yield self.OUTPUT_SIGNAL, result
ezmsg/sigproc/quantize.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import ezmsg.core as ez
2
- import numpy as np
2
+ from array_api_compat import get_namespace
3
3
  from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
4
4
  from ezmsg.util.messages.axisarray import AxisArray, replace
5
5
 
@@ -33,32 +33,33 @@ class QuantizeTransformer(BaseTransformer[QuantizeSettings, AxisArray, AxisArray
33
33
  self,
34
34
  message: AxisArray,
35
35
  ) -> AxisArray:
36
+ xp = get_namespace(message.data)
36
37
  expected_range = self.settings.max_val - self.settings.min_val
37
38
  scale_factor = 2**self.settings.bits - 1
38
39
  clip_max = self.settings.max_val
39
40
 
40
41
  # Determine appropriate integer type based on bits
41
42
  if self.settings.bits <= 1:
42
- dtype = bool
43
+ dtype = xp.bool
43
44
  elif self.settings.bits <= 8:
44
- dtype = np.uint8
45
+ dtype = xp.uint8
45
46
  elif self.settings.bits <= 16:
46
- dtype = np.uint16
47
+ dtype = xp.uint16
47
48
  elif self.settings.bits <= 32:
48
- dtype = np.uint32
49
+ dtype = xp.uint32
49
50
  else:
50
- dtype = np.uint64
51
+ dtype = xp.uint64
51
52
  if self.settings.bits == 64:
52
53
  # The practical upper bound before converting to int is: 2**64 - 1025
53
54
  # Anything larger will wrap around to 0.
54
55
  #
55
56
  clip_max *= 1 - 2e-16
56
57
 
57
- data = message.data.clip(self.settings.min_val, clip_max)
58
+ data = xp.clip(message.data, self.settings.min_val, clip_max)
58
59
  data = (data - self.settings.min_val) / expected_range
59
60
 
60
61
  # Scale to the quantized range [0, 2^bits - 1]
61
- data = np.rint(scale_factor * data).astype(dtype)
62
+ data = xp.round(scale_factor * data).astype(dtype)
62
63
 
63
64
  # Create a new AxisArray with the quantized data
64
65
  return replace(message, data=data)
@@ -1,8 +1,9 @@
1
+ import math
1
2
  from collections import deque
2
3
 
3
4
  import ezmsg.core as ez
4
- import numpy as np
5
5
  import numpy.typing as npt
6
+ from array_api_compat import get_namespace
6
7
  from ezmsg.baseproc import (
7
8
  BaseAdaptiveTransformer,
8
9
  BaseAdaptiveTransformerUnit,
@@ -111,12 +112,13 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
111
112
  return hash((message.key, samp_shape, gain))
112
113
 
113
114
  def _reset_state(self, message: AxisArray) -> None:
115
+ xp = get_namespace(message.data)
114
116
  ch = message.data.shape[-1]
115
- self._state.mean = np.zeros(ch)
117
+ self._state.mean = xp.zeros(ch, dtype=xp.float64)
116
118
  self._state.N = 0
117
- self._state.M2 = np.zeros(ch)
119
+ self._state.M2 = xp.zeros(ch, dtype=xp.float64)
118
120
  self._state.k_samples = (
119
- int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain))
121
+ math.ceil(self.settings.window_size / message.axes[self.settings.axis].gain)
120
122
  if self.settings.window_size is not None
121
123
  else self.settings.k_samples
122
124
  )
@@ -127,7 +129,7 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
127
129
  ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
128
130
  self._state.samples = deque(maxlen=self._state.k_samples)
129
131
  self._state.min_samples = (
130
- int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain))
132
+ math.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain)
131
133
  if self.settings.window_size is not None
132
134
  else self.settings.min_samples
133
135
  )
@@ -136,10 +138,11 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
136
138
  self._state.min_samples = self._state.k_samples
137
139
 
138
140
  def _add_batch_stats(self, x: npt.NDArray) -> None:
139
- x = np.asarray(x, dtype=np.float64)
141
+ xp = get_namespace(x)
142
+ x = xp.asarray(x, dtype=xp.float64)
140
143
  n_b = x.shape[0]
141
- mean_b = np.mean(x, axis=0)
142
- M2_b = np.sum((x - mean_b) ** 2, axis=0)
144
+ mean_b = xp.mean(x, axis=0)
145
+ M2_b = xp.sum((x - mean_b) ** 2, axis=0)
143
146
 
144
147
  if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
145
148
  n_old, mean_old, M2_old = self._state.samples.popleft()
@@ -148,8 +151,8 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
148
151
 
149
152
  if N_new <= 0:
150
153
  self._state.N = 0
151
- self._state.mean = np.zeros_like(self._state.mean)
152
- self._state.M2 = np.zeros_like(self._state.M2)
154
+ self._state.mean = xp.zeros_like(self._state.mean)
155
+ self._state.M2 = xp.zeros_like(self._state.M2)
153
156
  else:
154
157
  delta = mean_old - self._state.mean
155
158
  self._state.N = N_new
@@ -170,32 +173,37 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
170
173
  self._add_batch_stats(x)
171
174
 
172
175
  def _process(self, message: AxisArray) -> AxisArray:
176
+ xp = get_namespace(message.data)
177
+
173
178
  if self._state.N == 0 or self._state.N < self._state.min_samples:
174
179
  if self.settings.update_with_signal:
175
180
  x = message.data
176
181
  if self.settings.artifact_z_thresh is not None and self._state.N > 0:
177
182
  varis = self._state.M2 / self._state.N
178
- std = np.maximum(np.sqrt(varis), 1e-8)
179
- z = np.abs((x - self._state.mean) / std)
180
- mask = np.any(z > self.settings.artifact_z_thresh, axis=1)
183
+ raw_std = varis**0.5
184
+ std = xp.where(xp.isnan(raw_std), raw_std, xp.clip(raw_std, min=1e-8))
185
+ z = xp.abs((x - self._state.mean) / std)
186
+ mask = xp.any(z > self.settings.artifact_z_thresh, axis=1)
181
187
  x = x[~mask]
182
188
  if x.size > 0:
183
189
  self._add_batch_stats(x)
184
190
  return message
185
191
 
186
192
  varis = self._state.M2 / self._state.N
187
- std = np.maximum(np.sqrt(varis), 1e-8)
188
- with np.errstate(divide="ignore", invalid="ignore"):
189
- result = (message.data - self._state.mean) / std
190
- result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
193
+ raw_std = varis**0.5
194
+ # Preserve NaN from negative variance (will be caught below), clip positive std to floor
195
+ std = xp.where(xp.isnan(raw_std), raw_std, xp.clip(raw_std, min=1e-8))
196
+ result = (message.data - self._state.mean) / std
197
+ # Replace NaN/inf with 0 (equivalent to nan_to_num with nan=0, posinf=0, neginf=0)
198
+ result = xp.where(xp.isfinite(result), result, xp.asarray(0.0, dtype=result.dtype))
191
199
  if self.settings.clip is not None:
192
- result = np.clip(result, -self.settings.clip, self.settings.clip)
200
+ result = xp.clip(result, -self.settings.clip, self.settings.clip)
193
201
 
194
202
  if self.settings.update_with_signal:
195
203
  x = message.data
196
204
  if self.settings.artifact_z_thresh is not None:
197
- z_scores = np.abs((x - self._state.mean) / std)
198
- mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1)
205
+ z_scores = xp.abs((x - self._state.mean) / std)
206
+ mask = xp.any(z_scores > self.settings.artifact_z_thresh, axis=1)
199
207
  x = x[~mask]
200
208
  if x.size > 0:
201
209
  self._add_batch_stats(x)
ezmsg/sigproc/scaler.py CHANGED
@@ -2,6 +2,7 @@ import typing
2
2
 
3
3
  import ezmsg.core as ez
4
4
  import numpy as np
5
+ from array_api_compat import get_namespace
5
6
  from ezmsg.baseproc import (
6
7
  BaseStatefulTransformer,
7
8
  BaseTransformerUnit,
@@ -132,15 +133,20 @@ class AdaptiveStandardScalerTransformer(
132
133
  self._state.vars_sq_ewma.settings = replace(self._state.vars_sq_ewma.settings, accumulate=value)
133
134
 
134
135
  def _process(self, message: AxisArray) -> AxisArray:
136
+ xp = get_namespace(message.data)
137
+
135
138
  # Update step (respects accumulate setting via child EWMAs)
136
139
  mean_message = self._state.samps_ewma(message)
137
140
  var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
138
141
 
139
- # Get step
142
+ # Get step: safe division avoids warnings from zero/negative variance
140
143
  varis = var_sq_message.data - mean_message.data**2
141
- with np.errstate(divide="ignore", invalid="ignore"):
142
- result = (message.data - mean_message.data) / (varis**0.5)
143
- result[np.isnan(result)] = 0.0
144
+ std = varis**0.5
145
+ mask = std > 0
146
+ safe_std = xp.where(mask, std, xp.asarray(1.0, dtype=std.dtype))
147
+ result = xp.where(
148
+ mask, (message.data - mean_message.data) / safe_std, xp.asarray(0.0, dtype=message.data.dtype)
149
+ )
144
150
  return replace(message, data=result)
145
151
 
146
152
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-sigproc
3
- Version: 2.11.0
3
+ Version: 2.13.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
@@ -1,8 +1,8 @@
1
1
  ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
2
- ezmsg/sigproc/__version__.py,sha256=eqKbWb9LnxuZWE9-pafopBz45ugg0beSlKLIOIjeSzc,706
2
+ ezmsg/sigproc/__version__.py,sha256=_4LOjlEcfZzfuqIlglDZmVBPO4LyQ8P97qO716YoUL8,706
3
3
  ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
4
4
  ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
5
- ezmsg/sigproc/affinetransform.py,sha256=jl7DiSa5Yb0qsmFJbfSiSeGmvK1SGoBgycFC5JU5DVY,9434
5
+ ezmsg/sigproc/affinetransform.py,sha256=ZugiQg89Ly1I9SDgf0ZzgU2XdwVDmPrU7-orO9yrt7w,20210
6
6
  ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,9362
7
7
  ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
8
8
  ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
@@ -29,12 +29,13 @@ ezmsg/sigproc/firfilter.py,sha256=7r_I476nYuixJsuwc_hQ0Fbq8WB0gnYBZUKs3zultOQ,37
29
29
  ezmsg/sigproc/gaussiansmoothing.py,sha256=BfXm9YQoOtieM4ABK2KRgxeQz055rd7mqtTVqmjT3Rk,2672
30
30
  ezmsg/sigproc/kaiser.py,sha256=dIwgHbLXUHPtdotsGNLE9VG_clhcMgvVnSoFkMVgF9M,3483
31
31
  ezmsg/sigproc/linear.py,sha256=b3NRzQNBvdU2jqenZT9XXFHax9Mavbj2xFiVxOwl1Ms,4662
32
+ ezmsg/sigproc/merge.py,sha256=LmuN3LDIZF7DynMcjLp7eGc2G3Yxks9Zd8-luSqqXuA,15436
32
33
  ezmsg/sigproc/messages.py,sha256=KQczHTeifn4BZycChN8ZcpfZoQW3lC_xuFmN72QT97s,925
33
- ezmsg/sigproc/quantize.py,sha256=uSM2z2xXwL0dgSltyzLEmlKjaJZ2meA3PDWX8_bM0Hs,2195
34
+ ezmsg/sigproc/quantize.py,sha256=y7T4_67BHZluX3gyl2anp8iL6EEI6JvsK7Pmp1vapsk,2268
34
35
  ezmsg/sigproc/resample.py,sha256=3mm9pvxryNVhQuTCIMW3ToUkUfbVOCsIgvXUiurit1Y,11389
35
- ezmsg/sigproc/rollingscaler.py,sha256=e-smSKDhmDD2nWIf6I77CtRxQp_7sHS268SGPi7aXp8,8499
36
+ ezmsg/sigproc/rollingscaler.py,sha256=GcLctVAWTmx9J39r0-dt3e7C_hs25s7M0dDnKiGhkC4,8955
36
37
  ezmsg/sigproc/sampler.py,sha256=iOk2YoUX22u9iTjFKimzP5V074RDBVcmswgfyxvZRZo,10761
37
- ezmsg/sigproc/scaler.py,sha256=nCgShZufPId_b-Sbsc8Si31lbtOb3nPImNcnksd774w,6578
38
+ ezmsg/sigproc/scaler.py,sha256=kVXjRbqoxJ5yJICGsGagRXYIDW3-oihSnbBj-n3s55o,6816
38
39
  ezmsg/sigproc/signalinjector.py,sha256=mB62H2b-ScgPtH1jajEpxgDHqdb-RKekQfgyNncsE8Y,2874
39
40
  ezmsg/sigproc/singlebandpow.py,sha256=BVlWhFI6zU3ME3EVdZbwf-FMz1d2sfuNFDKXs1hn5HM,4353
40
41
  ezmsg/sigproc/slicer.py,sha256=xLXxWf722V08ytVwvPimYjDKKj0pkC2HjdgCVaoaOvs,5195
@@ -61,7 +62,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
61
62
  ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
62
63
  ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
63
64
  ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
64
- ezmsg_sigproc-2.11.0.dist-info/METADATA,sha256=8XB8fu3sNqsrwV-ff8xtlWUKsFdERMSqqkotMhfNtu0,1909
65
- ezmsg_sigproc-2.11.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
66
- ezmsg_sigproc-2.11.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
67
- ezmsg_sigproc-2.11.0.dist-info/RECORD,,
65
+ ezmsg_sigproc-2.13.0.dist-info/METADATA,sha256=RXENX541lABAic8oUDuT8vQwx9nlWY9JETyXYKxdeTQ,1909
66
+ ezmsg_sigproc-2.13.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
67
+ ezmsg_sigproc-2.13.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
68
+ ezmsg_sigproc-2.13.0.dist-info/RECORD,,