braindecode 1.5.0.dev1010__py3-none-any.whl → 1.5.0.dev1015__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +2 -0
- braindecode/augmentation/functional.py +127 -0
- braindecode/augmentation/transforms.py +97 -0
- braindecode/models/emg2qwerty.py +207 -7
- braindecode/version.py +1 -1
- {braindecode-1.5.0.dev1010.dist-info → braindecode-1.5.0.dev1015.dist-info}/METADATA +1 -1
- {braindecode-1.5.0.dev1010.dist-info → braindecode-1.5.0.dev1015.dist-info}/RECORD +11 -11
- {braindecode-1.5.0.dev1010.dist-info → braindecode-1.5.0.dev1015.dist-info}/WHEEL +0 -0
- {braindecode-1.5.0.dev1010.dist-info → braindecode-1.5.0.dev1015.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.5.0.dev1010.dist-info → braindecode-1.5.0.dev1015.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.5.0.dev1010.dist-info → braindecode-1.5.0.dev1015.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,7 @@ from . import functional
|
|
|
4
4
|
from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
|
|
5
5
|
from .transforms import (
|
|
6
6
|
AmplitudeScale,
|
|
7
|
+
BandRotation,
|
|
7
8
|
BandstopFilter,
|
|
8
9
|
ChannelsDropout,
|
|
9
10
|
ChannelsReref,
|
|
@@ -47,6 +48,7 @@ __all__ = [
|
|
|
47
48
|
"SegmentationReconstruction",
|
|
48
49
|
"MaskEncoding",
|
|
49
50
|
"AmplitudeScale",
|
|
51
|
+
"BandRotation",
|
|
50
52
|
"ChannelsReref",
|
|
51
53
|
"functional",
|
|
52
54
|
]
|
|
@@ -1298,3 +1298,130 @@ def amplitude_scale(
|
|
|
1298
1298
|
X = s * X
|
|
1299
1299
|
|
|
1300
1300
|
return X, y
|
|
1301
|
+
|
|
1302
|
+
|
|
1303
|
+
def band_rotation(
|
|
1304
|
+
X: torch.Tensor,
|
|
1305
|
+
y: torch.Tensor,
|
|
1306
|
+
num_bands: int = 2,
|
|
1307
|
+
electrodes_per_band: int = 16,
|
|
1308
|
+
band_offsets: tuple[int, ...] = (-1, 0, 1),
|
|
1309
|
+
max_temporal_jitter: int = 0,
|
|
1310
|
+
circular_jitter: bool = True,
|
|
1311
|
+
random_state: int | np.random.RandomState | None = None,
|
|
1312
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1313
|
+
"""Per-band electrode rotation + inter-band temporal jitter.
|
|
1314
|
+
|
|
1315
|
+
Models small wristband rotation between sessions and relative timing
|
|
1316
|
+
noise between two arms. Introduced in [Sivakumar2024]_ for the
|
|
1317
|
+
emg2qwerty CTC keystroke decoding task: each electrode band gets its
|
|
1318
|
+
own circular roll along the channel axis (``Uniform(band_offsets)``
|
|
1319
|
+
positions), and band 1 also gets a sample-level temporal shift
|
|
1320
|
+
(``Uniform(-max_temporal_jitter, +max_temporal_jitter)``) along the
|
|
1321
|
+
time axis.
|
|
1322
|
+
|
|
1323
|
+
Channel layout assumes ``(B, num_bands * electrodes_per_band, T)`` with
|
|
1324
|
+
bands contiguous along the channel axis. Same offset / shift is
|
|
1325
|
+
applied to every sample in the batch (one set of parameters per call).
|
|
1326
|
+
|
|
1327
|
+
Parameters
|
|
1328
|
+
----------
|
|
1329
|
+
X : torch.Tensor
|
|
1330
|
+
EMG input batch of shape ``(B, C, T)`` with
|
|
1331
|
+
``C == num_bands * electrodes_per_band``.
|
|
1332
|
+
y : torch.Tensor
|
|
1333
|
+
Labels (returned unchanged).
|
|
1334
|
+
num_bands : int, optional
|
|
1335
|
+
Number of electrode bands (e.g. ``2`` for left + right wristband).
|
|
1336
|
+
Must be ``>= 1``. Defaults to 2.
|
|
1337
|
+
electrodes_per_band : int, optional
|
|
1338
|
+
Electrodes per band (e.g. ``16``). Must be ``>= 1``. Defaults
|
|
1339
|
+
to 16.
|
|
1340
|
+
band_offsets : tuple of int, optional
|
|
1341
|
+
Per-band roll values to sample from uniformly. ``(-1, 0, 1)``
|
|
1342
|
+
covers ±1-electrode misalignment. Must be non-empty. Defaults
|
|
1343
|
+
to ``(-1, 0, 1)``.
|
|
1344
|
+
max_temporal_jitter : int, optional
|
|
1345
|
+
Max ±-sample temporal shift applied to band 1 only when
|
|
1346
|
+
``num_bands >= 2``. Defaults to 0 (disabled). Must be ``>= 0``.
|
|
1347
|
+
circular_jitter : bool, optional
|
|
1348
|
+
If True (the default, paper-faithful), the temporal jitter is a
|
|
1349
|
+
circular ``torch.roll`` — samples shifted off one edge wrap to
|
|
1350
|
+
the other. If False, the gap left by the shift is zero-padded
|
|
1351
|
+
and the shifted-off samples are dropped, avoiding wrap-around
|
|
1352
|
+
discontinuity at the cost of a small zeroed margin. Has no
|
|
1353
|
+
effect when ``max_temporal_jitter == 0``.
|
|
1354
|
+
random_state : int | numpy.random.RandomState, optional
|
|
1355
|
+
Seed / generator for sampling rotation + jitter values.
|
|
1356
|
+
|
|
1357
|
+
Returns
|
|
1358
|
+
-------
|
|
1359
|
+
torch.Tensor
|
|
1360
|
+
Transformed inputs.
|
|
1361
|
+
torch.Tensor
|
|
1362
|
+
Labels (unchanged).
|
|
1363
|
+
|
|
1364
|
+
References
|
|
1365
|
+
----------
|
|
1366
|
+
.. [Sivakumar2024] Sivakumar, V., Seely, J., Du, A., Bittner, S. R.,
|
|
1367
|
+
Berenzweig, A., Bolarinwa, A., Gramfort, A., & Mandel, M. I. (2024).
|
|
1368
|
+
"emg2qwerty: A Large Dataset with Baselines for Touch Typing using
|
|
1369
|
+
Surface Electromyography." *NeurIPS Datasets and Benchmarks Track*.
|
|
1370
|
+
"""
|
|
1371
|
+
if num_bands < 1:
|
|
1372
|
+
raise ValueError(f"num_bands must be >= 1, got {num_bands}")
|
|
1373
|
+
if electrodes_per_band < 1:
|
|
1374
|
+
raise ValueError(f"electrodes_per_band must be >= 1, got {electrodes_per_band}")
|
|
1375
|
+
# Normalise to a tuple before truth-testing so callers can pass any
|
|
1376
|
+
# sequence-like (incl. ``np.ndarray``) without hitting numpy's
|
|
1377
|
+
# ambiguous-truth-value error on ``if not band_offsets``.
|
|
1378
|
+
band_offsets = tuple(band_offsets)
|
|
1379
|
+
if not band_offsets:
|
|
1380
|
+
raise ValueError("band_offsets must be non-empty")
|
|
1381
|
+
if not all(isinstance(o, (int, np.integer)) for o in band_offsets):
|
|
1382
|
+
raise ValueError(f"band_offsets must contain integers, got {band_offsets!r}")
|
|
1383
|
+
if max_temporal_jitter < 0:
|
|
1384
|
+
raise ValueError(f"max_temporal_jitter must be >= 0, got {max_temporal_jitter}")
|
|
1385
|
+
expected_channels = num_bands * electrodes_per_band
|
|
1386
|
+
if X.shape[1] != expected_channels:
|
|
1387
|
+
raise ValueError(
|
|
1388
|
+
f"X.shape[1]={X.shape[1]} != num_bands * electrodes_per_band="
|
|
1389
|
+
f"{expected_channels}"
|
|
1390
|
+
)
|
|
1391
|
+
|
|
1392
|
+
rng = check_random_state(random_state)
|
|
1393
|
+
band_offsets_arr = np.asarray(band_offsets)
|
|
1394
|
+
out = X.clone()
|
|
1395
|
+
|
|
1396
|
+
# Per-band channel-axis rolls. A vectorized ``torch.gather`` was
|
|
1397
|
+
# benchmarked and is ~16 % slower for the typical ``num_bands == 2``
|
|
1398
|
+
# case on CPU (the index tensor is larger than what two contiguous
|
|
1399
|
+
# rolls touch); the gather only wins past ``num_bands >= 8``.
|
|
1400
|
+
for b in range(num_bands):
|
|
1401
|
+
offset = int(rng.choice(band_offsets_arr))
|
|
1402
|
+
if offset:
|
|
1403
|
+
sl = slice(b * electrodes_per_band, (b + 1) * electrodes_per_band)
|
|
1404
|
+
out[:, sl, :] = torch.roll(out[:, sl, :], offset, dims=1)
|
|
1405
|
+
|
|
1406
|
+
# Inter-band temporal jitter — paper recipe applies it to band 1 only.
|
|
1407
|
+
if max_temporal_jitter > 0 and num_bands >= 2:
|
|
1408
|
+
shift = int(rng.randint(-max_temporal_jitter, max_temporal_jitter + 1))
|
|
1409
|
+
if shift:
|
|
1410
|
+
sl = slice(electrodes_per_band, 2 * electrodes_per_band)
|
|
1411
|
+
band1 = out[:, sl, :]
|
|
1412
|
+
if circular_jitter:
|
|
1413
|
+
# Paper-faithful circular shift; wraps end-of-window
|
|
1414
|
+
# samples to the start (and vice versa).
|
|
1415
|
+
out[:, sl, :] = torch.roll(band1, shift, dims=2)
|
|
1416
|
+
else:
|
|
1417
|
+
# Crop-and-pad shift: drop samples that fall off one end,
|
|
1418
|
+
# zero-pad the gap on the other. Avoids the wrap-around
|
|
1419
|
+
# discontinuity at the cost of a ``|shift|``-sample margin.
|
|
1420
|
+
shifted = torch.zeros_like(band1)
|
|
1421
|
+
if shift > 0:
|
|
1422
|
+
shifted[:, :, shift:] = band1[:, :, :-shift]
|
|
1423
|
+
else: # shift < 0
|
|
1424
|
+
shifted[:, :, :shift] = band1[:, :, -shift:]
|
|
1425
|
+
out[:, sl, :] = shifted
|
|
1426
|
+
|
|
1427
|
+
return out, y
|
|
@@ -16,6 +16,7 @@ from mne.channels import make_standard_montage
|
|
|
16
16
|
from .base import Transform
|
|
17
17
|
from .functional import (
|
|
18
18
|
amplitude_scale,
|
|
19
|
+
band_rotation,
|
|
19
20
|
bandstop_filter,
|
|
20
21
|
channels_dropout,
|
|
21
22
|
channels_permute,
|
|
@@ -1356,3 +1357,99 @@ class AmplitudeScale(Transform):
|
|
|
1356
1357
|
def get_augmentation_params(self, *batch):
|
|
1357
1358
|
"""Return transform parameters."""
|
|
1358
1359
|
return {"random_state": self.rng, "scale": self.scale}
|
|
1360
|
+
|
|
1361
|
+
|
|
1362
|
+
class BandRotation(Transform):
|
|
1363
|
+
"""Per-band electrode rotation + inter-band temporal jitter.
|
|
1364
|
+
|
|
1365
|
+
Models small wristband rotation between sessions and relative timing
|
|
1366
|
+
noise between two arms. Introduced in [Sivakumar2024]_ for the
|
|
1367
|
+
emg2qwerty surface-EMG keystroke decoding task: the channel axis is
|
|
1368
|
+
laid out as ``(B, num_bands * electrodes_per_band, T)`` with bands
|
|
1369
|
+
contiguous, each band gets a uniform circular roll along the channel
|
|
1370
|
+
axis, and when ``num_bands >= 2``, band 1 also gets a sample-level
|
|
1371
|
+
temporal shift. The same offset / shift is applied to every sample
|
|
1372
|
+
in a transformed sub-batch (one set of parameters per call).
|
|
1373
|
+
|
|
1374
|
+
Parameters
|
|
1375
|
+
----------
|
|
1376
|
+
probability : float
|
|
1377
|
+
Float setting the probability of applying the operation.
|
|
1378
|
+
num_bands : int, optional
|
|
1379
|
+
Number of electrode bands (e.g. ``2`` for left + right wristband).
|
|
1380
|
+
Must be ``>= 1``. Defaults to 2.
|
|
1381
|
+
electrodes_per_band : int, optional
|
|
1382
|
+
Electrodes per band (e.g. ``16``). Must be ``>= 1``. Defaults
|
|
1383
|
+
to 16.
|
|
1384
|
+
band_offsets : tuple of int, optional
|
|
1385
|
+
Per-band roll values to sample from uniformly. ``(-1, 0, 1)``
|
|
1386
|
+
covers ±1-electrode misalignment. Must be non-empty. Defaults
|
|
1387
|
+
to ``(-1, 0, 1)``.
|
|
1388
|
+
max_temporal_jitter : int, optional
|
|
1389
|
+
Max ±-sample temporal shift applied to band 1. Defaults to 0
|
|
1390
|
+
(jitter disabled). Must be ``>= 0``. The emg2qwerty paper uses
|
|
1391
|
+
120 samples (60 ms at 2 kHz).
|
|
1392
|
+
circular_jitter : bool, optional
|
|
1393
|
+
If True (default, paper-faithful) the jitter is a circular roll;
|
|
1394
|
+
if False the gap left by the shift is zero-padded. See
|
|
1395
|
+
:func:`band_rotation`.
|
|
1396
|
+
random_state : int | numpy.random.RandomState, optional
|
|
1397
|
+
Seed for the rotation / jitter sampler. Defaults to None.
|
|
1398
|
+
|
|
1399
|
+
References
|
|
1400
|
+
----------
|
|
1401
|
+
.. [Sivakumar2024] Sivakumar, V., Seely, J., Du, A., Bittner, S. R.,
|
|
1402
|
+
Berenzweig, A., Bolarinwa, A., Gramfort, A., & Mandel, M. I. (2024).
|
|
1403
|
+
"emg2qwerty: A Large Dataset with Baselines for Touch Typing using
|
|
1404
|
+
Surface Electromyography." *NeurIPS Datasets and Benchmarks Track*.
|
|
1405
|
+
"""
|
|
1406
|
+
|
|
1407
|
+
operation = staticmethod(band_rotation) # type: ignore[assignment]
|
|
1408
|
+
|
|
1409
|
+
def __init__(
|
|
1410
|
+
self,
|
|
1411
|
+
probability,
|
|
1412
|
+
num_bands=2,
|
|
1413
|
+
electrodes_per_band=16,
|
|
1414
|
+
band_offsets=(-1, 0, 1),
|
|
1415
|
+
max_temporal_jitter=0,
|
|
1416
|
+
circular_jitter=True,
|
|
1417
|
+
random_state=None,
|
|
1418
|
+
):
|
|
1419
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
1420
|
+
# Up-front parameter validation; the underlying ``band_rotation``
|
|
1421
|
+
# also re-checks at call time, but raising here surfaces config
|
|
1422
|
+
# mistakes when the Transform is built rather than on the first
|
|
1423
|
+
# batch.
|
|
1424
|
+
if num_bands < 1:
|
|
1425
|
+
raise ValueError(f"num_bands must be >= 1, got {num_bands}")
|
|
1426
|
+
if electrodes_per_band < 1:
|
|
1427
|
+
raise ValueError(
|
|
1428
|
+
f"electrodes_per_band must be >= 1, got {electrodes_per_band}"
|
|
1429
|
+
)
|
|
1430
|
+
band_offsets = tuple(band_offsets)
|
|
1431
|
+
if not band_offsets:
|
|
1432
|
+
raise ValueError("band_offsets must be non-empty")
|
|
1433
|
+
if not all(isinstance(o, (int, np.integer)) for o in band_offsets):
|
|
1434
|
+
raise ValueError(
|
|
1435
|
+
f"band_offsets must contain integers, got {band_offsets!r}"
|
|
1436
|
+
)
|
|
1437
|
+
if max_temporal_jitter < 0:
|
|
1438
|
+
raise ValueError(
|
|
1439
|
+
f"max_temporal_jitter must be >= 0, got {max_temporal_jitter}"
|
|
1440
|
+
)
|
|
1441
|
+
self.num_bands = num_bands
|
|
1442
|
+
self.electrodes_per_band = electrodes_per_band
|
|
1443
|
+
self.band_offsets = band_offsets
|
|
1444
|
+
self.max_temporal_jitter = max_temporal_jitter
|
|
1445
|
+
self.circular_jitter = circular_jitter
|
|
1446
|
+
|
|
1447
|
+
def get_augmentation_params(self, *batch):
|
|
1448
|
+
return {
|
|
1449
|
+
"num_bands": self.num_bands,
|
|
1450
|
+
"electrodes_per_band": self.electrodes_per_band,
|
|
1451
|
+
"band_offsets": self.band_offsets,
|
|
1452
|
+
"max_temporal_jitter": self.max_temporal_jitter,
|
|
1453
|
+
"circular_jitter": self.circular_jitter,
|
|
1454
|
+
"random_state": self.rng,
|
|
1455
|
+
}
|
braindecode/models/emg2qwerty.py
CHANGED
|
@@ -56,7 +56,11 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
56
56
|
Returns ``(batch, T_out, n_outputs)``. With ``n_times=8000`` and
|
|
57
57
|
defaults, ``T_out=373``. For :class:`~torch.nn.CTCLoss`, transpose
|
|
58
58
|
to ``(T_out, batch, n_outputs)``; use :meth:`compute_output_lengths`
|
|
59
|
-
for emission lengths.
|
|
59
|
+
for emission lengths. Pass ``return_features=True`` to return the
|
|
60
|
+
pre-classifier encoder representation as a
|
|
61
|
+
``{"features": (batch, T_out, num_features), "cls_token": None}``
|
|
62
|
+
dict, matching the BIOT / signal-JEPA convention used by downstream
|
|
63
|
+
wrappers (e.g. neuroai's ``DownstreamWrapperModel``).
|
|
60
64
|
|
|
61
65
|
.. rubric:: Paper training recipe
|
|
62
66
|
|
|
@@ -69,7 +73,9 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
69
73
|
local minimum).
|
|
70
74
|
- **Augmentation**: per-band electrode rotations by -1/0/+1 positions,
|
|
71
75
|
±60-sample temporal jitter, and SpecAugment [park2019specaug]_ on
|
|
72
|
-
the log-spectrogram.
|
|
76
|
+
the log-spectrogram. SpecAugment is built into the model
|
|
77
|
+
(``spec_augment=True``) and only fires in training mode; the
|
|
78
|
+
time/frequency-jitter pieces are dataset-side augmentations.
|
|
73
79
|
- **Decoding**: greedy CTC. Upstream also reports a 6-gram KenLM
|
|
74
80
|
beam decoder, not ported here.
|
|
75
81
|
|
|
@@ -145,6 +151,37 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
145
151
|
layers and again after the second :class:`~torch.nn.Linear`.
|
|
146
152
|
Default ``0.0`` matches the upstream paper recipe (no dropout).
|
|
147
153
|
Set ``> 0`` for regularized training.
|
|
154
|
+
spec_augment : bool
|
|
155
|
+
If ``True``, apply SpecAugment [park2019specaug]_ time/frequency
|
|
156
|
+
masking on the log-spectrogram during training only. Disabled in
|
|
157
|
+
``eval`` mode and absent from the parameter / state-dict count.
|
|
158
|
+
Defaults to ``False``; set to ``True`` to match the upstream
|
|
159
|
+
emg2qwerty paper recipe.
|
|
160
|
+
n_time_masks : int
|
|
161
|
+
Maximum number of time masks applied per call. Each forward pass
|
|
162
|
+
samples a uniform integer in ``[0, n_time_masks]``. Defaults to
|
|
163
|
+
``3`` (Sivakumar et al. Sec 5.2).
|
|
164
|
+
time_mask_param : int
|
|
165
|
+
Maximum time-mask width in spectrogram frames. Defaults to ``25``.
|
|
166
|
+
n_freq_masks : int
|
|
167
|
+
Maximum number of frequency masks applied per call. Each forward
|
|
168
|
+
pass samples a uniform integer in ``[0, n_freq_masks]``. Defaults
|
|
169
|
+
to ``2``.
|
|
170
|
+
freq_mask_param : int
|
|
171
|
+
Maximum frequency-mask width in STFT bins. Defaults to ``4``.
|
|
172
|
+
spec_augment_prob : float
|
|
173
|
+
Probability of running SpecAugment on a given training batch
|
|
174
|
+
(Bernoulli gate before sampling mask counts). Defaults to ``1.0``.
|
|
175
|
+
return_feature : bool
|
|
176
|
+
If ``True``, ``forward`` returns a tuple
|
|
177
|
+
``(emissions, features)`` instead of just ``emissions`` —
|
|
178
|
+
:class:`braindecode.models.BIOT`-style legacy feature path. Lets
|
|
179
|
+
configuration-driven downstream wrappers (e.g. neuroai's
|
|
180
|
+
``DownstreamWrapperModel`` with ``model_output_key=1``) pick up
|
|
181
|
+
the encoder representation without passing a runtime kwarg.
|
|
182
|
+
Defaults to ``False``. Mutually compatible with the runtime
|
|
183
|
+
``return_features`` (plural) flag, which still wins when set
|
|
184
|
+
to ``True``.
|
|
148
185
|
|
|
149
186
|
Examples
|
|
150
187
|
--------
|
|
@@ -216,6 +253,13 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
216
253
|
log_softmax: bool = False,
|
|
217
254
|
activation: type[nn.Module] = nn.ReLU,
|
|
218
255
|
drop_prob: float = 0.0,
|
|
256
|
+
spec_augment: bool = False,
|
|
257
|
+
n_time_masks: int = 3,
|
|
258
|
+
time_mask_param: int = 25,
|
|
259
|
+
n_freq_masks: int = 2,
|
|
260
|
+
freq_mask_param: int = 4,
|
|
261
|
+
spec_augment_prob: float = 1.0,
|
|
262
|
+
return_feature: bool = False,
|
|
219
263
|
# Standard braindecode args
|
|
220
264
|
n_times: int | None = None,
|
|
221
265
|
input_window_seconds: float | None = None,
|
|
@@ -256,6 +300,7 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
256
300
|
self.hop_length = hop_length
|
|
257
301
|
self.kernel_width = kernel_width
|
|
258
302
|
self.log_softmax = log_softmax
|
|
303
|
+
self.return_feature = return_feature
|
|
259
304
|
|
|
260
305
|
n_freq_bins = n_fft // 2 + 1
|
|
261
306
|
in_features = electrodes_per_band * n_freq_bins
|
|
@@ -269,6 +314,23 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
269
314
|
log_eps=log_eps,
|
|
270
315
|
)
|
|
271
316
|
|
|
317
|
+
# Built-in SpecAugment lives between the spectrogram and the BatchNorm
|
|
318
|
+
# so it operates on the log-power tensor (matches upstream emg2qwerty
|
|
319
|
+
# and the previous neuralbench callback). ``nn.Identity`` keeps the
|
|
320
|
+
# forward path symmetrical without contributing parameters or
|
|
321
|
+
# state-dict keys when SpecAugment is disabled.
|
|
322
|
+
self.spec_augment: nn.Module
|
|
323
|
+
if spec_augment:
|
|
324
|
+
self.spec_augment = _SpecAugment(
|
|
325
|
+
n_time_masks=n_time_masks,
|
|
326
|
+
time_mask_param=time_mask_param,
|
|
327
|
+
n_freq_masks=n_freq_masks,
|
|
328
|
+
freq_mask_param=freq_mask_param,
|
|
329
|
+
prob=spec_augment_prob,
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
self.spec_augment = nn.Identity()
|
|
333
|
+
|
|
272
334
|
# Indices 0/1/3 match upstream's ``TDSConvCTCModule.model``;
|
|
273
335
|
# index 2 is a parameter-free Flatten; upstream's index 4 (head)
|
|
274
336
|
# is broken out as ``self.final_layer`` and remapped via :attr:`mapping`.
|
|
@@ -298,7 +360,13 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
298
360
|
isinstance(m, _TDSConv2dBlock) for m in self.model[3].tds_conv_blocks
|
|
299
361
|
)
|
|
300
362
|
|
|
301
|
-
def forward(
|
|
363
|
+
def forward(
|
|
364
|
+
self, x: torch.Tensor, return_features: bool = False
|
|
365
|
+
) -> (
|
|
366
|
+
torch.Tensor
|
|
367
|
+
| dict[str, torch.Tensor | None]
|
|
368
|
+
| tuple[torch.Tensor, torch.Tensor]
|
|
369
|
+
):
|
|
302
370
|
"""Run the full pipeline.
|
|
303
371
|
|
|
304
372
|
Parameters
|
|
@@ -307,12 +375,37 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
307
375
|
Raw EMG of shape ``(batch, n_chans=32, n_times)``. ``n_times``
|
|
308
376
|
must be at least the encoder's receptive field, ``n_fft +
|
|
309
377
|
n_conv_blocks * (kernel_width - 1) * hop_length``.
|
|
378
|
+
return_features : bool
|
|
379
|
+
If ``True``, return a ``dict`` with the encoder representation
|
|
380
|
+
instead of the classification emissions. The encoder is the
|
|
381
|
+
full TDS-Conv stack up to (but not including)
|
|
382
|
+
``self.final_layer`` — i.e. what downstream wrappers want
|
|
383
|
+
when they apply their own probe/aggregation. Matches the
|
|
384
|
+
BIOT / signal-JEPA convention so the same neuroai
|
|
385
|
+
``DownstreamWrapperModel(model_output_key="features")``
|
|
386
|
+
can consume it. Wins over the constructor-time
|
|
387
|
+
``return_feature`` flag when set.
|
|
310
388
|
|
|
311
389
|
Returns
|
|
312
390
|
-------
|
|
313
|
-
|
|
314
|
-
|
|
391
|
+
torch.Tensor or dict or tuple
|
|
392
|
+
Default (``return_features=False``, init
|
|
393
|
+
``return_feature=False``): ``torch.Tensor`` of shape
|
|
394
|
+
``(batch, T_out, n_outputs)``. Log-probabilities if
|
|
315
395
|
``log_softmax=True``, otherwise logits.
|
|
396
|
+
|
|
397
|
+
If runtime ``return_features=True``: ``dict`` with
|
|
398
|
+
``"features"`` (shape ``(batch, T_out, num_features)``,
|
|
399
|
+
where ``num_features = num_bands * mlp_features[-1]``) and
|
|
400
|
+
``"cls_token"`` (always ``None`` — TDS-Conv has no
|
|
401
|
+
``[CLS]``).
|
|
402
|
+
|
|
403
|
+
If init ``return_feature=True`` and runtime
|
|
404
|
+
``return_features=False``: tuple ``(emissions, features)``
|
|
405
|
+
where ``features`` has shape ``(batch, T_out,
|
|
406
|
+
num_features)``. Same layout BIOT exposes for
|
|
407
|
+
configuration-driven feature extraction (e.g. neuroai's
|
|
408
|
+
``model_output_key=1``).
|
|
316
409
|
"""
|
|
317
410
|
if x.ndim != 3 or x.shape[-2] != self.n_chans:
|
|
318
411
|
raise ValueError(
|
|
@@ -331,11 +424,24 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
331
424
|
f"kernel_width={self.kernel_width})."
|
|
332
425
|
)
|
|
333
426
|
spectrogram = self.spectrogram(x)
|
|
427
|
+
spectrogram = self.spec_augment(spectrogram)
|
|
334
428
|
encoded = self.model(spectrogram)
|
|
429
|
+
# ``encoded`` is (T_out, B, num_features); only materialise the
|
|
430
|
+
# batch-first features tensor in the branches that actually return
|
|
431
|
+
# it, so the default emissions-only path skips the extra transpose
|
|
432
|
+
# + contiguous copy on every forward.
|
|
433
|
+
if return_features:
|
|
434
|
+
return {
|
|
435
|
+
"features": encoded.transpose(0, 1).contiguous(),
|
|
436
|
+
"cls_token": None,
|
|
437
|
+
}
|
|
335
438
|
emissions = self.final_layer(encoded)
|
|
336
439
|
if self.log_softmax:
|
|
337
440
|
emissions = F.log_softmax(emissions, dim=-1)
|
|
338
|
-
|
|
441
|
+
emissions = emissions.transpose(0, 1).contiguous()
|
|
442
|
+
if self.return_feature:
|
|
443
|
+
return emissions, encoded.transpose(0, 1).contiguous()
|
|
444
|
+
return emissions
|
|
339
445
|
|
|
340
446
|
def reset_head(self, n_outputs: int) -> None:
|
|
341
447
|
"""Replace the classification head for a new vocabulary size.
|
|
@@ -411,7 +517,13 @@ class EMG2QwertyNet(EEGModuleMixin, nn.Module):
|
|
|
411
517
|
dummy_input = torch.zeros(
|
|
412
518
|
1, self.n_chans, n_times, dtype=dtype, device=device
|
|
413
519
|
)
|
|
414
|
-
|
|
520
|
+
# ``return_features=False`` keeps the dict path off; the init
|
|
521
|
+
# ``return_feature`` flag may still produce a tuple, so unpack
|
|
522
|
+
# the emissions explicitly to report the public output shape.
|
|
523
|
+
out = self.forward(dummy_input, return_features=False)
|
|
524
|
+
emissions = out[0] if isinstance(out, tuple) else out
|
|
525
|
+
assert isinstance(emissions, torch.Tensor)
|
|
526
|
+
return tuple(emissions.shape)
|
|
415
527
|
|
|
416
528
|
|
|
417
529
|
class _LogSpectrogram(nn.Module):
|
|
@@ -483,6 +595,94 @@ class _LogSpectrogram(nn.Module):
|
|
|
483
595
|
).movedim(-1, 0)
|
|
484
596
|
|
|
485
597
|
|
|
598
|
+
class _SpecAugment(nn.Module):
|
|
599
|
+
r"""SpecAugment masking on the log-spectrogram during training.
|
|
600
|
+
|
|
601
|
+
Applies up to ``n_time_masks`` × ``time_mask_param``-frame time
|
|
602
|
+
bands and ``n_freq_masks`` × ``freq_mask_param``-bin frequency
|
|
603
|
+
bands. Masks are independent per ``(sample × band × electrode)``
|
|
604
|
+
triple — same recipe as the upstream emg2qwerty
|
|
605
|
+
:class:`emg2qwerty.transforms.SpecAugment` dataset transform
|
|
606
|
+
(Sivakumar et al. Sec 5.2 / NeurIPS 2024), which is
|
|
607
|
+
:func:`torchaudio.functional.mask_along_axis_iid`-style masking
|
|
608
|
+
sampled per leading dim of a spectrogram with shape
|
|
609
|
+
``(..., freq, time)``. No-op outside ``training``.
|
|
610
|
+
|
|
611
|
+
The mask fill value is the on-device mean of the spectrogram —
|
|
612
|
+
``log(power=1)=0`` would sit well above the typical log-power
|
|
613
|
+
distribution and inject artificial spikes — and stays a 0-D
|
|
614
|
+
tensor so the forward pass adds no host round-trip on GPU.
|
|
615
|
+
"""
|
|
616
|
+
|
|
617
|
+
def __init__(
|
|
618
|
+
self,
|
|
619
|
+
n_time_masks: int = 3,
|
|
620
|
+
time_mask_param: int = 25,
|
|
621
|
+
n_freq_masks: int = 2,
|
|
622
|
+
freq_mask_param: int = 4,
|
|
623
|
+
prob: float = 1.0,
|
|
624
|
+
) -> None:
|
|
625
|
+
super().__init__()
|
|
626
|
+
if n_time_masks < 0 or n_freq_masks < 0:
|
|
627
|
+
raise ValueError(
|
|
628
|
+
f"n_time_masks and n_freq_masks must be >= 0; got "
|
|
629
|
+
f"n_time_masks={n_time_masks}, n_freq_masks={n_freq_masks}."
|
|
630
|
+
)
|
|
631
|
+
if time_mask_param < 0 or freq_mask_param < 0:
|
|
632
|
+
raise ValueError(
|
|
633
|
+
f"time_mask_param and freq_mask_param must be >= 0; got "
|
|
634
|
+
f"time_mask_param={time_mask_param}, "
|
|
635
|
+
f"freq_mask_param={freq_mask_param}."
|
|
636
|
+
)
|
|
637
|
+
if not 0.0 <= prob <= 1.0:
|
|
638
|
+
raise ValueError(f"prob must be in [0, 1]; got {prob}.")
|
|
639
|
+
self.n_time_masks = n_time_masks
|
|
640
|
+
self.time_mask_param = time_mask_param
|
|
641
|
+
self.n_freq_masks = n_freq_masks
|
|
642
|
+
self.freq_mask_param = freq_mask_param
|
|
643
|
+
self.prob = prob
|
|
644
|
+
# ``iid_masks=True`` so masking is sampled over every leading dim
|
|
645
|
+
# except the trailing ``(freq, time)`` pair — i.e. one mask per
|
|
646
|
+
# ``(sample × band × electrode)`` on a 5-D
|
|
647
|
+
# ``(B, num_bands, electrodes, freq, T)`` input. Matches upstream
|
|
648
|
+
# emg2qwerty's per-``(band × electrode)`` dataset-time recipe.
|
|
649
|
+
self.time_mask = ta_transforms.TimeMasking(time_mask_param, iid_masks=True)
|
|
650
|
+
self.freq_mask = ta_transforms.FrequencyMasking(freq_mask_param, iid_masks=True)
|
|
651
|
+
|
|
652
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
653
|
+
# ``x``: (T_spec, B, num_bands, electrodes, freq).
|
|
654
|
+
if (
|
|
655
|
+
not self.training
|
|
656
|
+
or self.prob <= 0.0
|
|
657
|
+
or (self.n_time_masks == 0 and self.n_freq_masks == 0)
|
|
658
|
+
):
|
|
659
|
+
return x
|
|
660
|
+
# All RNG draws use ``x.device`` so reproducibility seeds the same
|
|
661
|
+
# stream regardless of whether the user calls ``torch.manual_seed``
|
|
662
|
+
# or ``torch.cuda.manual_seed`` — and so torchaudio's internal
|
|
663
|
+
# device-side RNG and our Python-level gate stay in sync. ``.item()``
|
|
664
|
+
# still forces a host sync for the Python ``if``/loop bound, but
|
|
665
|
+
# that is unavoidable for control flow.
|
|
666
|
+
if self.prob < 1.0 and torch.rand((), device=x.device).item() >= self.prob:
|
|
667
|
+
return x
|
|
668
|
+
# ``torchaudio`` masking expects ``(..., freq, time)``; here that means
|
|
669
|
+
# ``(B, num_bands, electrodes, freq, T_spec)``. Move time to the end
|
|
670
|
+
# rather than reshaping into 4D, because ``mask_along_axis_iid`` draws
|
|
671
|
+
# one mask per leading-axis index, so the 5-D layout already gives the
|
|
672
|
+
# desired per-``(B × num_bands × electrodes)`` independence.
|
|
673
|
+
spec = x.movedim(0, -1).contiguous()
|
|
674
|
+
# 0-D on-device tensor — ``masked_fill`` / ``torch.where`` accept it
|
|
675
|
+
# without a host sync.
|
|
676
|
+
mask_value = spec.mean()
|
|
677
|
+
n_t = int(torch.randint(self.n_time_masks + 1, (), device=x.device).item())
|
|
678
|
+
for _ in range(n_t):
|
|
679
|
+
spec = self.time_mask(spec, mask_value=mask_value)
|
|
680
|
+
n_f = int(torch.randint(self.n_freq_masks + 1, (), device=x.device).item())
|
|
681
|
+
for _ in range(n_f):
|
|
682
|
+
spec = self.freq_mask(spec, mask_value=mask_value)
|
|
683
|
+
return spec.movedim(-1, 0)
|
|
684
|
+
|
|
685
|
+
|
|
486
686
|
class _SpectrogramNorm(nn.Module):
|
|
487
687
|
r""":class:`~torch.nn.BatchNorm2d` over (band × electrode) channels.
|
|
488
688
|
|
braindecode/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "1.5.0.
|
|
1
|
+
__version__ = "1.5.0.dev1015"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.5.0.
|
|
3
|
+
Version: 1.5.0.dev1015
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
5
|
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
@@ -3,11 +3,11 @@ braindecode/classifier.py,sha256=7kC_oY_UzHEes_WWdCvEpiA1ZKxMeuLL5tIPp5rfcpg,962
|
|
|
3
3
|
braindecode/eegneuralnet.py,sha256=xjE6aPZdCQPs29NIpy_m1GLMMC2WZ3Db0Fuh1-xE1h4,13827
|
|
4
4
|
braindecode/regressor.py,sha256=KiMJpqCUPWA2k2JWk9HGYTzeoBqJ4gAKEudeUVcFZY4,9266
|
|
5
5
|
braindecode/util.py,sha256=f8bNIwt-SwsHqheH_BADQxTtA9oPt3Lb7GFnoI-Huwc,14101
|
|
6
|
-
braindecode/version.py,sha256=
|
|
7
|
-
braindecode/augmentation/__init__.py,sha256=
|
|
6
|
+
braindecode/version.py,sha256=H7kXWUs3T_eeGV04VfvDHjBTgxYEXSFgslEWx9TudIs,30
|
|
7
|
+
braindecode/augmentation/__init__.py,sha256=hmnjUsL_DX5BxYVdyNReh7T3YRQEJKYzciB1UwYHRvc,1119
|
|
8
8
|
braindecode/augmentation/base.py,sha256=OJ1shOljI1yTY9zh2qWxQwivlY43sfx9Q-MAyMhxtPs,7338
|
|
9
|
-
braindecode/augmentation/functional.py,sha256=
|
|
10
|
-
braindecode/augmentation/transforms.py,sha256=
|
|
9
|
+
braindecode/augmentation/functional.py,sha256=jGKTNgWf9ZIFGcShYD0Qlb9IuC47OIwC813eCEcTPsM,49757
|
|
10
|
+
braindecode/augmentation/transforms.py,sha256=QPS-cjbHz0TcKbd5Uuiag0s2Kt83xDhH1juBvAK6C5M,51426
|
|
11
11
|
braindecode/datasets/__init__.py,sha256=rVOBadwqYBiMz5kl7nGiBOmMgr11xvjS4nuzzZTOn1U,1102
|
|
12
12
|
braindecode/datasets/base.py,sha256=3lKLZQO4hfA-dv_JJEfPwyZ5nzRkLTu4qiRAqFVZUUQ,70508
|
|
13
13
|
braindecode/datasets/bbci.py,sha256=SCm7OnCObotILQ0B1EdmZPoyJtzsRXpeU_gNKtqQLSc,19288
|
|
@@ -66,7 +66,7 @@ braindecode/models/eegpt.py,sha256=5ZXItSURum1vfUciuCHbJGIkXLk6XISP-e3NsFIv3wc,4
|
|
|
66
66
|
braindecode/models/eegsimpleconv.py,sha256=suHO-v9laImwvXpLF2dwvoFFBKjiV-czAW1FHwRSscI,7306
|
|
67
67
|
braindecode/models/eegsym.py,sha256=-5wb28oxx3YSCkFUnla-6P0RdGYshBPhfke7vSj-tnA,34592
|
|
68
68
|
braindecode/models/eegtcnet.py,sha256=awEIwEIWSvS0b2Hb7ROfxV9DSwNe5z2224a-Teznuyo,10916
|
|
69
|
-
braindecode/models/emg2qwerty.py,sha256=
|
|
69
|
+
braindecode/models/emg2qwerty.py,sha256=ln5Gmf7u0dup4_PN6xLRXs0KY-TrRb606p0s6l5J0o8,38925
|
|
70
70
|
braindecode/models/fbcnet.py,sha256=YE5pCtF0Oo3J7rh8DDBl0oYZy9Tb2oyXkOYJJMr76Bo,7711
|
|
71
71
|
braindecode/models/fblightconvnet.py,sha256=bOo7DlFiqByVQ0e6ethv5n2J7N-tIhiasLObqGLAg4g,11107
|
|
72
72
|
braindecode/models/fbmsnet.py,sha256=prw9LcZBH_mEwV__fhUOOTbK4bmRdoKLLpjNuLA94Yg,12355
|
|
@@ -128,9 +128,9 @@ braindecode/visualization/frequency.py,sha256=gNwkn9yIik5SUp7d9HE9J_vPVGyzNsxxCO
|
|
|
128
128
|
braindecode/visualization/metrics.py,sha256=j01kc04P9uEkQ2g2Tt2C76yr6soIj31PAuBMflrmODg,13615
|
|
129
129
|
braindecode/visualization/sanity.py,sha256=nNClauUC8dCj_KCy_1RmaPDQAqExLczfPtUeQ7k9-Q0,4812
|
|
130
130
|
braindecode/visualization/topology.py,sha256=mXxUfCCUJqa_cMF4y6GC3_A-qBCcS4uTc0EzBolkytE,2274
|
|
131
|
-
braindecode-1.5.0.
|
|
132
|
-
braindecode-1.5.0.
|
|
133
|
-
braindecode-1.5.0.
|
|
134
|
-
braindecode-1.5.0.
|
|
135
|
-
braindecode-1.5.0.
|
|
136
|
-
braindecode-1.5.0.
|
|
131
|
+
braindecode-1.5.0.dev1015.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
|
|
132
|
+
braindecode-1.5.0.dev1015.dist-info/licenses/NOTICE.txt,sha256=ZFFhigxIaKgDcMjCzPyAVSFV42ztU0kLOENt_kvherw,857
|
|
133
|
+
braindecode-1.5.0.dev1015.dist-info/METADATA,sha256=6-m5pOAtFg3doXKj78dFw_JSObw3BgLBtZBy_0gYhYI,10275
|
|
134
|
+
braindecode-1.5.0.dev1015.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
135
|
+
braindecode-1.5.0.dev1015.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
|
|
136
|
+
braindecode-1.5.0.dev1015.dist-info/RECORD,,
|
|
File without changes
|
{braindecode-1.5.0.dev1010.dist-info → braindecode-1.5.0.dev1015.dist-info}/licenses/LICENSE.txt
RENAMED
|
File without changes
|
{braindecode-1.5.0.dev1010.dist-info → braindecode-1.5.0.dev1015.dist-info}/licenses/NOTICE.txt
RENAMED
|
File without changes
|
|
File without changes
|