braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
braindecode/util.py ADDED
@@ -0,0 +1,419 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ import glob
5
+ import os
6
+ import random
7
+ import re
8
+ from warnings import warn
9
+
10
+ import h5py
11
+ import mne
12
+ import numpy as np
13
+ import torch
14
+ from sklearn.utils import check_random_state
15
+ from torch import Tensor
16
+
17
+
18
+ def set_random_seeds(seed, cuda, cudnn_benchmark=None):
19
+ """Set seeds for python random module numpy.random and torch.
20
+
21
+ For more details about reproducibility in pytorch see
22
+ https://pytorch.org/docs/stable/notes/randomness.html
23
+
24
+ Parameters
25
+ ----------
26
+ seed : int
27
+ Random seed.
28
+ cuda : bool
29
+ Whether to set cuda seed with torch.
30
+ cudnn_benchmark : bool (default=None)
31
+ Whether pytorch will use cudnn benchmark. When set to `None` it will not modify
32
+ torch.backends.cudnn.benchmark (displays warning in the case of possible lack of
33
+ reproducibility). When set to True, results may not be reproducible (no warning displayed).
34
+ When set to False it may slow down computations.
35
+
36
+ Notes
37
+ -----
38
+ In some cases setting environment variable `PYTHONHASHSEED` may be needed before running a
39
+ script to ensure full reproducibility. See
40
+ https://forums.fast.ai/t/solved-reproducibility-where-is-the-randomness-coming-in/31628/14
41
+
42
+ Using this function may not ensure full reproducibility of the results as we do not set
43
+ `torch.use_deterministic_algorithms(True)`.
44
+ """
45
+ random.seed(seed)
46
+ torch.manual_seed(seed)
47
+ if cuda:
48
+ if isinstance(cudnn_benchmark, bool):
49
+ torch.backends.cudnn.benchmark = cudnn_benchmark
50
+ elif cudnn_benchmark is None:
51
+ if torch.backends.cudnn.benchmark:
52
+ warn(
53
+ "torch.backends.cudnn.benchmark was set to True which may results in lack of "
54
+ "reproducibility. In some cases to ensure reproducibility you may need to "
55
+ "set torch.backends.cudnn.benchmark to False.",
56
+ UserWarning,
57
+ )
58
+ else:
59
+ raise ValueError(
60
+ f"cudnn_benchmark expected to be bool or None, got '{cudnn_benchmark}'"
61
+ )
62
+ torch.cuda.manual_seed_all(seed)
63
+ np.random.seed(seed)
64
+
65
+
66
+ def np_to_th(X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs):
67
+ """
68
+ Convenience function to transform numpy array to `torch.Tensor`.
69
+
70
+ Converts `X` to ndarray using asarray if necessary.
71
+
72
+ Parameters
73
+ ----------
74
+ X : ndarray or list or number
75
+ Input arrays
76
+ requires_grad : bool
77
+ passed on to Variable constructor
78
+ dtype : numpy dtype, optional
79
+ var_kwargs:
80
+ passed on to Variable constructor
81
+
82
+ Returns
83
+ -------
84
+ var : `torch.Tensor`
85
+ """
86
+ if not hasattr(X, "__len__"):
87
+ X = [X]
88
+ X = np.asarray(X)
89
+ if dtype is not None:
90
+ X = X.astype(dtype)
91
+ X_tensor = torch.tensor(X, requires_grad=requires_grad, **tensor_kwargs)
92
+ if pin_memory:
93
+ X_tensor = X_tensor.pin_memory()
94
+ return X_tensor
95
+
96
+
97
+ def th_to_np(var: Tensor):
98
+ """Convenience function to transform `torch.Tensor` to numpy.
99
+
100
+ array.
101
+ Should work both for CPU and GPU.
102
+ """
103
+ return var.cpu().data.numpy()
104
+
105
+
106
+ def corr(a, b):
107
+ """
108
+ Computes correlation only between terms of a and terms of b, not within.
109
+
110
+ a and b.
111
+
112
+ Parameters
113
+ ----------
114
+ a, b : 2darray, features x samples
115
+
116
+ Returns
117
+ -------
118
+ Correlation between features in x and features in y
119
+ """
120
+ # Difference to numpy:
121
+ # Correlation only between terms of x and y
122
+ # not between x and x or y and y
123
+ this_cov = cov(a, b)
124
+ return _cov_to_corr(this_cov, a, b)
125
+
126
+
127
+ def cov(a, b):
128
+ """
129
+ Computes covariance only between terms of a and terms of b, not within.
130
+
131
+ a and b.
132
+
133
+ Parameters
134
+ ----------
135
+ a, b : 2darray, features x samples
136
+
137
+ Returns
138
+ -------
139
+ Covariance between features in x and features in y
140
+ """
141
+ demeaned_a = a - np.mean(a, axis=1, keepdims=True)
142
+ demeaned_b = b - np.mean(b, axis=1, keepdims=True)
143
+ this_cov = np.dot(demeaned_a, demeaned_b.T) / (b.shape[1] - 1)
144
+ return this_cov
145
+
146
+
147
+ def _cov_to_corr(this_cov, a, b):
148
+ # computing "unbiased" corr
149
+ # ddof=1 for unbiased..
150
+ var_a = np.var(a, axis=1, ddof=1)
151
+ var_b = np.var(b, axis=1, ddof=1)
152
+ return _cov_and_var_to_corr(this_cov, var_a, var_b)
153
+
154
+
155
+ def _cov_and_var_to_corr(this_cov, var_a, var_b):
156
+ divisor = np.outer(np.sqrt(var_a), np.sqrt(var_b))
157
+ return this_cov / divisor
158
+
159
+
160
+ def wrap_reshape_apply_fn(stat_fn, a, b, axis_a, axis_b):
161
+ """
162
+ Reshape two nd-arrays into 2d-arrays, apply function and reshape.
163
+
164
+ result back.
165
+
166
+ Parameters
167
+ ----------
168
+ stat_fn : function
169
+ Function to apply to 2d-arrays
170
+ a : nd-array: nd-array
171
+ b : nd-array
172
+ axis_a : int or list of int
173
+ sample axis
174
+ axis_b : int or list of int
175
+ sample axis
176
+
177
+ Returns
178
+ -------
179
+ result : nd-array
180
+ The result reshaped to remaining_dims_a + remaining_dims_b
181
+ """
182
+ if not hasattr(axis_a, "__len__"):
183
+ axis_a = [axis_a]
184
+ if not hasattr(axis_b, "__len__"):
185
+ axis_b = [axis_b]
186
+ other_axis_a = [i for i in range(a.ndim) if i not in axis_a]
187
+ other_axis_b = [i for i in range(b.ndim) if i not in axis_b]
188
+ transposed_topo_a = a.transpose(tuple(other_axis_a) + tuple(axis_a))
189
+ n_stat_axis_a = [a.shape[i] for i in axis_a]
190
+ n_other_axis_a = [a.shape[i] for i in other_axis_a]
191
+ flat_topo_a = transposed_topo_a.reshape(
192
+ np.prod(n_other_axis_a), np.prod(n_stat_axis_a)
193
+ )
194
+ transposed_topo_b = b.transpose(tuple(other_axis_b) + tuple(axis_b))
195
+ n_stat_axis_b = [b.shape[i] for i in axis_b]
196
+ n_other_axis_b = [b.shape[i] for i in other_axis_b]
197
+ flat_topo_b = transposed_topo_b.reshape(
198
+ np.prod(n_other_axis_b), np.prod(n_stat_axis_b)
199
+ )
200
+ assert np.array_equal(n_stat_axis_a, n_stat_axis_b)
201
+ stat_result = stat_fn(flat_topo_a, flat_topo_b)
202
+ topo_result = stat_result.reshape(tuple(n_other_axis_a) + tuple(n_other_axis_b))
203
+ return topo_result
204
+
205
+
206
+ def get_balanced_batches(n_trials, rng, shuffle, n_batches=None, batch_size=None):
207
+ """Create indices for batches balanced in size.
208
+
209
+ (batches will have maximum size difference of 1).
210
+ Supply either batch size or number of batches. Resulting batches
211
+ will not have the given batch size but rather the next largest batch size
212
+ that allows to split the set into balanced batches (maximum size difference 1).
213
+
214
+ Parameters
215
+ ----------
216
+ n_trials : int
217
+ Size of set.
218
+ rng : RandomState
219
+ shuffle : bool
220
+ Whether to shuffle indices before splitting set.
221
+ n_batches : int, optional
222
+ batch_size : int, optional
223
+
224
+ Returns
225
+ -------
226
+ batches : list of list of int
227
+ Indices for each batch.
228
+ """
229
+ assert batch_size is not None or n_batches is not None
230
+ if n_batches is None:
231
+ n_batches = int(np.round(n_trials / float(batch_size)))
232
+
233
+ if n_batches > 0:
234
+ min_batch_size = n_trials // n_batches
235
+ n_batches_with_extra_trial = n_trials % n_batches
236
+ else:
237
+ n_batches = 1
238
+ min_batch_size = n_trials
239
+ n_batches_with_extra_trial = 0
240
+ assert n_batches_with_extra_trial < n_batches
241
+ all_inds = np.array(range(n_trials))
242
+ if shuffle:
243
+ rng.shuffle(all_inds)
244
+ i_start_trial = 0
245
+ i_stop_trial = 0
246
+ batches = []
247
+ for i_batch in range(n_batches):
248
+ i_stop_trial += min_batch_size
249
+ if i_batch < n_batches_with_extra_trial:
250
+ i_stop_trial += 1
251
+ batch_inds = all_inds[range(i_start_trial, i_stop_trial)]
252
+ batches.append(batch_inds)
253
+ i_start_trial = i_stop_trial
254
+ assert i_start_trial == n_trials
255
+ return batches
256
+
257
+
258
+ def create_mne_dummy_raw(
259
+ n_channels,
260
+ n_times,
261
+ sfreq,
262
+ include_anns=True,
263
+ description=None,
264
+ savedir=None,
265
+ save_format="fif",
266
+ overwrite=True,
267
+ random_state=None,
268
+ ):
269
+ """Create an mne.io.RawArray with fake data, and optionally save it.
270
+
271
+ This will overwrite already existing files.
272
+
273
+ Parameters
274
+ ----------
275
+ n_channels : int
276
+ Number of channels.
277
+ n_times : int
278
+ Number of samples.
279
+ sfreq : float
280
+ Sampling frequency.
281
+ include_anns : bool
282
+ If True, also create annotations.
283
+ description : list | None
284
+ List of descriptions used for creating annotations. It should contain
285
+ 10 elements.
286
+ savedir : str | None
287
+ If provided as a string, the file will be saved under that directory.
288
+ save_format : str | list
289
+ If `savedir` is provided, this specifies the file format the data should
290
+ be saved to. Can be 'raw' or 'hdf5', or a list containing both.
291
+ random_state : int | RandomState
292
+ Random state for the generation of random data.
293
+
294
+ Returns
295
+ -------
296
+ raw : mne.io.Raw
297
+ The created Raw object.
298
+ save_fname : dict | None
299
+ Dictionary containing the name the raw data was saved to.
300
+ """
301
+ random_state = check_random_state(random_state)
302
+ data = random_state.rand(n_channels, n_times)
303
+ ch_names = [f"ch{i}" for i in range(n_channels)]
304
+ ch_types = ["eeg"] * n_channels
305
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
306
+
307
+ raw = mne.io.RawArray(data, info)
308
+
309
+ if include_anns:
310
+ n_anns = 10
311
+ inds = np.linspace(int(sfreq * 2), int(n_times - sfreq * 2), num=n_anns).astype(
312
+ int
313
+ )
314
+ onset = raw.times[inds]
315
+ duration = [1] * n_anns
316
+ if description is None:
317
+ description = ["test"] * n_anns
318
+ anns = mne.Annotations(onset, duration, description)
319
+ raw = raw.set_annotations(anns)
320
+
321
+ save_fname = dict()
322
+ if savedir is not None:
323
+ if not isinstance(save_format, list):
324
+ save_format = [save_format]
325
+ fname = os.path.join(savedir, "fake_eeg_raw")
326
+
327
+ if "fif" in save_format:
328
+ fif_fname = fname + ".fif"
329
+ raw.save(fif_fname, overwrite=overwrite)
330
+ save_fname["fif"] = fif_fname
331
+ if "hdf5" in save_format:
332
+ h5_fname = fname + ".h5"
333
+ with h5py.File(h5_fname, "w") as f:
334
+ f.create_dataset("fake_raw", dtype="f8", data=raw.get_data())
335
+ save_fname["hdf5"] = h5_fname
336
+
337
+ return raw, save_fname
338
+
339
+
340
+ class ThrowAwayIndexLoader(object):
341
+ def __init__(self, net, loader, is_regression):
342
+ self.net = net
343
+ self.loader = loader
344
+ self.last_i = None
345
+ self.is_regression = is_regression
346
+
347
+ def __iter__(
348
+ self,
349
+ ):
350
+ normal_iter = self.loader.__iter__()
351
+ for batch in normal_iter:
352
+ if len(batch) == 3:
353
+ x, y, i = batch
354
+ # Store for scoring callbacks
355
+ self.net._last_window_inds_ = i
356
+ else:
357
+ x, y = batch
358
+
359
+ # TODO: should be on dataset side
360
+ if hasattr(x, "type"):
361
+ x = x.type(torch.float32)
362
+ if self.is_regression:
363
+ y = y.type(torch.float32)
364
+ else:
365
+ y = y.type(torch.int64)
366
+ yield x, y
367
+
368
+
369
+ def update_estimator_docstring(base_class, docstring):
370
+ base_doc = base_class.__doc__.replace(" : ", ": ")
371
+ idx = base_doc.find("callbacks:")
372
+ idx_end = idx + base_doc[idx:].find("\n\n")
373
+ # remove callback descripiton already included in braindecode docstring
374
+ filtered_doc = base_doc[:idx] + base_doc[idx_end + 6 :]
375
+ splitted = docstring.split("Parameters\n ----------\n ")
376
+ out_docstring = (
377
+ splitted[0]
378
+ + filtered_doc[
379
+ filtered_doc.find("Parameters") : filtered_doc.find("Attributes")
380
+ ]
381
+ + splitted[1]
382
+ + filtered_doc[filtered_doc.find("Attributes") :]
383
+ )
384
+ return out_docstring
385
+
386
+
387
+ def _update_moabb_docstring(base_class, docstring):
388
+ base_doc = base_class.__doc__
389
+ # Clean up malformed rubrics from moabb docstrings
390
+ # Remove lines that have ".. rubric::" followed by content on same line or improper formatting
391
+
392
+ base_doc = re.sub(r"\.\. rubric:: (.+?)\n\s+\.\. note::", r".. note::", base_doc)
393
+ out_docstring = base_doc + f"\n\n{docstring}"
394
+ return out_docstring
395
+
396
+
397
+ def read_all_file_names(directory, extension):
398
+ """Read all files with specified extension from given path and sorts them.
399
+
400
+ based on a given sorting key.
401
+
402
+ Parameters
403
+ ----------
404
+ directory : str
405
+ Parent directory to be searched for files of the specified type.
406
+ extension : str
407
+ File extension, i.e. ".edf" or ".txt".
408
+
409
+ Returns
410
+ -------
411
+ file_paths : list(str)
412
+ List of all files found in (sub)directories of path.
413
+ """
414
+ assert extension.startswith(".")
415
+ file_paths = glob.glob(directory + "**/*" + extension, recursive=True)
416
+ assert len(file_paths) > 0, (
417
+ f"something went wrong. Found no {extension} files in {directory}"
418
+ )
419
+ return file_paths
braindecode/version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "1.3.0.dev177069446"
@@ -0,0 +1,8 @@
1
+ """
2
+ Functions for visualisations, especially of the ConvNets.
3
+ """
4
+
5
+ from .confusion_matrices import plot_confusion_matrix
6
+ from .gradients import compute_amplitude_gradients
7
+
8
+ __all__ = ["compute_amplitude_gradients", "plot_confusion_matrix"]