braindecode 1.3.0.dev174777731__py3-none-any.whl → 1.3.0.dev175955015__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

@@ -5,7 +5,8 @@ from __future__ import annotations
5
5
 
6
6
  import math
7
7
  from copy import deepcopy
8
- from typing import Any, Sequence
8
+ from pathlib import Path
9
+ from typing import Any, Optional, Sequence
9
10
 
10
11
  import torch
11
12
  from einops.layers.torch import Rearrange
@@ -319,25 +320,50 @@ class SignalJEPA_Contextual(_BaseSignalJEPA):
319
320
  @classmethod
320
321
  def from_pretrained(
321
322
  cls,
322
- model: SignalJEPA,
323
- n_outputs: int,
323
+ model: Optional[SignalJEPA | str | Path] = None, # type: ignore
324
+ n_outputs: Optional[int] = None, # type: ignore
324
325
  n_spat_filters: int = 4,
325
- chs_info: list[dict[str, Any]] | None = None,
326
+ chs_info: Optional[list[dict[str, Any]]] = None, # type: ignore
327
+ **kwargs,
326
328
  ):
327
- """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
329
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
328
330
 
329
331
  Parameters
330
332
  ----------
331
- model: SignalJEPA
332
- Pre-trained model.
333
- n_outputs: int
334
- Number of classes for the new model.
333
+ model: SignalJEPA, str, Path, or None
334
+ Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
335
+ (for Hub-style loading), or None (for Hub loading via kwargs).
336
+ n_outputs: int or None
337
+ Number of classes for the new model. Required when loading from a SignalJEPA model,
338
+ optional when loading from Hub (will be read from config).
335
339
  n_spat_filters: int
336
340
  Number of spatial filters.
337
341
  chs_info: list of dict | None
338
342
  Information about each individual EEG channel. This should be filled with
339
343
  ``info["chs"]``. Refer to :class:`mne.Info` for more details.
344
+ **kwargs
345
+ Additional keyword arguments passed to the parent class for Hub loading.
340
346
  """
347
+ # Check if this is a Hub-style load (from a directory path)
348
+ if isinstance(model, (str, Path)) or (model is None and kwargs):
349
+ # This is a Hub load, delegate to parent class
350
+ if isinstance(model, (str, Path)):
351
+ # model is actually the repo_id or directory path
352
+ return super().from_pretrained(model, **kwargs)
353
+ else:
354
+ # model is None, treat as hub-style load
355
+ return super().from_pretrained(**kwargs)
356
+
357
+ # This is the original SignalJEPA transfer learning case
358
+ if not isinstance(model, SignalJEPA):
359
+ raise TypeError(
360
+ f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
361
+ )
362
+ if n_outputs is None:
363
+ raise ValueError(
364
+ "n_outputs must be provided when loading from a SignalJEPA model"
365
+ )
366
+
341
367
  feature_encoder = model.feature_encoder
342
368
  pos_encoder = model.pos_encoder
343
369
  transformer = model.transformer
@@ -463,22 +489,47 @@ class SignalJEPA_PostLocal(_BaseSignalJEPA):
463
489
 
464
490
  @classmethod
465
491
  def from_pretrained(
466
- cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
492
+ cls,
493
+ model: SignalJEPA | str | Path = None, # type: ignore
494
+ n_outputs: int = None, # type: ignore
495
+ n_spat_filters: int = 4,
496
+ **kwargs,
467
497
  ):
468
- """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
498
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
469
499
 
470
500
  Parameters
471
501
  ----------
472
- model: SignalJEPA
473
- Pre-trained model.
474
- n_outputs: int
475
- Number of classes for the new model.
502
+ model: SignalJEPA, str, Path, or None
503
+ Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
504
+ (for Hub-style loading), or None (for Hub loading via kwargs).
505
+ n_outputs: int or None
506
+ Number of classes for the new model. Required when loading from a SignalJEPA model,
507
+ optional when loading from Hub (will be read from config).
476
508
  n_spat_filters: int
477
509
  Number of spatial filters.
478
- chs_info: list of dict | None
479
- Information about each individual EEG channel. This should be filled with
480
- ``info["chs"]``. Refer to :class:`mne.Info` for more details.
510
+ **kwargs
511
+ Additional keyword arguments passed to the parent class for Hub loading.
481
512
  """
513
+ # Check if this is a Hub-style load (from a directory path)
514
+ if isinstance(model, (str, Path)) or (model is None and kwargs):
515
+ # This is a Hub load, delegate to parent class
516
+ if isinstance(model, (str, Path)):
517
+ # model is actually the repo_id or directory path
518
+ return super().from_pretrained(model, **kwargs)
519
+ else:
520
+ # model is None, treat as hub-style load
521
+ return super().from_pretrained(**kwargs)
522
+
523
+ # This is the original SignalJEPA transfer learning case
524
+ if not isinstance(model, SignalJEPA):
525
+ raise TypeError(
526
+ f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
527
+ )
528
+ if n_outputs is None:
529
+ raise ValueError(
530
+ "n_outputs must be provided when loading from a SignalJEPA model"
531
+ )
532
+
482
533
  feature_encoder = model.feature_encoder
483
534
  assert feature_encoder is not None
484
535
  new_model = cls(
@@ -597,22 +648,47 @@ class SignalJEPA_PreLocal(_BaseSignalJEPA):
597
648
 
598
649
  @classmethod
599
650
  def from_pretrained(
600
- cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
651
+ cls,
652
+ model: SignalJEPA | str | Path = None, # type: ignore
653
+ n_outputs: int = None, # type: ignore
654
+ n_spat_filters: int = 4,
655
+ **kwargs,
601
656
  ):
602
- """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
657
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
603
658
 
604
659
  Parameters
605
660
  ----------
606
- model: SignalJEPA
607
- Pre-trained model.
608
- n_outputs: int
609
- Number of classes for the new model.
661
+ model: SignalJEPA, str, Path, or None
662
+ Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
663
+ (for Hub-style loading), or None (for Hub loading via kwargs).
664
+ n_outputs: int or None
665
+ Number of classes for the new model. Required when loading from a SignalJEPA model,
666
+ optional when loading from Hub (will be read from config).
610
667
  n_spat_filters: int
611
668
  Number of spatial filters.
612
- chs_info: list of dict | None
613
- Information about each individual EEG channel. This should be filled with
614
- ``info["chs"]``. Refer to :class:`mne.Info` for more details.
669
+ **kwargs
670
+ Additional keyword arguments passed to the parent class for Hub loading.
615
671
  """
672
+ # Check if this is a Hub-style load (from a directory path)
673
+ if isinstance(model, (str, Path)) or (model is None and kwargs):
674
+ # This is a Hub load, delegate to parent class
675
+ if isinstance(model, (str, Path)):
676
+ # model is actually the repo_id or directory path
677
+ return super().from_pretrained(model, **kwargs)
678
+ else:
679
+ # model is None, treat as hub-style load
680
+ return super().from_pretrained(**kwargs)
681
+
682
+ # This is the original SignalJEPA transfer learning case
683
+ if not isinstance(model, SignalJEPA):
684
+ raise TypeError(
685
+ f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
686
+ )
687
+ if n_outputs is None:
688
+ raise ValueError(
689
+ "n_outputs must be provided when loading from a SignalJEPA model"
690
+ )
691
+
616
692
  feature_encoder = model.feature_encoder
617
693
  assert feature_encoder is not None
618
694
  new_model = cls(