braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,324 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Sequence
4
+
5
+ import torch
6
+ from einops.layers.torch import Rearrange
7
+ from mne.utils import warn
8
+ from torch import nn
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.models.fbcnet import _valid_layers
12
+ from braindecode.modules import (
13
+ Conv2dWithConstraint,
14
+ FilterBankLayer,
15
+ LinearWithConstraint,
16
+ )
17
+
18
+
19
+ class FBMSNet(EEGModuleMixin, nn.Module):
20
+ """FBMSNet from Liu et al (2022) [fbmsnet]_.
21
+
22
+ .. figure:: https://raw.githubusercontent.com/Want2Vanish/FBMSNet/refs/heads/main/FBMSNet.png
23
+ :align: center
24
+ :alt: FBMSNet Architecture
25
+
26
+ 0. **FilterBank Layer**: Applying filterbank to transform the input.
27
+
28
+ 1. **Temporal Convolution Block**: Utilizes mixed depthwise convolution
29
+ (MixConv) to extract multiscale temporal features from multiview EEG
30
+ representations. The input is split into groups corresponding to different
31
+ views each convolved with kernels of varying sizes.
32
+ Kernel sizes are set relative to the EEG
33
+ sampling rate, with ratio coefficients [0.5, 0.25, 0.125, 0.0625],
34
+ dividing the input into four groups.
35
+
36
+ 2. **Spatial Convolution Block**: Applies depthwise convolution with a kernel
37
+ size of (n_chans, 1) to span all EEG channels, effectively learning spatial
38
+ filters. This is followed by batch normalization and the Swish activation
39
+ function. A maximum norm constraint of 2 is imposed on the convolution
40
+ weights to regularize the model.
41
+
42
+ 3. **Temporal Log-Variance Block**: Computes the log-variance.
43
+
44
+ 4. **Classification Layer**: A fully connected with weight constraint.
45
+
46
+ Notes
47
+ -----
48
+ This implementation is not guaranteed to be correct and has not been checked
49
+ by the original authors; it has only been reimplemented from the paper
50
+ description and source code [fbmsnetcode]_. There is an extra layer here to
51
+ compute the filterbank during bash time and not on data time. This avoids
52
+ data-leak, and allows the model to follow the braindecode convention.
53
+
54
+ Parameters
55
+ ----------
56
+ n_bands : int, default=9
57
+ Number of input channels (e.g., number of frequency bands).
58
+ stride_factor : int, default=4
59
+ Stride factor for temporal segmentation.
60
+ temporal_layer : str, default='LogVarLayer'
61
+ Temporal aggregation layer to use.
62
+ n_filters_spat : int, default=36
63
+ Number of output channels from the MixedConv2d layer.
64
+ dilatability : int, default=8
65
+ Expansion factor for the spatial convolution block.
66
+ activation : nn.Module, default=nn.SiLU
67
+ Activation function class to apply.
68
+ verbose: bool, default False
69
+ Verbose parameter to create the filter using mne.
70
+
71
+ References
72
+ ----------
73
+ .. [fbmsnet] Liu, K., Yang, M., Yu, Z., Wang, G., & Wu, W. (2022).
74
+ FBMSNet: A filter-bank multi-scale convolutional neural network for
75
+ EEG-based motor imagery decoding. IEEE Transactions on Biomedical
76
+ Engineering, 70(2), 436-445.
77
+ .. [fbmsnetcode] Liu, K., Yang, M., Yu, Z., Wang, G., & Wu, W. (2022).
78
+ FBMSNet: A filter-bank multi-scale convolutional neural network for
79
+ EEG-based motor imagery decoding.
80
+ https://github.com/Want2Vanish/FBMSNet
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ # Braindecode parameters
86
+ n_chans=None,
87
+ n_outputs=None,
88
+ chs_info=None,
89
+ n_times=None,
90
+ input_window_seconds=None,
91
+ sfreq=None,
92
+ # models parameters
93
+ n_bands: int = 9,
94
+ n_filters_spat: int = 36,
95
+ temporal_layer: str = "LogVarLayer",
96
+ n_dim: int = 3,
97
+ stride_factor: int = 4,
98
+ dilatability: int = 8,
99
+ activation: nn.Module = nn.SiLU,
100
+ kernels_weights: Sequence[int] = (15, 31, 63, 125),
101
+ cnn_max_norm: float = 2,
102
+ linear_max_norm: float = 0.5,
103
+ verbose: bool = False,
104
+ filter_parameters: Optional[dict] = None,
105
+ ):
106
+ super().__init__(
107
+ n_chans=n_chans,
108
+ n_outputs=n_outputs,
109
+ chs_info=chs_info,
110
+ n_times=n_times,
111
+ input_window_seconds=input_window_seconds,
112
+ sfreq=sfreq,
113
+ )
114
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
115
+
116
+ # Parameters
117
+ self.n_bands = n_bands
118
+ self.n_filters_spat = n_filters_spat
119
+ self.n_dim = n_dim
120
+ self.stride_factor = stride_factor
121
+ self.activation = activation
122
+ self.dilatability = dilatability
123
+ self.kernels_weights = kernels_weights
124
+ self.filter_parameters = filter_parameters or {}
125
+ self.out_channels_spatial = self.n_filters_spat * self.dilatability
126
+
127
+ # Checkers
128
+ if temporal_layer not in _valid_layers:
129
+ raise NotImplementedError(
130
+ f"Temporal layer '{temporal_layer}' is not implemented."
131
+ )
132
+
133
+ if self.n_times % self.stride_factor != 0:
134
+ warn(
135
+ f"Time dimension ({self.n_times}) is not divisible by"
136
+ f" stride_factor ({self.stride_factor}). Input will be padded.",
137
+ UserWarning,
138
+ )
139
+
140
+ # Layers
141
+ # Following paper nomeclature
142
+ self.spectral_filtering = FilterBankLayer(
143
+ n_chans=self.n_chans,
144
+ sfreq=self.sfreq,
145
+ band_filters=self.n_bands,
146
+ verbose=verbose,
147
+ **self.filter_parameters,
148
+ )
149
+ # As we have an internal process to create the bands,
150
+ # we get the values from the filterbank
151
+ self.n_bands = self.spectral_filtering.n_bands
152
+
153
+ # MixedConv2d Layer
154
+ self.mix_conv = nn.Sequential(
155
+ _MixedConv2d(
156
+ in_channels=self.n_bands,
157
+ out_channels=self.n_filters_spat,
158
+ stride=1,
159
+ dilation=1,
160
+ depthwise=False,
161
+ kernels_weights=kernels_weights,
162
+ ),
163
+ nn.BatchNorm2d(self.n_filters_spat),
164
+ )
165
+
166
+ # Spatial Convolution Block (SCB)
167
+ self.spatial_conv = nn.Sequential(
168
+ Conv2dWithConstraint(
169
+ in_channels=self.n_filters_spat,
170
+ out_channels=self.out_channels_spatial,
171
+ kernel_size=(self.n_chans, 1),
172
+ groups=self.n_filters_spat,
173
+ max_norm=cnn_max_norm,
174
+ padding=0,
175
+ ),
176
+ nn.BatchNorm2d(self.out_channels_spatial),
177
+ self.activation(),
178
+ )
179
+
180
+ # Padding layer
181
+ if self.n_times % self.stride_factor != 0:
182
+ self.padding_size = stride_factor - (self.n_times % stride_factor)
183
+ self.n_times_padded = self.n_times + self.padding_size
184
+ self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
185
+ else:
186
+ self.padding_layer = nn.Identity()
187
+ self.n_times_padded = self.n_times
188
+
189
+ # Temporal Aggregation Layer
190
+ self.temporal_layer = _valid_layers[temporal_layer](dim=self.n_dim) # type: ignore
191
+
192
+ self.flatten_layer = Rearrange("batch ... -> batch (...)")
193
+
194
+ # Final fully connected layer
195
+ self.final_layer = LinearWithConstraint(
196
+ in_features=self.out_channels_spatial * self.stride_factor,
197
+ out_features=self.n_outputs,
198
+ max_norm=linear_max_norm,
199
+ )
200
+
201
+ def forward(self, x):
202
+ """
203
+ Forward pass of the FBMSNet model.
204
+
205
+ Parameters
206
+ ----------
207
+ x : torch.Tensor
208
+ Input tensor with shape (batch_size, n_chans, n_times).
209
+
210
+ Returns
211
+ -------
212
+ torch.Tensor
213
+ Output tensor with shape (batch_size, n_outputs).
214
+ """
215
+ batch, _, _ = x.shape
216
+
217
+ # shape: (batch, n_chans, n_times)
218
+ x = self.spectral_filtering(x)
219
+ # shape: (batch, n_bands, n_chans, n_times)
220
+
221
+ # Mixed convolution
222
+ x = self.mix_conv(x)
223
+ # shape: (batch, self.n_filters_spat, n_chans, n_times)
224
+
225
+ # Spatial convolution block
226
+ x = self.spatial_conv(x)
227
+ # shape: (batch, self.out_channels_spatial, 1, n_times)
228
+
229
+ # Apply some padding to the input to make it divisible by the stride factor
230
+ x = self.padding_layer(x)
231
+ # shape: (batch, self.out_channels_spatial, 1, n_times_padded)
232
+
233
+ # Reshape for temporal layer
234
+ x = x.view(batch, self.out_channels_spatial, self.stride_factor, -1)
235
+ # shape: (batch, self.out_channels_spatial, self.stride_factor, n_times/self.stride_factor)
236
+
237
+ # Temporal aggregation
238
+ x = self.temporal_layer(x)
239
+ # shape: (batch, self.out_channels_spatial, self.stride_factor, 1)
240
+
241
+ # Flatten and classify
242
+ x = self.flatten_layer(x)
243
+ # shape: (batch, self.out_channels_spatial*self.stride_factor)
244
+
245
+ x = self.final_layer(x)
246
+ # shape: (batch, n_outputs)
247
+ return x
248
+
249
+
250
+ class _MixedConv2d(nn.Module):
251
+ """Mixed Grouped Convolution for multiscale feature extraction."""
252
+
253
+ def __init__(
254
+ self,
255
+ in_channels,
256
+ out_channels,
257
+ kernels_weights=(15, 31, 63, 125),
258
+ stride=1,
259
+ dilation=1,
260
+ depthwise=False,
261
+ ):
262
+ super().__init__()
263
+
264
+ num_groups = len(kernels_weights)
265
+ in_splits = self._split_channels(in_channels, num_groups)
266
+ out_splits = self._split_channels(out_channels, num_groups)
267
+ self.splits = in_splits
268
+
269
+ self.convs = nn.ModuleList()
270
+ # Create a convolutional layer for each kernel size
271
+ for k, in_ch, out_ch in zip(kernels_weights, in_splits, out_splits):
272
+ conv_groups = out_ch if depthwise else 1
273
+ conv = nn.Conv2d(
274
+ in_channels=in_ch,
275
+ out_channels=out_ch,
276
+ kernel_size=(1, k),
277
+ stride=stride,
278
+ padding="same",
279
+ dilation=dilation,
280
+ groups=conv_groups,
281
+ bias=False,
282
+ )
283
+ self.convs.append(conv)
284
+
285
+ @staticmethod
286
+ def _split_channels(num_chan, num_groups):
287
+ """
288
+ Splits the total number of channels into a specified
289
+ number of groups as evenly as possible.
290
+
291
+ Parameters
292
+ ----------
293
+ num_chan : int
294
+ The total number of channels to split.
295
+ num_groups : int
296
+ The number of groups to split the channels into.
297
+
298
+ Returns
299
+ -------
300
+ list of int
301
+ A list containing the number of channels in each group.
302
+ The first group may have more channels if the division is not even.
303
+ """
304
+ split = [num_chan // num_groups for _ in range(num_groups)]
305
+ split[0] += num_chan - sum(split)
306
+ return split
307
+
308
+ def forward(self, x):
309
+ # Split the input tensor `x` along the channel dimension (dim=1) into groups.
310
+ # The size of each group is defined by `self.splits`, which is calculated
311
+ # based on the number of input channels and the number of kernel sizes.
312
+ x_split = torch.split(x, self.splits, 1)
313
+
314
+ # For each split group, apply the corresponding convolutional layer.
315
+ # `self.values()` returns the convolutional layers in the order they were added.
316
+ # The result is a list of output tensors, one for each group.
317
+ x_out = [conv(x_split[i]) for i, conv in enumerate(self.convs)]
318
+
319
+ # Concatenate the outputs from all groups along the channel dimension (dim=1)
320
+ # to form a single output tensor.
321
+ x = torch.cat(x_out, 1)
322
+
323
+ # Return the concatenated tensor as the output of the mixed convolution.
324
+ return x
@@ -6,14 +6,12 @@ import torch
6
6
  from torch import nn
7
7
  from torch.nn import ConstantPad2d
8
8
 
9
- from .deep4 import Deep4Net
10
- from .util import to_dense_prediction_model
11
- from .shallow_fbcsp import ShallowFBCSPNet
12
- from .base import EEGModuleMixin, deprecated_args
9
+ from braindecode.models.deep4 import Deep4Net
10
+ from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
13
11
 
14
12
 
15
- class HybridNet(EEGModuleMixin, nn.Module):
16
- """Hybrid ConvNet model from Schirrmeister et al 2017.
13
+ class HybridNet(nn.Module):
14
+ """Hybrid ConvNet model from Schirrmeister, R T et al (2017) [Schirrmeister2017]_.
17
15
 
18
16
  See [Schirrmeister2017]_ for details.
19
17
 
@@ -28,25 +26,21 @@ class HybridNet(EEGModuleMixin, nn.Module):
28
26
  Online: http://dx.doi.org/10.1002/hbm.23730
29
27
  """
30
28
 
31
- def __init__(self, n_chans=None, n_outputs=None, n_times=None,
32
- in_chans=None, n_classes=None, input_window_samples=None,
33
- add_log_softmax=True):
34
-
35
- n_chans, n_outputs, n_times = deprecated_args(
36
- self,
37
- ('in_chans', 'n_chans', in_chans, n_chans),
38
- ('n_classes', 'n_outputs', n_classes, n_outputs),
39
- ('input_window_samples', 'n_times', input_window_samples, n_times),
40
- )
41
- super().__init__(
42
- n_outputs=n_outputs,
43
- n_chans=n_chans,
44
- n_times=n_times,
45
- add_log_softmax=add_log_softmax,
46
- )
29
+ def __init__(
30
+ self,
31
+ n_chans=None,
32
+ n_outputs=None,
33
+ n_times=None,
34
+ input_window_seconds=None,
35
+ sfreq=None,
36
+ chs_info=None,
37
+ activation: nn.Module = nn.ELU,
38
+ drop_prob: float = 0.5,
39
+ ):
40
+ super().__init__()
47
41
  self.mapping = {
48
- 'final_conv.weight': 'final_layer.weight',
49
- 'final_conv.bias': 'final_layer.bias'
42
+ "final_conv.weight": "final_layer.weight",
43
+ "final_conv.bias": "final_layer.bias",
50
44
  }
51
45
 
52
46
  deep_model = Deep4Net(
@@ -58,61 +52,52 @@ class HybridNet(EEGModuleMixin, nn.Module):
58
52
  n_filters_3=50,
59
53
  n_filters_4=60,
60
54
  n_times=n_times,
55
+ input_window_seconds=input_window_seconds,
56
+ sfreq=sfreq,
57
+ chs_info=chs_info,
61
58
  final_conv_length=2,
59
+ activation_first_conv_nonlin=activation,
60
+ activation_later_conv_nonlin=activation,
61
+ drop_prob=drop_prob,
62
62
  )
63
63
  shallow_model = ShallowFBCSPNet(
64
64
  n_chans=n_chans,
65
65
  n_outputs=n_outputs,
66
66
  n_times=n_times,
67
+ input_window_seconds=input_window_seconds,
68
+ sfreq=sfreq,
69
+ chs_info=chs_info,
67
70
  n_filters_time=30,
68
71
  n_filters_spat=40,
69
72
  filter_time_length=28,
70
73
  final_conv_length=29,
74
+ drop_prob=drop_prob,
75
+ )
76
+ new_conv_layer = nn.Conv2d(
77
+ deep_model.final_layer.conv_classifier.in_channels,
78
+ 60,
79
+ kernel_size=deep_model.final_layer.conv_classifier.kernel_size,
80
+ stride=deep_model.final_layer.conv_classifier.stride,
71
81
  )
82
+ deep_model.final_layer = new_conv_layer
72
83
 
73
- del n_outputs, n_chans, n_times
74
- del in_chans, n_classes, input_window_samples
75
-
76
- reduced_deep_model = nn.Sequential()
77
- for name, module in deep_model.named_children():
78
- if name == "final_layer":
79
- new_conv_layer = nn.Conv2d(
80
- module.conv_classifier.in_channels,
81
- 60,
82
- kernel_size=module.conv_classifier.kernel_size,
83
- stride=module.conv_classifier.stride,
84
- )
85
- reduced_deep_model.add_module("deep_final_conv", new_conv_layer)
86
- break
87
- reduced_deep_model.add_module(name, module)
88
-
89
- reduced_shallow_model = nn.Sequential()
90
- for name, module in shallow_model.named_children():
91
- if name == "final_layer":
92
- new_conv_layer = nn.Conv2d(
93
- module.conv_classifier.in_channels,
94
- 40,
95
- kernel_size=module.conv_classifier.kernel_size,
96
- stride=module.conv_classifier.stride,
97
- )
98
- reduced_shallow_model.add_module(
99
- "shallow_final_conv", new_conv_layer
100
- )
101
- break
102
- reduced_shallow_model.add_module(name, module)
103
-
104
- to_dense_prediction_model(reduced_deep_model)
105
- to_dense_prediction_model(reduced_shallow_model)
106
- self.reduced_deep_model = reduced_deep_model
107
- self.reduced_shallow_model = reduced_shallow_model
84
+ new_conv_layer = nn.Conv2d(
85
+ shallow_model.final_layer.conv_classifier.in_channels,
86
+ 40,
87
+ kernel_size=shallow_model.final_layer.conv_classifier.kernel_size,
88
+ stride=shallow_model.final_layer.conv_classifier.stride,
89
+ )
90
+ shallow_model.final_layer = new_conv_layer
91
+
92
+ deep_model.to_dense_prediction_model()
93
+ shallow_model.to_dense_prediction_model()
94
+ self.reduced_deep_model = deep_model
95
+ self.reduced_shallow_model = shallow_model
108
96
 
109
97
  self.final_layer = nn.Sequential(
110
- nn.Conv2d(
111
- 100,
112
- self.n_outputs,
113
- kernel_size=(1, 1),
114
- stride=1),
115
- nn.LogSoftmax(dim=1) if self.add_log_softmax else nn.Identity())
98
+ nn.Conv2d(100, n_outputs, kernel_size=(1, 1), stride=1),
99
+ nn.Identity(),
100
+ )
116
101
 
117
102
  def forward(self, x):
118
103
  """Forward pass.
@@ -128,13 +113,9 @@ class HybridNet(EEGModuleMixin, nn.Module):
128
113
  n_diff_deep_shallow = deep_out.size()[2] - shallow_out.size()[2]
129
114
 
130
115
  if n_diff_deep_shallow < 0:
131
- deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)(
132
- deep_out
133
- )
116
+ deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)(deep_out)
134
117
  elif n_diff_deep_shallow > 0:
135
- shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)(
136
- shallow_out
137
- )
118
+ shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)(shallow_out)
138
119
 
139
120
  merged_out = torch.cat((deep_out, shallow_out), dim=1)
140
121