braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,369 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ # Hubert Banville <hubert.jbanville@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+ import inspect
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Literal, Optional, Sequence
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ import braindecode.models as models
14
+
15
+ models_dict = {}
16
+
17
+ # For the models inside the init model, go through all the models
18
+ # check those have the EEGMixin class inherited. If they are, add them to the
19
+ # list.
20
+
21
+
22
+ def _init_models_dict():
23
+ for m in inspect.getmembers(models, inspect.isclass):
24
+ if (
25
+ issubclass(m[1], models.base.EEGModuleMixin)
26
+ and m[1] != models.base.EEGModuleMixin
27
+ ):
28
+ if m[1].__name__ == "EEGNetv4":
29
+ continue
30
+ models_dict[m[0]] = m[1]
31
+
32
+
33
+ SigArgName = Literal[
34
+ "n_outputs",
35
+ "n_chans",
36
+ "chs_info",
37
+ "n_times",
38
+ "input_window_seconds",
39
+ "sfreq",
40
+ ]
41
+
42
+
43
+ ################################################################
44
+ # Test cases for models
45
+ #
46
+ # This list should be updated whenever a new model is added to
47
+ # braindecode (otherwise `test_completeness__models_test_cases`
48
+ # will fail).
49
+ # Each element in the list should be a tuple with structure
50
+ # (model_class, required_params, signal_params), such that:
51
+ #
52
+ # model_name: str
53
+ # The name of the class of the model to be tested.
54
+ # required_params: list[str]
55
+ # The signal-related parameters that are needed to initialize
56
+ # the model.
57
+ # signal_params: dict | None
58
+ # The characteristics of the signal that should be passed to
59
+ # the model tested in case the default_signal_params are not
60
+ # compatible with this model.
61
+ # The keys of this dictionary can only be among those of
62
+ # default_signal_params.
63
+ ################################################################
64
+ models_mandatory_parameters: list[
65
+ tuple[str, list[SigArgName], dict[SigArgName, Any] | None]
66
+ ] = [
67
+ ("ATCNet", ["n_chans", "n_outputs", "n_times"], None),
68
+ ("BDTCN", ["n_chans", "n_outputs"], None),
69
+ ("Deep4Net", ["n_chans", "n_outputs", "n_times"], None),
70
+ ("DeepSleepNet", ["n_outputs"], None),
71
+ ("EEGConformer", ["n_chans", "n_outputs", "n_times"], None),
72
+ ("EEGInceptionERP", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
73
+ ("EEGInceptionMI", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
74
+ ("EEGITNet", ["n_chans", "n_outputs", "n_times"], None),
75
+ ("EEGNet", ["n_chans", "n_outputs", "n_times"], None),
76
+ ("ShallowFBCSPNet", ["n_chans", "n_outputs", "n_times"], None),
77
+ (
78
+ "SleepStagerBlanco2020",
79
+ ["n_chans", "n_outputs", "n_times"],
80
+ {"n_chans": 4}, # n_chans dividable by n_groups=2
81
+ ),
82
+ ("SleepStagerChambon2018", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
83
+ (
84
+ "AttnSleep",
85
+ ["n_outputs", "n_times", "sfreq"],
86
+ {
87
+ "sfreq": 100.0,
88
+ "n_times": 3000,
89
+ "chs_info": [{"ch_name": "C1", "kind": "eeg"}],
90
+ },
91
+ ), # 1 channel
92
+ ("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
93
+ ("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 128.0}),
94
+ ("BIOT", ["n_chans", "n_outputs", "sfreq", "n_times"], None),
95
+ ("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
96
+ ("Labram", ["n_chans", "n_outputs", "n_times"], None),
97
+ ("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
98
+ ("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
99
+ ("ContraWR", ["n_chans", "n_outputs", "sfreq", "n_times"], {"sfreq": 200.0}),
100
+ ("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
101
+ ("EEGSym", ["chs_info", "n_chans", "n_outputs", "n_times", "sfreq"], None),
102
+ ("TSception", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
103
+ ("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
104
+ ("SyncNet", ["n_chans", "n_outputs", "n_times"], None),
105
+ ("MSVTNet", ["n_chans", "n_outputs", "n_times"], None),
106
+ ("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
107
+ ("CTNet", ["n_chans", "n_outputs", "n_times"], None),
108
+ ("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 250.0}),
109
+ ("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
110
+ ("SignalJEPA", ["chs_info"], None),
111
+ ("SignalJEPA_Contextual", ["chs_info", "n_times", "n_outputs"], None),
112
+ ("SignalJEPA_PostLocal", ["n_chans", "n_times", "n_outputs"], None),
113
+ ("SignalJEPA_PreLocal", ["n_chans", "n_times", "n_outputs"], None),
114
+ ("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
115
+ ("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
116
+ ("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
117
+ ("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
118
+ ("PBT", ["n_chans", "n_outputs", "n_times"], None),
119
+ ("SSTDPN", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
120
+ ("BrainModule", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
121
+ ("BENDR", ["n_chans", "n_outputs", "n_times"], None),
122
+ ("LUNA", ["n_chans", "n_times", "n_outputs"], None),
123
+ ("MEDFormer", ["n_chans", "n_outputs", "n_times"], None),
124
+ (
125
+ "REVE",
126
+ ["n_times", "n_outputs", "n_chans", "chs_info"],
127
+ {
128
+ "sfreq": 200.0,
129
+ "n_chans": 19,
130
+ "n_times": 1_000,
131
+ "chs_info": [{"ch_name": f"E{i + 1}", "kind": "eeg"} for i in range(19)],
132
+ },
133
+ ),
134
+ ]
135
+
136
+ ################################################################
137
+ # List of models that are not meant for classification
138
+ #
139
+ # Their output shape may difer from the expected output shape
140
+ # for classification models.
141
+ ################################################################
142
+ non_classification_models = [
143
+ "SignalJEPA",
144
+ ]
145
+
146
+ ################################################################
147
+
148
+ rng = np.random.default_rng(12)
149
+ # Generating the channel info
150
+ chs_info = [
151
+ {
152
+ "ch_name": f"C{i}",
153
+ "kind": "eeg",
154
+ "loc": rng.random(12),
155
+ }
156
+ for i in range(1, 4)
157
+ ]
158
+ default_signal_params: dict[SigArgName, Any] = {
159
+ "n_times": 1000,
160
+ "sfreq": 250.0,
161
+ "n_outputs": 2,
162
+ "chs_info": chs_info,
163
+ "n_chans": len(chs_info),
164
+ "input_window_seconds": 4.0,
165
+ }
166
+
167
+
168
+ def _get_signal_params(
169
+ signal_params: dict[SigArgName, Any] | None,
170
+ required_params: list[SigArgName] | None = None,
171
+ ) -> dict[SigArgName, Any]:
172
+ """Get signal parameters for model initialization in tests."""
173
+ sp = deepcopy(default_signal_params)
174
+ if signal_params is not None:
175
+ sp.update(signal_params)
176
+ if "chs_info" in signal_params and "n_chans" not in signal_params:
177
+ sp["n_chans"] = len(signal_params["chs_info"])
178
+ if "n_chans" in signal_params and "chs_info" not in signal_params:
179
+ sp["chs_info"] = [
180
+ {"ch_name": f"C{i}", "kind": "eeg", "loc": rng.random(12)}
181
+ for i in range(signal_params["n_chans"])
182
+ ]
183
+ assert isinstance(sp["n_times"], int)
184
+ assert isinstance(sp["sfreq"], float)
185
+ assert isinstance(sp["input_window_seconds"], float)
186
+ if "input_window_seconds" not in signal_params:
187
+ sp["input_window_seconds"] = sp["n_times"] / sp["sfreq"]
188
+ if "sfreq" not in signal_params:
189
+ sp["sfreq"] = sp["n_times"] / sp["input_window_seconds"]
190
+ if "n_times" not in signal_params:
191
+ sp["n_times"] = int(sp["input_window_seconds"] * sp["sfreq"])
192
+ if required_params is not None:
193
+ sp = {
194
+ k: sp[k] for k in set((signal_params or {}).keys()).union(required_params)
195
+ }
196
+ return sp
197
+
198
+
199
+ def _get_possible_signal_params(
200
+ signal_params: dict[SigArgName, Any], required_params: list[SigArgName]
201
+ ):
202
+ sp = signal_params
203
+
204
+ # List possible model kwargs:
205
+ output_kwargs = []
206
+ output_kwargs.append(dict(n_outputs=sp["n_outputs"]))
207
+
208
+ if "n_outputs" not in required_params:
209
+ output_kwargs.append(dict(n_outputs=None))
210
+
211
+ channel_kwargs = []
212
+ channel_kwargs.append(dict(chs_info=sp["chs_info"], n_chans=None))
213
+ if "chs_info" not in required_params:
214
+ channel_kwargs.append(dict(n_chans=sp["n_chans"], chs_info=None))
215
+ if "n_chans" not in required_params and "chs_info" not in required_params:
216
+ channel_kwargs.append(dict(n_chans=None, chs_info=None))
217
+
218
+ time_kwargs = []
219
+ time_kwargs.append(
220
+ dict(n_times=sp["n_times"], sfreq=sp["sfreq"], input_window_seconds=None)
221
+ )
222
+ time_kwargs.append(
223
+ dict(
224
+ n_times=None,
225
+ sfreq=sp["sfreq"],
226
+ input_window_seconds=sp["input_window_seconds"],
227
+ )
228
+ )
229
+ time_kwargs.append(
230
+ dict(
231
+ n_times=sp["n_times"],
232
+ sfreq=None,
233
+ input_window_seconds=sp["input_window_seconds"],
234
+ )
235
+ )
236
+ if "n_times" not in required_params and "sfreq" not in required_params:
237
+ time_kwargs.append(
238
+ dict(
239
+ n_times=None,
240
+ sfreq=None,
241
+ input_window_seconds=sp["input_window_seconds"],
242
+ )
243
+ )
244
+ if (
245
+ "n_times" not in required_params
246
+ and "input_window_seconds" not in required_params
247
+ ):
248
+ time_kwargs.append(
249
+ dict(n_times=None, sfreq=sp["sfreq"], input_window_seconds=None)
250
+ )
251
+ if "sfreq" not in required_params and "input_window_seconds" not in required_params:
252
+ time_kwargs.append(
253
+ dict(n_times=sp["n_times"], sfreq=None, input_window_seconds=None)
254
+ )
255
+ if (
256
+ "n_times" not in required_params
257
+ and "sfreq" not in required_params
258
+ and "input_window_seconds" not in required_params
259
+ ):
260
+ time_kwargs.append(dict(n_times=None, sfreq=None, input_window_seconds=None))
261
+
262
+ return [
263
+ dict(**o, **c, **t)
264
+ for o in output_kwargs
265
+ for c in channel_kwargs
266
+ for t in time_kwargs
267
+ ]
268
+
269
+
270
+ ################################################################
271
+ def get_summary_table(dir_name=None):
272
+ if dir_name is None:
273
+ dir_path = Path(__file__).parent
274
+ else:
275
+ dir_path = Path(dir_name) if not isinstance(dir_name, Path) else dir_name
276
+
277
+ path = dir_path / "summary.csv"
278
+
279
+ df = pd.read_csv(
280
+ path,
281
+ header=0,
282
+ index_col="Model",
283
+ skipinitialspace=True,
284
+ )
285
+ return df
286
+
287
+
288
+ def extract_channel_locations_from_chs_info(
289
+ chs_info: Optional[Sequence[Dict[str, Any]]],
290
+ num_channels: Optional[int] = None,
291
+ ) -> Optional[np.ndarray]:
292
+ """Extract 3D channel locations from MNE-style channel information.
293
+
294
+ This function provides a unified approach to extract 3D channel locations
295
+ from MNE channel information. It's compatible with models like SignalJEPA
296
+ and LUNA that need to work with channel spatial information.
297
+
298
+ Parameters
299
+ ----------
300
+ chs_info : list of dict or None
301
+ Channel information, typically from ``mne.Info.chs``. Each dict should
302
+ contain a 'loc' key with a 12-element array (MNE format) where indices 3:6
303
+ represent the 3D cartesian coordinates.
304
+ num_channels : int or None
305
+ If specified, only extract the first ``num_channels`` channel locations.
306
+ If None, extract all available channels.
307
+
308
+ Returns
309
+ -------
310
+ channel_locations : np.ndarray of shape (n_channels, 3) or None
311
+ Array of 3D channel locations in cartesian coordinates. Returns None if
312
+ no valid locations are found.
313
+
314
+ Notes
315
+ -----
316
+ - This function handles both 12-element MNE location format (using indices 3:6)
317
+ and 3-element location format (using directly).
318
+ - Invalid or missing locations cause extraction to stop at that point.
319
+ - Returns None if no valid locations can be extracted.
320
+ - This is a unified utility compatible with models like SignalJEPA and LUNA.
321
+
322
+ Examples
323
+ --------
324
+ >>> import mne
325
+ >>> from braindecode.models.util import extract_channel_locations_from_chs_info
326
+ >>> raw = mne.io.read_raw_edf("sample.edf")
327
+ >>> locs = extract_channel_locations_from_chs_info(raw.info['chs'], num_channels=22)
328
+ >>> print(locs.shape)
329
+ (22, 3)
330
+ """
331
+ if chs_info is None:
332
+ return None
333
+
334
+ locations = []
335
+ n_to_extract = num_channels if num_channels is not None else len(chs_info)
336
+
337
+ for i, ch_info in enumerate(chs_info[:n_to_extract]):
338
+ if not isinstance(ch_info, dict):
339
+ break
340
+
341
+ loc = ch_info.get("loc")
342
+ if loc is None:
343
+ break
344
+
345
+ try:
346
+ loc_array = np.asarray(loc, dtype=np.float32)
347
+
348
+ # MNE format: 12-element array with coordinates at indices 3:6
349
+ if loc_array.ndim == 1 and loc_array.size >= 6:
350
+ if loc_array.size == 12:
351
+ # Standard MNE format
352
+ coordinates = loc_array[3:6]
353
+ else:
354
+ # Assume first 3 elements are coordinates
355
+ coordinates = loc_array[:3]
356
+ else:
357
+ break
358
+
359
+ locations.append(coordinates)
360
+ except (ValueError, TypeError):
361
+ break
362
+
363
+ if len(locations) == 0:
364
+ return None
365
+
366
+ return np.stack(locations, axis=0)
367
+
368
+
369
+ _summary_table = get_summary_table()
@@ -0,0 +1,92 @@
1
+ from .activation import LogActivation, SafeLog
2
+ from .attention import (
3
+ CAT,
4
+ CBAM,
5
+ ECA,
6
+ FCA,
7
+ GCT,
8
+ SRM,
9
+ CATLite,
10
+ EncNet,
11
+ GatherExcite,
12
+ GSoP,
13
+ MultiHeadAttention,
14
+ SqueezeAndExcitation,
15
+ )
16
+ from .blocks import MLP, FeedForwardBlock, InceptionBlock
17
+ from .convolution import (
18
+ AvgPool2dWithConv,
19
+ CausalConv1d,
20
+ CombinedConv,
21
+ Conv2dWithConstraint,
22
+ DepthwiseConv2d,
23
+ )
24
+ from .filter import FilterBankLayer, GeneralizedGaussianFilter
25
+ from .layers import (
26
+ Chomp1d,
27
+ DropPath,
28
+ Ensure4d,
29
+ SqueezeFinalOutput,
30
+ SubjectLayers,
31
+ TimeDistributed,
32
+ )
33
+ from .linear import LinearWithConstraint, MaxNormLinear
34
+ from .parametrization import MaxNorm, MaxNormParametrize
35
+ from .stats import (
36
+ LogPowerLayer,
37
+ LogVarLayer,
38
+ MaxLayer,
39
+ MeanLayer,
40
+ StatLayer,
41
+ StdLayer,
42
+ VarLayer,
43
+ )
44
+ from .util import aggregate_probas
45
+ from .wrapper import Expression, IntermediateOutputWrapper
46
+
47
+ __all__ = [
48
+ "LogActivation",
49
+ "SafeLog",
50
+ "CAT",
51
+ "CBAM",
52
+ "ECA",
53
+ "FCA",
54
+ "GCT",
55
+ "SRM",
56
+ "CATLite",
57
+ "EncNet",
58
+ "GatherExcite",
59
+ "GSoP",
60
+ "MultiHeadAttention",
61
+ "SqueezeAndExcitation",
62
+ "MLP",
63
+ "FeedForwardBlock",
64
+ "InceptionBlock",
65
+ "AvgPool2dWithConv",
66
+ "CausalConv1d",
67
+ "CombinedConv",
68
+ "Conv2dWithConstraint",
69
+ "DepthwiseConv2d",
70
+ "FilterBankLayer",
71
+ "GeneralizedGaussianFilter",
72
+ "Chomp1d",
73
+ "DropPath",
74
+ "Ensure4d",
75
+ "SubjectLayers",
76
+ "SqueezeFinalOutput",
77
+ "TimeDistributed",
78
+ "LinearWithConstraint",
79
+ "MaxNormLinear",
80
+ "MaxNorm",
81
+ "MaxNormParametrize",
82
+ "LogPowerLayer",
83
+ "LogVarLayer",
84
+ "MaxLayer",
85
+ "MeanLayer",
86
+ "StatLayer",
87
+ "StdLayer",
88
+ "VarLayer",
89
+ "aggregate_probas",
90
+ "Expression",
91
+ "IntermediateOutputWrapper",
92
+ ]
@@ -0,0 +1,86 @@
1
+ import torch
2
+ from torch import Tensor, nn
3
+
4
+ import braindecode.functional as F
5
+
6
+
7
+ class SafeLog(nn.Module):
8
+ r"""
9
+ Safe logarithm activation function module.
10
+
11
+ :math:`\text{SafeLog}(x) = \log\left(\max(x, \epsilon)\right)`
12
+
13
+ Parameters
14
+ ----------
15
+ epsilon : float, optional
16
+ A small value to clamp the input tensor to prevent computing log(0) or log of negative numbers.
17
+ Default is 1e-6.
18
+
19
+ Examples
20
+ --------
21
+ >>> import torch
22
+ >>> from braindecode.modules import SafeLog
23
+ >>> module = SafeLog(epsilon=1e-6)
24
+ >>> inputs = torch.rand(2, 3)
25
+ >>> outputs = module(inputs)
26
+ >>> outputs.shape
27
+ torch.Size([2, 3])
28
+
29
+ """
30
+
31
+ def __init__(self, epsilon: float = 1e-6):
32
+ super().__init__()
33
+ self.epsilon = epsilon
34
+
35
+ def forward(self, x) -> Tensor:
36
+ """
37
+ Forward pass of the SafeLog module.
38
+
39
+ Parameters
40
+ ----------
41
+ x : torch.Tensor
42
+ Input tensor.
43
+
44
+ Returns
45
+ -------
46
+ torch.Tensor
47
+ Output tensor after applying safe logarithm.
48
+ """
49
+ return F.safe_log(x=x, eps=self.epsilon)
50
+
51
+ def extra_repr(self) -> str:
52
+ eps_str = f"eps={self.epsilon}"
53
+ return eps_str
54
+
55
+
56
+ class LogActivation(nn.Module):
57
+ """Logarithm activation function.
58
+
59
+ Parameters
60
+ ----------
61
+ epsilon : float, default=1e-6
62
+ Small float to adjust the activation.
63
+
64
+ Examples
65
+ --------
66
+ >>> import torch
67
+ >>> from braindecode.modules import LogActivation
68
+ >>> module = LogActivation(epsilon=1e-6)
69
+ >>> inputs = torch.rand(2, 3)
70
+ >>> outputs = module(inputs)
71
+ >>> outputs.shape
72
+ torch.Size([2, 3])
73
+ """
74
+
75
+ def __init__(self, epsilon: float = 1e-6, *args, **kwargs):
76
+ """
77
+ Parameters
78
+ ----------
79
+ epsilon : float
80
+ Small float to adjust the activation.
81
+ """
82
+ super().__init__(*args, **kwargs)
83
+ self.epsilon = epsilon
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ return torch.log(x + self.epsilon) # Adding epsilon to prevent log(0)