braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,6 @@
1
- """Dataset classes."""
1
+ """
2
+ Dataset classes.
3
+ """
2
4
 
3
5
  # Authors: Hubert Banville <hubert.jbanville@gmail.com>
4
6
  # Lukas Gemein <l.gemein@gmail.com>
@@ -26,7 +28,7 @@ from mne.utils.docs import deprecated
26
28
  from torch.utils.data import ConcatDataset, Dataset
27
29
  from typing_extensions import TypeVar
28
30
 
29
- from .bids.hub import HubDatasetMixin
31
+ from .hub import HubDatasetMixin
30
32
  from .registry import register_dataset
31
33
 
32
34
 
@@ -63,9 +65,9 @@ class RecordDataset(Dataset[tuple[np.ndarray, int | str, tuple[int, int, int]]])
63
65
 
64
66
  Parameters
65
67
  ----------
66
- description : dict | pd.Series
68
+ description: dict | pd.Series
67
69
  Description in the form key: value.
68
- overwrite : bool
70
+ overwrite: bool
69
71
  Has to be True if a key in description already exists in the
70
72
  dataset description.
71
73
  """
@@ -399,6 +401,7 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
399
401
  list of RecordDataset
400
402
  target_transform : callable | None
401
403
  Optional function to call on targets before returning them.
404
+
402
405
  """
403
406
 
404
407
  datasets: list[T]
@@ -433,8 +436,8 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
433
436
 
434
437
  def __getitem__(self, idx: int | list):
435
438
  """
436
- ---
437
-
439
+ Parameters
440
+ ----------
438
441
  idx : int | list
439
442
  Index of window and target to return. If provided as a list of
440
443
  ints, multiple windows and targets will be extracted and
@@ -569,8 +572,7 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
569
572
  self._target_transform = fn
570
573
 
571
574
  def _outdated_save(self, path, overwrite=False):
572
- """This is a copy of the old saving function, that had inconsistent.
573
-
575
+ """This is a copy of the old saving function, that had inconsistent
574
576
  functionality for BaseDataset and WindowsDataset. It only exists to
575
577
  assure backwards compatibility by still being able to run the old tests.
576
578
 
@@ -666,10 +668,10 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
666
668
 
667
669
  Parameters
668
670
  ----------
669
- description : dict | pd.DataFrame
671
+ description: dict | pd.DataFrame
670
672
  Description in the form key: value where the length of the value
671
673
  has to match the number of datasets.
672
- overwrite : bool
674
+ overwrite: bool
673
675
  Has to be True if a key in description already exists in the
674
676
  dataset description.
675
677
  """
@@ -717,14 +719,8 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
717
719
  hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
718
720
  ):
719
721
  raise ValueError("dataset should have either raw or windows attribute")
720
-
721
- # Create path if it doesn't exist
722
- os.makedirs(path, exist_ok=True)
723
-
724
722
  path_contents = os.listdir(path)
725
- n_sub_dirs = len(
726
- [e for e in path_contents if os.path.isdir(os.path.join(path, e))]
727
- )
723
+ n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
728
724
  for i_ds, ds in enumerate(self.datasets):
729
725
  # remove subdirectory from list of untouched files / subdirectories
730
726
  if str(i_ds + offset) in path_contents:
@@ -27,11 +27,11 @@ class BBCIDataset(object):
27
27
 
28
28
  Parameters
29
29
  ----------
30
- filename : str
31
- load_sensor_names : list of str, optional
30
+ filename: str
31
+ load_sensor_names: list of str, optional
32
32
  Also speeds up loading if you only load some sensors.
33
33
  None means load all sensors.
34
- check_class_names : bool, optional
34
+ check_class_names: bool, optional
35
35
  check if the class names are part of some known class names at
36
36
  Translational NeuroTechnology Lab, AG Ball, Freiburg, Germany.
37
37
  """
@@ -134,13 +134,13 @@ class BBCIDataset(object):
134
134
 
135
135
  Parameters
136
136
  ----------
137
- filename : str
138
- pattern : str, optional
137
+ filename: str
138
+ pattern: str, optional
139
139
  Only return those sensor names that match the given pattern.
140
140
 
141
141
  Returns
142
142
  -------
143
- sensor_names : list of str
143
+ sensor_names: list of str
144
144
  Sensor names that match the pattern or all sensor names in the file.
145
145
  """
146
146
  with h5py.File(filename, "r") as h5file:
@@ -237,17 +237,17 @@ class BBCIDataset(object):
237
237
 
238
238
  def _check_class_names(all_class_names, event_times_in_ms, event_classes):
239
239
  """
240
- Checks if the class names are part of some known class names used in.
241
-
240
+ Checks if the class names are part of some known class names used in
242
241
  translational neurotechnology lab, AG Ball, Freiburg.
243
242
 
244
243
  Logs warning in case class names are not known.
245
244
 
246
245
  Parameters
247
246
  ----------
248
- all_class_names : list of str
249
- event_times_in_ms : list of number
250
- event_classes : list of number
247
+ all_class_names: list of str
248
+ event_times_in_ms: list of number
249
+ event_classes: list of number
250
+
251
251
  """
252
252
  if all_class_names == ["Right Hand", "Left Hand", "Rest", "Feet"]:
253
253
  pass
@@ -665,15 +665,16 @@ def load_bbci_sets_from_folder(
665
665
 
666
666
  Parameters
667
667
  ----------
668
- folder : str
668
+ folder: str
669
669
  Folder with .BBCI.mat files inside
670
- runs : list of int
670
+ runs: list of int
671
671
  If you only want to load specific runs.
672
672
  Assumes filenames with such kind of part: S001R02 for Run 2.
673
673
  Tries to match this regex: ``'S[0-9]{3,3}R[0-9]{2,2}_'``.
674
674
 
675
675
  Returns
676
676
  -------
677
+
677
678
  """
678
679
  bbci_mat_files = sorted(glob(os.path.join(folder, "*.BBCI.mat")))
679
680
  if runs != "all":
@@ -33,16 +33,16 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
33
33
  http://www.bbci.de/competition/iv/ for the dataset and competition description.
34
34
  ECoG library containing the dataset: https://searchworks.stanford.edu/view/zk881ps0522
35
35
 
36
+ Notes
37
+ -----
38
+ When using this dataset please cite [1]_ .
39
+
36
40
  Parameters
37
41
  ----------
38
42
  subject_ids : list(int) | int | None
39
43
  (list of) int of subject(s) to be loaded. If None, load all available
40
44
  subjects. Should be in range 1-3.
41
45
 
42
- Notes
43
- -----
44
- When using this dataset please cite [1]_ .
45
-
46
46
  References
47
47
  ----------
48
48
  .. [1] Miller, Kai J. "A library of human electrocorticographic data and analyses."
@@ -94,6 +94,7 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
94
94
 
95
95
  Returns
96
96
  -------
97
+
97
98
  """
98
99
  signature = "BCICompetitionIVDataset4"
99
100
  folder_name = "BCI_Competion4_dataset4_data_fingerflexions"
@@ -1,4 +1,3 @@
1
- # mypy: ignore-errors
2
1
  """Dataset for loading BIDS.
3
2
 
4
3
  More information on BIDS (Brain Imaging Data Structure) can be found at https://bids.neuroimaging.io
@@ -20,19 +19,26 @@ import numpy as np
20
19
  import pandas as pd
21
20
  from joblib import Parallel, delayed
22
21
 
23
- from ..base import BaseConcatDataset, RawDataset, WindowsDataset
22
+ from .base import BaseConcatDataset, RawDataset, WindowsDataset
24
23
 
25
24
 
26
25
  def _description_from_bids_path(bids_path: mne_bids.BIDSPath) -> dict[str, Any]:
27
- description = {"path": bids_path.fpath, **bids_path.entities}
28
- description.update(
29
- {
30
- "suffix": bids_path.suffix,
31
- "extension": bids_path.extension,
32
- "datatype": bids_path.datatype,
33
- }
34
- )
35
- return description
26
+ return {
27
+ "path": bids_path.fpath,
28
+ "subject": bids_path.subject,
29
+ "session": bids_path.session,
30
+ "task": bids_path.task,
31
+ "acquisition": bids_path.acquisition,
32
+ "run": bids_path.run,
33
+ "processing": bids_path.processing,
34
+ "recording": bids_path.recording,
35
+ "space": bids_path.space,
36
+ "split": bids_path.split,
37
+ "description": bids_path.description,
38
+ "suffix": bids_path.suffix,
39
+ "extension": bids_path.extension,
40
+ "datatype": bids_path.datatype,
41
+ }
36
42
 
37
43
 
38
44
  @dataclass
@@ -59,7 +65,7 @@ class BIDSDataset(BaseConcatDataset):
59
65
  The acquisition session. Corresponds to "ses".
60
66
  tasks : str | array-like of str | None
61
67
  The experimental task. Corresponds to "task".
62
- acquisitions : str | array-like of str | None
68
+ acquisitions: str | array-like of str | None
63
69
  The acquisition parameters. Corresponds to "acq".
64
70
  runs : str | array-like of str | None
65
71
  The run number. Corresponds to "run".
@@ -25,13 +25,14 @@ class BIDSIterableDataset(IterableDataset):
25
25
 
26
26
  Examples
27
27
  --------
28
- >>> from braindecode.datasets import BaseConcatDataset, RawDataset, RecordDataset
29
- >>> from braindecode.datasets.bids import BIDSIterableDataset
28
+ >>> from braindecode.datasets import RecordDataset, BaseConcatDataset
29
+ >>> from braindecode.datasets.bids import BIDSIterableDataset, _description_from_bids_path
30
30
  >>> from braindecode.preprocessing import create_fixed_length_windows
31
31
  >>>
32
32
  >>> def my_reader_fn(path):
33
33
  ... raw = mne_bids.read_raw_bids(path)
34
- ... ds: RecordDataset = RawDataset(raw, description={"path": path.fpath})
34
+ ... desc = _description_from_bids_path(path)
35
+ ... ds = RawDataset(raw, description=desc)
35
36
  ... windows_ds = create_fixed_length_windows(
36
37
  ... BaseConcatDataset([ds]),
37
38
  ... window_size_samples=400,
@@ -47,8 +48,7 @@ class BIDSIterableDataset(IterableDataset):
47
48
  Parameters
48
49
  ----------
49
50
  reader_fn : Callable[[mne_bids.BIDSPath], Sequence]
50
- A function that takes a BIDSPath and returns a dataset (e.g., a
51
- RecordDataset or BaseConcatDataset of RecordDataset).
51
+ A function that takes a BIDSPath and returns a dataset.
52
52
  pool_size : int
53
53
  The number of recordings to read and sample from.
54
54
  bids_paths : list[mne_bids.BIDSPath] | None
@@ -62,7 +62,7 @@ class BIDSIterableDataset(IterableDataset):
62
62
  The acquisition session. Corresponds to "ses".
63
63
  tasks : str | array-like of str | None
64
64
  The experimental task. Corresponds to "task".
65
- acquisitions : str | array-like of str | None
65
+ acquisitions: str | array-like of str | None
66
66
  The acquisition parameters. Corresponds to "acq".
67
67
  runs : str | array-like of str | None
68
68
  The run number. Corresponds to "run".
@@ -106,8 +106,6 @@ class BIDSIterableDataset(IterableDataset):
106
106
  If True, preload the data. Defaults to False.
107
107
  n_jobs : int
108
108
  Number of jobs to run in parallel. Defaults to 1.
109
-
110
-
111
109
  """
112
110
 
113
111
  def __init__(