braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,125 @@
1
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+
6
+ from __future__ import annotations
7
+
8
+ import os
9
+
10
+ import mne
11
+ import numpy as np
12
+ import pandas as pd
13
+ from mne.datasets.sleep_physionet.age import fetch_data
14
+
15
+ from .base import BaseConcatDataset, RawDataset
16
+
17
+
18
+ class SleepPhysionet(BaseConcatDataset):
19
+ """Sleep Physionet dataset.
20
+
21
+ Sleep dataset from https://physionet.org/content/sleep-edfx/1.0.0/.
22
+ Contains overnight recordings from 78 healthy subjects.
23
+
24
+ See `MNE example <https://mne.tools/stable/auto_tutorials/clinical/60_sleep.html>`.
25
+
26
+ Parameters
27
+ ----------
28
+ subject_ids : list(int) | int | None
29
+ (list of) int of subject(s) to be loaded. If None, load all available
30
+ subjects.
31
+ recording_ids : list(int) | None
32
+ Recordings to load per subject (each subject except 13 has two
33
+ recordings). Can be [1], [2] or [1, 2] (same as None).
34
+ preload : bool
35
+ If True, preload the data of the Raw objects.
36
+ load_eeg_only : bool
37
+ If True, only load the EEG channels and discard the others (EOG, EMG,
38
+ temperature, respiration) to avoid resampling the other signals.
39
+ crop_wake_mins : float
40
+ Number of minutes of wake time to keep before the first sleep event
41
+ and after the last sleep event. Used to reduce the imbalance in this
42
+ dataset. Default of 30 mins.
43
+ crop : None | tuple
44
+ If not None crop the raw files (e.g. to use only the first 3h).
45
+ Example: ``crop=(0, 3600*3)`` to keep only the first 3h.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ subject_ids: list[int] | int | None = None,
51
+ recording_ids: list[int] | None = None,
52
+ preload=False,
53
+ load_eeg_only=True,
54
+ crop_wake_mins=30,
55
+ crop=None,
56
+ ):
57
+ if subject_ids is None:
58
+ subject_ids = list(range(83))
59
+ if recording_ids is None:
60
+ recording_ids = [1, 2]
61
+
62
+ paths = fetch_data(subject_ids, recording=recording_ids, on_missing="warn")
63
+
64
+ all_base_ds = list()
65
+ for p in paths:
66
+ raw, desc = self._load_raw(
67
+ p[0],
68
+ p[1],
69
+ preload=preload,
70
+ load_eeg_only=load_eeg_only,
71
+ crop_wake_mins=crop_wake_mins,
72
+ crop=crop,
73
+ )
74
+ base_ds = RawDataset(raw, desc)
75
+ all_base_ds.append(base_ds)
76
+ super().__init__(all_base_ds)
77
+
78
+ @staticmethod
79
+ def _load_raw(
80
+ raw_fname,
81
+ ann_fname,
82
+ preload,
83
+ load_eeg_only=True,
84
+ crop_wake_mins=False,
85
+ crop=None,
86
+ ):
87
+ ch_mapping = {
88
+ "EOG horizontal": "eog",
89
+ "Resp oro-nasal": "misc",
90
+ "EMG submental": "misc",
91
+ "Temp rectal": "misc",
92
+ "Event marker": "misc",
93
+ }
94
+ exclude = list(ch_mapping.keys()) if load_eeg_only else ()
95
+
96
+ raw = mne.io.read_raw_edf(raw_fname, preload=preload, exclude=exclude)
97
+ annots = mne.read_annotations(ann_fname)
98
+ raw.set_annotations(annots, emit_warning=False)
99
+
100
+ if crop_wake_mins > 0:
101
+ # Find first and last sleep stages
102
+ mask = [x[-1] in ["1", "2", "3", "4", "R"] for x in annots.description]
103
+ sleep_event_inds = np.where(mask)[0]
104
+
105
+ # Crop raw
106
+ tmin = annots[int(sleep_event_inds[0])]["onset"] - crop_wake_mins * 60
107
+ tmax = annots[int(sleep_event_inds[-1])]["onset"] + crop_wake_mins * 60
108
+ raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1]))
109
+
110
+ # Rename EEG channels
111
+ ch_names = {i: i.replace("EEG ", "") for i in raw.ch_names if "EEG" in i}
112
+ raw.rename_channels(ch_names)
113
+
114
+ if not load_eeg_only:
115
+ raw.set_channel_types(ch_mapping)
116
+
117
+ if crop is not None:
118
+ raw.crop(*crop)
119
+
120
+ basename = os.path.basename(raw_fname)
121
+ subj_nb = int(basename[3:5])
122
+ sess_nb = int(basename[5])
123
+ desc = pd.Series({"subject": subj_nb, "recording": sess_nb}, name="")
124
+
125
+ return raw, desc