braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,280 @@
1
+ # Authors: Chun-Shu Wei
2
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ import math
7
+ from warnings import warn
8
+
9
+ import torch
10
+ from einops.layers.torch import Rearrange
11
+ from torch import nn
12
+
13
+ from braindecode.models.base import EEGModuleMixin
14
+ from braindecode.modules import LogActivation
15
+
16
+
17
+ class SCCNet(EEGModuleMixin, nn.Module):
18
+ r"""SCCNet from Wei, C S (2019) [sccnet]_.
19
+
20
+ :bdg-success:`Convolution`
21
+
22
+ Spatial component-wise convolutional network (SCCNet) for motor-imagery EEG
23
+ classification.
24
+
25
+ .. figure:: https://dt5vp8kor0orz.cloudfront.net/6e3ec5d729cd51fe8acc5a978db27d02a5df9e05/2-Figure1-1.png
26
+ :align: center
27
+ :alt: Spatial component-wise convolutional network
28
+ :width: 680px
29
+
30
+ .. rubric:: Architectural Overview
31
+
32
+ SCCNet is a spatial-first convolutional layer that fixes temporal kernels in seconds
33
+ to make its filters correspond to neurophysiologically aligned windows. The model
34
+ comprises four stages:
35
+
36
+ 1. **Spatial Component Analysis**: Performs convolution spatial filtering
37
+ across all EEG channels to extract spatial components, effectively
38
+ reducing the channel dimension.
39
+ 2. **Spatio-Temporal Filtering**: Applies convolution across the spatial
40
+ components and temporal domain to capture spatio-temporal patterns.
41
+ 3. **Temporal Smoothing (Pooling)**: Uses average pooling over time to smooth the
42
+ features and reduce the temporal dimension, focusing on longer-term patterns.
43
+ 4. **Classification**: Flattens the features and applies a fully connected
44
+ layer.
45
+
46
+ .. rubric:: Macro Components
47
+
48
+ - `SCCNet.spatial_conv` **(spatial component analysis)**
49
+
50
+ - *Operations.*
51
+ - :class:`~torch.nn.Conv2d` with kernel `(n_chans, N_t)` and stride `(1, 1)` on an input reshaped to `(B, 1, n_chans, T)`; typical choice `N_t=1` yields a pure across-channel projection (montage-wide linear spatial filter).
52
+ - Zero padding to preserve time, :class:`~torch.nn.BatchNorm2d`; output has `N_u` component signals shaped `(B, 1, N_u, T)` after a permute step.
53
+
54
+ *Interpretability/robustness.* Mimics CSP-like spatial filtering: each learned filter is a channel-weighted component, easing inspection and reducing channel noise.
55
+
56
+ - `SCCNet.spatial_filt_conv` **(spatio-temporal filtering)**
57
+
58
+ - *Operations.*
59
+ - :class:`~torch.nn.Conv2d` with kernel `(N_u, 12)` over components and time (12 samples ~ 0.1 s at 125 Hz),
60
+ - :class:`~torch.nn.BatchNorm2d`;
61
+ - Nonlinearity is **power-like**: the original paper uses **square** like :class:`~braindecode.models.ShallowFBCSPNet` with the class :class:`~braindecode.modules.LogActivation` as default.
62
+ - :class:`~torch.nn.Dropout` with rate `p=0.5`.
63
+
64
+ - *Role.* Learns frequency-selective energy features and inter-component interactions within a 0.1 s context (beta/alpha cycle scale).
65
+
66
+ - `SCCNet.temporal_smoothing` **(aggregation + readout)**
67
+
68
+ - *Operations.*
69
+ - :class:`~torch.nn.AvgPool2d` with size `(1, 62)` (~ 0.5 s) for temporal smoothing and downsampling
70
+ - :class:`~torch.nn.Flatten`
71
+ - :class:`~torch.nn.Linear` to `n_outputs`.
72
+
73
+
74
+ .. rubric:: Convolutional Details
75
+
76
+ * **Temporal (where time-domain patterns are learned).**
77
+ The second block's kernel length is fixed to 12 samples (≈ 100 ms) and slides with
78
+ stride 1; average pooling `(1, 62)` (≈ 500 ms) integrates power over longer spans.
79
+ These choices bake in short-cycle detection followed by half-second trend smoothing.
80
+
81
+ * **Spatial (how electrodes are processed).**
82
+ The first block's kernel spans **all electrodes** `(n_chans, N_t)`. With `N_t=1`,
83
+ it reduces to a montage-wide linear projection, mapping channels → `N_u` components.
84
+ The second block mixes **across components** via kernel height `N_u`.
85
+
86
+ * **Spectral (how frequency information is captured).**
87
+ No explicit transform is used; learned **temporal kernels** serve as bandpass-like
88
+ filters, and the **square/log power** nonlinearity plus 0.5 s averaging approximate
89
+ band-power estimation (ERD/ERS-style features).
90
+
91
+ .. rubric:: Attention / Sequential Modules
92
+
93
+ This model contains **no attention** and **no recurrent units**.
94
+
95
+ .. rubric:: Additional Mechanisms
96
+
97
+ - :class:`~torch.nn.BatchNorm2d` and zero-padding are applied to both convolutions;
98
+ L2 weight decay was used in the original paper; dropout `p=0.5` combats overfitting.
99
+ - Contrasting with other compact neural network, in EEGNet performs a temporal depthwise conv
100
+ followed by a **depthwise spatial** conv (separable), learning temporal filters first.
101
+ SCCNet inverts this order: it performs a **full spatial projection first** (CSP-like),
102
+ then a short **spatio-temporal** conv with an explicit 0.1 s kernel, followed by
103
+ **power-like** nonlinearity and longer temporal averaging. EEGNet's ELU and
104
+ separable design favor parameter efficiency; SCCNet's second-scale kernels and
105
+ square/log emphasize interpretable **band-power** features.
106
+
107
+ - Reference implementation: see [sccnetcode]_.
108
+
109
+ .. rubric:: Usage and Configuration
110
+
111
+ * **Training from the original authors.**
112
+
113
+ * Match window length so that `T` is comfortably larger than pooling length
114
+ (e.g., > 1.5-2 s for MI).
115
+ * Start with standard MI augmentations (channel dropout/shuffle, time reverse)
116
+ and tune `n_spatial_filters` before deeper changes.
117
+
118
+ Parameters
119
+ ----------
120
+ n_spatial_filters : int, optional
121
+ Number of spatial filters in the first convolutional layer, variable `N_u` from the
122
+ original paper. Default is 22.
123
+ n_spatial_filters_smooth : int, optional
124
+ Number of spatial filters used as filter in the second convolutional
125
+ layer. Default is 20.
126
+ drop_prob : float, optional
127
+ Dropout probability. Default is 0.5.
128
+ activation : nn.Module, optional
129
+ Activation function after the second convolutional layer. Default is
130
+ logarithm activation.
131
+
132
+ References
133
+ ----------
134
+ .. [sccnet] Wei, C. S., Koike-Akino, T., & Wang, Y. (2019, March). Spatial
135
+ component-wise convolutional network (SCCNet) for motor-imagery EEG
136
+ classification. In 2019 9th International IEEE/EMBS Conference on
137
+ Neural Engineering (NER) (pp. 328-331). IEEE.
138
+ .. [sccnetcode] Hsieh, C. Y., Chou, J. L., Chang, Y. H., & Wei, C. S.
139
+ XBrainLab: An Open-Source Software for Explainable Artificial
140
+ Intelligence-Based EEG Analysis. In NeurIPS 2023 AI for
141
+ Science Workshop.
142
+
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ # Signal related parameters
148
+ n_chans=None,
149
+ n_outputs=None,
150
+ n_times=None,
151
+ chs_info=None,
152
+ input_window_seconds=None,
153
+ sfreq=None,
154
+ # Model related parameters
155
+ n_spatial_filters: int = 22,
156
+ n_spatial_filters_smooth: int = 20,
157
+ drop_prob: float = 0.5,
158
+ activation: type[nn.Module] = LogActivation,
159
+ batch_norm_momentum: float = 0.1,
160
+ ):
161
+ super().__init__(
162
+ n_outputs=n_outputs,
163
+ n_chans=n_chans,
164
+ chs_info=chs_info,
165
+ n_times=n_times,
166
+ input_window_seconds=input_window_seconds,
167
+ sfreq=sfreq,
168
+ )
169
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
170
+ # Parameters
171
+ self.n_spatial_filters = n_spatial_filters
172
+ self.n_spatial_filters_smooth = n_spatial_filters_smooth
173
+ self.drop_prob = drop_prob
174
+
175
+ # Original logical for SCCNet
176
+ conv_kernel_time = 0.1 # 100ms
177
+ pool_kernel_time = 0.5 # 500ms
178
+
179
+ # Calculate sample-based sizes from time durations
180
+ conv_kernel_samples = int(math.floor(self.sfreq * conv_kernel_time))
181
+ pool_kernel_samples = int(math.floor(self.sfreq * pool_kernel_time))
182
+
183
+ # If the input window is too short for the default kernel sizes,
184
+ # scale them down proportionally.
185
+ total_kernel_samples = conv_kernel_samples + pool_kernel_samples
186
+
187
+ if self.n_times < total_kernel_samples:
188
+ warning_msg = (
189
+ f"Input window seconds ({self.input_window_seconds:.2f}s) is smaller than the "
190
+ f"model's combined kernel sizes ({(total_kernel_samples / self.sfreq):.2f}s). "
191
+ "Scaling temporal parameters down proportionally."
192
+ )
193
+ warn(warning_msg, UserWarning, stacklevel=2)
194
+
195
+ scaling_factor = self.n_times / total_kernel_samples
196
+ conv_kernel_samples = int(math.floor(conv_kernel_samples * scaling_factor))
197
+ pool_kernel_samples = int(math.floor(pool_kernel_samples * scaling_factor))
198
+
199
+ # Ensure kernels are at least 1 sample wide
200
+ self.samples_100ms = max(1, conv_kernel_samples)
201
+ self.kernel_size_pool = max(1, pool_kernel_samples)
202
+
203
+ num_features = self._calc_num_features()
204
+
205
+ # Layers
206
+ self.ensure_dim = Rearrange("batch nchan times -> batch 1 nchan times")
207
+
208
+ self.activation = LogActivation() if activation is None else activation()
209
+
210
+ self.spatial_conv = nn.Conv2d(
211
+ in_channels=1,
212
+ out_channels=self.n_spatial_filters,
213
+ kernel_size=(self.n_chans, 1),
214
+ )
215
+
216
+ self.spatial_batch_norm = nn.BatchNorm2d(
217
+ self.n_spatial_filters, momentum=batch_norm_momentum
218
+ )
219
+
220
+ self.permute = Rearrange(
221
+ "batch filspat nchans time -> batch nchans filspat time"
222
+ )
223
+
224
+ self.spatial_filt_conv = nn.Conv2d(
225
+ in_channels=1,
226
+ out_channels=self.n_spatial_filters_smooth,
227
+ kernel_size=(self.n_spatial_filters, self.samples_100ms),
228
+ bias=False,
229
+ )
230
+ self.batch_norm = nn.BatchNorm2d(
231
+ self.n_spatial_filters_smooth, momentum=batch_norm_momentum
232
+ )
233
+
234
+ self.dropout = nn.Dropout(self.drop_prob)
235
+ self.temporal_smoothing = nn.AvgPool2d(
236
+ kernel_size=(1, self.kernel_size_pool),
237
+ stride=(1, self.samples_100ms),
238
+ )
239
+
240
+ self.final_layer = nn.Linear(num_features, self.n_outputs)
241
+
242
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
243
+ # Shape: (batch_size, n_chans, n_times)
244
+ x = self.ensure_dim(x)
245
+ # Shape: (batch_size, 1, n_chans, n_times)
246
+ x = self.spatial_conv(x)
247
+ # Shape: (batch_size, n_filters, 1, n_times)
248
+ x = self.spatial_batch_norm(x)
249
+ # Shape: (batch_size, n_filters, 1, n_times)
250
+ x = self.permute(x)
251
+ # Shape: (batch_size, 1, n_filters, n_times)
252
+ x = self.spatial_filt_conv(x)
253
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
254
+ x = self.batch_norm(x)
255
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
256
+ x = torch.pow(x, 2)
257
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
258
+ x = self.dropout(x)
259
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
260
+ x = self.temporal_smoothing(x)
261
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced_avg_pool)
262
+ x = self.activation(x)
263
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced_avg_pool)
264
+ x = x.view(x.size(0), -1)
265
+ # Shape: (batch_size, n_filters_filt*n_times_reduced_avg_pool)
266
+ x = self.final_layer(x)
267
+ # Shape: (batch_size, n_outputs)
268
+ return x
269
+
270
+ def _calc_num_features(self) -> int:
271
+ # Compute the number of features for the final linear layer
272
+ w_out_conv2 = (
273
+ self.n_times - self.samples_100ms + 1 # After second conv layer
274
+ )
275
+ w_out_pool = (
276
+ (w_out_conv2 - self.kernel_size_pool) // self.samples_100ms + 1
277
+ # After pooling layer
278
+ )
279
+ num_features = self.n_spatial_filters_smooth * w_out_pool
280
+ return num_features
@@ -0,0 +1,212 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from typing import Callable
6
+
7
+ from einops.layers.torch import Rearrange
8
+ from torch import nn
9
+ from torch.nn import init
10
+
11
+ from braindecode.functional import square
12
+ from braindecode.models.base import EEGModuleMixin
13
+ from braindecode.modules import (
14
+ CombinedConv,
15
+ Ensure4d,
16
+ Expression,
17
+ SafeLog,
18
+ SqueezeFinalOutput,
19
+ )
20
+
21
+
22
+ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
23
+ r"""Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
24
+
25
+ :bdg-success:`Convolution`
26
+
27
+ .. figure:: https://onlinelibrary.wiley.com/cms/asset/221ea375-6701-40d3-ab3f-e411aad62d9e/hbm23730-fig-0002-m.jpg
28
+ :align: center
29
+ :alt: ShallowNet Architecture
30
+
31
+ Model described in [Schirrmeister2017]_.
32
+
33
+ Parameters
34
+ ----------
35
+ n_filters_time: int
36
+ Number of temporal filters.
37
+ filter_time_length: int
38
+ Length of the temporal filter.
39
+ n_filters_spat: int
40
+ Number of spatial filters.
41
+ pool_time_length: int
42
+ Length of temporal pooling filter.
43
+ pool_time_stride: int
44
+ Length of stride between temporal pooling filters.
45
+ final_conv_length: int | str
46
+ Length of the final convolution layer.
47
+ If set to "auto", length of the input signal must be specified.
48
+ conv_nonlin: callable
49
+ Non-linear function to be used after convolution layers.
50
+ pool_mode: str
51
+ Method to use on pooling layers. "max" or "mean".
52
+ activation_pool_nonlin: callable
53
+ Non-linear function to be used after pooling layers.
54
+ split_first_layer: bool
55
+ Split first layer into temporal and spatial layers (True) or just use temporal (False).
56
+ There would be no non-linearity between the split layers.
57
+ batch_norm: bool
58
+ Whether to use batch normalisation.
59
+ batch_norm_alpha: float
60
+ Momentum for BatchNorm2d.
61
+ drop_prob: float
62
+ Dropout probability.
63
+
64
+ References
65
+ ----------
66
+ .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
67
+ L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
68
+ & Ball, T. (2017).
69
+ Deep learning with convolutional neural networks for EEG decoding and
70
+ visualization.
71
+ Human Brain Mapping , Aug. 2017.
72
+ Online: http://dx.doi.org/10.1002/hbm.23730
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ n_chans=None,
78
+ n_outputs=None,
79
+ n_times=None,
80
+ n_filters_time=40,
81
+ filter_time_length=25,
82
+ n_filters_spat=40,
83
+ pool_time_length=75,
84
+ pool_time_stride=15,
85
+ final_conv_length="auto",
86
+ conv_nonlin: Callable = square,
87
+ pool_mode="mean",
88
+ activation_pool_nonlin: type[nn.Module] = SafeLog,
89
+ split_first_layer=True,
90
+ batch_norm=True,
91
+ batch_norm_alpha=0.1,
92
+ drop_prob=0.5,
93
+ chs_info=None,
94
+ input_window_seconds=None,
95
+ sfreq=None,
96
+ ):
97
+ super().__init__(
98
+ n_outputs=n_outputs,
99
+ n_chans=n_chans,
100
+ chs_info=chs_info,
101
+ n_times=n_times,
102
+ input_window_seconds=input_window_seconds,
103
+ sfreq=sfreq,
104
+ )
105
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
106
+ if final_conv_length == "auto":
107
+ assert self.n_times is not None
108
+ self.n_filters_time = n_filters_time
109
+ self.filter_time_length = filter_time_length
110
+ self.n_filters_spat = n_filters_spat
111
+ self.pool_time_length = pool_time_length
112
+ self.pool_time_stride = pool_time_stride
113
+ self.final_conv_length = final_conv_length
114
+ self.conv_nonlin = conv_nonlin
115
+ self.pool_mode = pool_mode
116
+ self.pool_nonlin = activation_pool_nonlin
117
+ self.split_first_layer = split_first_layer
118
+ self.batch_norm = batch_norm
119
+ self.batch_norm_alpha = batch_norm_alpha
120
+ self.drop_prob = drop_prob
121
+
122
+ self.mapping = {
123
+ "conv_time.weight": "conv_time_spat.conv_time.weight",
124
+ "conv_spat.weight": "conv_time_spat.conv_spat.weight",
125
+ "conv_time.bias": "conv_time_spat.conv_time.bias",
126
+ "conv_spat.bias": "conv_time_spat.conv_spat.bias",
127
+ "conv_classifier.weight": "final_layer.conv_classifier.weight",
128
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
129
+ }
130
+
131
+ self.add_module("ensuredims", Ensure4d())
132
+ pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
133
+ if self.split_first_layer:
134
+ self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
135
+ self.add_module(
136
+ "conv_time_spat",
137
+ CombinedConv(
138
+ in_chans=self.n_chans,
139
+ n_filters_time=self.n_filters_time,
140
+ n_filters_spat=self.n_filters_spat,
141
+ filter_time_length=filter_time_length,
142
+ bias_time=True,
143
+ bias_spat=not self.batch_norm,
144
+ ),
145
+ )
146
+ n_filters_conv = self.n_filters_spat
147
+ else:
148
+ self.add_module(
149
+ "conv_time",
150
+ nn.Conv2d(
151
+ self.n_chans,
152
+ self.n_filters_time,
153
+ (self.filter_time_length, 1),
154
+ stride=1,
155
+ bias=not self.batch_norm,
156
+ ),
157
+ )
158
+ n_filters_conv = self.n_filters_time
159
+ if self.batch_norm:
160
+ self.add_module(
161
+ "bnorm",
162
+ nn.BatchNorm2d(
163
+ n_filters_conv, momentum=self.batch_norm_alpha, affine=True
164
+ ),
165
+ )
166
+ self.add_module("conv_nonlin_exp", Expression(self.conv_nonlin))
167
+ self.add_module(
168
+ "pool",
169
+ pool_class(
170
+ kernel_size=(self.pool_time_length, 1),
171
+ stride=(self.pool_time_stride, 1),
172
+ ),
173
+ )
174
+ self.add_module("pool_nonlin_exp", self.pool_nonlin())
175
+ self.add_module("drop", nn.Dropout(p=self.drop_prob))
176
+ self.eval()
177
+ if self.final_conv_length == "auto":
178
+ self.final_conv_length = self.get_output_shape()[2]
179
+
180
+ # Incorporating classification module and subsequent ones in one final layer
181
+ module = nn.Sequential()
182
+
183
+ module.add_module(
184
+ "conv_classifier",
185
+ nn.Conv2d(
186
+ n_filters_conv,
187
+ self.n_outputs,
188
+ (self.final_conv_length, 1),
189
+ bias=True,
190
+ ),
191
+ )
192
+
193
+ module.add_module("squeeze", SqueezeFinalOutput())
194
+
195
+ self.add_module("final_layer", module)
196
+
197
+ # Initialization, xavier is same as in paper...
198
+ init.xavier_uniform_(self.conv_time_spat.conv_time.weight, gain=1)
199
+ # maybe no bias in case of no split layer and batch norm
200
+ if self.split_first_layer or (not self.batch_norm):
201
+ init.constant_(self.conv_time_spat.conv_time.bias, 0)
202
+ if self.split_first_layer:
203
+ init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
204
+ if not self.batch_norm:
205
+ init.constant_(self.conv_time_spat.conv_spat.bias, 0)
206
+ if self.batch_norm:
207
+ init.constant_(self.bnorm.weight, 1)
208
+ init.constant_(self.bnorm.bias, 0)
209
+ init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
210
+ init.constant_(self.final_layer.conv_classifier.bias, 0)
211
+
212
+ self.train()