braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -3,122 +3,33 @@
3
3
  #
4
4
  # License: BSD (3-clause)
5
5
 
6
+
6
7
  import numpy as np
7
8
  import torch
8
9
  from torch import nn
9
- from .base import EEGModuleMixin, deprecated_args
10
-
11
-
12
- def _crop_tensors_to_match(x1, x2, axis=-1):
13
- """Crops two tensors to their lowest-common-dimension along an axis."""
14
- dim_cropped = min(x1.shape[axis], x2.shape[axis])
15
-
16
- x1_cropped = torch.index_select(
17
- x1, dim=axis,
18
- index=torch.arange(dim_cropped).to(device=x1.device)
19
- )
20
- x2_cropped = torch.index_select(
21
- x2, dim=axis,
22
- index=torch.arange(dim_cropped).to(device=x1.device)
23
- )
24
- return x1_cropped, x2_cropped
25
-
26
-
27
- class _EncoderBlock(nn.Module):
28
- """Encoding block for a timeseries x of shape (B, C, T)."""
29
-
30
- def __init__(self,
31
- in_channels=2,
32
- out_channels=2,
33
- kernel_size=9,
34
- downsample=2):
35
- super().__init__()
36
- self.in_channels = in_channels
37
- self.out_channels = out_channels
38
- self.kernel_size = kernel_size
39
- self.downsample = downsample
40
-
41
- self.block_prepool = nn.Sequential(
42
- nn.Conv1d(in_channels=in_channels,
43
- out_channels=out_channels,
44
- kernel_size=kernel_size,
45
- padding='same'),
46
- nn.ELU(),
47
- nn.BatchNorm1d(num_features=out_channels),
48
- )
49
-
50
- self.pad = nn.ConstantPad1d(padding=1, value=0)
51
- self.maxpool = nn.MaxPool1d(
52
- kernel_size=self.downsample, stride=self.downsample)
53
-
54
- def forward(self, x):
55
- x = self.block_prepool(x)
56
- residual = x
57
- if x.shape[-1] % 2:
58
- x = self.pad(x)
59
- x = self.maxpool(x)
60
- return x, residual
61
-
62
-
63
- class _DecoderBlock(nn.Module):
64
- """Decoding block for a timeseries x of shape (B, C, T)."""
65
-
66
- def __init__(self,
67
- in_channels=2,
68
- out_channels=2,
69
- kernel_size=9,
70
- upsample=2,
71
- with_skip_connection=True):
72
- super().__init__()
73
- self.in_channels = in_channels
74
- self.out_channels = out_channels
75
- self.kernel_size = kernel_size
76
- self.upsample = upsample
77
- self.with_skip_connection = with_skip_connection
78
-
79
- self.block_preskip = nn.Sequential(
80
- nn.Upsample(scale_factor=upsample),
81
- nn.Conv1d(in_channels=in_channels,
82
- out_channels=out_channels,
83
- kernel_size=2,
84
- padding='same'),
85
- nn.ELU(),
86
- nn.BatchNorm1d(num_features=out_channels),
87
- )
88
- self.block_postskip = nn.Sequential(
89
- nn.Conv1d(
90
- in_channels=(
91
- 2 * out_channels if with_skip_connection else out_channels),
92
- out_channels=out_channels,
93
- kernel_size=kernel_size,
94
- padding='same'),
95
- nn.ELU(),
96
- nn.BatchNorm1d(num_features=out_channels),
97
- )
98
10
 
99
- def forward(self, x, residual):
100
- x = self.block_preskip(x)
101
- if self.with_skip_connection:
102
- x, residual = _crop_tensors_to_match(x, residual, axis=-1) # in case of mismatch
103
- x = torch.cat([x, residual], axis=1) # (B, 2 * C, T)
104
- x = self.block_postskip(x)
105
- return x
11
+ from braindecode.models.base import EEGModuleMixin
106
12
 
107
13
 
108
14
  class USleep(EEGModuleMixin, nn.Module):
109
- """Sleep staging architecture from Perslev et al 2021.
15
+ """
16
+ Sleep staging architecture from Perslev et al. (2021) [1]_.
17
+
18
+ .. figure:: https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41746-021-00440-5/MediaObjects/41746_2021_440_Fig2_HTML.png
19
+ :align: center
20
+ :alt: USleep Architecture
110
21
 
111
22
  U-Net (autoencoder with skip connections) feature-extractor for sleep
112
23
  staging described in [1]_.
113
24
 
114
25
  For the encoder ('down'):
115
- -- the temporal dimension shrinks (via maxpooling in the time-domain)
116
- -- the spatial dimension expands (via more conv1d filters in the
117
- time-domain)
26
+ - the temporal dimension shrinks (via maxpooling in the time-domain)
27
+ - the spatial dimension expands (via more conv1d filters in the time-domain)
28
+
118
29
  For the decoder ('up'):
119
- -- the temporal dimension expands (via upsampling in the time-domain)
120
- -- the spatial dimension shrinks (via fewer conv1d filters in the
121
- time-domain)
30
+ - the temporal dimension expands (via upsampling in the time-domain)
31
+ - the spatial dimension shrinks (via fewer conv1d filters in the time-domain)
32
+
122
33
  Both do so at exponential rates.
123
34
 
124
35
  Parameters
@@ -128,12 +39,12 @@ class USleep(EEGModuleMixin, nn.Module):
128
39
  sfreq : float
129
40
  EEG sampling frequency. Set to 128 in [1]_.
130
41
  depth : int
131
- Number of conv blocks in encoding layer (number of 2x2 max pools)
132
- Note: each block halve the spatial dimensions of the features.
42
+ Number of conv blocks in encoding layer (number of 2x2 max pools).
43
+ Note: each block halves the spatial dimensions of the features.
133
44
  n_time_filters : int
134
45
  Initial number of convolutional filters. Set to 5 in [1]_.
135
46
  complexity_factor : float
136
- Multiplicative factor for number of channels at each layer of the U-Net.
47
+ Multiplicative factor for the number of channels at each layer of the U-Net.
137
48
  Set to 2 in [1]_.
138
49
  with_skip_connection : bool
139
50
  If True, use skip connections in decoder blocks.
@@ -147,47 +58,35 @@ class USleep(EEGModuleMixin, nn.Module):
147
58
  ensure_odd_conv_size : bool
148
59
  If True and the size of the convolutional kernel is an even number, one
149
60
  will be added to it to ensure it is odd, so that the decoder blocks can
150
- work. This can ne useful when using different sampling rates from 128
61
+ work. This can be useful when using different sampling rates from 128
151
62
  or 100 Hz.
152
- in_chans : int
153
- Alias for n_chans.
154
- n_classes : int
155
- Alias for n_outputs.
156
- input_size_s : float
157
- Alias for input_window_seconds.
63
+ activation : nn.Module, default=nn.ELU
64
+ Activation function class to apply. Should be a PyTorch activation
65
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
158
66
 
159
67
  References
160
68
  ----------
161
69
  .. [1] Perslev M, Darkner S, Kempfner L, Nikolic M, Jennum PJ, Igel C.
162
- U-Sleep: resilient high-frequency sleep staging. npj Digit. Med. 4, 72 (2021).
163
- https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
70
+ U-Sleep: resilient high-frequency sleep staging. *npj Digit. Med.* 4, 72 (2021).
71
+ https://github.com/perslev/U-Time/blob/master/utime/models/usleep.py
164
72
  """
165
73
 
166
74
  def __init__(
167
- self,
168
- n_chans=2,
169
- sfreq=128,
170
- depth=12,
171
- n_time_filters=5,
172
- complexity_factor=1.67,
173
- with_skip_connection=True,
174
- n_outputs=5,
175
- input_window_seconds=30,
176
- time_conv_size_s=9 / 128,
177
- ensure_odd_conv_size=False,
178
- chs_info=None,
179
- n_times=None,
180
- in_chans=None,
181
- n_classes=None,
182
- input_size_s=None,
183
- add_log_softmax=False,
75
+ self,
76
+ n_chans=None,
77
+ sfreq=None,
78
+ depth=12,
79
+ n_time_filters=5,
80
+ complexity_factor=1.67,
81
+ with_skip_connection=True,
82
+ n_outputs=5,
83
+ input_window_seconds=None,
84
+ time_conv_size_s=9 / 128,
85
+ ensure_odd_conv_size=False,
86
+ activation: nn.Module = nn.ELU,
87
+ chs_info=None,
88
+ n_times=None,
184
89
  ):
185
- n_chans, n_outputs, input_window_seconds = deprecated_args(
186
- self,
187
- ("in_chans", "n_chans", in_chans, n_chans),
188
- ("n_classes", "n_outputs", n_classes, n_outputs),
189
- ("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
190
- )
191
90
  super().__init__(
192
91
  n_outputs=n_outputs,
193
92
  n_chans=n_chans,
@@ -195,16 +94,14 @@ class USleep(EEGModuleMixin, nn.Module):
195
94
  n_times=n_times,
196
95
  input_window_seconds=input_window_seconds,
197
96
  sfreq=sfreq,
198
- add_log_softmax=add_log_softmax,
199
97
  )
200
98
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
201
- del in_chans, n_classes, input_size_s
202
99
 
203
100
  self.mapping = {
204
- 'clf.3.weight': 'final_layer.0.weight',
205
- 'clf.3.bias': 'final_layer.0.bias',
206
- 'clf.5.weight': 'final_layer.2.weight',
207
- 'clf.5.bias': 'final_layer.2.bias'
101
+ "clf.3.weight": "final_layer.0.weight",
102
+ "clf.3.bias": "final_layer.0.bias",
103
+ "clf.5.weight": "final_layer.2.weight",
104
+ "clf.5.bias": "final_layer.2.bias",
208
105
  }
209
106
 
210
107
  max_pool_size = 2 # Hardcoded to avoid dimensional errors
@@ -214,8 +111,9 @@ class USleep(EEGModuleMixin, nn.Module):
214
111
  time_conv_size += 1
215
112
  else:
216
113
  raise ValueError(
217
- 'time_conv_size must be an odd number to accommodate the '
218
- 'upsampling step in the decoder blocks.')
114
+ "time_conv_size must be an odd number to accommodate the "
115
+ "upsampling step in the decoder blocks."
116
+ )
219
117
 
220
118
  channels = [self.n_chans]
221
119
  n_filters = n_time_filters
@@ -225,38 +123,42 @@ class USleep(EEGModuleMixin, nn.Module):
225
123
  self.channels = channels
226
124
 
227
125
  # Instantiate encoder
228
- encoder = list()
229
- for idx in range(depth):
230
- encoder += [
231
- _EncoderBlock(in_channels=channels[idx],
232
- out_channels=channels[idx + 1],
233
- kernel_size=time_conv_size,
234
- downsample=max_pool_size)
235
- ]
236
- self.encoder = nn.Sequential(*encoder)
126
+ self.encoder_blocks = nn.ModuleList(
127
+ _EncoderBlock(
128
+ in_channels=channels[idx],
129
+ out_channels=channels[idx + 1],
130
+ kernel_size=time_conv_size,
131
+ downsample=max_pool_size,
132
+ activation=activation,
133
+ )
134
+ for idx in range(depth)
135
+ )
237
136
 
238
137
  # Instantiate bottom (channels increase, temporal dim stays the same)
239
138
  self.bottom = nn.Sequential(
240
- nn.Conv1d(in_channels=channels[-2],
241
- out_channels=channels[-1],
242
- kernel_size=time_conv_size,
243
- padding=(time_conv_size - 1) // 2), # preserves dimension
244
- nn.ELU(),
139
+ nn.Conv1d(
140
+ in_channels=channels[-2],
141
+ out_channels=channels[-1],
142
+ kernel_size=time_conv_size,
143
+ padding=(time_conv_size - 1) // 2,
144
+ ), # preserves dimension
145
+ activation(),
245
146
  nn.BatchNorm1d(num_features=channels[-1]),
246
147
  )
247
148
 
248
149
  # Instantiate decoder
249
- decoder = list()
250
150
  channels_reverse = channels[::-1]
251
- for idx in range(depth):
252
- decoder += [
253
- _DecoderBlock(in_channels=channels_reverse[idx],
254
- out_channels=channels_reverse[idx + 1],
255
- kernel_size=time_conv_size,
256
- upsample=max_pool_size,
257
- with_skip_connection=with_skip_connection)
258
- ]
259
- self.decoder = nn.Sequential(*decoder)
151
+ self.decoder_blocks = nn.ModuleList(
152
+ _DecoderBlock(
153
+ in_channels=channels_reverse[idx],
154
+ out_channels=channels_reverse[idx + 1],
155
+ kernel_size=time_conv_size,
156
+ upsample=max_pool_size,
157
+ with_skip_connection=with_skip_connection,
158
+ activation=activation,
159
+ )
160
+ for idx in range(depth)
161
+ )
260
162
 
261
163
  # The temporal dimension remains unchanged
262
164
  # (except through the AvgPooling which collapses it to 1)
@@ -282,7 +184,7 @@ class USleep(EEGModuleMixin, nn.Module):
282
184
  stride=1,
283
185
  padding=0,
284
186
  ), # output is (B, n_classes, S)
285
- nn.ELU(),
187
+ activation(),
286
188
  nn.Conv1d(
287
189
  in_channels=self.n_outputs,
288
190
  out_channels=self.n_outputs,
@@ -290,11 +192,11 @@ class USleep(EEGModuleMixin, nn.Module):
290
192
  stride=1,
291
193
  padding=0,
292
194
  ),
293
- nn.LogSoftmax(dim=1) if self.add_log_softmax else nn.Identity(),
195
+ nn.Identity(),
294
196
  # output is (B, n_classes, S)
295
197
  )
296
198
 
297
- def forward(self, x):
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
298
200
  """If input x has shape (B, S, C, T), return y_pred of shape (B, n_classes, S).
299
201
  If input x has shape (B, C, T), return y_pred of shape (B, n_classes).
300
202
  """
@@ -305,7 +207,7 @@ class USleep(EEGModuleMixin, nn.Module):
305
207
 
306
208
  # encoder
307
209
  residuals = []
308
- for down in self.encoder:
210
+ for down in self.encoder_blocks:
309
211
  x, res = down(x)
310
212
  residuals.append(res)
311
213
 
@@ -313,9 +215,11 @@ class USleep(EEGModuleMixin, nn.Module):
313
215
  x = self.bottom(x)
314
216
 
315
217
  # decoder
316
- residuals = residuals[::-1] # flip order
317
- for up, res in zip(self.decoder, residuals):
318
- x = up(x, res)
218
+ num_blocks = len(self.decoder_blocks) # statically known
219
+ for idx, dec in enumerate(self.decoder_blocks):
220
+ # pick the matching residual in reverse order
221
+ res = residuals[num_blocks - 1 - idx]
222
+ x = dec(x, res)
319
223
 
320
224
  # classifier
321
225
  x = self.clf(x)
@@ -325,3 +229,112 @@ class USleep(EEGModuleMixin, nn.Module):
325
229
  y_pred = y_pred[:, :, 0]
326
230
 
327
231
  return y_pred
232
+
233
+
234
+ class _EncoderBlock(nn.Module):
235
+ """Encoding block for a timeseries x of shape (B, C, T)."""
236
+
237
+ def __init__(
238
+ self,
239
+ in_channels=2,
240
+ out_channels=2,
241
+ kernel_size=9,
242
+ downsample=2,
243
+ activation: nn.Module = nn.ELU,
244
+ ):
245
+ super().__init__()
246
+ self.in_channels = in_channels
247
+ self.out_channels = out_channels
248
+ self.kernel_size = kernel_size
249
+ self.downsample = downsample
250
+
251
+ self.block_prepool = nn.Sequential(
252
+ nn.Conv1d(
253
+ in_channels=in_channels,
254
+ out_channels=out_channels,
255
+ kernel_size=kernel_size,
256
+ padding="same",
257
+ ),
258
+ activation(),
259
+ nn.BatchNorm1d(num_features=out_channels),
260
+ )
261
+
262
+ self.pad = nn.ConstantPad1d(padding=1, value=0.0)
263
+ self.maxpool = nn.MaxPool1d(kernel_size=self.downsample, stride=self.downsample)
264
+
265
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
266
+ x = self.block_prepool(x)
267
+ residual = x
268
+ if x.shape[-1] % 2:
269
+ x = self.pad(x)
270
+ x = self.maxpool(x)
271
+ return x, residual
272
+
273
+
274
+ class _DecoderBlock(nn.Module):
275
+ """Decoding block for a timeseries x of shape (B, C, T)."""
276
+
277
+ def __init__(
278
+ self,
279
+ in_channels=2,
280
+ out_channels=2,
281
+ kernel_size=9,
282
+ upsample=2,
283
+ with_skip_connection=True,
284
+ activation: nn.Module = nn.ELU,
285
+ ):
286
+ super().__init__()
287
+ self.in_channels = in_channels
288
+ self.out_channels = out_channels
289
+ self.kernel_size = kernel_size
290
+ self.upsample = upsample
291
+ self.with_skip_connection = with_skip_connection
292
+
293
+ self.block_preskip = nn.Sequential(
294
+ nn.Upsample(scale_factor=upsample),
295
+ nn.Conv1d(
296
+ in_channels=in_channels,
297
+ out_channels=out_channels,
298
+ kernel_size=2,
299
+ padding="same",
300
+ ),
301
+ activation(),
302
+ nn.BatchNorm1d(num_features=out_channels),
303
+ )
304
+ self.block_postskip = nn.Sequential(
305
+ nn.Conv1d(
306
+ in_channels=(
307
+ 2 * out_channels if with_skip_connection else out_channels
308
+ ),
309
+ out_channels=out_channels,
310
+ kernel_size=kernel_size,
311
+ padding="same",
312
+ ),
313
+ activation(),
314
+ nn.BatchNorm1d(num_features=out_channels),
315
+ )
316
+
317
+ def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
318
+ x = self.block_preskip(x)
319
+ if self.with_skip_connection:
320
+ x, residual = self._crop_tensors_to_match(
321
+ x, residual, axis=-1
322
+ ) # in case of mismatch
323
+ x = torch.cat([x, residual], dim=1) # (B, 2 * C, T)
324
+ x = self.block_postskip(x)
325
+ return x
326
+
327
+ @staticmethod
328
+ def _crop_tensors_to_match(
329
+ x1: torch.Tensor, x2: torch.Tensor, axis: int = -1
330
+ ) -> tuple[torch.Tensor, torch.Tensor]:
331
+ """Crops two tensors to their lowest-common-dimension along an axis."""
332
+ dim_cropped = min(x1.shape[axis], x2.shape[axis])
333
+
334
+ x1_cropped = torch.index_select(
335
+ x1, dim=axis, index=torch.arange(dim_cropped).to(device=x1.device)
336
+ )
337
+ x2_cropped = torch.index_select(
338
+ x2, dim=axis, index=torch.arange(dim_cropped).to(device=x1.device)
339
+ )
340
+ return x1_cropped, x2_cropped