braindecode 1.3.0.dev174777731__py3-none-any.whl → 1.3.0.dev175955015__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/functional.py +154 -54
- braindecode/augmentation/transforms.py +2 -2
- braindecode/datasets/base.py +1 -1
- braindecode/datasets/sleep_physio_challe_18.py +2 -1
- braindecode/datautil/serialization.py +11 -6
- braindecode/eegneuralnet.py +2 -0
- braindecode/models/__init__.py +4 -0
- braindecode/models/atcnet.py +7 -7
- braindecode/models/attentionbasenet.py +2 -0
- braindecode/models/base.py +280 -2
- braindecode/models/biot.py +1 -1
- braindecode/models/labram.py +168 -69
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/signal_jepa.py +103 -27
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +8 -6
- braindecode/models/util.py +2 -0
- braindecode/preprocessing/preprocess.py +11 -2
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev174777731.dist-info → braindecode-1.3.0.dev175955015.dist-info}/METADATA +4 -2
- {braindecode-1.3.0.dev174777731.dist-info → braindecode-1.3.0.dev175955015.dist-info}/RECORD +25 -23
- {braindecode-1.3.0.dev174777731.dist-info → braindecode-1.3.0.dev175955015.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev174777731.dist-info → braindecode-1.3.0.dev175955015.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev174777731.dist-info → braindecode-1.3.0.dev175955015.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev174777731.dist-info → braindecode-1.3.0.dev175955015.dist-info}/top_level.txt +0 -0
|
@@ -5,7 +5,8 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Optional, Sequence
|
|
9
10
|
|
|
10
11
|
import torch
|
|
11
12
|
from einops.layers.torch import Rearrange
|
|
@@ -319,25 +320,50 @@ class SignalJEPA_Contextual(_BaseSignalJEPA):
|
|
|
319
320
|
@classmethod
|
|
320
321
|
def from_pretrained(
|
|
321
322
|
cls,
|
|
322
|
-
model: SignalJEPA,
|
|
323
|
-
n_outputs: int,
|
|
323
|
+
model: Optional[SignalJEPA | str | Path] = None, # type: ignore
|
|
324
|
+
n_outputs: Optional[int] = None, # type: ignore
|
|
324
325
|
n_spat_filters: int = 4,
|
|
325
|
-
chs_info: list[dict[str, Any]]
|
|
326
|
+
chs_info: Optional[list[dict[str, Any]]] = None, # type: ignore
|
|
327
|
+
**kwargs,
|
|
326
328
|
):
|
|
327
|
-
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
|
|
329
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
|
|
328
330
|
|
|
329
331
|
Parameters
|
|
330
332
|
----------
|
|
331
|
-
model: SignalJEPA
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
333
|
+
model: SignalJEPA, str, Path, or None
|
|
334
|
+
Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
|
|
335
|
+
(for Hub-style loading), or None (for Hub loading via kwargs).
|
|
336
|
+
n_outputs: int or None
|
|
337
|
+
Number of classes for the new model. Required when loading from a SignalJEPA model,
|
|
338
|
+
optional when loading from Hub (will be read from config).
|
|
335
339
|
n_spat_filters: int
|
|
336
340
|
Number of spatial filters.
|
|
337
341
|
chs_info: list of dict | None
|
|
338
342
|
Information about each individual EEG channel. This should be filled with
|
|
339
343
|
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
344
|
+
**kwargs
|
|
345
|
+
Additional keyword arguments passed to the parent class for Hub loading.
|
|
340
346
|
"""
|
|
347
|
+
# Check if this is a Hub-style load (from a directory path)
|
|
348
|
+
if isinstance(model, (str, Path)) or (model is None and kwargs):
|
|
349
|
+
# This is a Hub load, delegate to parent class
|
|
350
|
+
if isinstance(model, (str, Path)):
|
|
351
|
+
# model is actually the repo_id or directory path
|
|
352
|
+
return super().from_pretrained(model, **kwargs)
|
|
353
|
+
else:
|
|
354
|
+
# model is None, treat as hub-style load
|
|
355
|
+
return super().from_pretrained(**kwargs)
|
|
356
|
+
|
|
357
|
+
# This is the original SignalJEPA transfer learning case
|
|
358
|
+
if not isinstance(model, SignalJEPA):
|
|
359
|
+
raise TypeError(
|
|
360
|
+
f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
|
|
361
|
+
)
|
|
362
|
+
if n_outputs is None:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
"n_outputs must be provided when loading from a SignalJEPA model"
|
|
365
|
+
)
|
|
366
|
+
|
|
341
367
|
feature_encoder = model.feature_encoder
|
|
342
368
|
pos_encoder = model.pos_encoder
|
|
343
369
|
transformer = model.transformer
|
|
@@ -463,22 +489,47 @@ class SignalJEPA_PostLocal(_BaseSignalJEPA):
|
|
|
463
489
|
|
|
464
490
|
@classmethod
|
|
465
491
|
def from_pretrained(
|
|
466
|
-
cls,
|
|
492
|
+
cls,
|
|
493
|
+
model: SignalJEPA | str | Path = None, # type: ignore
|
|
494
|
+
n_outputs: int = None, # type: ignore
|
|
495
|
+
n_spat_filters: int = 4,
|
|
496
|
+
**kwargs,
|
|
467
497
|
):
|
|
468
|
-
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
|
|
498
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
|
|
469
499
|
|
|
470
500
|
Parameters
|
|
471
501
|
----------
|
|
472
|
-
model: SignalJEPA
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
502
|
+
model: SignalJEPA, str, Path, or None
|
|
503
|
+
Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
|
|
504
|
+
(for Hub-style loading), or None (for Hub loading via kwargs).
|
|
505
|
+
n_outputs: int or None
|
|
506
|
+
Number of classes for the new model. Required when loading from a SignalJEPA model,
|
|
507
|
+
optional when loading from Hub (will be read from config).
|
|
476
508
|
n_spat_filters: int
|
|
477
509
|
Number of spatial filters.
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
510
|
+
**kwargs
|
|
511
|
+
Additional keyword arguments passed to the parent class for Hub loading.
|
|
481
512
|
"""
|
|
513
|
+
# Check if this is a Hub-style load (from a directory path)
|
|
514
|
+
if isinstance(model, (str, Path)) or (model is None and kwargs):
|
|
515
|
+
# This is a Hub load, delegate to parent class
|
|
516
|
+
if isinstance(model, (str, Path)):
|
|
517
|
+
# model is actually the repo_id or directory path
|
|
518
|
+
return super().from_pretrained(model, **kwargs)
|
|
519
|
+
else:
|
|
520
|
+
# model is None, treat as hub-style load
|
|
521
|
+
return super().from_pretrained(**kwargs)
|
|
522
|
+
|
|
523
|
+
# This is the original SignalJEPA transfer learning case
|
|
524
|
+
if not isinstance(model, SignalJEPA):
|
|
525
|
+
raise TypeError(
|
|
526
|
+
f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
|
|
527
|
+
)
|
|
528
|
+
if n_outputs is None:
|
|
529
|
+
raise ValueError(
|
|
530
|
+
"n_outputs must be provided when loading from a SignalJEPA model"
|
|
531
|
+
)
|
|
532
|
+
|
|
482
533
|
feature_encoder = model.feature_encoder
|
|
483
534
|
assert feature_encoder is not None
|
|
484
535
|
new_model = cls(
|
|
@@ -597,22 +648,47 @@ class SignalJEPA_PreLocal(_BaseSignalJEPA):
|
|
|
597
648
|
|
|
598
649
|
@classmethod
|
|
599
650
|
def from_pretrained(
|
|
600
|
-
cls,
|
|
651
|
+
cls,
|
|
652
|
+
model: SignalJEPA | str | Path = None, # type: ignore
|
|
653
|
+
n_outputs: int = None, # type: ignore
|
|
654
|
+
n_spat_filters: int = 4,
|
|
655
|
+
**kwargs,
|
|
601
656
|
):
|
|
602
|
-
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
|
|
657
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
|
|
603
658
|
|
|
604
659
|
Parameters
|
|
605
660
|
----------
|
|
606
|
-
model: SignalJEPA
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
661
|
+
model: SignalJEPA, str, Path, or None
|
|
662
|
+
Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
|
|
663
|
+
(for Hub-style loading), or None (for Hub loading via kwargs).
|
|
664
|
+
n_outputs: int or None
|
|
665
|
+
Number of classes for the new model. Required when loading from a SignalJEPA model,
|
|
666
|
+
optional when loading from Hub (will be read from config).
|
|
610
667
|
n_spat_filters: int
|
|
611
668
|
Number of spatial filters.
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
669
|
+
**kwargs
|
|
670
|
+
Additional keyword arguments passed to the parent class for Hub loading.
|
|
615
671
|
"""
|
|
672
|
+
# Check if this is a Hub-style load (from a directory path)
|
|
673
|
+
if isinstance(model, (str, Path)) or (model is None and kwargs):
|
|
674
|
+
# This is a Hub load, delegate to parent class
|
|
675
|
+
if isinstance(model, (str, Path)):
|
|
676
|
+
# model is actually the repo_id or directory path
|
|
677
|
+
return super().from_pretrained(model, **kwargs)
|
|
678
|
+
else:
|
|
679
|
+
# model is None, treat as hub-style load
|
|
680
|
+
return super().from_pretrained(**kwargs)
|
|
681
|
+
|
|
682
|
+
# This is the original SignalJEPA transfer learning case
|
|
683
|
+
if not isinstance(model, SignalJEPA):
|
|
684
|
+
raise TypeError(
|
|
685
|
+
f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
|
|
686
|
+
)
|
|
687
|
+
if n_outputs is None:
|
|
688
|
+
raise ValueError(
|
|
689
|
+
"n_outputs must be provided when loading from a SignalJEPA model"
|
|
690
|
+
)
|
|
691
|
+
|
|
616
692
|
feature_encoder = model.feature_encoder
|
|
617
693
|
assert feature_encoder is not None
|
|
618
694
|
new_model = cls(
|