braindecode 1.5.0.dev1010__py3-none-any.whl → 1.5.0.dev1013__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ from . import functional
4
4
  from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
5
5
  from .transforms import (
6
6
  AmplitudeScale,
7
+ BandRotation,
7
8
  BandstopFilter,
8
9
  ChannelsDropout,
9
10
  ChannelsReref,
@@ -47,6 +48,7 @@ __all__ = [
47
48
  "SegmentationReconstruction",
48
49
  "MaskEncoding",
49
50
  "AmplitudeScale",
51
+ "BandRotation",
50
52
  "ChannelsReref",
51
53
  "functional",
52
54
  ]
@@ -1298,3 +1298,130 @@ def amplitude_scale(
1298
1298
  X = s * X
1299
1299
 
1300
1300
  return X, y
1301
+
1302
+
1303
+ def band_rotation(
1304
+ X: torch.Tensor,
1305
+ y: torch.Tensor,
1306
+ num_bands: int = 2,
1307
+ electrodes_per_band: int = 16,
1308
+ band_offsets: tuple[int, ...] = (-1, 0, 1),
1309
+ max_temporal_jitter: int = 0,
1310
+ circular_jitter: bool = True,
1311
+ random_state: int | np.random.RandomState | None = None,
1312
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1313
+ """Per-band electrode rotation + inter-band temporal jitter.
1314
+
1315
+ Models small wristband rotation between sessions and relative timing
1316
+ noise between two arms. Introduced in [Sivakumar2024]_ for the
1317
+ emg2qwerty CTC keystroke decoding task: each electrode band gets its
1318
+ own circular roll along the channel axis (``Uniform(band_offsets)``
1319
+ positions), and band 1 also gets a sample-level temporal shift
1320
+ (``Uniform(-max_temporal_jitter, +max_temporal_jitter)``) along the
1321
+ time axis.
1322
+
1323
+ Channel layout assumes ``(B, num_bands * electrodes_per_band, T)`` with
1324
+ bands contiguous along the channel axis. Same offset / shift is
1325
+ applied to every sample in the batch (one set of parameters per call).
1326
+
1327
+ Parameters
1328
+ ----------
1329
+ X : torch.Tensor
1330
+ EMG input batch of shape ``(B, C, T)`` with
1331
+ ``C == num_bands * electrodes_per_band``.
1332
+ y : torch.Tensor
1333
+ Labels (returned unchanged).
1334
+ num_bands : int, optional
1335
+ Number of electrode bands (e.g. ``2`` for left + right wristband).
1336
+ Must be ``>= 1``. Defaults to 2.
1337
+ electrodes_per_band : int, optional
1338
+ Electrodes per band (e.g. ``16``). Must be ``>= 1``. Defaults
1339
+ to 16.
1340
+ band_offsets : tuple of int, optional
1341
+ Per-band roll values to sample from uniformly. ``(-1, 0, 1)``
1342
+ covers ±1-electrode misalignment. Must be non-empty. Defaults
1343
+ to ``(-1, 0, 1)``.
1344
+ max_temporal_jitter : int, optional
1345
+ Max ±-sample temporal shift applied to band 1 only when
1346
+ ``num_bands >= 2``. Defaults to 0 (disabled). Must be ``>= 0``.
1347
+ circular_jitter : bool, optional
1348
+ If True (the default, paper-faithful), the temporal jitter is a
1349
+ circular ``torch.roll`` — samples shifted off one edge wrap to
1350
+ the other. If False, the gap left by the shift is zero-padded
1351
+ and the shifted-off samples are dropped, avoiding wrap-around
1352
+ discontinuity at the cost of a small zeroed margin. Has no
1353
+ effect when ``max_temporal_jitter == 0``.
1354
+ random_state : int | numpy.random.RandomState, optional
1355
+ Seed / generator for sampling rotation + jitter values.
1356
+
1357
+ Returns
1358
+ -------
1359
+ torch.Tensor
1360
+ Transformed inputs.
1361
+ torch.Tensor
1362
+ Labels (unchanged).
1363
+
1364
+ References
1365
+ ----------
1366
+ .. [Sivakumar2024] Sivakumar, V., Seely, J., Du, A., Bittner, S. R.,
1367
+ Berenzweig, A., Bolarinwa, A., Gramfort, A., & Mandel, M. I. (2024).
1368
+ "emg2qwerty: A Large Dataset with Baselines for Touch Typing using
1369
+ Surface Electromyography." *NeurIPS Datasets and Benchmarks Track*.
1370
+ """
1371
+ if num_bands < 1:
1372
+ raise ValueError(f"num_bands must be >= 1, got {num_bands}")
1373
+ if electrodes_per_band < 1:
1374
+ raise ValueError(f"electrodes_per_band must be >= 1, got {electrodes_per_band}")
1375
+ # Normalise to a tuple before truth-testing so callers can pass any
1376
+ # sequence-like (incl. ``np.ndarray``) without hitting numpy's
1377
+ # ambiguous-truth-value error on ``if not band_offsets``.
1378
+ band_offsets = tuple(band_offsets)
1379
+ if not band_offsets:
1380
+ raise ValueError("band_offsets must be non-empty")
1381
+ if not all(isinstance(o, (int, np.integer)) for o in band_offsets):
1382
+ raise ValueError(f"band_offsets must contain integers, got {band_offsets!r}")
1383
+ if max_temporal_jitter < 0:
1384
+ raise ValueError(f"max_temporal_jitter must be >= 0, got {max_temporal_jitter}")
1385
+ expected_channels = num_bands * electrodes_per_band
1386
+ if X.shape[1] != expected_channels:
1387
+ raise ValueError(
1388
+ f"X.shape[1]={X.shape[1]} != num_bands * electrodes_per_band="
1389
+ f"{expected_channels}"
1390
+ )
1391
+
1392
+ rng = check_random_state(random_state)
1393
+ band_offsets_arr = np.asarray(band_offsets)
1394
+ out = X.clone()
1395
+
1396
+ # Per-band channel-axis rolls. A vectorized ``torch.gather`` was
1397
+ # benchmarked and is ~16 % slower for the typical ``num_bands == 2``
1398
+ # case on CPU (the index tensor is larger than what two contiguous
1399
+ # rolls touch); the gather only wins past ``num_bands >= 8``.
1400
+ for b in range(num_bands):
1401
+ offset = int(rng.choice(band_offsets_arr))
1402
+ if offset:
1403
+ sl = slice(b * electrodes_per_band, (b + 1) * electrodes_per_band)
1404
+ out[:, sl, :] = torch.roll(out[:, sl, :], offset, dims=1)
1405
+
1406
+ # Inter-band temporal jitter — paper recipe applies it to band 1 only.
1407
+ if max_temporal_jitter > 0 and num_bands >= 2:
1408
+ shift = int(rng.randint(-max_temporal_jitter, max_temporal_jitter + 1))
1409
+ if shift:
1410
+ sl = slice(electrodes_per_band, 2 * electrodes_per_band)
1411
+ band1 = out[:, sl, :]
1412
+ if circular_jitter:
1413
+ # Paper-faithful circular shift; wraps end-of-window
1414
+ # samples to the start (and vice versa).
1415
+ out[:, sl, :] = torch.roll(band1, shift, dims=2)
1416
+ else:
1417
+ # Crop-and-pad shift: drop samples that fall off one end,
1418
+ # zero-pad the gap on the other. Avoids the wrap-around
1419
+ # discontinuity at the cost of a ``|shift|``-sample margin.
1420
+ shifted = torch.zeros_like(band1)
1421
+ if shift > 0:
1422
+ shifted[:, :, shift:] = band1[:, :, :-shift]
1423
+ else: # shift < 0
1424
+ shifted[:, :, :shift] = band1[:, :, -shift:]
1425
+ out[:, sl, :] = shifted
1426
+
1427
+ return out, y
@@ -16,6 +16,7 @@ from mne.channels import make_standard_montage
16
16
  from .base import Transform
17
17
  from .functional import (
18
18
  amplitude_scale,
19
+ band_rotation,
19
20
  bandstop_filter,
20
21
  channels_dropout,
21
22
  channels_permute,
@@ -1356,3 +1357,99 @@ class AmplitudeScale(Transform):
1356
1357
  def get_augmentation_params(self, *batch):
1357
1358
  """Return transform parameters."""
1358
1359
  return {"random_state": self.rng, "scale": self.scale}
1360
+
1361
+
1362
+ class BandRotation(Transform):
1363
+ """Per-band electrode rotation + inter-band temporal jitter.
1364
+
1365
+ Models small wristband rotation between sessions and relative timing
1366
+ noise between two arms. Introduced in [Sivakumar2024]_ for the
1367
+ emg2qwerty surface-EMG keystroke decoding task: the channel axis is
1368
+ laid out as ``(B, num_bands * electrodes_per_band, T)`` with bands
1369
+ contiguous, each band gets a uniform circular roll along the channel
1370
+ axis, and when ``num_bands >= 2``, band 1 also gets a sample-level
1371
+ temporal shift. The same offset / shift is applied to every sample
1372
+ in a transformed sub-batch (one set of parameters per call).
1373
+
1374
+ Parameters
1375
+ ----------
1376
+ probability : float
1377
+ Float setting the probability of applying the operation.
1378
+ num_bands : int, optional
1379
+ Number of electrode bands (e.g. ``2`` for left + right wristband).
1380
+ Must be ``>= 1``. Defaults to 2.
1381
+ electrodes_per_band : int, optional
1382
+ Electrodes per band (e.g. ``16``). Must be ``>= 1``. Defaults
1383
+ to 16.
1384
+ band_offsets : tuple of int, optional
1385
+ Per-band roll values to sample from uniformly. ``(-1, 0, 1)``
1386
+ covers ±1-electrode misalignment. Must be non-empty. Defaults
1387
+ to ``(-1, 0, 1)``.
1388
+ max_temporal_jitter : int, optional
1389
+ Max ±-sample temporal shift applied to band 1. Defaults to 0
1390
+ (jitter disabled). Must be ``>= 0``. The emg2qwerty paper uses
1391
+ 120 samples (60 ms at 2 kHz).
1392
+ circular_jitter : bool, optional
1393
+ If True (default, paper-faithful) the jitter is a circular roll;
1394
+ if False the gap left by the shift is zero-padded. See
1395
+ :func:`band_rotation`.
1396
+ random_state : int | numpy.random.RandomState, optional
1397
+ Seed for the rotation / jitter sampler. Defaults to None.
1398
+
1399
+ References
1400
+ ----------
1401
+ .. [Sivakumar2024] Sivakumar, V., Seely, J., Du, A., Bittner, S. R.,
1402
+ Berenzweig, A., Bolarinwa, A., Gramfort, A., & Mandel, M. I. (2024).
1403
+ "emg2qwerty: A Large Dataset with Baselines for Touch Typing using
1404
+ Surface Electromyography." *NeurIPS Datasets and Benchmarks Track*.
1405
+ """
1406
+
1407
+ operation = staticmethod(band_rotation) # type: ignore[assignment]
1408
+
1409
+ def __init__(
1410
+ self,
1411
+ probability,
1412
+ num_bands=2,
1413
+ electrodes_per_band=16,
1414
+ band_offsets=(-1, 0, 1),
1415
+ max_temporal_jitter=0,
1416
+ circular_jitter=True,
1417
+ random_state=None,
1418
+ ):
1419
+ super().__init__(probability=probability, random_state=random_state)
1420
+ # Up-front parameter validation; the underlying ``band_rotation``
1421
+ # also re-checks at call time, but raising here surfaces config
1422
+ # mistakes when the Transform is built rather than on the first
1423
+ # batch.
1424
+ if num_bands < 1:
1425
+ raise ValueError(f"num_bands must be >= 1, got {num_bands}")
1426
+ if electrodes_per_band < 1:
1427
+ raise ValueError(
1428
+ f"electrodes_per_band must be >= 1, got {electrodes_per_band}"
1429
+ )
1430
+ band_offsets = tuple(band_offsets)
1431
+ if not band_offsets:
1432
+ raise ValueError("band_offsets must be non-empty")
1433
+ if not all(isinstance(o, (int, np.integer)) for o in band_offsets):
1434
+ raise ValueError(
1435
+ f"band_offsets must contain integers, got {band_offsets!r}"
1436
+ )
1437
+ if max_temporal_jitter < 0:
1438
+ raise ValueError(
1439
+ f"max_temporal_jitter must be >= 0, got {max_temporal_jitter}"
1440
+ )
1441
+ self.num_bands = num_bands
1442
+ self.electrodes_per_band = electrodes_per_band
1443
+ self.band_offsets = band_offsets
1444
+ self.max_temporal_jitter = max_temporal_jitter
1445
+ self.circular_jitter = circular_jitter
1446
+
1447
+ def get_augmentation_params(self, *batch):
1448
+ return {
1449
+ "num_bands": self.num_bands,
1450
+ "electrodes_per_band": self.electrodes_per_band,
1451
+ "band_offsets": self.band_offsets,
1452
+ "max_temporal_jitter": self.max_temporal_jitter,
1453
+ "circular_jitter": self.circular_jitter,
1454
+ "random_state": self.rng,
1455
+ }
braindecode/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.5.0.dev1010"
1
+ __version__ = "1.5.0.dev1013"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.5.0.dev1010
3
+ Version: 1.5.0.dev1013
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -3,11 +3,11 @@ braindecode/classifier.py,sha256=7kC_oY_UzHEes_WWdCvEpiA1ZKxMeuLL5tIPp5rfcpg,962
3
3
  braindecode/eegneuralnet.py,sha256=xjE6aPZdCQPs29NIpy_m1GLMMC2WZ3Db0Fuh1-xE1h4,13827
4
4
  braindecode/regressor.py,sha256=KiMJpqCUPWA2k2JWk9HGYTzeoBqJ4gAKEudeUVcFZY4,9266
5
5
  braindecode/util.py,sha256=f8bNIwt-SwsHqheH_BADQxTtA9oPt3Lb7GFnoI-Huwc,14101
6
- braindecode/version.py,sha256=dhwXuCj7AR3Ot0DD1Q7xVlLGgTrkwdesqS9QyInsZcM,30
7
- braindecode/augmentation/__init__.py,sha256=4xune2QUK6KHMKsAqijF7I9eeiVbP0wEoQJjCNLNcKM,1081
6
+ braindecode/version.py,sha256=FNlUpdx-TzIz_DQr0F-sTkgDmsCNiRxs8Z1XdxuVyj0,30
7
+ braindecode/augmentation/__init__.py,sha256=hmnjUsL_DX5BxYVdyNReh7T3YRQEJKYzciB1UwYHRvc,1119
8
8
  braindecode/augmentation/base.py,sha256=OJ1shOljI1yTY9zh2qWxQwivlY43sfx9Q-MAyMhxtPs,7338
9
- braindecode/augmentation/functional.py,sha256=q2k6mAXrujYlOZUndcjZN8e8b-6oJF1gGsORAI23hyE,43998
10
- braindecode/augmentation/transforms.py,sha256=x-3pwX0PtMHfSnPLGKNXbpTSk7j17Ci2FG_-646scg4,47268
9
+ braindecode/augmentation/functional.py,sha256=jGKTNgWf9ZIFGcShYD0Qlb9IuC47OIwC813eCEcTPsM,49757
10
+ braindecode/augmentation/transforms.py,sha256=QPS-cjbHz0TcKbd5Uuiag0s2Kt83xDhH1juBvAK6C5M,51426
11
11
  braindecode/datasets/__init__.py,sha256=rVOBadwqYBiMz5kl7nGiBOmMgr11xvjS4nuzzZTOn1U,1102
12
12
  braindecode/datasets/base.py,sha256=3lKLZQO4hfA-dv_JJEfPwyZ5nzRkLTu4qiRAqFVZUUQ,70508
13
13
  braindecode/datasets/bbci.py,sha256=SCm7OnCObotILQ0B1EdmZPoyJtzsRXpeU_gNKtqQLSc,19288
@@ -128,9 +128,9 @@ braindecode/visualization/frequency.py,sha256=gNwkn9yIik5SUp7d9HE9J_vPVGyzNsxxCO
128
128
  braindecode/visualization/metrics.py,sha256=j01kc04P9uEkQ2g2Tt2C76yr6soIj31PAuBMflrmODg,13615
129
129
  braindecode/visualization/sanity.py,sha256=nNClauUC8dCj_KCy_1RmaPDQAqExLczfPtUeQ7k9-Q0,4812
130
130
  braindecode/visualization/topology.py,sha256=mXxUfCCUJqa_cMF4y6GC3_A-qBCcS4uTc0EzBolkytE,2274
131
- braindecode-1.5.0.dev1010.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
132
- braindecode-1.5.0.dev1010.dist-info/licenses/NOTICE.txt,sha256=ZFFhigxIaKgDcMjCzPyAVSFV42ztU0kLOENt_kvherw,857
133
- braindecode-1.5.0.dev1010.dist-info/METADATA,sha256=4sQBBGOi3h1EE-DCZxdlh3Hswxjd4jQJ_fTdFBUSsyc,10275
134
- braindecode-1.5.0.dev1010.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
135
- braindecode-1.5.0.dev1010.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
136
- braindecode-1.5.0.dev1010.dist-info/RECORD,,
131
+ braindecode-1.5.0.dev1013.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
132
+ braindecode-1.5.0.dev1013.dist-info/licenses/NOTICE.txt,sha256=ZFFhigxIaKgDcMjCzPyAVSFV42ztU0kLOENt_kvherw,857
133
+ braindecode-1.5.0.dev1013.dist-info/METADATA,sha256=gPiGhTDpD-62qMNoaF6556d2tZHAX_BbV3G2V1Fk790,10275
134
+ braindecode-1.5.0.dev1013.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
135
+ braindecode-1.5.0.dev1013.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
136
+ braindecode-1.5.0.dev1013.dist-info/RECORD,,