braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,473 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ from typing import Dict, Optional
7
+
8
+ from einops.layers.torch import Rearrange
9
+ from mne.utils import warn
10
+ from torch import nn
11
+
12
+ from braindecode.functional import glorot_weight_zero_bias
13
+ from braindecode.models.base import EEGModuleMixin
14
+ from braindecode.modules import (
15
+ Conv2dWithConstraint,
16
+ Ensure4d,
17
+ Expression,
18
+ LinearWithConstraint,
19
+ SqueezeFinalOutput,
20
+ )
21
+
22
+
23
+ class EEGNetv4(EEGModuleMixin, nn.Sequential):
24
+ """EEGNet v4 model from Lawhern et al. (2018) [EEGNet4]_.
25
+
26
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/15/5/056013/revision2/jneaace8cf01_hr.jpg
27
+ :align: center
28
+ :alt: EEGNet4 Architecture
29
+
30
+ See details in [EEGNet4]_.
31
+
32
+ Parameters
33
+ ----------
34
+ final_conv_length : int or "auto", default="auto"
35
+ Length of the final convolution layer. If "auto", it is set based on n_times.
36
+ pool_mode : {"mean", "max"}, default="mean"
37
+ Pooling method to use in pooling layers.
38
+ F1 : int, default=8
39
+ Number of temporal filters in the first convolutional layer.
40
+ D : int, default=2
41
+ Depth multiplier for the depthwise convolution.
42
+ F2 : int or None, default=None
43
+ Number of pointwise filters in the separable convolution. Usually set to ``F1 * D``.
44
+ depthwise_kernel_length : int, default=16
45
+ Length of the depthwise convolution kernel in the separable convolution.
46
+ pool1_kernel_size : int, default=4
47
+ Kernel size of the first pooling layer.
48
+ pool2_kernel_size : int, default=8
49
+ Kernel size of the second pooling layer.
50
+ kernel_length : int, default=64
51
+ Length of the temporal convolution kernel.
52
+ conv_spatial_max_norm : float, default=1
53
+ Maximum norm constraint for the spatial (depthwise) convolution.
54
+ activation : nn.Module, default=nn.ELU
55
+ Non-linear activation function to be used in the layers.
56
+ batch_norm_momentum : float, default=0.01
57
+ Momentum for instance normalization in batch norm layers.
58
+ batch_norm_affine : bool, default=True
59
+ If True, batch norm has learnable affine parameters.
60
+ batch_norm_eps : float, default=1e-3
61
+ Epsilon for numeric stability in batch norm layers.
62
+ drop_prob : float, default=0.25
63
+ Dropout probability.
64
+ final_layer_with_constraint : bool, default=False
65
+ If ``False``, uses a convolution-based classification layer. If ``True``,
66
+ apply a flattened linear layer with constraint on the weights norm as the final classification step.
67
+ norm_rate : float, default=0.25
68
+ Max-norm constraint value for the linear layer (used if ``final_layer_conv=False``).
69
+
70
+ References
71
+ ----------
72
+ .. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, S. M.,
73
+ Hung, C. P., & Lance, B. J. (2018). EEGNet: a compact convolutional
74
+ neural network for EEG-based brain–computer interfaces. Journal of
75
+ neural engineering, 15(5), 056013.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ # signal's parameters
81
+ n_chans: Optional[int] = None,
82
+ n_outputs: Optional[int] = None,
83
+ n_times: Optional[int] = None,
84
+ # model's parameters
85
+ final_conv_length: str | int = "auto",
86
+ pool_mode: str = "mean",
87
+ F1: int = 8,
88
+ D: int = 2,
89
+ F2: Optional[int | None] = None,
90
+ kernel_length: int = 64,
91
+ *,
92
+ depthwise_kernel_length: int = 16,
93
+ pool1_kernel_size: int = 4,
94
+ pool2_kernel_size: int = 8,
95
+ conv_spatial_max_norm: int = 1,
96
+ activation: nn.Module = nn.ELU,
97
+ batch_norm_momentum: float = 0.01,
98
+ batch_norm_affine: bool = True,
99
+ batch_norm_eps: float = 1e-3,
100
+ drop_prob: float = 0.25,
101
+ final_layer_with_constraint: bool = False,
102
+ norm_rate: float = 0.25,
103
+ # Other ways to construct the signal related parameters
104
+ chs_info: Optional[list[Dict]] = None,
105
+ input_window_seconds=None,
106
+ sfreq=None,
107
+ **kwargs,
108
+ ):
109
+ super().__init__(
110
+ n_outputs=n_outputs,
111
+ n_chans=n_chans,
112
+ chs_info=chs_info,
113
+ n_times=n_times,
114
+ input_window_seconds=input_window_seconds,
115
+ sfreq=sfreq,
116
+ )
117
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
118
+ if final_conv_length == "auto":
119
+ assert self.n_times is not None
120
+
121
+ if not final_layer_with_constraint:
122
+ warn(
123
+ "Parameter 'final_layer_with_constraint=False' is deprecated and will be "
124
+ "removed in a future release. Please use `final_layer_linear=True`.",
125
+ DeprecationWarning,
126
+ )
127
+
128
+ if "third_kernel_size" in kwargs:
129
+ warn(
130
+ "The parameter `third_kernel_size` is deprecated "
131
+ "and will be removed in a future version.",
132
+ )
133
+ unexpected_kwargs = set(kwargs) - {"third_kernel_size"}
134
+ if unexpected_kwargs:
135
+ raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
136
+
137
+ self.final_conv_length = final_conv_length
138
+ self.pool_mode = pool_mode
139
+ self.F1 = F1
140
+ self.D = D
141
+
142
+ if F2 is None:
143
+ F2 = self.F1 * self.D
144
+ self.F2 = F2
145
+
146
+ self.kernel_length = kernel_length
147
+ self.depthwise_kernel_length = depthwise_kernel_length
148
+ self.pool1_kernel_size = pool1_kernel_size
149
+ self.pool2_kernel_size = pool2_kernel_size
150
+ self.drop_prob = drop_prob
151
+ self.activation = activation
152
+ self.batch_norm_momentum = batch_norm_momentum
153
+ self.batch_norm_affine = batch_norm_affine
154
+ self.batch_norm_eps = batch_norm_eps
155
+ self.conv_spatial_max_norm = conv_spatial_max_norm
156
+ self.norm_rate = norm_rate
157
+
158
+ # For the load_state_dict
159
+ # When padronize all layers,
160
+ # add the old's parameters here
161
+ self.mapping = {
162
+ "conv_classifier.weight": "final_layer.conv_classifier.weight",
163
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
164
+ }
165
+
166
+ pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
167
+ self.add_module("ensuredims", Ensure4d())
168
+
169
+ self.add_module("dimshuffle", Rearrange("batch ch t 1 -> batch 1 ch t"))
170
+ self.add_module(
171
+ "conv_temporal",
172
+ nn.Conv2d(
173
+ 1,
174
+ self.F1,
175
+ (1, self.kernel_length),
176
+ bias=False,
177
+ padding=(0, self.kernel_length // 2),
178
+ ),
179
+ )
180
+ self.add_module(
181
+ "bnorm_temporal",
182
+ nn.BatchNorm2d(
183
+ self.F1,
184
+ momentum=self.batch_norm_momentum,
185
+ affine=self.batch_norm_affine,
186
+ eps=self.batch_norm_eps,
187
+ ),
188
+ )
189
+ self.add_module(
190
+ "conv_spatial",
191
+ Conv2dWithConstraint(
192
+ in_channels=self.F1,
193
+ out_channels=self.F1 * self.D,
194
+ kernel_size=(self.n_chans, 1),
195
+ max_norm=self.conv_spatial_max_norm,
196
+ bias=False,
197
+ groups=self.F1,
198
+ ),
199
+ )
200
+
201
+ self.add_module(
202
+ "bnorm_1",
203
+ nn.BatchNorm2d(
204
+ self.F1 * self.D,
205
+ momentum=self.batch_norm_momentum,
206
+ affine=self.batch_norm_affine,
207
+ eps=self.batch_norm_eps,
208
+ ),
209
+ )
210
+ self.add_module("elu_1", activation())
211
+
212
+ self.add_module(
213
+ "pool_1",
214
+ pool_class(
215
+ kernel_size=(1, self.pool1_kernel_size),
216
+ ),
217
+ )
218
+ self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
219
+
220
+ # https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
221
+ self.add_module(
222
+ "conv_separable_depth",
223
+ nn.Conv2d(
224
+ self.F1 * self.D,
225
+ self.F1 * self.D,
226
+ (1, self.depthwise_kernel_length),
227
+ bias=False,
228
+ groups=self.F1 * self.D,
229
+ padding=(0, self.depthwise_kernel_length // 2),
230
+ ),
231
+ )
232
+ self.add_module(
233
+ "conv_separable_point",
234
+ nn.Conv2d(
235
+ self.F1 * self.D,
236
+ self.F2,
237
+ kernel_size=(1, 1),
238
+ bias=False,
239
+ ),
240
+ )
241
+
242
+ self.add_module(
243
+ "bnorm_2",
244
+ nn.BatchNorm2d(
245
+ self.F2,
246
+ momentum=self.batch_norm_momentum,
247
+ affine=self.batch_norm_affine,
248
+ eps=self.batch_norm_eps,
249
+ ),
250
+ )
251
+ self.add_module("elu_2", self.activation())
252
+ self.add_module(
253
+ "pool_2",
254
+ pool_class(
255
+ kernel_size=(1, self.pool2_kernel_size),
256
+ ),
257
+ )
258
+ self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
259
+
260
+ output_shape = self.get_output_shape()
261
+ n_out_virtual_chans = output_shape[2]
262
+
263
+ if self.final_conv_length == "auto":
264
+ n_out_time = output_shape[3]
265
+ self.final_conv_length = n_out_time
266
+
267
+ # Incorporating classification module and subsequent ones in one final layer
268
+ module = nn.Sequential()
269
+ if not final_layer_with_constraint:
270
+ module.add_module(
271
+ "conv_classifier",
272
+ nn.Conv2d(
273
+ self.F2,
274
+ self.n_outputs,
275
+ (n_out_virtual_chans, self.final_conv_length),
276
+ bias=True,
277
+ ),
278
+ )
279
+
280
+ # Transpose back to the logic of braindecode,
281
+ # so time in third dimension (axis=2)
282
+ module.add_module(
283
+ "permute_back",
284
+ Rearrange("batch x y z -> batch x z y"),
285
+ )
286
+
287
+ module.add_module("squeeze", SqueezeFinalOutput())
288
+ else:
289
+ module.add_module("flatten", nn.Flatten())
290
+ module.add_module(
291
+ "linearconstraint",
292
+ LinearWithConstraint(
293
+ in_features=self.F2 * self.final_conv_length,
294
+ out_features=self.n_outputs,
295
+ max_norm=norm_rate,
296
+ ),
297
+ )
298
+ self.add_module("final_layer", module)
299
+
300
+ glorot_weight_zero_bias(self)
301
+
302
+
303
+ class EEGNetv1(EEGModuleMixin, nn.Sequential):
304
+ """EEGNet model from Lawhern et al. 2016 from [EEGNet]_.
305
+
306
+ See details in [EEGNet]_.
307
+
308
+ Parameters
309
+ ----------
310
+ in_chans :
311
+ Alias for n_chans.
312
+ n_classes:
313
+ Alias for n_outputs.
314
+ input_window_samples :
315
+ Alias for n_times.
316
+ activation: nn.Module, default=nn.ELU
317
+ Activation function class to apply. Should be a PyTorch activation
318
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
319
+
320
+ Notes
321
+ -----
322
+ This implementation is not guaranteed to be correct, has not been checked
323
+ by original authors, only reimplemented from the paper description.
324
+
325
+ References
326
+ ----------
327
+ .. [EEGNet] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
328
+ S. M., Hung, C. P., & Lance, B. J. (2016).
329
+ EEGNet: A Compact Convolutional Network for EEG-based
330
+ Brain-Computer Interfaces.
331
+ arXiv preprint arXiv:1611.08024.
332
+ """
333
+
334
+ def __init__(
335
+ self,
336
+ n_chans=None,
337
+ n_outputs=None,
338
+ n_times=None,
339
+ final_conv_length="auto",
340
+ pool_mode="max",
341
+ second_kernel_size=(2, 32),
342
+ third_kernel_size=(8, 4),
343
+ drop_prob=0.25,
344
+ activation: nn.Module = nn.ELU,
345
+ chs_info=None,
346
+ input_window_seconds=None,
347
+ sfreq=None,
348
+ ):
349
+ super().__init__(
350
+ n_outputs=n_outputs,
351
+ n_chans=n_chans,
352
+ chs_info=chs_info,
353
+ n_times=n_times,
354
+ input_window_seconds=input_window_seconds,
355
+ sfreq=sfreq,
356
+ )
357
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
358
+ warn(
359
+ "The class EEGNetv1 is deprecated and will be removed in the "
360
+ "release 1.0 of braindecode. Please use "
361
+ "braindecode.models.EEGNetv4 instead in the future.",
362
+ DeprecationWarning,
363
+ )
364
+ if final_conv_length == "auto":
365
+ assert self.n_times is not None
366
+ self.final_conv_length = final_conv_length
367
+ self.pool_mode = pool_mode
368
+ self.second_kernel_size = second_kernel_size
369
+ self.third_kernel_size = third_kernel_size
370
+ self.drop_prob = drop_prob
371
+ # For the load_state_dict
372
+ # When padronize all layers,
373
+ # add the old's parameters here
374
+ self.mapping = {
375
+ "conv_classifier.weight": "final_layer.conv_classifier.weight",
376
+ "conv_classifier.bias": "final_layer.conv_classifier.bias",
377
+ }
378
+
379
+ pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
380
+ self.add_module("ensuredims", Ensure4d())
381
+ n_filters_1 = 16
382
+ self.add_module(
383
+ "conv_1",
384
+ nn.Conv2d(self.n_chans, n_filters_1, (1, 1), stride=1, bias=True),
385
+ )
386
+ self.add_module(
387
+ "bnorm_1",
388
+ nn.BatchNorm2d(n_filters_1, momentum=0.01, affine=True, eps=1e-3),
389
+ )
390
+ self.add_module("elu_1", activation())
391
+ # transpose to examples x 1 x (virtual, not EEG) channels x time
392
+ self.add_module("permute_1", Rearrange("batch x y z -> batch z x y"))
393
+
394
+ self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
395
+
396
+ n_filters_2 = 4
397
+ # keras pads unequal padding more in front, so padding
398
+ # too large should be ok.
399
+ # Not padding in time so that cropped training makes sense
400
+ # https://stackoverflow.com/questions/43994604/padding-with-even-kernel-size-in-a-convolutional-layer-in-keras-theano
401
+
402
+ self.add_module(
403
+ "conv_2",
404
+ nn.Conv2d(
405
+ 1,
406
+ n_filters_2,
407
+ self.second_kernel_size,
408
+ stride=1,
409
+ padding=(self.second_kernel_size[0] // 2, 0),
410
+ bias=True,
411
+ ),
412
+ )
413
+ self.add_module(
414
+ "bnorm_2",
415
+ nn.BatchNorm2d(n_filters_2, momentum=0.01, affine=True, eps=1e-3),
416
+ )
417
+ self.add_module("elu_2", activation())
418
+ self.add_module("pool_2", pool_class(kernel_size=(2, 4), stride=(2, 4)))
419
+ self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
420
+
421
+ n_filters_3 = 4
422
+ self.add_module(
423
+ "conv_3",
424
+ nn.Conv2d(
425
+ n_filters_2,
426
+ n_filters_3,
427
+ self.third_kernel_size,
428
+ stride=1,
429
+ padding=(self.third_kernel_size[0] // 2, 0),
430
+ bias=True,
431
+ ),
432
+ )
433
+ self.add_module(
434
+ "bnorm_3",
435
+ nn.BatchNorm2d(n_filters_3, momentum=0.01, affine=True, eps=1e-3),
436
+ )
437
+ self.add_module("elu_3", activation())
438
+ self.add_module("pool_3", pool_class(kernel_size=(2, 4), stride=(2, 4)))
439
+ self.add_module("drop_3", nn.Dropout(p=self.drop_prob))
440
+
441
+ output_shape = self.get_output_shape()
442
+ n_out_virtual_chans = output_shape[2]
443
+
444
+ if self.final_conv_length == "auto":
445
+ n_out_time = output_shape[3]
446
+ self.final_conv_length = n_out_time
447
+
448
+ # Incorporating classification module and subsequent ones in one final layer
449
+ module = nn.Sequential()
450
+
451
+ module.add_module(
452
+ "conv_classifier",
453
+ nn.Conv2d(
454
+ n_filters_3,
455
+ self.n_outputs,
456
+ (n_out_virtual_chans, self.final_conv_length),
457
+ bias=True,
458
+ ),
459
+ )
460
+
461
+ # Transpose back to the logic of braindecode,
462
+
463
+ # so time in third dimension (axis=2)
464
+ module.add_module(
465
+ "permute_2",
466
+ Rearrange("batch x y z -> batch x z y"),
467
+ )
468
+
469
+ module.add_module("squeeze", SqueezeFinalOutput())
470
+
471
+ self.add_module("final_layer", module)
472
+
473
+ glorot_weight_zero_bias(self)