braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,439 @@
1
+ # Authors: Theo Gnassounou <theo.gnassounou@inria.fr>
2
+ # Omar Chehab <l-emir-omar.chehab@inria.fr>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+ from braindecode.models.base import EEGModuleMixin
12
+
13
+
14
+ class USleep(EEGModuleMixin, nn.Module):
15
+ r"""
16
+ Sleep staging architecture from Perslev et al (2021) [1]_.
17
+
18
+ :bdg-success:`Convolution`
19
+
20
+ .. figure:: https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41746-021-00440-5/MediaObjects/41746_2021_440_Fig2_HTML.png
21
+ :align: center
22
+ :alt: USleep Architecture
23
+
24
+ Figure: U-Sleep consists of an encoder (left) which encodes the input signals into dense feature representations, a decoder (middle) which projects
25
+ the learned features into the input space to generate a dense sleep stage representation, and finally a specially designed segment
26
+ classifier (right) which generates sleep stages at a chosen temporal resolution.
27
+
28
+ .. rubric:: Architectural Overview
29
+
30
+ U-Sleep is a **fully convolutional**, feed-forward encoder-decoder with a *segment classifier* head for
31
+ time-series **segmentation** (sleep staging). It maps multi-channel PSG (EEG+EOG) to a *dense, high-frequency*
32
+ per-sample representation, then aggregates it into fixed-length stage labels (e.g., 30 s). The network
33
+ processes arbitrarily long inputs in **one forward pass** (resampling to 128 Hz), allowing whole-night
34
+ hypnograms in seconds.
35
+
36
+ - (i). :class:`_EncoderBlock` extracts progressively deeper temporal features at lower resolution;
37
+ - (ii). :class:`_Decoder` upsamples and fuses encoder features via U-Net-style skips to recover a per-sample stage map;
38
+ - (iii). Segment Classifier mean-pools over the target epoch length and applies two pointwise convs to yield
39
+ per-epoch probabilities. Integrates into the USleep class.
40
+
41
+ .. rubric:: Macro Components
42
+
43
+ - Encoder :class:`_EncoderBlock` **(multi-scale temporal feature extractor; downsampling x2 per block)**
44
+
45
+ - *Operations.*
46
+ - **Conv1d** (:class:`torch.nn.Conv1d`) with kernel ``9`` (stride ``1``, no dilation)
47
+ - **ELU** (:class:`torch.nn.ELU`)
48
+ - **Batch Norm** (:class:`torch.nn.BatchNorm1d`)
49
+ - **Max Pool 1d**, :class:`torch.nn.MaxPool1d` (``kernel=2, stride=2``).
50
+
51
+ Filters grow with depth by a factor of ``sqrt(2)`` (start ``c_1=5``); each block exposes a **skip**
52
+ (pre-pooling activation) to the matching decoder block.
53
+ *Role.* Slow, uniform downsampling preserves early information while expanding the effective temporal
54
+ context over minutes—foundational for robust cross-cohort staging.
55
+
56
+ The number of filters grows with depth (capacity scaling); each block also exposes a **skip** (pre-pool)
57
+ to the matching decoder block.
58
+
59
+ **Rationale.**
60
+ - Slow, uniform downsampling (x2 each level) preserves information in early layers while expanding the temporal receptive field over the minutes.
61
+
62
+ - Decoder :class:`_DecoderBlock` **(progressive upsampling + skip fusion to high-frequency map, 12 blocks; upsampling x2 per block)**
63
+
64
+ - *Operations.*
65
+
66
+ - **Nearest-neighbor upsample**, :class:`nn.Upsample` (x2)
67
+ - **Convolution2d** (k=2), :class:`torch.nn.Conv2d`
68
+ - ELU, :class:`torch.nn.ELU`
69
+ - Batch Norm, :class:`torch.nn.BatchNorm2d`
70
+ - **Concatenate** with the encoder skip at the same temporal scale, ``torch.cat``
71
+ - **Convolution**, :class:`torch.nn.Conv2d`
72
+ - ELU, :class:`torch.nn.ELU`
73
+ - Batch Norm, :class:`torch.nn.BatchNorm2d`.
74
+
75
+ **Output**: A multi-class, **high-frequency** per-sample representation aligned to the input rate (128 Hz).
76
+
77
+ - **Segment Classifier incorporate into :class:`braindecode.models.USleep` (aggregation to fixed epochs)**
78
+
79
+ - *Operations.*
80
+
81
+ - **Mean-pool**, :class:`torch.nn.AvgPool2d` per class with kernel = epoch length *i* and stride *i*
82
+ - **1x1 conv**, :class:`torch.nn.Conv2d`
83
+ - ELU, :class:`torch.nn.ELU`
84
+ - **1x1 conv**, :class:`torch.nn.Conv2d` with ``(T, K)`` (epochs x stages).
85
+
86
+ **Role**: Learns a **non-linear** weighted combination over each 30-s window (unlike U-Time's linear combiner).
87
+
88
+ .. rubric:: Convolutional Details
89
+
90
+ - **Temporal (where time-domain patterns are learned).**
91
+
92
+ All convolutions are **1-D along time**; depth (12 levels) plus pooling yields an extensive receptive field
93
+ (reported sensitivity to ±6.75 min around each epoch; theoretical field ≈ 9.6 min at the deepest layer).
94
+ The decoder restores sample-level resolution before epoch aggregation.
95
+
96
+ - **Spatial (how channels are processed).**
97
+
98
+ Convolutions mix across the *channel* dimension jointly with time (no separate spatial operator). The system
99
+ is **montage-agnostic** (any reasonable EEG/EOG pair) and was trained across diverse cohorts/protocols,
100
+ supporting robustness to channel placement and hardware differences.
101
+
102
+ - **Spectral (how frequency content is captured).**
103
+
104
+ No explicit Fourier/wavelet transform is used; the **stack of temporal convolutions** acts as a learned
105
+ filter bank whose effective bandwidth grows with depth. The high-frequency decoder output (128 Hz)
106
+ retains fine temporal detail for the segment classifier.
107
+
108
+
109
+ .. rubric:: Attention / Sequential Modules
110
+
111
+ U-Sleep contains **no attention or recurrent units**; it is a *pure* feed-forward, fully convolutional
112
+ segmentation network inspired by U-Net/U-Time, favoring training stability and cross-dataset portability.
113
+
114
+
115
+ .. rubric:: Additional Mechanisms
116
+
117
+ - **U-Net lineage with task-specific head.** U-Sleep extends U-Time by being **deeper** (12 vs. 4 levels),
118
+ switching ReLU→**ELU**, using uniform pooling (2) at all depths, and replacing the linear combiner with a
119
+ **two-layer** pointwise head—improving capacity and resilience across datasets.
120
+ - **Arbitrary-length inference.** Thanks to full convolutionality and tiling-free design, entire nights can be
121
+ staged in a single pass on commodity hardware. Inputs shorter than ≈ 17.5 min may reduce performance by
122
+ limiting long-range context.
123
+ - **Complexity scaling (alpha).** Filter counts can be adjusted by a global **complexity factor** to trade accuracy
124
+ and memory (as described in the paper's topology table).
125
+
126
+
127
+ .. rubric:: Usage and Configuration
128
+
129
+ - **Practice.** Resample PSG to **128 Hz** and provide at least two channels (one EEG, one EOG). Choose epoch
130
+ length *i* (often 30 s); ensure windows long enough to exploit the model's receptive field (e.g., training on
131
+ ≥ 17.5 min chunks).
132
+
133
+
134
+ Parameters
135
+ ----------
136
+ n_chans : int
137
+ Number of EEG or EOG channels. Set to 2 in [1]_ (1 EEG, 1 EOG).
138
+ sfreq : float
139
+ EEG sampling frequency. Set to 128 in [1]_.
140
+ depth : int
141
+ Number of conv blocks in encoding layer (number of 2x2 max pools).
142
+ Note: each block halves the spatial dimensions of the features.
143
+ n_time_filters : int
144
+ Initial number of convolutional filters. Set to 5 in [1]_.
145
+ complexity_factor : float
146
+ Multiplicative factor for the number of channels at each layer of the U-Net.
147
+ Set to 2 in [1]_.
148
+ with_skip_connection : bool
149
+ If True, use skip connections in decoder blocks.
150
+ n_outputs : int
151
+ Number of outputs/classes. Set to 5.
152
+ input_window_seconds : float
153
+ Size of the input, in seconds. Set to 30 in [1]_.
154
+ time_conv_size_s : float
155
+ Size of the temporal convolution kernel, in seconds. Set to 9 / 128 in
156
+ [1]_.
157
+ ensure_odd_conv_size : bool
158
+ If True and the size of the convolutional kernel is an even number, one
159
+ will be added to it to ensure it is odd, so that the decoder blocks can
160
+ work. This can be useful when using different sampling rates from 128
161
+ or 100 Hz.
162
+ activation : nn.Module, default=nn.ELU
163
+ Activation function class to apply. Should be a PyTorch activation
164
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
165
+
166
+ References
167
+ ----------
168
+ .. [1] Perslev M, Darkner S, Kempfner L, Nikolic M, Jennum PJ, Igel C.
169
+ U-Sleep: resilient high-frequency sleep staging. *npj Digit. Med.* 4, 72 (2021).
170
+ https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ n_chans=None,
176
+ sfreq=None,
177
+ depth=12,
178
+ n_time_filters=5,
179
+ complexity_factor=1.67,
180
+ with_skip_connection=True,
181
+ n_outputs=5,
182
+ input_window_seconds=None,
183
+ time_conv_size_s=9 / 128,
184
+ ensure_odd_conv_size=False,
185
+ activation: type[nn.Module] = nn.ELU,
186
+ chs_info=None,
187
+ n_times=None,
188
+ ):
189
+ super().__init__(
190
+ n_outputs=n_outputs,
191
+ n_chans=n_chans,
192
+ chs_info=chs_info,
193
+ n_times=n_times,
194
+ input_window_seconds=input_window_seconds,
195
+ sfreq=sfreq,
196
+ )
197
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
198
+
199
+ self.mapping = {
200
+ "clf.3.weight": "final_layer.0.weight",
201
+ "clf.3.bias": "final_layer.0.bias",
202
+ "clf.5.weight": "final_layer.2.weight",
203
+ "clf.5.bias": "final_layer.2.bias",
204
+ }
205
+
206
+ max_pool_size = 2 # Hardcoded to avoid dimensional errors
207
+ time_conv_size = int(np.round(time_conv_size_s * self.sfreq))
208
+ if time_conv_size % 2 == 0:
209
+ if ensure_odd_conv_size:
210
+ time_conv_size += 1
211
+ else:
212
+ raise ValueError(
213
+ "time_conv_size must be an odd number to accommodate the "
214
+ "upsampling step in the decoder blocks."
215
+ )
216
+
217
+ channels = [self.n_chans]
218
+ n_filters = n_time_filters
219
+ for _ in range(depth + 1):
220
+ channels.append(int(n_filters * np.sqrt(complexity_factor)))
221
+ n_filters = int(n_filters * np.sqrt(2))
222
+ self.channels = channels
223
+
224
+ # Instantiate encoder
225
+ self.encoder_blocks = nn.ModuleList(
226
+ _EncoderBlock(
227
+ in_channels=channels[idx],
228
+ out_channels=channels[idx + 1],
229
+ kernel_size=time_conv_size,
230
+ downsample=max_pool_size,
231
+ activation=activation,
232
+ )
233
+ for idx in range(depth)
234
+ )
235
+
236
+ # Instantiate bottom (channels increase, temporal dim stays the same)
237
+ self.bottom = nn.Sequential(
238
+ nn.Conv1d(
239
+ in_channels=channels[-2],
240
+ out_channels=channels[-1],
241
+ kernel_size=time_conv_size,
242
+ padding=(time_conv_size - 1) // 2,
243
+ ), # preserves dimension
244
+ activation(),
245
+ nn.BatchNorm1d(num_features=channels[-1]),
246
+ )
247
+
248
+ # Instantiate decoder
249
+ channels_reverse = channels[::-1]
250
+ self.decoder_blocks = nn.ModuleList(
251
+ _DecoderBlock(
252
+ in_channels=channels_reverse[idx],
253
+ out_channels=channels_reverse[idx + 1],
254
+ kernel_size=time_conv_size,
255
+ upsample=max_pool_size,
256
+ with_skip_connection=with_skip_connection,
257
+ activation=activation,
258
+ )
259
+ for idx in range(depth)
260
+ )
261
+
262
+ # The temporal dimension remains unchanged
263
+ # (except through the AvgPooling which collapses it to 1)
264
+ # The spatial dimension is preserved from the end of the UNet, and is mapped to n_classes
265
+
266
+ self.clf = nn.Sequential(
267
+ nn.Conv1d(
268
+ in_channels=channels[1],
269
+ out_channels=channels[1],
270
+ kernel_size=1,
271
+ stride=1,
272
+ padding=0,
273
+ ), # output is (B, C, 1, S * T)
274
+ nn.Tanh(),
275
+ nn.AvgPool1d(self.n_times), # output is (B, C, S)
276
+ )
277
+
278
+ self.final_layer = nn.Sequential(
279
+ nn.Conv1d(
280
+ in_channels=channels[1],
281
+ out_channels=self.n_outputs,
282
+ kernel_size=1,
283
+ stride=1,
284
+ padding=0,
285
+ ), # output is (B, n_classes, S)
286
+ activation(),
287
+ nn.Conv1d(
288
+ in_channels=self.n_outputs,
289
+ out_channels=self.n_outputs,
290
+ kernel_size=1,
291
+ stride=1,
292
+ padding=0,
293
+ ),
294
+ nn.Identity(),
295
+ # output is (B, n_classes, S)
296
+ )
297
+
298
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
+ """If input x has shape (B, S, C, T), return y_pred of shape (B, n_classes, S).
300
+ If input x has shape (B, C, T), return y_pred of shape (B, n_classes).
301
+ """
302
+ # reshape input
303
+ if x.ndim == 4: # input x has shape (B, S, C, T)
304
+ x = x.permute(0, 2, 1, 3) # (B, C, S, T)
305
+ x = x.flatten(start_dim=2) # (B, C, S * T)
306
+
307
+ # encoder
308
+ residuals = []
309
+ for down in self.encoder_blocks:
310
+ x, res = down(x)
311
+ residuals.append(res)
312
+
313
+ # bottom
314
+ x = self.bottom(x)
315
+
316
+ # decoder
317
+ num_blocks = len(self.decoder_blocks) # statically known
318
+ for idx, dec in enumerate(self.decoder_blocks):
319
+ # pick the matching residual in reverse order
320
+ res = residuals[num_blocks - 1 - idx]
321
+ x = dec(x, res)
322
+
323
+ # classifier
324
+ x = self.clf(x)
325
+ y_pred = self.final_layer(x) # (B, n_classes, seq_length)
326
+
327
+ if y_pred.shape[-1] == 1: # seq_length of 1
328
+ y_pred = y_pred[:, :, 0]
329
+
330
+ return y_pred
331
+
332
+
333
+ class _EncoderBlock(nn.Module):
334
+ r"""Encoding block for a timeseries x of shape (B, C, T)."""
335
+
336
+ def __init__(
337
+ self,
338
+ in_channels=2,
339
+ out_channels=2,
340
+ kernel_size=9,
341
+ downsample=2,
342
+ activation: type[nn.Module] = nn.ELU,
343
+ ):
344
+ super().__init__()
345
+ self.in_channels = in_channels
346
+ self.out_channels = out_channels
347
+ self.kernel_size = kernel_size
348
+ self.downsample = downsample
349
+
350
+ self.block_prepool = nn.Sequential(
351
+ nn.Conv1d(
352
+ in_channels=in_channels,
353
+ out_channels=out_channels,
354
+ kernel_size=kernel_size,
355
+ padding="same",
356
+ ),
357
+ activation(),
358
+ nn.BatchNorm1d(num_features=out_channels),
359
+ )
360
+
361
+ self.pad = nn.ConstantPad1d(padding=1, value=0.0)
362
+ self.maxpool = nn.MaxPool1d(kernel_size=self.downsample, stride=self.downsample)
363
+
364
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
365
+ x = self.block_prepool(x)
366
+ residual = x
367
+ if x.shape[-1] % 2:
368
+ x = self.pad(x)
369
+ x = self.maxpool(x)
370
+ return x, residual
371
+
372
+
373
+ class _DecoderBlock(nn.Module):
374
+ r"""Decoding block for a timeseries x of shape (B, C, T)."""
375
+
376
+ def __init__(
377
+ self,
378
+ in_channels=2,
379
+ out_channels=2,
380
+ kernel_size=9,
381
+ upsample=2,
382
+ with_skip_connection=True,
383
+ activation: type[nn.Module] = nn.ELU,
384
+ ):
385
+ super().__init__()
386
+ self.in_channels = in_channels
387
+ self.out_channels = out_channels
388
+ self.kernel_size = kernel_size
389
+ self.upsample = upsample
390
+ self.with_skip_connection = with_skip_connection
391
+
392
+ self.block_preskip = nn.Sequential(
393
+ nn.Upsample(scale_factor=upsample),
394
+ nn.Conv1d(
395
+ in_channels=in_channels,
396
+ out_channels=out_channels,
397
+ kernel_size=2,
398
+ padding="same",
399
+ ),
400
+ activation(),
401
+ nn.BatchNorm1d(num_features=out_channels),
402
+ )
403
+ self.block_postskip = nn.Sequential(
404
+ nn.Conv1d(
405
+ in_channels=(
406
+ 2 * out_channels if with_skip_connection else out_channels
407
+ ),
408
+ out_channels=out_channels,
409
+ kernel_size=kernel_size,
410
+ padding="same",
411
+ ),
412
+ activation(),
413
+ nn.BatchNorm1d(num_features=out_channels),
414
+ )
415
+
416
+ def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
417
+ x = self.block_preskip(x)
418
+ if self.with_skip_connection:
419
+ x, residual = self._crop_tensors_to_match(
420
+ x, residual, axis=-1
421
+ ) # in case of mismatch
422
+ x = torch.cat([x, residual], dim=1) # (B, 2 * C, T)
423
+ x = self.block_postskip(x)
424
+ return x
425
+
426
+ @staticmethod
427
+ def _crop_tensors_to_match(
428
+ x1: torch.Tensor, x2: torch.Tensor, axis: int = -1
429
+ ) -> tuple[torch.Tensor, torch.Tensor]:
430
+ """Crops two tensors to their lowest-common-dimension along an axis."""
431
+ dim_cropped = min(x1.shape[axis], x2.shape[axis])
432
+
433
+ x1_cropped = torch.index_select(
434
+ x1, dim=axis, index=torch.arange(dim_cropped).to(device=x1.device)
435
+ )
436
+ x2_cropped = torch.index_select(
437
+ x2, dim=axis, index=torch.arange(dim_cropped).to(device=x1.device)
438
+ )
439
+ return x1_cropped, x2_cropped