braindecode 1.3.0.dev177628147__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/functional.py +0 -101
- braindecode/augmentation/transforms.py +0 -74
- braindecode/datasets/base.py +3 -18
- braindecode/datautil/serialization.py +1 -0
- braindecode/models/__init__.py +1 -8
- braindecode/models/summary.csv +0 -1
- braindecode/models/util.py +0 -84
- braindecode/preprocessing/__init__.py +0 -5
- braindecode/preprocessing/eegprep_preprocess.py +19 -134
- braindecode/preprocessing/mne_preprocess.py +25 -56
- braindecode/preprocessing/preprocess.py +41 -126
- braindecode/preprocessing/util.py +0 -11
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +5 -11
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/RECORD +19 -24
- braindecode/datasets/hub.py +0 -962
- braindecode/datasets/hub_validation.py +0 -113
- braindecode/datasets/registry.py +0 -120
- braindecode/datautil/hub_formats.py +0 -180
- braindecode/models/luna.py +0 -836
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177628147.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
braindecode/models/luna.py
DELETED
|
@@ -1,836 +0,0 @@
|
|
|
1
|
-
"""This implementation is adapted from ETH Zurich's BioFoundation repository.
|
|
2
|
-
|
|
3
|
-
Döner, B., Ingolfsson, T. M., Benini, L., & Li, Y. (2025).
|
|
4
|
-
LUNA: Efficient and Topology-Agnostic Foundation Model for EEG Signal Analysis.
|
|
5
|
-
The Thirty-Ninth Annual Conference on Neural Information Processing Systems, NeurIPS.
|
|
6
|
-
Retrieved from https://openreview.net/forum?id=uazfjnFL0G
|
|
7
|
-
|
|
8
|
-
Original Authors: Berkay Döner, Thorir Mar Ingolfsson
|
|
9
|
-
Braindecode Adaptation: Bruno Aristimunha
|
|
10
|
-
|
|
11
|
-
the LICENSE Of this file is APACHE-2.0.
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
import math
|
|
15
|
-
from typing import Any, Dict, Optional, Tuple, Type
|
|
16
|
-
|
|
17
|
-
import torch
|
|
18
|
-
import torch.fft as fft
|
|
19
|
-
import torch.nn as nn
|
|
20
|
-
import torch.nn.functional as F
|
|
21
|
-
from einops import rearrange
|
|
22
|
-
from rotary_embedding_torch import RotaryEmbedding
|
|
23
|
-
|
|
24
|
-
from braindecode.models.base import EEGModuleMixin
|
|
25
|
-
from braindecode.models.util import extract_channel_locations_from_chs_info
|
|
26
|
-
from braindecode.modules.layers import DropPath
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class LUNA(EEGModuleMixin, nn.Module):
|
|
30
|
-
"""LUNA from Döner et al. [LUNA]_.
|
|
31
|
-
|
|
32
|
-
:bdg-success:`Convolution` :bdg-danger:`Large Brain Model` :bdg-dark-line:`Channel`
|
|
33
|
-
|
|
34
|
-
.. figure:: https://arxiv.org/html/2510.22257v1/x1.png
|
|
35
|
-
:align: center
|
|
36
|
-
:alt: LUNA Architecture.
|
|
37
|
-
|
|
38
|
-
LUNA is a topology-invariant EEG model that processes signals from varying
|
|
39
|
-
numbers of channels using a channel-unification mechanism with learned queries.
|
|
40
|
-
|
|
41
|
-
The architecture consists of:
|
|
42
|
-
1. Patch Feature Extraction (temporal CNN + FFT-based features)
|
|
43
|
-
2. Channel-Unification Module (cross-attention with learned queries)
|
|
44
|
-
3. Patch-wise Temporal Encoder (RoPE-based transformer)
|
|
45
|
-
4. Decoder Heads (classification or reconstruction)
|
|
46
|
-
|
|
47
|
-
Parameters
|
|
48
|
-
----------
|
|
49
|
-
patch_size : int
|
|
50
|
-
Number of time samples per patch. Default: 40.
|
|
51
|
-
num_queries : int
|
|
52
|
-
Number of learned queries for channel unification.
|
|
53
|
-
Paper uses: 4 (Base), 6 (Large), 8 (Huge). Default: 4.
|
|
54
|
-
embed_dim : int
|
|
55
|
-
Embedding dimension for patch features.
|
|
56
|
-
Paper uses: 64 (Base), 96 (Large), 128 (Huge). Default: 64.
|
|
57
|
-
depth : int
|
|
58
|
-
Number of transformer encoder blocks.
|
|
59
|
-
Paper uses: 8 (Base), 10 (Large), 24 (Huge). Default: 8.
|
|
60
|
-
num_heads : int
|
|
61
|
-
Number of attention heads in channel unification.
|
|
62
|
-
Default: 2.
|
|
63
|
-
mlp_ratio : float
|
|
64
|
-
Ratio of MLP hidden dimension to embedding dimension. Default: 4.0.
|
|
65
|
-
norm_layer : nn.Module
|
|
66
|
-
Normalization layer class. Default: nn.LayerNorm.
|
|
67
|
-
drop_path : float
|
|
68
|
-
Stochastic depth rate. Default: 0.0.
|
|
69
|
-
|
|
70
|
-
References
|
|
71
|
-
----------
|
|
72
|
-
.. [LUNA] Döner, B., Ingolfsson, T. M., Benini, L., & Li, Y. (2025).
|
|
73
|
-
LUNA: Efficient and Topology-Agnostic Foundation Model for EEG Signal Analysis.
|
|
74
|
-
The Thirty-Ninth Annual Conference on Neural Information Processing Systems - NeurIPS.
|
|
75
|
-
Retrieved from https://openreview.net/forum?id=uazfjnFL0G
|
|
76
|
-
"""
|
|
77
|
-
|
|
78
|
-
def __init__(
|
|
79
|
-
self,
|
|
80
|
-
# Braindecode EEGModuleMixin parameters
|
|
81
|
-
n_outputs: Optional[int] = None,
|
|
82
|
-
n_chans: Optional[int] = None,
|
|
83
|
-
n_times: Optional[int] = None,
|
|
84
|
-
sfreq: Optional[float] = None,
|
|
85
|
-
chs_info: Optional[Any] = None,
|
|
86
|
-
input_window_seconds: Optional[float] = None,
|
|
87
|
-
# Model-specific parameters
|
|
88
|
-
patch_size: int = 40,
|
|
89
|
-
num_queries: int = 4,
|
|
90
|
-
embed_dim: int = 64,
|
|
91
|
-
depth: int = 8,
|
|
92
|
-
num_heads: int = 2,
|
|
93
|
-
mlp_ratio: float = 4.0,
|
|
94
|
-
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
95
|
-
drop_path: float = 0.0,
|
|
96
|
-
drop_prob_chan: float = 0.0,
|
|
97
|
-
attn_drop: float = 0.0,
|
|
98
|
-
activation: Type[nn.Module] = nn.GELU,
|
|
99
|
-
):
|
|
100
|
-
super().__init__(
|
|
101
|
-
n_outputs=n_outputs,
|
|
102
|
-
n_chans=n_chans,
|
|
103
|
-
n_times=n_times,
|
|
104
|
-
sfreq=sfreq,
|
|
105
|
-
chs_info=chs_info,
|
|
106
|
-
input_window_seconds=input_window_seconds,
|
|
107
|
-
)
|
|
108
|
-
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
109
|
-
|
|
110
|
-
# Mapping for loading pretrained weights
|
|
111
|
-
self.mapping = {
|
|
112
|
-
"channel_location_embedder.0.fc1.weight": "channel_location_embedder.fc1.weight",
|
|
113
|
-
"channel_location_embedder.0.fc1.bias": "channel_location_embedder.fc1.bias",
|
|
114
|
-
"channel_location_embedder.0.fc2.weight": "channel_location_embedder.fc2.weight",
|
|
115
|
-
"channel_location_embedder.0.fc2.bias": "channel_location_embedder.fc2.bias",
|
|
116
|
-
"channel_location_embedder.0.norm.weight": "channel_location_embedder.norm.weight",
|
|
117
|
-
"channel_location_embedder.0.norm.bias": "channel_location_embedder.norm.bias",
|
|
118
|
-
}
|
|
119
|
-
|
|
120
|
-
# Model parameters
|
|
121
|
-
self.num_classes = self.n_outputs if self.n_outputs else 0
|
|
122
|
-
self.embed_dim = embed_dim
|
|
123
|
-
self.num_queries = num_queries
|
|
124
|
-
self.patch_size = patch_size
|
|
125
|
-
self.patch_embed_size = embed_dim
|
|
126
|
-
self.num_heads = num_heads
|
|
127
|
-
self.depth = depth
|
|
128
|
-
self.drop_path = drop_path
|
|
129
|
-
self.attn_drop = attn_drop
|
|
130
|
-
self.drop_prob_chan = drop_prob_chan
|
|
131
|
-
self.mlp_ratio = mlp_ratio
|
|
132
|
-
self.activation = activation
|
|
133
|
-
|
|
134
|
-
# Layers
|
|
135
|
-
self.patch_embed = _PatchEmbedNetwork(
|
|
136
|
-
embed_dim=self.embed_dim, patch_size=self.patch_size
|
|
137
|
-
)
|
|
138
|
-
self.freq_embed = _FrequencyFeatureEmbedder(
|
|
139
|
-
embed_dim=self.embed_dim, patch_size=self.patch_size
|
|
140
|
-
)
|
|
141
|
-
# For weight loading, we omit the normalization here to match parameter count
|
|
142
|
-
self.channel_location_embedder = _Mlp(
|
|
143
|
-
in_features=int(self.patch_embed_size),
|
|
144
|
-
out_features=int(self.patch_embed_size),
|
|
145
|
-
hidden_features=int(self.patch_embed_size * 2),
|
|
146
|
-
act_layer=self.activation,
|
|
147
|
-
drop=self.drop_prob_chan,
|
|
148
|
-
)
|
|
149
|
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
150
|
-
self.cross_attn = _CrossAttentionBlock(
|
|
151
|
-
num_queries=self.num_queries,
|
|
152
|
-
input_embed_dim=self.embed_dim,
|
|
153
|
-
output_embed_dim=self.embed_dim,
|
|
154
|
-
num_heads=self.num_heads,
|
|
155
|
-
ff_dim=int(self.mlp_ratio * self.embed_dim),
|
|
156
|
-
)
|
|
157
|
-
self.blocks = nn.ModuleList(
|
|
158
|
-
[
|
|
159
|
-
_RotaryTransformerBlock(
|
|
160
|
-
dim=int(self.embed_dim * self.num_queries),
|
|
161
|
-
num_heads=int(self.num_heads * self.num_queries),
|
|
162
|
-
mlp_ratio=self.mlp_ratio,
|
|
163
|
-
qkv_bias=True,
|
|
164
|
-
drop=0.0,
|
|
165
|
-
attn_drop=0.0,
|
|
166
|
-
drop_path=self.drop_path,
|
|
167
|
-
norm_layer=norm_layer,
|
|
168
|
-
)
|
|
169
|
-
for i in range(depth)
|
|
170
|
-
]
|
|
171
|
-
)
|
|
172
|
-
self.norm = norm_layer(int(self.embed_dim * self.num_queries))
|
|
173
|
-
|
|
174
|
-
self._channel_location_cache: Dict[int, torch.Tensor] = {}
|
|
175
|
-
|
|
176
|
-
if self.num_classes == 0:
|
|
177
|
-
self.decoder_head = _PatchReconstructionHeadWithQueries(
|
|
178
|
-
input_dim=self.patch_size,
|
|
179
|
-
embed_dim=self.embed_dim,
|
|
180
|
-
num_heads=self.num_heads,
|
|
181
|
-
num_queries=self.num_queries,
|
|
182
|
-
)
|
|
183
|
-
self.channel_emb = _ChannelEmbeddings(self.embed_dim)
|
|
184
|
-
else:
|
|
185
|
-
self.final_layer = _ClassificationHeadWithQueries(
|
|
186
|
-
input_dim=self.patch_size,
|
|
187
|
-
num_queries=self.num_queries,
|
|
188
|
-
embed_dim=self.embed_dim,
|
|
189
|
-
num_classes=self.num_classes,
|
|
190
|
-
num_heads=self.num_heads,
|
|
191
|
-
)
|
|
192
|
-
self.mask_token.requires_grad = (
|
|
193
|
-
False # no use of mask token for classification
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
self.initialize_weights()
|
|
197
|
-
|
|
198
|
-
def initialize_weights(self) -> None:
|
|
199
|
-
self.cross_attn.initialize_weights()
|
|
200
|
-
trunc_normal_(self.mask_token, std=0.02)
|
|
201
|
-
self.apply(self._init_weights)
|
|
202
|
-
self.fix_init_weight()
|
|
203
|
-
|
|
204
|
-
def _init_weights(self, m: nn.Module) -> None:
|
|
205
|
-
if isinstance(m, nn.Linear):
|
|
206
|
-
torch.nn.init.xavier_normal_(m.weight)
|
|
207
|
-
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
208
|
-
nn.init.constant_(m.bias, 0)
|
|
209
|
-
elif isinstance(m, nn.LayerNorm):
|
|
210
|
-
nn.init.constant_(m.bias, 0)
|
|
211
|
-
nn.init.constant_(m.weight, 1.0)
|
|
212
|
-
|
|
213
|
-
def fix_init_weight(self) -> None:
|
|
214
|
-
def rescale(param: torch.Tensor, layer_id: int) -> None:
|
|
215
|
-
param.div_(math.sqrt(2.0 * layer_id))
|
|
216
|
-
|
|
217
|
-
for layer_id, layer in enumerate(self.blocks):
|
|
218
|
-
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
|
219
|
-
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
|
220
|
-
|
|
221
|
-
def prepare_tokens(
|
|
222
|
-
self,
|
|
223
|
-
x_signal: torch.Tensor,
|
|
224
|
-
channel_locations: torch.Tensor,
|
|
225
|
-
mask: Optional[torch.Tensor] = None,
|
|
226
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
227
|
-
num_channels = channel_locations.shape[1]
|
|
228
|
-
num_patches_per_channel = x_signal.shape[-1] // self.patch_size
|
|
229
|
-
x_patched = self.patch_embed(x_signal)
|
|
230
|
-
freq_embed = self.freq_embed(x_signal)
|
|
231
|
-
x_patched = x_patched + freq_embed
|
|
232
|
-
x_masked = x_patched.clone() # (B, N, D), N = C * num_patches_per_channel
|
|
233
|
-
|
|
234
|
-
if mask is not None:
|
|
235
|
-
mask_tokens = self.mask_token.repeat(
|
|
236
|
-
x_masked.shape[0], x_masked.shape[1], 1
|
|
237
|
-
) # (B, N, D) N = C * num_patches_per_channel
|
|
238
|
-
mask = rearrange(
|
|
239
|
-
mask, "B C (S P) -> B (C S) P", P=self.patch_size
|
|
240
|
-
) # (B, C, T) -> (B, N, P)
|
|
241
|
-
mask = (
|
|
242
|
-
(mask.sum(dim=-1) > 0).unsqueeze(-1).float()
|
|
243
|
-
) # (B, N, 1), since a patch is either fully masked or not
|
|
244
|
-
x_masked = torch.where(mask.bool(), mask_tokens, x_masked)
|
|
245
|
-
|
|
246
|
-
channel_min = torch.min(channel_locations, dim=1, keepdim=True)[0]
|
|
247
|
-
channel_max = torch.max(channel_locations, dim=1, keepdim=True)[0]
|
|
248
|
-
channel_locations = (channel_locations - channel_min) / (
|
|
249
|
-
channel_max - channel_min + 1e-8
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
if mask is not None:
|
|
253
|
-
channel_locations = (
|
|
254
|
-
channel_locations + torch.randn_like(channel_locations) * 0.02
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
channel_locations = nerf_positional_encoding(
|
|
258
|
-
channel_locations, self.patch_embed_size
|
|
259
|
-
)
|
|
260
|
-
channel_locations_emb = self.channel_location_embedder(channel_locations)
|
|
261
|
-
|
|
262
|
-
x_tokenized = rearrange(x_masked, "B (C t) D -> (B t) C D", C=num_channels)
|
|
263
|
-
channel_locations_emb = channel_locations_emb.repeat(
|
|
264
|
-
num_patches_per_channel, 1, 1
|
|
265
|
-
)
|
|
266
|
-
x_tokenized = x_tokenized + channel_locations_emb
|
|
267
|
-
|
|
268
|
-
return x_tokenized, channel_locations_emb
|
|
269
|
-
|
|
270
|
-
def forward(
|
|
271
|
-
self,
|
|
272
|
-
X: torch.Tensor,
|
|
273
|
-
mask: Optional[torch.Tensor] = None,
|
|
274
|
-
channel_locations: Optional[torch.Tensor] = None,
|
|
275
|
-
channel_names: Optional[torch.Tensor] = None,
|
|
276
|
-
) -> torch.Tensor:
|
|
277
|
-
"""Forward pass."""
|
|
278
|
-
x_signal = X
|
|
279
|
-
B, C, _ = x_signal.shape
|
|
280
|
-
|
|
281
|
-
if channel_locations is None:
|
|
282
|
-
channel_locations = self.get_default_channel_locations(
|
|
283
|
-
batch_size=B,
|
|
284
|
-
num_channels=C,
|
|
285
|
-
device=x_signal.device,
|
|
286
|
-
dtype=x_signal.dtype,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
x, channel_locations_emb = self.prepare_tokens(
|
|
290
|
-
x_signal, channel_locations, mask=mask
|
|
291
|
-
)
|
|
292
|
-
x, _ = self.cross_attn(x)
|
|
293
|
-
x = rearrange(x, "(B t) Q D -> B t (Q D)", B=B)
|
|
294
|
-
num_patches = x.shape[1]
|
|
295
|
-
|
|
296
|
-
for blk in self.blocks:
|
|
297
|
-
x = blk(x)
|
|
298
|
-
x_latent = self.norm(x)
|
|
299
|
-
|
|
300
|
-
if self.num_classes > 0:
|
|
301
|
-
return self.final_layer(x_latent)
|
|
302
|
-
|
|
303
|
-
if channel_names is None:
|
|
304
|
-
raise ValueError("channel_names must be provided for reconstruction tasks.")
|
|
305
|
-
channel_emb = self.channel_emb(channel_names)
|
|
306
|
-
channel_emb = channel_emb.repeat(num_patches, 1, 1)
|
|
307
|
-
decoder_queries = channel_locations_emb + channel_emb
|
|
308
|
-
return self.decoder_head(x_latent, decoder_queries)
|
|
309
|
-
|
|
310
|
-
def get_default_channel_locations(
|
|
311
|
-
self,
|
|
312
|
-
batch_size: int,
|
|
313
|
-
num_channels: int,
|
|
314
|
-
device: torch.device,
|
|
315
|
-
dtype: torch.dtype,
|
|
316
|
-
) -> torch.Tensor:
|
|
317
|
-
if num_channels not in self._channel_location_cache:
|
|
318
|
-
template = self.build_channel_location_template(num_channels)
|
|
319
|
-
self._channel_location_cache[num_channels] = template
|
|
320
|
-
template = self._channel_location_cache[num_channels].to(
|
|
321
|
-
device=device, dtype=dtype
|
|
322
|
-
)
|
|
323
|
-
return template.unsqueeze(0).repeat(batch_size, 1, 1)
|
|
324
|
-
|
|
325
|
-
def build_channel_location_template(self, num_channels: int) -> torch.Tensor:
|
|
326
|
-
"""Build channel location template for the model.
|
|
327
|
-
|
|
328
|
-
Attempts to extract channel locations from chs_info. Falls back to a default
|
|
329
|
-
linear spacing along the x-axis if real locations are unavailable.
|
|
330
|
-
|
|
331
|
-
Parameters
|
|
332
|
-
----------
|
|
333
|
-
num_channels : int
|
|
334
|
-
Number of channels to generate locations for.
|
|
335
|
-
|
|
336
|
-
Returns
|
|
337
|
-
-------
|
|
338
|
-
torch.Tensor
|
|
339
|
-
Tensor of shape (num_channels, 3) with channel locations in 3D space.
|
|
340
|
-
"""
|
|
341
|
-
# Try to extract channel locations from chs_info using the unified utility
|
|
342
|
-
channel_info = getattr(self, "_chs_info", None)
|
|
343
|
-
if channel_info is not None:
|
|
344
|
-
locs = extract_channel_locations_from_chs_info(
|
|
345
|
-
channel_info, num_channels=num_channels
|
|
346
|
-
)
|
|
347
|
-
if locs is not None and len(locs) == num_channels:
|
|
348
|
-
return torch.from_numpy(locs).float()
|
|
349
|
-
|
|
350
|
-
# Fallback: generate default linear spacing along x-axis
|
|
351
|
-
positions = torch.linspace(-1.0, 1.0, steps=num_channels, dtype=torch.float32)
|
|
352
|
-
zeros = torch.zeros_like(positions)
|
|
353
|
-
locs_tensor = torch.stack([positions, zeros, zeros], dim=-1)
|
|
354
|
-
return locs_tensor
|
|
355
|
-
|
|
356
|
-
def _get_default_channel_locations(
|
|
357
|
-
self,
|
|
358
|
-
batch_size: int,
|
|
359
|
-
num_channels: int,
|
|
360
|
-
device: torch.device,
|
|
361
|
-
dtype: torch.dtype,
|
|
362
|
-
) -> torch.Tensor:
|
|
363
|
-
return self.get_default_channel_locations(
|
|
364
|
-
batch_size=batch_size,
|
|
365
|
-
num_channels=num_channels,
|
|
366
|
-
device=device,
|
|
367
|
-
dtype=dtype,
|
|
368
|
-
)
|
|
369
|
-
|
|
370
|
-
def _build_channel_location_template(self, num_channels: int) -> torch.Tensor:
|
|
371
|
-
return self.build_channel_location_template(num_channels)
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
def trunc_normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> None:
|
|
375
|
-
nn.init.trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def nerf_positional_encoding(coords: torch.Tensor, embed_size: int) -> torch.Tensor:
|
|
379
|
-
"""
|
|
380
|
-
coords: (N, C, 3)
|
|
381
|
-
Returns: (N, C, embed_size)
|
|
382
|
-
"""
|
|
383
|
-
N, C, dim = coords.shape
|
|
384
|
-
device = coords.device
|
|
385
|
-
freqs = embed_size // (2 * dim)
|
|
386
|
-
leftover = embed_size - freqs * 2 * dim
|
|
387
|
-
freq_bands = 2.0 ** torch.arange(freqs, device=device).float()
|
|
388
|
-
scaled_coords = coords.unsqueeze(-1) * freq_bands.view(
|
|
389
|
-
1, 1, 1, -1
|
|
390
|
-
) # (N, C, dim, freqs)
|
|
391
|
-
sin_enc = torch.sin(scaled_coords) # (N, C, dim, freqs)
|
|
392
|
-
cos_enc = torch.cos(scaled_coords) # (N, C, dim, freqs)
|
|
393
|
-
encoded = (
|
|
394
|
-
torch.stack([sin_enc, cos_enc], dim=-1)
|
|
395
|
-
.permute(0, 1, 3, 2, 4)
|
|
396
|
-
.reshape(N, C, freqs * dim * 2)
|
|
397
|
-
)
|
|
398
|
-
if leftover > 0:
|
|
399
|
-
pad = torch.zeros(N, C, leftover, device=device, dtype=coords.dtype)
|
|
400
|
-
encoded = torch.cat([encoded, pad], dim=-1)
|
|
401
|
-
return encoded
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
class _ChannelEmbeddings(nn.Module):
|
|
405
|
-
"""
|
|
406
|
-
This class creates embeddings for each EEG channel based on a predefined
|
|
407
|
-
mapping of channel names to indices.
|
|
408
|
-
|
|
409
|
-
The number of unique channels is determined by the union of channels
|
|
410
|
-
from SEED Pretraining, TUEG, and Siena datasets.
|
|
411
|
-
|
|
412
|
-
Parameters
|
|
413
|
-
----------
|
|
414
|
-
embed_dim : int
|
|
415
|
-
Dimension of the channel embeddings.
|
|
416
|
-
number_channels : int
|
|
417
|
-
Number of unique EEG channels. Default is 90.
|
|
418
|
-
|
|
419
|
-
"""
|
|
420
|
-
|
|
421
|
-
def __init__(self, embed_dim: int, number_channels=90) -> None:
|
|
422
|
-
super().__init__()
|
|
423
|
-
self.embeddings = nn.Embedding(number_channels, embed_dim)
|
|
424
|
-
|
|
425
|
-
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
|
426
|
-
return self.embeddings(indices)
|
|
427
|
-
|
|
428
|
-
def initialize_weights(self) -> None:
|
|
429
|
-
torch.init.normal_(self.embeddings.weight, std=2.0)
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
class _FrequencyFeatureEmbedder(nn.Module):
|
|
433
|
-
"""
|
|
434
|
-
This class takes data that is of the form (B, C, T) and patches it
|
|
435
|
-
along the time dimension (T) into patches of size P (patch_size).
|
|
436
|
-
The output is of the form (B, C, S, P) where S = T // P.
|
|
437
|
-
"""
|
|
438
|
-
|
|
439
|
-
def __init__(self, patch_size: int, embed_dim: int) -> None:
|
|
440
|
-
super().__init__()
|
|
441
|
-
self.patch_size = patch_size
|
|
442
|
-
self.embed_dim = embed_dim
|
|
443
|
-
in_features = 2 * (patch_size // 2 + 1)
|
|
444
|
-
self.frequency_to_embed = _Mlp(
|
|
445
|
-
in_features=in_features,
|
|
446
|
-
hidden_features=int(4 * in_features),
|
|
447
|
-
out_features=embed_dim,
|
|
448
|
-
)
|
|
449
|
-
|
|
450
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
451
|
-
B, C, T = x.size()
|
|
452
|
-
S = T // self.patch_size
|
|
453
|
-
# There is a chance that the input tensor is not divisible by the patch size
|
|
454
|
-
# In this case we need to pad the tensor with zeros
|
|
455
|
-
if T % self.patch_size != 0:
|
|
456
|
-
# Pad last dimension with zeros to make it divisible by patch size
|
|
457
|
-
pad_size = self.patch_size - (T % self.patch_size)
|
|
458
|
-
x = F.pad(x, (0, pad_size))
|
|
459
|
-
T = x.size(-1)
|
|
460
|
-
S = T // self.patch_size
|
|
461
|
-
x = x.view(B, C, S, self.patch_size)
|
|
462
|
-
|
|
463
|
-
freq_representation = fft.rfft(
|
|
464
|
-
x, dim=-1
|
|
465
|
-
) # (B, C, num_patches, patch_size // 2 + 1)
|
|
466
|
-
magnitude = torch.abs(freq_representation)
|
|
467
|
-
phase = torch.angle(freq_representation)
|
|
468
|
-
|
|
469
|
-
# Concatenate magnitude and phase along the frequency axis (last dimension)
|
|
470
|
-
freq_features = torch.cat((magnitude, phase), dim=-1)
|
|
471
|
-
# Map frequency features to embedding dimension
|
|
472
|
-
embedded = self.frequency_to_embed(
|
|
473
|
-
freq_features
|
|
474
|
-
) # (B, C, num_patches, embed_dim)
|
|
475
|
-
embedded = rearrange(embedded, "B C t D -> B (C t) D")
|
|
476
|
-
return embedded
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
class _RotarySelfAttentionBlock(nn.Module):
|
|
480
|
-
def __init__(
|
|
481
|
-
self,
|
|
482
|
-
dim: int,
|
|
483
|
-
num_heads: int = 8,
|
|
484
|
-
qkv_bias: bool = True,
|
|
485
|
-
qk_scale: Optional[float] = None,
|
|
486
|
-
attn_drop: float = 0.0,
|
|
487
|
-
proj_drop: float = 0.0,
|
|
488
|
-
) -> None:
|
|
489
|
-
super().__init__()
|
|
490
|
-
self.dim = dim
|
|
491
|
-
self.num_heads = num_heads
|
|
492
|
-
head_dim = dim // num_heads
|
|
493
|
-
self.rotary_emb = RotaryEmbedding(dim=head_dim, learned_freq=False)
|
|
494
|
-
|
|
495
|
-
self.scale = qk_scale or head_dim**-0.5
|
|
496
|
-
|
|
497
|
-
self.qkv_proj = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
498
|
-
self.attn_drop = attn_drop
|
|
499
|
-
self.attn_drop_fn = nn.Dropout(attn_drop)
|
|
500
|
-
self.proj = nn.Linear(dim, dim)
|
|
501
|
-
self.proj_drop = nn.Dropout(proj_drop)
|
|
502
|
-
|
|
503
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
504
|
-
B, N, C = x.shape
|
|
505
|
-
qkv = (
|
|
506
|
-
self.qkv_proj(x)
|
|
507
|
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
|
508
|
-
.permute(2, 0, 3, 1, 4)
|
|
509
|
-
) # (K, B, H, N, D)
|
|
510
|
-
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
511
|
-
q = self.rotary_emb.rotate_queries_or_keys(q)
|
|
512
|
-
k = self.rotary_emb.rotate_queries_or_keys(k)
|
|
513
|
-
# Calculate attention scores
|
|
514
|
-
attn_weights = (q @ k.transpose(-2, -1)) * self.scale # (B, H, N, N)
|
|
515
|
-
|
|
516
|
-
# Apply softmax to get attention probabilities
|
|
517
|
-
attn_weights = torch.softmax(attn_weights, dim=-1)
|
|
518
|
-
|
|
519
|
-
# Apply dropout
|
|
520
|
-
attn_weights = self.attn_drop_fn(attn_weights)
|
|
521
|
-
|
|
522
|
-
# Apply attention weights to values
|
|
523
|
-
attn = attn_weights @ v # (B, H, N, D)
|
|
524
|
-
attn = rearrange(attn, "B H N D -> B N (H D)")
|
|
525
|
-
return self.proj_drop(self.proj(attn))
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
class _FeedForwardBlock(nn.Module):
|
|
529
|
-
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
|
|
530
|
-
super().__init__()
|
|
531
|
-
self.fc1 = nn.Linear(dim, hidden_dim)
|
|
532
|
-
self.activation = nn.GELU()
|
|
533
|
-
self.dropout1 = nn.Dropout(dropout)
|
|
534
|
-
self.dropout2 = nn.Dropout(dropout)
|
|
535
|
-
self.fc2 = nn.Linear(hidden_dim, dim)
|
|
536
|
-
self.norm = nn.LayerNorm(hidden_dim)
|
|
537
|
-
|
|
538
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
539
|
-
x = self.fc1(x)
|
|
540
|
-
x = self.activation(x)
|
|
541
|
-
x = self.dropout1(x)
|
|
542
|
-
x = self.norm(x)
|
|
543
|
-
x = self.fc2(x)
|
|
544
|
-
x = self.dropout2(x)
|
|
545
|
-
return x
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
class _RotaryTransformerBlock(nn.Module):
|
|
549
|
-
def __init__(
|
|
550
|
-
self,
|
|
551
|
-
dim: int,
|
|
552
|
-
num_heads: int,
|
|
553
|
-
mlp_ratio: float = 4.0,
|
|
554
|
-
qkv_bias: bool = False,
|
|
555
|
-
qk_scale: Optional[float] = None,
|
|
556
|
-
drop: float = 0.0,
|
|
557
|
-
attn_drop: float = 0.0,
|
|
558
|
-
drop_path: float = 0.0,
|
|
559
|
-
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
560
|
-
) -> None:
|
|
561
|
-
super().__init__()
|
|
562
|
-
self.norm1 = norm_layer(dim)
|
|
563
|
-
self.attn = _RotarySelfAttentionBlock(
|
|
564
|
-
dim=dim,
|
|
565
|
-
num_heads=num_heads,
|
|
566
|
-
qkv_bias=qkv_bias,
|
|
567
|
-
qk_scale=qk_scale,
|
|
568
|
-
attn_drop=attn_drop,
|
|
569
|
-
proj_drop=drop,
|
|
570
|
-
)
|
|
571
|
-
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
572
|
-
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
573
|
-
self.norm2 = norm_layer(dim)
|
|
574
|
-
self.mlp = _FeedForwardBlock(
|
|
575
|
-
dim=dim, hidden_dim=int(dim * mlp_ratio), dropout=drop
|
|
576
|
-
)
|
|
577
|
-
|
|
578
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
579
|
-
x = x + self.drop_path1(self.attn(self.norm1(x)))
|
|
580
|
-
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
|
581
|
-
return x
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
class _PatchReconstructionHeadWithQueries(nn.Module):
|
|
585
|
-
def __init__(
|
|
586
|
-
self,
|
|
587
|
-
input_dim: int = 8,
|
|
588
|
-
embed_dim: int = 768,
|
|
589
|
-
num_heads: int = 8,
|
|
590
|
-
num_queries: int = 4,
|
|
591
|
-
) -> None:
|
|
592
|
-
super().__init__()
|
|
593
|
-
self.input_dim = input_dim
|
|
594
|
-
self.embed_dim = embed_dim
|
|
595
|
-
self.reconstruction_shape = self.input_dim
|
|
596
|
-
self.num_queries = num_queries
|
|
597
|
-
# Projection from embed space to pixel space, according to type of input
|
|
598
|
-
self.decoder_pred = nn.TransformerDecoder(
|
|
599
|
-
nn.TransformerDecoderLayer(
|
|
600
|
-
embed_dim,
|
|
601
|
-
num_heads,
|
|
602
|
-
dropout=0.0,
|
|
603
|
-
batch_first=True,
|
|
604
|
-
activation="gelu",
|
|
605
|
-
dim_feedforward=int(embed_dim * 4),
|
|
606
|
-
norm_first=True,
|
|
607
|
-
),
|
|
608
|
-
num_layers=1,
|
|
609
|
-
)
|
|
610
|
-
self.norm = nn.LayerNorm(embed_dim)
|
|
611
|
-
self.decoder_linear = _Mlp(
|
|
612
|
-
embed_dim, int(embed_dim * 4), input_dim, act_layer=nn.GELU, drop=0.0
|
|
613
|
-
) # nn.Linear(embed_dim, input_dim, bias=True)
|
|
614
|
-
|
|
615
|
-
def forward(self, enc: torch.Tensor, decoder_queries: torch.Tensor) -> torch.Tensor:
|
|
616
|
-
"""
|
|
617
|
-
enc: [B, num_patches, embed_dim], embed_dim = Q*D
|
|
618
|
-
decoder_queries: [B*num_patches, num_channels, embed_dim]
|
|
619
|
-
"""
|
|
620
|
-
|
|
621
|
-
B, num_patches, embed_dim = enc.shape
|
|
622
|
-
enc = rearrange(enc, "B t (Q D) -> (B t) Q D", Q=self.num_queries)
|
|
623
|
-
out = self.decoder_pred(decoder_queries, enc) # (B*t, C, D)
|
|
624
|
-
out = self.norm(out)
|
|
625
|
-
out = self.decoder_linear(out) # (B*t, C, patch_size)
|
|
626
|
-
out = rearrange(out, "(B t) C P -> B C (t P)", B=B)
|
|
627
|
-
return out
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
class _ClassificationHeadWithQueries(nn.Module):
|
|
631
|
-
def __init__(
|
|
632
|
-
self,
|
|
633
|
-
input_dim: int = 8,
|
|
634
|
-
embed_dim: int = 768,
|
|
635
|
-
num_queries: int = 8,
|
|
636
|
-
num_heads: int = 8,
|
|
637
|
-
num_classes: int = 2,
|
|
638
|
-
drop_decoder: float = 0.15,
|
|
639
|
-
drop_ffn: float = 0.15,
|
|
640
|
-
) -> None:
|
|
641
|
-
super().__init__()
|
|
642
|
-
self.input_dim = input_dim
|
|
643
|
-
self.embed_dim = int(embed_dim * num_queries)
|
|
644
|
-
self.reconstruction_shape = self.input_dim
|
|
645
|
-
self.decoder_attn = nn.MultiheadAttention(
|
|
646
|
-
self.embed_dim, num_heads, batch_first=True, dropout=drop_decoder
|
|
647
|
-
)
|
|
648
|
-
self.decoder_ffn = _Mlp(
|
|
649
|
-
in_features=self.embed_dim,
|
|
650
|
-
hidden_features=int(self.embed_dim * 4),
|
|
651
|
-
out_features=num_classes,
|
|
652
|
-
act_layer=nn.GELU,
|
|
653
|
-
drop=drop_ffn,
|
|
654
|
-
)
|
|
655
|
-
|
|
656
|
-
self.learned_agg = nn.Parameter(
|
|
657
|
-
torch.randn(1, 1, self.embed_dim), requires_grad=True
|
|
658
|
-
)
|
|
659
|
-
|
|
660
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
661
|
-
"""
|
|
662
|
-
Output shape:
|
|
663
|
-
[B, num_tokens, in_chans, input_dim]
|
|
664
|
-
Args:
|
|
665
|
-
x: [B, num_tokens+1, embed_dim]
|
|
666
|
-
channel_embeddings: [B, in_chans, embed_dim]
|
|
667
|
-
"""
|
|
668
|
-
B, num_patches, embed_dim = x.shape
|
|
669
|
-
decoder_queries = self.learned_agg.repeat(x.shape[0], 1, 1)
|
|
670
|
-
|
|
671
|
-
x = self.decoder_attn(query=decoder_queries, key=x, value=x)[0]
|
|
672
|
-
x = x[:, 0, :]
|
|
673
|
-
x = self.decoder_ffn(x)
|
|
674
|
-
return x
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
class _CrossAttentionBlock(nn.Module):
|
|
678
|
-
def __init__(
|
|
679
|
-
self,
|
|
680
|
-
num_queries: int,
|
|
681
|
-
input_embed_dim: int,
|
|
682
|
-
output_embed_dim: int,
|
|
683
|
-
num_heads: int,
|
|
684
|
-
dropout_p: float = 0.1,
|
|
685
|
-
ff_dim: int = 2048,
|
|
686
|
-
) -> None:
|
|
687
|
-
super().__init__()
|
|
688
|
-
self.num_queries = num_queries
|
|
689
|
-
self.dropout_p = dropout_p
|
|
690
|
-
self.query_embed = nn.Parameter(
|
|
691
|
-
torch.randn(1, num_queries, input_embed_dim), requires_grad=True
|
|
692
|
-
) # Learnable queries
|
|
693
|
-
self.cross_attention = nn.MultiheadAttention(
|
|
694
|
-
embed_dim=input_embed_dim,
|
|
695
|
-
num_heads=num_heads,
|
|
696
|
-
dropout=dropout_p,
|
|
697
|
-
batch_first=True,
|
|
698
|
-
)
|
|
699
|
-
self.temperature = nn.Parameter(torch.tensor(1.0), requires_grad=False)
|
|
700
|
-
|
|
701
|
-
self.ffn = _Mlp(
|
|
702
|
-
input_embed_dim,
|
|
703
|
-
ff_dim,
|
|
704
|
-
output_embed_dim,
|
|
705
|
-
act_layer=nn.GELU,
|
|
706
|
-
drop=dropout_p,
|
|
707
|
-
)
|
|
708
|
-
self.keys_norm = nn.LayerNorm(input_embed_dim)
|
|
709
|
-
self.values_norm = nn.LayerNorm(input_embed_dim)
|
|
710
|
-
self.queries_norm = nn.LayerNorm(input_embed_dim)
|
|
711
|
-
self.query_self_attn = nn.TransformerEncoder(
|
|
712
|
-
nn.TransformerEncoderLayer(
|
|
713
|
-
input_embed_dim,
|
|
714
|
-
nhead=num_heads,
|
|
715
|
-
activation="gelu",
|
|
716
|
-
dim_feedforward=ff_dim,
|
|
717
|
-
batch_first=True,
|
|
718
|
-
norm_first=True,
|
|
719
|
-
),
|
|
720
|
-
num_layers=3,
|
|
721
|
-
)
|
|
722
|
-
|
|
723
|
-
def initialize_weights(self) -> None:
|
|
724
|
-
torch.nn.init.orthogonal_(self.query_embed, gain=1.0)
|
|
725
|
-
self.apply(self._init_weights)
|
|
726
|
-
|
|
727
|
-
def _init_weights(self, m: nn.Module) -> None:
|
|
728
|
-
if isinstance(m, nn.Linear):
|
|
729
|
-
# we use xavier_uniform following official JAX ViT:
|
|
730
|
-
torch.nn.init.xavier_normal_(m.weight)
|
|
731
|
-
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
732
|
-
nn.init.constant_(m.bias, 0)
|
|
733
|
-
elif isinstance(m, nn.LayerNorm):
|
|
734
|
-
nn.init.constant_(m.bias, 0)
|
|
735
|
-
nn.init.constant_(m.weight, 1.0)
|
|
736
|
-
|
|
737
|
-
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
738
|
-
# x is the input with shape (batch_size*num_patches, num_channels, embed_dim)
|
|
739
|
-
batch_size, _, _ = x.size()
|
|
740
|
-
queries = self.query_embed.repeat(batch_size, 1, 1)
|
|
741
|
-
queries = self.queries_norm(queries)
|
|
742
|
-
keys = self.keys_norm(x)
|
|
743
|
-
values = self.values_norm(x)
|
|
744
|
-
|
|
745
|
-
attention_out, attention_scores = self.cross_attention(
|
|
746
|
-
query=queries, key=keys, value=values
|
|
747
|
-
) # Shape: (batch_size*num_patches, num_queries, embed_dim)
|
|
748
|
-
attention_out = self.ffn(attention_out) + attention_out
|
|
749
|
-
|
|
750
|
-
attention_out = self.query_self_attn(attention_out)
|
|
751
|
-
return (
|
|
752
|
-
attention_out,
|
|
753
|
-
attention_scores,
|
|
754
|
-
) # Shape: (batch_size*num_patches, num_queries, embed_dim)
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
class _PatchEmbedNetwork(nn.Module):
|
|
758
|
-
def __init__(self, embed_dim: int = 64, patch_size: int = 40) -> None:
|
|
759
|
-
super().__init__()
|
|
760
|
-
self.patch_size = patch_size
|
|
761
|
-
self.embed_dim = embed_dim
|
|
762
|
-
self.in_channels = 1
|
|
763
|
-
self.out_channels = int(embed_dim // 4)
|
|
764
|
-
self.groups = 4
|
|
765
|
-
self.kernel_size = int(patch_size // 2)
|
|
766
|
-
self.proj_in = nn.Sequential(
|
|
767
|
-
nn.Conv2d(
|
|
768
|
-
in_channels=self.in_channels,
|
|
769
|
-
out_channels=self.out_channels,
|
|
770
|
-
kernel_size=(1, self.kernel_size - 1),
|
|
771
|
-
stride=(1, self.kernel_size // 2),
|
|
772
|
-
padding=(0, self.kernel_size // 2 - 1),
|
|
773
|
-
),
|
|
774
|
-
nn.GroupNorm(self.groups, self.out_channels),
|
|
775
|
-
nn.GELU(),
|
|
776
|
-
nn.Conv2d(
|
|
777
|
-
in_channels=self.out_channels,
|
|
778
|
-
out_channels=self.out_channels,
|
|
779
|
-
kernel_size=(1, 3),
|
|
780
|
-
stride=(1, 1),
|
|
781
|
-
padding=(0, 1),
|
|
782
|
-
),
|
|
783
|
-
nn.GroupNorm(self.groups, self.out_channels),
|
|
784
|
-
nn.GELU(),
|
|
785
|
-
nn.Conv2d(
|
|
786
|
-
in_channels=self.out_channels,
|
|
787
|
-
out_channels=self.out_channels,
|
|
788
|
-
kernel_size=(1, 3),
|
|
789
|
-
stride=(1, 1),
|
|
790
|
-
padding=(0, 1),
|
|
791
|
-
),
|
|
792
|
-
nn.GroupNorm(self.groups, self.out_channels),
|
|
793
|
-
nn.GELU(),
|
|
794
|
-
)
|
|
795
|
-
|
|
796
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
797
|
-
"""
|
|
798
|
-
x: (B, C, T)
|
|
799
|
-
output: (B, C*S, D) where S = T//patch_size, D = embed_dim
|
|
800
|
-
"""
|
|
801
|
-
x = rearrange(x, "B C (S P) -> B (C S) P", P=self.patch_size)
|
|
802
|
-
x = x.unsqueeze(1)
|
|
803
|
-
x = self.proj_in(x)
|
|
804
|
-
x = rearrange(x, "B E CS D -> B CS (D E)")
|
|
805
|
-
return x
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
class _Mlp(nn.Module):
|
|
809
|
-
"""MLP as used in Vision Transformer, MLP-Mixer and related networks.
|
|
810
|
-
|
|
811
|
-
Code copied from timm.models.mlp.Mlp
|
|
812
|
-
"""
|
|
813
|
-
|
|
814
|
-
def __init__(
|
|
815
|
-
self,
|
|
816
|
-
in_features,
|
|
817
|
-
hidden_features=None,
|
|
818
|
-
out_features=None,
|
|
819
|
-
act_layer=nn.GELU,
|
|
820
|
-
drop=0.0,
|
|
821
|
-
):
|
|
822
|
-
super().__init__()
|
|
823
|
-
out_features = out_features or in_features
|
|
824
|
-
hidden_features = hidden_features or in_features
|
|
825
|
-
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
826
|
-
self.act = act_layer()
|
|
827
|
-
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
828
|
-
self.drop = nn.Dropout(drop)
|
|
829
|
-
|
|
830
|
-
def forward(self, x):
|
|
831
|
-
x = self.fc1(x)
|
|
832
|
-
x = self.act(x)
|
|
833
|
-
x = self.drop(x)
|
|
834
|
-
x = self.fc2(x)
|
|
835
|
-
x = self.drop(x)
|
|
836
|
-
return x
|