ezmsg-sigproc 2.13.0__py3-none-any.whl → 2.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.13.0'
32
- __version_tuple__ = version_tuple = (2, 13, 0)
31
+ __version__ = version = '2.13.1'
32
+ __version_tuple__ = version_tuple = (2, 13, 1)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -215,58 +215,11 @@ class AffineTransformTransformer(
215
215
  if weights is not None:
216
216
  weights = np.ascontiguousarray(weights)
217
217
 
218
- self._state.weights = weights
219
-
220
- # Note: If weights were scipy.sparse BSR then maybe we could use automate this next part.
221
- # However, that would break compatibility with Array API.
222
-
223
- # --- Block-diagonal cluster detection ---
224
- # Clusters are a list of (input_indices, output_indices) tuples.
225
- n_in, n_out = weights.shape
226
- if self.settings.channel_clusters is not None:
227
- # Validate input index bounds
228
- all_in = np.concatenate([np.asarray(group) for group in self.settings.channel_clusters])
229
- if np.any((all_in < 0) | (all_in >= n_in)):
230
- raise ValueError(
231
- "channel_clusters contains out-of-range input indices " f"(valid range: 0..{n_in - 1})"
232
- )
233
-
234
- # Derive output indices from non-zero weights for each input cluster
235
- clusters = []
236
- for group in self.settings.channel_clusters:
237
- in_idx = np.asarray(group)
238
- out_idx = np.where(np.any(weights[in_idx, :] != 0, axis=0))[0]
239
- clusters.append((in_idx, out_idx))
240
-
241
- max_cross = _max_cross_cluster_weight(weights, clusters)
242
- if max_cross > 0:
243
- ez.logger.warning(
244
- f"Non-zero cross-cluster weights detected (max abs: {max_cross:.2e}). "
245
- "These will be ignored in block-diagonal multiplication."
246
- )
247
- else:
248
- clusters = _find_block_diagonal_clusters(weights)
249
- if clusters is not None:
250
- ez.logger.info(
251
- f"Auto-detected {len(clusters)} block-diagonal clusters "
252
- f"(sizes: {[(len(i), len(o)) for i, o in clusters]})"
253
- )
254
-
255
- # Merge small clusters to avoid excessive loop overhead
256
- if clusters is not None:
257
- clusters = _merge_small_clusters(clusters, self.settings.min_cluster_size)
258
-
259
- if clusters is not None and len(clusters) > 1:
260
- self._state.n_out = n_out
261
- self._state.clusters = [
262
- (in_idx, out_idx, np.ascontiguousarray(weights[np.ix_(in_idx, out_idx)]))
263
- for in_idx, out_idx in clusters
264
- ]
265
- self._state.weights = None
266
- else:
267
- self._state.clusters = None
218
+ # Cluster detection + weight storage (delegated)
219
+ self.set_weights(weights, recalc_clusters=True)
268
220
 
269
221
  # --- Axis label handling (for non-square transforms, non-cluster path) ---
222
+ n_in, n_out = weights.shape
270
223
  axis = self.settings.axis or message.dims[-1]
271
224
  if axis in message.axes and hasattr(message.axes[axis], "data") and n_in != n_out:
272
225
  in_labels = message.axes[axis].data
@@ -302,6 +255,80 @@ class AffineTransformTransformer(
302
255
  for in_idx, out_idx, sub_w in self._state.clusters
303
256
  ]
304
257
 
258
+ def set_weights(self, weights, *, recalc_clusters=False) -> None:
259
+ """Replace weight values, optionally recalculating cluster decomposition.
260
+
261
+ *weights* must be in **canonical orientation** (``right_multiply``
262
+ already applied by the caller or by ``_reset_state``). The array may
263
+ live in any Array-API namespace (NumPy, CuPy, etc.).
264
+
265
+ Args:
266
+ weights: Weight matrix in canonical orientation.
267
+ recalc_clusters: When True, re-run block-diagonal cluster detection
268
+ and store the new decomposition. When False (default), reuse
269
+ the existing cluster structure and only update weight values.
270
+ """
271
+ if recalc_clusters:
272
+ # Note: If weights were scipy.sparse BSR then maybe we could automate this next part.
273
+ # However, that would break compatibility with Array API.
274
+
275
+ # --- Block-diagonal cluster detection ---
276
+ # Clusters are a list of (input_indices, output_indices) tuples.
277
+ w_np = np.ascontiguousarray(weights)
278
+ n_in, n_out = w_np.shape
279
+ if self.settings.channel_clusters is not None:
280
+ # Validate input index bounds
281
+ all_in = np.concatenate([np.asarray(group) for group in self.settings.channel_clusters])
282
+ if np.any((all_in < 0) | (all_in >= n_in)):
283
+ raise ValueError(
284
+ "channel_clusters contains out-of-range input indices " f"(valid range: 0..{n_in - 1})"
285
+ )
286
+
287
+ # Derive output indices from non-zero weights for each input cluster
288
+ clusters = []
289
+ for group in self.settings.channel_clusters:
290
+ in_idx = np.asarray(group)
291
+ out_idx = np.where(np.any(w_np[in_idx, :] != 0, axis=0))[0]
292
+ clusters.append((in_idx, out_idx))
293
+
294
+ max_cross = _max_cross_cluster_weight(w_np, clusters)
295
+ if max_cross > 0:
296
+ ez.logger.warning(
297
+ f"Non-zero cross-cluster weights detected (max abs: {max_cross:.2e}). "
298
+ "These will be ignored in block-diagonal multiplication."
299
+ )
300
+ else:
301
+ clusters = _find_block_diagonal_clusters(w_np)
302
+ if clusters is not None:
303
+ ez.logger.info(
304
+ f"Auto-detected {len(clusters)} block-diagonal clusters "
305
+ f"(sizes: {[(len(i), len(o)) for i, o in clusters]})"
306
+ )
307
+
308
+ # Merge small clusters to avoid excessive loop overhead
309
+ if clusters is not None:
310
+ clusters = _merge_small_clusters(clusters, self.settings.min_cluster_size)
311
+
312
+ if clusters is not None and len(clusters) > 1:
313
+ self._state.n_out = n_out
314
+ self._state.clusters = [
315
+ (in_idx, out_idx, np.ascontiguousarray(w_np[np.ix_(in_idx, out_idx)]))
316
+ for in_idx, out_idx in clusters
317
+ ]
318
+ self._state.weights = None
319
+ else:
320
+ self._state.weights = weights
321
+ self._state.clusters = None
322
+ else:
323
+ xp = get_namespace(weights)
324
+ if self._state.clusters is not None:
325
+ self._state.clusters = [
326
+ (in_idx, out_idx, xp.take(xp.take(weights, in_idx, axis=0), out_idx, axis=1))
327
+ for in_idx, out_idx, _ in self._state.clusters
328
+ ]
329
+ else:
330
+ self._state.weights = weights
331
+
305
332
  def _block_diagonal_matmul(self, xp, data, axis_idx):
306
333
  """Perform matmul using block-diagonal decomposition.
307
334
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-sigproc
3
- Version: 2.13.0
3
+ Version: 2.13.1
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
@@ -1,8 +1,8 @@
1
1
  ezmsg/sigproc/__init__.py,sha256=8K4IcOA3-pfzadoM6s2Sfg5460KlJUocGgyTJTJl96U,52
2
- ezmsg/sigproc/__version__.py,sha256=_4LOjlEcfZzfuqIlglDZmVBPO4LyQ8P97qO716YoUL8,706
2
+ ezmsg/sigproc/__version__.py,sha256=LrOgVsOxSmuj7RbXPLo3yIvC77lH9VlW4tk7Ihs46rY,706
3
3
  ezmsg/sigproc/activation.py,sha256=83vnTa3ZcC4Q3VSWcGfaqhCEqYRNySUOyVpMHZXfz-c,2755
4
4
  ezmsg/sigproc/adaptive_lattice_notch.py,sha256=ThUR48mbSHuThkimtD0j4IXNMrOVcpZgGhE7PCYfXhU,8818
5
- ezmsg/sigproc/affinetransform.py,sha256=ZugiQg89Ly1I9SDgf0ZzgU2XdwVDmPrU7-orO9yrt7w,20210
5
+ ezmsg/sigproc/affinetransform.py,sha256=mjA21DRVYm0kS2tK7dNR_mU5XAxDbJGuuhnuzz0gtw4,21679
6
6
  ezmsg/sigproc/aggregate.py,sha256=7Hdz1m-S6Cl9h0oRQHeS_UTGBemhOB4XdFyX6cGcdHo,9362
7
7
  ezmsg/sigproc/bandpower.py,sha256=dAhH56sUrXNhcRFymTTwjdM_KcU5OxFzrR_sxIPAxyw,2264
8
8
  ezmsg/sigproc/base.py,sha256=SJvKEb8gw6mUMwlV5sH0iPG0bXrgS8tvkPwhI-j89MQ,3672
@@ -62,7 +62,7 @@ ezmsg/sigproc/util/message.py,sha256=ppN3IYtIAwrxWG9JOvgWFn1wDdIumkEzYFfqpH9VQkY
62
62
  ezmsg/sigproc/util/profile.py,sha256=eVOo9pXgusrnH1yfRdd2RsM7Dbe2UpyC0LJ9MfGpB08,416
63
63
  ezmsg/sigproc/util/sparse.py,sha256=NjbJitCtO0B6CENTlyd9c-lHEJwoCan-T3DIgPyeShw,4834
64
64
  ezmsg/sigproc/util/typeresolution.py,sha256=fMFzLi63dqCIclGFLcMdM870OYxJnkeWw6aWKNMk718,362
65
- ezmsg_sigproc-2.13.0.dist-info/METADATA,sha256=RXENX541lABAic8oUDuT8vQwx9nlWY9JETyXYKxdeTQ,1909
66
- ezmsg_sigproc-2.13.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
67
- ezmsg_sigproc-2.13.0.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
68
- ezmsg_sigproc-2.13.0.dist-info/RECORD,,
65
+ ezmsg_sigproc-2.13.1.dist-info/METADATA,sha256=mjyCiq1zy3JbgayLzq3BQLB-i1LOouUeGHe0tn3v7MY,1909
66
+ ezmsg_sigproc-2.13.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
67
+ ezmsg_sigproc-2.13.1.dist-info/licenses/LICENSE,sha256=seu0tKhhAMPCUgc1XpXGGaCxY1YaYvFJwqFuQZAl2go,1100
68
+ ezmsg_sigproc-2.13.1.dist-info/RECORD,,