braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,240 @@
1
+ """Preprocessor objects based on mne methods."""
2
+
3
+ # Authors: Bruna Lopes <brunajaflopes@gmail.com>
4
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
5
+ #
6
+ # License: BSD-3
7
+ import inspect
8
+
9
+ import mne.channels
10
+ import mne.io
11
+ import mne.preprocessing
12
+
13
+ from braindecode.preprocessing.preprocess import Preprocessor
14
+
15
+
16
+ def _is_standalone_function(func):
17
+ """
18
+ Determine if a function is standalone based on its module.
19
+
20
+ Standalone functions are those in mne.preprocessing, mne.channels, mne.filter, etc.
21
+ that are not methods of mne.io.Raw.
22
+ """
23
+ # Check if it's a method of Raw by seeing if it's bound or unbound method
24
+ if hasattr(mne.io.Raw, func.__name__):
25
+ return False
26
+ # Otherwise, it's a standalone function
27
+ return True
28
+
29
+
30
+ def _generate_init_method(func, force_copy_false=False):
31
+ """
32
+ Generate an __init__ method for a class based on the function's signature.
33
+
34
+ Parameters
35
+ ----------
36
+ func : callable
37
+ The function to wrap.
38
+ force_copy_false : bool
39
+ If True, forces copy=False by default for functions that have a copy parameter.
40
+ """
41
+ func_name = func.__name__
42
+ parameters = list(inspect.signature(func).parameters.values())
43
+ param_names = [
44
+ param.name
45
+ for param in parameters[1:] # Skip 'self' or 'raw' or 'epochs'
46
+ ]
47
+ all_mandatory = [
48
+ param.name
49
+ for param in parameters[1:] # Skip 'self' or 'raw' or 'epochs'
50
+ if param.default == inspect.Parameter.empty
51
+ ]
52
+
53
+ def init_method(self, *args, **kwargs):
54
+ used = []
55
+ mandatory = list(all_mandatory)
56
+ init_kwargs = {}
57
+
58
+ # For standalone functions with copy parameter, set copy=False by default
59
+ if force_copy_false and "copy" in param_names and "copy" not in kwargs:
60
+ kwargs["copy"] = False
61
+
62
+ for name, value in zip(param_names, args):
63
+ init_kwargs[name] = value
64
+ used.append(name)
65
+ if name in mandatory:
66
+ mandatory.remove(name)
67
+ for name, value in kwargs.items():
68
+ if name in used:
69
+ raise TypeError(f"Multiple values for argument '{name}'")
70
+ if name not in param_names:
71
+ raise TypeError(
72
+ f"'{name}' is an invalid keyword argument for {func_name}()"
73
+ )
74
+ init_kwargs[name] = value
75
+ if name in mandatory:
76
+ mandatory.remove(name)
77
+ if len(mandatory) > 0:
78
+ raise TypeError(
79
+ f"{func_name}() missing required arguments: {', '.join(mandatory)}"
80
+ )
81
+ Preprocessor.__init__(self, fn=func_name, apply_on_array=False, **init_kwargs)
82
+
83
+ init_method.__signature__ = inspect.signature(func)
84
+ return init_method
85
+
86
+
87
+ def _generate_repr_method(class_name):
88
+ def repr_method(self):
89
+ args_str = ", ".join(f"{k}={v.__repr__()}" for k, v in self.kwargs.items())
90
+ return f"{class_name}({args_str})"
91
+
92
+ return repr_method
93
+
94
+
95
+ def _generate_mne_pre_processor(function):
96
+ """
97
+ Generate a class based on an MNE function for preprocessing.
98
+
99
+ Parameters
100
+ ----------
101
+ function : callable
102
+ The MNE function to wrap. Automatically determines if it's standalone
103
+ or a Raw method based on the function's module and name.
104
+ """
105
+ class_name = "".join(word.title() for word in function.__name__.split("_")).replace(
106
+ "Eeg", "EEG"
107
+ )
108
+
109
+ # Automatically determine if function is standalone
110
+ is_standalone = _is_standalone_function(function)
111
+
112
+ # Create a wrapper note that references the original MNE function
113
+ # For Raw methods, use mne.io.Raw.method_name format with :meth:
114
+ # For standalone functions, use the function name only with :func:
115
+ if not is_standalone:
116
+ ref_path = f"mne.io.Raw.{function.__name__}"
117
+ ref_role = "meth"
118
+ else:
119
+ # For standalone functions, try common MNE public APIs
120
+ # These are more likely to be in intersphinx inventory
121
+ func_name = function.__name__
122
+ if function.__module__.startswith("mne.preprocessing"):
123
+ ref_path = f"mne.preprocessing.{func_name}"
124
+ elif function.__module__.startswith("mne.channels"):
125
+ ref_path = f"mne.channels.{func_name}"
126
+ elif function.__module__.startswith("mne.filter"):
127
+ ref_path = f"mne.filter.{func_name}"
128
+ else:
129
+ ref_path = f"{function.__module__}.{func_name}"
130
+ ref_role = "func"
131
+
132
+ # Use proper Sphinx cross-reference for intersphinx linking
133
+ wrapper_note = (
134
+ f"Braindecode preprocessor wrapper for :{ref_role}:`~{ref_path}`.\n\n"
135
+ )
136
+
137
+ base_classes = (Preprocessor,)
138
+
139
+ # Check if function has a 'copy' parameter
140
+ sig = inspect.signature(function)
141
+ has_copy_param = "copy" in sig.parameters
142
+ force_copy_false = is_standalone and has_copy_param
143
+ # Automatically determine if function is standalone
144
+ is_standalone = _is_standalone_function(function)
145
+
146
+ # Check if function has a 'copy' parameter
147
+ sig = inspect.signature(function)
148
+ has_copy_param = "copy" in sig.parameters
149
+ force_copy_false = is_standalone and has_copy_param
150
+ class_attrs = {
151
+ "__init__": _generate_init_method(function, force_copy_false),
152
+ "__doc__": wrapper_note + (function.__doc__ or ""),
153
+ "__repr__": _generate_repr_method(class_name),
154
+ "fn": function if is_standalone else function.__name__,
155
+ "_is_standalone": is_standalone,
156
+ }
157
+ generated_class = type(class_name, base_classes, class_attrs)
158
+
159
+ return generated_class
160
+
161
+
162
+ # List of MNE functions to generate classes for
163
+ mne_functions = [
164
+ # From mne.filter
165
+ mne.filter.resample,
166
+ mne.filter.filter_data,
167
+ mne.filter.notch_filter,
168
+ # From mne.io.Raw methods
169
+ mne.io.Raw.add_channels,
170
+ mne.io.Raw.add_events,
171
+ mne.io.Raw.add_proj,
172
+ mne.io.Raw.add_reference_channels,
173
+ mne.io.Raw.anonymize,
174
+ mne.io.Raw.apply_gradient_compensation,
175
+ mne.io.Raw.apply_hilbert,
176
+ mne.io.Raw.apply_proj,
177
+ mne.io.Raw.crop,
178
+ mne.io.Raw.crop_by_annotations,
179
+ mne.io.Raw.del_proj,
180
+ mne.io.Raw.drop_channels,
181
+ mne.io.Raw.filter,
182
+ mne.io.Raw.fix_mag_coil_types,
183
+ mne.io.Raw.interpolate_bads,
184
+ mne.io.Raw.interpolate_to,
185
+ mne.io.Raw.notch_filter,
186
+ mne.io.Raw.pick,
187
+ mne.io.Raw.pick_channels,
188
+ mne.io.Raw.pick_types,
189
+ mne.io.Raw.rename_channels,
190
+ mne.io.Raw.reorder_channels,
191
+ mne.io.Raw.rescale,
192
+ mne.io.Raw.resample,
193
+ mne.io.Raw.savgol_filter,
194
+ mne.io.Raw.set_annotations,
195
+ mne.io.Raw.set_channel_types,
196
+ mne.io.Raw.set_eeg_reference,
197
+ mne.io.Raw.set_meas_date,
198
+ mne.io.Raw.set_montage,
199
+ # Standalone functions from mne.preprocessing
200
+ mne.preprocessing.annotate_amplitude,
201
+ mne.preprocessing.annotate_break,
202
+ mne.preprocessing.annotate_movement,
203
+ mne.preprocessing.annotate_muscle_zscore,
204
+ mne.preprocessing.annotate_nan,
205
+ mne.preprocessing.compute_current_source_density,
206
+ mne.preprocessing.compute_bridged_electrodes,
207
+ mne.preprocessing.equalize_bads,
208
+ mne.preprocessing.find_bad_channels_lof,
209
+ mne.preprocessing.fix_stim_artifact,
210
+ mne.preprocessing.interpolate_bridged_electrodes,
211
+ mne.preprocessing.maxwell_filter,
212
+ mne.preprocessing.oversampled_temporal_projection,
213
+ mne.preprocessing.realign_raw,
214
+ mne.preprocessing.regress_artifact,
215
+ # Standalone functions from mne.channels
216
+ mne.channels.combine_channels,
217
+ mne.channels.equalize_channels,
218
+ mne.channels.rename_channels,
219
+ # Top-level mne functions for referencing
220
+ mne.add_reference_channels,
221
+ mne.set_bipolar_reference,
222
+ mne.set_eeg_reference,
223
+ ]
224
+
225
+ # Automatically generate and add classes to the global namespace
226
+ for function in mne_functions:
227
+ class_obj = _generate_mne_pre_processor(function)
228
+ globals()[class_obj.__name__] = class_obj
229
+
230
+ # Define __all__ based on the generated class names
231
+ __all__ = [
232
+ class_obj.__name__
233
+ for class_obj in globals().values()
234
+ if isinstance(class_obj, type)
235
+ and issubclass(class_obj, Preprocessor)
236
+ and class_obj != Preprocessor
237
+ ]
238
+
239
+ # Clean up unnecessary variables
240
+ del mne_functions, function, class_obj