braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,233 @@
1
+ from collections.abc import Callable
2
+ from inspect import signature
3
+ from types import UnionType
4
+ from typing import Annotated, Any, Literal, Union, get_args, get_origin
5
+
6
+ import numpy as np
7
+ from mne.utils import _soft_import
8
+ from typing_extensions import TypedDict
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.models.util import SigArgName, models_dict, models_mandatory_parameters
12
+
13
+ pydantic = _soft_import(name="pydantic", purpose="model configuration", strict=False)
14
+
15
+ try:
16
+ from numpydantic import NDArray, Shape
17
+ except ImportError:
18
+ # we can't use soft import for numpydantic because numpydantic does not define its version in __init__
19
+ NDArray = Any # type: ignore
20
+ Shape = Any # type: ignore
21
+
22
+
23
+ class ChsInfoType(TypedDict, total=False, closed=True): # type: ignore[call-arg]
24
+ cal: float
25
+ ch_name: str
26
+ coil_type: int
27
+ coord_frame: int
28
+ kind: str
29
+ loc: NDArray[Shape["12"], np.float64] # type: ignore[misc]
30
+ logno: int
31
+ range: float
32
+ scanno: int
33
+ unit: int
34
+ unit_mul: int
35
+
36
+
37
+ def _replace_type_hints(type_hint: Any) -> Any:
38
+ origin = get_origin(type_hint)
39
+ args = get_args(type_hint)
40
+ if origin is type or origin is Callable or type_hint is Callable:
41
+ return pydantic.ImportString
42
+ if origin is None:
43
+ return type_hint
44
+ replaced_args = tuple(_replace_type_hints(arg) for arg in args)
45
+ if origin is UnionType:
46
+ origin = Union
47
+ return origin[replaced_args]
48
+
49
+
50
+ SIGNAL_ARGS_TYPES = {
51
+ "n_chans": int,
52
+ "n_times": int,
53
+ "sfreq": float,
54
+ "input_window_seconds": float,
55
+ "n_outputs": int,
56
+ "chs_info": list[ChsInfoType],
57
+ }
58
+
59
+
60
+ class BaseBraindecodeModelConfig(pydantic.BaseModel): # type: ignore
61
+ def create_instance(self) -> EEGModuleMixin:
62
+ model_cls = models_dict[self.model_name_]
63
+ kwargs = self.model_dump(mode="python", exclude={"model_name_"})
64
+ if kwargs.get("n_chans") is not None and kwargs.get("chs_info") is not None:
65
+ kwargs.pop("n_chans")
66
+ if (
67
+ kwargs.get("n_times") is not None
68
+ and kwargs.get("input_window_seconds") is not None
69
+ and kwargs.get("sfreq") is not None
70
+ ):
71
+ kwargs.pop("n_times")
72
+ return model_cls(**kwargs)
73
+
74
+
75
+ def make_model_config(
76
+ model_class: type[EEGModuleMixin],
77
+ required: list[SigArgName],
78
+ ) -> type[BaseBraindecodeModelConfig]:
79
+ """Create a pydantic model config for a given model class.
80
+
81
+ Parameters
82
+ ----------
83
+ model_class : type[EEGModuleMixin]
84
+ The model class for which to create the config.
85
+ required : list of SigArgName
86
+ The required signal arguments for the model.
87
+
88
+ Returns
89
+ -------
90
+ type
91
+ A pydantic BaseModel subclass representing the model config.
92
+ """
93
+ if not pydantic:
94
+ raise ImportError(
95
+ "pydantic is required to use make_model_config. "
96
+ "Please install braindecode[typing]."
97
+ )
98
+
99
+ # ironically, we need to ignore the type here to have the soft dependency.
100
+
101
+ @pydantic.model_validator(mode="before")
102
+ def validate_signal_params(cls, data: Any):
103
+ n_outputs = data.get("n_outputs")
104
+ n_chans = data.get("n_chans")
105
+ chs_info = data.get("chs_info")
106
+ n_times = data.get("n_times")
107
+ input_window_seconds = data.get("input_window_seconds")
108
+ sfreq = data.get("sfreq")
109
+
110
+ # Check that required parameters are provided or can be inferred
111
+ if "n_outputs" in required and n_outputs is None:
112
+ raise ValueError("n_outputs is a required parameter but was not provided.")
113
+ if "n_chans" in required and n_chans is None and chs_info is None:
114
+ raise ValueError(
115
+ "n_chans is required and could not be inferred. Either specify n_chans or chs_info."
116
+ )
117
+ if "chs_info" in required and chs_info is None:
118
+ raise ValueError("chs_info is a required parameter but was not provided.")
119
+ if "n_times" in required and (
120
+ n_times is None and (sfreq is None or input_window_seconds is None)
121
+ ):
122
+ raise ValueError(
123
+ "n_times is required and could not be inferred."
124
+ "Either specify n_times or input_window_seconds and sfreq."
125
+ )
126
+ if "sfreq" in required and (
127
+ sfreq is None and (n_times is None or input_window_seconds is None)
128
+ ):
129
+ raise ValueError(
130
+ "sfreq is required and could not be inferred."
131
+ "Either specify sfreq or input_window_seconds and n_times."
132
+ )
133
+ if "input_window_seconds" in required and (
134
+ input_window_seconds is None and (n_times is None or sfreq is None)
135
+ ):
136
+ raise ValueError(
137
+ "input_window_seconds is required and could not be inferred."
138
+ "Either specify input_window_seconds or n_times and sfreq."
139
+ )
140
+
141
+ # Infer missing parameters if possible, and check consistency
142
+ if chs_info is not None:
143
+ if n_chans is None:
144
+ data["n_chans"] = len(chs_info)
145
+ elif n_chans != len(chs_info):
146
+ raise ValueError(
147
+ f"Provided {n_chans=} does not match length of chs_info: {len(chs_info)}."
148
+ )
149
+ if (
150
+ n_times is not None
151
+ and sfreq is not None
152
+ and input_window_seconds is not None
153
+ ):
154
+ if n_times != round(input_window_seconds * sfreq):
155
+ raise ValueError(
156
+ f"Provided {n_times=} does not match {input_window_seconds=} * {sfreq=}."
157
+ )
158
+ elif n_times is None and sfreq is not None and input_window_seconds is not None:
159
+ data["n_times"] = round(input_window_seconds * sfreq)
160
+ elif sfreq is None and n_times is not None and input_window_seconds is not None:
161
+ data["sfreq"] = n_times / input_window_seconds
162
+ elif input_window_seconds is None and n_times is not None and sfreq is not None:
163
+ data["input_window_seconds"] = n_times / sfreq
164
+ return data
165
+
166
+ signature_params = signature(model_class.__init__, eval_str=True).parameters
167
+ has_args = any(p.kind == p.VAR_POSITIONAL for p in signature_params.values())
168
+ has_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature_params.values())
169
+ if has_args:
170
+ raise ValueError("Model __init__ methods cannot have *args")
171
+
172
+ extra = "allow" if has_kwargs else "forbid"
173
+ fields = {}
174
+ for name, p in signature_params.items():
175
+ if name == "self" or p.kind == p.VAR_KEYWORD:
176
+ continue
177
+
178
+ annot = p.annotation
179
+ if annot is p.empty:
180
+ annot = Any
181
+ # case with type[nn.Module] or callable
182
+ else:
183
+ annot = _replace_type_hints(annot)
184
+ # Most models did not specify types for signal args, so we add them here
185
+ if name in SIGNAL_ARGS_TYPES:
186
+ annot = SIGNAL_ARGS_TYPES[name] | None
187
+
188
+ fields[name] = (annot, p.default) if p.default is not p.empty else annot
189
+
190
+ name = model_class.__name__
191
+ model_config = pydantic.create_model(
192
+ f"{name}Config",
193
+ model_name_=(Literal[name], name),
194
+ __config__=pydantic.ConfigDict(
195
+ arbitrary_types_allowed=True, extra=extra, validate_default=True
196
+ ),
197
+ __doc__=f"Pydantic config of model {model_class.__name__}\n\n{model_class.__doc__}",
198
+ __base__=BaseBraindecodeModelConfig,
199
+ __module__="braindecode.models.config",
200
+ __validators__={"validate_signal_params": validate_signal_params},
201
+ **fields,
202
+ )
203
+ return model_config
204
+
205
+
206
+ # Automatically generate and add classes to the global namespace
207
+ # and define __all__ based on generated classes
208
+ __all__ = ["make_model_config"]
209
+
210
+ if not pydantic:
211
+ pass
212
+ else:
213
+ models_configs: list[type[BaseBraindecodeModelConfig]] = []
214
+ for model_name, req, _ in models_mandatory_parameters:
215
+ model_cls = models_dict[model_name]
216
+ model_cfg = make_model_config(model_cls, req)
217
+ globals()[model_cfg.__name__] = model_cfg
218
+ __all__.append(model_cfg.__name__)
219
+ models_configs.append(model_cfg)
220
+
221
+ BraindecodeModelConfig = Annotated[ # type: ignore
222
+ Union[tuple(models_configs)],
223
+ pydantic.Field(
224
+ discriminator="model_name_", description="Braindecode model configuration"
225
+ ),
226
+ ]
227
+
228
+ # # Example usage:
229
+ #
230
+ # class DummyConfigWithModel(pydantic.BaseModel):
231
+ # model: BraindecodeModelConfig
232
+ #
233
+ # DummyConfigWithModel.model_validate({'model': dict(model_name_='EEGNet', n_chans=16, n_outputs=1, n_times=200)})
@@ -0,0 +1,319 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from mne.utils import warn
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+
9
+
10
+ class ContraWR(EEGModuleMixin, nn.Module):
11
+ r"""Contrast with the World Representation ContraWR from Yang et al (2021) [Yang2021]_.
12
+
13
+ :bdg-success:`Convolution`
14
+
15
+ This model is a convolutional neural network that uses a spectral
16
+ representation with a series of convolutional layers and residual blocks.
17
+ The model is designed to learn a representation of the EEG signal that can
18
+ be used for sleep staging.
19
+
20
+ Parameters
21
+ ----------
22
+ steps : int, optional
23
+ Number of steps to take the frequency decomposition `hop_length`
24
+ parameters by default 20.
25
+ emb_size : int, optional
26
+ Embedding size for the final layer, by default 256.
27
+ res_channels : list[int], optional
28
+ Number of channels for each residual block, by default [32, 64, 128].
29
+ activation: nn.Module, default=nn.ELU
30
+ Activation function class to apply. Should be a PyTorch activation
31
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
32
+ drop_prob : float, default=0.5
33
+ The dropout rate for regularization. Values should be between 0 and 1.
34
+
35
+ .. versionadded:: 0.9
36
+
37
+ Notes
38
+ -----
39
+ This implementation is not guaranteed to be correct, has not been checked
40
+ by original authors. The modifications are minimal and the model is expected
41
+ to work as intended. the original code from [Code2023]_.
42
+
43
+ References
44
+ ----------
45
+ .. [Yang2021] Yang, C., Xiao, C., Westover, M. B., & Sun, J. (2023).
46
+ Self-supervised electroencephalogram representation learning for automatic
47
+ sleep staging: model development and evaluation study. JMIR AI, 2(1), e46769.
48
+ .. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
49
+ Biosignal Transformer for Cross-data Learning in the Wild.
50
+ GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ n_chans=None,
56
+ n_outputs=None,
57
+ sfreq=None,
58
+ emb_size: int = 256,
59
+ res_channels: list[int] = [32, 64, 128],
60
+ steps=20,
61
+ activation: type[nn.Module] = nn.ELU,
62
+ drop_prob: float = 0.5,
63
+ stride_res: int = 2,
64
+ kernel_size_res: int = 3,
65
+ padding_res: int = 1,
66
+ # Another way to pass the EEG parameters
67
+ chs_info=None,
68
+ n_times=None,
69
+ input_window_seconds=None,
70
+ ):
71
+ super().__init__(
72
+ n_outputs=n_outputs,
73
+ n_chans=n_chans,
74
+ chs_info=chs_info,
75
+ n_times=n_times,
76
+ input_window_seconds=input_window_seconds,
77
+ sfreq=sfreq,
78
+ )
79
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
80
+ if not isinstance(res_channels, list):
81
+ raise ValueError("res_channels must be a list of integers.")
82
+
83
+ if self.input_window_seconds < 1.0:
84
+ warning_msg = (
85
+ "The input window is less than 1 second, which may not be "
86
+ "sufficient for the model to learn meaningful representations."
87
+ "changing the `n_fft` to `n_times`."
88
+ )
89
+ warn(warning_msg, UserWarning)
90
+ self.n_fft = self.n_times
91
+ else:
92
+ self.n_fft = int(self.sfreq)
93
+
94
+ self.steps = steps
95
+
96
+ res_channels = [self.n_chans] + res_channels + [emb_size]
97
+
98
+ self.torch_stft = _STFTModule(
99
+ n_fft=self.n_fft,
100
+ hop_length=int(self.n_fft // self.steps),
101
+ )
102
+
103
+ self.convs = nn.ModuleList(
104
+ [
105
+ _ResBlock(
106
+ in_channels=res_channels[i],
107
+ out_channels=res_channels[i + 1],
108
+ stride=stride_res,
109
+ use_downsampling=True,
110
+ pooling=True,
111
+ drop_prob=drop_prob,
112
+ kernel_size=kernel_size_res,
113
+ padding=padding_res,
114
+ activation=activation,
115
+ )
116
+ for i in range(len(res_channels) - 1)
117
+ ]
118
+ )
119
+ self.adaptative_pool = nn.AdaptiveAvgPool2d((1, 1))
120
+ self.flatten_layer = nn.Flatten()
121
+
122
+ self.activation_layer = activation()
123
+ self.final_layer = nn.Linear(emb_size, self.n_outputs)
124
+
125
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
126
+ """
127
+ Forward pass.
128
+
129
+ Parameters
130
+ ----------
131
+ X: Tensor
132
+ Input tensor of shape (batch_size, n_channels, n_times).
133
+ Returns
134
+ -------
135
+ Tensor
136
+ Output tensor of shape (batch_size, n_outputs).
137
+ """
138
+ X = self.torch_stft(X)
139
+
140
+ for conv in self.convs:
141
+ X = conv.forward(X)
142
+
143
+ emb = self.adaptative_pool(X)
144
+ emb = self.flatten_layer(emb)
145
+ emb = self.activation_layer(emb)
146
+
147
+ return self.final_layer(emb)
148
+
149
+
150
+ class _ResBlock(nn.Module):
151
+ r"""Convolutional Residual Block 2D.
152
+
153
+ This block stacks two convolutional layers with batch normalization,
154
+ max pooling, dropout, and residual connection.
155
+
156
+ Parameters
157
+ ----------
158
+ in_channels : int
159
+ Number of input channels.
160
+ out_channels : int
161
+ Number of output channels.
162
+ stride : int (default=1)
163
+ Stride of the convolutional layers.
164
+ use_downsampling : bool (default=True)
165
+ Whether to use a downsampling residual connection.
166
+ pooling : bool (default=True)
167
+ Whether to use max pooling.
168
+ kernel_size : int (default=3)
169
+ Kernel size of the convolutional layers.
170
+ padding : int (default=1)
171
+ Padding of the convolutional layers.
172
+ activation: nn.Module, default=nn.ELU
173
+ Activation function class to apply. Should be a PyTorch activation
174
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
175
+ drop_prob : float, default=0.5
176
+ The dropout rate for regularization. Values should be between 0 and 1.
177
+
178
+ Examples
179
+ --------
180
+ >>> import torch
181
+ >>> model = ResBlock2D(6, 16, 1, True, True)
182
+ >>> input_ = torch.randn((16, 6, 28, 150)) # (batch, channel, height, width)
183
+ >>> output = model(input_)
184
+ >>> output.shape
185
+ torch.Size([16, 16, 14, 75])
186
+ """
187
+
188
+ def __init__(
189
+ self,
190
+ in_channels,
191
+ out_channels,
192
+ stride=1,
193
+ use_downsampling=True,
194
+ pooling=True,
195
+ kernel_size=3,
196
+ padding=1,
197
+ drop_prob=0.5,
198
+ activation: type[nn.Module] = nn.ReLU,
199
+ ):
200
+ super().__init__()
201
+ self.conv1 = nn.Conv2d(
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ kernel_size=kernel_size,
205
+ stride=stride,
206
+ padding=padding,
207
+ )
208
+ self.bn1 = nn.BatchNorm2d(out_channels)
209
+ self.relu = activation()
210
+ self.conv2 = nn.Conv2d(
211
+ in_channels=out_channels,
212
+ out_channels=out_channels,
213
+ kernel_size=kernel_size,
214
+ padding=padding,
215
+ )
216
+ self.bn2 = nn.BatchNorm2d(out_channels)
217
+ self.maxpool = nn.MaxPool2d(
218
+ kernel_size=kernel_size, stride=stride, padding=padding
219
+ )
220
+ self.downsample = nn.Sequential(
221
+ nn.Conv2d(
222
+ in_channels=in_channels,
223
+ out_channels=out_channels,
224
+ kernel_size=kernel_size,
225
+ stride=stride,
226
+ padding=padding,
227
+ ),
228
+ nn.BatchNorm2d(out_channels),
229
+ )
230
+ self.use_downsampling = use_downsampling
231
+ self.pooling = pooling
232
+ self.dropout = nn.Dropout(drop_prob)
233
+
234
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
235
+ """
236
+
237
+ Parameters
238
+ ----------
239
+ X: Tensor
240
+ Input tensor of shape (batch_size, n_channels, n_freqs, n_times).
241
+
242
+ Returns
243
+ -------
244
+ Tensor
245
+ Output tensor of shape (batch_size, n_channels, n_freqs, n_times).
246
+ """
247
+ out = self.conv1(x)
248
+ out = self.bn1(out)
249
+ out = self.relu(out)
250
+ out = self.conv2(out)
251
+ out = self.bn2(out)
252
+ if self.use_downsampling:
253
+ residual = self.downsample(x)
254
+ out += residual
255
+ if self.pooling:
256
+ out = self.maxpool(out)
257
+ out = self.dropout(out)
258
+ return out
259
+
260
+
261
+ class _STFTModule(nn.Module):
262
+ r"""
263
+ A PyTorch module that computes the Short-Time Fourier Transform (STFT)
264
+ of an EEG batch tensor.
265
+
266
+ Expects input of shape (batch_size, n_channels, n_times) and returns
267
+ (batch_size, n_channels, n_freqs, n_times).
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ n_fft: int,
273
+ hop_length: int,
274
+ center: bool = True,
275
+ onesided: bool = True,
276
+ return_complex: bool = True,
277
+ normalized: bool = True,
278
+ ):
279
+ """
280
+ Parameters
281
+ ----------
282
+ n_fft : int
283
+ Number of FFT points (window size).
284
+ steps : int
285
+ Number of hops per window (i.e. hop_length = n_fft // steps).
286
+ """
287
+ super().__init__()
288
+ self.n_fft = n_fft
289
+ self.hop_length = hop_length
290
+ self.center = center
291
+ self.one_sided = onesided
292
+ self.return_complex = return_complex
293
+ self.normalized = normalized
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ window = torch.ones(self.n_fft, device=x.device)
297
+
298
+ # x: (B, C, T)
299
+ B, C, T = x.shape
300
+ # flatten batch & channel into one dim
301
+ x_flat = x.reshape(B * C, T)
302
+
303
+ # compute stft on 2D tensor
304
+ spec_flat = torch.stft(
305
+ x_flat,
306
+ n_fft=self.n_fft,
307
+ hop_length=self.hop_length,
308
+ win_length=self.n_fft,
309
+ window=window,
310
+ normalized=self.normalized,
311
+ center=self.center,
312
+ onesided=self.one_sided,
313
+ return_complex=self.return_complex,
314
+ )
315
+
316
+ F, L = spec_flat.shape[-2], spec_flat.shape[-1]
317
+ spec = spec_flat.view(B, C, F, L)
318
+
319
+ return torch.abs(spec)