ezmsg-sigproc 2.11.0__py3-none-any.whl → 2.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/affinetransform.py +312 -40
- ezmsg/sigproc/merge.py +358 -0
- ezmsg/sigproc/quantize.py +9 -8
- ezmsg/sigproc/rollingscaler.py +28 -20
- ezmsg/sigproc/scaler.py +10 -4
- {ezmsg_sigproc-2.11.0.dist-info → ezmsg_sigproc-2.13.0.dist-info}/METADATA +1 -1
- {ezmsg_sigproc-2.11.0.dist-info → ezmsg_sigproc-2.13.0.dist-info}/RECORD +10 -9
- {ezmsg_sigproc-2.11.0.dist-info → ezmsg_sigproc-2.13.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-2.11.0.dist-info → ezmsg_sigproc-2.13.0.dist-info}/licenses/LICENSE +0 -0
ezmsg/sigproc/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.
|
|
32
|
-
__version_tuple__ = version_tuple = (2,
|
|
31
|
+
__version__ = version = '2.13.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 13, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
ezmsg/sigproc/affinetransform.py
CHANGED
|
@@ -14,9 +14,9 @@ from pathlib import Path
|
|
|
14
14
|
import ezmsg.core as ez
|
|
15
15
|
import numpy as np
|
|
16
16
|
import numpy.typing as npt
|
|
17
|
+
from array_api_compat import get_namespace
|
|
17
18
|
from ezmsg.baseproc import (
|
|
18
19
|
BaseStatefulTransformer,
|
|
19
|
-
BaseTransformer,
|
|
20
20
|
BaseTransformerUnit,
|
|
21
21
|
processor_state,
|
|
22
22
|
)
|
|
@@ -24,6 +24,117 @@ from ezmsg.util.messages.axisarray import AxisArray, AxisBase
|
|
|
24
24
|
from ezmsg.util.messages.util import replace
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
def _find_block_diagonal_clusters(weights: np.ndarray) -> list[tuple[np.ndarray, np.ndarray]] | None:
|
|
28
|
+
"""Detect block-diagonal structure in a weight matrix.
|
|
29
|
+
|
|
30
|
+
Finds connected components in the bipartite graph of non-zero weights,
|
|
31
|
+
where input channels and output channels are separate node sets.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
weights: 2-D weight matrix of shape (n_in, n_out).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
List of (input_indices, output_indices) tuples, one per block, or
|
|
38
|
+
None if the matrix is not block-diagonal (single connected component).
|
|
39
|
+
"""
|
|
40
|
+
if weights.ndim != 2:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
n_in, n_out = weights.shape
|
|
44
|
+
if n_in + n_out <= 2:
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
from scipy.sparse import coo_matrix
|
|
48
|
+
from scipy.sparse.csgraph import connected_components
|
|
49
|
+
|
|
50
|
+
rows, cols = np.nonzero(weights)
|
|
51
|
+
if len(rows) == 0:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
# Bipartite graph: input nodes [0, n_in), output nodes [n_in, n_in + n_out)
|
|
55
|
+
shifted_cols = cols + n_in
|
|
56
|
+
adj_rows = np.concatenate([rows, shifted_cols])
|
|
57
|
+
adj_cols = np.concatenate([shifted_cols, rows])
|
|
58
|
+
adj_data = np.ones(len(adj_rows), dtype=bool)
|
|
59
|
+
n_nodes = n_in + n_out
|
|
60
|
+
adj = coo_matrix((adj_data, (adj_rows, adj_cols)), shape=(n_nodes, n_nodes))
|
|
61
|
+
|
|
62
|
+
n_components, labels = connected_components(adj, directed=False)
|
|
63
|
+
|
|
64
|
+
if n_components <= 1:
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
clusters = []
|
|
68
|
+
for comp in range(n_components):
|
|
69
|
+
members = np.where(labels == comp)[0]
|
|
70
|
+
in_idx = np.sort(members[members < n_in])
|
|
71
|
+
out_idx = np.sort(members[members >= n_in] - n_in)
|
|
72
|
+
if len(in_idx) > 0 and len(out_idx) > 0:
|
|
73
|
+
clusters.append((in_idx, out_idx))
|
|
74
|
+
|
|
75
|
+
return clusters if len(clusters) > 1 else None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _max_cross_cluster_weight(weights: np.ndarray, clusters: list[tuple[np.ndarray, np.ndarray]]) -> float:
|
|
79
|
+
"""Return the maximum absolute weight between different clusters."""
|
|
80
|
+
mask = np.zeros(weights.shape, dtype=bool)
|
|
81
|
+
for in_idx, out_idx in clusters:
|
|
82
|
+
mask[np.ix_(in_idx, out_idx)] = True
|
|
83
|
+
cross = np.abs(weights[~mask])
|
|
84
|
+
return float(cross.max()) if cross.size > 0 else 0.0
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _merge_small_clusters(
|
|
88
|
+
clusters: list[tuple[np.ndarray, np.ndarray]], min_size: int
|
|
89
|
+
) -> list[tuple[np.ndarray, np.ndarray]]:
|
|
90
|
+
"""Merge clusters smaller than *min_size* into combined groups.
|
|
91
|
+
|
|
92
|
+
Small clusters are greedily concatenated until each merged group has
|
|
93
|
+
at least *min_size* channels (measured as ``max(n_in, n_out)``).
|
|
94
|
+
Any leftover small clusters that don't reach the threshold are
|
|
95
|
+
combined into a final group.
|
|
96
|
+
|
|
97
|
+
The merged group's sub-weight-matrix will contain the original small
|
|
98
|
+
diagonal blocks with zeros between them — a dense matmul on that
|
|
99
|
+
sub-matrix is cheaper than iterating over many tiny matmuls.
|
|
100
|
+
"""
|
|
101
|
+
if min_size <= 1:
|
|
102
|
+
return clusters
|
|
103
|
+
|
|
104
|
+
large = []
|
|
105
|
+
small = []
|
|
106
|
+
for cluster in clusters:
|
|
107
|
+
in_idx, out_idx = cluster
|
|
108
|
+
if max(len(in_idx), len(out_idx)) >= min_size:
|
|
109
|
+
large.append(cluster)
|
|
110
|
+
else:
|
|
111
|
+
small.append(cluster)
|
|
112
|
+
|
|
113
|
+
if not small:
|
|
114
|
+
return clusters
|
|
115
|
+
|
|
116
|
+
current_in: list[np.ndarray] = []
|
|
117
|
+
current_out: list[np.ndarray] = []
|
|
118
|
+
current_in_size = 0
|
|
119
|
+
current_out_size = 0
|
|
120
|
+
for in_idx, out_idx in small:
|
|
121
|
+
current_in.append(in_idx)
|
|
122
|
+
current_out.append(out_idx)
|
|
123
|
+
current_in_size += len(in_idx)
|
|
124
|
+
current_out_size += len(out_idx)
|
|
125
|
+
if max(current_in_size, current_out_size) >= min_size:
|
|
126
|
+
large.append((np.sort(np.concatenate(current_in)), np.sort(np.concatenate(current_out))))
|
|
127
|
+
current_in = []
|
|
128
|
+
current_out = []
|
|
129
|
+
current_in_size = 0
|
|
130
|
+
current_out_size = 0
|
|
131
|
+
|
|
132
|
+
if current_in:
|
|
133
|
+
large.append((np.sort(np.concatenate(current_in)), np.sort(np.concatenate(current_out))))
|
|
134
|
+
|
|
135
|
+
return large
|
|
136
|
+
|
|
137
|
+
|
|
27
138
|
class AffineTransformSettings(ez.Settings):
|
|
28
139
|
"""
|
|
29
140
|
Settings for :obj:`AffineTransform`.
|
|
@@ -38,11 +149,32 @@ class AffineTransformSettings(ez.Settings):
|
|
|
38
149
|
right_multiply: bool = True
|
|
39
150
|
"""Set False to transpose the weights before applying."""
|
|
40
151
|
|
|
152
|
+
channel_clusters: list[list[int]] | None = None
|
|
153
|
+
"""Optional explicit input channel cluster specification for block-diagonal optimization.
|
|
154
|
+
|
|
155
|
+
Each element is a list of input channel indices forming one cluster. The
|
|
156
|
+
corresponding output indices are derived automatically from the non-zero
|
|
157
|
+
columns of the weight matrix for those input rows.
|
|
158
|
+
|
|
159
|
+
When provided, the weight matrix is decomposed into per-cluster sub-matrices
|
|
160
|
+
and multiplied separately, which is faster when cross-cluster weights are zero.
|
|
161
|
+
|
|
162
|
+
If None, block-diagonal structure is auto-detected from the zero pattern
|
|
163
|
+
of the weights."""
|
|
164
|
+
|
|
165
|
+
min_cluster_size: int = 32
|
|
166
|
+
"""Minimum number of channels per cluster for the block-diagonal optimization.
|
|
167
|
+
Clusters smaller than this are greedily merged together to avoid excessive
|
|
168
|
+
Python loop overhead. Set to 1 to disable merging."""
|
|
169
|
+
|
|
41
170
|
|
|
42
171
|
@processor_state
|
|
43
172
|
class AffineTransformState:
|
|
44
173
|
weights: npt.NDArray | None = None
|
|
45
174
|
new_axis: AxisBase | None = None
|
|
175
|
+
n_out: int = 0
|
|
176
|
+
clusters: list | None = None
|
|
177
|
+
"""list of (in_indices_xp, out_indices_xp, sub_weights_xp) tuples when block-diagonal."""
|
|
46
178
|
|
|
47
179
|
|
|
48
180
|
class AffineTransformTransformer(
|
|
@@ -85,11 +217,60 @@ class AffineTransformTransformer(
|
|
|
85
217
|
|
|
86
218
|
self._state.weights = weights
|
|
87
219
|
|
|
220
|
+
# Note: If weights were scipy.sparse BSR then maybe we could use automate this next part.
|
|
221
|
+
# However, that would break compatibility with Array API.
|
|
222
|
+
|
|
223
|
+
# --- Block-diagonal cluster detection ---
|
|
224
|
+
# Clusters are a list of (input_indices, output_indices) tuples.
|
|
225
|
+
n_in, n_out = weights.shape
|
|
226
|
+
if self.settings.channel_clusters is not None:
|
|
227
|
+
# Validate input index bounds
|
|
228
|
+
all_in = np.concatenate([np.asarray(group) for group in self.settings.channel_clusters])
|
|
229
|
+
if np.any((all_in < 0) | (all_in >= n_in)):
|
|
230
|
+
raise ValueError(
|
|
231
|
+
"channel_clusters contains out-of-range input indices " f"(valid range: 0..{n_in - 1})"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Derive output indices from non-zero weights for each input cluster
|
|
235
|
+
clusters = []
|
|
236
|
+
for group in self.settings.channel_clusters:
|
|
237
|
+
in_idx = np.asarray(group)
|
|
238
|
+
out_idx = np.where(np.any(weights[in_idx, :] != 0, axis=0))[0]
|
|
239
|
+
clusters.append((in_idx, out_idx))
|
|
240
|
+
|
|
241
|
+
max_cross = _max_cross_cluster_weight(weights, clusters)
|
|
242
|
+
if max_cross > 0:
|
|
243
|
+
ez.logger.warning(
|
|
244
|
+
f"Non-zero cross-cluster weights detected (max abs: {max_cross:.2e}). "
|
|
245
|
+
"These will be ignored in block-diagonal multiplication."
|
|
246
|
+
)
|
|
247
|
+
else:
|
|
248
|
+
clusters = _find_block_diagonal_clusters(weights)
|
|
249
|
+
if clusters is not None:
|
|
250
|
+
ez.logger.info(
|
|
251
|
+
f"Auto-detected {len(clusters)} block-diagonal clusters "
|
|
252
|
+
f"(sizes: {[(len(i), len(o)) for i, o in clusters]})"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Merge small clusters to avoid excessive loop overhead
|
|
256
|
+
if clusters is not None:
|
|
257
|
+
clusters = _merge_small_clusters(clusters, self.settings.min_cluster_size)
|
|
258
|
+
|
|
259
|
+
if clusters is not None and len(clusters) > 1:
|
|
260
|
+
self._state.n_out = n_out
|
|
261
|
+
self._state.clusters = [
|
|
262
|
+
(in_idx, out_idx, np.ascontiguousarray(weights[np.ix_(in_idx, out_idx)]))
|
|
263
|
+
for in_idx, out_idx in clusters
|
|
264
|
+
]
|
|
265
|
+
self._state.weights = None
|
|
266
|
+
else:
|
|
267
|
+
self._state.clusters = None
|
|
268
|
+
|
|
269
|
+
# --- Axis label handling (for non-square transforms, non-cluster path) ---
|
|
88
270
|
axis = self.settings.axis or message.dims[-1]
|
|
89
|
-
if axis in message.axes and hasattr(message.axes[axis], "data") and
|
|
271
|
+
if axis in message.axes and hasattr(message.axes[axis], "data") and n_in != n_out:
|
|
90
272
|
in_labels = message.axes[axis].data
|
|
91
273
|
new_labels = []
|
|
92
|
-
n_in, n_out = weights.shape
|
|
93
274
|
if len(in_labels) != n_in:
|
|
94
275
|
ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.")
|
|
95
276
|
else:
|
|
@@ -111,23 +292,70 @@ class AffineTransformTransformer(
|
|
|
111
292
|
|
|
112
293
|
self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
|
|
113
294
|
|
|
295
|
+
# Convert to match message.data namespace for efficient operations in _process
|
|
296
|
+
xp = get_namespace(message.data)
|
|
297
|
+
if self._state.weights is not None:
|
|
298
|
+
self._state.weights = xp.asarray(self._state.weights)
|
|
299
|
+
if self._state.clusters is not None:
|
|
300
|
+
self._state.clusters = [
|
|
301
|
+
(xp.asarray(in_idx), xp.asarray(out_idx), xp.asarray(sub_w))
|
|
302
|
+
for in_idx, out_idx, sub_w in self._state.clusters
|
|
303
|
+
]
|
|
304
|
+
|
|
305
|
+
def _block_diagonal_matmul(self, xp, data, axis_idx):
|
|
306
|
+
"""Perform matmul using block-diagonal decomposition.
|
|
307
|
+
|
|
308
|
+
For each cluster, gathers input channels via ``xp.take``, performs a
|
|
309
|
+
matmul with the cluster's sub-weight matrix, and writes the result
|
|
310
|
+
directly into the pre-allocated output at the cluster's output indices.
|
|
311
|
+
Omitted output channels naturally remain zero.
|
|
312
|
+
"""
|
|
313
|
+
needs_permute = axis_idx not in [-1, data.ndim - 1]
|
|
314
|
+
if needs_permute:
|
|
315
|
+
dim_perm = list(range(data.ndim))
|
|
316
|
+
dim_perm.append(dim_perm.pop(axis_idx))
|
|
317
|
+
data = xp.permute_dims(data, dim_perm)
|
|
318
|
+
|
|
319
|
+
# Pre-allocate output (omitted channels stay zero)
|
|
320
|
+
out_shape = data.shape[:-1] + (self._state.n_out,)
|
|
321
|
+
result = xp.zeros(out_shape, dtype=data.dtype)
|
|
322
|
+
|
|
323
|
+
for in_idx, out_idx, sub_weights in self._state.clusters:
|
|
324
|
+
chunk = xp.take(data, in_idx, axis=data.ndim - 1)
|
|
325
|
+
result[..., out_idx] = xp.matmul(chunk, sub_weights)
|
|
326
|
+
|
|
327
|
+
if needs_permute:
|
|
328
|
+
inv_dim_perm = list(range(result.ndim))
|
|
329
|
+
inv_dim_perm.insert(axis_idx, inv_dim_perm.pop(-1))
|
|
330
|
+
result = xp.permute_dims(result, inv_dim_perm)
|
|
331
|
+
|
|
332
|
+
return result
|
|
333
|
+
|
|
114
334
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
335
|
+
xp = get_namespace(message.data)
|
|
115
336
|
axis = self.settings.axis or message.dims[-1]
|
|
116
337
|
axis_idx = message.get_axis_idx(axis)
|
|
117
338
|
data = message.data
|
|
118
339
|
|
|
119
|
-
if
|
|
120
|
-
|
|
121
|
-
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
|
|
122
|
-
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
|
|
123
|
-
data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
|
|
124
|
-
|
|
125
|
-
if axis_idx in [-1, len(message.dims) - 1]:
|
|
126
|
-
data = np.matmul(data, self._state.weights)
|
|
340
|
+
if self._state.clusters is not None:
|
|
341
|
+
data = self._block_diagonal_matmul(xp, data, axis_idx)
|
|
127
342
|
else:
|
|
128
|
-
data
|
|
129
|
-
|
|
130
|
-
|
|
343
|
+
if data.shape[axis_idx] == (self._state.weights.shape[0] - 1):
|
|
344
|
+
# The weights are stacked A|B where A is the transform and B is a single row
|
|
345
|
+
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
|
|
346
|
+
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
|
|
347
|
+
data = xp.concat((data, xp.ones(sample_shape, dtype=data.dtype)), axis=axis_idx)
|
|
348
|
+
|
|
349
|
+
if axis_idx in [-1, len(message.dims) - 1]:
|
|
350
|
+
data = xp.matmul(data, self._state.weights)
|
|
351
|
+
else:
|
|
352
|
+
perm = list(range(data.ndim))
|
|
353
|
+
perm.append(perm.pop(axis_idx))
|
|
354
|
+
data = xp.permute_dims(data, perm)
|
|
355
|
+
data = xp.matmul(data, self._state.weights)
|
|
356
|
+
inv_perm = list(range(data.ndim))
|
|
357
|
+
inv_perm.insert(axis_idx, inv_perm.pop(-1))
|
|
358
|
+
data = xp.permute_dims(data, inv_perm)
|
|
131
359
|
|
|
132
360
|
replace_kwargs = {"data": data}
|
|
133
361
|
if self._state.new_axis is not None:
|
|
@@ -144,6 +372,8 @@ def affine_transform(
|
|
|
144
372
|
weights: np.ndarray | str | Path,
|
|
145
373
|
axis: str | None = None,
|
|
146
374
|
right_multiply: bool = True,
|
|
375
|
+
channel_clusters: list[list[int]] | None = None,
|
|
376
|
+
min_cluster_size: int = 32,
|
|
147
377
|
) -> AffineTransformTransformer:
|
|
148
378
|
"""
|
|
149
379
|
Perform affine transformations on streaming data.
|
|
@@ -152,19 +382,25 @@ def affine_transform(
|
|
|
152
382
|
weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
|
|
153
383
|
axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
|
|
154
384
|
right_multiply: Set False to transpose the weights before applying.
|
|
385
|
+
channel_clusters: Optional explicit channel cluster specification. See
|
|
386
|
+
:attr:`AffineTransformSettings.channel_clusters`.
|
|
387
|
+
min_cluster_size: Minimum channels per cluster; smaller clusters are merged. See
|
|
388
|
+
:attr:`AffineTransformSettings.min_cluster_size`.
|
|
155
389
|
|
|
156
390
|
Returns:
|
|
157
391
|
:obj:`AffineTransformTransformer`.
|
|
158
392
|
"""
|
|
159
393
|
return AffineTransformTransformer(
|
|
160
|
-
AffineTransformSettings(
|
|
394
|
+
AffineTransformSettings(
|
|
395
|
+
weights=weights,
|
|
396
|
+
axis=axis,
|
|
397
|
+
right_multiply=right_multiply,
|
|
398
|
+
channel_clusters=channel_clusters,
|
|
399
|
+
min_cluster_size=min_cluster_size,
|
|
400
|
+
)
|
|
161
401
|
)
|
|
162
402
|
|
|
163
403
|
|
|
164
|
-
def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
|
|
165
|
-
return np.zeros_like(data)
|
|
166
|
-
|
|
167
|
-
|
|
168
404
|
class CommonRereferenceSettings(ez.Settings):
|
|
169
405
|
"""
|
|
170
406
|
Settings for :obj:`CommonRereference`
|
|
@@ -179,35 +415,64 @@ class CommonRereferenceSettings(ez.Settings):
|
|
|
179
415
|
include_current: bool = True
|
|
180
416
|
"""Set False to exclude each channel from participating in the calculation of its reference."""
|
|
181
417
|
|
|
418
|
+
channel_clusters: list[list[int]] | None = None
|
|
419
|
+
"""Optional channel clusters for per-cluster rereferencing. Each element is a
|
|
420
|
+
list of channel indices forming one cluster. The common reference is computed
|
|
421
|
+
independently within each cluster. If None, all channels form a single cluster."""
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@processor_state
|
|
425
|
+
class CommonRereferenceState:
|
|
426
|
+
clusters: list | None = None
|
|
427
|
+
"""list of xp arrays of channel indices, one per cluster."""
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
class CommonRereferenceTransformer(
|
|
431
|
+
BaseStatefulTransformer[CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceState]
|
|
432
|
+
):
|
|
433
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
434
|
+
axis = self.settings.axis or message.dims[-1]
|
|
435
|
+
axis_idx = message.get_axis_idx(axis)
|
|
436
|
+
return hash((message.key, message.data.shape[axis_idx]))
|
|
437
|
+
|
|
438
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
439
|
+
xp = get_namespace(message.data)
|
|
440
|
+
axis = self.settings.axis or message.dims[-1]
|
|
441
|
+
axis_idx = message.get_axis_idx(axis)
|
|
442
|
+
n_chans = message.data.shape[axis_idx]
|
|
443
|
+
|
|
444
|
+
if self.settings.channel_clusters is not None:
|
|
445
|
+
self._state.clusters = [xp.asarray(group) for group in self.settings.channel_clusters]
|
|
446
|
+
else:
|
|
447
|
+
self._state.clusters = [xp.arange(n_chans)]
|
|
182
448
|
|
|
183
|
-
class CommonRereferenceTransformer(BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]):
|
|
184
449
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
185
450
|
if self.settings.mode == "passthrough":
|
|
186
451
|
return message
|
|
187
452
|
|
|
453
|
+
xp = get_namespace(message.data)
|
|
188
454
|
axis = self.settings.axis or message.dims[-1]
|
|
189
455
|
axis_idx = message.get_axis_idx(axis)
|
|
456
|
+
func = {"mean": xp.mean, "median": np.median}[self.settings.mode]
|
|
190
457
|
|
|
191
|
-
|
|
458
|
+
# Use result_type to match dtype promotion from data - float operations.
|
|
459
|
+
out_dtype = np.result_type(message.data.dtype, np.float64)
|
|
460
|
+
output = xp.zeros(message.data.shape, dtype=out_dtype)
|
|
192
461
|
|
|
193
|
-
|
|
462
|
+
for cluster_idx in self._state.clusters:
|
|
463
|
+
cluster_data = xp.take(message.data, cluster_idx, axis=axis_idx)
|
|
464
|
+
ref_data = func(cluster_data, axis=axis_idx, keepdims=True)
|
|
194
465
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
# However, if we had excluded the current channel,
|
|
199
|
-
# then we would have omitted the contribution of the current channel:
|
|
200
|
-
# `CAR[i] = x[0]/(N-1) + x[1]/(N-1) + ... x[i-1]/(N-1) + x[i+1]/(N-1) + ... + x[N-1]/(N-1)`
|
|
201
|
-
# The majority of the calculation is the same as when the current channel is included;
|
|
202
|
-
# we need only rescale CAR so the divisor is `N-1` instead of `N`, then subtract the contribution
|
|
203
|
-
# from the current channel (i.e., `x[i] / (N-1)`)
|
|
204
|
-
# i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
|
|
205
|
-
# We can use broadcasting subtraction instead of looping over channels.
|
|
206
|
-
N = message.data.shape[axis_idx]
|
|
207
|
-
ref_data = (N / (N - 1)) * ref_data - message.data / (N - 1)
|
|
208
|
-
# Note: I profiled using AffineTransformTransformer; it's ~30x slower than this implementation.
|
|
466
|
+
if not self.settings.include_current:
|
|
467
|
+
N = cluster_data.shape[axis_idx]
|
|
468
|
+
ref_data = (N / (N - 1)) * ref_data - cluster_data / (N - 1)
|
|
209
469
|
|
|
210
|
-
|
|
470
|
+
# Write per-cluster result into output at the correct axis position
|
|
471
|
+
idx = [slice(None)] * output.ndim
|
|
472
|
+
idx[axis_idx] = cluster_idx
|
|
473
|
+
output[tuple(idx)] = cluster_data - ref_data
|
|
474
|
+
|
|
475
|
+
return replace(message, data=output)
|
|
211
476
|
|
|
212
477
|
|
|
213
478
|
class CommonRereference(
|
|
@@ -217,19 +482,26 @@ class CommonRereference(
|
|
|
217
482
|
|
|
218
483
|
|
|
219
484
|
def common_rereference(
|
|
220
|
-
mode: str = "mean",
|
|
485
|
+
mode: str = "mean",
|
|
486
|
+
axis: str | None = None,
|
|
487
|
+
include_current: bool = True,
|
|
488
|
+
channel_clusters: list[list[int]] | None = None,
|
|
221
489
|
) -> CommonRereferenceTransformer:
|
|
222
490
|
"""
|
|
223
491
|
Perform common average referencing (CAR) on streaming data.
|
|
224
492
|
|
|
225
493
|
Args:
|
|
226
494
|
mode: The statistical mode to apply -- either "mean" or "median"
|
|
227
|
-
axis: The name of
|
|
495
|
+
axis: The name of the axis to apply the transformation to.
|
|
228
496
|
include_current: Set False to exclude each channel from participating in the calculation of its reference.
|
|
497
|
+
channel_clusters: Optional channel clusters for per-cluster rereferencing. See
|
|
498
|
+
:attr:`CommonRereferenceSettings.channel_clusters`.
|
|
229
499
|
|
|
230
500
|
Returns:
|
|
231
501
|
:obj:`CommonRereferenceTransformer`
|
|
232
502
|
"""
|
|
233
503
|
return CommonRereferenceTransformer(
|
|
234
|
-
CommonRereferenceSettings(
|
|
504
|
+
CommonRereferenceSettings(
|
|
505
|
+
mode=mode, axis=axis, include_current=include_current, channel_clusters=channel_clusters
|
|
506
|
+
)
|
|
235
507
|
)
|
ezmsg/sigproc/merge.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""Time-aligned merge of two AxisArray streams along a non-time axis."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing
|
|
6
|
+
|
|
7
|
+
import ezmsg.core as ez
|
|
8
|
+
import numpy as np
|
|
9
|
+
from array_api_compat import get_namespace
|
|
10
|
+
from ezmsg.baseproc.protocols import processor_state
|
|
11
|
+
from ezmsg.baseproc.stateful import BaseStatefulTransformer
|
|
12
|
+
from ezmsg.baseproc.units import BaseProcessorUnit
|
|
13
|
+
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
|
|
14
|
+
from ezmsg.util.messages.util import replace
|
|
15
|
+
|
|
16
|
+
from .util.axisarray_buffer import HybridAxisArrayBuffer
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MergeSettings(ez.Settings):
|
|
20
|
+
axis: str = "ch"
|
|
21
|
+
"""Axis along which to concatenate the two signals."""
|
|
22
|
+
|
|
23
|
+
align_axis: str | None = "time"
|
|
24
|
+
"""Axis used for alignment. If None, defaults to the first dimension."""
|
|
25
|
+
|
|
26
|
+
buffer_dur: float = 10.0
|
|
27
|
+
"""Buffer duration in seconds for each input stream."""
|
|
28
|
+
|
|
29
|
+
relabel_axis: bool = True
|
|
30
|
+
"""Whether to relabel coordinate axis labels to ensure uniqueness."""
|
|
31
|
+
|
|
32
|
+
label_a: str = "_a"
|
|
33
|
+
"""Suffix appended to signal A labels when relabel_axis is True."""
|
|
34
|
+
|
|
35
|
+
label_b: str = "_b"
|
|
36
|
+
"""Suffix appended to signal B labels when relabel_axis is True."""
|
|
37
|
+
|
|
38
|
+
new_key: str | None = None
|
|
39
|
+
"""Output AxisArray key. If None, uses the key from signal A."""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@processor_state
|
|
43
|
+
class MergeState:
|
|
44
|
+
# Common state
|
|
45
|
+
gain: float | None = None
|
|
46
|
+
align_axis: str | None = None
|
|
47
|
+
aligned: bool = False
|
|
48
|
+
merged_concat_axis: CoordinateAxis | None = None
|
|
49
|
+
|
|
50
|
+
# A state
|
|
51
|
+
buf_a: HybridAxisArrayBuffer | None = None
|
|
52
|
+
concat_axis_a: CoordinateAxis | None = None
|
|
53
|
+
a_concat_dim: int | None = None
|
|
54
|
+
a_other_dims: tuple[int, ...] | None = None
|
|
55
|
+
|
|
56
|
+
# B state
|
|
57
|
+
buf_b: HybridAxisArrayBuffer | None = None
|
|
58
|
+
concat_axis_b: CoordinateAxis | None = None
|
|
59
|
+
b_concat_dim: int | None = None
|
|
60
|
+
b_other_dims: tuple[int, ...] | None = None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class MergeProcessor(BaseStatefulTransformer[MergeSettings, AxisArray, AxisArray | None, MergeState]):
|
|
64
|
+
"""Processor that time-aligns two AxisArray streams and concatenates them.
|
|
65
|
+
|
|
66
|
+
Input A flows through the standard ``__call__`` / ``_process`` path,
|
|
67
|
+
getting automatic ``_hash_message`` / ``_reset_state`` handling from
|
|
68
|
+
:class:`BaseStatefulTransformer`. Input B flows through :meth:`push_b`,
|
|
69
|
+
which independently tracks its own structure.
|
|
70
|
+
|
|
71
|
+
Invalidation rules:
|
|
72
|
+
|
|
73
|
+
- Gain mismatch (either input vs stored common gain) → full reset.
|
|
74
|
+
- Concat-axis dimensionality change → per-input buffer reset +
|
|
75
|
+
alignment and merged-axis cache invalidation.
|
|
76
|
+
- Non-align/non-concat axis shape change → per-input buffer reset +
|
|
77
|
+
alignment invalidation.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
# -- Structural extraction helpers ---------------------------------------
|
|
81
|
+
|
|
82
|
+
def _extract_gain(self, message: AxisArray) -> float | None:
|
|
83
|
+
"""Extract the align-axis gain from a message."""
|
|
84
|
+
align_name = self.settings.align_axis or message.dims[0]
|
|
85
|
+
ax = message.axes.get(align_name)
|
|
86
|
+
if ax is not None and hasattr(ax, "gain"):
|
|
87
|
+
return ax.gain
|
|
88
|
+
if ax is not None and hasattr(ax, "data") and len(ax.data) > 1:
|
|
89
|
+
return float(ax.data[-1] - ax.data[0]) / (len(ax.data) - 1)
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
# -- Reset helpers -------------------------------------------------------
|
|
93
|
+
|
|
94
|
+
def _full_reset(self, align_axis: str) -> None:
|
|
95
|
+
"""Reset all state — both inputs and common merge state."""
|
|
96
|
+
self._state.align_axis = align_axis
|
|
97
|
+
self._state.buf_a = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis)
|
|
98
|
+
self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis)
|
|
99
|
+
self._state.gain = None
|
|
100
|
+
self._state.aligned = False
|
|
101
|
+
self._state.concat_axis_a = None
|
|
102
|
+
self._state.concat_axis_b = None
|
|
103
|
+
self._state.merged_concat_axis = None
|
|
104
|
+
self._state.a_concat_dim = None
|
|
105
|
+
self._state.a_other_dims = None
|
|
106
|
+
self._state.b_concat_dim = None
|
|
107
|
+
self._state.b_other_dims = None
|
|
108
|
+
|
|
109
|
+
def _reset_a_state(self) -> None:
|
|
110
|
+
"""Reset input-A buffer and concat-axis cache."""
|
|
111
|
+
self._state.buf_a = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=self._state.align_axis)
|
|
112
|
+
self._state.concat_axis_a = None
|
|
113
|
+
|
|
114
|
+
def _reset_b_state(self) -> None:
|
|
115
|
+
"""Reset input-B buffer and concat-axis cache."""
|
|
116
|
+
self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=self._state.align_axis)
|
|
117
|
+
self._state.concat_axis_b = None
|
|
118
|
+
|
|
119
|
+
# -- BaseStatefulTransformer interface ------------------------------------
|
|
120
|
+
|
|
121
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
122
|
+
"""Hash the align-axis gain only.
|
|
123
|
+
|
|
124
|
+
Gain changes trigger a full reset via ``_reset_state``. Concat-axis
|
|
125
|
+
and non-merge dimension changes are handled as partial resets inside
|
|
126
|
+
``_process`` and ``push_b``.
|
|
127
|
+
"""
|
|
128
|
+
return hash(self._extract_gain(message))
|
|
129
|
+
|
|
130
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
131
|
+
"""Full reset — called by the base class when gain changes."""
|
|
132
|
+
align_axis = self.settings.align_axis or message.dims[0]
|
|
133
|
+
self._full_reset(align_axis)
|
|
134
|
+
|
|
135
|
+
def _process(self, message: AxisArray) -> AxisArray | None:
|
|
136
|
+
"""Process input A: detect structural changes, buffer, try merge."""
|
|
137
|
+
# Detect per-input structural changes.
|
|
138
|
+
align_idx = message.dims.index(self._state.align_axis)
|
|
139
|
+
concat_idx = message.dims.index(self.settings.axis) if self.settings.axis in message.dims else None
|
|
140
|
+
concat_dim = message.data.shape[concat_idx] if concat_idx is not None else None
|
|
141
|
+
other_dims = tuple(s for i, s in enumerate(message.data.shape) if i != align_idx and i != concat_idx)
|
|
142
|
+
|
|
143
|
+
if self._state.a_concat_dim is not None and concat_dim != self._state.a_concat_dim:
|
|
144
|
+
self._reset_a_state()
|
|
145
|
+
self._state.aligned = False
|
|
146
|
+
self._state.merged_concat_axis = None
|
|
147
|
+
elif self._state.a_other_dims is not None and other_dims != self._state.a_other_dims:
|
|
148
|
+
self._reset_a_state()
|
|
149
|
+
self._state.aligned = False
|
|
150
|
+
|
|
151
|
+
self._state.a_concat_dim = concat_dim
|
|
152
|
+
self._state.a_other_dims = other_dims
|
|
153
|
+
|
|
154
|
+
self._state.buf_a.write(message)
|
|
155
|
+
if self._state.gain is None:
|
|
156
|
+
self._state.gain = self._state.buf_a.axis_gain
|
|
157
|
+
self._update_concat_axis(message, "a")
|
|
158
|
+
return self._try_merge()
|
|
159
|
+
|
|
160
|
+
# -- Input B entry point ------------------------------------------------
|
|
161
|
+
|
|
162
|
+
def push_b(self, message: AxisArray) -> AxisArray | None:
|
|
163
|
+
"""Process input B: check gain, detect structural changes, buffer, try merge."""
|
|
164
|
+
align_axis = self.settings.align_axis or message.dims[0]
|
|
165
|
+
|
|
166
|
+
# Gain compatibility check.
|
|
167
|
+
b_gain = self._extract_gain(message)
|
|
168
|
+
if self._state.gain is not None and b_gain != self._state.gain:
|
|
169
|
+
self._full_reset(align_axis)
|
|
170
|
+
# Set the base-class hash so the next compatible A goes straight
|
|
171
|
+
# to _process instead of triggering another full reset.
|
|
172
|
+
self._hash = self._hash_message(message)
|
|
173
|
+
|
|
174
|
+
# Lazy-create buf_b if B arrives before A.
|
|
175
|
+
if self._state.buf_b is None:
|
|
176
|
+
if self._state.align_axis is None:
|
|
177
|
+
self._state.align_axis = align_axis
|
|
178
|
+
self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis)
|
|
179
|
+
|
|
180
|
+
# Detect per-input structural changes.
|
|
181
|
+
align_idx = message.dims.index(align_axis)
|
|
182
|
+
concat_idx = message.dims.index(self.settings.axis) if self.settings.axis in message.dims else None
|
|
183
|
+
concat_dim = message.data.shape[concat_idx] if concat_idx is not None else None
|
|
184
|
+
other_dims = tuple(s for i, s in enumerate(message.data.shape) if i != align_idx and i != concat_idx)
|
|
185
|
+
|
|
186
|
+
if self._state.b_concat_dim is not None and concat_dim != self._state.b_concat_dim:
|
|
187
|
+
self._reset_b_state()
|
|
188
|
+
self._state.aligned = False
|
|
189
|
+
self._state.merged_concat_axis = None
|
|
190
|
+
elif self._state.b_other_dims is not None and other_dims != self._state.b_other_dims:
|
|
191
|
+
self._reset_b_state()
|
|
192
|
+
self._state.aligned = False
|
|
193
|
+
|
|
194
|
+
self._state.b_concat_dim = concat_dim
|
|
195
|
+
self._state.b_other_dims = other_dims
|
|
196
|
+
|
|
197
|
+
self._state.buf_b.write(message)
|
|
198
|
+
if self._state.gain is None:
|
|
199
|
+
self._state.gain = self._state.buf_b.axis_gain
|
|
200
|
+
self._update_concat_axis(message, "b")
|
|
201
|
+
return self._try_merge()
|
|
202
|
+
|
|
203
|
+
# -- Concat-axis caching ------------------------------------------------
|
|
204
|
+
|
|
205
|
+
def _update_concat_axis(self, message: AxisArray, which: str) -> None:
|
|
206
|
+
"""Track each input's concat-axis labels; invalidate cache on change."""
|
|
207
|
+
concat_dim = self.settings.axis
|
|
208
|
+
if concat_dim not in message.axes:
|
|
209
|
+
return
|
|
210
|
+
ax = message.axes[concat_dim]
|
|
211
|
+
if not hasattr(ax, "data"):
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
if which == "a":
|
|
215
|
+
if self._state.concat_axis_a is None or not np.array_equal(self._state.concat_axis_a.data, ax.data):
|
|
216
|
+
self._state.concat_axis_a = ax
|
|
217
|
+
self._state.merged_concat_axis = None
|
|
218
|
+
else:
|
|
219
|
+
if self._state.concat_axis_b is None or not np.array_equal(self._state.concat_axis_b.data, ax.data):
|
|
220
|
+
self._state.concat_axis_b = ax
|
|
221
|
+
self._state.merged_concat_axis = None
|
|
222
|
+
|
|
223
|
+
def _build_merged_concat_axis(self) -> CoordinateAxis | None:
|
|
224
|
+
"""Build the merged CoordinateAxis from the two cached per-input axes."""
|
|
225
|
+
if self._state.concat_axis_a is None or self._state.concat_axis_b is None:
|
|
226
|
+
return None
|
|
227
|
+
if self.settings.relabel_axis:
|
|
228
|
+
labels_a = np.array([str(lbl) + self.settings.label_a for lbl in self._state.concat_axis_a.data])
|
|
229
|
+
labels_b = np.array([str(lbl) + self.settings.label_b for lbl in self._state.concat_axis_b.data])
|
|
230
|
+
else:
|
|
231
|
+
labels_a = self._state.concat_axis_a.data
|
|
232
|
+
labels_b = self._state.concat_axis_b.data
|
|
233
|
+
return CoordinateAxis(
|
|
234
|
+
data=np.concatenate([labels_a, labels_b]),
|
|
235
|
+
dims=self._state.concat_axis_a.dims,
|
|
236
|
+
unit=self._state.concat_axis_a.unit,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# -- Core merge logic ---------------------------------------------------
|
|
240
|
+
|
|
241
|
+
def _try_merge(self) -> AxisArray | None:
|
|
242
|
+
"""Align and read from both buffers, returning the merged result.
|
|
243
|
+
|
|
244
|
+
Initial alignment is performed once. After the first successful
|
|
245
|
+
merge the two streams are assumed to share a common clock and
|
|
246
|
+
never drop samples, so we simply read
|
|
247
|
+
``min(available_a, available_b)`` on every subsequent call.
|
|
248
|
+
"""
|
|
249
|
+
if self._state.buf_a is None or self._state.buf_b is None:
|
|
250
|
+
return None
|
|
251
|
+
if self._state.buf_a.is_empty() or self._state.buf_b.is_empty():
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
gain = self._state.gain
|
|
255
|
+
|
|
256
|
+
# --- Initial alignment (runs only until the first successful merge) ---
|
|
257
|
+
if not self._state.aligned:
|
|
258
|
+
first_a = self._state.buf_a.axis_first_value
|
|
259
|
+
final_a = self._state.buf_a.axis_final_value
|
|
260
|
+
first_b = self._state.buf_b.axis_first_value
|
|
261
|
+
final_b = self._state.buf_b.axis_final_value
|
|
262
|
+
|
|
263
|
+
overlap_start = max(first_a, first_b)
|
|
264
|
+
overlap_end = min(final_a, final_b)
|
|
265
|
+
|
|
266
|
+
if overlap_end < overlap_start - gain / 2:
|
|
267
|
+
if final_a < first_b:
|
|
268
|
+
self._state.buf_a.seek(self._state.buf_a.available())
|
|
269
|
+
elif final_b < first_a:
|
|
270
|
+
self._state.buf_b.seek(self._state.buf_b.available())
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
if first_a < overlap_start - gain / 2:
|
|
274
|
+
self._state.buf_a.seek(int(round((overlap_start - first_a) / gain)))
|
|
275
|
+
if first_b < overlap_start - gain / 2:
|
|
276
|
+
self._state.buf_b.seek(int(round((overlap_start - first_b) / gain)))
|
|
277
|
+
|
|
278
|
+
# --- Read aligned samples ---
|
|
279
|
+
n_read = min(self._state.buf_a.available(), self._state.buf_b.available())
|
|
280
|
+
if n_read <= 0:
|
|
281
|
+
return None
|
|
282
|
+
|
|
283
|
+
aa_a = self._state.buf_a.read(n_read)
|
|
284
|
+
aa_b = self._state.buf_b.read(n_read)
|
|
285
|
+
if aa_a is None or aa_b is None:
|
|
286
|
+
return None
|
|
287
|
+
|
|
288
|
+
if not self._state.aligned:
|
|
289
|
+
axis_a = aa_a.axes.get(self._state.align_axis)
|
|
290
|
+
axis_b = aa_b.axes.get(self._state.align_axis)
|
|
291
|
+
if axis_a is not None and axis_b is not None:
|
|
292
|
+
off_a = axis_a.value(0) if hasattr(axis_a, "value") else None
|
|
293
|
+
off_b = axis_b.value(0) if hasattr(axis_b, "value") else None
|
|
294
|
+
if off_a is not None and off_b is not None:
|
|
295
|
+
if not np.isclose(off_a, off_b, atol=abs(gain) * 1e-6):
|
|
296
|
+
raise RuntimeError(
|
|
297
|
+
f"Offset mismatch after alignment: " f"off_a={off_a}, off_b={off_b}, gain={gain}"
|
|
298
|
+
)
|
|
299
|
+
self._state.aligned = True
|
|
300
|
+
|
|
301
|
+
return self._concat(aa_a, aa_b)
|
|
302
|
+
|
|
303
|
+
def _concat(self, a: AxisArray, b: AxisArray) -> AxisArray:
|
|
304
|
+
"""Concatenate *a* and *b* along the configured merge axis."""
|
|
305
|
+
merge_dim = self.settings.axis
|
|
306
|
+
|
|
307
|
+
# If the merge dim doesn't exist in an input, add it as a trailing axis.
|
|
308
|
+
if merge_dim not in a.dims:
|
|
309
|
+
xp = get_namespace(a.data)
|
|
310
|
+
a = replace(a, data=xp.expand_dims(a.data, axis=-1), dims=[*a.dims, merge_dim])
|
|
311
|
+
if merge_dim not in b.dims:
|
|
312
|
+
xp = get_namespace(b.data)
|
|
313
|
+
b = replace(b, data=xp.expand_dims(b.data, axis=-1), dims=[*b.dims, merge_dim])
|
|
314
|
+
|
|
315
|
+
# Use the cached merged axis (rebuilt lazily when labels change).
|
|
316
|
+
if self._state.merged_concat_axis is None:
|
|
317
|
+
self._state.merged_concat_axis = self._build_merged_concat_axis()
|
|
318
|
+
|
|
319
|
+
key = self.settings.new_key if self.settings.new_key is not None else a.key
|
|
320
|
+
result = AxisArray.concatenate(a, b, dim=merge_dim, axis=self._state.merged_concat_axis)
|
|
321
|
+
if key != result.key:
|
|
322
|
+
result = replace(result, key=key)
|
|
323
|
+
return result
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class Merge(BaseProcessorUnit[MergeSettings]):
|
|
327
|
+
"""Merge two AxisArray streams by time-aligning and concatenating along a non-time axis.
|
|
328
|
+
|
|
329
|
+
Input A routes through the processor's ``__acall__`` (triggering
|
|
330
|
+
hash-based reset when the stream structure changes). Input B
|
|
331
|
+
routes through ``push_b`` which independently tracks its own structure.
|
|
332
|
+
|
|
333
|
+
Inherits ``INPUT_SETTINGS`` and ``on_settings`` → ``create_processor``
|
|
334
|
+
from :class:`BaseProcessorUnit`.
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
SETTINGS = MergeSettings
|
|
338
|
+
|
|
339
|
+
INPUT_SIGNAL_A = ez.InputStream(AxisArray)
|
|
340
|
+
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
|
|
341
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
342
|
+
|
|
343
|
+
def create_processor(self) -> None:
|
|
344
|
+
self.processor = MergeProcessor(settings=self.SETTINGS)
|
|
345
|
+
|
|
346
|
+
@ez.subscriber(INPUT_SIGNAL_A, zero_copy=True)
|
|
347
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
348
|
+
async def on_a(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
349
|
+
result = await self.processor.__acall__(msg)
|
|
350
|
+
if result is not None:
|
|
351
|
+
yield self.OUTPUT_SIGNAL, result
|
|
352
|
+
|
|
353
|
+
@ez.subscriber(INPUT_SIGNAL_B, zero_copy=True)
|
|
354
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
355
|
+
async def on_b(self, msg: AxisArray) -> typing.AsyncGenerator:
|
|
356
|
+
result = self.processor.push_b(msg)
|
|
357
|
+
if result is not None:
|
|
358
|
+
yield self.OUTPUT_SIGNAL, result
|
ezmsg/sigproc/quantize.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import ezmsg.core as ez
|
|
2
|
-
|
|
2
|
+
from array_api_compat import get_namespace
|
|
3
3
|
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
|
|
4
4
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
5
5
|
|
|
@@ -33,32 +33,33 @@ class QuantizeTransformer(BaseTransformer[QuantizeSettings, AxisArray, AxisArray
|
|
|
33
33
|
self,
|
|
34
34
|
message: AxisArray,
|
|
35
35
|
) -> AxisArray:
|
|
36
|
+
xp = get_namespace(message.data)
|
|
36
37
|
expected_range = self.settings.max_val - self.settings.min_val
|
|
37
38
|
scale_factor = 2**self.settings.bits - 1
|
|
38
39
|
clip_max = self.settings.max_val
|
|
39
40
|
|
|
40
41
|
# Determine appropriate integer type based on bits
|
|
41
42
|
if self.settings.bits <= 1:
|
|
42
|
-
dtype = bool
|
|
43
|
+
dtype = xp.bool
|
|
43
44
|
elif self.settings.bits <= 8:
|
|
44
|
-
dtype =
|
|
45
|
+
dtype = xp.uint8
|
|
45
46
|
elif self.settings.bits <= 16:
|
|
46
|
-
dtype =
|
|
47
|
+
dtype = xp.uint16
|
|
47
48
|
elif self.settings.bits <= 32:
|
|
48
|
-
dtype =
|
|
49
|
+
dtype = xp.uint32
|
|
49
50
|
else:
|
|
50
|
-
dtype =
|
|
51
|
+
dtype = xp.uint64
|
|
51
52
|
if self.settings.bits == 64:
|
|
52
53
|
# The practical upper bound before converting to int is: 2**64 - 1025
|
|
53
54
|
# Anything larger will wrap around to 0.
|
|
54
55
|
#
|
|
55
56
|
clip_max *= 1 - 2e-16
|
|
56
57
|
|
|
57
|
-
data = message.data
|
|
58
|
+
data = xp.clip(message.data, self.settings.min_val, clip_max)
|
|
58
59
|
data = (data - self.settings.min_val) / expected_range
|
|
59
60
|
|
|
60
61
|
# Scale to the quantized range [0, 2^bits - 1]
|
|
61
|
-
data =
|
|
62
|
+
data = xp.round(scale_factor * data).astype(dtype)
|
|
62
63
|
|
|
63
64
|
# Create a new AxisArray with the quantized data
|
|
64
65
|
return replace(message, data=data)
|
ezmsg/sigproc/rollingscaler.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
import math
|
|
1
2
|
from collections import deque
|
|
2
3
|
|
|
3
4
|
import ezmsg.core as ez
|
|
4
|
-
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
+
from array_api_compat import get_namespace
|
|
6
7
|
from ezmsg.baseproc import (
|
|
7
8
|
BaseAdaptiveTransformer,
|
|
8
9
|
BaseAdaptiveTransformerUnit,
|
|
@@ -111,12 +112,13 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
|
|
|
111
112
|
return hash((message.key, samp_shape, gain))
|
|
112
113
|
|
|
113
114
|
def _reset_state(self, message: AxisArray) -> None:
|
|
115
|
+
xp = get_namespace(message.data)
|
|
114
116
|
ch = message.data.shape[-1]
|
|
115
|
-
self._state.mean =
|
|
117
|
+
self._state.mean = xp.zeros(ch, dtype=xp.float64)
|
|
116
118
|
self._state.N = 0
|
|
117
|
-
self._state.M2 =
|
|
119
|
+
self._state.M2 = xp.zeros(ch, dtype=xp.float64)
|
|
118
120
|
self._state.k_samples = (
|
|
119
|
-
|
|
121
|
+
math.ceil(self.settings.window_size / message.axes[self.settings.axis].gain)
|
|
120
122
|
if self.settings.window_size is not None
|
|
121
123
|
else self.settings.k_samples
|
|
122
124
|
)
|
|
@@ -127,7 +129,7 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
|
|
|
127
129
|
ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.")
|
|
128
130
|
self._state.samples = deque(maxlen=self._state.k_samples)
|
|
129
131
|
self._state.min_samples = (
|
|
130
|
-
|
|
132
|
+
math.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain)
|
|
131
133
|
if self.settings.window_size is not None
|
|
132
134
|
else self.settings.min_samples
|
|
133
135
|
)
|
|
@@ -136,10 +138,11 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
|
|
|
136
138
|
self._state.min_samples = self._state.k_samples
|
|
137
139
|
|
|
138
140
|
def _add_batch_stats(self, x: npt.NDArray) -> None:
|
|
139
|
-
|
|
141
|
+
xp = get_namespace(x)
|
|
142
|
+
x = xp.asarray(x, dtype=xp.float64)
|
|
140
143
|
n_b = x.shape[0]
|
|
141
|
-
mean_b =
|
|
142
|
-
M2_b =
|
|
144
|
+
mean_b = xp.mean(x, axis=0)
|
|
145
|
+
M2_b = xp.sum((x - mean_b) ** 2, axis=0)
|
|
143
146
|
|
|
144
147
|
if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples:
|
|
145
148
|
n_old, mean_old, M2_old = self._state.samples.popleft()
|
|
@@ -148,8 +151,8 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
|
|
|
148
151
|
|
|
149
152
|
if N_new <= 0:
|
|
150
153
|
self._state.N = 0
|
|
151
|
-
self._state.mean =
|
|
152
|
-
self._state.M2 =
|
|
154
|
+
self._state.mean = xp.zeros_like(self._state.mean)
|
|
155
|
+
self._state.M2 = xp.zeros_like(self._state.M2)
|
|
153
156
|
else:
|
|
154
157
|
delta = mean_old - self._state.mean
|
|
155
158
|
self._state.N = N_new
|
|
@@ -170,32 +173,37 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
|
|
|
170
173
|
self._add_batch_stats(x)
|
|
171
174
|
|
|
172
175
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
176
|
+
xp = get_namespace(message.data)
|
|
177
|
+
|
|
173
178
|
if self._state.N == 0 or self._state.N < self._state.min_samples:
|
|
174
179
|
if self.settings.update_with_signal:
|
|
175
180
|
x = message.data
|
|
176
181
|
if self.settings.artifact_z_thresh is not None and self._state.N > 0:
|
|
177
182
|
varis = self._state.M2 / self._state.N
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
183
|
+
raw_std = varis**0.5
|
|
184
|
+
std = xp.where(xp.isnan(raw_std), raw_std, xp.clip(raw_std, min=1e-8))
|
|
185
|
+
z = xp.abs((x - self._state.mean) / std)
|
|
186
|
+
mask = xp.any(z > self.settings.artifact_z_thresh, axis=1)
|
|
181
187
|
x = x[~mask]
|
|
182
188
|
if x.size > 0:
|
|
183
189
|
self._add_batch_stats(x)
|
|
184
190
|
return message
|
|
185
191
|
|
|
186
192
|
varis = self._state.M2 / self._state.N
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
result =
|
|
193
|
+
raw_std = varis**0.5
|
|
194
|
+
# Preserve NaN from negative variance (will be caught below), clip positive std to floor
|
|
195
|
+
std = xp.where(xp.isnan(raw_std), raw_std, xp.clip(raw_std, min=1e-8))
|
|
196
|
+
result = (message.data - self._state.mean) / std
|
|
197
|
+
# Replace NaN/inf with 0 (equivalent to nan_to_num with nan=0, posinf=0, neginf=0)
|
|
198
|
+
result = xp.where(xp.isfinite(result), result, xp.asarray(0.0, dtype=result.dtype))
|
|
191
199
|
if self.settings.clip is not None:
|
|
192
|
-
result =
|
|
200
|
+
result = xp.clip(result, -self.settings.clip, self.settings.clip)
|
|
193
201
|
|
|
194
202
|
if self.settings.update_with_signal:
|
|
195
203
|
x = message.data
|
|
196
204
|
if self.settings.artifact_z_thresh is not None:
|
|
197
|
-
z_scores =
|
|
198
|
-
mask =
|
|
205
|
+
z_scores = xp.abs((x - self._state.mean) / std)
|
|
206
|
+
mask = xp.any(z_scores > self.settings.artifact_z_thresh, axis=1)
|
|
199
207
|
x = x[~mask]
|
|
200
208
|
if x.size > 0:
|
|
201
209
|
self._add_batch_stats(x)
|
ezmsg/sigproc/scaler.py
CHANGED
|
@@ -2,6 +2,7 @@ import typing
|
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
import numpy as np
|
|
5
|
+
from array_api_compat import get_namespace
|
|
5
6
|
from ezmsg.baseproc import (
|
|
6
7
|
BaseStatefulTransformer,
|
|
7
8
|
BaseTransformerUnit,
|
|
@@ -132,15 +133,20 @@ class AdaptiveStandardScalerTransformer(
|
|
|
132
133
|
self._state.vars_sq_ewma.settings = replace(self._state.vars_sq_ewma.settings, accumulate=value)
|
|
133
134
|
|
|
134
135
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
136
|
+
xp = get_namespace(message.data)
|
|
137
|
+
|
|
135
138
|
# Update step (respects accumulate setting via child EWMAs)
|
|
136
139
|
mean_message = self._state.samps_ewma(message)
|
|
137
140
|
var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2))
|
|
138
141
|
|
|
139
|
-
# Get step
|
|
142
|
+
# Get step: safe division avoids warnings from zero/negative variance
|
|
140
143
|
varis = var_sq_message.data - mean_message.data**2
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
+
std = varis**0.5
|
|
145
|
+
mask = std > 0
|
|
146
|
+
safe_std = xp.where(mask, std, xp.asarray(1.0, dtype=std.dtype))
|
|
147
|
+
result = xp.where(
|
|
148
|
+
mask, (message.data - mean_message.data) / safe_std, xp.asarray(0.0, dtype=message.data.dtype)
|
|
149
|
+
)
|
|
144
150
|
return replace(message, data=result)
|
|
145
151
|
|
|
146
152
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-sigproc
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.13.0
|
|
4
4
|
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
5
|
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
|
|
2
|
-
ezmsg/sigproc/__version__.py,sha256=
|
|
2
|
+
ezmsg/sigproc/__version__.py,sha256=_4LOjlEcfZzfuqIlglDZmVBPO4LyQ8P97qO716YoUL8,706
|
|
3
3
|
ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
|
|
4
4
|
ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
|
|
5
|
-
ezmsg/sigproc/affinetransform.py,sha256=
|
|
5
|
+
ezmsg/sigproc/affinetransform.py,sha256=ZugiQg89Ly1I9SDgf0ZzgU2XdwVDmPrU7-orO9yrt7w,20210
|
|
6
6
|
ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,9362
|
|
7
7
|
ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
|
|
8
8
|
ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
|
|
@@ -29,12 +29,13 @@ ezmsg/sigproc/firfilter.py,sha256=7r_I476nYuixJsuwc_hQ0Fbq8WB0gnYBZUKs3zultOQ,37
|
|
|
29
29
|
ezmsg/sigproc/gaussiansmoothing.py,sha256=BfXm9YQoOtieM4ABK2KRgxeQz055rd7mqtTVqmjT3Rk,2672
|
|
30
30
|
ezmsg/sigproc/kaiser.py,sha256=dIwgHbLXUHPtdotsGNLE9VG_clhcMgvVnSoFkMVgF9M,3483
|
|
31
31
|
ezmsg/sigproc/linear.py,sha256=b3NRzQNBvdU2jqenZT9XXFHax9Mavbj2xFiVxOwl1Ms,4662
|
|
32
|
+
ezmsg/sigproc/merge.py,sha256=LmuN3LDIZF7DynMcjLp7eGc2G3Yxks9Zd8-luSqqXuA,15436
|
|
32
33
|
ezmsg/sigproc/messages.py,sha256=KQczHTeifn4BZycChN8ZcpfZoQW3lC_xuFmN72QT97s,925
|
|
33
|
-
ezmsg/sigproc/quantize.py,sha256=
|
|
34
|
+
ezmsg/sigproc/quantize.py,sha256=y7T4_67BHZluX3gyl2anp8iL6EEI6JvsK7Pmp1vapsk,2268
|
|
34
35
|
ezmsg/sigproc/resample.py,sha256=3mm9pvxryNVhQuTCIMW3ToUkUfbVOCsIgvXUiurit1Y,11389
|
|
35
|
-
ezmsg/sigproc/rollingscaler.py,sha256=
|
|
36
|
+
ezmsg/sigproc/rollingscaler.py,sha256=GcLctVAWTmx9J39r0-dt3e7C_hs25s7M0dDnKiGhkC4,8955
|
|
36
37
|
ezmsg/sigproc/sampler.py,sha256=iOk2YoUX22u9iTjFKimzP5V074RDBVcmswgfyxvZRZo,10761
|
|
37
|
-
ezmsg/sigproc/scaler.py,sha256=
|
|
38
|
+
ezmsg/sigproc/scaler.py,sha256=kVXjRbqoxJ5yJICGsGagRXYIDW3-oihSnbBj-n3s55o,6816
|
|
38
39
|
ezmsg/sigproc/signalinjector.py,sha256=mB62H2b-ScgPtH1jajEpxgDHqdb-RKekQfgyNncsE8Y,2874
|
|
39
40
|
ezmsg/sigproc/singlebandpow.py,sha256=BVlWhFI6zU3ME3EVdZbwf-FMz1d2sfuNFDKXs1hn5HM,4353
|
|
40
41
|
ezmsg/sigproc/slicer.py,sha256=xLXxWf722V08ytVwvPimYjDKKj0pkC2HjdgCVaoaOvs,5195
|
|
@@ -61,7 +62,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
|
|
|
61
62
|
ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
|
|
62
63
|
ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
|
|
63
64
|
ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
|
|
64
|
-
ezmsg_sigproc-2.
|
|
65
|
-
ezmsg_sigproc-2.
|
|
66
|
-
ezmsg_sigproc-2.
|
|
67
|
-
ezmsg_sigproc-2.
|
|
65
|
+
ezmsg_sigproc-2.13.0.dist-info/METADATA,sha256=RXENX541lABAic8oUDuT8vQwx9nlWY9JETyXYKxdeTQ,1909
|
|
66
|
+
ezmsg_sigproc-2.13.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
67
|
+
ezmsg_sigproc-2.13.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
|
|
68
|
+
ezmsg_sigproc-2.13.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|