braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,379 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
+ # Cedric Rommel <cedric.rommel@inria.fr>
3
+ #
4
+ # License: BSD (3-clause)
5
+ import math
6
+
7
+ from einops.layers.torch import Rearrange
8
+ from torch import nn
9
+
10
+ from braindecode.functional import glorot_weight_zero_bias
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
13
+
14
+
15
+ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
16
+ r"""EEG Inception for ERP-based from Santamaria-Vazquez et al (2020) [santamaria2020]_.
17
+
18
+ :bdg-success:`Convolution`
19
+
20
+ .. figure:: https://braindecode.org/dev/_static/model/eeginceptionerp.jpg
21
+ :align: center
22
+ :alt: EEGInceptionERP Architecture
23
+
24
+ Figure: Overview of EEG-Inception architecture. 2D convolution blocks and depthwise 2D convolution blocks include batch normalization, activation and dropout regularization. The kernel size is displayed for convolutional and average pooling layers.
25
+
26
+ .. rubric:: Architectural Overview
27
+
28
+ A two-stage, multi-scale CNN tailored to ERP detection from short (0-1000 ms) single-trial epochs. Signals are mapped through
29
+ * (i) :class:`_InceptionModule1` multi-scale temporal feature extraction plus per-branch spatial mixing;
30
+ * (ii) :class:`_InceptionModule2` deeper multi-scale refinement at a reduced temporal resolution; and
31
+ * (iii) :class:`_OutputModule` compact aggregation and linear readout.
32
+
33
+ .. rubric:: Macro Components
34
+
35
+ - :class:`_InceptionModule1` **(multi-scale temporal + spatial mixing)**
36
+
37
+ - *Operations.*
38
+
39
+ - `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
40
+ - `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
41
+ - `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
42
+ - `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
43
+ - `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
44
+ - `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
45
+
46
+ *Interpretability/robustness.* Depthwise `1 x n_chans` layers act as learnable montage-wide spatial filters per temporal scale; pooling stabilizes against jitter.
47
+
48
+ - :class:`_InceptionModule2` **(refinement at coarser timebase)**
49
+
50
+ - *Operations.*
51
+
52
+ - `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
53
+ - `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
54
+ - `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout.
55
+ - `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
56
+ - `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
57
+ - `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
58
+ - `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
59
+
60
+ *Role.* Adds higher-level, shorter-window evidence while progressively compressing temporal dimension.
61
+
62
+ - :class:`_OutputModule` **(aggregation + readout)**
63
+
64
+ - *Operations.*
65
+
66
+ - :class:`torch.nn.Flatten`
67
+ - :class:`torch.nn.Linear` ``(features → 2)``
68
+
69
+ .. rubric:: Convolutional Details
70
+
71
+ - **Temporal (where time-domain patterns are learned).**
72
+
73
+ First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
74
+ (≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
75
+ ``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
76
+ temporal resolution changes only via average pooling.
77
+
78
+ - **Spatial (how electrodes are processed).**
79
+
80
+ Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
81
+ yielding scale-specific channel projections (no cross-branch mixing until concatenation).
82
+ There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
83
+
84
+ - **Spectral (how frequency information is captured).**
85
+
86
+ No explicit transform; multiple temporal kernels form a *learned filter bank* over
87
+ ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
88
+ post-stimulus components.
89
+
90
+ .. rubric:: Additional Mechanisms
91
+
92
+ - Every conv/depthwise block includes **BatchNorm**, nonlinearity (paper used grid-searched activation), and **dropout**.
93
+ - Two Inception stages followed by short convs and pooling keep parameters small (≈15k reported) while preserving multi-scale evidence.
94
+ - Expected input: epochs of shape ``(B,1,128,8)`` (time x channels as a 2D map) or reshaped from ``(B,8,128)`` with an added singleton feature dimension.
95
+
96
+ .. rubric:: Usage and Configuration
97
+
98
+ - **Key knobs.** Number of filters per branch; kernel lengths in both Inception modules; depthwise kernel over channels (typically ``n_chans``); pooling lengths/strides; dropout rate; choice of activation.
99
+ - **Training tips.** Use 0-1000 ms windows at 128 Hz with CAR; tune activation and dropout (they strongly affect performance); early-stop on validation loss when overfitting emerges.
100
+
101
+ .. rubric:: Implementation Details
102
+
103
+ The model is strongly based on the original InceptionNet for an image. The main goal is
104
+ to extract features in parallel with different scales. The authors extracted three scales
105
+ proportional to the window sample size. The network had three parts:
106
+ 1-larger inception block largest, 2-smaller inception block followed by 3-bottleneck
107
+ for classification.
108
+
109
+ One advantage of the EEG-Inception block is that it allows a network
110
+ to learn simultaneous components of low and high frequency associated with the signal.
111
+ The winners of BEETL Competition/NeurIps 2021 used parts of the
112
+ model [beetl]_.
113
+
114
+ The code for the paper and this model is also available at [santamaria2020]_
115
+ and an adaptation for PyTorch [2]_.
116
+
117
+
118
+ Parameters
119
+ ----------
120
+ n_times : int, optional
121
+ Size of the input, in number of samples. Set to 128 (1s) as in
122
+ [santamaria2020]_.
123
+ sfreq : float, optional
124
+ EEG sampling frequency. Defaults to 128 as in [santamaria2020]_.
125
+ drop_prob : float, optional
126
+ Dropout rate inside all the network. Defaults to 0.5 as in
127
+ [santamaria2020]_.
128
+ scales_samples_s: list(float), optional
129
+ Windows for inception block. Temporal scale (s) of the convolutions on
130
+ each Inception module. This parameter determines the kernel sizes of
131
+ the filters. Defaults to 0.5, 0.25, 0.125 seconds, as in
132
+ [santamaria2020]_.
133
+ n_filters : int, optional
134
+ Initial number of convolutional filters. Defaults to 8 as in
135
+ [santamaria2020]_.
136
+ activation: nn.Module, optional
137
+ Activation function. Defaults to ELU activation as in
138
+ [santamaria2020]_.
139
+ batch_norm_alpha: float, optional
140
+ Momentum for BatchNorm2d. Defaults to 0.01.
141
+ depth_multiplier: int, optional
142
+ Depth multiplier for the depthwise convolution. Defaults to 2 as in
143
+ [santamaria2020]_.
144
+ pooling_sizes: list(int), optional
145
+ Pooling sizes for the inception blocks. Defaults to 4, 2, 2 and 2, as
146
+ in [santamaria2020]_.
147
+
148
+
149
+ References
150
+ ----------
151
+ .. [santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V.,
152
+ Vaquerizo-Villar, F., & Hornero, R. (2020).
153
+ EEG-inception: A novel deep convolutional neural network for assistive
154
+ ERP-based brain-computer interfaces.
155
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering , v. 28.
156
+ Online: http://dx.doi.org/10.1109/TNSRE.2020.3048106
157
+ .. [2] Grifcc. Implementation of the EEGInception in torch (2022).
158
+ Online: https://github.com/Grifcc/EEG/
159
+ .. [beetl] Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S.,
160
+ Jayaram, V., Jeunet, C., Bakas, S., Ludwig, S., Barmpas, K., Bahri, M., Panagakis,
161
+ Y., Laskaris, N., Adamos, D.A., Zafeiriou, S., Duong, W.C., Gordon, S.M.,
162
+ Lawhern, V.J., Śliwowski, M., Rouanne, V. &amp; Tempczyk, P. (2022).
163
+ 2021 BEETL Competition: Advancing Transfer Learning for Subject Independence &amp;
164
+ Heterogeneous EEG Data Sets. Proceedings of the NeurIPS 2021 Competitions and
165
+ Demonstrations Track, in Proceedings of Machine Learning Research
166
+ 176:205-219 Available from https://proceedings.mlr.press/v176/wei22a.html.
167
+
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ n_chans=None,
173
+ n_outputs=None,
174
+ n_times=1000,
175
+ sfreq=128,
176
+ drop_prob=0.5,
177
+ scales_samples_s=(0.5, 0.25, 0.125),
178
+ n_filters=8,
179
+ activation: type[nn.Module] = nn.ELU,
180
+ batch_norm_alpha=0.01,
181
+ depth_multiplier=2,
182
+ pooling_sizes=(4, 2, 2, 2),
183
+ chs_info=None,
184
+ input_window_seconds=None,
185
+ ):
186
+ super().__init__(
187
+ n_outputs=n_outputs,
188
+ n_chans=n_chans,
189
+ chs_info=chs_info,
190
+ n_times=n_times,
191
+ input_window_seconds=input_window_seconds,
192
+ sfreq=sfreq,
193
+ )
194
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
195
+ self.drop_prob = drop_prob
196
+ self.n_filters = n_filters
197
+ self.scales_samples_s = scales_samples_s
198
+ self.scales_samples = tuple(
199
+ int(size_s * self.sfreq) for size_s in self.scales_samples_s
200
+ )
201
+ self.activation = activation
202
+ self.alpha_momentum = batch_norm_alpha
203
+ self.depth_multiplier = depth_multiplier
204
+ self.pooling_sizes = pooling_sizes
205
+
206
+ self.mapping = {
207
+ "classification.1.weight": "final_layer.fc.weight",
208
+ "classification.1.bias": "final_layer.fc.bias",
209
+ }
210
+
211
+ self.add_module("ensuredims", Ensure4d())
212
+
213
+ self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 C T"))
214
+
215
+ # ======== Inception branches ========================
216
+ block11 = self._get_inception_branch_1(
217
+ in_channels=self.n_chans,
218
+ out_channels=self.n_filters,
219
+ kernel_length=self.scales_samples[0],
220
+ alpha_momentum=self.alpha_momentum,
221
+ activation=self.activation,
222
+ drop_prob=self.drop_prob,
223
+ depth_multiplier=self.depth_multiplier,
224
+ )
225
+ block12 = self._get_inception_branch_1(
226
+ in_channels=self.n_chans,
227
+ out_channels=self.n_filters,
228
+ kernel_length=self.scales_samples[1],
229
+ alpha_momentum=self.alpha_momentum,
230
+ activation=self.activation,
231
+ drop_prob=self.drop_prob,
232
+ depth_multiplier=self.depth_multiplier,
233
+ )
234
+ block13 = self._get_inception_branch_1(
235
+ in_channels=self.n_chans,
236
+ out_channels=self.n_filters,
237
+ kernel_length=self.scales_samples[2],
238
+ alpha_momentum=self.alpha_momentum,
239
+ activation=self.activation,
240
+ drop_prob=self.drop_prob,
241
+ depth_multiplier=self.depth_multiplier,
242
+ )
243
+
244
+ self.add_module(
245
+ "inception_block_1", InceptionBlock((block11, block12, block13))
246
+ )
247
+
248
+ self.add_module("avg_pool_1", nn.AvgPool2d((1, self.pooling_sizes[0])))
249
+
250
+ # ======== Inception branches ========================
251
+ n_concat_filters = len(self.scales_samples) * self.n_filters
252
+ n_concat_dw_filters = n_concat_filters * self.depth_multiplier
253
+ block21 = self._get_inception_branch_2(
254
+ in_channels=n_concat_dw_filters,
255
+ out_channels=self.n_filters,
256
+ kernel_length=self.scales_samples[0] // 4,
257
+ alpha_momentum=self.alpha_momentum,
258
+ activation=self.activation,
259
+ drop_prob=self.drop_prob,
260
+ )
261
+ block22 = self._get_inception_branch_2(
262
+ in_channels=n_concat_dw_filters,
263
+ out_channels=self.n_filters,
264
+ kernel_length=self.scales_samples[1] // 4,
265
+ alpha_momentum=self.alpha_momentum,
266
+ activation=self.activation,
267
+ drop_prob=self.drop_prob,
268
+ )
269
+ block23 = self._get_inception_branch_2(
270
+ in_channels=n_concat_dw_filters,
271
+ out_channels=self.n_filters,
272
+ kernel_length=self.scales_samples[2] // 4,
273
+ alpha_momentum=self.alpha_momentum,
274
+ activation=self.activation,
275
+ drop_prob=self.drop_prob,
276
+ )
277
+
278
+ self.add_module(
279
+ "inception_block_2", InceptionBlock((block21, block22, block23))
280
+ )
281
+
282
+ self.add_module("avg_pool_2", nn.AvgPool2d((1, self.pooling_sizes[1])))
283
+
284
+ self.add_module(
285
+ "final_block",
286
+ nn.Sequential(
287
+ nn.Conv2d(
288
+ n_concat_filters,
289
+ n_concat_filters // 2,
290
+ (1, 8),
291
+ padding="same",
292
+ bias=False,
293
+ ),
294
+ nn.BatchNorm2d(n_concat_filters // 2, momentum=self.alpha_momentum),
295
+ activation(),
296
+ nn.Dropout(self.drop_prob),
297
+ nn.AvgPool2d((1, self.pooling_sizes[2])),
298
+ nn.Conv2d(
299
+ n_concat_filters // 2,
300
+ n_concat_filters // 4,
301
+ (1, 4),
302
+ padding="same",
303
+ bias=False,
304
+ ),
305
+ nn.BatchNorm2d(n_concat_filters // 4, momentum=self.alpha_momentum),
306
+ activation(),
307
+ nn.Dropout(self.drop_prob),
308
+ nn.AvgPool2d((1, self.pooling_sizes[3])),
309
+ ),
310
+ )
311
+
312
+ spatial_dim_last_layer = self.n_times // math.prod(self.pooling_sizes)
313
+ n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4
314
+
315
+ self.add_module("flat", nn.Flatten())
316
+
317
+ # Incorporating classification module and subsequent ones in one final layer
318
+ module = nn.Sequential()
319
+
320
+ module.add_module(
321
+ "fc",
322
+ nn.Linear(spatial_dim_last_layer * n_channels_last_layer, self.n_outputs),
323
+ )
324
+
325
+ module.add_module("identity", nn.Identity())
326
+
327
+ self.add_module("final_layer", module)
328
+
329
+ glorot_weight_zero_bias(self)
330
+
331
+ @staticmethod
332
+ def _get_inception_branch_1(
333
+ in_channels,
334
+ out_channels,
335
+ kernel_length,
336
+ alpha_momentum,
337
+ drop_prob,
338
+ activation,
339
+ depth_multiplier,
340
+ ):
341
+ return nn.Sequential(
342
+ nn.Conv2d(
343
+ 1,
344
+ out_channels,
345
+ kernel_size=(1, kernel_length),
346
+ padding="same",
347
+ bias=True,
348
+ ),
349
+ nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
350
+ activation(),
351
+ nn.Dropout(drop_prob),
352
+ DepthwiseConv2d(
353
+ out_channels,
354
+ kernel_size=(in_channels, 1),
355
+ depth_multiplier=depth_multiplier,
356
+ bias=False,
357
+ padding="valid",
358
+ ),
359
+ nn.BatchNorm2d(depth_multiplier * out_channels, momentum=alpha_momentum),
360
+ activation(),
361
+ nn.Dropout(drop_prob),
362
+ )
363
+
364
+ @staticmethod
365
+ def _get_inception_branch_2(
366
+ in_channels, out_channels, kernel_length, alpha_momentum, drop_prob, activation
367
+ ):
368
+ return nn.Sequential(
369
+ nn.Conv2d(
370
+ in_channels,
371
+ out_channels,
372
+ kernel_size=(1, kernel_length),
373
+ padding="same",
374
+ bias=False,
375
+ ),
376
+ nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
377
+ activation(),
378
+ nn.Dropout(drop_prob),
379
+ )