braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,11 @@
3
3
  #
4
4
  # License: BSD-3
5
5
 
6
- import warnings
7
- from typing import Dict, Iterable, List, Optional, Tuple
6
+ from __future__ import annotations
8
7
 
8
+ import warnings
9
9
  from collections import OrderedDict
10
+ from typing import Dict, Iterable, Optional
10
11
 
11
12
  import numpy as np
12
13
  import torch
@@ -21,11 +22,11 @@ def deprecated_args(obj, *old_new_args):
21
22
  out_args.append(new_val)
22
23
  else:
23
24
  warnings.warn(
24
- f'{obj.__class__.__name__}: {old_name!r} is depreciated. Use {new_name!r} instead.'
25
+ f"{obj.__class__.__name__}: {old_name!r} is depreciated. Use {new_name!r} instead."
25
26
  )
26
27
  if new_val is not None:
27
28
  raise ValueError(
28
- f'{obj.__class__.__name__}: Both {old_name!r} and {new_name!r} were specified.'
29
+ f"{obj.__class__.__name__}: Both {old_name!r} and {new_name!r} were specified."
29
30
  )
30
31
  out_args.append(old_val)
31
32
  return out_args
@@ -51,21 +52,12 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
51
52
  Length of the input window in seconds.
52
53
  sfreq : float
53
54
  Sampling frequency of the EEG recordings.
54
- add_log_softmax: bool
55
- Whether to use log-softmax non-linearity as the output function.
56
- LogSoftmax final layer will be removed in the future.
57
- Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
58
- Check the documentation of the torch.nn loss functions:
59
- https://pytorch.org/docs/stable/nn.html#loss-functions.
60
55
 
61
56
  Raises
62
57
  ------
63
58
  ValueError: If some input signal-related parameters are not specified
64
59
  and can not be inferred.
65
60
 
66
- FutureWarning: If add_log_softmax is True, since LogSoftmax final layer
67
- will be removed in the future.
68
-
69
61
  Notes
70
62
  -----
71
63
  If some input signal-related parameters are not specified,
@@ -73,139 +65,132 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
73
65
  """
74
66
 
75
67
  def __init__(
76
- self,
77
- n_outputs: Optional[int] = None,
78
- n_chans: Optional[int] = None,
79
- chs_info: Optional[List[Dict]] = None,
80
- n_times: Optional[int] = None,
81
- input_window_seconds: Optional[float] = None,
82
- sfreq: Optional[float] = None,
83
- add_log_softmax: Optional[bool] = False,
68
+ self,
69
+ n_outputs: Optional[int] = None, # type: ignore[assignment]
70
+ n_chans: Optional[int] = None, # type: ignore[assignment]
71
+ chs_info=None, # type: ignore[assignment]
72
+ n_times: Optional[int] = None, # type: ignore[assignment]
73
+ input_window_seconds: Optional[float] = None, # type: ignore[assignment]
74
+ sfreq: Optional[float] = None, # type: ignore[assignment]
84
75
  ):
76
+ if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
77
+ raise ValueError(f"{n_chans=} different from {chs_info=} length")
85
78
  if (
86
- n_chans is not None and
87
- chs_info is not None and
88
- len(chs_info) != n_chans
89
- ):
90
- raise ValueError(f'{n_chans=} different from {chs_info=} length')
91
- if (
92
- n_times is not None and
93
- input_window_seconds is not None and
94
- sfreq is not None and
95
- n_times != int(input_window_seconds * sfreq)
79
+ n_times is not None
80
+ and input_window_seconds is not None
81
+ and sfreq is not None
82
+ and n_times != int(input_window_seconds * sfreq)
96
83
  ):
97
84
  raise ValueError(
98
- f'{n_times=} different from '
99
- f'{input_window_seconds=} * {sfreq=}'
85
+ f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
100
86
  )
101
- self._n_outputs = n_outputs
102
- self._n_chans = n_chans
103
- self._chs_info = chs_info
104
- self._n_times = n_times
105
- self._input_window_seconds = input_window_seconds
106
- self._sfreq = sfreq
107
- self._add_log_softmax = add_log_softmax
87
+
88
+ self._input_window_seconds = input_window_seconds # type: ignore[assignment]
89
+ self._chs_info = chs_info # type: ignore[assignment]
90
+ self._n_outputs = n_outputs # type: ignore[assignment]
91
+ self._n_chans = n_chans # type: ignore[assignment]
92
+ self._n_times = n_times # type: ignore[assignment]
93
+ self._sfreq = sfreq # type: ignore[assignment]
94
+
108
95
  super().__init__()
109
96
 
110
97
  @property
111
- def n_outputs(self):
98
+ def n_outputs(self) -> int:
112
99
  if self._n_outputs is None:
113
- raise ValueError('n_outputs not specified.')
100
+ raise ValueError("n_outputs not specified.")
114
101
  return self._n_outputs
115
102
 
116
103
  @property
117
- def n_chans(self):
104
+ def n_chans(self) -> int:
118
105
  if self._n_chans is None and self._chs_info is not None:
119
106
  return len(self._chs_info)
120
107
  elif self._n_chans is None:
121
108
  raise ValueError(
122
- 'n_chans could not be inferred. Either specify n_chans or chs_info.'
109
+ "n_chans could not be inferred. Either specify n_chans or chs_info."
123
110
  )
124
111
  return self._n_chans
125
112
 
126
113
  @property
127
- def chs_info(self):
114
+ def chs_info(self) -> list[str]:
128
115
  if self._chs_info is None:
129
- raise ValueError('chs_info not specified.')
116
+ raise ValueError("chs_info not specified.")
130
117
  return self._chs_info
131
118
 
132
119
  @property
133
- def n_times(self):
120
+ def n_times(self) -> int:
134
121
  if (
135
- self._n_times is None and
136
- self._input_window_seconds is not None and
137
- self._sfreq is not None
122
+ self._n_times is None
123
+ and self._input_window_seconds is not None
124
+ and self._sfreq is not None
138
125
  ):
139
126
  return int(self._input_window_seconds * self._sfreq)
140
127
  elif self._n_times is None:
141
128
  raise ValueError(
142
- 'n_times could not be inferred. '
143
- 'Either specify n_times or input_window_seconds and sfreq.'
129
+ "n_times could not be inferred. "
130
+ "Either specify n_times or input_window_seconds and sfreq."
144
131
  )
145
132
  return self._n_times
146
133
 
147
134
  @property
148
- def input_window_seconds(self):
135
+ def input_window_seconds(self) -> float:
149
136
  if (
150
- self._input_window_seconds is None and
151
- self._n_times is not None and
152
- self._sfreq is not None
137
+ self._input_window_seconds is None
138
+ and self._n_times is not None
139
+ and self._sfreq is not None
153
140
  ):
154
- return self._n_times / self._sfreq
141
+ return float(self._n_times / self._sfreq)
155
142
  elif self._input_window_seconds is None:
156
143
  raise ValueError(
157
- 'input_window_seconds could not be inferred. '
158
- 'Either specify input_window_seconds or n_times and sfreq.'
144
+ "input_window_seconds could not be inferred. "
145
+ "Either specify input_window_seconds or n_times and sfreq."
159
146
  )
160
147
  return self._input_window_seconds
161
148
 
162
149
  @property
163
- def sfreq(self):
150
+ def sfreq(self) -> float:
164
151
  if (
165
- self._sfreq is None and
166
- self._input_window_seconds is not None and
167
- self._n_times is not None
152
+ self._sfreq is None
153
+ and self._input_window_seconds is not None
154
+ and self._n_times is not None
168
155
  ):
169
- return self._n_times / self._input_window_seconds
156
+ return float(self._n_times / self._input_window_seconds)
170
157
  elif self._sfreq is None:
171
158
  raise ValueError(
172
- 'sfreq could not be inferred. '
173
- 'Either specify sfreq or input_window_seconds and n_times.'
159
+ "sfreq could not be inferred. "
160
+ "Either specify sfreq or input_window_seconds and n_times."
174
161
  )
175
162
  return self._sfreq
176
163
 
177
164
  @property
178
- def add_log_softmax(self):
179
- if self._add_log_softmax:
180
- warnings.warn("LogSoftmax final layer will be removed! " +
181
- "Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!")
182
- return self._add_log_softmax
183
-
184
- @property
185
- def input_shape(self) -> Tuple[int]:
165
+ def input_shape(self) -> tuple[int, int, int]:
186
166
  """Input data shape."""
187
167
  return (1, self.n_chans, self.n_times)
188
168
 
189
- def get_output_shape(self) -> Tuple[int]:
169
+ def get_output_shape(self) -> tuple[int, ...]:
190
170
  """Returns shape of neural network output for batch size equal 1.
191
171
 
192
172
  Returns
193
173
  -------
194
- output_shape: Tuple[int]
174
+ output_shape: tuple[int, ...]
195
175
  shape of the network output for `batch_size==1` (1, ...)
196
- """
176
+ """
197
177
  with torch.inference_mode():
198
178
  try:
199
- return tuple(self.forward(
200
- torch.zeros(
201
- self.input_shape,
202
- dtype=next(self.parameters()).dtype,
203
- device=next(self.parameters()).device
204
- )).shape)
179
+ return tuple(
180
+ self.forward( # type: ignore
181
+ torch.zeros(
182
+ self.input_shape,
183
+ dtype=next(self.parameters()).dtype, # type: ignore
184
+ device=next(self.parameters()).device, # type: ignore
185
+ )
186
+ ).shape
187
+ )
205
188
  except RuntimeError as exc:
206
189
  if str(exc).endswith(
207
- ("Output size is too small",
208
- "Kernel size can't be greater than actual input size")
190
+ (
191
+ "Output size is too small",
192
+ "Kernel size can't be greater than actual input size",
193
+ )
209
194
  ):
210
195
  msg = (
211
196
  "During model prediction RuntimeError was thrown showing that at some "
@@ -217,10 +202,9 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
217
202
  raise ValueError(msg) from exc
218
203
  raise exc
219
204
 
220
- mapping = None
205
+ mapping: Optional[Dict[str, str]] = None
221
206
 
222
207
  def load_state_dict(self, state_dict, *args, **kwargs):
223
-
224
208
  mapping = self.mapping if self.mapping else {}
225
209
  new_state_dict = OrderedDict()
226
210
  for k, v in state_dict.items():
@@ -231,7 +215,7 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
231
215
 
232
216
  return super().load_state_dict(new_state_dict, *args, **kwargs)
233
217
 
234
- def to_dense_prediction_model(self, axis: Tuple[int] = (2, 3)) -> None:
218
+ def to_dense_prediction_model(self, axis: tuple[int, ...] | int = (2, 3)) -> None:
235
219
  """
236
220
  Transform a sequential model with strides to a model that outputs
237
221
  dense predictions by removing the strides and instead inserting dilations.
@@ -250,19 +234,19 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
250
234
  backwards one layer.
251
235
 
252
236
  """
253
- if not hasattr(axis, "__len__"):
254
- axis = [axis]
255
- assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis"
237
+ if not hasattr(axis, "__iter__"):
238
+ axis = (axis,)
239
+ assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis" # type: ignore[union-attr]
256
240
  axis = np.array(axis) - 2
257
241
  stride_so_far = np.array([1, 1])
258
- for module in self.modules():
242
+ for module in self.modules(): # type: ignore
259
243
  if hasattr(module, "dilation"):
260
244
  assert module.dilation == 1 or (module.dilation == (1, 1)), (
261
245
  "Dilation should equal 1 before conversion, maybe the model is "
262
246
  "already converted?"
263
247
  )
264
248
  new_dilation = [1, 1]
265
- for ax in axis:
249
+ for ax in axis: # type: ignore[union-attr]
266
250
  new_dilation[ax] = int(stride_so_far[ax])
267
251
  module.dilation = tuple(new_dilation)
268
252
  if hasattr(module, "stride"):
@@ -270,19 +254,19 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
270
254
  module.stride = (module.stride, module.stride)
271
255
  stride_so_far *= np.array(module.stride)
272
256
  new_stride = list(module.stride)
273
- for ax in axis:
257
+ for ax in axis: # type: ignore[union-attr]
274
258
  new_stride[ax] = 1
275
259
  module.stride = tuple(new_stride)
276
260
 
277
261
  def get_torchinfo_statistics(
278
- self,
279
- col_names: Optional[Iterable[str]] = (
280
- "input_size",
281
- "output_size",
282
- "num_params",
283
- "kernel_size",
284
- ),
285
- row_settings: Optional[Iterable[str]] = ("var_names", "depth"),
262
+ self,
263
+ col_names: Optional[Iterable[str]] = (
264
+ "input_size",
265
+ "output_size",
266
+ "num_params",
267
+ "kernel_size",
268
+ ),
269
+ row_settings: Optional[Iterable[str]] = ("var_names", "depth"),
286
270
  ) -> ModelStatistics:
287
271
  """Generate table describing the model using torchinfo.summary.
288
272