braindecode 1.3.0.dev175415232__py3-none-any.whl → 1.3.0.dev175955015__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/eegneuralnet.py +2 -0
- braindecode/models/attentionbasenet.py +2 -0
- braindecode/models/base.py +280 -2
- braindecode/models/labram.py +168 -69
- braindecode/models/signal_jepa.py +103 -27
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev175415232.dist-info → braindecode-1.3.0.dev175955015.dist-info}/METADATA +4 -2
- {braindecode-1.3.0.dev175415232.dist-info → braindecode-1.3.0.dev175955015.dist-info}/RECORD +12 -12
- {braindecode-1.3.0.dev175415232.dist-info → braindecode-1.3.0.dev175955015.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev175415232.dist-info → braindecode-1.3.0.dev175955015.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev175415232.dist-info → braindecode-1.3.0.dev175955015.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev175415232.dist-info → braindecode-1.3.0.dev175955015.dist-info}/top_level.txt +0 -0
braindecode/eegneuralnet.py
CHANGED
|
@@ -189,6 +189,8 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
189
189
|
"Skipping setting signal-related parameters from data."
|
|
190
190
|
)
|
|
191
191
|
return
|
|
192
|
+
if classes is None:
|
|
193
|
+
classes = getattr(self, "classes", None)
|
|
192
194
|
# get kwargs from signal:
|
|
193
195
|
signal_kwargs = dict()
|
|
194
196
|
# Using shape to work both with torch.tensor and numpy.array:
|
|
@@ -381,6 +381,8 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
381
381
|
for k, pl, ps in zip(kernel_lengths, pool_lengths, pool_strides):
|
|
382
382
|
out = math.floor(out + 2 * (k // 2) - k + 1)
|
|
383
383
|
out = math.floor((out - pl) / ps + 1)
|
|
384
|
+
# Ensure output is at least 1 to avoid zero-sized tensors
|
|
385
|
+
out = max(1, out)
|
|
384
386
|
seq_lengths.append(int(out))
|
|
385
387
|
return seq_lengths
|
|
386
388
|
|
braindecode/models/base.py
CHANGED
|
@@ -5,15 +5,35 @@
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
+
import json
|
|
8
9
|
import warnings
|
|
9
10
|
from collections import OrderedDict
|
|
10
|
-
from
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, Iterable, Optional, Type, Union
|
|
11
13
|
|
|
12
14
|
import numpy as np
|
|
13
15
|
import torch
|
|
14
16
|
from docstring_inheritance import NumpyDocstringInheritanceInitMeta
|
|
17
|
+
from mne.utils import _soft_import
|
|
15
18
|
from torchinfo import ModelStatistics, summary
|
|
16
19
|
|
|
20
|
+
from braindecode.version import __version__
|
|
21
|
+
|
|
22
|
+
huggingface_hub = _soft_import(
|
|
23
|
+
"huggingface_hub", "Hugging Face Hub integration", strict=False
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
HAS_HF_HUB = huggingface_hub is not False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _BaseHubMixin:
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Define base class for hub mixin
|
|
34
|
+
if HAS_HF_HUB:
|
|
35
|
+
_BaseHubMixin: Type = huggingface_hub.PyTorchModelHubMixin # type: ignore
|
|
36
|
+
|
|
17
37
|
|
|
18
38
|
def deprecated_args(obj, *old_new_args):
|
|
19
39
|
out_args = []
|
|
@@ -32,10 +52,14 @@ def deprecated_args(obj, *old_new_args):
|
|
|
32
52
|
return out_args
|
|
33
53
|
|
|
34
54
|
|
|
35
|
-
class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
55
|
+
class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta):
|
|
36
56
|
"""
|
|
37
57
|
Mixin class for all EEG models in braindecode.
|
|
38
58
|
|
|
59
|
+
This class integrates with Hugging Face Hub when the ``huggingface_hub`` package
|
|
60
|
+
is installed, enabling models to be pushed to and loaded from the Hub using
|
|
61
|
+
:func:`push_to_hub()` and :func:`from_pretrained()` methods.
|
|
62
|
+
|
|
39
63
|
Parameters
|
|
40
64
|
----------
|
|
41
65
|
n_outputs : int
|
|
@@ -62,8 +86,87 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
62
86
|
-----
|
|
63
87
|
If some input signal-related parameters are not specified,
|
|
64
88
|
there will be an attempt to infer them from the other parameters.
|
|
89
|
+
|
|
90
|
+
Hugging Face Hub Integration
|
|
91
|
+
-----------------------------
|
|
92
|
+
When the optional ``huggingface_hub`` package is installed, all models
|
|
93
|
+
automatically gain the ability to be pushed to and loaded from the
|
|
94
|
+
Hugging Face Hub. Install with::
|
|
95
|
+
|
|
96
|
+
pip install braindecode[hug]
|
|
97
|
+
|
|
98
|
+
**Pushing a model to the Hub:**
|
|
99
|
+
|
|
100
|
+
.. code-block:: python
|
|
101
|
+
|
|
102
|
+
from braindecode.models import EEGNetv4
|
|
103
|
+
|
|
104
|
+
# Train your model
|
|
105
|
+
model = EEGNetv4(n_chans=22, n_outputs=4, n_times=1000)
|
|
106
|
+
# ... training code ...
|
|
107
|
+
|
|
108
|
+
# Push to the Hub
|
|
109
|
+
model.push_to_hub(
|
|
110
|
+
repo_id="username/my-eegnet-model", commit_message="Initial model upload"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
**Loading a model from the Hub:**
|
|
114
|
+
|
|
115
|
+
.. code-block:: python
|
|
116
|
+
|
|
117
|
+
from braindecode.models import EEGNetv4
|
|
118
|
+
|
|
119
|
+
# Load pretrained model
|
|
120
|
+
model = EEGNetv4.from_pretrained("username/my-eegnet-model")
|
|
121
|
+
|
|
122
|
+
The integration automatically handles EEG-specific parameters (n_chans,
|
|
123
|
+
n_times, sfreq, chs_info, etc.) by saving them in a config file alongside
|
|
124
|
+
the model weights. This ensures that loaded models are correctly configured
|
|
125
|
+
for their original data specifications.
|
|
126
|
+
|
|
127
|
+
.. important::
|
|
128
|
+
Currently, only EEG-specific parameters (n_outputs, n_chans, n_times,
|
|
129
|
+
input_window_seconds, sfreq, chs_info) are saved to the Hub. Model-specific
|
|
130
|
+
parameters (e.g., dropout rates, activation functions, number of filters)
|
|
131
|
+
are not preserved and will use their default values when loading from the Hub.
|
|
132
|
+
|
|
133
|
+
To use non-default model parameters, specify them explicitly when calling
|
|
134
|
+
:func:`from_pretrained()`::
|
|
135
|
+
|
|
136
|
+
model = EEGNet.from_pretrained("user/model", dropout=0.3, activation='relu')
|
|
137
|
+
|
|
138
|
+
Full parameter serialization will be addressed in a future update.
|
|
65
139
|
"""
|
|
66
140
|
|
|
141
|
+
def __init_subclass__(cls, **kwargs):
|
|
142
|
+
if not HAS_HF_HUB:
|
|
143
|
+
super().__init_subclass__(**kwargs)
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
base_tags = ["braindecode", cls.__name__]
|
|
147
|
+
user_tags = kwargs.pop("tags", None)
|
|
148
|
+
tags = list(user_tags) if user_tags is not None else []
|
|
149
|
+
for tag in base_tags:
|
|
150
|
+
if tag not in tags:
|
|
151
|
+
tags.append(tag)
|
|
152
|
+
|
|
153
|
+
docs_url = kwargs.pop(
|
|
154
|
+
"docs_url",
|
|
155
|
+
f"https://braindecode.org/stable/generated/braindecode.models.{cls.__name__}.html",
|
|
156
|
+
)
|
|
157
|
+
repo_url = kwargs.pop("repo_url", "https://braindecode.org")
|
|
158
|
+
library_name = kwargs.pop("library_name", "braindecode")
|
|
159
|
+
license = kwargs.pop("license", "bsd-3-clause")
|
|
160
|
+
# TODO: model_card_template can be added in the future for custom model cards
|
|
161
|
+
super().__init_subclass__(
|
|
162
|
+
tags=tags,
|
|
163
|
+
docs_url=docs_url,
|
|
164
|
+
repo_url=repo_url,
|
|
165
|
+
library_name=library_name,
|
|
166
|
+
license=license,
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
|
|
67
170
|
def __init__(
|
|
68
171
|
self,
|
|
69
172
|
n_outputs: Optional[int] = None, # type: ignore[assignment]
|
|
@@ -73,6 +176,16 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
73
176
|
input_window_seconds: Optional[float] = None, # type: ignore[assignment]
|
|
74
177
|
sfreq: Optional[float] = None, # type: ignore[assignment]
|
|
75
178
|
):
|
|
179
|
+
# Deserialize chs_info if it comes as a list of dicts (from Hub)
|
|
180
|
+
if chs_info is not None and isinstance(chs_info, list):
|
|
181
|
+
if len(chs_info) > 0 and isinstance(chs_info[0], dict):
|
|
182
|
+
# Check if it needs deserialization (has 'loc' as list)
|
|
183
|
+
if "loc" in chs_info[0] and isinstance(chs_info[0]["loc"], list):
|
|
184
|
+
chs_info = self._deserialize_chs_info(chs_info)
|
|
185
|
+
warnings.warn(
|
|
186
|
+
"Modifying chs_info argument using the _deserialize_chs_info() method"
|
|
187
|
+
)
|
|
188
|
+
|
|
76
189
|
if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
|
|
77
190
|
raise ValueError(f"{n_chans=} different from {chs_info=} length")
|
|
78
191
|
if (
|
|
@@ -294,3 +407,168 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
294
407
|
|
|
295
408
|
def __str__(self) -> str:
|
|
296
409
|
return str(self.get_torchinfo_statistics())
|
|
410
|
+
|
|
411
|
+
@staticmethod
|
|
412
|
+
def _serialize_chs_info(chs_info):
|
|
413
|
+
"""
|
|
414
|
+
Serialize MNE channel info to JSON-compatible format.
|
|
415
|
+
|
|
416
|
+
Parameters
|
|
417
|
+
----------
|
|
418
|
+
chs_info : list of dict or None
|
|
419
|
+
Channel information from MNE Info object.
|
|
420
|
+
|
|
421
|
+
Returns
|
|
422
|
+
-------
|
|
423
|
+
list of dict or None
|
|
424
|
+
Serialized channel information that can be saved to JSON.
|
|
425
|
+
"""
|
|
426
|
+
if chs_info is None:
|
|
427
|
+
return None
|
|
428
|
+
|
|
429
|
+
serialized = []
|
|
430
|
+
for ch in chs_info:
|
|
431
|
+
# Extract serializable fields from MNE channel info
|
|
432
|
+
ch_dict = {
|
|
433
|
+
"ch_name": ch.get("ch_name", ""),
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
# Handle kind field - can be either string or integer
|
|
437
|
+
kind_val = ch.get("kind")
|
|
438
|
+
if kind_val is not None:
|
|
439
|
+
ch_dict["kind"] = (
|
|
440
|
+
kind_val if isinstance(kind_val, str) else int(kind_val)
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Add numeric fields with safe conversion
|
|
444
|
+
coil_type = ch.get("coil_type")
|
|
445
|
+
if coil_type is not None:
|
|
446
|
+
ch_dict["coil_type"] = int(coil_type)
|
|
447
|
+
|
|
448
|
+
unit = ch.get("unit")
|
|
449
|
+
if unit is not None:
|
|
450
|
+
ch_dict["unit"] = int(unit)
|
|
451
|
+
|
|
452
|
+
cal = ch.get("cal")
|
|
453
|
+
if cal is not None:
|
|
454
|
+
ch_dict["cal"] = float(cal)
|
|
455
|
+
|
|
456
|
+
range_val = ch.get("range")
|
|
457
|
+
if range_val is not None:
|
|
458
|
+
ch_dict["range"] = float(range_val)
|
|
459
|
+
|
|
460
|
+
# Serialize location array if present
|
|
461
|
+
if "loc" in ch and ch["loc"] is not None:
|
|
462
|
+
ch_dict["loc"] = (
|
|
463
|
+
ch["loc"].tolist()
|
|
464
|
+
if hasattr(ch["loc"], "tolist")
|
|
465
|
+
else list(ch["loc"])
|
|
466
|
+
)
|
|
467
|
+
serialized.append(ch_dict)
|
|
468
|
+
|
|
469
|
+
return serialized
|
|
470
|
+
|
|
471
|
+
@staticmethod
|
|
472
|
+
def _deserialize_chs_info(chs_info_dict):
|
|
473
|
+
"""
|
|
474
|
+
Deserialize channel info from JSON-compatible format to MNE-like structure.
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
chs_info_dict : list of dict or None
|
|
479
|
+
Serialized channel information.
|
|
480
|
+
|
|
481
|
+
Returns
|
|
482
|
+
-------
|
|
483
|
+
list of dict or None
|
|
484
|
+
Deserialized channel information compatible with MNE.
|
|
485
|
+
"""
|
|
486
|
+
if chs_info_dict is None:
|
|
487
|
+
return None
|
|
488
|
+
|
|
489
|
+
deserialized = []
|
|
490
|
+
for ch_dict in chs_info_dict:
|
|
491
|
+
ch = ch_dict.copy()
|
|
492
|
+
# Convert location back to numpy array if present
|
|
493
|
+
if "loc" in ch and ch["loc"] is not None:
|
|
494
|
+
ch["loc"] = np.array(ch["loc"])
|
|
495
|
+
deserialized.append(ch)
|
|
496
|
+
|
|
497
|
+
return deserialized
|
|
498
|
+
|
|
499
|
+
def _save_pretrained(self, save_directory):
|
|
500
|
+
"""
|
|
501
|
+
Save model configuration and weights to the Hub.
|
|
502
|
+
|
|
503
|
+
This method is called by PyTorchModelHubMixin.push_to_hub() to save
|
|
504
|
+
model-specific configuration alongside the model weights.
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
save_directory : str or Path
|
|
509
|
+
Directory where the configuration should be saved.
|
|
510
|
+
"""
|
|
511
|
+
if not HAS_HF_HUB:
|
|
512
|
+
return
|
|
513
|
+
|
|
514
|
+
save_directory = Path(save_directory)
|
|
515
|
+
|
|
516
|
+
# Collect EEG-specific configuration
|
|
517
|
+
config = {
|
|
518
|
+
"n_outputs": self._n_outputs,
|
|
519
|
+
"n_chans": self._n_chans,
|
|
520
|
+
"n_times": self._n_times,
|
|
521
|
+
"input_window_seconds": self._input_window_seconds,
|
|
522
|
+
"sfreq": self._sfreq,
|
|
523
|
+
"chs_info": self._serialize_chs_info(self._chs_info),
|
|
524
|
+
"braindecode_version": __version__,
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
# Save to config.json
|
|
528
|
+
config_path = save_directory / "config.json"
|
|
529
|
+
with open(config_path, "w") as f:
|
|
530
|
+
json.dump(config, f, indent=2)
|
|
531
|
+
|
|
532
|
+
# Save model weights with standard Hub filename
|
|
533
|
+
weights_path = save_directory / "pytorch_model.bin"
|
|
534
|
+
torch.save(self.state_dict(), weights_path)
|
|
535
|
+
|
|
536
|
+
# Also save in safetensors format using parent's implementation
|
|
537
|
+
try:
|
|
538
|
+
super()._save_pretrained(save_directory)
|
|
539
|
+
except (ImportError, RuntimeError) as e:
|
|
540
|
+
# Fallback to pytorch_model.bin if safetensors saving fails
|
|
541
|
+
warnings.warn(
|
|
542
|
+
f"Could not save model in safetensors format: {e}. "
|
|
543
|
+
"Model weights saved in pytorch_model.bin instead.",
|
|
544
|
+
stacklevel=2,
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
if HAS_HF_HUB:
|
|
548
|
+
|
|
549
|
+
@classmethod
|
|
550
|
+
def _from_pretrained(
|
|
551
|
+
cls,
|
|
552
|
+
*,
|
|
553
|
+
model_id: str,
|
|
554
|
+
revision: Optional[str],
|
|
555
|
+
cache_dir: Optional[Union[str, Path]],
|
|
556
|
+
force_download: bool,
|
|
557
|
+
local_files_only: bool,
|
|
558
|
+
token: Union[str, bool, None],
|
|
559
|
+
map_location: str = "cpu",
|
|
560
|
+
strict: bool = False,
|
|
561
|
+
**model_kwargs,
|
|
562
|
+
):
|
|
563
|
+
model_kwargs.pop("braindecode_version", None)
|
|
564
|
+
return super()._from_pretrained( # type: ignore
|
|
565
|
+
model_id=model_id,
|
|
566
|
+
revision=revision,
|
|
567
|
+
cache_dir=cache_dir,
|
|
568
|
+
force_download=force_download,
|
|
569
|
+
local_files_only=local_files_only,
|
|
570
|
+
token=token,
|
|
571
|
+
map_location=map_location,
|
|
572
|
+
strict=strict,
|
|
573
|
+
**model_kwargs,
|
|
574
|
+
)
|
braindecode/models/labram.py
CHANGED
|
@@ -61,13 +61,24 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
61
61
|
|
|
62
62
|
.. versionadded:: 0.9
|
|
63
63
|
|
|
64
|
+
|
|
65
|
+
Examples on how to load pre-trained weights:
|
|
66
|
+
--------------------------------------------
|
|
67
|
+
>>> import torch
|
|
68
|
+
>>> from braindecode.models import Labram
|
|
69
|
+
>>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
|
|
70
|
+
>>> url = 'https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt'
|
|
71
|
+
>>> state = torch.hub.load_state_dict_from_url(url, progress=True)
|
|
72
|
+
>>> model.load_state_dict(state)
|
|
73
|
+
|
|
74
|
+
|
|
64
75
|
Parameters
|
|
65
76
|
----------
|
|
66
77
|
patch_size : int
|
|
67
78
|
The size of the patch to be used in the patch embedding.
|
|
68
79
|
emb_size : int
|
|
69
80
|
The dimension of the embedding.
|
|
70
|
-
|
|
81
|
+
in_conv_channels : int
|
|
71
82
|
The number of convolutional input channels.
|
|
72
83
|
out_channels : int
|
|
73
84
|
The number of convolutional output channels.
|
|
@@ -79,8 +90,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
79
90
|
The expansion ratio of the mlp layer
|
|
80
91
|
qkv_bias : bool (default=False)
|
|
81
92
|
If True, add a learnable bias to the query, key, and value tensors.
|
|
82
|
-
qk_norm : Pytorch Normalize layer (default=
|
|
83
|
-
If not None, apply LayerNorm to the query and key tensors
|
|
93
|
+
qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
94
|
+
If not None, apply LayerNorm to the query and key tensors.
|
|
95
|
+
Default is nn.LayerNorm for better weight transfer from original LaBraM.
|
|
96
|
+
Set to None to disable Q,K normalization.
|
|
84
97
|
qk_scale : float (default=None)
|
|
85
98
|
If not None, use this value as the scale factor. If None,
|
|
86
99
|
use head_dim**-0.5, where head_dim = dim // num_heads.
|
|
@@ -92,9 +105,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
92
105
|
Dropout rate for the attention weights used on DropPath.
|
|
93
106
|
norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
94
107
|
The normalization layer to be used.
|
|
95
|
-
init_values : float (default=
|
|
108
|
+
init_values : float (default=0.1)
|
|
96
109
|
If not None, use this value to initialize the gamma_1 and gamma_2
|
|
97
|
-
parameters.
|
|
110
|
+
parameters for residual scaling. Default is 0.1 for better weight
|
|
111
|
+
transfer from original LaBraM. Set to None to disable.
|
|
98
112
|
use_abs_pos_emb : bool (default=True)
|
|
99
113
|
If True, use absolute position embedding.
|
|
100
114
|
use_mean_pooling : bool (default=True)
|
|
@@ -135,19 +149,19 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
135
149
|
input_window_seconds=None,
|
|
136
150
|
patch_size=200,
|
|
137
151
|
emb_size=200,
|
|
138
|
-
|
|
152
|
+
in_conv_channels=1,
|
|
139
153
|
out_channels=8,
|
|
140
154
|
n_layers=12,
|
|
141
155
|
att_num_heads=10,
|
|
142
156
|
mlp_ratio=4.0,
|
|
143
157
|
qkv_bias=False,
|
|
144
|
-
qk_norm=
|
|
158
|
+
qk_norm=nn.LayerNorm,
|
|
145
159
|
qk_scale=None,
|
|
146
160
|
drop_prob=0.0,
|
|
147
161
|
attn_drop_prob=0.0,
|
|
148
162
|
drop_path_prob=0.0,
|
|
149
163
|
norm_layer=nn.LayerNorm,
|
|
150
|
-
init_values=
|
|
164
|
+
init_values=0.1,
|
|
151
165
|
use_abs_pos_emb=True,
|
|
152
166
|
use_mean_pooling=True,
|
|
153
167
|
init_scale=0.001,
|
|
@@ -183,15 +197,15 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
183
197
|
self.patch_size = patch_size
|
|
184
198
|
self.n_path = self.n_times // self.patch_size
|
|
185
199
|
|
|
186
|
-
if neural_tokenizer and
|
|
200
|
+
if neural_tokenizer and in_conv_channels != 1:
|
|
187
201
|
warn(
|
|
188
202
|
"The model is in Neural Tokenizer mode, but the variable "
|
|
189
|
-
+ "`
|
|
190
|
-
+ "`
|
|
191
|
-
+ "
|
|
203
|
+
+ "`in_conv_channels` is different from the default values."
|
|
204
|
+
+ "`in_conv_channels` is only needed for the Neural Decoder mode."
|
|
205
|
+
+ "in_conv_channels is not used in the Neural Tokenizer mode.",
|
|
192
206
|
UserWarning,
|
|
193
207
|
)
|
|
194
|
-
|
|
208
|
+
in_conv_channels = 1
|
|
195
209
|
# If you can use the model in Neural Tokenizer mode,
|
|
196
210
|
# temporal conv layer will be use over the patched dataset
|
|
197
211
|
if neural_tokenizer:
|
|
@@ -228,7 +242,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
228
242
|
_PatchEmbed(
|
|
229
243
|
n_times=self.n_times,
|
|
230
244
|
patch_size=patch_size,
|
|
231
|
-
in_channels=
|
|
245
|
+
in_channels=in_conv_channels,
|
|
232
246
|
emb_dim=self.emb_size,
|
|
233
247
|
),
|
|
234
248
|
)
|
|
@@ -373,8 +387,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
373
387
|
Parameters
|
|
374
388
|
----------
|
|
375
389
|
x : torch.Tensor
|
|
376
|
-
The input data with shape (batch, n_chans,
|
|
377
|
-
if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
|
|
390
|
+
The input data with shape (batch, n_chans, n_times).
|
|
378
391
|
input_chans : int
|
|
379
392
|
The number of input channels.
|
|
380
393
|
return_patch_tokens : bool
|
|
@@ -387,37 +400,72 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
387
400
|
x : torch.Tensor
|
|
388
401
|
The output of the model.
|
|
389
402
|
"""
|
|
403
|
+
batch_size = x.shape[0]
|
|
404
|
+
|
|
390
405
|
if self.neural_tokenizer:
|
|
391
|
-
|
|
406
|
+
# For neural tokenizer: input is (batch, n_chans, n_times)
|
|
407
|
+
# patch_embed returns (batch, n_chans, emb_dim)
|
|
408
|
+
x = self.patch_embed(x)
|
|
409
|
+
# x shape: (batch, n_chans, emb_dim)
|
|
410
|
+
n_patch = self.n_chans
|
|
411
|
+
temporal = self.emb_size
|
|
392
412
|
else:
|
|
393
|
-
|
|
394
|
-
|
|
413
|
+
# For neural decoder: input is (batch, n_chans, n_times)
|
|
414
|
+
# patch_embed returns (batch, n_patchs, emb_dim)
|
|
415
|
+
x = self.patch_embed(x)
|
|
416
|
+
# x shape: (batch, n_patchs, emb_dim)
|
|
417
|
+
batch_size, n_patch, temporal = x.shape
|
|
418
|
+
|
|
395
419
|
# add the [CLS] token to the embedded patch tokens
|
|
396
420
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
397
421
|
|
|
422
|
+
# Concatenate cls token with patch/channel embeddings
|
|
398
423
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
399
424
|
|
|
400
425
|
# Positional Embedding
|
|
401
|
-
if input_chans is not None:
|
|
402
|
-
pos_embed_used = self.position_embedding[:, input_chans]
|
|
403
|
-
else:
|
|
404
|
-
pos_embed_used = self.position_embedding
|
|
405
|
-
|
|
406
426
|
if self.position_embedding is not None:
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
427
|
+
if self.neural_tokenizer:
|
|
428
|
+
# In tokenizer mode, use channel-based position embedding
|
|
429
|
+
if input_chans is not None:
|
|
430
|
+
pos_embed_used = self.position_embedding[:, input_chans]
|
|
431
|
+
else:
|
|
432
|
+
pos_embed_used = self.position_embedding
|
|
433
|
+
|
|
434
|
+
pos_embed = self._adj_position_embedding(
|
|
435
|
+
pos_embed_used=pos_embed_used, batch_size=batch_size
|
|
436
|
+
)
|
|
437
|
+
else:
|
|
438
|
+
# In decoder mode, we have different number of patches
|
|
439
|
+
# Adapt position embedding for n_patch patches
|
|
440
|
+
# Use the first n_patch+1 positions from position_embedding
|
|
441
|
+
n_pos = min(self.position_embedding.shape[1], n_patch + 1)
|
|
442
|
+
pos_embed_used = self.position_embedding[:, :n_pos, :]
|
|
443
|
+
pos_embed = pos_embed_used.expand(batch_size, -1, -1)
|
|
444
|
+
|
|
410
445
|
x += pos_embed
|
|
411
446
|
|
|
412
447
|
# The time embedding is added across the channels after the [CLS] token
|
|
413
448
|
if self.neural_tokenizer:
|
|
414
449
|
num_ch = self.n_chans
|
|
450
|
+
time_embed = self._adj_temporal_embedding(
|
|
451
|
+
num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
|
|
452
|
+
)
|
|
453
|
+
x[:, 1:, :] += time_embed
|
|
415
454
|
else:
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
455
|
+
# In decoder mode, we have n_patch patches and don't need to expand
|
|
456
|
+
# Just broadcast the temporal embedding
|
|
457
|
+
if temporal is None:
|
|
458
|
+
temporal = self.emb_size
|
|
459
|
+
|
|
460
|
+
# Get temporal embeddings for n_patch patches
|
|
461
|
+
n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
|
|
462
|
+
time_embed = self.temporal_embedding[
|
|
463
|
+
:, 1 : n_time_tokens + 1, :
|
|
464
|
+
] # (1, n_patch, emb_dim)
|
|
465
|
+
time_embed = time_embed.expand(
|
|
466
|
+
batch_size, -1, -1
|
|
467
|
+
) # (batch, n_patch, emb_dim)
|
|
468
|
+
x[:, 1:, :] += time_embed
|
|
421
469
|
|
|
422
470
|
x = self.pos_drop(x)
|
|
423
471
|
|
|
@@ -428,10 +476,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
428
476
|
if self.fc_norm is not None:
|
|
429
477
|
if return_all_tokens:
|
|
430
478
|
return self.fc_norm(x)
|
|
431
|
-
|
|
479
|
+
tokens = x[:, 1:, :]
|
|
432
480
|
if return_patch_tokens:
|
|
433
|
-
return self.fc_norm(
|
|
434
|
-
return self.fc_norm(
|
|
481
|
+
return self.fc_norm(tokens)
|
|
482
|
+
return self.fc_norm(tokens.mean(1))
|
|
435
483
|
else:
|
|
436
484
|
if return_all_tokens:
|
|
437
485
|
return x
|
|
@@ -505,14 +553,16 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
505
553
|
def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
|
|
506
554
|
"""
|
|
507
555
|
Adjust the dimensions of the time embedding to match the
|
|
508
|
-
number of channels.
|
|
556
|
+
number of channels or patches.
|
|
509
557
|
|
|
510
558
|
Parameters
|
|
511
559
|
----------
|
|
512
560
|
num_ch : int
|
|
513
|
-
The number of channels or number of
|
|
561
|
+
The number of channels or number of patches.
|
|
514
562
|
batch_size : int
|
|
515
563
|
Batch size of the input data.
|
|
564
|
+
dim_embed : int
|
|
565
|
+
The embedding dimension (temporal feature dimension).
|
|
516
566
|
|
|
517
567
|
Returns
|
|
518
568
|
-------
|
|
@@ -523,17 +573,24 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
523
573
|
if dim_embed is None:
|
|
524
574
|
cut_dimension = self.patch_size
|
|
525
575
|
else:
|
|
526
|
-
cut_dimension = dim_embed
|
|
527
|
-
|
|
528
|
-
|
|
576
|
+
cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
|
|
577
|
+
|
|
578
|
+
# Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
|
|
579
|
+
# Slice to cut_dimension: (1, cut_dimension, emb_size)
|
|
580
|
+
temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
|
|
581
|
+
|
|
529
582
|
# Add a new dimension to the time embedding
|
|
530
|
-
# e.g. (
|
|
583
|
+
# e.g. (1, 5, 200) -> (1, 1, 5, 200)
|
|
531
584
|
temporal_embedding = temporal_embedding.unsqueeze(1)
|
|
532
|
-
|
|
533
|
-
#
|
|
585
|
+
|
|
586
|
+
# Expand the time embedding to match the number of channels or patches
|
|
587
|
+
# (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
|
|
534
588
|
temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
|
|
589
|
+
|
|
535
590
|
# Flatten the intermediate dimensions
|
|
591
|
+
# (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
|
|
536
592
|
temporal_embedding = temporal_embedding.flatten(1, 2)
|
|
593
|
+
|
|
537
594
|
return temporal_embedding
|
|
538
595
|
|
|
539
596
|
def _adj_position_embedding(self, pos_embed_used, batch_size):
|
|
@@ -679,25 +736,27 @@ class _SegmentPatch(nn.Module):
|
|
|
679
736
|
|
|
680
737
|
|
|
681
738
|
class _PatchEmbed(nn.Module):
|
|
682
|
-
"""EEG to Patch Embedding.
|
|
739
|
+
"""EEG to Patch Embedding for Neural Decoder mode.
|
|
683
740
|
|
|
684
741
|
This code is used when we want to apply the patch embedding
|
|
685
|
-
after the codebook layer.
|
|
742
|
+
after the codebook layer (Neural Decoder mode).
|
|
743
|
+
|
|
744
|
+
The input is expected to be in the format (Batch, n_channels, n_times),
|
|
745
|
+
but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
|
|
746
|
+
This class reshapes the input to the pre-patched format, then applies a 2D
|
|
747
|
+
convolution to project this pre-patched data to the embedding dimension,
|
|
748
|
+
and finally flattens across channels to produce a unified embedding.
|
|
686
749
|
|
|
687
750
|
Parameters:
|
|
688
751
|
-----------
|
|
689
752
|
n_times: int (default=2000)
|
|
690
|
-
Number of temporal components of the input tensor.
|
|
753
|
+
Number of temporal components of the input tensor (used for dimension calculation).
|
|
691
754
|
patch_size: int (default=200)
|
|
692
755
|
Size of the patch, default is 1-seconds with 200Hz.
|
|
693
756
|
in_channels: int (default=1)
|
|
694
|
-
Number of input channels
|
|
757
|
+
Number of input channels (from VQVAE codebook).
|
|
695
758
|
emb_dim: int (default=200)
|
|
696
|
-
Number of
|
|
697
|
-
we used the same as patch_size.
|
|
698
|
-
n_codebooks: int (default=62)
|
|
699
|
-
Number of patches to be used in the convolution, here,
|
|
700
|
-
we used the same as n_times // patch_size.
|
|
759
|
+
Number of output embedding dimension.
|
|
701
760
|
"""
|
|
702
761
|
|
|
703
762
|
def __init__(
|
|
@@ -707,10 +766,13 @@ class _PatchEmbed(nn.Module):
|
|
|
707
766
|
self.n_times = n_times
|
|
708
767
|
self.patch_size = patch_size
|
|
709
768
|
self.patch_shape = (1, self.n_times // self.patch_size)
|
|
710
|
-
n_patchs =
|
|
711
|
-
|
|
712
|
-
self.
|
|
769
|
+
self.n_patchs = self.n_times // self.patch_size
|
|
770
|
+
self.emb_dim = emb_dim
|
|
771
|
+
self.in_channels = in_channels
|
|
713
772
|
|
|
773
|
+
# 2D Conv to project the pre-patched data
|
|
774
|
+
# Input: (Batch, in_channels, n_patches, patch_size)
|
|
775
|
+
# After proj: (Batch, emb_dim, n_patches, 1)
|
|
714
776
|
self.proj = nn.Conv2d(
|
|
715
777
|
in_channels=in_channels,
|
|
716
778
|
out_channels=emb_dim,
|
|
@@ -718,27 +780,64 @@ class _PatchEmbed(nn.Module):
|
|
|
718
780
|
stride=(1, self.patch_size),
|
|
719
781
|
)
|
|
720
782
|
|
|
721
|
-
self.merge_transpose = Rearrange(
|
|
722
|
-
"Batch ch patch spatch -> Batch patch spatch ch",
|
|
723
|
-
)
|
|
724
|
-
|
|
725
783
|
def forward(self, x):
|
|
726
784
|
"""
|
|
727
|
-
Apply the
|
|
728
|
-
then merge the output tensor to the desired shape.
|
|
785
|
+
Apply the temporal projection to the input tensor after grouping channels.
|
|
729
786
|
|
|
730
|
-
Parameters
|
|
731
|
-
|
|
732
|
-
x: torch.Tensor
|
|
733
|
-
Input tensor of shape (Batch,
|
|
787
|
+
Parameters
|
|
788
|
+
----------
|
|
789
|
+
x : torch.Tensor
|
|
790
|
+
Input tensor of shape (Batch, n_channels, n_times) or
|
|
791
|
+
(Batch, n_channels, n_patches, patch_size).
|
|
734
792
|
|
|
735
|
-
|
|
793
|
+
Returns
|
|
736
794
|
-------
|
|
737
|
-
|
|
738
|
-
Output tensor of shape (Batch, n_patchs,
|
|
795
|
+
torch.Tensor
|
|
796
|
+
Output tensor of shape (Batch, n_patchs, emb_dim).
|
|
739
797
|
"""
|
|
798
|
+
if x.ndim == 4:
|
|
799
|
+
batch_size, n_channels, n_patchs, patch_len = x.shape
|
|
800
|
+
if patch_len != self.patch_size:
|
|
801
|
+
raise ValueError(
|
|
802
|
+
"When providing a 4D tensor, the last dimension "
|
|
803
|
+
f"({patch_len}) must match patch_size ({self.patch_size})."
|
|
804
|
+
)
|
|
805
|
+
n_times = n_patchs * patch_len
|
|
806
|
+
x = x.reshape(batch_size, n_channels, n_times)
|
|
807
|
+
elif x.ndim == 3:
|
|
808
|
+
batch_size, n_channels, n_times = x.shape
|
|
809
|
+
else:
|
|
810
|
+
raise ValueError(
|
|
811
|
+
"Input must be either 3D (batch, channels, times) or "
|
|
812
|
+
"4D (batch, channels, n_patches, patch_size)."
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
if n_times % self.patch_size != 0:
|
|
816
|
+
raise ValueError(
|
|
817
|
+
f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
|
|
818
|
+
)
|
|
819
|
+
if n_channels % self.in_channels != 0:
|
|
820
|
+
raise ValueError(
|
|
821
|
+
"The input channel dimension "
|
|
822
|
+
f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
group_size = n_channels // self.in_channels
|
|
826
|
+
|
|
827
|
+
# Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
|
|
828
|
+
# EEG channels as the spatial height dimension.
|
|
829
|
+
# Shape after view: (Batch, in_channels, group_size, n_times)
|
|
830
|
+
x = x.view(batch_size, self.in_channels, group_size, n_times)
|
|
831
|
+
|
|
832
|
+
# Apply the temporal projection per group.
|
|
833
|
+
# Output shape: (Batch, emb_dim, group_size, n_patchs)
|
|
740
834
|
x = self.proj(x)
|
|
741
|
-
|
|
835
|
+
|
|
836
|
+
# THIS IS braindecode's MODIFICATION:
|
|
837
|
+
# Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
|
|
838
|
+
x = x.mean(dim=2)
|
|
839
|
+
x = x.transpose(1, 2).contiguous()
|
|
840
|
+
|
|
742
841
|
return x
|
|
743
842
|
|
|
744
843
|
|
|
@@ -5,7 +5,8 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Optional, Sequence
|
|
9
10
|
|
|
10
11
|
import torch
|
|
11
12
|
from einops.layers.torch import Rearrange
|
|
@@ -319,25 +320,50 @@ class SignalJEPA_Contextual(_BaseSignalJEPA):
|
|
|
319
320
|
@classmethod
|
|
320
321
|
def from_pretrained(
|
|
321
322
|
cls,
|
|
322
|
-
model: SignalJEPA,
|
|
323
|
-
n_outputs: int,
|
|
323
|
+
model: Optional[SignalJEPA | str | Path] = None, # type: ignore
|
|
324
|
+
n_outputs: Optional[int] = None, # type: ignore
|
|
324
325
|
n_spat_filters: int = 4,
|
|
325
|
-
chs_info: list[dict[str, Any]]
|
|
326
|
+
chs_info: Optional[list[dict[str, Any]]] = None, # type: ignore
|
|
327
|
+
**kwargs,
|
|
326
328
|
):
|
|
327
|
-
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
|
|
329
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
|
|
328
330
|
|
|
329
331
|
Parameters
|
|
330
332
|
----------
|
|
331
|
-
model: SignalJEPA
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
333
|
+
model: SignalJEPA, str, Path, or None
|
|
334
|
+
Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
|
|
335
|
+
(for Hub-style loading), or None (for Hub loading via kwargs).
|
|
336
|
+
n_outputs: int or None
|
|
337
|
+
Number of classes for the new model. Required when loading from a SignalJEPA model,
|
|
338
|
+
optional when loading from Hub (will be read from config).
|
|
335
339
|
n_spat_filters: int
|
|
336
340
|
Number of spatial filters.
|
|
337
341
|
chs_info: list of dict | None
|
|
338
342
|
Information about each individual EEG channel. This should be filled with
|
|
339
343
|
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
344
|
+
**kwargs
|
|
345
|
+
Additional keyword arguments passed to the parent class for Hub loading.
|
|
340
346
|
"""
|
|
347
|
+
# Check if this is a Hub-style load (from a directory path)
|
|
348
|
+
if isinstance(model, (str, Path)) or (model is None and kwargs):
|
|
349
|
+
# This is a Hub load, delegate to parent class
|
|
350
|
+
if isinstance(model, (str, Path)):
|
|
351
|
+
# model is actually the repo_id or directory path
|
|
352
|
+
return super().from_pretrained(model, **kwargs)
|
|
353
|
+
else:
|
|
354
|
+
# model is None, treat as hub-style load
|
|
355
|
+
return super().from_pretrained(**kwargs)
|
|
356
|
+
|
|
357
|
+
# This is the original SignalJEPA transfer learning case
|
|
358
|
+
if not isinstance(model, SignalJEPA):
|
|
359
|
+
raise TypeError(
|
|
360
|
+
f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
|
|
361
|
+
)
|
|
362
|
+
if n_outputs is None:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
"n_outputs must be provided when loading from a SignalJEPA model"
|
|
365
|
+
)
|
|
366
|
+
|
|
341
367
|
feature_encoder = model.feature_encoder
|
|
342
368
|
pos_encoder = model.pos_encoder
|
|
343
369
|
transformer = model.transformer
|
|
@@ -463,22 +489,47 @@ class SignalJEPA_PostLocal(_BaseSignalJEPA):
|
|
|
463
489
|
|
|
464
490
|
@classmethod
|
|
465
491
|
def from_pretrained(
|
|
466
|
-
cls,
|
|
492
|
+
cls,
|
|
493
|
+
model: SignalJEPA | str | Path = None, # type: ignore
|
|
494
|
+
n_outputs: int = None, # type: ignore
|
|
495
|
+
n_spat_filters: int = 4,
|
|
496
|
+
**kwargs,
|
|
467
497
|
):
|
|
468
|
-
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
|
|
498
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
|
|
469
499
|
|
|
470
500
|
Parameters
|
|
471
501
|
----------
|
|
472
|
-
model: SignalJEPA
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
502
|
+
model: SignalJEPA, str, Path, or None
|
|
503
|
+
Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
|
|
504
|
+
(for Hub-style loading), or None (for Hub loading via kwargs).
|
|
505
|
+
n_outputs: int or None
|
|
506
|
+
Number of classes for the new model. Required when loading from a SignalJEPA model,
|
|
507
|
+
optional when loading from Hub (will be read from config).
|
|
476
508
|
n_spat_filters: int
|
|
477
509
|
Number of spatial filters.
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
510
|
+
**kwargs
|
|
511
|
+
Additional keyword arguments passed to the parent class for Hub loading.
|
|
481
512
|
"""
|
|
513
|
+
# Check if this is a Hub-style load (from a directory path)
|
|
514
|
+
if isinstance(model, (str, Path)) or (model is None and kwargs):
|
|
515
|
+
# This is a Hub load, delegate to parent class
|
|
516
|
+
if isinstance(model, (str, Path)):
|
|
517
|
+
# model is actually the repo_id or directory path
|
|
518
|
+
return super().from_pretrained(model, **kwargs)
|
|
519
|
+
else:
|
|
520
|
+
# model is None, treat as hub-style load
|
|
521
|
+
return super().from_pretrained(**kwargs)
|
|
522
|
+
|
|
523
|
+
# This is the original SignalJEPA transfer learning case
|
|
524
|
+
if not isinstance(model, SignalJEPA):
|
|
525
|
+
raise TypeError(
|
|
526
|
+
f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
|
|
527
|
+
)
|
|
528
|
+
if n_outputs is None:
|
|
529
|
+
raise ValueError(
|
|
530
|
+
"n_outputs must be provided when loading from a SignalJEPA model"
|
|
531
|
+
)
|
|
532
|
+
|
|
482
533
|
feature_encoder = model.feature_encoder
|
|
483
534
|
assert feature_encoder is not None
|
|
484
535
|
new_model = cls(
|
|
@@ -597,22 +648,47 @@ class SignalJEPA_PreLocal(_BaseSignalJEPA):
|
|
|
597
648
|
|
|
598
649
|
@classmethod
|
|
599
650
|
def from_pretrained(
|
|
600
|
-
cls,
|
|
651
|
+
cls,
|
|
652
|
+
model: SignalJEPA | str | Path = None, # type: ignore
|
|
653
|
+
n_outputs: int = None, # type: ignore
|
|
654
|
+
n_spat_filters: int = 4,
|
|
655
|
+
**kwargs,
|
|
601
656
|
):
|
|
602
|
-
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
|
|
657
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
|
|
603
658
|
|
|
604
659
|
Parameters
|
|
605
660
|
----------
|
|
606
|
-
model: SignalJEPA
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
661
|
+
model: SignalJEPA, str, Path, or None
|
|
662
|
+
Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
|
|
663
|
+
(for Hub-style loading), or None (for Hub loading via kwargs).
|
|
664
|
+
n_outputs: int or None
|
|
665
|
+
Number of classes for the new model. Required when loading from a SignalJEPA model,
|
|
666
|
+
optional when loading from Hub (will be read from config).
|
|
610
667
|
n_spat_filters: int
|
|
611
668
|
Number of spatial filters.
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
669
|
+
**kwargs
|
|
670
|
+
Additional keyword arguments passed to the parent class for Hub loading.
|
|
615
671
|
"""
|
|
672
|
+
# Check if this is a Hub-style load (from a directory path)
|
|
673
|
+
if isinstance(model, (str, Path)) or (model is None and kwargs):
|
|
674
|
+
# This is a Hub load, delegate to parent class
|
|
675
|
+
if isinstance(model, (str, Path)):
|
|
676
|
+
# model is actually the repo_id or directory path
|
|
677
|
+
return super().from_pretrained(model, **kwargs)
|
|
678
|
+
else:
|
|
679
|
+
# model is None, treat as hub-style load
|
|
680
|
+
return super().from_pretrained(**kwargs)
|
|
681
|
+
|
|
682
|
+
# This is the original SignalJEPA transfer learning case
|
|
683
|
+
if not isinstance(model, SignalJEPA):
|
|
684
|
+
raise TypeError(
|
|
685
|
+
f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
|
|
686
|
+
)
|
|
687
|
+
if n_outputs is None:
|
|
688
|
+
raise ValueError(
|
|
689
|
+
"n_outputs must be provided when loading from a SignalJEPA model"
|
|
690
|
+
)
|
|
691
|
+
|
|
616
692
|
feature_encoder = model.feature_encoder
|
|
617
693
|
assert feature_encoder is not None
|
|
618
694
|
new_model = cls(
|
braindecode/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "1.3.0.
|
|
1
|
+
__version__ = "1.3.0.dev175955015"
|
{braindecode-1.3.0.dev175415232.dist-info → braindecode-1.3.0.dev175955015.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.3.0.
|
|
3
|
+
Version: 1.3.0.dev175955015
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
5
|
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
@@ -40,6 +40,8 @@ Requires-Dist: linear_attention_transformer
|
|
|
40
40
|
Requires-Dist: docstring_inheritance
|
|
41
41
|
Provides-Extra: moabb
|
|
42
42
|
Requires-Dist: moabb>=1.2.0; extra == "moabb"
|
|
43
|
+
Provides-Extra: hug
|
|
44
|
+
Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "hug"
|
|
43
45
|
Provides-Extra: tests
|
|
44
46
|
Requires-Dist: pytest; extra == "tests"
|
|
45
47
|
Requires-Dist: pytest-cov; extra == "tests"
|
|
@@ -65,7 +67,7 @@ Requires-Dist: pre-commit; extra == "docs"
|
|
|
65
67
|
Requires-Dist: openneuro-py; extra == "docs"
|
|
66
68
|
Requires-Dist: plotly; extra == "docs"
|
|
67
69
|
Provides-Extra: all
|
|
68
|
-
Requires-Dist: braindecode[docs,moabb,tests]; extra == "all"
|
|
70
|
+
Requires-Dist: braindecode[docs,hug,moabb,tests]; extra == "all"
|
|
69
71
|
Dynamic: license-file
|
|
70
72
|
|
|
71
73
|
.. image:: https://badges.gitter.im/braindecodechat/community.svg
|
{braindecode-1.3.0.dev175415232.dist-info → braindecode-1.3.0.dev175955015.dist-info}/RECORD
RENAMED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
braindecode/__init__.py,sha256=Ac3LEEyIHWFY_fFh3eAY1GZUqXcUxVSJwOSUCwGEDvQ,182
|
|
2
2
|
braindecode/classifier.py,sha256=k9vSCtfQbld0YVleDi5rrrmk6k_k5JYEPPBYcNxYjZ8,9807
|
|
3
|
-
braindecode/eegneuralnet.py,sha256=
|
|
3
|
+
braindecode/eegneuralnet.py,sha256=U6kRdT2u8A2Ca0axMTR8IAESBsvgjLMusAbYappKAOk,15368
|
|
4
4
|
braindecode/regressor.py,sha256=VLfrpiXklwI4onkwue3QmzlBWcvspu0tlrLo9RT1Oiw,9375
|
|
5
5
|
braindecode/util.py,sha256=J-tBcDJNlMTIFW2mfOy6Ko0nsgdP4obRoEVDeg2rFH0,12686
|
|
6
|
-
braindecode/version.py,sha256=
|
|
6
|
+
braindecode/version.py,sha256=_vqrBwFXZ8aTOQzNbG2-qCrZTT7wOQJENwjT0xpmojI,35
|
|
7
7
|
braindecode/augmentation/__init__.py,sha256=LG7ONqCufYAF9NZt8POIp10lYXb8iSueYkF-CWGK2Ls,1001
|
|
8
8
|
braindecode/augmentation/base.py,sha256=gg7wYsVfa9jfqBddtE03B5ZrPHFFmPl2sa3LOrRnGfo,7325
|
|
9
9
|
braindecode/augmentation/functional.py,sha256=lPhGpZcVtgfQ3oV6p6IQLBCWM_Psa60TwxH3Wj1WyOQ,41133
|
|
@@ -29,9 +29,9 @@ braindecode/functional/functions.py,sha256=CoEweM6YLhigx0tNmmz6yAc8iQ078sTFY2GeC
|
|
|
29
29
|
braindecode/functional/initialization.py,sha256=BUSC7y2TMsfShpMYBVwm3xg3ODFqWp-STH7yD4sn8zk,1388
|
|
30
30
|
braindecode/models/__init__.py,sha256=vB0ZFhucH1cRQPoAAAcc3S-hVTnAy674Eu0FjjjKJp0,2543
|
|
31
31
|
braindecode/models/atcnet.py,sha256=H2IWMscm3IM4PH8DA_iLkUaeMXgA120DmVld4jBFOCM,32242
|
|
32
|
-
braindecode/models/attentionbasenet.py,sha256=
|
|
32
|
+
braindecode/models/attentionbasenet.py,sha256=_bml0Ofy7yB12X19a026EYkcLuzZIab0v3sQTqZ5HGQ,30485
|
|
33
33
|
braindecode/models/attn_sleep.py,sha256=m6sdFfD4en2hHf_TpotLPC1hVweJcYZvjgf12bV5FZg,17822
|
|
34
|
-
braindecode/models/base.py,sha256=
|
|
34
|
+
braindecode/models/base.py,sha256=KjsHVQDdUCAJB4nS-a6eze-H7ayvU4565tsFUcDVxVQ,20212
|
|
35
35
|
braindecode/models/biot.py,sha256=d2P1i_8k98SU3FkN_dKPXcCoFVmyQIIrBbI1-F3g-8E,17509
|
|
36
36
|
braindecode/models/contrawr.py,sha256=eeR_ik4gNZ3rJLM6Mw9gJ2gTMkZ8CU8C4rN_GQMQTAE,10044
|
|
37
37
|
braindecode/models/ctnet.py,sha256=ce5F31q2weBKvg7PL80iDm7za9fhGaCFvNfHoJW_dtg,17315
|
|
@@ -51,12 +51,12 @@ braindecode/models/fblightconvnet.py,sha256=d5MwhawhkjilAMo0ckaYMxJhdGMEuorWgHX-
|
|
|
51
51
|
braindecode/models/fbmsnet.py,sha256=9bZn2_n1dTrI1Qh3Sz9zMZnH_a-Yq-13UHYSmF6r_UE,11659
|
|
52
52
|
braindecode/models/hybrid.py,sha256=hA8jwD3_3LL71BxUjRM1dkhqlHU9E9hjuDokh-jBq-4,4024
|
|
53
53
|
braindecode/models/ifnet.py,sha256=Y2bwfko3SDjD74AzgUEzgMhKJFGCCw_Q_Noh5VONEjQ,15137
|
|
54
|
-
braindecode/models/labram.py,sha256=
|
|
54
|
+
braindecode/models/labram.py,sha256=dnZpHbuB60pKZWZHNQaM01eNajGG0tkZB2iutT882PM,46563
|
|
55
55
|
braindecode/models/msvtnet.py,sha256=hxeCLkHS6w2w89YlLfEPCyQ4XQQpt45bEYPiQJ9SFzY,12642
|
|
56
56
|
braindecode/models/patchedtransformer.py,sha256=9TY9l2X4EoCuE9IoOObjubKFRdmsN5lbrVQLnmr66VY,23444
|
|
57
57
|
braindecode/models/sccnet.py,sha256=C7vdwIR5cI6wJCl5f8TnGQG6qinq21y4HG6l-D5AwbY,11971
|
|
58
58
|
braindecode/models/shallow_fbcsp.py,sha256=7U07DJBrm2JHV8v5ja-xuE5-IH5tfmryhJtrfO1n4jk,7531
|
|
59
|
-
braindecode/models/signal_jepa.py,sha256=
|
|
59
|
+
braindecode/models/signal_jepa.py,sha256=ObP8-AauGZHG9tXxRGnvEnlSiwZ1YssbARuXpUl7swk,41013
|
|
60
60
|
braindecode/models/sinc_shallow.py,sha256=Ilv8K1XhMGiRTBtQdq7L595i6cEFYOBe0_UDv-LqL7s,11907
|
|
61
61
|
braindecode/models/sleep_stager_blanco_2020.py,sha256=vXulnDYutEFLM0UPXyAI0YIj5QImUMVEmYZb78j34H8,6034
|
|
62
62
|
braindecode/models/sleep_stager_chambon_2018.py,sha256=8w8IR2PsfG0jSc3o0YVopgHpOvCHNIuMi7-QRJOYEW4,5245
|
|
@@ -95,9 +95,9 @@ braindecode/training/scoring.py,sha256=WRkwqbitA3m_dzRnGp2ZIZPge5Nhx9gAEQhIHzeH4
|
|
|
95
95
|
braindecode/visualization/__init__.py,sha256=4EER_xHqZIDzEvmgUEm7K1bgNKpyZAIClR9ZCkMuY4M,240
|
|
96
96
|
braindecode/visualization/confusion_matrices.py,sha256=qIWMLEHow5CJ7PhGggD8mnD55Le6xhma9HSzt4R33fc,9509
|
|
97
97
|
braindecode/visualization/gradients.py,sha256=KZo-GA0uwiwty2_94j2IjmCR2SKcfPb1Bi3sQq7vpTk,2170
|
|
98
|
-
braindecode-1.3.0.
|
|
99
|
-
braindecode-1.3.0.
|
|
100
|
-
braindecode-1.3.0.
|
|
101
|
-
braindecode-1.3.0.
|
|
102
|
-
braindecode-1.3.0.
|
|
103
|
-
braindecode-1.3.0.
|
|
98
|
+
braindecode-1.3.0.dev175955015.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
|
|
99
|
+
braindecode-1.3.0.dev175955015.dist-info/licenses/NOTICE.txt,sha256=sOxuTbalPxTM8H6VqtvGbXCt_BoOF7JevEYG_knqbm4,620
|
|
100
|
+
braindecode-1.3.0.dev175955015.dist-info/METADATA,sha256=E4357DmrBx2DdYRbANQcKTrbPLlxlkAPfVRur72RBlU,7215
|
|
101
|
+
braindecode-1.3.0.dev175955015.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
102
|
+
braindecode-1.3.0.dev175955015.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
|
|
103
|
+
braindecode-1.3.0.dev175955015.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{braindecode-1.3.0.dev175415232.dist-info → braindecode-1.3.0.dev175955015.dist-info}/top_level.txt
RENAMED
|
File without changes
|