braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,247 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops.layers.torch import Rearrange
11
+
12
+ from braindecode.models.base import EEGModuleMixin
13
+ from braindecode.modules import Conv2dWithConstraint, LinearWithConstraint
14
+
15
+
16
+ class EEGNeX(EEGModuleMixin, nn.Module):
17
+ """EEGNeX model from Chen et al. (2024) [eegnex]_.
18
+
19
+ .. figure:: https://braindecode.org/dev/_static/model/eegnex.jpg
20
+ :align: center
21
+ :alt: EEGNeX Architecture
22
+
23
+ Parameters
24
+ ----------
25
+ activation : nn.Module, optional
26
+ Activation function to use. Default is `nn.ELU`.
27
+ depth_multiplier : int, optional
28
+ Depth multiplier for the depthwise convolution. Default is 2.
29
+ filter_1 : int, optional
30
+ Number of filters in the first convolutional layer. Default is 8.
31
+ filter_2 : int, optional
32
+ Number of filters in the second convolutional layer. Default is 32.
33
+ drop_prob: float, optional
34
+ Dropout rate. Default is 0.5.
35
+ kernel_block_4 : tuple[int, int], optional
36
+ Kernel size for block 4. Default is (1, 16).
37
+ dilation_block_4 : tuple[int, int], optional
38
+ Dilation rate for block 4. Default is (1, 2).
39
+ avg_pool_block4 : tuple[int, int], optional
40
+ Pooling size for block 4. Default is (1, 4).
41
+ kernel_block_5 : tuple[int, int], optional
42
+ Kernel size for block 5. Default is (1, 16).
43
+ dilation_block_5 : tuple[int, int], optional
44
+ Dilation rate for block 5. Default is (1, 4).
45
+ avg_pool_block5 : tuple[int, int], optional
46
+ Pooling size for block 5. Default is (1, 8).
47
+
48
+ Notes
49
+ -----
50
+ This implementation is not guaranteed to be correct, has not been checked
51
+ by original authors, only reimplemented from the paper description and
52
+ source code in tensorflow [EEGNexCode]_.
53
+
54
+ References
55
+ ----------
56
+ .. [eegnex] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
57
+ Toward reliable signals decoding for electroencephalogram: A benchmark
58
+ study to EEGNeX. Biomedical Signal Processing and Control, 87, 105475.
59
+ .. [EEGNexCode] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
60
+ Toward reliable signals decoding for electroencephalogram: A benchmark
61
+ study to EEGNeX. https://github.com/chenxiachan/EEGNeX
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ # Signal related parameters
67
+ n_chans=None,
68
+ n_outputs=None,
69
+ n_times=None,
70
+ chs_info=None,
71
+ input_window_seconds=None,
72
+ sfreq=None,
73
+ # Model parameters
74
+ activation: nn.Module = nn.ELU,
75
+ depth_multiplier: int = 2,
76
+ filter_1: int = 8,
77
+ filter_2: int = 32,
78
+ drop_prob: float = 0.5,
79
+ kernel_block_1_2: int = 64,
80
+ kernel_block_4: int = 16,
81
+ dilation_block_4: int = 2,
82
+ avg_pool_block4: int = 4,
83
+ kernel_block_5: int = 16,
84
+ dilation_block_5: int = 4,
85
+ avg_pool_block5: int = 8,
86
+ max_norm_conv: float = 1.0,
87
+ max_norm_linear: float = 0.25,
88
+ ):
89
+ super().__init__(
90
+ n_outputs=n_outputs,
91
+ n_chans=n_chans,
92
+ chs_info=chs_info,
93
+ n_times=n_times,
94
+ input_window_seconds=input_window_seconds,
95
+ sfreq=sfreq,
96
+ )
97
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
98
+
99
+ self.depth_multiplier = depth_multiplier
100
+ self.filter_1 = filter_1
101
+ self.filter_2 = filter_2
102
+ self.filter_3 = self.filter_2 * self.depth_multiplier
103
+ self.drop_prob = drop_prob
104
+ self.activation = activation
105
+ self.kernel_block_1_2 = (1, kernel_block_1_2)
106
+ self.kernel_block_4 = (1, kernel_block_4)
107
+ self.dilation_block_4 = (1, dilation_block_4)
108
+ self.avg_pool_block4 = (1, avg_pool_block4)
109
+ self.kernel_block_5 = (1, kernel_block_5)
110
+ self.dilation_block_5 = (1, dilation_block_5)
111
+ self.avg_pool_block5 = (1, avg_pool_block5)
112
+
113
+ # final layers output
114
+ self.in_features = self._calculate_output_length()
115
+
116
+ # Following paper nomenclature
117
+ self.block_1 = nn.Sequential(
118
+ Rearrange("batch ch time -> batch 1 ch time"),
119
+ nn.Conv2d(
120
+ in_channels=1,
121
+ out_channels=self.filter_1,
122
+ kernel_size=self.kernel_block_1_2,
123
+ padding="same",
124
+ bias=False,
125
+ ),
126
+ nn.BatchNorm2d(num_features=self.filter_1),
127
+ )
128
+
129
+ self.block_2 = nn.Sequential(
130
+ nn.Conv2d(
131
+ in_channels=self.filter_1,
132
+ out_channels=self.filter_2,
133
+ kernel_size=self.kernel_block_1_2,
134
+ padding="same",
135
+ bias=False,
136
+ ),
137
+ nn.BatchNorm2d(num_features=self.filter_2),
138
+ )
139
+
140
+ self.block_3 = nn.Sequential(
141
+ Conv2dWithConstraint(
142
+ in_channels=self.filter_2,
143
+ out_channels=self.filter_3,
144
+ max_norm=max_norm_conv,
145
+ kernel_size=(self.n_chans, 1),
146
+ groups=self.filter_2,
147
+ bias=False,
148
+ ),
149
+ nn.BatchNorm2d(num_features=self.filter_3),
150
+ self.activation(),
151
+ nn.AvgPool2d(
152
+ kernel_size=self.avg_pool_block4,
153
+ padding=(0, 1),
154
+ ),
155
+ nn.Dropout(p=self.drop_prob),
156
+ )
157
+
158
+ self.block_4 = nn.Sequential(
159
+ nn.Conv2d(
160
+ in_channels=self.filter_3,
161
+ out_channels=self.filter_2,
162
+ kernel_size=self.kernel_block_4,
163
+ dilation=self.dilation_block_4,
164
+ padding="same",
165
+ bias=False,
166
+ ),
167
+ nn.BatchNorm2d(num_features=self.filter_2),
168
+ )
169
+
170
+ self.block_5 = nn.Sequential(
171
+ nn.Conv2d(
172
+ in_channels=self.filter_2,
173
+ out_channels=self.filter_1,
174
+ kernel_size=self.kernel_block_5,
175
+ dilation=self.dilation_block_5,
176
+ padding="same",
177
+ bias=False,
178
+ ),
179
+ nn.BatchNorm2d(num_features=self.filter_1),
180
+ self.activation(),
181
+ nn.AvgPool2d(
182
+ kernel_size=self.avg_pool_block5,
183
+ padding=(0, 1),
184
+ ),
185
+ nn.Dropout(p=self.drop_prob),
186
+ nn.Flatten(),
187
+ )
188
+
189
+ self.final_layer = LinearWithConstraint(
190
+ in_features=self.in_features,
191
+ out_features=self.n_outputs,
192
+ max_norm=max_norm_linear,
193
+ )
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ """
197
+ Forward pass of the EEGNeX model.
198
+
199
+ Parameters
200
+ ----------
201
+ x : torch.Tensor
202
+ Input tensor of shape (batch_size, n_chans, n_times).
203
+
204
+ Returns
205
+ -------
206
+ torch.Tensor
207
+ Output tensor of shape (batch_size, n_outputs).
208
+ """
209
+ # x shape: (batch_size, n_chans, n_times)
210
+ x = self.block_1(x)
211
+ # (batch_size, n_filter, n_chans, n_times)
212
+ x = self.block_2(x)
213
+ # (batch_size, n_filter*4, n_chans, n_times)
214
+ x = self.block_3(x)
215
+ # (batch_size, 1, n_filter*8, n_times//4)
216
+ x = self.block_4(x)
217
+ # (batch_size, 1, n_filter*8, n_times//4)
218
+ x = self.block_5(x)
219
+ # (batch_size, n_filter*(n_times//32))
220
+ x = self.final_layer(x)
221
+
222
+ return x
223
+
224
+ def _calculate_output_length(self) -> int:
225
+ # Pooling kernel sizes for the time dimension
226
+ p4 = self.avg_pool_block4[1]
227
+ p5 = self.avg_pool_block5[1]
228
+
229
+ # Padding for the time dimension (assumed from padding=(0, 1))
230
+ pad4 = 1
231
+ pad5 = 1
232
+
233
+ # Stride is assumed to be equal to kernel size (p4 and p5)
234
+
235
+ # Calculate time dimension after block 3 pooling
236
+ # Formula: floor((L_in + 2*padding - kernel_size) / stride) + 1
237
+ T3 = math.floor((self.n_times + 2 * pad4 - p4) / p4) + 1
238
+
239
+ # Calculate time dimension after block 5 pooling
240
+ T5 = math.floor((T3 + 2 * pad5 - p5) / p5) + 1
241
+
242
+ # Calculate final flattened features (channels * 1 * time_dim)
243
+ # The spatial dimension is reduced to 1 after block 3's depthwise conv.
244
+ final_in_features = (
245
+ self.filter_1 * T5
246
+ ) # filter_1 is the number of channels before flatten
247
+ return final_in_features
@@ -0,0 +1,362 @@
1
+ # Authors: Robin Tibor Schirrmeister <robintibor@gmail.com>
2
+ # Tonio Ball
3
+ #
4
+ # License: BSD-3
5
+
6
+ import numpy as np
7
+ import torch
8
+ from einops.layers.torch import Rearrange
9
+ from torch import nn
10
+ from torch.nn import init
11
+
12
+ from braindecode.models.base import EEGModuleMixin
13
+ from braindecode.modules import (
14
+ AvgPool2dWithConv,
15
+ Ensure4d,
16
+ SqueezeFinalOutput,
17
+ )
18
+
19
+
20
+ class EEGResNet(EEGModuleMixin, nn.Sequential):
21
+ """EEGResNet from Schirrmeister et al. 2017 [Schirrmeister2017]_.
22
+
23
+ .. figure:: https://onlinelibrary.wiley.com/cms/asset/bed1b768-809f-4bc6-b942-b36970d81271/hbm23730-fig-0003-m.jpg
24
+ :align: center
25
+ :alt: EEGResNet Architecture
26
+
27
+ Model described in [Schirrmeister2017]_.
28
+
29
+ Parameters
30
+ ----------
31
+ in_chans :
32
+ Alias for ``n_chans``.
33
+ n_classes :
34
+ Alias for ``n_outputs``.
35
+ input_window_samples :
36
+ Alias for ``n_times``.
37
+ activation: nn.Module, default=nn.ELU
38
+ Activation function class to apply. Should be a PyTorch activation
39
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
40
+
41
+ References
42
+ ----------
43
+ .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
44
+ L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
45
+ & Ball, T. (2017). Deep learning with convolutional neural networks for ,
46
+ EEG decoding and visualization. Human Brain Mapping, Aug. 2017.
47
+ Online: http://dx.doi.org/10.1002/hbm.23730
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ n_chans=None,
53
+ n_outputs=None,
54
+ n_times=None,
55
+ final_pool_length="auto",
56
+ n_first_filters=20,
57
+ n_layers_per_block=2,
58
+ first_filter_length=3,
59
+ activation=nn.ELU,
60
+ split_first_layer=True,
61
+ batch_norm_alpha=0.1,
62
+ batch_norm_epsilon=1e-4,
63
+ conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
64
+ chs_info=None,
65
+ input_window_seconds=None,
66
+ sfreq=250,
67
+ ):
68
+ super().__init__(
69
+ n_outputs=n_outputs,
70
+ n_chans=n_chans,
71
+ chs_info=chs_info,
72
+ n_times=n_times,
73
+ input_window_seconds=input_window_seconds,
74
+ sfreq=sfreq,
75
+ )
76
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
77
+
78
+ if final_pool_length == "auto":
79
+ assert self.n_times is not None
80
+ assert first_filter_length % 2 == 1
81
+ self.final_pool_length = final_pool_length
82
+ self.n_first_filters = n_first_filters
83
+ self.n_layers_per_block = n_layers_per_block
84
+ self.first_filter_length = first_filter_length
85
+ self.nonlinearity = activation
86
+ self.split_first_layer = split_first_layer
87
+ self.batch_norm_alpha = batch_norm_alpha
88
+ self.batch_norm_epsilon = batch_norm_epsilon
89
+ self.conv_weight_init_fn = conv_weight_init_fn
90
+
91
+ self.mapping = {
92
+ "conv_classifier.weight": "final_layer.conv_classifier.weight",
93
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
94
+ }
95
+
96
+ self.add_module("ensuredims", Ensure4d())
97
+ if self.split_first_layer:
98
+ self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
99
+ self.add_module(
100
+ "conv_time",
101
+ nn.Conv2d(
102
+ 1,
103
+ self.n_first_filters,
104
+ (self.first_filter_length, 1),
105
+ stride=1,
106
+ padding=(self.first_filter_length // 2, 0),
107
+ ),
108
+ )
109
+ self.add_module(
110
+ "conv_spat",
111
+ nn.Conv2d(
112
+ self.n_first_filters,
113
+ self.n_first_filters,
114
+ (1, self.n_chans),
115
+ stride=(1, 1),
116
+ bias=False,
117
+ ),
118
+ )
119
+ else:
120
+ self.add_module(
121
+ "conv_time",
122
+ nn.Conv2d(
123
+ self.n_chans,
124
+ self.n_first_filters,
125
+ (self.first_filter_length, 1),
126
+ stride=(1, 1),
127
+ padding=(self.first_filter_length // 2, 0),
128
+ bias=False,
129
+ ),
130
+ )
131
+ n_filters_conv = self.n_first_filters
132
+ self.add_module(
133
+ "bnorm",
134
+ nn.BatchNorm2d(
135
+ n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5
136
+ ),
137
+ )
138
+ self.add_module("conv_nonlin", self.nonlinearity())
139
+ cur_dilation = np.array([1, 1])
140
+ n_cur_filters = n_filters_conv
141
+ i_block = 1
142
+ for i_layer in range(self.n_layers_per_block):
143
+ self.add_module(
144
+ "res_{:d}_{:d}".format(i_block, i_layer),
145
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
146
+ )
147
+ i_block += 1
148
+ cur_dilation[0] *= 2
149
+ n_out_filters = int(2 * n_cur_filters)
150
+ self.add_module(
151
+ "res_{:d}_{:d}".format(i_block, 0),
152
+ _ResidualBlock(
153
+ n_cur_filters,
154
+ n_out_filters,
155
+ dilation=cur_dilation,
156
+ ),
157
+ )
158
+ n_cur_filters = n_out_filters
159
+ for i_layer in range(1, self.n_layers_per_block):
160
+ self.add_module(
161
+ "res_{:d}_{:d}".format(i_block, i_layer),
162
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
163
+ )
164
+
165
+ i_block += 1
166
+ cur_dilation[0] *= 2
167
+ n_out_filters = int(1.5 * n_cur_filters)
168
+ self.add_module(
169
+ "res_{:d}_{:d}".format(i_block, 0),
170
+ _ResidualBlock(
171
+ n_cur_filters,
172
+ n_out_filters,
173
+ dilation=cur_dilation,
174
+ ),
175
+ )
176
+ n_cur_filters = n_out_filters
177
+ for i_layer in range(1, self.n_layers_per_block):
178
+ self.add_module(
179
+ "res_{:d}_{:d}".format(i_block, i_layer),
180
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
181
+ )
182
+
183
+ i_block += 1
184
+ cur_dilation[0] *= 2
185
+ self.add_module(
186
+ "res_{:d}_{:d}".format(i_block, 0),
187
+ _ResidualBlock(
188
+ n_cur_filters,
189
+ n_cur_filters,
190
+ dilation=cur_dilation,
191
+ ),
192
+ )
193
+ for i_layer in range(1, self.n_layers_per_block):
194
+ self.add_module(
195
+ "res_{:d}_{:d}".format(i_block, i_layer),
196
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
197
+ )
198
+
199
+ i_block += 1
200
+ cur_dilation[0] *= 2
201
+ self.add_module(
202
+ "res_{:d}_{:d}".format(i_block, 0),
203
+ _ResidualBlock(
204
+ n_cur_filters,
205
+ n_cur_filters,
206
+ dilation=cur_dilation,
207
+ ),
208
+ )
209
+ for i_layer in range(1, self.n_layers_per_block):
210
+ self.add_module(
211
+ "res_{:d}_{:d}".format(i_block, i_layer),
212
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
213
+ )
214
+
215
+ i_block += 1
216
+ cur_dilation[0] *= 2
217
+ self.add_module(
218
+ "res_{:d}_{:d}".format(i_block, 0),
219
+ _ResidualBlock(
220
+ n_cur_filters,
221
+ n_cur_filters,
222
+ dilation=cur_dilation,
223
+ ),
224
+ )
225
+ for i_layer in range(1, self.n_layers_per_block):
226
+ self.add_module(
227
+ "res_{:d}_{:d}".format(i_block, i_layer),
228
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
229
+ )
230
+ i_block += 1
231
+ cur_dilation[0] *= 2
232
+ self.add_module(
233
+ "res_{:d}_{:d}".format(i_block, 0),
234
+ _ResidualBlock(
235
+ n_cur_filters,
236
+ n_cur_filters,
237
+ dilation=cur_dilation,
238
+ ),
239
+ )
240
+ for i_layer in range(1, self.n_layers_per_block):
241
+ self.add_module(
242
+ "res_{:d}_{:d}".format(i_block, i_layer),
243
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
244
+ )
245
+
246
+ self.eval()
247
+ if self.final_pool_length == "auto":
248
+ self.add_module("mean_pool", nn.AdaptiveAvgPool2d((1, 1)))
249
+ else:
250
+ pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
251
+ self.add_module(
252
+ "mean_pool",
253
+ AvgPool2dWithConv(
254
+ (self.final_pool_length, 1), (1, 1), dilation=pool_dilation
255
+ ),
256
+ )
257
+
258
+ # Incorporating classification module and subsequent ones in one final layer
259
+ module = nn.Sequential()
260
+
261
+ module.add_module(
262
+ "conv_classifier",
263
+ nn.Conv2d(
264
+ n_cur_filters,
265
+ self.n_outputs,
266
+ (1, 1),
267
+ bias=True,
268
+ ),
269
+ )
270
+
271
+ module.add_module("squeeze", SqueezeFinalOutput())
272
+
273
+ self.add_module("final_layer", module)
274
+
275
+ # Initialize all weights
276
+ self.apply(lambda module: self._weights_init(module, self.conv_weight_init_fn))
277
+
278
+ # Start in train mode
279
+ self.train()
280
+
281
+ @staticmethod
282
+ def _weights_init(module, conv_weight_init_fn):
283
+ """
284
+ initialize weights
285
+ """
286
+ classname = module.__class__.__name__
287
+ if "Conv" in classname and classname != "AvgPool2dWithConv":
288
+ conv_weight_init_fn(module.weight)
289
+ if module.bias is not None:
290
+ init.constant_(module.bias, 0)
291
+ elif "BatchNorm" in classname:
292
+ init.constant_(module.weight, 1)
293
+ init.constant_(module.bias, 0)
294
+
295
+
296
+ class _ResidualBlock(nn.Module):
297
+ """
298
+ create a residual learning building block with two stacked 3x3 convlayers as in paper
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ in_filters,
304
+ out_num_filters,
305
+ dilation,
306
+ filter_time_length=3,
307
+ nonlinearity: nn.Module = nn.ELU,
308
+ batch_norm_alpha=0.1,
309
+ batch_norm_epsilon=1e-4,
310
+ ):
311
+ super(_ResidualBlock, self).__init__()
312
+ time_padding = int((filter_time_length - 1) * dilation[0])
313
+ assert time_padding % 2 == 0
314
+ time_padding = int(time_padding // 2)
315
+ dilation = (int(dilation[0]), int(dilation[1]))
316
+ assert (out_num_filters - in_filters) % 2 == 0, (
317
+ "Need even number of extra channels in order to be able to pad correctly"
318
+ )
319
+ self.n_pad_chans = out_num_filters - in_filters
320
+
321
+ self.conv_1 = nn.Conv2d(
322
+ in_filters,
323
+ out_num_filters,
324
+ (filter_time_length, 1),
325
+ stride=(1, 1),
326
+ dilation=dilation,
327
+ padding=(time_padding, 0),
328
+ )
329
+ self.bn1 = nn.BatchNorm2d(
330
+ out_num_filters,
331
+ momentum=batch_norm_alpha,
332
+ affine=True,
333
+ eps=batch_norm_epsilon,
334
+ )
335
+ self.conv_2 = nn.Conv2d(
336
+ out_num_filters,
337
+ out_num_filters,
338
+ (filter_time_length, 1),
339
+ stride=(1, 1),
340
+ dilation=dilation,
341
+ padding=(time_padding, 0),
342
+ )
343
+ self.bn2 = nn.BatchNorm2d(
344
+ out_num_filters,
345
+ momentum=batch_norm_alpha,
346
+ affine=True,
347
+ eps=batch_norm_epsilon,
348
+ )
349
+ # also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
350
+ # for resnet options as ilya used them
351
+ self.nonlinearity = nonlinearity()
352
+
353
+ def forward(self, x):
354
+ stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
355
+ stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
356
+ if self.n_pad_chans != 0:
357
+ zeros_for_padding = x.new_zeros(
358
+ (x.shape[0], self.n_pad_chans // 2, x.shape[2], x.shape[3])
359
+ )
360
+ x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
361
+ out = self.nonlinearity(x + stack_2)
362
+ return out