ezmsg-sigproc 2.13.1__py3-none-any.whl → 2.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.13.1'
32
- __version_tuple__ = version_tuple = (2, 13, 1)
31
+ __version__ = version = '2.15.0'
32
+ __version_tuple__ = version_tuple = (2, 15, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -23,6 +23,8 @@ from ezmsg.baseproc import (
23
23
  from ezmsg.util.messages.axisarray import AxisArray, AxisBase
24
24
  from ezmsg.util.messages.util import replace
25
25
 
26
+ from ezmsg.sigproc.util.array import array_device, is_float_dtype, xp_asarray, xp_create
27
+
26
28
 
27
29
  def _find_block_diagonal_clusters(weights: np.ndarray) -> list[tuple[np.ndarray, np.ndarray]] | None:
28
30
  """Detect block-diagonal structure in a weight matrix.
@@ -245,13 +247,25 @@ class AffineTransformTransformer(
245
247
 
246
248
  self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
247
249
 
248
- # Convert to match message.data namespace for efficient operations in _process
250
+ # Convert to match message.data namespace and device for _process.
251
+ # Weights are stored as numpy float64 after cluster detection; some
252
+ # devices (e.g. MPS) don't support float64, so we downcast weight
253
+ # arrays to the message's dtype when the message is floating-point.
249
254
  xp = get_namespace(message.data)
255
+ dev = array_device(message.data)
256
+ msg_dt = message.data.dtype
257
+ # Downcast weights dtype only for float message data (avoids casting
258
+ # float weights to integer when message data happens to be int).
259
+ w_dt = msg_dt if is_float_dtype(xp, msg_dt) else None
250
260
  if self._state.weights is not None:
251
- self._state.weights = xp.asarray(self._state.weights)
261
+ self._state.weights = xp_asarray(xp, self._state.weights, dtype=w_dt, device=dev)
252
262
  if self._state.clusters is not None:
253
263
  self._state.clusters = [
254
- (xp.asarray(in_idx), xp.asarray(out_idx), xp.asarray(sub_w))
264
+ (
265
+ xp_asarray(xp, in_idx, device=dev),
266
+ xp_asarray(xp, out_idx, device=dev),
267
+ xp_asarray(xp, sub_w, dtype=w_dt, device=dev),
268
+ )
255
269
  for in_idx, out_idx, sub_w in self._state.clusters
256
270
  ]
257
271
 
@@ -345,7 +359,7 @@ class AffineTransformTransformer(
345
359
 
346
360
  # Pre-allocate output (omitted channels stay zero)
347
361
  out_shape = data.shape[:-1] + (self._state.n_out,)
348
- result = xp.zeros(out_shape, dtype=data.dtype)
362
+ result = xp_create(xp.zeros, out_shape, dtype=data.dtype, device=array_device(data))
349
363
 
350
364
  for in_idx, out_idx, sub_weights in self._state.clusters:
351
365
  chunk = xp.take(data, in_idx, axis=data.ndim - 1)
@@ -371,7 +385,10 @@ class AffineTransformTransformer(
371
385
  # The weights are stacked A|B where A is the transform and B is a single row
372
386
  # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
373
387
  sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
374
- data = xp.concat((data, xp.ones(sample_shape, dtype=data.dtype)), axis=axis_idx)
388
+ data = xp.concat(
389
+ (data, xp_create(xp.ones, sample_shape, dtype=data.dtype, device=array_device(data))),
390
+ axis=axis_idx,
391
+ )
375
392
 
376
393
  if axis_idx in [-1, len(message.dims) - 1]:
377
394
  data = xp.matmul(data, self._state.weights)
@@ -12,8 +12,6 @@ from ezmsg.baseproc import (
12
12
  from ezmsg.util.messages.axisarray import AxisArray
13
13
  from ezmsg.util.messages.util import replace
14
14
 
15
- from ezmsg.sigproc.sampler import SampleMessage
16
-
17
15
 
18
16
  class RollingScalerSettings(ez.Settings):
19
17
  axis: str = "time"
@@ -168,8 +166,8 @@ class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, Axis
168
166
 
169
167
  self._state.samples.append((n_b, mean_b, M2_b))
170
168
 
171
- def partial_fit(self, message: SampleMessage) -> None:
172
- x = message.sample.data
169
+ def partial_fit(self, message: AxisArray) -> None:
170
+ x = message.data
173
171
  self._add_batch_stats(x)
174
172
 
175
173
  def _process(self, message: AxisArray) -> AxisArray:
ezmsg/sigproc/sampler.py CHANGED
@@ -21,7 +21,7 @@ from ezmsg.util.messages.util import replace
21
21
 
22
22
  from .util.axisarray_buffer import HybridAxisArrayBuffer
23
23
  from .util.buffer import UpdateStrategy
24
- from .util.message import SampleMessage, SampleTriggerMessage
24
+ from .util.message import SampleTriggerMessage
25
25
  from .util.profile import profile_subpub
26
26
 
27
27
 
@@ -75,17 +75,7 @@ class SamplerState:
75
75
 
76
76
 
77
77
  class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, AxisArray, SamplerState]):
78
- def __call__(self, message: AxisArray | SampleTriggerMessage) -> list[SampleMessage]:
79
- # TODO: Currently we have a single entry point that accepts both
80
- # data and trigger messages and we choose a code path based on
81
- # the message type. However, in the future we will likely replace
82
- # SampleTriggerMessage with an agumented form of AxisArray,
83
- # leveraging its attrs field, which makes this a bit harder.
84
- # We should probably force callers of this object to explicitly
85
- # call `push_trigger` for trigger messages. This will also
86
- # simplify typing somewhat because `push_trigger` should not
87
- # return anything yet we currently have it returning an empty
88
- # list just to be compatible with __call__.
78
+ def __call__(self, message: AxisArray | SampleTriggerMessage) -> list[AxisArray]:
89
79
  if isinstance(message, AxisArray):
90
80
  return super().__call__(message)
91
81
  else:
@@ -109,7 +99,7 @@ class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, Axi
109
99
  self._state.triggers = deque()
110
100
  self._state.triggers.clear()
111
101
 
112
- def _process(self, message: AxisArray) -> list[SampleMessage]:
102
+ def _process(self, message: AxisArray) -> list[AxisArray]:
113
103
  self._state.buffer.write(message)
114
104
 
115
105
  # How much data in the buffer?
@@ -119,7 +109,7 @@ class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, Axi
119
109
  )
120
110
 
121
111
  # Process in reverse order so that we can remove triggers safely as we iterate.
122
- msgs_out: list[SampleMessage] = []
112
+ msgs_out: list[AxisArray] = []
123
113
  for trig_ix in range(len(self._state.triggers) - 1, -1, -1):
124
114
  trig = self._state.triggers[trig_ix]
125
115
  if trig.period is None:
@@ -152,13 +142,13 @@ class SamplerTransformer(BaseStatefulTransformer[SamplerSettings, AxisArray, Axi
152
142
  # Note: buffer will trim itself as needed based on buffer_dur.
153
143
 
154
144
  # Prepare output and drop trigger
155
- msgs_out.append(SampleMessage(trigger=copy.copy(trig), sample=buff_axarr))
145
+ msgs_out.append(replace(buff_axarr, attrs={**buff_axarr.attrs, "trigger": copy.copy(trig)}))
156
146
  del self._state.triggers[trig_ix]
157
147
 
158
148
  msgs_out.reverse() # in-place
159
149
  return msgs_out
160
150
 
161
- def push_trigger(self, message: SampleTriggerMessage) -> list[SampleMessage]:
151
+ def push_trigger(self, message: SampleTriggerMessage) -> list[AxisArray]:
162
152
  # Input is a trigger message that we will use to sample the buffer.
163
153
 
164
154
  if self._state.buffer is None:
@@ -198,7 +188,7 @@ class Sampler(BaseTransformerUnit[SamplerSettings, AxisArray, AxisArray, Sampler
198
188
  SETTINGS = SamplerSettings
199
189
 
200
190
  INPUT_TRIGGER = ez.InputStream(SampleTriggerMessage)
201
- OUTPUT_SIGNAL = ez.OutputStream(SampleMessage)
191
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
202
192
 
203
193
  @ez.subscriber(INPUT_TRIGGER)
204
194
  async def on_trigger(self, msg: SampleTriggerMessage) -> None:
@@ -228,7 +218,8 @@ def sampler(
228
218
 
229
219
  Returns:
230
220
  A generator that expects `.send` either an :obj:`AxisArray` containing streaming data messages,
231
- or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`SampleMessage` s.
221
+ or a :obj:`SampleTriggerMessage` containing a trigger, and yields the list of :obj:`AxisArray` s
222
+ with trigger info stored in ``attrs["trigger"]``.
232
223
  """
233
224
  return SamplerTransformer(
234
225
  settings=SamplerSettings(
@@ -0,0 +1,59 @@
1
+ """Portable helpers for Array API interoperability.
2
+
3
+ These utilities smooth over differences between Array API libraries
4
+ (NumPy, PyTorch, MLX, CuPy, etc.) — in particular around ``device``
5
+ placement and ``dtype`` introspection, which are not uniformly supported.
6
+ """
7
+
8
+ import numpy as np
9
+
10
+
11
+ def array_device(x):
12
+ """Return the device of an array, or ``None`` for device-less libraries."""
13
+ try:
14
+ from array_api_compat import device
15
+
16
+ return device(x)
17
+ except (AttributeError, TypeError):
18
+ return None
19
+
20
+
21
+ def xp_asarray(xp, obj, *, dtype=None, device=None):
22
+ """Portable ``xp.asarray`` that omits unsupported kwargs.
23
+
24
+ Some Array API libraries (e.g. MLX) don't accept a ``device`` keyword.
25
+ This helper builds the kwargs dict dynamically so that only supported
26
+ arguments are forwarded.
27
+ """
28
+ kwargs = {}
29
+ if dtype is not None:
30
+ kwargs["dtype"] = dtype
31
+ if device is not None:
32
+ kwargs["device"] = device
33
+ return xp.asarray(obj, **kwargs)
34
+
35
+
36
+ def xp_create(fn, *args, dtype=None, device=None, **extra):
37
+ """Call a creation function (``zeros``, ``ones``, ``eye``) portably.
38
+
39
+ Omits ``device`` if it is ``None`` (for libraries that don't support it).
40
+ """
41
+ kwargs = dict(extra)
42
+ if dtype is not None:
43
+ kwargs["dtype"] = dtype
44
+ if device is not None:
45
+ kwargs["device"] = device
46
+ return fn(*args, **kwargs)
47
+
48
+
49
+ def is_float_dtype(xp, dtype) -> bool:
50
+ """Check whether *dtype* is a real floating-point type, portably."""
51
+ try:
52
+ return xp.isdtype(dtype, "real floating")
53
+ except AttributeError:
54
+ pass
55
+ # Fallback for libraries without isdtype (e.g. MLX).
56
+ try:
57
+ return xp.issubdtype(dtype, xp.floating)
58
+ except (AttributeError, TypeError):
59
+ return np.issubdtype(np.dtype(dtype), np.floating)
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-sigproc
3
- Version: 2.13.1
3
+ Version: 2.15.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
7
7
  License-File: LICENSE
8
8
  Requires-Python: >=3.10.15
9
9
  Requires-Dist: array-api-compat>=1.11.1
10
- Requires-Dist: ezmsg-baseproc>=1.1.0
10
+ Requires-Dist: ezmsg-baseproc>=1.3.0
11
11
  Requires-Dist: ezmsg>=3.6.0
12
12
  Requires-Dist: numba>=0.61.0
13
13
  Requires-Dist: numpy>=1.26.0
@@ -1,8 +1,8 @@
1
1
  ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
2
- ezmsg/sigproc/__version__.py,sha256=LrOgVsOxSmuj7RbXPLo3yIvC77lH9VlW4tk7Ihs46rY,706
2
+ ezmsg/sigproc/__version__.py,sha256=a2xWIti3jyJFUZUTWeA586-E0pruFwlKCcTQl-q5UQQ,706
3
3
  ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
4
4
  ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
5
- ezmsg/sigproc/affinetransform.py,sha256=mjA21DRVYm0kS2tK7dNR_mU5XAxDbJGuuhnuzz0gtw4,21679
5
+ ezmsg/sigproc/affinetransform.py,sha256=rnbQ2QtLq15eoFrqL-ij12yXSoLLq5ncOwRRzQgB1lA,22574
6
6
  ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,9362
7
7
  ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
8
8
  ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
@@ -33,8 +33,8 @@ ezmsg/sigproc/merge.py,sha256=LmuN3LDIZF7DynMcjLp7eGc2G3Yxks9Zd8-luSqqXuA,15436
33
33
  ezmsg/sigproc/messages.py,sha256=KQczHTeifn4BZycChN8ZcpfZoQW3lC_xuFmN72QT97s,925
34
34
  ezmsg/sigproc/quantize.py,sha256=y7T4_67BHZluX3gyl2anp8iL6EEI6JvsK7Pmp1vapsk,2268
35
35
  ezmsg/sigproc/resample.py,sha256=3mm9pvxryNVhQuTCIMW3ToUkUfbVOCsIgvXUiurit1Y,11389
36
- ezmsg/sigproc/rollingscaler.py,sha256=GcLctVAWTmx9J39r0-dt3e7C_hs25s7M0dDnKiGhkC4,8955
37
- ezmsg/sigproc/sampler.py,sha256=iOk2YoUX22u9iTjFKimzP5V074RDBVcmswgfyxvZRZo,10761
36
+ ezmsg/sigproc/rollingscaler.py,sha256=1cJf-wESetWJihrIWKJUycSkH_5OG3K5JncxjM-1TTI,8895
37
+ ezmsg/sigproc/sampler.py,sha256=IgQ4d1VE-aaf0U2vkGDXtIwsKfpoW9ZR_s8Fi9VW3Kc,10103
38
38
  ezmsg/sigproc/scaler.py,sha256=kVXjRbqoxJ5yJICGsGagRXYIDW3-oihSnbBj-n3s55o,6816
39
39
  ezmsg/sigproc/signalinjector.py,sha256=mB62H2b-ScgPtH1jajEpxgDHqdb-RKekQfgyNncsE8Y,2874
40
40
  ezmsg/sigproc/singlebandpow.py,sha256=BVlWhFI6zU3ME3EVdZbwf-FMz1d2sfuNFDKXs1hn5HM,4353
@@ -55,6 +55,7 @@ ezmsg/sigproc/math/log.py,sha256=JhjSqLnQnvx_3F4txRYHuUPSJ12Yj2HvRTsCMNvlxpo,202
55
55
  ezmsg/sigproc/math/pow.py,sha256=0sdlXFUEBXmpEV_i75oshGRjMguv8L13nLt7hlvdX3E,1284
56
56
  ezmsg/sigproc/math/scale.py,sha256=4_xHcHNuf13E1fxIF5vbkPfkN4En6zkfPIKID7lCERk,1133
57
57
  ezmsg/sigproc/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
+ ezmsg/sigproc/util/array.py,sha256=S6wTGRbtJpU5I9u2t0MTUAMSKkm1PJref145s87xHLg,1831
58
59
  ezmsg/sigproc/util/asio.py,sha256=aAj0e7OoBvkRy28k05HL2s9YPCTxOddc05xMN-qd4lQ,577
59
60
  ezmsg/sigproc/util/axisarray_buffer.py,sha256=TGDeC6CXmmp7OUuiGd6xYQijRGYDE4QGdWxjK5Vs3nE,14057
60
61
  ezmsg/sigproc/util/buffer.py,sha256=83Gm0IuowmcMlXgLFB_rz8_ZPhkwG4DNNejyWJDKJl8,19658
@@ -62,7 +63,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
62
63
  ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
63
64
  ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
64
65
  ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
65
- ezmsg_sigproc-2.13.1.dist-info/METADATA,sha256=mjyCiq1zy3JbgayLzq3BQLB-i1LOouUeGHe0tn3v7MY,1909
66
- ezmsg_sigproc-2.13.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
67
- ezmsg_sigproc-2.13.1.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
68
- ezmsg_sigproc-2.13.1.dist-info/RECORD,,
66
+ ezmsg_sigproc-2.15.0.dist-info/METADATA,sha256=WHVPI25XYKw4BAzF8749TZ5k-rg0Ik1iiH-p5ze8byc,1909
67
+ ezmsg_sigproc-2.15.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
68
+ ezmsg_sigproc-2.15.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
69
+ ezmsg_sigproc-2.15.0.dist-info/RECORD,,