braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,126 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import ConstantPad2d
8
+
9
+ from braindecode.models.deep4 import Deep4Net
10
+ from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
11
+
12
+
13
+ class HybridNet(nn.Module):
14
+ r"""Hybrid ConvNet model from Schirrmeister, R T et al (2017) [Schirrmeister2017]_.
15
+
16
+ See [Schirrmeister2017]_ for details.
17
+
18
+ References
19
+ ----------
20
+ .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
21
+ L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
22
+ & Ball, T. (2017).
23
+ Deep learning with convolutional neural networks for EEG decoding and
24
+ visualization.
25
+ Human Brain Mapping , Aug. 2017.
26
+ Online: http://dx.doi.org/10.1002/hbm.23730
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ n_chans=None,
32
+ n_outputs=None,
33
+ n_times=None,
34
+ input_window_seconds=None,
35
+ sfreq=None,
36
+ chs_info=None,
37
+ activation: type[nn.Module] = nn.ELU,
38
+ drop_prob: float = 0.5,
39
+ ):
40
+ super().__init__()
41
+ self.mapping = {
42
+ "final_conv.weight": "final_layer.weight",
43
+ "final_conv.bias": "final_layer.bias",
44
+ }
45
+
46
+ deep_model = Deep4Net(
47
+ n_chans=n_chans,
48
+ n_outputs=n_outputs,
49
+ n_filters_time=20,
50
+ n_filters_spat=30,
51
+ n_filters_2=40,
52
+ n_filters_3=50,
53
+ n_filters_4=60,
54
+ n_times=n_times,
55
+ input_window_seconds=input_window_seconds,
56
+ sfreq=sfreq,
57
+ chs_info=chs_info,
58
+ final_conv_length=2,
59
+ activation_first_conv_nonlin=activation,
60
+ activation_later_conv_nonlin=activation,
61
+ drop_prob=drop_prob,
62
+ )
63
+ shallow_model = ShallowFBCSPNet(
64
+ n_chans=n_chans,
65
+ n_outputs=n_outputs,
66
+ n_times=n_times,
67
+ input_window_seconds=input_window_seconds,
68
+ sfreq=sfreq,
69
+ chs_info=chs_info,
70
+ n_filters_time=30,
71
+ n_filters_spat=40,
72
+ filter_time_length=28,
73
+ final_conv_length=29,
74
+ drop_prob=drop_prob,
75
+ )
76
+ new_conv_layer = nn.Conv2d(
77
+ deep_model.final_layer.conv_classifier.in_channels,
78
+ 60,
79
+ kernel_size=deep_model.final_layer.conv_classifier.kernel_size,
80
+ stride=deep_model.final_layer.conv_classifier.stride,
81
+ )
82
+ deep_model.final_layer = new_conv_layer
83
+
84
+ new_conv_layer = nn.Conv2d(
85
+ shallow_model.final_layer.conv_classifier.in_channels,
86
+ 40,
87
+ kernel_size=shallow_model.final_layer.conv_classifier.kernel_size,
88
+ stride=shallow_model.final_layer.conv_classifier.stride,
89
+ )
90
+ shallow_model.final_layer = new_conv_layer
91
+
92
+ deep_model.to_dense_prediction_model()
93
+ shallow_model.to_dense_prediction_model()
94
+ self.reduced_deep_model = deep_model
95
+ self.reduced_shallow_model = shallow_model
96
+
97
+ self.final_layer = nn.Sequential(
98
+ nn.Conv2d(100, n_outputs, kernel_size=(1, 1), stride=1),
99
+ nn.Identity(),
100
+ )
101
+
102
+ def forward(self, x):
103
+ """Forward pass.
104
+
105
+ Parameters
106
+ ----------
107
+ x: torch.Tensor
108
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
109
+ """
110
+ deep_out = self.reduced_deep_model(x)
111
+ shallow_out = self.reduced_shallow_model(x)
112
+
113
+ n_diff_deep_shallow = deep_out.size()[2] - shallow_out.size()[2]
114
+
115
+ if n_diff_deep_shallow < 0:
116
+ deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)(deep_out)
117
+ elif n_diff_deep_shallow > 0:
118
+ shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)(shallow_out)
119
+
120
+ merged_out = torch.cat((deep_out, shallow_out), dim=1)
121
+
122
+ output = self.final_layer(merged_out)
123
+
124
+ squeezed = output.squeeze(3)
125
+
126
+ return squeezed
@@ -0,0 +1,443 @@
1
+ """IFNet Neural Network.
2
+
3
+ Authors: Jiaheng Wang
4
+ Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
5
+ License: MIT (https://github.com/Jiaheng-Wang/IFNet/blob/main/LICENSE)
6
+
7
+ J. Wang, L. Yao and Y. Wang, "IFNet: An Interactive Frequency Convolutional
8
+ Neural Network for Enhancing Motor Imagery Decoding from EEG," in IEEE
9
+ Transactions on Neural Systems and Rehabilitation Engineering,
10
+ doi: 10.1109/TNSRE.2023.3257319.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Optional, Sequence
16
+
17
+ import torch
18
+ from einops.layers.torch import Rearrange
19
+ from mne.utils import warn
20
+ from torch import nn
21
+ from torch.nn.init import trunc_normal_
22
+
23
+ from braindecode.models.base import EEGModuleMixin
24
+ from braindecode.modules import (
25
+ FilterBankLayer,
26
+ LinearWithConstraint,
27
+ LogPowerLayer,
28
+ )
29
+
30
+
31
+ class IFNet(EEGModuleMixin, nn.Module):
32
+ r"""IFNetV2 from Wang J et al (2023) [ifnet]_.
33
+
34
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
35
+
36
+ .. figure:: https://raw.githubusercontent.com/Jiaheng-Wang/IFNet/main/IFNet.png
37
+ :align: center
38
+ :alt: IFNetV2 Architecture
39
+
40
+ Overview of the Interactive Frequency Convolutional Neural Network architecture.
41
+
42
+ IFNetV2 is designed to effectively capture spectro-spatial-temporal
43
+ features for motor imagery decoding from EEG data. The model consists of
44
+ three stages: Spectro-Spatial Feature Representation, Cross-Frequency
45
+ Interactions, and Classification.
46
+
47
+ - **Spectro-Spatial Feature Representation**: The raw EEG signals are
48
+ filtered into two characteristic frequency bands: low (4-16 Hz) and
49
+ high (16-40 Hz), covering the most relevant motor imagery bands.
50
+ Spectro-spatial features are then extracted through 1D point-wise
51
+ spatial convolution followed by temporal convolution.
52
+
53
+ - **Cross-Frequency Interactions**: The extracted spectro-spatial
54
+ features from each frequency band are combined through an element-wise
55
+ summation operation, which enhances feature representation while
56
+ preserving distinct characteristics.
57
+
58
+ - **Classification**: The aggregated spectro-spatial features are further
59
+ reduced through temporal average pooling and passed through a fully
60
+ connected layer followed by a softmax operation to generate output
61
+ probabilities for each class.
62
+
63
+ Notes
64
+ -----
65
+ This implementation is not guaranteed to be correct, has not been checked
66
+ by original authors, only reimplemented from the paper description and
67
+ Torch source code [ifnetv2code]_. Version 2 is present only in the repository,
68
+ and the main difference is one pooling layer, describe at the TABLE VII
69
+ from the paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10070810
70
+
71
+
72
+ Parameters
73
+ ----------
74
+ bands : list[tuple[int, int]] or int or None, default=[[4, 16], (16, 40)]
75
+ Frequency bands for filtering.
76
+ out_planes : int, default=64
77
+ Number of output feature dimensions.
78
+ kernel_sizes : tuple of int, default=(63, 31)
79
+ List of kernel sizes for temporal convolutions.
80
+ patch_size : int, default=125
81
+ Size of the patches for temporal segmentation.
82
+ drop_prob : float, default=0.5
83
+ Dropout probability.
84
+ activation : nn.Module, default=nn.GELU
85
+ Activation function after the InterFrequency Layer.
86
+ verbose : bool, default=False
87
+ Verbose to control the filtering layer
88
+ filter_parameters : dict, default={}
89
+ Additional parameters for the filter bank layer.
90
+
91
+ References
92
+ ----------
93
+ .. [ifnet] Wang, J., Yao, L., & Wang, Y. (2023). IFNet: An interactive
94
+ frequency convolutional neural network for enhancing motor imagery
95
+ decoding from EEG. IEEE Transactions on Neural Systems and
96
+ Rehabilitation Engineering, 31, 1900-1911.
97
+ .. [ifnetv2code] Wang, J., Yao, L., & Wang, Y. (2023). IFNet: An interactive
98
+ frequency convolutional neural network for enhancing motor imagery
99
+ decoding from EEG.
100
+ https://github.com/Jiaheng-Wang/IFNet
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ # Braindecode parameters
106
+ n_chans=None,
107
+ n_outputs=None,
108
+ n_times=None,
109
+ chs_info=None,
110
+ input_window_seconds=None,
111
+ sfreq=None,
112
+ # Model-specific parameters
113
+ bands: list[tuple[float, float]] | int | None = [(4.0, 16.0), (16, 40)],
114
+ n_filters_spat: int = 64,
115
+ kernel_sizes: tuple[int, int] = (63, 31),
116
+ stride_factor: int = 8,
117
+ drop_prob: float = 0.5,
118
+ linear_max_norm: float = 0.5,
119
+ activation: type[nn.Module] = nn.GELU,
120
+ verbose: bool = False,
121
+ filter_parameters: Optional[dict] = None,
122
+ ):
123
+ super().__init__(
124
+ n_chans=n_chans,
125
+ n_outputs=n_outputs,
126
+ chs_info=chs_info,
127
+ n_times=n_times,
128
+ input_window_seconds=input_window_seconds,
129
+ sfreq=sfreq,
130
+ )
131
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
132
+
133
+ self.bands = bands
134
+ self.n_filters_spat = n_filters_spat
135
+ self.stride_factor = stride_factor
136
+ self.kernel_sizes = kernel_sizes
137
+ self.verbose = verbose
138
+ self.filter_parameters = filter_parameters
139
+ self.drop_prob = drop_prob
140
+ self.activation = activation
141
+ self.linear_max_norm = linear_max_norm
142
+ self.filter_parameters = filter_parameters or {}
143
+
144
+ # Layers
145
+ # Following paper nomenclature
146
+ self.spectral_filtering = FilterBankLayer(
147
+ n_chans=self.n_chans,
148
+ sfreq=self.sfreq,
149
+ band_filters=self.bands,
150
+ verbose=verbose,
151
+ **self.filter_parameters,
152
+ )
153
+ # As we have an internal process to create the bands,
154
+ # we get the values from the filterbank
155
+ self.n_bands = self.spectral_filtering.n_bands
156
+
157
+ # My interpretation from the TABLE VII IFNet Architecture from the
158
+ # paper.
159
+ self.ensuredim = Rearrange(
160
+ "batch nbands chans time -> batch (nbands chans) time"
161
+ )
162
+
163
+ # SpatioTemporal Feature Block
164
+ self.feature_block = _SpatioTemporalFeatureBlock(
165
+ in_channels=self.n_chans * self.n_bands,
166
+ out_channels=self.n_filters_spat,
167
+ kernel_sizes=self.kernel_sizes,
168
+ stride_factor=self.stride_factor,
169
+ n_bands=self.n_bands,
170
+ drop_prob=self.drop_prob,
171
+ activation=self.activation,
172
+ n_times=self.n_times,
173
+ )
174
+
175
+ # Final classification layer
176
+ self.final_layer = LinearWithConstraint(
177
+ in_features=self.n_filters_spat * stride_factor,
178
+ out_features=self.n_outputs,
179
+ max_norm=self.linear_max_norm,
180
+ )
181
+
182
+ self.flatten = Rearrange("batch ... -> batch (...)")
183
+
184
+ # Initialize parameters
185
+ self._initialize_weights(self)
186
+
187
+ @staticmethod
188
+ def _initialize_weights(m):
189
+ """Initializes weights of the network.
190
+
191
+ Parameters
192
+ ----------
193
+ m : nn.Module
194
+ Module to initialize.
195
+ """
196
+ if isinstance(m, nn.Linear):
197
+ trunc_normal_(m.weight, std=0.01)
198
+ if m.bias is not None:
199
+ nn.init.constant_(m.bias, 0)
200
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d)):
201
+ if m.weight is not None:
202
+ nn.init.constant_(m.weight, 1.0)
203
+ if m.bias is not None:
204
+ nn.init.constant_(m.bias, 0)
205
+ elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
206
+ trunc_normal_(m.weight, std=0.01)
207
+ if m.bias is not None:
208
+ nn.init.constant_(m.bias, 0)
209
+
210
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
211
+ """Forward pass of IFNet.
212
+
213
+ Parameters
214
+ ----------
215
+ x : torch.Tensor
216
+ Input tensor with shape (batch_size, n_chans, n_times).
217
+
218
+ Returns
219
+ -------
220
+ torch.Tensor
221
+ Output tensor with shape (batch_size, n_outputs).
222
+ """
223
+ # Pass through the spectral filtering layer
224
+ x = self.spectral_filtering(x)
225
+ # x is now of shape (batch_size, n_bands, n_chans, n_times)
226
+ x = self.ensuredim(x)
227
+ # x is now of shape (batch_size, n_bands * n_chans, n_times)
228
+
229
+ # Pass through the feature block
230
+ x = self.feature_block(x)
231
+
232
+ # Flatten and pass through the final layer
233
+ x = self.flatten(x)
234
+
235
+ x = self.final_layer(x)
236
+
237
+ return x
238
+
239
+
240
+ class _InterFrequencyModule(nn.Module):
241
+ r"""Module that combines outputs from different frequency bands."""
242
+
243
+ def __init__(self, activation: type[nn.Module] = nn.GELU):
244
+ """
245
+
246
+ Parameters
247
+ ----------
248
+ activation: nn.Module
249
+ Activation function for the InterFrequency Module
250
+
251
+ """
252
+ super().__init__()
253
+
254
+ self.activation = activation()
255
+
256
+ def forward(self, x_list: list[torch.Tensor]) -> torch.Tensor:
257
+ """Forward pass.
258
+
259
+ Parameters
260
+ ----------
261
+ x_list : list of torch.Tensor
262
+ List of tensors to be combined.
263
+
264
+ Returns
265
+ -------
266
+ torch.Tensor
267
+ Combined tensor after applying GELU activation.
268
+ """
269
+ x = torch.stack(x_list, dim=0).sum(dim=0)
270
+ x = self.activation(x)
271
+ return x
272
+
273
+
274
+ class _SpatioTemporalFeatureBlock(nn.Module):
275
+ r"""SpatioTemporal Feature Block consisting of spatial and temporal convolutions."""
276
+
277
+ def __init__(
278
+ self,
279
+ n_times: int,
280
+ in_channels: int,
281
+ out_channels: int = 64,
282
+ kernel_sizes: Sequence[int] = [63, 31],
283
+ stride_factor: int = 8,
284
+ n_bands: int = 2,
285
+ drop_prob: float = 0.5,
286
+ activation: type[nn.Module] = nn.GELU,
287
+ dim: int = 3,
288
+ ):
289
+ """
290
+ Parameters
291
+ ----------
292
+ in_channels : int
293
+ Number of input channels.
294
+ out_channels : int, default=64
295
+ Number of output channels.
296
+ kernel_sizes : list of int, default=[63, 31]
297
+ List of kernel sizes for temporal convolutions.
298
+ stride_factor : int, default=4
299
+ Stride factor for temporal segmentation.
300
+ n_bands : int, default=2
301
+ Number of frequency bands or groups.
302
+ drop_prob : float, default=0.5
303
+ Dropout probability.
304
+ activation: nn.Module, default=nn.GELU
305
+ Activation function after the InterFrequency Layer
306
+ dim: int, default=3
307
+ Internal dimensional to apply the LogPowerLayer
308
+ """
309
+ super().__init__()
310
+ self.in_channels = in_channels
311
+ self.out_channels = out_channels
312
+ self.n_bands = n_bands
313
+ self.stride_factor = stride_factor
314
+ self.drop_prob = drop_prob
315
+ self.activation = activation
316
+ self.dim = dim
317
+ self.n_times = n_times
318
+ self.kernel_sizes = kernel_sizes
319
+
320
+ if self.n_bands != len(self.kernel_sizes):
321
+ warn(
322
+ f"Got {self.n_bands} bands, different from {len(self.kernel_sizes)} amount of "
323
+ "kernels to build the temporal convolution, we will apply "
324
+ "min(n_bands, len(self.kernel_size) to apply the convolution.",
325
+ UserWarning,
326
+ )
327
+ if self.n_bands > len(self.kernel_sizes):
328
+ self.n_bands = len(self.kernel_sizes)
329
+ warn(
330
+ f"Reducing number of bands to {len(self.kernel_sizes)} to match the number of kernels.",
331
+ UserWarning,
332
+ )
333
+ elif self.n_bands < len(self.kernel_sizes):
334
+ self.kernel_sizes = self.kernel_sizes[: self.n_bands]
335
+ warn(
336
+ f"Reducing number of kernels to {self.n_bands} to match the number of bands.",
337
+ UserWarning,
338
+ )
339
+
340
+ if self.n_times % self.stride_factor != 0:
341
+ warn(
342
+ f"Time dimension ({self.n_times}) is not divisible by"
343
+ f" stride_factor ({self.stride_factor}). Input will be padded.",
344
+ UserWarning,
345
+ )
346
+
347
+ out_channels_spatial = self.out_channels * self.n_bands
348
+
349
+ # Spatial convolution
350
+ self.spatial_conv = nn.Conv1d(
351
+ in_channels=self.in_channels,
352
+ out_channels=out_channels_spatial,
353
+ kernel_size=1,
354
+ groups=self.n_bands,
355
+ bias=False,
356
+ )
357
+ self.spatial_bn = nn.BatchNorm1d(out_channels_spatial)
358
+
359
+ self.unpack_bands = nn.Unflatten(
360
+ dim=1, unflattened_size=(self.n_bands, self.out_channels)
361
+ )
362
+
363
+ # Temporal convolutions for each radix
364
+ self.temporal_convs = nn.ModuleList()
365
+ for kernel_size in self.kernel_sizes:
366
+ self.temporal_convs.append(
367
+ nn.Sequential(
368
+ nn.Conv1d(
369
+ in_channels=self.out_channels,
370
+ out_channels=self.out_channels,
371
+ kernel_size=kernel_size,
372
+ padding=kernel_size // 2,
373
+ groups=self.out_channels,
374
+ bias=False,
375
+ ),
376
+ nn.BatchNorm1d(self.out_channels),
377
+ )
378
+ )
379
+ # Inter-frequency module
380
+ self.inter_frequency = _InterFrequencyModule(activation=self.activation)
381
+
382
+ if self.n_times % self.stride_factor != 0:
383
+ self.padding_size = stride_factor - (self.n_times % stride_factor)
384
+ self.n_times_padded = self.n_times + self.padding_size
385
+ self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
386
+ else:
387
+ self.padding_layer = nn.Identity()
388
+ self.n_times_padded = self.n_times
389
+
390
+ # Log-Power layer
391
+ self.log_power = LogPowerLayer(dim=self.dim) # type: ignore
392
+ # Dropout
393
+ self.dropout_layer = nn.Dropout(self.drop_prob)
394
+
395
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
396
+ """Forward pass.
397
+
398
+ Parameters
399
+ ----------
400
+ x : torch.Tensor
401
+ Input tensor of shape (batch_size, in_channels, n_times).
402
+
403
+ Returns
404
+ -------
405
+ torch.Tensor
406
+ Output tensor after processing.
407
+ """
408
+ batch_size, _, _ = x.shape
409
+
410
+ # Spatial convolution
411
+ x = self.spatial_conv(x)
412
+
413
+ x = self.spatial_bn(x)
414
+
415
+ # Split the output by bands for each frequency
416
+ x_split = self.unpack_bands(x)
417
+
418
+ x_t = []
419
+ for idx, conv in enumerate(self.temporal_convs):
420
+ x_t.append(conv(x_split[::, idx]))
421
+
422
+ # Inter-frequency interaction
423
+ x = self.inter_frequency(x_t)
424
+
425
+ # Reshape for temporal segmentation
426
+ x = self.padding_layer(x)
427
+ # x is now of shape (batch_size, ..., n_times_padded)
428
+
429
+ # Reshape for log-power computation
430
+ x = x.view(
431
+ batch_size,
432
+ self.out_channels,
433
+ self.stride_factor,
434
+ self.n_times_padded // self.stride_factor,
435
+ )
436
+
437
+ # Log-Power layer
438
+ x = self.log_power(x)
439
+
440
+ # Dropout
441
+ x = self.dropout_layer(x)
442
+
443
+ return x