ezmsg-sigproc 2.12.0__py3-none-any.whl → 2.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/affinetransform.py +333 -47
- {ezmsg_sigproc-2.12.0.dist-info → ezmsg_sigproc-2.13.1.dist-info}/METADATA +1 -1
- {ezmsg_sigproc-2.12.0.dist-info → ezmsg_sigproc-2.13.1.dist-info}/RECORD +6 -6
- {ezmsg_sigproc-2.12.0.dist-info → ezmsg_sigproc-2.13.1.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-2.12.0.dist-info → ezmsg_sigproc-2.13.1.dist-info}/licenses/LICENSE +0 -0
ezmsg/sigproc/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.
|
|
32
|
-
__version_tuple__ = version_tuple = (2,
|
|
31
|
+
__version__ = version = '2.13.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 13, 1)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
ezmsg/sigproc/affinetransform.py
CHANGED
|
@@ -17,7 +17,6 @@ import numpy.typing as npt
|
|
|
17
17
|
from array_api_compat import get_namespace
|
|
18
18
|
from ezmsg.baseproc import (
|
|
19
19
|
BaseStatefulTransformer,
|
|
20
|
-
BaseTransformer,
|
|
21
20
|
BaseTransformerUnit,
|
|
22
21
|
processor_state,
|
|
23
22
|
)
|
|
@@ -25,6 +24,117 @@ from ezmsg.util.messages.axisarray import AxisArray, AxisBase
|
|
|
25
24
|
from ezmsg.util.messages.util import replace
|
|
26
25
|
|
|
27
26
|
|
|
27
|
+
def _find_block_diagonal_clusters(weights: np.ndarray) -> list[tuple[np.ndarray, np.ndarray]] | None:
|
|
28
|
+
"""Detect block-diagonal structure in a weight matrix.
|
|
29
|
+
|
|
30
|
+
Finds connected components in the bipartite graph of non-zero weights,
|
|
31
|
+
where input channels and output channels are separate node sets.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
weights: 2-D weight matrix of shape (n_in, n_out).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
List of (input_indices, output_indices) tuples, one per block, or
|
|
38
|
+
None if the matrix is not block-diagonal (single connected component).
|
|
39
|
+
"""
|
|
40
|
+
if weights.ndim != 2:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
n_in, n_out = weights.shape
|
|
44
|
+
if n_in + n_out <= 2:
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
from scipy.sparse import coo_matrix
|
|
48
|
+
from scipy.sparse.csgraph import connected_components
|
|
49
|
+
|
|
50
|
+
rows, cols = np.nonzero(weights)
|
|
51
|
+
if len(rows) == 0:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
# Bipartite graph: input nodes [0, n_in), output nodes [n_in, n_in + n_out)
|
|
55
|
+
shifted_cols = cols + n_in
|
|
56
|
+
adj_rows = np.concatenate([rows, shifted_cols])
|
|
57
|
+
adj_cols = np.concatenate([shifted_cols, rows])
|
|
58
|
+
adj_data = np.ones(len(adj_rows), dtype=bool)
|
|
59
|
+
n_nodes = n_in + n_out
|
|
60
|
+
adj = coo_matrix((adj_data, (adj_rows, adj_cols)), shape=(n_nodes, n_nodes))
|
|
61
|
+
|
|
62
|
+
n_components, labels = connected_components(adj, directed=False)
|
|
63
|
+
|
|
64
|
+
if n_components <= 1:
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
clusters = []
|
|
68
|
+
for comp in range(n_components):
|
|
69
|
+
members = np.where(labels == comp)[0]
|
|
70
|
+
in_idx = np.sort(members[members < n_in])
|
|
71
|
+
out_idx = np.sort(members[members >= n_in] - n_in)
|
|
72
|
+
if len(in_idx) > 0 and len(out_idx) > 0:
|
|
73
|
+
clusters.append((in_idx, out_idx))
|
|
74
|
+
|
|
75
|
+
return clusters if len(clusters) > 1 else None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _max_cross_cluster_weight(weights: np.ndarray, clusters: list[tuple[np.ndarray, np.ndarray]]) -> float:
|
|
79
|
+
"""Return the maximum absolute weight between different clusters."""
|
|
80
|
+
mask = np.zeros(weights.shape, dtype=bool)
|
|
81
|
+
for in_idx, out_idx in clusters:
|
|
82
|
+
mask[np.ix_(in_idx, out_idx)] = True
|
|
83
|
+
cross = np.abs(weights[~mask])
|
|
84
|
+
return float(cross.max()) if cross.size > 0 else 0.0
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _merge_small_clusters(
|
|
88
|
+
clusters: list[tuple[np.ndarray, np.ndarray]], min_size: int
|
|
89
|
+
) -> list[tuple[np.ndarray, np.ndarray]]:
|
|
90
|
+
"""Merge clusters smaller than *min_size* into combined groups.
|
|
91
|
+
|
|
92
|
+
Small clusters are greedily concatenated until each merged group has
|
|
93
|
+
at least *min_size* channels (measured as ``max(n_in, n_out)``).
|
|
94
|
+
Any leftover small clusters that don't reach the threshold are
|
|
95
|
+
combined into a final group.
|
|
96
|
+
|
|
97
|
+
The merged group's sub-weight-matrix will contain the original small
|
|
98
|
+
diagonal blocks with zeros between them — a dense matmul on that
|
|
99
|
+
sub-matrix is cheaper than iterating over many tiny matmuls.
|
|
100
|
+
"""
|
|
101
|
+
if min_size <= 1:
|
|
102
|
+
return clusters
|
|
103
|
+
|
|
104
|
+
large = []
|
|
105
|
+
small = []
|
|
106
|
+
for cluster in clusters:
|
|
107
|
+
in_idx, out_idx = cluster
|
|
108
|
+
if max(len(in_idx), len(out_idx)) >= min_size:
|
|
109
|
+
large.append(cluster)
|
|
110
|
+
else:
|
|
111
|
+
small.append(cluster)
|
|
112
|
+
|
|
113
|
+
if not small:
|
|
114
|
+
return clusters
|
|
115
|
+
|
|
116
|
+
current_in: list[np.ndarray] = []
|
|
117
|
+
current_out: list[np.ndarray] = []
|
|
118
|
+
current_in_size = 0
|
|
119
|
+
current_out_size = 0
|
|
120
|
+
for in_idx, out_idx in small:
|
|
121
|
+
current_in.append(in_idx)
|
|
122
|
+
current_out.append(out_idx)
|
|
123
|
+
current_in_size += len(in_idx)
|
|
124
|
+
current_out_size += len(out_idx)
|
|
125
|
+
if max(current_in_size, current_out_size) >= min_size:
|
|
126
|
+
large.append((np.sort(np.concatenate(current_in)), np.sort(np.concatenate(current_out))))
|
|
127
|
+
current_in = []
|
|
128
|
+
current_out = []
|
|
129
|
+
current_in_size = 0
|
|
130
|
+
current_out_size = 0
|
|
131
|
+
|
|
132
|
+
if current_in:
|
|
133
|
+
large.append((np.sort(np.concatenate(current_in)), np.sort(np.concatenate(current_out))))
|
|
134
|
+
|
|
135
|
+
return large
|
|
136
|
+
|
|
137
|
+
|
|
28
138
|
class AffineTransformSettings(ez.Settings):
|
|
29
139
|
"""
|
|
30
140
|
Settings for :obj:`AffineTransform`.
|
|
@@ -39,11 +149,32 @@ class AffineTransformSettings(ez.Settings):
|
|
|
39
149
|
right_multiply: bool = True
|
|
40
150
|
"""Set False to transpose the weights before applying."""
|
|
41
151
|
|
|
152
|
+
channel_clusters: list[list[int]] | None = None
|
|
153
|
+
"""Optional explicit input channel cluster specification for block-diagonal optimization.
|
|
154
|
+
|
|
155
|
+
Each element is a list of input channel indices forming one cluster. The
|
|
156
|
+
corresponding output indices are derived automatically from the non-zero
|
|
157
|
+
columns of the weight matrix for those input rows.
|
|
158
|
+
|
|
159
|
+
When provided, the weight matrix is decomposed into per-cluster sub-matrices
|
|
160
|
+
and multiplied separately, which is faster when cross-cluster weights are zero.
|
|
161
|
+
|
|
162
|
+
If None, block-diagonal structure is auto-detected from the zero pattern
|
|
163
|
+
of the weights."""
|
|
164
|
+
|
|
165
|
+
min_cluster_size: int = 32
|
|
166
|
+
"""Minimum number of channels per cluster for the block-diagonal optimization.
|
|
167
|
+
Clusters smaller than this are greedily merged together to avoid excessive
|
|
168
|
+
Python loop overhead. Set to 1 to disable merging."""
|
|
169
|
+
|
|
42
170
|
|
|
43
171
|
@processor_state
|
|
44
172
|
class AffineTransformState:
|
|
45
173
|
weights: npt.NDArray | None = None
|
|
46
174
|
new_axis: AxisBase | None = None
|
|
175
|
+
n_out: int = 0
|
|
176
|
+
clusters: list | None = None
|
|
177
|
+
"""list of (in_indices_xp, out_indices_xp, sub_weights_xp) tuples when block-diagonal."""
|
|
47
178
|
|
|
48
179
|
|
|
49
180
|
class AffineTransformTransformer(
|
|
@@ -84,13 +215,15 @@ class AffineTransformTransformer(
|
|
|
84
215
|
if weights is not None:
|
|
85
216
|
weights = np.ascontiguousarray(weights)
|
|
86
217
|
|
|
87
|
-
|
|
218
|
+
# Cluster detection + weight storage (delegated)
|
|
219
|
+
self.set_weights(weights, recalc_clusters=True)
|
|
88
220
|
|
|
221
|
+
# --- Axis label handling (for non-square transforms, non-cluster path) ---
|
|
222
|
+
n_in, n_out = weights.shape
|
|
89
223
|
axis = self.settings.axis or message.dims[-1]
|
|
90
|
-
if axis in message.axes and hasattr(message.axes[axis], "data") and
|
|
224
|
+
if axis in message.axes and hasattr(message.axes[axis], "data") and n_in != n_out:
|
|
91
225
|
in_labels = message.axes[axis].data
|
|
92
226
|
new_labels = []
|
|
93
|
-
n_in, n_out = weights.shape
|
|
94
227
|
if len(in_labels) != n_in:
|
|
95
228
|
ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.")
|
|
96
229
|
else:
|
|
@@ -112,10 +245,118 @@ class AffineTransformTransformer(
|
|
|
112
245
|
|
|
113
246
|
self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
|
|
114
247
|
|
|
115
|
-
# Convert
|
|
248
|
+
# Convert to match message.data namespace for efficient operations in _process
|
|
116
249
|
xp = get_namespace(message.data)
|
|
117
250
|
if self._state.weights is not None:
|
|
118
251
|
self._state.weights = xp.asarray(self._state.weights)
|
|
252
|
+
if self._state.clusters is not None:
|
|
253
|
+
self._state.clusters = [
|
|
254
|
+
(xp.asarray(in_idx), xp.asarray(out_idx), xp.asarray(sub_w))
|
|
255
|
+
for in_idx, out_idx, sub_w in self._state.clusters
|
|
256
|
+
]
|
|
257
|
+
|
|
258
|
+
def set_weights(self, weights, *, recalc_clusters=False) -> None:
|
|
259
|
+
"""Replace weight values, optionally recalculating cluster decomposition.
|
|
260
|
+
|
|
261
|
+
*weights* must be in **canonical orientation** (``right_multiply``
|
|
262
|
+
already applied by the caller or by ``_reset_state``). The array may
|
|
263
|
+
live in any Array-API namespace (NumPy, CuPy, etc.).
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
weights: Weight matrix in canonical orientation.
|
|
267
|
+
recalc_clusters: When True, re-run block-diagonal cluster detection
|
|
268
|
+
and store the new decomposition. When False (default), reuse
|
|
269
|
+
the existing cluster structure and only update weight values.
|
|
270
|
+
"""
|
|
271
|
+
if recalc_clusters:
|
|
272
|
+
# Note: If weights were scipy.sparse BSR then maybe we could automate this next part.
|
|
273
|
+
# However, that would break compatibility with Array API.
|
|
274
|
+
|
|
275
|
+
# --- Block-diagonal cluster detection ---
|
|
276
|
+
# Clusters are a list of (input_indices, output_indices) tuples.
|
|
277
|
+
w_np = np.ascontiguousarray(weights)
|
|
278
|
+
n_in, n_out = w_np.shape
|
|
279
|
+
if self.settings.channel_clusters is not None:
|
|
280
|
+
# Validate input index bounds
|
|
281
|
+
all_in = np.concatenate([np.asarray(group) for group in self.settings.channel_clusters])
|
|
282
|
+
if np.any((all_in < 0) | (all_in >= n_in)):
|
|
283
|
+
raise ValueError(
|
|
284
|
+
"channel_clusters contains out-of-range input indices " f"(valid range: 0..{n_in - 1})"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Derive output indices from non-zero weights for each input cluster
|
|
288
|
+
clusters = []
|
|
289
|
+
for group in self.settings.channel_clusters:
|
|
290
|
+
in_idx = np.asarray(group)
|
|
291
|
+
out_idx = np.where(np.any(w_np[in_idx, :] != 0, axis=0))[0]
|
|
292
|
+
clusters.append((in_idx, out_idx))
|
|
293
|
+
|
|
294
|
+
max_cross = _max_cross_cluster_weight(w_np, clusters)
|
|
295
|
+
if max_cross > 0:
|
|
296
|
+
ez.logger.warning(
|
|
297
|
+
f"Non-zero cross-cluster weights detected (max abs: {max_cross:.2e}). "
|
|
298
|
+
"These will be ignored in block-diagonal multiplication."
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
clusters = _find_block_diagonal_clusters(w_np)
|
|
302
|
+
if clusters is not None:
|
|
303
|
+
ez.logger.info(
|
|
304
|
+
f"Auto-detected {len(clusters)} block-diagonal clusters "
|
|
305
|
+
f"(sizes: {[(len(i), len(o)) for i, o in clusters]})"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Merge small clusters to avoid excessive loop overhead
|
|
309
|
+
if clusters is not None:
|
|
310
|
+
clusters = _merge_small_clusters(clusters, self.settings.min_cluster_size)
|
|
311
|
+
|
|
312
|
+
if clusters is not None and len(clusters) > 1:
|
|
313
|
+
self._state.n_out = n_out
|
|
314
|
+
self._state.clusters = [
|
|
315
|
+
(in_idx, out_idx, np.ascontiguousarray(w_np[np.ix_(in_idx, out_idx)]))
|
|
316
|
+
for in_idx, out_idx in clusters
|
|
317
|
+
]
|
|
318
|
+
self._state.weights = None
|
|
319
|
+
else:
|
|
320
|
+
self._state.weights = weights
|
|
321
|
+
self._state.clusters = None
|
|
322
|
+
else:
|
|
323
|
+
xp = get_namespace(weights)
|
|
324
|
+
if self._state.clusters is not None:
|
|
325
|
+
self._state.clusters = [
|
|
326
|
+
(in_idx, out_idx, xp.take(xp.take(weights, in_idx, axis=0), out_idx, axis=1))
|
|
327
|
+
for in_idx, out_idx, _ in self._state.clusters
|
|
328
|
+
]
|
|
329
|
+
else:
|
|
330
|
+
self._state.weights = weights
|
|
331
|
+
|
|
332
|
+
def _block_diagonal_matmul(self, xp, data, axis_idx):
|
|
333
|
+
"""Perform matmul using block-diagonal decomposition.
|
|
334
|
+
|
|
335
|
+
For each cluster, gathers input channels via ``xp.take``, performs a
|
|
336
|
+
matmul with the cluster's sub-weight matrix, and writes the result
|
|
337
|
+
directly into the pre-allocated output at the cluster's output indices.
|
|
338
|
+
Omitted output channels naturally remain zero.
|
|
339
|
+
"""
|
|
340
|
+
needs_permute = axis_idx not in [-1, data.ndim - 1]
|
|
341
|
+
if needs_permute:
|
|
342
|
+
dim_perm = list(range(data.ndim))
|
|
343
|
+
dim_perm.append(dim_perm.pop(axis_idx))
|
|
344
|
+
data = xp.permute_dims(data, dim_perm)
|
|
345
|
+
|
|
346
|
+
# Pre-allocate output (omitted channels stay zero)
|
|
347
|
+
out_shape = data.shape[:-1] + (self._state.n_out,)
|
|
348
|
+
result = xp.zeros(out_shape, dtype=data.dtype)
|
|
349
|
+
|
|
350
|
+
for in_idx, out_idx, sub_weights in self._state.clusters:
|
|
351
|
+
chunk = xp.take(data, in_idx, axis=data.ndim - 1)
|
|
352
|
+
result[..., out_idx] = xp.matmul(chunk, sub_weights)
|
|
353
|
+
|
|
354
|
+
if needs_permute:
|
|
355
|
+
inv_dim_perm = list(range(result.ndim))
|
|
356
|
+
inv_dim_perm.insert(axis_idx, inv_dim_perm.pop(-1))
|
|
357
|
+
result = xp.permute_dims(result, inv_dim_perm)
|
|
358
|
+
|
|
359
|
+
return result
|
|
119
360
|
|
|
120
361
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
121
362
|
xp = get_namespace(message.data)
|
|
@@ -123,22 +364,25 @@ class AffineTransformTransformer(
|
|
|
123
364
|
axis_idx = message.get_axis_idx(axis)
|
|
124
365
|
data = message.data
|
|
125
366
|
|
|
126
|
-
if
|
|
127
|
-
|
|
128
|
-
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
|
|
129
|
-
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
|
|
130
|
-
data = xp.concat((data, xp.ones(sample_shape, dtype=data.dtype)), axis=axis_idx)
|
|
131
|
-
|
|
132
|
-
if axis_idx in [-1, len(message.dims) - 1]:
|
|
133
|
-
data = xp.matmul(data, self._state.weights)
|
|
367
|
+
if self._state.clusters is not None:
|
|
368
|
+
data = self._block_diagonal_matmul(xp, data, axis_idx)
|
|
134
369
|
else:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
370
|
+
if data.shape[axis_idx] == (self._state.weights.shape[0] - 1):
|
|
371
|
+
# The weights are stacked A|B where A is the transform and B is a single row
|
|
372
|
+
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
|
|
373
|
+
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
|
|
374
|
+
data = xp.concat((data, xp.ones(sample_shape, dtype=data.dtype)), axis=axis_idx)
|
|
375
|
+
|
|
376
|
+
if axis_idx in [-1, len(message.dims) - 1]:
|
|
377
|
+
data = xp.matmul(data, self._state.weights)
|
|
378
|
+
else:
|
|
379
|
+
perm = list(range(data.ndim))
|
|
380
|
+
perm.append(perm.pop(axis_idx))
|
|
381
|
+
data = xp.permute_dims(data, perm)
|
|
382
|
+
data = xp.matmul(data, self._state.weights)
|
|
383
|
+
inv_perm = list(range(data.ndim))
|
|
384
|
+
inv_perm.insert(axis_idx, inv_perm.pop(-1))
|
|
385
|
+
data = xp.permute_dims(data, inv_perm)
|
|
142
386
|
|
|
143
387
|
replace_kwargs = {"data": data}
|
|
144
388
|
if self._state.new_axis is not None:
|
|
@@ -155,6 +399,8 @@ def affine_transform(
|
|
|
155
399
|
weights: np.ndarray | str | Path,
|
|
156
400
|
axis: str | None = None,
|
|
157
401
|
right_multiply: bool = True,
|
|
402
|
+
channel_clusters: list[list[int]] | None = None,
|
|
403
|
+
min_cluster_size: int = 32,
|
|
158
404
|
) -> AffineTransformTransformer:
|
|
159
405
|
"""
|
|
160
406
|
Perform affine transformations on streaming data.
|
|
@@ -163,20 +409,25 @@ def affine_transform(
|
|
|
163
409
|
weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
|
|
164
410
|
axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
|
|
165
411
|
right_multiply: Set False to transpose the weights before applying.
|
|
412
|
+
channel_clusters: Optional explicit channel cluster specification. See
|
|
413
|
+
:attr:`AffineTransformSettings.channel_clusters`.
|
|
414
|
+
min_cluster_size: Minimum channels per cluster; smaller clusters are merged. See
|
|
415
|
+
:attr:`AffineTransformSettings.min_cluster_size`.
|
|
166
416
|
|
|
167
417
|
Returns:
|
|
168
418
|
:obj:`AffineTransformTransformer`.
|
|
169
419
|
"""
|
|
170
420
|
return AffineTransformTransformer(
|
|
171
|
-
AffineTransformSettings(
|
|
421
|
+
AffineTransformSettings(
|
|
422
|
+
weights=weights,
|
|
423
|
+
axis=axis,
|
|
424
|
+
right_multiply=right_multiply,
|
|
425
|
+
channel_clusters=channel_clusters,
|
|
426
|
+
min_cluster_size=min_cluster_size,
|
|
427
|
+
)
|
|
172
428
|
)
|
|
173
429
|
|
|
174
430
|
|
|
175
|
-
def zeros_for_noop(data, **ignore_kwargs):
|
|
176
|
-
xp = get_namespace(data)
|
|
177
|
-
return xp.zeros_like(data)
|
|
178
|
-
|
|
179
|
-
|
|
180
431
|
class CommonRereferenceSettings(ez.Settings):
|
|
181
432
|
"""
|
|
182
433
|
Settings for :obj:`CommonRereference`
|
|
@@ -191,8 +442,37 @@ class CommonRereferenceSettings(ez.Settings):
|
|
|
191
442
|
include_current: bool = True
|
|
192
443
|
"""Set False to exclude each channel from participating in the calculation of its reference."""
|
|
193
444
|
|
|
445
|
+
channel_clusters: list[list[int]] | None = None
|
|
446
|
+
"""Optional channel clusters for per-cluster rereferencing. Each element is a
|
|
447
|
+
list of channel indices forming one cluster. The common reference is computed
|
|
448
|
+
independently within each cluster. If None, all channels form a single cluster."""
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@processor_state
|
|
452
|
+
class CommonRereferenceState:
|
|
453
|
+
clusters: list | None = None
|
|
454
|
+
"""list of xp arrays of channel indices, one per cluster."""
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class CommonRereferenceTransformer(
|
|
458
|
+
BaseStatefulTransformer[CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceState]
|
|
459
|
+
):
|
|
460
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
461
|
+
axis = self.settings.axis or message.dims[-1]
|
|
462
|
+
axis_idx = message.get_axis_idx(axis)
|
|
463
|
+
return hash((message.key, message.data.shape[axis_idx]))
|
|
464
|
+
|
|
465
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
466
|
+
xp = get_namespace(message.data)
|
|
467
|
+
axis = self.settings.axis or message.dims[-1]
|
|
468
|
+
axis_idx = message.get_axis_idx(axis)
|
|
469
|
+
n_chans = message.data.shape[axis_idx]
|
|
470
|
+
|
|
471
|
+
if self.settings.channel_clusters is not None:
|
|
472
|
+
self._state.clusters = [xp.asarray(group) for group in self.settings.channel_clusters]
|
|
473
|
+
else:
|
|
474
|
+
self._state.clusters = [xp.arange(n_chans)]
|
|
194
475
|
|
|
195
|
-
class CommonRereferenceTransformer(BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]):
|
|
196
476
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
197
477
|
if self.settings.mode == "passthrough":
|
|
198
478
|
return message
|
|
@@ -200,27 +480,26 @@ class CommonRereferenceTransformer(BaseTransformer[CommonRereferenceSettings, Ax
|
|
|
200
480
|
xp = get_namespace(message.data)
|
|
201
481
|
axis = self.settings.axis or message.dims[-1]
|
|
202
482
|
axis_idx = message.get_axis_idx(axis)
|
|
483
|
+
func = {"mean": xp.mean, "median": np.median}[self.settings.mode]
|
|
484
|
+
|
|
485
|
+
# Use result_type to match dtype promotion from data - float operations.
|
|
486
|
+
out_dtype = np.result_type(message.data.dtype, np.float64)
|
|
487
|
+
output = xp.zeros(message.data.shape, dtype=out_dtype)
|
|
203
488
|
|
|
204
|
-
|
|
489
|
+
for cluster_idx in self._state.clusters:
|
|
490
|
+
cluster_data = xp.take(message.data, cluster_idx, axis=axis_idx)
|
|
491
|
+
ref_data = func(cluster_data, axis=axis_idx, keepdims=True)
|
|
205
492
|
|
|
206
|
-
|
|
493
|
+
if not self.settings.include_current:
|
|
494
|
+
N = cluster_data.shape[axis_idx]
|
|
495
|
+
ref_data = (N / (N - 1)) * ref_data - cluster_data / (N - 1)
|
|
207
496
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
# then we would have omitted the contribution of the current channel:
|
|
213
|
-
# `CAR[i] = x[0]/(N-1) + x[1]/(N-1) + ... x[i-1]/(N-1) + x[i+1]/(N-1) + ... + x[N-1]/(N-1)`
|
|
214
|
-
# The majority of the calculation is the same as when the current channel is included;
|
|
215
|
-
# we need only rescale CAR so the divisor is `N-1` instead of `N`, then subtract the contribution
|
|
216
|
-
# from the current channel (i.e., `x[i] / (N-1)`)
|
|
217
|
-
# i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
|
|
218
|
-
# We can use broadcasting subtraction instead of looping over channels.
|
|
219
|
-
N = message.data.shape[axis_idx]
|
|
220
|
-
ref_data = (N / (N - 1)) * ref_data - message.data / (N - 1)
|
|
221
|
-
# Note: I profiled using AffineTransformTransformer; it's ~30x slower than this implementation.
|
|
497
|
+
# Write per-cluster result into output at the correct axis position
|
|
498
|
+
idx = [slice(None)] * output.ndim
|
|
499
|
+
idx[axis_idx] = cluster_idx
|
|
500
|
+
output[tuple(idx)] = cluster_data - ref_data
|
|
222
501
|
|
|
223
|
-
return replace(message, data=
|
|
502
|
+
return replace(message, data=output)
|
|
224
503
|
|
|
225
504
|
|
|
226
505
|
class CommonRereference(
|
|
@@ -230,19 +509,26 @@ class CommonRereference(
|
|
|
230
509
|
|
|
231
510
|
|
|
232
511
|
def common_rereference(
|
|
233
|
-
mode: str = "mean",
|
|
512
|
+
mode: str = "mean",
|
|
513
|
+
axis: str | None = None,
|
|
514
|
+
include_current: bool = True,
|
|
515
|
+
channel_clusters: list[list[int]] | None = None,
|
|
234
516
|
) -> CommonRereferenceTransformer:
|
|
235
517
|
"""
|
|
236
518
|
Perform common average referencing (CAR) on streaming data.
|
|
237
519
|
|
|
238
520
|
Args:
|
|
239
521
|
mode: The statistical mode to apply -- either "mean" or "median"
|
|
240
|
-
axis: The name of
|
|
522
|
+
axis: The name of the axis to apply the transformation to.
|
|
241
523
|
include_current: Set False to exclude each channel from participating in the calculation of its reference.
|
|
524
|
+
channel_clusters: Optional channel clusters for per-cluster rereferencing. See
|
|
525
|
+
:attr:`CommonRereferenceSettings.channel_clusters`.
|
|
242
526
|
|
|
243
527
|
Returns:
|
|
244
528
|
:obj:`CommonRereferenceTransformer`
|
|
245
529
|
"""
|
|
246
530
|
return CommonRereferenceTransformer(
|
|
247
|
-
CommonRereferenceSettings(
|
|
531
|
+
CommonRereferenceSettings(
|
|
532
|
+
mode=mode, axis=axis, include_current=include_current, channel_clusters=channel_clusters
|
|
533
|
+
)
|
|
248
534
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-sigproc
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.13.1
|
|
4
4
|
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
5
|
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
|
|
2
|
-
ezmsg/sigproc/__version__.py,sha256=
|
|
2
|
+
ezmsg/sigproc/__version__.py,sha256=LrOgVsOxSmuj7RbXPLo3yIvC77lH9VlW4tk7Ihs46rY,706
|
|
3
3
|
ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
|
|
4
4
|
ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
|
|
5
|
-
ezmsg/sigproc/affinetransform.py,sha256=
|
|
5
|
+
ezmsg/sigproc/affinetransform.py,sha256=mjA21DRVYm0kS2tK7dNR_mU5XAxDbJGuuhnuzz0gtw4,21679
|
|
6
6
|
ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,9362
|
|
7
7
|
ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
|
|
8
8
|
ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
|
|
@@ -62,7 +62,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
|
|
|
62
62
|
ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
|
|
63
63
|
ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
|
|
64
64
|
ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
|
|
65
|
-
ezmsg_sigproc-2.
|
|
66
|
-
ezmsg_sigproc-2.
|
|
67
|
-
ezmsg_sigproc-2.
|
|
68
|
-
ezmsg_sigproc-2.
|
|
65
|
+
ezmsg_sigproc-2.13.1.dist-info/METADATA,sha256=mjyCiq1zy3JbgayLzq3BQLB-i1LOouUeGHe0tn3v7MY,1909
|
|
66
|
+
ezmsg_sigproc-2.13.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
67
|
+
ezmsg_sigproc-2.13.1.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
|
|
68
|
+
ezmsg_sigproc-2.13.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|