ezmsg-sigproc 2.13.0__py3-none-any.whl → 2.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.13.0'
32
- __version_tuple__ = version_tuple = (2, 13, 0)
31
+ __version__ = version = '2.14.0'
32
+ __version_tuple__ = version_tuple = (2, 14, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -23,6 +23,8 @@ from ezmsg.baseproc import (
23
23
  from ezmsg.util.messages.axisarray import AxisArray, AxisBase
24
24
  from ezmsg.util.messages.util import replace
25
25
 
26
+ from ezmsg.sigproc.util.array import array_device, is_float_dtype, xp_asarray, xp_create
27
+
26
28
 
27
29
  def _find_block_diagonal_clusters(weights: np.ndarray) -> list[tuple[np.ndarray, np.ndarray]] | None:
28
30
  """Detect block-diagonal structure in a weight matrix.
@@ -215,58 +217,11 @@ class AffineTransformTransformer(
215
217
  if weights is not None:
216
218
  weights = np.ascontiguousarray(weights)
217
219
 
218
- self._state.weights = weights
219
-
220
- # Note: If weights were scipy.sparse BSR then maybe we could use automate this next part.
221
- # However, that would break compatibility with Array API.
222
-
223
- # --- Block-diagonal cluster detection ---
224
- # Clusters are a list of (input_indices, output_indices) tuples.
225
- n_in, n_out = weights.shape
226
- if self.settings.channel_clusters is not None:
227
- # Validate input index bounds
228
- all_in = np.concatenate([np.asarray(group) for group in self.settings.channel_clusters])
229
- if np.any((all_in < 0) | (all_in >= n_in)):
230
- raise ValueError(
231
- "channel_clusters contains out-of-range input indices " f"(valid range: 0..{n_in - 1})"
232
- )
233
-
234
- # Derive output indices from non-zero weights for each input cluster
235
- clusters = []
236
- for group in self.settings.channel_clusters:
237
- in_idx = np.asarray(group)
238
- out_idx = np.where(np.any(weights[in_idx, :] != 0, axis=0))[0]
239
- clusters.append((in_idx, out_idx))
240
-
241
- max_cross = _max_cross_cluster_weight(weights, clusters)
242
- if max_cross > 0:
243
- ez.logger.warning(
244
- f"Non-zero cross-cluster weights detected (max abs: {max_cross:.2e}). "
245
- "These will be ignored in block-diagonal multiplication."
246
- )
247
- else:
248
- clusters = _find_block_diagonal_clusters(weights)
249
- if clusters is not None:
250
- ez.logger.info(
251
- f"Auto-detected {len(clusters)} block-diagonal clusters "
252
- f"(sizes: {[(len(i), len(o)) for i, o in clusters]})"
253
- )
254
-
255
- # Merge small clusters to avoid excessive loop overhead
256
- if clusters is not None:
257
- clusters = _merge_small_clusters(clusters, self.settings.min_cluster_size)
258
-
259
- if clusters is not None and len(clusters) > 1:
260
- self._state.n_out = n_out
261
- self._state.clusters = [
262
- (in_idx, out_idx, np.ascontiguousarray(weights[np.ix_(in_idx, out_idx)]))
263
- for in_idx, out_idx in clusters
264
- ]
265
- self._state.weights = None
266
- else:
267
- self._state.clusters = None
220
+ # Cluster detection + weight storage (delegated)
221
+ self.set_weights(weights, recalc_clusters=True)
268
222
 
269
223
  # --- Axis label handling (for non-square transforms, non-cluster path) ---
224
+ n_in, n_out = weights.shape
270
225
  axis = self.settings.axis or message.dims[-1]
271
226
  if axis in message.axes and hasattr(message.axes[axis], "data") and n_in != n_out:
272
227
  in_labels = message.axes[axis].data
@@ -292,16 +247,102 @@ class AffineTransformTransformer(
292
247
 
293
248
  self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
294
249
 
295
- # Convert to match message.data namespace for efficient operations in _process
250
+ # Convert to match message.data namespace and device for _process.
251
+ # Weights are stored as numpy float64 after cluster detection; some
252
+ # devices (e.g. MPS) don't support float64, so we downcast weight
253
+ # arrays to the message's dtype when the message is floating-point.
296
254
  xp = get_namespace(message.data)
255
+ dev = array_device(message.data)
256
+ msg_dt = message.data.dtype
257
+ # Downcast weights dtype only for float message data (avoids casting
258
+ # float weights to integer when message data happens to be int).
259
+ w_dt = msg_dt if is_float_dtype(xp, msg_dt) else None
297
260
  if self._state.weights is not None:
298
- self._state.weights = xp.asarray(self._state.weights)
261
+ self._state.weights = xp_asarray(xp, self._state.weights, dtype=w_dt, device=dev)
299
262
  if self._state.clusters is not None:
300
263
  self._state.clusters = [
301
- (xp.asarray(in_idx), xp.asarray(out_idx), xp.asarray(sub_w))
264
+ (
265
+ xp_asarray(xp, in_idx, device=dev),
266
+ xp_asarray(xp, out_idx, device=dev),
267
+ xp_asarray(xp, sub_w, dtype=w_dt, device=dev),
268
+ )
302
269
  for in_idx, out_idx, sub_w in self._state.clusters
303
270
  ]
304
271
 
272
+ def set_weights(self, weights, *, recalc_clusters=False) -> None:
273
+ """Replace weight values, optionally recalculating cluster decomposition.
274
+
275
+ *weights* must be in **canonical orientation** (``right_multiply``
276
+ already applied by the caller or by ``_reset_state``). The array may
277
+ live in any Array-API namespace (NumPy, CuPy, etc.).
278
+
279
+ Args:
280
+ weights: Weight matrix in canonical orientation.
281
+ recalc_clusters: When True, re-run block-diagonal cluster detection
282
+ and store the new decomposition. When False (default), reuse
283
+ the existing cluster structure and only update weight values.
284
+ """
285
+ if recalc_clusters:
286
+ # Note: If weights were scipy.sparse BSR then maybe we could automate this next part.
287
+ # However, that would break compatibility with Array API.
288
+
289
+ # --- Block-diagonal cluster detection ---
290
+ # Clusters are a list of (input_indices, output_indices) tuples.
291
+ w_np = np.ascontiguousarray(weights)
292
+ n_in, n_out = w_np.shape
293
+ if self.settings.channel_clusters is not None:
294
+ # Validate input index bounds
295
+ all_in = np.concatenate([np.asarray(group) for group in self.settings.channel_clusters])
296
+ if np.any((all_in < 0) | (all_in >= n_in)):
297
+ raise ValueError(
298
+ "channel_clusters contains out-of-range input indices " f"(valid range: 0..{n_in - 1})"
299
+ )
300
+
301
+ # Derive output indices from non-zero weights for each input cluster
302
+ clusters = []
303
+ for group in self.settings.channel_clusters:
304
+ in_idx = np.asarray(group)
305
+ out_idx = np.where(np.any(w_np[in_idx, :] != 0, axis=0))[0]
306
+ clusters.append((in_idx, out_idx))
307
+
308
+ max_cross = _max_cross_cluster_weight(w_np, clusters)
309
+ if max_cross > 0:
310
+ ez.logger.warning(
311
+ f"Non-zero cross-cluster weights detected (max abs: {max_cross:.2e}). "
312
+ "These will be ignored in block-diagonal multiplication."
313
+ )
314
+ else:
315
+ clusters = _find_block_diagonal_clusters(w_np)
316
+ if clusters is not None:
317
+ ez.logger.info(
318
+ f"Auto-detected {len(clusters)} block-diagonal clusters "
319
+ f"(sizes: {[(len(i), len(o)) for i, o in clusters]})"
320
+ )
321
+
322
+ # Merge small clusters to avoid excessive loop overhead
323
+ if clusters is not None:
324
+ clusters = _merge_small_clusters(clusters, self.settings.min_cluster_size)
325
+
326
+ if clusters is not None and len(clusters) > 1:
327
+ self._state.n_out = n_out
328
+ self._state.clusters = [
329
+ (in_idx, out_idx, np.ascontiguousarray(w_np[np.ix_(in_idx, out_idx)]))
330
+ for in_idx, out_idx in clusters
331
+ ]
332
+ self._state.weights = None
333
+ else:
334
+ self._state.weights = weights
335
+ self._state.clusters = None
336
+ else:
337
+ xp = get_namespace(weights)
338
+ if self._state.clusters is not None:
339
+ self._state.clusters = [
340
+ (in_idx, out_idx, xp.take(xp.take(weights, in_idx, axis=0), out_idx, axis=1))
341
+ for in_idx, out_idx, _ in self._state.clusters
342
+ ]
343
+ else:
344
+ self._state.weights = weights
345
+
305
346
  def _block_diagonal_matmul(self, xp, data, axis_idx):
306
347
  """Perform matmul using block-diagonal decomposition.
307
348
 
@@ -318,7 +359,7 @@ class AffineTransformTransformer(
318
359
 
319
360
  # Pre-allocate output (omitted channels stay zero)
320
361
  out_shape = data.shape[:-1] + (self._state.n_out,)
321
- result = xp.zeros(out_shape, dtype=data.dtype)
362
+ result = xp_create(xp.zeros, out_shape, dtype=data.dtype, device=array_device(data))
322
363
 
323
364
  for in_idx, out_idx, sub_weights in self._state.clusters:
324
365
  chunk = xp.take(data, in_idx, axis=data.ndim - 1)
@@ -344,7 +385,10 @@ class AffineTransformTransformer(
344
385
  # The weights are stacked A|B where A is the transform and B is a single row
345
386
  # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
346
387
  sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
347
- data = xp.concat((data, xp.ones(sample_shape, dtype=data.dtype)), axis=axis_idx)
388
+ data = xp.concat(
389
+ (data, xp_create(xp.ones, sample_shape, dtype=data.dtype, device=array_device(data))),
390
+ axis=axis_idx,
391
+ )
348
392
 
349
393
  if axis_idx in [-1, len(message.dims) - 1]:
350
394
  data = xp.matmul(data, self._state.weights)
@@ -0,0 +1,59 @@
1
+ """Portable helpers for Array API interoperability.
2
+
3
+ These utilities smooth over differences between Array API libraries
4
+ (NumPy, PyTorch, MLX, CuPy, etc.) — in particular around ``device``
5
+ placement and ``dtype`` introspection, which are not uniformly supported.
6
+ """
7
+
8
+ import numpy as np
9
+
10
+
11
+ def array_device(x):
12
+ """Return the device of an array, or ``None`` for device-less libraries."""
13
+ try:
14
+ from array_api_compat import device
15
+
16
+ return device(x)
17
+ except (AttributeError, TypeError):
18
+ return None
19
+
20
+
21
+ def xp_asarray(xp, obj, *, dtype=None, device=None):
22
+ """Portable ``xp.asarray`` that omits unsupported kwargs.
23
+
24
+ Some Array API libraries (e.g. MLX) don't accept a ``device`` keyword.
25
+ This helper builds the kwargs dict dynamically so that only supported
26
+ arguments are forwarded.
27
+ """
28
+ kwargs = {}
29
+ if dtype is not None:
30
+ kwargs["dtype"] = dtype
31
+ if device is not None:
32
+ kwargs["device"] = device
33
+ return xp.asarray(obj, **kwargs)
34
+
35
+
36
+ def xp_create(fn, *args, dtype=None, device=None, **extra):
37
+ """Call a creation function (``zeros``, ``ones``, ``eye``) portably.
38
+
39
+ Omits ``device`` if it is ``None`` (for libraries that don't support it).
40
+ """
41
+ kwargs = dict(extra)
42
+ if dtype is not None:
43
+ kwargs["dtype"] = dtype
44
+ if device is not None:
45
+ kwargs["device"] = device
46
+ return fn(*args, **kwargs)
47
+
48
+
49
+ def is_float_dtype(xp, dtype) -> bool:
50
+ """Check whether *dtype* is a real floating-point type, portably."""
51
+ try:
52
+ return xp.isdtype(dtype, "real floating")
53
+ except AttributeError:
54
+ pass
55
+ # Fallback for libraries without isdtype (e.g. MLX).
56
+ try:
57
+ return xp.issubdtype(dtype, xp.floating)
58
+ except (AttributeError, TypeError):
59
+ return np.issubdtype(np.dtype(dtype), np.floating)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-sigproc
3
- Version: 2.13.0
3
+ Version: 2.14.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
@@ -1,8 +1,8 @@
1
1
  ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
2
- ezmsg/sigproc/__version__.py,sha256=_4LOjlEcfZzfuqIlglDZmVBPO4LyQ8P97qO716YoUL8,706
2
+ ezmsg/sigproc/__version__.py,sha256=ho7Izv0FLYm0b-otHGFFngHF9OeMd4aMhL-uSzYcJ5U,706
3
3
  ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
4
4
  ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
5
- ezmsg/sigproc/affinetransform.py,sha256=ZugiQg89Ly1I9SDgf0ZzgU2XdwVDmPrU7-orO9yrt7w,20210
5
+ ezmsg/sigproc/affinetransform.py,sha256=rnbQ2QtLq15eoFrqL-ij12yXSoLLq5ncOwRRzQgB1lA,22574
6
6
  ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,9362
7
7
  ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
8
8
  ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
@@ -55,6 +55,7 @@ ezmsg/sigproc/math/log.py,sha256=JhjSqLnQnvx_3F4txRYHuUPSJ12Yj2HvRTsCMNvlxpo,202
55
55
  ezmsg/sigproc/math/pow.py,sha256=0sdlXFUEBXmpEV_i75oshGRjMguv8L13nLt7hlvdX3E,1284
56
56
  ezmsg/sigproc/math/scale.py,sha256=4_xHcHNuf13E1fxIF5vbkPfkN4En6zkfPIKID7lCERk,1133
57
57
  ezmsg/sigproc/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
+ ezmsg/sigproc/util/array.py,sha256=S6wTGRbtJpU5I9u2t0MTUAMSKkm1PJref145s87xHLg,1831
58
59
  ezmsg/sigproc/util/asio.py,sha256=aAj0e7OoBvkRy28k05HL2s9YPCTxOddc05xMN-qd4lQ,577
59
60
  ezmsg/sigproc/util/axisarray_buffer.py,sha256=TGDeC6CXmmp7OUuiGd6xYQijRGYDE4QGdWxjK5Vs3nE,14057
60
61
  ezmsg/sigproc/util/buffer.py,sha256=83Gm0IuowmcMlXgLFB_rz8_ZPhkwG4DNNejyWJDKJl8,19658
@@ -62,7 +63,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
62
63
  ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
63
64
  ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
64
65
  ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
65
- ezmsg_sigproc-2.13.0.dist-info/METADATA,sha256=RXENX541lABAic8oUDuT8vQwx9nlWY9JETyXYKxdeTQ,1909
66
- ezmsg_sigproc-2.13.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
67
- ezmsg_sigproc-2.13.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
68
- ezmsg_sigproc-2.13.0.dist-info/RECORD,,
66
+ ezmsg_sigproc-2.14.0.dist-info/METADATA,sha256=Ml-3BXd3vRBRe3H1j1pmWdlXmvnvphbYQ7SH_n4pw9Y,1909
67
+ ezmsg_sigproc-2.14.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
68
+ ezmsg_sigproc-2.14.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
69
+ ezmsg_sigproc-2.14.0.dist-info/RECORD,,