braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,221 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from einops.layers.torch import Rearrange
7
+ from mne.utils import warn
8
+ from torch import Tensor, nn
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import (
12
+ Conv2dWithConstraint,
13
+ FilterBankLayer,
14
+ LinearWithConstraint,
15
+ LogVarLayer,
16
+ MaxLayer,
17
+ MeanLayer,
18
+ StdLayer,
19
+ VarLayer,
20
+ )
21
+
22
+ _valid_layers = {
23
+ "VarLayer": VarLayer,
24
+ "StdLayer": StdLayer,
25
+ "LogVarLayer": LogVarLayer,
26
+ "MeanLayer": MeanLayer,
27
+ "MaxLayer": MaxLayer,
28
+ }
29
+
30
+
31
+ class FBCNet(EEGModuleMixin, nn.Module):
32
+ """FBCNet from Mane, R et al (2021) [fbcnet2021]_.
33
+
34
+ .. figure:: https://raw.githubusercontent.com/ravikiran-mane/FBCNet/refs/heads/master/FBCNet-V2.png
35
+ :align: center
36
+ :alt: FBCNet Architecture
37
+
38
+ The FBCNet model applies spatial convolution and variance calculation along
39
+ the time axis, inspired by the Filter Bank Common Spatial Pattern (FBCSP)
40
+ algorithm.
41
+
42
+ Notes
43
+ -----
44
+ This implementation is not guaranteed to be correct and has not been checked
45
+ by the original authors; it has only been reimplemented from the paper
46
+ description and source code [fbcnetcode2021]_. There is a difference in the
47
+ activation function; in the paper, the ELU is used as the activation function,
48
+ but in the original code, SiLU is used. We followed the code.
49
+
50
+ Parameters
51
+ ----------
52
+ n_bands : int or None or list[tuple[int, int]]], default=9
53
+ Number of frequency bands. Could
54
+ n_filters_spat : int, default=32
55
+ Number of spatial filters for the first convolution.
56
+ n_dim: int, default=3
57
+ Number of dimensions for the temporal reductor
58
+ temporal_layer : str, default='LogVarLayer'
59
+ Type of temporal aggregator layer. Options: 'VarLayer', 'StdLayer',
60
+ 'LogVarLayer', 'MeanLayer', 'MaxLayer'.
61
+ stride_factor : int, default=4
62
+ Stride factor for reshaping.
63
+ activation : nn.Module, default=nn.SiLU
64
+ Activation function class to apply in Spatial Convolution Block.
65
+ cnn_max_norm : float, default=2.0
66
+ Maximum norm for the spatial convolution layer.
67
+ linear_max_norm : float, default=0.5
68
+ Maximum norm for the final linear layer.
69
+ filter_parameters: dict, default None
70
+ Parameters for the FilterBankLayer
71
+
72
+ References
73
+ ----------
74
+ .. [fbcnet2021] Mane, R., Chew, E., Chua, K., Ang, K. K., Robinson, N.,
75
+ Vinod, A. P., ... & Guan, C. (2021). FBCNet: A multi-view convolutional
76
+ neural network for brain-computer interface. preprint arXiv:2104.01233.
77
+ .. [fbcnetcode2021] Link to source-code:
78
+ https://github.com/ravikiran-mane/FBCNet
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ # Braindecode parameters
84
+ n_chans=None,
85
+ n_outputs=None,
86
+ chs_info=None,
87
+ n_times=None,
88
+ input_window_seconds=None,
89
+ sfreq=None,
90
+ # models parameters
91
+ n_bands=9,
92
+ n_filters_spat: int = 32,
93
+ temporal_layer: str = "LogVarLayer",
94
+ n_dim: int = 3,
95
+ stride_factor: int = 4,
96
+ activation: nn.Module = nn.SiLU,
97
+ linear_max_norm: float = 0.5,
98
+ cnn_max_norm: float = 2.0,
99
+ filter_parameters: dict[Any, Any] | None = None,
100
+ ):
101
+ super().__init__(
102
+ n_chans=n_chans,
103
+ n_outputs=n_outputs,
104
+ chs_info=chs_info,
105
+ n_times=n_times,
106
+ input_window_seconds=input_window_seconds,
107
+ sfreq=sfreq,
108
+ )
109
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
110
+
111
+ # Parameters
112
+ self.n_bands = n_bands
113
+ self.n_filters_spat = n_filters_spat
114
+ self.n_dim = n_dim
115
+ self.stride_factor = stride_factor
116
+ self.activation = activation
117
+ self.filter_parameters = filter_parameters or {}
118
+
119
+ # Checkers
120
+ if temporal_layer not in _valid_layers:
121
+ raise NotImplementedError(
122
+ f"Temporal layer '{temporal_layer}' is not implemented."
123
+ )
124
+
125
+ if self.n_times % self.stride_factor != 0:
126
+ warn(
127
+ f"Time dimension ({self.n_times}) is not divisible by"
128
+ f" stride_factor ({self.stride_factor}). Input will be padded.",
129
+ UserWarning,
130
+ )
131
+
132
+ # Layers
133
+ # Following paper nomenclature
134
+ self.spectral_filtering = FilterBankLayer(
135
+ n_chans=self.n_chans,
136
+ sfreq=self.sfreq,
137
+ band_filters=self.n_bands,
138
+ verbose=False,
139
+ **self.filter_parameters,
140
+ )
141
+ # As we have an internal process to create the bands,
142
+ # we get the values from the filterbank
143
+ self.n_bands = self.spectral_filtering.n_bands
144
+
145
+ # Spatial Convolution Block (SCB)
146
+ self.spatial_conv = nn.Sequential(
147
+ Conv2dWithConstraint(
148
+ in_channels=self.n_bands,
149
+ out_channels=self.n_filters_spat * self.n_bands,
150
+ kernel_size=(self.n_chans, 1),
151
+ groups=self.n_bands,
152
+ max_norm=cnn_max_norm,
153
+ padding=0,
154
+ ),
155
+ nn.BatchNorm2d(self.n_filters_spat * self.n_bands),
156
+ self.activation(),
157
+ )
158
+
159
+ # Padding layer
160
+ if self.n_times % self.stride_factor != 0:
161
+ self.padding_size = stride_factor - (self.n_times % stride_factor)
162
+ self.n_times_padded = self.n_times + self.padding_size
163
+ self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
164
+ else:
165
+ self.padding_layer = nn.Identity()
166
+ self.n_times_padded = self.n_times
167
+
168
+ # Temporal aggregator
169
+ self.temporal_layer = _valid_layers[temporal_layer](dim=self.n_dim) # type: ignore
170
+
171
+ # Flatten layer
172
+ self.flatten_layer = Rearrange("batch ... -> batch (...)")
173
+
174
+ # Final fully connected layer
175
+ self.final_layer = LinearWithConstraint(
176
+ in_features=self.n_filters_spat * self.n_bands * self.stride_factor,
177
+ out_features=self.n_outputs,
178
+ max_norm=linear_max_norm,
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ """
183
+ Forward pass of the FBCNet model.
184
+
185
+ Parameters
186
+ ----------
187
+ x : torch.Tensor
188
+ Input tensor with shape (batch_size, n_chans, n_times).
189
+
190
+ Returns
191
+ -------
192
+ torch.Tensor
193
+ Output tensor with shape (batch_size, n_outputs).
194
+ """
195
+ # output: (batch_size, n_chans, n_times)
196
+ x = self.spectral_filtering(x)
197
+
198
+ # output: (batch_size, n_bands, n_chans, n_times)
199
+ x = self.spatial_conv(x)
200
+ batch_size, channels, _, _ = x.shape
201
+
202
+ # shape: (batch_size, n_filters_spat * n_bands, 1, n_times)
203
+ x = self.padding_layer(x)
204
+
205
+ # shape: (batch_size, n_filters_spat * n_bands, 1, n_times_padded)
206
+ x = x.view(
207
+ batch_size,
208
+ channels,
209
+ self.stride_factor,
210
+ self.n_times_padded // self.stride_factor,
211
+ )
212
+ # shape: batch_size, n_filters_spat * n_bands, stride, n_times_padded/stride
213
+ x = self.temporal_layer(x) # type: ignore[operator]
214
+
215
+ # shape: batch_size, n_filters_spat * n_bands, stride, 1
216
+ x = self.flatten_layer(x)
217
+
218
+ # shape: batch_size, n_filters_spat * n_bands * stride
219
+ x = self.final_layer(x)
220
+ # shape: batch_size, n_outputs
221
+ return x
@@ -0,0 +1,313 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops.layers.torch import Rearrange
8
+ from mne.utils import warn
9
+ from torch import nn
10
+
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.modules import (
13
+ FilterBankLayer,
14
+ LogVarLayer,
15
+ )
16
+
17
+
18
+ class FBLightConvNet(EEGModuleMixin, nn.Module):
19
+ """LightConvNet from Ma, X et al (2023) [lightconvnet]_.
20
+
21
+ .. figure:: https://raw.githubusercontent.com/Ma-Xinzhi/LightConvNet/refs/heads/main/network_architecture.png
22
+ :align: center
23
+ :alt: LightConvNet Neural Network
24
+
25
+ A lightweight convolutional neural network incorporating temporal
26
+ dependency learning and attention mechanisms. The architecture is
27
+ designed to efficiently capture spatial and temporal features through
28
+ specialized convolutional layers and **multi-head attention**.
29
+
30
+ The network architecture consists of four main modules:
31
+
32
+ 1. **Spatial and Spectral Information Learning**:
33
+ Applies filterbank and spatial convolutions.
34
+ This module is followed by batch normalization and
35
+ an activation function to enhance feature representation.
36
+
37
+ 2. **Temporal Segmentation and Feature Extraction**:
38
+ Divides the processed data into non-overlapping temporal windows.
39
+ Within each window, a variance-based layer extracts discriminative features,
40
+ which are then log-transformed to stabilize variance before being
41
+ passed to the attention module.
42
+
43
+ 3. **Temporal Attention Module**: Utilizes a multi-head attention
44
+ mechanism with depthwise separable convolutions to capture dependencies
45
+ across different temporal segments. The attention weights are normalized
46
+ using softmax and aggregated to form a comprehensive temporal
47
+ representation.
48
+
49
+ 4. **Final Layer**: Flattens the aggregated features and passes them
50
+ through a linear layer to with kernel sizes matching the input
51
+ dimensions to integrate features across different channels generate the
52
+ final output predictions.
53
+
54
+ Notes
55
+ -----
56
+ This implementation is not guaranteed to be correct and has not been checked
57
+ by the original authors; it is a braindecode adaptation from the Pytorch
58
+ source-code [lightconvnetcode]_.
59
+
60
+ Parameters
61
+ ----------
62
+ n_bands : int or None or list of tuple of int, default=8
63
+ Number of frequency bands or a list of frequency band tuples. If a list of tuples is provided,
64
+ each tuple defines the lower and upper bounds of a frequency band.
65
+ n_filters_spat : int, default=32
66
+ Number of spatial filters in the depthwise convolutional layer.
67
+ n_dim : int, default=3
68
+ Number of dimensions for the temporal reduction layer.
69
+ stride_factor : int, default=4
70
+ Stride factor used for reshaping the temporal dimension.
71
+ activation : nn.Module, default=nn.ELU
72
+ Activation function class to apply after convolutional layers.
73
+ verbose : bool, default=False
74
+ If True, enables verbose output during filter creation using mne.
75
+ filter_parameters : dict, default={}
76
+ Additional parameters for the FilterBankLayer.
77
+ heads : int, default=8
78
+ Number of attention heads in the multi-head attention mechanism.
79
+ weight_softmax : bool, default=True
80
+ If True, applies softmax to the attention weights.
81
+ bias : bool, default=False
82
+ If True, includes a bias term in the convolutional layers.
83
+
84
+ References
85
+ ----------
86
+ .. [lightconvnet] Ma, X., Chen, W., Pei, Z., Liu, J., Huang, B., & Chen, J.
87
+ (2023). A temporal dependency learning CNN with attention mechanism
88
+ for MI-EEG decoding. IEEE Transactions on Neural Systems and
89
+ Rehabilitation Engineering.
90
+ .. [lightconvnetcode] Link to source-code:
91
+ https://github.com/Ma-Xinzhi/LightConvNet
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ # Braindecode parameters
97
+ n_chans=None,
98
+ n_outputs=None,
99
+ chs_info=None,
100
+ n_times=None,
101
+ input_window_seconds=None,
102
+ sfreq=None,
103
+ # models parameters
104
+ n_bands=9,
105
+ n_filters_spat: int = 32,
106
+ n_dim: int = 3,
107
+ stride_factor: int = 4,
108
+ win_len: int = 250,
109
+ heads: int = 8,
110
+ weight_softmax: bool = True,
111
+ bias: bool = False,
112
+ activation: nn.Module = nn.ELU,
113
+ verbose: bool = False,
114
+ filter_parameters: Optional[dict] = None,
115
+ ):
116
+ super().__init__(
117
+ n_chans=n_chans,
118
+ n_outputs=n_outputs,
119
+ chs_info=chs_info,
120
+ n_times=n_times,
121
+ input_window_seconds=input_window_seconds,
122
+ sfreq=sfreq,
123
+ )
124
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
125
+
126
+ # Parameters
127
+ self.n_bands = n_bands
128
+ self.n_filters_spat = n_filters_spat
129
+ self.n_dim = n_dim
130
+ self.stride_factor = stride_factor
131
+ self.win_len = win_len
132
+ self.activation = activation
133
+ self.heads = heads
134
+ self.weight_softmax = weight_softmax
135
+ self.bias = bias
136
+ self.filter_parameters = filter_parameters or {}
137
+
138
+ # Checkers
139
+ self.n_times_truncated = self.n_times
140
+ if self.n_times % self.win_len != 0:
141
+ warn(
142
+ f"Time dimension ({self.n_times}) is not divisible by"
143
+ f" win_len ({self.win_len}). Input will be "
144
+ f"truncated in {self.n_times % self.win_len} temporal points ",
145
+ UserWarning,
146
+ )
147
+ self.n_times_truncated = self.n_times - (self.n_times % self.win_len)
148
+
149
+ # Layers
150
+ # Following paper nomeclature
151
+ self.spectral_filtering = FilterBankLayer(
152
+ n_chans=self.n_chans,
153
+ sfreq=self.sfreq,
154
+ band_filters=self.n_bands,
155
+ verbose=verbose,
156
+ **self.filter_parameters,
157
+ )
158
+ # As we have an internal process to create the bands,
159
+ # we get the values from the filterbank
160
+ self.n_bands = self.spectral_filtering.n_bands
161
+
162
+ # The convolution here is different.
163
+ self.spatial_conv = nn.Sequential(
164
+ nn.Conv2d(
165
+ in_channels=self.n_bands,
166
+ out_channels=self.n_filters_spat,
167
+ kernel_size=(self.n_chans, 1),
168
+ ),
169
+ nn.BatchNorm2d(self.n_filters_spat),
170
+ self.activation(),
171
+ )
172
+
173
+ # Temporal aggregator
174
+ self.temporal_layer = LogVarLayer(self.n_dim, False)
175
+
176
+ self.flatten_layer = Rearrange("batch ... -> batch (...)")
177
+
178
+ # LightWeightConv1D
179
+ self.attn_conv = _LightweightConv1d(
180
+ self.n_filters_spat,
181
+ (self.n_times // self.win_len),
182
+ heads=self.heads,
183
+ weight_softmax=weight_softmax,
184
+ bias=bias,
185
+ )
186
+
187
+ self.final_layer = nn.Linear(
188
+ in_features=self.n_filters_spat,
189
+ out_features=self.n_outputs,
190
+ )
191
+
192
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
193
+ """
194
+ Forward pass of the FBLightConvNet model.
195
+ Parameters
196
+ ----------
197
+ x : torch.Tensor
198
+ Input tensor with shape (batch_size, n_chans, n_times).
199
+ Returns
200
+ -------
201
+ torch.Tensor
202
+ Output tensor with shape (batch_size, n_outputs).
203
+ """
204
+ batch_size, _, _ = x.shape
205
+ # x.shape: batch, n_chans, n_times
206
+
207
+ x = self.spectral_filtering(x)
208
+ # x.shape: batch, nbands, n_chans, n_times
209
+
210
+ x = self.spatial_conv(x)
211
+ # x.shape: batch, n_filters_spat, n_times
212
+
213
+ x = x[:, :, :, : self.n_times_truncated]
214
+ # batch, n_filters_spat, n_times_trucated
215
+
216
+ x = x.reshape([batch_size, self.n_filters_spat, -1, self.win_len])
217
+ # batch, n_filters_spat, n_windows, win_len
218
+ # where the n_windows = n_times_truncated / win_len
219
+ # and win_len = 250 by default
220
+
221
+ x = self.temporal_layer(x)
222
+ # x.shape : batch, n_filters_spat, n_windows
223
+
224
+ x = self.attn_conv(x)
225
+ # x.shape : batch, n_filters_spat, 1
226
+
227
+ x = self.flatten_layer(x)
228
+ # x.shape : batch, n_filters_spat
229
+
230
+ x = self.final_layer(x)
231
+ # x.shape : batch, n_outputs
232
+ return x
233
+
234
+
235
+ class _LightweightConv1d(nn.Module):
236
+ """Lightweight 1D Convolution Module.
237
+
238
+ Applies a convolution operation with multiple heads, allowing for
239
+ parallel filter applications. Optionally applies a softmax normalization
240
+ to the convolution weights.
241
+
242
+ Parameters
243
+ ----------
244
+ input_size : int
245
+ Number of channels of the input and output.
246
+ kernel_size : int, optional
247
+ Size of the convolution kernel. Default is `1`.
248
+ padding : int, optional
249
+ Amount of zero-padding added to both sides of the input. Default is `0`.
250
+ heads : int, optional
251
+ Number of attention heads used. The weight has shape `(heads, 1, kernel_size)`.
252
+ Default is `1`.
253
+ weight_softmax : bool, optional
254
+ If `True`, normalizes the convolution weights with softmax before applying the convolution.
255
+ Default is `False`.
256
+ bias : bool, optional
257
+ If `True`, adds a learnable bias to the output. Default is `False`.
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ input_size: int,
263
+ kernel_size: int = 1,
264
+ padding: int = 0,
265
+ heads: int = 1,
266
+ weight_softmax: bool = False,
267
+ bias: bool = False,
268
+ ):
269
+ super().__init__()
270
+ self.input_size = input_size
271
+ self.kernel_size = kernel_size
272
+ self.heads = heads
273
+ self.padding = padding
274
+ self.weight_softmax = weight_softmax
275
+ self.weight = nn.Parameter(torch.Tensor(heads, 1, kernel_size))
276
+
277
+ if bias:
278
+ self.bias = nn.Parameter(torch.Tensor(input_size))
279
+ else:
280
+ self.bias = None
281
+
282
+ self._init_parameters()
283
+
284
+ def _init_parameters(self):
285
+ nn.init.xavier_uniform_(self.weight)
286
+ if self.bias is not None:
287
+ nn.init.constant_(self.bias, 0.0)
288
+
289
+ def forward(self, input):
290
+ # batch, n_filters_spat, n_windows
291
+ B, C, T = input.size()
292
+
293
+ H = self.heads
294
+
295
+ weight = self.weight
296
+ if self.weight_softmax:
297
+ weight = F.softmax(weight, dim=-1)
298
+ # shape: (heads, 1, kernel_size)
299
+
300
+ # reshape input so each head is its own “batch”
301
+ # original C = H * (C/H), so view to (B * (C/H), H, T) then transpose
302
+ # but since C/H == 1 here per head-channel grouping, .view(-1, H, T) works
303
+ # new shape: (B * channels_per_head, H, T)
304
+ input = input.view(-1, H, T)
305
+ output = F.conv1d(input, weight, padding=self.padding, groups=self.heads)
306
+ # 4, 8, 1
307
+ output = output.view(B, C, -1)
308
+ # 1, 32, 1
309
+ if self.bias is not None:
310
+ # Add bias if it exists
311
+ output = output + self.bias.view(1, -1, 1)
312
+ # final shape: batch, n_filters_spat
313
+ return output