braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,19 @@
6
6
  #
7
7
  # License: BSD-3
8
8
 
9
- from contextlib import contextmanager
10
9
  import warnings
10
+ from contextlib import contextmanager
11
11
 
12
12
  import numpy as np
13
13
  import torch
14
14
  from mne.utils.check import check_version
15
15
  from skorch.callbacks.scoring import EpochScoring
16
- from skorch.utils import to_numpy
17
16
  from skorch.dataset import unpack_data
17
+ from skorch.utils import to_numpy
18
18
  from torch.utils.data import DataLoader
19
19
 
20
20
 
21
- def trial_preds_from_window_preds(
22
- preds, i_window_in_trials, i_stop_in_trials):
21
+ def trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials):
23
22
  """
24
23
  Assigning window predictions to trials while removing duplicate
25
24
  predictions.
@@ -41,7 +40,8 @@ def trial_preds_from_window_preds(
41
40
 
42
41
  """
43
42
  assert len(preds) == len(i_window_in_trials) == len(i_stop_in_trials), (
44
- f'{len(preds)}, {len(i_window_in_trials)}, {len(i_stop_in_trials)}')
43
+ f"{len(preds)}, {len(i_window_in_trials)}, {len(i_stop_in_trials)}"
44
+ )
45
45
 
46
46
  # Algorithm for assigning window predictions to trials
47
47
  # while removing duplicate predictions:
@@ -64,11 +64,11 @@ def trial_preds_from_window_preds(
64
64
  i_last_stop = None
65
65
  i_last_window = -1
66
66
  for window_preds, i_window, i_stop in zip(
67
- preds, i_window_in_trials, i_stop_in_trials):
67
+ preds, i_window_in_trials, i_stop_in_trials
68
+ ):
68
69
  window_preds = np.array(window_preds)
69
70
  if i_window != (i_last_window + 1):
70
- assert i_window == 0, (
71
- "window numbers in new trial should start from 0")
71
+ assert i_window == 0, "window numbers in new trial should start from 0"
72
72
  preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
73
73
  cur_trial_preds = []
74
74
  i_last_stop = None
@@ -117,16 +117,17 @@ class CroppedTrialEpochScoring(EpochScoring):
117
117
  """
118
118
  Class to compute scores for trials from a model that predicts (super)crops.
119
119
  """
120
+
120
121
  # XXX needs a docstring !!!
121
122
 
122
123
  def __init__(
123
- self,
124
- scoring,
125
- lower_is_better=True,
126
- on_train=False,
127
- name=None,
128
- target_extractor=to_numpy,
129
- use_caching=True,
124
+ self,
125
+ scoring,
126
+ lower_is_better=True,
127
+ on_train=False,
128
+ name=None,
129
+ target_extractor=to_numpy,
130
+ use_caching=True,
130
131
  ):
131
132
  super().__init__(
132
133
  scoring=scoring,
@@ -147,8 +148,7 @@ class CroppedTrialEpochScoring(EpochScoring):
147
148
  if not self.on_train:
148
149
  self.window_inds_ = []
149
150
 
150
- def on_batch_end(
151
- self, net, batch, y_pred, training, **kwargs):
151
+ def on_batch_end(self, net, batch, y_pred, training, **kwargs):
152
152
  # Skorch saves the predictions without moving them from GPU
153
153
  # https://github.com/skorch-dev/skorch/blob/fe71e3d55a4ae5f5f94ef7bdfc00fca3b3fd267f/skorch/callbacks/scoring.py#L385
154
154
  # This can cause memory issues in case of a large number of predictions
@@ -164,41 +164,42 @@ class CroppedTrialEpochScoring(EpochScoring):
164
164
  # Prevent that rng state of torch is changed by
165
165
  # creation+usage of iterator
166
166
  rng_state = torch.random.get_rng_state()
167
- pred_results = net.predict_with_window_inds_and_ys(
168
- dataset_train)
167
+ pred_results = net.predict_with_window_inds_and_ys(dataset_train)
169
168
  torch.random.set_rng_state(rng_state)
170
169
  else:
171
170
  pred_results = {}
172
- pred_results['i_window_in_trials'] = np.concatenate(
171
+ pred_results["i_window_in_trials"] = np.concatenate(
173
172
  [i[0].cpu().numpy() for i in self.window_inds_]
174
173
  )
175
- pred_results['i_window_stops'] = np.concatenate(
174
+ pred_results["i_window_stops"] = np.concatenate(
176
175
  [i[2].cpu().numpy() for i in self.window_inds_]
177
176
  )
178
- pred_results['preds'] = np.concatenate(
179
- [y_pred.cpu().numpy() for y_pred in self.y_preds_])
180
- pred_results['window_ys'] = np.concatenate(
181
- [y.cpu().numpy() for y in self.y_trues_])
177
+ pred_results["preds"] = np.concatenate(
178
+ [y_pred.cpu().numpy() for y_pred in self.y_preds_]
179
+ )
180
+ pred_results["window_ys"] = np.concatenate(
181
+ [y.cpu().numpy() for y in self.y_trues_]
182
+ )
182
183
 
183
184
  # A new trial starts
184
185
  # when the index of the window in trials
185
186
  # does not increment by 1
186
187
  # Add dummy infinity at start
187
- window_0_per_trial_mask = np.diff(
188
- pred_results['i_window_in_trials'], prepend=[np.inf]) != 1
189
- trial_ys = pred_results['window_ys'][window_0_per_trial_mask]
188
+ window_0_per_trial_mask = (
189
+ np.diff(pred_results["i_window_in_trials"], prepend=[np.inf]) != 1
190
+ )
191
+ trial_ys = pred_results["window_ys"][window_0_per_trial_mask]
190
192
  trial_preds = trial_preds_from_window_preds(
191
- pred_results['preds'],
192
- pred_results['i_window_in_trials'],
193
- pred_results['i_window_stops'])
193
+ pred_results["preds"],
194
+ pred_results["i_window_in_trials"],
195
+ pred_results["i_window_stops"],
196
+ )
194
197
 
195
198
  # Average across the timesteps of each trial so we have per-trial
196
199
  # predictions already, these will be just passed through the forward
197
200
  # method of the classifier/regressor to the skorch scoring function.
198
201
  # trial_preds is a list, each item is a 2d array classes x time
199
- y_preds_per_trial = np.array(
200
- [np.mean(p, axis=1) for p in trial_preds]
201
- )
202
+ y_preds_per_trial = np.array([np.mean(p, axis=1) for p in trial_preds])
202
203
  # Move into format expected by skorch (list of torch tensors)
203
204
  y_preds_per_trial = [torch.tensor(y_preds_per_trial)]
204
205
 
@@ -206,9 +207,10 @@ class CroppedTrialEpochScoring(EpochScoring):
206
207
  # that are also on same set
207
208
  cbs = net.callbacks_
208
209
  epoch_cbs = [
209
- cb for name, cb in cbs if
210
- isinstance(cb, CroppedTrialEpochScoring) and (
211
- cb.on_train == self.on_train)
210
+ cb
211
+ for name, cb in cbs
212
+ if isinstance(cb, CroppedTrialEpochScoring)
213
+ and (cb.on_train == self.on_train)
212
214
  ]
213
215
  for cb in epoch_cbs:
214
216
  cb.y_preds_ = y_preds_per_trial
@@ -218,7 +220,7 @@ class CroppedTrialEpochScoring(EpochScoring):
218
220
  dataset = dataset_train if self.on_train else dataset_valid
219
221
 
220
222
  with _cache_net_forward_iter(
221
- net, self.use_caching, self.y_preds_
223
+ net, self.use_caching, self.y_preds_
222
224
  ) as cached_net:
223
225
  current_score = self._scoring(cached_net, dataset, self.y_trues_)
224
226
  self._record_score(net.history, current_score)
@@ -231,6 +233,7 @@ class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
231
233
  Class to compute scores for trials from a model that predicts (super)crops with
232
234
  time series target.
233
235
  """
236
+
234
237
  def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
235
238
  assert self.use_caching
236
239
  if not self.crops_to_trials_computed:
@@ -238,37 +241,40 @@ class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
238
241
  # Prevent that rng state of torch is changed by
239
242
  # creation+usage of iterator
240
243
  rng_state = torch.random.get_rng_state()
241
- pred_results = net.predict_with_window_inds_and_ys(
242
- dataset_train)
244
+ pred_results = net.predict_with_window_inds_and_ys(dataset_train)
243
245
  torch.random.set_rng_state(rng_state)
244
246
  else:
245
247
  pred_results = {}
246
- pred_results['i_window_in_trials'] = np.concatenate(
248
+ pred_results["i_window_in_trials"] = np.concatenate(
247
249
  [i[0].cpu().numpy() for i in self.window_inds_]
248
250
  )
249
- pred_results['i_window_stops'] = np.concatenate(
251
+ pred_results["i_window_stops"] = np.concatenate(
250
252
  [i[2].cpu().numpy() for i in self.window_inds_]
251
253
  )
252
- pred_results['preds'] = np.concatenate(
253
- [y_pred.cpu().numpy() for y_pred in self.y_preds_])
254
- pred_results['window_ys'] = np.concatenate(
255
- [y.cpu().numpy() for y in self.y_trues_])
254
+ pred_results["preds"] = np.concatenate(
255
+ [y_pred.cpu().numpy() for y_pred in self.y_preds_]
256
+ )
257
+ pred_results["window_ys"] = np.concatenate(
258
+ [y.cpu().numpy() for y in self.y_trues_]
259
+ )
256
260
 
257
- num_preds = pred_results['preds'][-1].shape[-1]
261
+ num_preds = pred_results["preds"][-1].shape[-1]
258
262
  # slice the targets to fit preds shape
259
- pred_results['window_ys'] = [
260
- targets[:, -num_preds:] for targets in pred_results['window_ys']
263
+ pred_results["window_ys"] = [
264
+ targets[:, -num_preds:] for targets in pred_results["window_ys"]
261
265
  ]
262
266
 
263
267
  trial_preds = trial_preds_from_window_preds(
264
- pred_results['preds'],
265
- pred_results['i_window_in_trials'],
266
- pred_results['i_window_stops'])
268
+ pred_results["preds"],
269
+ pred_results["i_window_in_trials"],
270
+ pred_results["i_window_stops"],
271
+ )
267
272
 
268
273
  trial_ys = trial_preds_from_window_preds(
269
- pred_results['window_ys'],
270
- pred_results['i_window_in_trials'],
271
- pred_results['i_window_stops'])
274
+ pred_results["window_ys"],
275
+ pred_results["i_window_in_trials"],
276
+ pred_results["i_window_stops"],
277
+ )
272
278
 
273
279
  # the output is a list of predictions/targets per trial where each item is a
274
280
  # timeseries of predictions/targets of shape (n_classes x timesteps)
@@ -290,9 +296,10 @@ class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
290
296
  # that are also on same set
291
297
  cbs = net.callbacks_
292
298
  epoch_cbs = [
293
- cb for name, cb in cbs if
294
- isinstance(cb, CroppedTimeSeriesEpochScoring) and (
295
- cb.on_train == self.on_train)
299
+ cb
300
+ for name, cb in cbs
301
+ if isinstance(cb, CroppedTimeSeriesEpochScoring)
302
+ and (cb.on_train == self.on_train)
296
303
  ]
297
304
  masked_preds = [torch.tensor(masked_preds.T)]
298
305
  for cb in epoch_cbs:
@@ -365,7 +372,7 @@ class PostEpochTrainScoring(EpochScoring):
365
372
  for batch in iterator:
366
373
  batch_X, batch_y = unpack_data(batch)
367
374
  # TODO: remove after skorch 0.10 release
368
- if not check_version('skorch', min_version='0.10.1'):
375
+ if not check_version("skorch", min_version="0.10.1"):
369
376
  yp = net.evaluation_step(batch_X, training=False)
370
377
  # X, y unpacking has been pushed downstream in skorch 0.10
371
378
  else:
@@ -394,9 +401,7 @@ class PostEpochTrainScoring(EpochScoring):
394
401
  with _cache_net_forward_iter(
395
402
  net, use_caching=True, y_preds=self.y_preds_
396
403
  ) as cached_net:
397
- current_score = self._scoring(
398
- cached_net, dataset_train, self.y_trues_
399
- )
404
+ current_score = self._scoring(cached_net, dataset_train, self.y_trues_)
400
405
  self._record_score(net.history, current_score)
401
406
 
402
407
 
@@ -432,12 +437,14 @@ def predict_trials(module, dataset, return_targets=True, batch_size=1, num_worke
432
437
  module.eval()
433
438
  # we have a cropped dataset if there exists at least one trial with more
434
439
  # than one compute window
435
- more_than_one_window = sum(dataset.get_metadata()['i_window_in_trial'] != 0) > 0
440
+ more_than_one_window = sum(dataset.get_metadata()["i_window_in_trial"] != 0) > 0
436
441
  if not more_than_one_window:
437
- warnings.warn('This function was designed to predict trials from '
438
- 'cropped datasets, which typically have multiple compute '
439
- 'windows per trial. The given dataset has exactly one '
440
- 'window per trial.')
442
+ warnings.warn(
443
+ "This function was designed to predict trials from "
444
+ "cropped datasets, which typically have multiple compute "
445
+ "windows per trial. The given dataset has exactly one "
446
+ "window per trial."
447
+ )
441
448
  loader = DataLoader(
442
449
  dataset=dataset,
443
450
  batch_size=batch_size,
@@ -463,7 +470,8 @@ def predict_trials(module, dataset, return_targets=True, batch_size=1, num_worke
463
470
  if all_ys[0].shape == ():
464
471
  all_ys = np.array(all_ys)
465
472
  ys_per_trial = all_ys[
466
- np.diff(torch.cat(all_inds[0::3]), prepend=[np.inf]) != 1]
473
+ np.diff(torch.cat(all_inds[0::3]), prepend=[np.inf]) != 1
474
+ ]
467
475
  else:
468
476
  ys_per_trial = trial_preds_from_window_preds(
469
477
  preds=all_ys,
braindecode/util.py CHANGED
@@ -12,6 +12,7 @@ import mne
12
12
  import numpy as np
13
13
  import torch
14
14
  from sklearn.utils import check_random_state
15
+ from torch import Tensor
15
16
 
16
17
 
17
18
  def set_random_seeds(seed, cuda, cudnn_benchmark=None):
@@ -51,7 +52,9 @@ def set_random_seeds(seed, cuda, cudnn_benchmark=None):
51
52
  warn(
52
53
  "torch.backends.cudnn.benchmark was set to True which may results in lack of "
53
54
  "reproducibility. In some cases to ensure reproducibility you may need to "
54
- "set torch.backends.cudnn.benchmark to False.", UserWarning)
55
+ "set torch.backends.cudnn.benchmark to False.",
56
+ UserWarning,
57
+ )
55
58
  else:
56
59
  raise ValueError(
57
60
  f"cudnn_benchmark expected to be bool or None, got '{cudnn_benchmark}'"
@@ -60,19 +63,7 @@ def set_random_seeds(seed, cuda, cudnn_benchmark=None):
60
63
  np.random.seed(seed)
61
64
 
62
65
 
63
- def np_to_var(
64
- X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
65
- ):
66
- warn("np_to_var has been renamed np_to_th, please use np_to_th instead")
67
- return np_to_th(
68
- X, requires_grad=requires_grad, dtype=dtype, pin_memory=pin_memory,
69
- **tensor_kwargs
70
- )
71
-
72
-
73
- def np_to_th(
74
- X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
75
- ):
66
+ def np_to_th(X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs):
76
67
  """
77
68
  Convenience function to transform numpy array to `torch.Tensor`.
78
69
 
@@ -103,12 +94,7 @@ def np_to_th(
103
94
  return X_tensor
104
95
 
105
96
 
106
- def var_to_np(var):
107
- warn("var_to_np has been renamed th_to_np, please use th_to_np instead")
108
- return th_to_np(var)
109
-
110
-
111
- def th_to_np(var):
97
+ def th_to_np(var: Tensor):
112
98
  """Convenience function to transform `torch.Tensor` to numpy
113
99
  array.
114
100
 
@@ -209,15 +195,11 @@ def wrap_reshape_apply_fn(stat_fn, a, b, axis_a, axis_b):
209
195
  )
210
196
  assert np.array_equal(n_stat_axis_a, n_stat_axis_b)
211
197
  stat_result = stat_fn(flat_topo_a, flat_topo_b)
212
- topo_result = stat_result.reshape(
213
- tuple(n_other_axis_a) + tuple(n_other_axis_b)
214
- )
198
+ topo_result = stat_result.reshape(tuple(n_other_axis_a) + tuple(n_other_axis_b))
215
199
  return topo_result
216
200
 
217
201
 
218
- def get_balanced_batches(
219
- n_trials, rng, shuffle, n_batches=None, batch_size=None
220
- ):
202
+ def get_balanced_batches(n_trials, rng, shuffle, n_batches=None, batch_size=None):
221
203
  """Create indices for batches balanced in size
222
204
  (batches will have maximum size difference of 1).
223
205
  Supply either batch size or number of batches. Resulting batches
@@ -268,9 +250,17 @@ def get_balanced_batches(
268
250
  return batches
269
251
 
270
252
 
271
- def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
272
- description=None, savedir=None, save_format='fif',
273
- overwrite=True, random_state=None):
253
+ def create_mne_dummy_raw(
254
+ n_channels,
255
+ n_times,
256
+ sfreq,
257
+ include_anns=True,
258
+ description=None,
259
+ savedir=None,
260
+ save_format="fif",
261
+ overwrite=True,
262
+ random_state=None,
263
+ ):
274
264
  """Create an mne.io.RawArray with fake data, and optionally save it.
275
265
 
276
266
  This will overwrite already existing files.
@@ -305,20 +295,21 @@ def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
305
295
  """
306
296
  random_state = check_random_state(random_state)
307
297
  data = random_state.rand(n_channels, n_times)
308
- ch_names = [f'ch{i}' for i in range(n_channels)]
309
- ch_types = ['eeg'] * n_channels
298
+ ch_names = [f"ch{i}" for i in range(n_channels)]
299
+ ch_types = ["eeg"] * n_channels
310
300
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
311
301
 
312
302
  raw = mne.io.RawArray(data, info)
313
303
 
314
304
  if include_anns:
315
305
  n_anns = 10
316
- inds = np.linspace(
317
- int(sfreq * 2), int(n_times - sfreq * 2), num=n_anns).astype(int)
306
+ inds = np.linspace(int(sfreq * 2), int(n_times - sfreq * 2), num=n_anns).astype(
307
+ int
308
+ )
318
309
  onset = raw.times[inds]
319
310
  duration = [1] * n_anns
320
311
  if description is None:
321
- description = ['test'] * n_anns
312
+ description = ["test"] * n_anns
322
313
  anns = mne.Annotations(onset, duration, description)
323
314
  raw = raw.set_annotations(anns)
324
315
 
@@ -326,18 +317,17 @@ def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
326
317
  if savedir is not None:
327
318
  if not isinstance(save_format, list):
328
319
  save_format = [save_format]
329
- fname = os.path.join(savedir, 'fake_eeg_raw')
320
+ fname = os.path.join(savedir, "fake_eeg_raw")
330
321
 
331
- if 'fif' in save_format:
332
- fif_fname = fname + '.fif'
322
+ if "fif" in save_format:
323
+ fif_fname = fname + ".fif"
333
324
  raw.save(fif_fname, overwrite=overwrite)
334
- save_fname['fif'] = fif_fname
335
- if 'hdf5' in save_format:
336
- h5_fname = fname + '.h5'
337
- with h5py.File(h5_fname, 'w') as f:
338
- f.create_dataset(
339
- 'fake_raw', dtype='f8', data=raw.get_data())
340
- save_fname['hdf5'] = h5_fname
325
+ save_fname["fif"] = fif_fname
326
+ if "hdf5" in save_format:
327
+ h5_fname = fname + ".h5"
328
+ with h5py.File(h5_fname, "w") as f:
329
+ f.create_dataset("fake_raw", dtype="f8", data=raw.get_data())
330
+ save_fname["hdf5"] = h5_fname
341
331
 
342
332
  return raw, save_fname
343
333
 
@@ -349,7 +339,9 @@ class ThrowAwayIndexLoader(object):
349
339
  self.last_i = None
350
340
  self.is_regression = is_regression
351
341
 
352
- def __iter__(self, ):
342
+ def __iter__(
343
+ self,
344
+ ):
353
345
  normal_iter = self.loader.__iter__()
354
346
  for batch in normal_iter:
355
347
  if len(batch) == 3:
@@ -360,7 +352,7 @@ class ThrowAwayIndexLoader(object):
360
352
  x, y = batch
361
353
 
362
354
  # TODO: should be on dataset side
363
- if hasattr(x, 'type'):
355
+ if hasattr(x, "type"):
364
356
  x = x.type(torch.float32)
365
357
  if self.is_regression:
366
358
  y = y.type(torch.float32)
@@ -370,23 +362,26 @@ class ThrowAwayIndexLoader(object):
370
362
 
371
363
 
372
364
  def update_estimator_docstring(base_class, docstring):
373
- base_doc = base_class.__doc__.replace(' : ', ': ')
374
- idx = base_doc.find('callbacks:')
375
- idx_end = idx + base_doc[idx:].find('\n\n')
365
+ base_doc = base_class.__doc__.replace(" : ", ": ")
366
+ idx = base_doc.find("callbacks:")
367
+ idx_end = idx + base_doc[idx:].find("\n\n")
376
368
  # remove callback descripiton already included in braindecode docstring
377
- filtered_doc = base_doc[:idx] + base_doc[idx_end + 6:]
378
- splitted = docstring.split('Parameters\n ----------\n ')
369
+ filtered_doc = base_doc[:idx] + base_doc[idx_end + 6 :]
370
+ splitted = docstring.split("Parameters\n ----------\n ")
379
371
  out_docstring = (
380
- splitted[0] +
381
- filtered_doc[filtered_doc.find('Parameters'):filtered_doc.find('Attributes')] +
382
- splitted[1] +
383
- filtered_doc[filtered_doc.find('Attributes'):])
372
+ splitted[0]
373
+ + filtered_doc[
374
+ filtered_doc.find("Parameters") : filtered_doc.find("Attributes")
375
+ ]
376
+ + splitted[1]
377
+ + filtered_doc[filtered_doc.find("Attributes") :]
378
+ )
384
379
  return out_docstring
385
380
 
386
381
 
387
382
  def _update_moabb_docstring(base_class, docstring):
388
383
  base_doc = base_class.__doc__
389
- out_docstring = base_doc + f'\n\n{docstring}'
384
+ out_docstring = base_doc + f"\n\n{docstring}"
390
385
  return out_docstring
391
386
 
392
387
 
@@ -406,8 +401,9 @@ def read_all_file_names(directory, extension):
406
401
  file_paths: list(str)
407
402
  List of all files found in (sub)directories of path.
408
403
  """
409
- assert extension.startswith('.')
410
- file_paths = glob.glob(directory + '**/*' + extension, recursive=True)
404
+ assert extension.startswith(".")
405
+ file_paths = glob.glob(directory + "**/*" + extension, recursive=True)
411
406
  assert len(file_paths) > 0, (
412
- f'something went wrong. Found no {extension} files in {directory}')
407
+ f"something went wrong. Found no {extension} files in {directory}"
408
+ )
413
409
  return file_paths
braindecode/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.8"
1
+ __version__ = "1.1.0"
@@ -2,8 +2,7 @@
2
2
  Functions for visualisations, especially of the ConvNets.
3
3
  """
4
4
 
5
- from .gradients import compute_amplitude_gradients
6
5
  from .confusion_matrices import plot_confusion_matrix
6
+ from .gradients import compute_amplitude_gradients
7
7
 
8
- __all__ = ["compute_amplitude_gradients",
9
- "plot_confusion_matrix"]
8
+ __all__ = ["compute_amplitude_gradients", "plot_confusion_matrix"]