ezmsg-sigproc 2.13.1__py3-none-any.whl → 2.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.13.1'
32
- __version_tuple__ = version_tuple = (2, 13, 1)
31
+ __version__ = version = '2.14.0'
32
+ __version_tuple__ = version_tuple = (2, 14, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -23,6 +23,8 @@ from ezmsg.baseproc import (
23
23
  from ezmsg.util.messages.axisarray import AxisArray, AxisBase
24
24
  from ezmsg.util.messages.util import replace
25
25
 
26
+ from ezmsg.sigproc.util.array import array_device, is_float_dtype, xp_asarray, xp_create
27
+
26
28
 
27
29
  def _find_block_diagonal_clusters(weights: np.ndarray) -> list[tuple[np.ndarray, np.ndarray]] | None:
28
30
  """Detect block-diagonal structure in a weight matrix.
@@ -245,13 +247,25 @@ class AffineTransformTransformer(
245
247
 
246
248
  self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
247
249
 
248
- # Convert to match message.data namespace for efficient operations in _process
250
+ # Convert to match message.data namespace and device for _process.
251
+ # Weights are stored as numpy float64 after cluster detection; some
252
+ # devices (e.g. MPS) don't support float64, so we downcast weight
253
+ # arrays to the message's dtype when the message is floating-point.
249
254
  xp = get_namespace(message.data)
255
+ dev = array_device(message.data)
256
+ msg_dt = message.data.dtype
257
+ # Downcast weights dtype only for float message data (avoids casting
258
+ # float weights to integer when message data happens to be int).
259
+ w_dt = msg_dt if is_float_dtype(xp, msg_dt) else None
250
260
  if self._state.weights is not None:
251
- self._state.weights = xp.asarray(self._state.weights)
261
+ self._state.weights = xp_asarray(xp, self._state.weights, dtype=w_dt, device=dev)
252
262
  if self._state.clusters is not None:
253
263
  self._state.clusters = [
254
- (xp.asarray(in_idx), xp.asarray(out_idx), xp.asarray(sub_w))
264
+ (
265
+ xp_asarray(xp, in_idx, device=dev),
266
+ xp_asarray(xp, out_idx, device=dev),
267
+ xp_asarray(xp, sub_w, dtype=w_dt, device=dev),
268
+ )
255
269
  for in_idx, out_idx, sub_w in self._state.clusters
256
270
  ]
257
271
 
@@ -345,7 +359,7 @@ class AffineTransformTransformer(
345
359
 
346
360
  # Pre-allocate output (omitted channels stay zero)
347
361
  out_shape = data.shape[:-1] + (self._state.n_out,)
348
- result = xp.zeros(out_shape, dtype=data.dtype)
362
+ result = xp_create(xp.zeros, out_shape, dtype=data.dtype, device=array_device(data))
349
363
 
350
364
  for in_idx, out_idx, sub_weights in self._state.clusters:
351
365
  chunk = xp.take(data, in_idx, axis=data.ndim - 1)
@@ -371,7 +385,10 @@ class AffineTransformTransformer(
371
385
  # The weights are stacked A|B where A is the transform and B is a single row
372
386
  # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
373
387
  sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
374
- data = xp.concat((data, xp.ones(sample_shape, dtype=data.dtype)), axis=axis_idx)
388
+ data = xp.concat(
389
+ (data, xp_create(xp.ones, sample_shape, dtype=data.dtype, device=array_device(data))),
390
+ axis=axis_idx,
391
+ )
375
392
 
376
393
  if axis_idx in [-1, len(message.dims) - 1]:
377
394
  data = xp.matmul(data, self._state.weights)
@@ -0,0 +1,59 @@
1
+ """Portable helpers for Array API interoperability.
2
+
3
+ These utilities smooth over differences between Array API libraries
4
+ (NumPy, PyTorch, MLX, CuPy, etc.) — in particular around ``device``
5
+ placement and ``dtype`` introspection, which are not uniformly supported.
6
+ """
7
+
8
+ import numpy as np
9
+
10
+
11
+ def array_device(x):
12
+ """Return the device of an array, or ``None`` for device-less libraries."""
13
+ try:
14
+ from array_api_compat import device
15
+
16
+ return device(x)
17
+ except (AttributeError, TypeError):
18
+ return None
19
+
20
+
21
+ def xp_asarray(xp, obj, *, dtype=None, device=None):
22
+ """Portable ``xp.asarray`` that omits unsupported kwargs.
23
+
24
+ Some Array API libraries (e.g. MLX) don't accept a ``device`` keyword.
25
+ This helper builds the kwargs dict dynamically so that only supported
26
+ arguments are forwarded.
27
+ """
28
+ kwargs = {}
29
+ if dtype is not None:
30
+ kwargs["dtype"] = dtype
31
+ if device is not None:
32
+ kwargs["device"] = device
33
+ return xp.asarray(obj, **kwargs)
34
+
35
+
36
+ def xp_create(fn, *args, dtype=None, device=None, **extra):
37
+ """Call a creation function (``zeros``, ``ones``, ``eye``) portably.
38
+
39
+ Omits ``device`` if it is ``None`` (for libraries that don't support it).
40
+ """
41
+ kwargs = dict(extra)
42
+ if dtype is not None:
43
+ kwargs["dtype"] = dtype
44
+ if device is not None:
45
+ kwargs["device"] = device
46
+ return fn(*args, **kwargs)
47
+
48
+
49
+ def is_float_dtype(xp, dtype) -> bool:
50
+ """Check whether *dtype* is a real floating-point type, portably."""
51
+ try:
52
+ return xp.isdtype(dtype, "real floating")
53
+ except AttributeError:
54
+ pass
55
+ # Fallback for libraries without isdtype (e.g. MLX).
56
+ try:
57
+ return xp.issubdtype(dtype, xp.floating)
58
+ except (AttributeError, TypeError):
59
+ return np.issubdtype(np.dtype(dtype), np.floating)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-sigproc
3
- Version: 2.13.1
3
+ Version: 2.14.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
@@ -1,8 +1,8 @@
1
1
  ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
2
- ezmsg/sigproc/__version__.py,sha256=LrOgVsOxSmuj7RbXPLo3yIvC77lH9VlW4tk7Ihs46rY,706
2
+ ezmsg/sigproc/__version__.py,sha256=ho7Izv0FLYm0b-otHGFFngHF9OeMd4aMhL-uSzYcJ5U,706
3
3
  ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
4
4
  ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
5
- ezmsg/sigproc/affinetransform.py,sha256=mjA21DRVYm0kS2tK7dNR_mU5XAxDbJGuuhnuzz0gtw4,21679
5
+ ezmsg/sigproc/affinetransform.py,sha256=rnbQ2QtLq15eoFrqL-ij12yXSoLLq5ncOwRRzQgB1lA,22574
6
6
  ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,9362
7
7
  ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
8
8
  ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
@@ -55,6 +55,7 @@ ezmsg/sigproc/math/log.py,sha256=JhjSqLnQnvx_3F4txRYHuUPSJ12Yj2HvRTsCMNvlxpo,202
55
55
  ezmsg/sigproc/math/pow.py,sha256=0sdlXFUEBXmpEV_i75oshGRjMguv8L13nLt7hlvdX3E,1284
56
56
  ezmsg/sigproc/math/scale.py,sha256=4_xHcHNuf13E1fxIF5vbkPfkN4En6zkfPIKID7lCERk,1133
57
57
  ezmsg/sigproc/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
+ ezmsg/sigproc/util/array.py,sha256=S6wTGRbtJpU5I9u2t0MTUAMSKkm1PJref145s87xHLg,1831
58
59
  ezmsg/sigproc/util/asio.py,sha256=aAj0e7OoBvkRy28k05HL2s9YPCTxOddc05xMN-qd4lQ,577
59
60
  ezmsg/sigproc/util/axisarray_buffer.py,sha256=TGDeC6CXmmp7OUuiGd6xYQijRGYDE4QGdWxjK5Vs3nE,14057
60
61
  ezmsg/sigproc/util/buffer.py,sha256=83Gm0IuowmcMlXgLFB_rz8_ZPhkwG4DNNejyWJDKJl8,19658
@@ -62,7 +63,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
62
63
  ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
63
64
  ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
64
65
  ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
65
- ezmsg_sigproc-2.13.1.dist-info/METADATA,sha256=mjyCiq1zy3JbgayLzq3BQLB-i1LOouUeGHe0tn3v7MY,1909
66
- ezmsg_sigproc-2.13.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
67
- ezmsg_sigproc-2.13.1.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
68
- ezmsg_sigproc-2.13.1.dist-info/RECORD,,
66
+ ezmsg_sigproc-2.14.0.dist-info/METADATA,sha256=Ml-3BXd3vRBRe3H1j1pmWdlXmvnvphbYQ7SH_n4pw9Y,1909
67
+ ezmsg_sigproc-2.14.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
68
+ ezmsg_sigproc-2.14.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
69
+ ezmsg_sigproc-2.14.0.dist-info/RECORD,,