braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -4,62 +4,67 @@
4
4
  # License: BSD-3
5
5
 
6
6
  import numpy as np
7
-
8
7
  import torch
8
+ from einops.layers.torch import Rearrange
9
9
  from torch import nn
10
10
  from torch.nn import init
11
- from torch.nn.functional import elu
12
- from einops.layers.torch import Rearrange
13
11
 
14
- from .functions import squeeze_final_output
15
- from .modules import Expression, AvgPool2dWithConv, Ensure4d
16
- from .base import EEGModuleMixin, deprecated_args
12
+ from braindecode.models.base import EEGModuleMixin
13
+ from braindecode.modules import (
14
+ AvgPool2dWithConv,
15
+ Ensure4d,
16
+ SqueezeFinalOutput,
17
+ )
17
18
 
18
19
 
19
20
  class EEGResNet(EEGModuleMixin, nn.Sequential):
20
- """Residual Network for EEG.
21
+ """EEGResNet from Schirrmeister et al. 2017 [Schirrmeister2017]_.
22
+
23
+ .. figure:: https://onlinelibrary.wiley.com/cms/asset/bed1b768-809f-4bc6-b942-b36970d81271/hbm23730-fig-0003-m.jpg
24
+ :align: center
25
+ :alt: EEGResNet Architecture
21
26
 
22
- XXX missing reference
27
+ Model described in [Schirrmeister2017]_.
23
28
 
24
29
  Parameters
25
30
  ----------
26
- in_chans :
27
- Alias for `n_chans`.
28
- n_classes :
29
- Alias for `n_outputs`.
30
- input_window_samples :
31
- Alias for `n_times`.
31
+ in_chans :
32
+ Alias for ``n_chans``.
33
+ n_classes :
34
+ Alias for ``n_outputs``.
35
+ input_window_samples :
36
+ Alias for ``n_times``.
37
+ activation: nn.Module, default=nn.ELU
38
+ Activation function class to apply. Should be a PyTorch activation
39
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
32
40
 
41
+ References
42
+ ----------
43
+ .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
44
+ L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
45
+ & Ball, T. (2017). Deep learning with convolutional neural networks for ,
46
+ EEG decoding and visualization. Human Brain Mapping, Aug. 2017.
47
+ Online: http://dx.doi.org/10.1002/hbm.23730
33
48
  """
34
49
 
35
50
  def __init__(
36
- self,
37
- n_chans=None,
38
- n_outputs=None,
39
- n_times=None,
40
- final_pool_length=None,
41
- n_first_filters=None,
42
- n_layers_per_block=2,
43
- first_filter_length=3,
44
- nonlinearity=elu,
45
- split_first_layer=True,
46
- batch_norm_alpha=0.1,
47
- batch_norm_epsilon=1e-4,
48
- conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
49
- chs_info=None,
50
- input_window_seconds=None,
51
- sfreq=None,
52
- in_chans=None,
53
- n_classes=None,
54
- input_window_samples=None,
55
- add_log_softmax=True,
51
+ self,
52
+ n_chans=None,
53
+ n_outputs=None,
54
+ n_times=None,
55
+ final_pool_length="auto",
56
+ n_first_filters=20,
57
+ n_layers_per_block=2,
58
+ first_filter_length=3,
59
+ activation=nn.ELU,
60
+ split_first_layer=True,
61
+ batch_norm_alpha=0.1,
62
+ batch_norm_epsilon=1e-4,
63
+ conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
64
+ chs_info=None,
65
+ input_window_seconds=None,
66
+ sfreq=250,
56
67
  ):
57
- n_chans, n_outputs, n_times = deprecated_args(
58
- self,
59
- ("in_chans", "n_chans", in_chans, n_chans),
60
- ("n_classes", "n_outputs", n_classes, n_outputs),
61
- ("input_window_samples", "n_times", input_window_samples, n_times),
62
- )
63
68
  super().__init__(
64
69
  n_outputs=n_outputs,
65
70
  n_chans=n_chans,
@@ -67,19 +72,17 @@ class EEGResNet(EEGModuleMixin, nn.Sequential):
67
72
  n_times=n_times,
68
73
  input_window_seconds=input_window_seconds,
69
74
  sfreq=sfreq,
70
- add_log_softmax=add_log_softmax,
71
75
  )
72
76
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
73
- del in_chans, n_classes, input_window_samples
74
77
 
75
- if final_pool_length == 'auto':
78
+ if final_pool_length == "auto":
76
79
  assert self.n_times is not None
77
80
  assert first_filter_length % 2 == 1
78
81
  self.final_pool_length = final_pool_length
79
82
  self.n_first_filters = n_first_filters
80
83
  self.n_layers_per_block = n_layers_per_block
81
84
  self.first_filter_length = first_filter_length
82
- self.nonlinearity = nonlinearity
85
+ self.nonlinearity = activation
83
86
  self.split_first_layer = split_first_layer
84
87
  self.batch_norm_alpha = batch_norm_alpha
85
88
  self.batch_norm_epsilon = batch_norm_epsilon
@@ -87,147 +90,207 @@ class EEGResNet(EEGModuleMixin, nn.Sequential):
87
90
 
88
91
  self.mapping = {
89
92
  "conv_classifier.weight": "final_layer.conv_classifier.weight",
90
- "conv_classifier.bias": "final_layer.conv_classifier.bias"
93
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
91
94
  }
92
95
 
93
96
  self.add_module("ensuredims", Ensure4d())
94
97
  if self.split_first_layer:
95
- self.add_module('dimshuffle',
96
- Rearrange("batch C T 1 -> batch 1 T C"))
97
- self.add_module('conv_time', nn.Conv2d(1, self.n_first_filters,
98
- (self.first_filter_length, 1),
99
- stride=1,
100
- padding=(self.first_filter_length // 2, 0)))
101
- self.add_module('conv_spat',
102
- nn.Conv2d(self.n_first_filters, self.n_first_filters,
103
- (1, self.n_chans),
104
- stride=(1, 1),
105
- bias=False))
98
+ self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
99
+ self.add_module(
100
+ "conv_time",
101
+ nn.Conv2d(
102
+ 1,
103
+ self.n_first_filters,
104
+ (self.first_filter_length, 1),
105
+ stride=1,
106
+ padding=(self.first_filter_length // 2, 0),
107
+ ),
108
+ )
109
+ self.add_module(
110
+ "conv_spat",
111
+ nn.Conv2d(
112
+ self.n_first_filters,
113
+ self.n_first_filters,
114
+ (1, self.n_chans),
115
+ stride=(1, 1),
116
+ bias=False,
117
+ ),
118
+ )
106
119
  else:
107
- self.add_module('conv_time',
108
- nn.Conv2d(self.n_chans, self.n_first_filters,
109
- (self.first_filter_length, 1),
110
- stride=(1, 1),
111
- padding=(self.first_filter_length // 2, 0),
112
- bias=False, ))
120
+ self.add_module(
121
+ "conv_time",
122
+ nn.Conv2d(
123
+ self.n_chans,
124
+ self.n_first_filters,
125
+ (self.first_filter_length, 1),
126
+ stride=(1, 1),
127
+ padding=(self.first_filter_length // 2, 0),
128
+ bias=False,
129
+ ),
130
+ )
113
131
  n_filters_conv = self.n_first_filters
114
- self.add_module('bnorm',
115
- nn.BatchNorm2d(n_filters_conv,
116
- momentum=self.batch_norm_alpha,
117
- affine=True,
118
- eps=1e-5), )
119
- self.add_module('conv_nonlin', Expression(self.nonlinearity))
132
+ self.add_module(
133
+ "bnorm",
134
+ nn.BatchNorm2d(
135
+ n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5
136
+ ),
137
+ )
138
+ self.add_module("conv_nonlin", self.nonlinearity())
120
139
  cur_dilation = np.array([1, 1])
121
140
  n_cur_filters = n_filters_conv
122
141
  i_block = 1
123
142
  for i_layer in range(self.n_layers_per_block):
124
- self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
125
- _ResidualBlock(n_cur_filters, n_cur_filters,
126
- dilation=cur_dilation))
143
+ self.add_module(
144
+ "res_{:d}_{:d}".format(i_block, i_layer),
145
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
146
+ )
127
147
  i_block += 1
128
148
  cur_dilation[0] *= 2
129
149
  n_out_filters = int(2 * n_cur_filters)
130
- self.add_module('res_{:d}_{:d}'.format(i_block, 0),
131
- _ResidualBlock(n_cur_filters, n_out_filters,
132
- dilation=cur_dilation, ))
150
+ self.add_module(
151
+ "res_{:d}_{:d}".format(i_block, 0),
152
+ _ResidualBlock(
153
+ n_cur_filters,
154
+ n_out_filters,
155
+ dilation=cur_dilation,
156
+ ),
157
+ )
133
158
  n_cur_filters = n_out_filters
134
159
  for i_layer in range(1, self.n_layers_per_block):
135
- self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
136
- _ResidualBlock(n_cur_filters, n_cur_filters,
137
- dilation=cur_dilation))
160
+ self.add_module(
161
+ "res_{:d}_{:d}".format(i_block, i_layer),
162
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
163
+ )
138
164
 
139
165
  i_block += 1
140
166
  cur_dilation[0] *= 2
141
167
  n_out_filters = int(1.5 * n_cur_filters)
142
- self.add_module('res_{:d}_{:d}'.format(i_block, 0),
143
- _ResidualBlock(n_cur_filters, n_out_filters,
144
- dilation=cur_dilation, ))
168
+ self.add_module(
169
+ "res_{:d}_{:d}".format(i_block, 0),
170
+ _ResidualBlock(
171
+ n_cur_filters,
172
+ n_out_filters,
173
+ dilation=cur_dilation,
174
+ ),
175
+ )
145
176
  n_cur_filters = n_out_filters
146
177
  for i_layer in range(1, self.n_layers_per_block):
147
- self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
148
- _ResidualBlock(n_cur_filters, n_cur_filters,
149
- dilation=cur_dilation))
178
+ self.add_module(
179
+ "res_{:d}_{:d}".format(i_block, i_layer),
180
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
181
+ )
150
182
 
151
183
  i_block += 1
152
184
  cur_dilation[0] *= 2
153
- self.add_module('res_{:d}_{:d}'.format(i_block, 0),
154
- _ResidualBlock(n_cur_filters, n_cur_filters,
155
- dilation=cur_dilation, ))
185
+ self.add_module(
186
+ "res_{:d}_{:d}".format(i_block, 0),
187
+ _ResidualBlock(
188
+ n_cur_filters,
189
+ n_cur_filters,
190
+ dilation=cur_dilation,
191
+ ),
192
+ )
156
193
  for i_layer in range(1, self.n_layers_per_block):
157
- self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
158
- _ResidualBlock(n_cur_filters, n_cur_filters,
159
- dilation=cur_dilation))
194
+ self.add_module(
195
+ "res_{:d}_{:d}".format(i_block, i_layer),
196
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
197
+ )
160
198
 
161
199
  i_block += 1
162
200
  cur_dilation[0] *= 2
163
- self.add_module('res_{:d}_{:d}'.format(i_block, 0),
164
- _ResidualBlock(n_cur_filters, n_cur_filters,
165
- dilation=cur_dilation, ))
201
+ self.add_module(
202
+ "res_{:d}_{:d}".format(i_block, 0),
203
+ _ResidualBlock(
204
+ n_cur_filters,
205
+ n_cur_filters,
206
+ dilation=cur_dilation,
207
+ ),
208
+ )
166
209
  for i_layer in range(1, self.n_layers_per_block):
167
- self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
168
- _ResidualBlock(n_cur_filters, n_cur_filters,
169
- dilation=cur_dilation))
210
+ self.add_module(
211
+ "res_{:d}_{:d}".format(i_block, i_layer),
212
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
213
+ )
170
214
 
171
215
  i_block += 1
172
216
  cur_dilation[0] *= 2
173
- self.add_module('res_{:d}_{:d}'.format(i_block, 0),
174
- _ResidualBlock(n_cur_filters, n_cur_filters,
175
- dilation=cur_dilation, ))
217
+ self.add_module(
218
+ "res_{:d}_{:d}".format(i_block, 0),
219
+ _ResidualBlock(
220
+ n_cur_filters,
221
+ n_cur_filters,
222
+ dilation=cur_dilation,
223
+ ),
224
+ )
176
225
  for i_layer in range(1, self.n_layers_per_block):
177
- self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
178
- _ResidualBlock(n_cur_filters, n_cur_filters,
179
- dilation=cur_dilation))
226
+ self.add_module(
227
+ "res_{:d}_{:d}".format(i_block, i_layer),
228
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
229
+ )
180
230
  i_block += 1
181
231
  cur_dilation[0] *= 2
182
- self.add_module('res_{:d}_{:d}'.format(i_block, 0),
183
- _ResidualBlock(n_cur_filters, n_cur_filters,
184
- dilation=cur_dilation, ))
232
+ self.add_module(
233
+ "res_{:d}_{:d}".format(i_block, 0),
234
+ _ResidualBlock(
235
+ n_cur_filters,
236
+ n_cur_filters,
237
+ dilation=cur_dilation,
238
+ ),
239
+ )
185
240
  for i_layer in range(1, self.n_layers_per_block):
186
- self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
187
- _ResidualBlock(n_cur_filters, n_cur_filters,
188
- dilation=cur_dilation))
241
+ self.add_module(
242
+ "res_{:d}_{:d}".format(i_block, i_layer),
243
+ _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
244
+ )
189
245
 
190
246
  self.eval()
191
- if self.final_pool_length == 'auto':
192
- self.add_module('mean_pool', nn.AdaptiveAvgPool2d((1, 1)))
247
+ if self.final_pool_length == "auto":
248
+ self.add_module("mean_pool", nn.AdaptiveAvgPool2d((1, 1)))
193
249
  else:
194
250
  pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
195
- self.add_module('mean_pool', AvgPool2dWithConv(
196
- (self.final_pool_length, 1), (1, 1),
197
- dilation=pool_dilation))
251
+ self.add_module(
252
+ "mean_pool",
253
+ AvgPool2dWithConv(
254
+ (self.final_pool_length, 1), (1, 1), dilation=pool_dilation
255
+ ),
256
+ )
198
257
 
199
258
  # Incorporating classification module and subsequent ones in one final layer
200
259
  module = nn.Sequential()
201
260
 
202
- module.add_module("conv_classifier",
203
- nn.Conv2d(n_cur_filters, self.n_outputs, (1, 1), bias=True, ))
204
-
205
- if self.add_log_softmax:
206
- module.add_module('logsoftmax', nn.LogSoftmax(dim=1))
261
+ module.add_module(
262
+ "conv_classifier",
263
+ nn.Conv2d(
264
+ n_cur_filters,
265
+ self.n_outputs,
266
+ (1, 1),
267
+ bias=True,
268
+ ),
269
+ )
207
270
 
208
- module.add_module("squeeze", Expression(squeeze_final_output))
271
+ module.add_module("squeeze", SqueezeFinalOutput())
209
272
 
210
273
  self.add_module("final_layer", module)
211
274
 
212
275
  # Initialize all weights
213
- self.apply(lambda module: _weights_init(module, self.conv_weight_init_fn))
214
-
215
- # Start in eval mode
216
- self.eval()
276
+ self.apply(lambda module: self._weights_init(module, self.conv_weight_init_fn))
217
277
 
278
+ # Start in train mode
279
+ self.train()
218
280
 
219
- def _weights_init(module, conv_weight_init_fn):
220
- """
221
- initialize weights
222
- """
223
- classname = module.__class__.__name__
224
- if 'Conv' in classname and classname != "AvgPool2dWithConv":
225
- conv_weight_init_fn(module.weight)
226
- if module.bias is not None:
281
+ @staticmethod
282
+ def _weights_init(module, conv_weight_init_fn):
283
+ """
284
+ initialize weights
285
+ """
286
+ classname = module.__class__.__name__
287
+ if "Conv" in classname and classname != "AvgPool2dWithConv":
288
+ conv_weight_init_fn(module.weight)
289
+ if module.bias is not None:
290
+ init.constant_(module.bias, 0)
291
+ elif "BatchNorm" in classname:
292
+ init.constant_(module.weight, 1)
227
293
  init.constant_(module.bias, 0)
228
- elif 'BatchNorm' in classname:
229
- init.constant_(module.weight, 1)
230
- init.constant_(module.bias, 0)
231
294
 
232
295
 
233
296
  class _ResidualBlock(nn.Module):
@@ -235,46 +298,65 @@ class _ResidualBlock(nn.Module):
235
298
  create a residual learning building block with two stacked 3x3 convlayers as in paper
236
299
  """
237
300
 
238
- def __init__(self, in_filters,
239
- out_num_filters,
240
- dilation,
241
- filter_time_length=3,
242
- nonlinearity=elu,
243
- batch_norm_alpha=0.1, batch_norm_epsilon=1e-4):
301
+ def __init__(
302
+ self,
303
+ in_filters,
304
+ out_num_filters,
305
+ dilation,
306
+ filter_time_length=3,
307
+ nonlinearity: nn.Module = nn.ELU,
308
+ batch_norm_alpha=0.1,
309
+ batch_norm_epsilon=1e-4,
310
+ ):
244
311
  super(_ResidualBlock, self).__init__()
245
312
  time_padding = int((filter_time_length - 1) * dilation[0])
246
313
  assert time_padding % 2 == 0
247
314
  time_padding = int(time_padding // 2)
248
315
  dilation = (int(dilation[0]), int(dilation[1]))
249
316
  assert (out_num_filters - in_filters) % 2 == 0, (
250
- "Need even number of extra channels in order to be able to "
251
- "pad correctly")
317
+ "Need even number of extra channels in order to be able to pad correctly"
318
+ )
252
319
  self.n_pad_chans = out_num_filters - in_filters
253
320
 
254
321
  self.conv_1 = nn.Conv2d(
255
- in_filters, out_num_filters, (filter_time_length, 1), stride=(1, 1),
322
+ in_filters,
323
+ out_num_filters,
324
+ (filter_time_length, 1),
325
+ stride=(1, 1),
256
326
  dilation=dilation,
257
- padding=(time_padding, 0))
327
+ padding=(time_padding, 0),
328
+ )
258
329
  self.bn1 = nn.BatchNorm2d(
259
- out_num_filters, momentum=batch_norm_alpha, affine=True,
260
- eps=batch_norm_epsilon)
330
+ out_num_filters,
331
+ momentum=batch_norm_alpha,
332
+ affine=True,
333
+ eps=batch_norm_epsilon,
334
+ )
261
335
  self.conv_2 = nn.Conv2d(
262
- out_num_filters, out_num_filters, (filter_time_length, 1), stride=(1, 1),
336
+ out_num_filters,
337
+ out_num_filters,
338
+ (filter_time_length, 1),
339
+ stride=(1, 1),
263
340
  dilation=dilation,
264
- padding=(time_padding, 0))
341
+ padding=(time_padding, 0),
342
+ )
265
343
  self.bn2 = nn.BatchNorm2d(
266
- out_num_filters, momentum=batch_norm_alpha,
267
- affine=True, eps=batch_norm_epsilon)
344
+ out_num_filters,
345
+ momentum=batch_norm_alpha,
346
+ affine=True,
347
+ eps=batch_norm_epsilon,
348
+ )
268
349
  # also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
269
350
  # for resnet options as ilya used them
270
- self.nonlinearity = nonlinearity
351
+ self.nonlinearity = nonlinearity()
271
352
 
272
353
  def forward(self, x):
273
354
  stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
274
355
  stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
275
356
  if self.n_pad_chans != 0:
276
- zeros_for_padding = torch.autograd.Variable(
277
- x.new_zeros((x.size()[0], self.n_pad_chans // 2, x.size()[2], x.size()[3])))
357
+ zeros_for_padding = x.new_zeros(
358
+ (x.shape[0], self.n_pad_chans // 2, x.shape[2], x.shape[3])
359
+ )
278
360
  x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
279
361
  out = self.nonlinearity(x + stack_2)
280
362
  return out