braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,296 @@
1
+ # Authors: Pierre Guetschel
2
+ # Maciej Sliwowski
3
+ #
4
+ # License: BSD-3
5
+
6
+ from __future__ import annotations
7
+
8
+ import warnings
9
+ from collections import OrderedDict
10
+ from typing import Dict, Iterable, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ from docstring_inheritance import NumpyDocstringInheritanceInitMeta
15
+ from torchinfo import ModelStatistics, summary
16
+
17
+
18
+ def deprecated_args(obj, *old_new_args):
19
+ out_args = []
20
+ for old_name, new_name, old_val, new_val in old_new_args:
21
+ if old_val is None:
22
+ out_args.append(new_val)
23
+ else:
24
+ warnings.warn(
25
+ f"{obj.__class__.__name__}: {old_name!r} is depreciated. Use {new_name!r} instead."
26
+ )
27
+ if new_val is not None:
28
+ raise ValueError(
29
+ f"{obj.__class__.__name__}: Both {old_name!r} and {new_name!r} were specified."
30
+ )
31
+ out_args.append(old_val)
32
+ return out_args
33
+
34
+
35
+ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
36
+ """
37
+ Mixin class for all EEG models in braindecode.
38
+
39
+ Parameters
40
+ ----------
41
+ n_outputs : int
42
+ Number of outputs of the model. This is the number of classes
43
+ in the case of classification.
44
+ n_chans : int
45
+ Number of EEG channels.
46
+ chs_info : list of dict
47
+ Information about each individual EEG channel. This should be filled with
48
+ ``info["chs"]``. Refer to :class:`mne.Info` for more details.
49
+ n_times : int
50
+ Number of time samples of the input window.
51
+ input_window_seconds : float
52
+ Length of the input window in seconds.
53
+ sfreq : float
54
+ Sampling frequency of the EEG recordings.
55
+
56
+ Raises
57
+ ------
58
+ ValueError: If some input signal-related parameters are not specified
59
+ and can not be inferred.
60
+
61
+ Notes
62
+ -----
63
+ If some input signal-related parameters are not specified,
64
+ there will be an attempt to infer them from the other parameters.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ n_outputs: Optional[int] = None, # type: ignore[assignment]
70
+ n_chans: Optional[int] = None, # type: ignore[assignment]
71
+ chs_info=None, # type: ignore[assignment]
72
+ n_times: Optional[int] = None, # type: ignore[assignment]
73
+ input_window_seconds: Optional[float] = None, # type: ignore[assignment]
74
+ sfreq: Optional[float] = None, # type: ignore[assignment]
75
+ ):
76
+ if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
77
+ raise ValueError(f"{n_chans=} different from {chs_info=} length")
78
+ if (
79
+ n_times is not None
80
+ and input_window_seconds is not None
81
+ and sfreq is not None
82
+ and n_times != int(input_window_seconds * sfreq)
83
+ ):
84
+ raise ValueError(
85
+ f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
86
+ )
87
+
88
+ self._input_window_seconds = input_window_seconds # type: ignore[assignment]
89
+ self._chs_info = chs_info # type: ignore[assignment]
90
+ self._n_outputs = n_outputs # type: ignore[assignment]
91
+ self._n_chans = n_chans # type: ignore[assignment]
92
+ self._n_times = n_times # type: ignore[assignment]
93
+ self._sfreq = sfreq # type: ignore[assignment]
94
+
95
+ super().__init__()
96
+
97
+ @property
98
+ def n_outputs(self) -> int:
99
+ if self._n_outputs is None:
100
+ raise ValueError("n_outputs not specified.")
101
+ return self._n_outputs
102
+
103
+ @property
104
+ def n_chans(self) -> int:
105
+ if self._n_chans is None and self._chs_info is not None:
106
+ return len(self._chs_info)
107
+ elif self._n_chans is None:
108
+ raise ValueError(
109
+ "n_chans could not be inferred. Either specify n_chans or chs_info."
110
+ )
111
+ return self._n_chans
112
+
113
+ @property
114
+ def chs_info(self) -> list[str]:
115
+ if self._chs_info is None:
116
+ raise ValueError("chs_info not specified.")
117
+ return self._chs_info
118
+
119
+ @property
120
+ def n_times(self) -> int:
121
+ if (
122
+ self._n_times is None
123
+ and self._input_window_seconds is not None
124
+ and self._sfreq is not None
125
+ ):
126
+ return int(self._input_window_seconds * self._sfreq)
127
+ elif self._n_times is None:
128
+ raise ValueError(
129
+ "n_times could not be inferred. "
130
+ "Either specify n_times or input_window_seconds and sfreq."
131
+ )
132
+ return self._n_times
133
+
134
+ @property
135
+ def input_window_seconds(self) -> float:
136
+ if (
137
+ self._input_window_seconds is None
138
+ and self._n_times is not None
139
+ and self._sfreq is not None
140
+ ):
141
+ return float(self._n_times / self._sfreq)
142
+ elif self._input_window_seconds is None:
143
+ raise ValueError(
144
+ "input_window_seconds could not be inferred. "
145
+ "Either specify input_window_seconds or n_times and sfreq."
146
+ )
147
+ return self._input_window_seconds
148
+
149
+ @property
150
+ def sfreq(self) -> float:
151
+ if (
152
+ self._sfreq is None
153
+ and self._input_window_seconds is not None
154
+ and self._n_times is not None
155
+ ):
156
+ return float(self._n_times / self._input_window_seconds)
157
+ elif self._sfreq is None:
158
+ raise ValueError(
159
+ "sfreq could not be inferred. "
160
+ "Either specify sfreq or input_window_seconds and n_times."
161
+ )
162
+ return self._sfreq
163
+
164
+ @property
165
+ def input_shape(self) -> tuple[int, int, int]:
166
+ """Input data shape."""
167
+ return (1, self.n_chans, self.n_times)
168
+
169
+ def get_output_shape(self) -> tuple[int, ...]:
170
+ """Returns shape of neural network output for batch size equal 1.
171
+
172
+ Returns
173
+ -------
174
+ output_shape: tuple[int, ...]
175
+ shape of the network output for `batch_size==1` (1, ...)
176
+ """
177
+ with torch.inference_mode():
178
+ try:
179
+ return tuple(
180
+ self.forward( # type: ignore
181
+ torch.zeros(
182
+ self.input_shape,
183
+ dtype=next(self.parameters()).dtype, # type: ignore
184
+ device=next(self.parameters()).device, # type: ignore
185
+ )
186
+ ).shape
187
+ )
188
+ except RuntimeError as exc:
189
+ if str(exc).endswith(
190
+ (
191
+ "Output size is too small",
192
+ "Kernel size can't be greater than actual input size",
193
+ )
194
+ ):
195
+ msg = (
196
+ "During model prediction RuntimeError was thrown showing that at some "
197
+ f"layer `{str(exc).split('.')[-1]}` (see above in the stacktrace). This "
198
+ "could be caused by providing too small `n_times`/`input_window_seconds`. "
199
+ "Model may require longer chunks of signal in the input than "
200
+ f"{self.input_shape}."
201
+ )
202
+ raise ValueError(msg) from exc
203
+ raise exc
204
+
205
+ mapping: Optional[Dict[str, str]] = None
206
+
207
+ def load_state_dict(self, state_dict, *args, **kwargs):
208
+ mapping = self.mapping if self.mapping else {}
209
+ new_state_dict = OrderedDict()
210
+ for k, v in state_dict.items():
211
+ if k in mapping:
212
+ new_state_dict[mapping[k]] = v
213
+ else:
214
+ new_state_dict[k] = v
215
+
216
+ return super().load_state_dict(new_state_dict, *args, **kwargs)
217
+
218
+ def to_dense_prediction_model(self, axis: tuple[int, ...] | int = (2, 3)) -> None:
219
+ """
220
+ Transform a sequential model with strides to a model that outputs
221
+ dense predictions by removing the strides and instead inserting dilations.
222
+ Modifies model in-place.
223
+
224
+ Parameters
225
+ ----------
226
+ axis: int or (int,int)
227
+ Axis to transform (in terms of intermediate output axes)
228
+ can either be 2, 3, or (2,3).
229
+
230
+ Notes
231
+ -----
232
+ Does not yet work correctly for average pooling.
233
+ Prior to version 0.1.7, there had been a bug that could move strides
234
+ backwards one layer.
235
+
236
+ """
237
+ if not hasattr(axis, "__iter__"):
238
+ axis = (axis,)
239
+ assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis" # type: ignore[union-attr]
240
+ axis = np.array(axis) - 2
241
+ stride_so_far = np.array([1, 1])
242
+ for module in self.modules(): # type: ignore
243
+ if hasattr(module, "dilation"):
244
+ assert module.dilation == 1 or (module.dilation == (1, 1)), (
245
+ "Dilation should equal 1 before conversion, maybe the model is "
246
+ "already converted?"
247
+ )
248
+ new_dilation = [1, 1]
249
+ for ax in axis: # type: ignore[union-attr]
250
+ new_dilation[ax] = int(stride_so_far[ax])
251
+ module.dilation = tuple(new_dilation)
252
+ if hasattr(module, "stride"):
253
+ if not hasattr(module.stride, "__len__"):
254
+ module.stride = (module.stride, module.stride)
255
+ stride_so_far *= np.array(module.stride)
256
+ new_stride = list(module.stride)
257
+ for ax in axis: # type: ignore[union-attr]
258
+ new_stride[ax] = 1
259
+ module.stride = tuple(new_stride)
260
+
261
+ def get_torchinfo_statistics(
262
+ self,
263
+ col_names: Optional[Iterable[str]] = (
264
+ "input_size",
265
+ "output_size",
266
+ "num_params",
267
+ "kernel_size",
268
+ ),
269
+ row_settings: Optional[Iterable[str]] = ("var_names", "depth"),
270
+ ) -> ModelStatistics:
271
+ """Generate table describing the model using torchinfo.summary.
272
+
273
+ Parameters
274
+ ----------
275
+ col_names : tuple, optional
276
+ Specify which columns to show in the output, see torchinfo for details, by default
277
+ ("input_size", "output_size", "num_params", "kernel_size")
278
+ row_settings : tuple, optional
279
+ Specify which features to show in a row, see torchinfo for details, by default
280
+ ("var_names", "depth")
281
+
282
+ Returns
283
+ -------
284
+ torchinfo.ModelStatistics
285
+ ModelStatistics generated by torchinfo.summary.
286
+ """
287
+ return summary(
288
+ self,
289
+ input_size=(1, self.n_chans, self.n_times),
290
+ col_names=col_names,
291
+ row_settings=row_settings,
292
+ verbose=0,
293
+ )
294
+
295
+ def __str__(self) -> str:
296
+ return str(self.get_torchinfo_statistics())