ezmsg-sigproc 2.13.0__py3-none-any.whl → 2.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/affinetransform.py +99 -55
- ezmsg/sigproc/util/array.py +59 -0
- {ezmsg_sigproc-2.13.0.dist-info → ezmsg_sigproc-2.14.0.dist-info}/METADATA +1 -1
- {ezmsg_sigproc-2.13.0.dist-info → ezmsg_sigproc-2.14.0.dist-info}/RECORD +7 -6
- {ezmsg_sigproc-2.13.0.dist-info → ezmsg_sigproc-2.14.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-2.13.0.dist-info → ezmsg_sigproc-2.14.0.dist-info}/licenses/LICENSE +0 -0
ezmsg/sigproc/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.
|
|
32
|
-
__version_tuple__ = version_tuple = (2,
|
|
31
|
+
__version__ = version = '2.14.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 14, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
ezmsg/sigproc/affinetransform.py
CHANGED
|
@@ -23,6 +23,8 @@ from ezmsg.baseproc import (
|
|
|
23
23
|
from ezmsg.util.messages.axisarray import AxisArray, AxisBase
|
|
24
24
|
from ezmsg.util.messages.util import replace
|
|
25
25
|
|
|
26
|
+
from ezmsg.sigproc.util.array import array_device, is_float_dtype, xp_asarray, xp_create
|
|
27
|
+
|
|
26
28
|
|
|
27
29
|
def _find_block_diagonal_clusters(weights: np.ndarray) -> list[tuple[np.ndarray, np.ndarray]] | None:
|
|
28
30
|
"""Detect block-diagonal structure in a weight matrix.
|
|
@@ -215,58 +217,11 @@ class AffineTransformTransformer(
|
|
|
215
217
|
if weights is not None:
|
|
216
218
|
weights = np.ascontiguousarray(weights)
|
|
217
219
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
# Note: If weights were scipy.sparse BSR then maybe we could use automate this next part.
|
|
221
|
-
# However, that would break compatibility with Array API.
|
|
222
|
-
|
|
223
|
-
# --- Block-diagonal cluster detection ---
|
|
224
|
-
# Clusters are a list of (input_indices, output_indices) tuples.
|
|
225
|
-
n_in, n_out = weights.shape
|
|
226
|
-
if self.settings.channel_clusters is not None:
|
|
227
|
-
# Validate input index bounds
|
|
228
|
-
all_in = np.concatenate([np.asarray(group) for group in self.settings.channel_clusters])
|
|
229
|
-
if np.any((all_in < 0) | (all_in >= n_in)):
|
|
230
|
-
raise ValueError(
|
|
231
|
-
"channel_clusters contains out-of-range input indices " f"(valid range: 0..{n_in - 1})"
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
# Derive output indices from non-zero weights for each input cluster
|
|
235
|
-
clusters = []
|
|
236
|
-
for group in self.settings.channel_clusters:
|
|
237
|
-
in_idx = np.asarray(group)
|
|
238
|
-
out_idx = np.where(np.any(weights[in_idx, :] != 0, axis=0))[0]
|
|
239
|
-
clusters.append((in_idx, out_idx))
|
|
240
|
-
|
|
241
|
-
max_cross = _max_cross_cluster_weight(weights, clusters)
|
|
242
|
-
if max_cross > 0:
|
|
243
|
-
ez.logger.warning(
|
|
244
|
-
f"Non-zero cross-cluster weights detected (max abs: {max_cross:.2e}). "
|
|
245
|
-
"These will be ignored in block-diagonal multiplication."
|
|
246
|
-
)
|
|
247
|
-
else:
|
|
248
|
-
clusters = _find_block_diagonal_clusters(weights)
|
|
249
|
-
if clusters is not None:
|
|
250
|
-
ez.logger.info(
|
|
251
|
-
f"Auto-detected {len(clusters)} block-diagonal clusters "
|
|
252
|
-
f"(sizes: {[(len(i), len(o)) for i, o in clusters]})"
|
|
253
|
-
)
|
|
254
|
-
|
|
255
|
-
# Merge small clusters to avoid excessive loop overhead
|
|
256
|
-
if clusters is not None:
|
|
257
|
-
clusters = _merge_small_clusters(clusters, self.settings.min_cluster_size)
|
|
258
|
-
|
|
259
|
-
if clusters is not None and len(clusters) > 1:
|
|
260
|
-
self._state.n_out = n_out
|
|
261
|
-
self._state.clusters = [
|
|
262
|
-
(in_idx, out_idx, np.ascontiguousarray(weights[np.ix_(in_idx, out_idx)]))
|
|
263
|
-
for in_idx, out_idx in clusters
|
|
264
|
-
]
|
|
265
|
-
self._state.weights = None
|
|
266
|
-
else:
|
|
267
|
-
self._state.clusters = None
|
|
220
|
+
# Cluster detection + weight storage (delegated)
|
|
221
|
+
self.set_weights(weights, recalc_clusters=True)
|
|
268
222
|
|
|
269
223
|
# --- Axis label handling (for non-square transforms, non-cluster path) ---
|
|
224
|
+
n_in, n_out = weights.shape
|
|
270
225
|
axis = self.settings.axis or message.dims[-1]
|
|
271
226
|
if axis in message.axes and hasattr(message.axes[axis], "data") and n_in != n_out:
|
|
272
227
|
in_labels = message.axes[axis].data
|
|
@@ -292,16 +247,102 @@ class AffineTransformTransformer(
|
|
|
292
247
|
|
|
293
248
|
self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
|
|
294
249
|
|
|
295
|
-
# Convert to match message.data namespace
|
|
250
|
+
# Convert to match message.data namespace and device for _process.
|
|
251
|
+
# Weights are stored as numpy float64 after cluster detection; some
|
|
252
|
+
# devices (e.g. MPS) don't support float64, so we downcast weight
|
|
253
|
+
# arrays to the message's dtype when the message is floating-point.
|
|
296
254
|
xp = get_namespace(message.data)
|
|
255
|
+
dev = array_device(message.data)
|
|
256
|
+
msg_dt = message.data.dtype
|
|
257
|
+
# Downcast weights dtype only for float message data (avoids casting
|
|
258
|
+
# float weights to integer when message data happens to be int).
|
|
259
|
+
w_dt = msg_dt if is_float_dtype(xp, msg_dt) else None
|
|
297
260
|
if self._state.weights is not None:
|
|
298
|
-
self._state.weights = xp
|
|
261
|
+
self._state.weights = xp_asarray(xp, self._state.weights, dtype=w_dt, device=dev)
|
|
299
262
|
if self._state.clusters is not None:
|
|
300
263
|
self._state.clusters = [
|
|
301
|
-
(
|
|
264
|
+
(
|
|
265
|
+
xp_asarray(xp, in_idx, device=dev),
|
|
266
|
+
xp_asarray(xp, out_idx, device=dev),
|
|
267
|
+
xp_asarray(xp, sub_w, dtype=w_dt, device=dev),
|
|
268
|
+
)
|
|
302
269
|
for in_idx, out_idx, sub_w in self._state.clusters
|
|
303
270
|
]
|
|
304
271
|
|
|
272
|
+
def set_weights(self, weights, *, recalc_clusters=False) -> None:
|
|
273
|
+
"""Replace weight values, optionally recalculating cluster decomposition.
|
|
274
|
+
|
|
275
|
+
*weights* must be in **canonical orientation** (``right_multiply``
|
|
276
|
+
already applied by the caller or by ``_reset_state``). The array may
|
|
277
|
+
live in any Array-API namespace (NumPy, CuPy, etc.).
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
weights: Weight matrix in canonical orientation.
|
|
281
|
+
recalc_clusters: When True, re-run block-diagonal cluster detection
|
|
282
|
+
and store the new decomposition. When False (default), reuse
|
|
283
|
+
the existing cluster structure and only update weight values.
|
|
284
|
+
"""
|
|
285
|
+
if recalc_clusters:
|
|
286
|
+
# Note: If weights were scipy.sparse BSR then maybe we could automate this next part.
|
|
287
|
+
# However, that would break compatibility with Array API.
|
|
288
|
+
|
|
289
|
+
# --- Block-diagonal cluster detection ---
|
|
290
|
+
# Clusters are a list of (input_indices, output_indices) tuples.
|
|
291
|
+
w_np = np.ascontiguousarray(weights)
|
|
292
|
+
n_in, n_out = w_np.shape
|
|
293
|
+
if self.settings.channel_clusters is not None:
|
|
294
|
+
# Validate input index bounds
|
|
295
|
+
all_in = np.concatenate([np.asarray(group) for group in self.settings.channel_clusters])
|
|
296
|
+
if np.any((all_in < 0) | (all_in >= n_in)):
|
|
297
|
+
raise ValueError(
|
|
298
|
+
"channel_clusters contains out-of-range input indices " f"(valid range: 0..{n_in - 1})"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Derive output indices from non-zero weights for each input cluster
|
|
302
|
+
clusters = []
|
|
303
|
+
for group in self.settings.channel_clusters:
|
|
304
|
+
in_idx = np.asarray(group)
|
|
305
|
+
out_idx = np.where(np.any(w_np[in_idx, :] != 0, axis=0))[0]
|
|
306
|
+
clusters.append((in_idx, out_idx))
|
|
307
|
+
|
|
308
|
+
max_cross = _max_cross_cluster_weight(w_np, clusters)
|
|
309
|
+
if max_cross > 0:
|
|
310
|
+
ez.logger.warning(
|
|
311
|
+
f"Non-zero cross-cluster weights detected (max abs: {max_cross:.2e}). "
|
|
312
|
+
"These will be ignored in block-diagonal multiplication."
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
clusters = _find_block_diagonal_clusters(w_np)
|
|
316
|
+
if clusters is not None:
|
|
317
|
+
ez.logger.info(
|
|
318
|
+
f"Auto-detected {len(clusters)} block-diagonal clusters "
|
|
319
|
+
f"(sizes: {[(len(i), len(o)) for i, o in clusters]})"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Merge small clusters to avoid excessive loop overhead
|
|
323
|
+
if clusters is not None:
|
|
324
|
+
clusters = _merge_small_clusters(clusters, self.settings.min_cluster_size)
|
|
325
|
+
|
|
326
|
+
if clusters is not None and len(clusters) > 1:
|
|
327
|
+
self._state.n_out = n_out
|
|
328
|
+
self._state.clusters = [
|
|
329
|
+
(in_idx, out_idx, np.ascontiguousarray(w_np[np.ix_(in_idx, out_idx)]))
|
|
330
|
+
for in_idx, out_idx in clusters
|
|
331
|
+
]
|
|
332
|
+
self._state.weights = None
|
|
333
|
+
else:
|
|
334
|
+
self._state.weights = weights
|
|
335
|
+
self._state.clusters = None
|
|
336
|
+
else:
|
|
337
|
+
xp = get_namespace(weights)
|
|
338
|
+
if self._state.clusters is not None:
|
|
339
|
+
self._state.clusters = [
|
|
340
|
+
(in_idx, out_idx, xp.take(xp.take(weights, in_idx, axis=0), out_idx, axis=1))
|
|
341
|
+
for in_idx, out_idx, _ in self._state.clusters
|
|
342
|
+
]
|
|
343
|
+
else:
|
|
344
|
+
self._state.weights = weights
|
|
345
|
+
|
|
305
346
|
def _block_diagonal_matmul(self, xp, data, axis_idx):
|
|
306
347
|
"""Perform matmul using block-diagonal decomposition.
|
|
307
348
|
|
|
@@ -318,7 +359,7 @@ class AffineTransformTransformer(
|
|
|
318
359
|
|
|
319
360
|
# Pre-allocate output (omitted channels stay zero)
|
|
320
361
|
out_shape = data.shape[:-1] + (self._state.n_out,)
|
|
321
|
-
result = xp.zeros
|
|
362
|
+
result = xp_create(xp.zeros, out_shape, dtype=data.dtype, device=array_device(data))
|
|
322
363
|
|
|
323
364
|
for in_idx, out_idx, sub_weights in self._state.clusters:
|
|
324
365
|
chunk = xp.take(data, in_idx, axis=data.ndim - 1)
|
|
@@ -344,7 +385,10 @@ class AffineTransformTransformer(
|
|
|
344
385
|
# The weights are stacked A|B where A is the transform and B is a single row
|
|
345
386
|
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
|
|
346
387
|
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
|
|
347
|
-
data = xp.concat(
|
|
388
|
+
data = xp.concat(
|
|
389
|
+
(data, xp_create(xp.ones, sample_shape, dtype=data.dtype, device=array_device(data))),
|
|
390
|
+
axis=axis_idx,
|
|
391
|
+
)
|
|
348
392
|
|
|
349
393
|
if axis_idx in [-1, len(message.dims) - 1]:
|
|
350
394
|
data = xp.matmul(data, self._state.weights)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Portable helpers for Array API interoperability.
|
|
2
|
+
|
|
3
|
+
These utilities smooth over differences between Array API libraries
|
|
4
|
+
(NumPy, PyTorch, MLX, CuPy, etc.) — in particular around ``device``
|
|
5
|
+
placement and ``dtype`` introspection, which are not uniformly supported.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def array_device(x):
|
|
12
|
+
"""Return the device of an array, or ``None`` for device-less libraries."""
|
|
13
|
+
try:
|
|
14
|
+
from array_api_compat import device
|
|
15
|
+
|
|
16
|
+
return device(x)
|
|
17
|
+
except (AttributeError, TypeError):
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def xp_asarray(xp, obj, *, dtype=None, device=None):
|
|
22
|
+
"""Portable ``xp.asarray`` that omits unsupported kwargs.
|
|
23
|
+
|
|
24
|
+
Some Array API libraries (e.g. MLX) don't accept a ``device`` keyword.
|
|
25
|
+
This helper builds the kwargs dict dynamically so that only supported
|
|
26
|
+
arguments are forwarded.
|
|
27
|
+
"""
|
|
28
|
+
kwargs = {}
|
|
29
|
+
if dtype is not None:
|
|
30
|
+
kwargs["dtype"] = dtype
|
|
31
|
+
if device is not None:
|
|
32
|
+
kwargs["device"] = device
|
|
33
|
+
return xp.asarray(obj, **kwargs)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def xp_create(fn, *args, dtype=None, device=None, **extra):
|
|
37
|
+
"""Call a creation function (``zeros``, ``ones``, ``eye``) portably.
|
|
38
|
+
|
|
39
|
+
Omits ``device`` if it is ``None`` (for libraries that don't support it).
|
|
40
|
+
"""
|
|
41
|
+
kwargs = dict(extra)
|
|
42
|
+
if dtype is not None:
|
|
43
|
+
kwargs["dtype"] = dtype
|
|
44
|
+
if device is not None:
|
|
45
|
+
kwargs["device"] = device
|
|
46
|
+
return fn(*args, **kwargs)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def is_float_dtype(xp, dtype) -> bool:
|
|
50
|
+
"""Check whether *dtype* is a real floating-point type, portably."""
|
|
51
|
+
try:
|
|
52
|
+
return xp.isdtype(dtype, "real floating")
|
|
53
|
+
except AttributeError:
|
|
54
|
+
pass
|
|
55
|
+
# Fallback for libraries without isdtype (e.g. MLX).
|
|
56
|
+
try:
|
|
57
|
+
return xp.issubdtype(dtype, xp.floating)
|
|
58
|
+
except (AttributeError, TypeError):
|
|
59
|
+
return np.issubdtype(np.dtype(dtype), np.floating)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-sigproc
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.14.0
|
|
4
4
|
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
5
|
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
|
|
2
|
-
ezmsg/sigproc/__version__.py,sha256=
|
|
2
|
+
ezmsg/sigproc/__version__.py,sha256=ho7Izv0FLYm0b-otHGFFngHF9OeMd4aMhL-uSzYcJ5U,706
|
|
3
3
|
ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
|
|
4
4
|
ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
|
|
5
|
-
ezmsg/sigproc/affinetransform.py,sha256=
|
|
5
|
+
ezmsg/sigproc/affinetransform.py,sha256=rnbQ2QtLq15eoFrqL-ij12yXSoLLq5ncOwRRzQgB1lA,22574
|
|
6
6
|
ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,9362
|
|
7
7
|
ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
|
|
8
8
|
ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
|
|
@@ -55,6 +55,7 @@ ezmsg/sigproc/math/log.py,sha256=JhjSqLnQnvx_3F4txRYHuUPSJ12Yj2HvRTsCMNvlxpo,202
|
|
|
55
55
|
ezmsg/sigproc/math/pow.py,sha256=0sdlXFUEBXmpEV_i75oshGRjMguv8L13nLt7hlvdX3E,1284
|
|
56
56
|
ezmsg/sigproc/math/scale.py,sha256=4_xHcHNuf13E1fxIF5vbkPfkN4En6zkfPIKID7lCERk,1133
|
|
57
57
|
ezmsg/sigproc/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
58
|
+
ezmsg/sigproc/util/array.py,sha256=S6wTGRbtJpU5I9u2t0MTUAMSKkm1PJref145s87xHLg,1831
|
|
58
59
|
ezmsg/sigproc/util/asio.py,sha256=aAj0e7OoBvkRy28k05HL2s9YPCTxOddc05xMN-qd4lQ,577
|
|
59
60
|
ezmsg/sigproc/util/axisarray_buffer.py,sha256=TGDeC6CXmmp7OUuiGd6xYQijRGYDE4QGdWxjK5Vs3nE,14057
|
|
60
61
|
ezmsg/sigproc/util/buffer.py,sha256=83Gm0IuowmcMlXgLFB_rz8_ZPhkwG4DNNejyWJDKJl8,19658
|
|
@@ -62,7 +63,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
|
|
|
62
63
|
ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
|
|
63
64
|
ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
|
|
64
65
|
ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
|
|
65
|
-
ezmsg_sigproc-2.
|
|
66
|
-
ezmsg_sigproc-2.
|
|
67
|
-
ezmsg_sigproc-2.
|
|
68
|
-
ezmsg_sigproc-2.
|
|
66
|
+
ezmsg_sigproc-2.14.0.dist-info/METADATA,sha256=Ml-3BXd3vRBRe3H1j1pmWdlXmvnvphbYQ7SH_n4pw9Y,1909
|
|
67
|
+
ezmsg_sigproc-2.14.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
68
|
+
ezmsg_sigproc-2.14.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
|
|
69
|
+
ezmsg_sigproc-2.14.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|