braindecode 1.3.0.dev177628147__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,836 +0,0 @@
1
- """This implementation is adapted from ETH Zurich's BioFoundation repository.
2
-
3
- Döner, B., Ingolfsson, T. M., Benini, L., & Li, Y. (2025).
4
- LUNA: Efficient and Topology-Agnostic Foundation Model for EEG Signal Analysis.
5
- The Thirty-Ninth Annual Conference on Neural Information Processing Systems, NeurIPS.
6
- Retrieved from https://openreview.net/forum?id=uazfjnFL0G
7
-
8
- Original Authors: Berkay Döner, Thorir Mar Ingolfsson
9
- Braindecode Adaptation: Bruno Aristimunha
10
-
11
- the LICENSE Of this file is APACHE-2.0.
12
- """
13
-
14
- import math
15
- from typing import Any, Dict, Optional, Tuple, Type
16
-
17
- import torch
18
- import torch.fft as fft
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- from einops import rearrange
22
- from rotary_embedding_torch import RotaryEmbedding
23
-
24
- from braindecode.models.base import EEGModuleMixin
25
- from braindecode.models.util import extract_channel_locations_from_chs_info
26
- from braindecode.modules.layers import DropPath
27
-
28
-
29
- class LUNA(EEGModuleMixin, nn.Module):
30
- """LUNA from Döner et al. [LUNA]_.
31
-
32
- :bdg-success:`Convolution` :bdg-danger:`Large Brain Model` :bdg-dark-line:`Channel`
33
-
34
- .. figure:: https://arxiv.org/html/2510.22257v1/x1.png
35
- :align: center
36
- :alt: LUNA Architecture.
37
-
38
- LUNA is a topology-invariant EEG model that processes signals from varying
39
- numbers of channels using a channel-unification mechanism with learned queries.
40
-
41
- The architecture consists of:
42
- 1. Patch Feature Extraction (temporal CNN + FFT-based features)
43
- 2. Channel-Unification Module (cross-attention with learned queries)
44
- 3. Patch-wise Temporal Encoder (RoPE-based transformer)
45
- 4. Decoder Heads (classification or reconstruction)
46
-
47
- Parameters
48
- ----------
49
- patch_size : int
50
- Number of time samples per patch. Default: 40.
51
- num_queries : int
52
- Number of learned queries for channel unification.
53
- Paper uses: 4 (Base), 6 (Large), 8 (Huge). Default: 4.
54
- embed_dim : int
55
- Embedding dimension for patch features.
56
- Paper uses: 64 (Base), 96 (Large), 128 (Huge). Default: 64.
57
- depth : int
58
- Number of transformer encoder blocks.
59
- Paper uses: 8 (Base), 10 (Large), 24 (Huge). Default: 8.
60
- num_heads : int
61
- Number of attention heads in channel unification.
62
- Default: 2.
63
- mlp_ratio : float
64
- Ratio of MLP hidden dimension to embedding dimension. Default: 4.0.
65
- norm_layer : nn.Module
66
- Normalization layer class. Default: nn.LayerNorm.
67
- drop_path : float
68
- Stochastic depth rate. Default: 0.0.
69
-
70
- References
71
- ----------
72
- .. [LUNA] Döner, B., Ingolfsson, T. M., Benini, L., & Li, Y. (2025).
73
- LUNA: Efficient and Topology-Agnostic Foundation Model for EEG Signal Analysis.
74
- The Thirty-Ninth Annual Conference on Neural Information Processing Systems - NeurIPS.
75
- Retrieved from https://openreview.net/forum?id=uazfjnFL0G
76
- """
77
-
78
- def __init__(
79
- self,
80
- # Braindecode EEGModuleMixin parameters
81
- n_outputs: Optional[int] = None,
82
- n_chans: Optional[int] = None,
83
- n_times: Optional[int] = None,
84
- sfreq: Optional[float] = None,
85
- chs_info: Optional[Any] = None,
86
- input_window_seconds: Optional[float] = None,
87
- # Model-specific parameters
88
- patch_size: int = 40,
89
- num_queries: int = 4,
90
- embed_dim: int = 64,
91
- depth: int = 8,
92
- num_heads: int = 2,
93
- mlp_ratio: float = 4.0,
94
- norm_layer: Type[nn.Module] = nn.LayerNorm,
95
- drop_path: float = 0.0,
96
- drop_prob_chan: float = 0.0,
97
- attn_drop: float = 0.0,
98
- activation: Type[nn.Module] = nn.GELU,
99
- ):
100
- super().__init__(
101
- n_outputs=n_outputs,
102
- n_chans=n_chans,
103
- n_times=n_times,
104
- sfreq=sfreq,
105
- chs_info=chs_info,
106
- input_window_seconds=input_window_seconds,
107
- )
108
- del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
109
-
110
- # Mapping for loading pretrained weights
111
- self.mapping = {
112
- "channel_location_embedder.0.fc1.weight": "channel_location_embedder.fc1.weight",
113
- "channel_location_embedder.0.fc1.bias": "channel_location_embedder.fc1.bias",
114
- "channel_location_embedder.0.fc2.weight": "channel_location_embedder.fc2.weight",
115
- "channel_location_embedder.0.fc2.bias": "channel_location_embedder.fc2.bias",
116
- "channel_location_embedder.0.norm.weight": "channel_location_embedder.norm.weight",
117
- "channel_location_embedder.0.norm.bias": "channel_location_embedder.norm.bias",
118
- }
119
-
120
- # Model parameters
121
- self.num_classes = self.n_outputs if self.n_outputs else 0
122
- self.embed_dim = embed_dim
123
- self.num_queries = num_queries
124
- self.patch_size = patch_size
125
- self.patch_embed_size = embed_dim
126
- self.num_heads = num_heads
127
- self.depth = depth
128
- self.drop_path = drop_path
129
- self.attn_drop = attn_drop
130
- self.drop_prob_chan = drop_prob_chan
131
- self.mlp_ratio = mlp_ratio
132
- self.activation = activation
133
-
134
- # Layers
135
- self.patch_embed = _PatchEmbedNetwork(
136
- embed_dim=self.embed_dim, patch_size=self.patch_size
137
- )
138
- self.freq_embed = _FrequencyFeatureEmbedder(
139
- embed_dim=self.embed_dim, patch_size=self.patch_size
140
- )
141
- # For weight loading, we omit the normalization here to match parameter count
142
- self.channel_location_embedder = _Mlp(
143
- in_features=int(self.patch_embed_size),
144
- out_features=int(self.patch_embed_size),
145
- hidden_features=int(self.patch_embed_size * 2),
146
- act_layer=self.activation,
147
- drop=self.drop_prob_chan,
148
- )
149
- self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
150
- self.cross_attn = _CrossAttentionBlock(
151
- num_queries=self.num_queries,
152
- input_embed_dim=self.embed_dim,
153
- output_embed_dim=self.embed_dim,
154
- num_heads=self.num_heads,
155
- ff_dim=int(self.mlp_ratio * self.embed_dim),
156
- )
157
- self.blocks = nn.ModuleList(
158
- [
159
- _RotaryTransformerBlock(
160
- dim=int(self.embed_dim * self.num_queries),
161
- num_heads=int(self.num_heads * self.num_queries),
162
- mlp_ratio=self.mlp_ratio,
163
- qkv_bias=True,
164
- drop=0.0,
165
- attn_drop=0.0,
166
- drop_path=self.drop_path,
167
- norm_layer=norm_layer,
168
- )
169
- for i in range(depth)
170
- ]
171
- )
172
- self.norm = norm_layer(int(self.embed_dim * self.num_queries))
173
-
174
- self._channel_location_cache: Dict[int, torch.Tensor] = {}
175
-
176
- if self.num_classes == 0:
177
- self.decoder_head = _PatchReconstructionHeadWithQueries(
178
- input_dim=self.patch_size,
179
- embed_dim=self.embed_dim,
180
- num_heads=self.num_heads,
181
- num_queries=self.num_queries,
182
- )
183
- self.channel_emb = _ChannelEmbeddings(self.embed_dim)
184
- else:
185
- self.final_layer = _ClassificationHeadWithQueries(
186
- input_dim=self.patch_size,
187
- num_queries=self.num_queries,
188
- embed_dim=self.embed_dim,
189
- num_classes=self.num_classes,
190
- num_heads=self.num_heads,
191
- )
192
- self.mask_token.requires_grad = (
193
- False # no use of mask token for classification
194
- )
195
-
196
- self.initialize_weights()
197
-
198
- def initialize_weights(self) -> None:
199
- self.cross_attn.initialize_weights()
200
- trunc_normal_(self.mask_token, std=0.02)
201
- self.apply(self._init_weights)
202
- self.fix_init_weight()
203
-
204
- def _init_weights(self, m: nn.Module) -> None:
205
- if isinstance(m, nn.Linear):
206
- torch.nn.init.xavier_normal_(m.weight)
207
- if isinstance(m, nn.Linear) and m.bias is not None:
208
- nn.init.constant_(m.bias, 0)
209
- elif isinstance(m, nn.LayerNorm):
210
- nn.init.constant_(m.bias, 0)
211
- nn.init.constant_(m.weight, 1.0)
212
-
213
- def fix_init_weight(self) -> None:
214
- def rescale(param: torch.Tensor, layer_id: int) -> None:
215
- param.div_(math.sqrt(2.0 * layer_id))
216
-
217
- for layer_id, layer in enumerate(self.blocks):
218
- rescale(layer.attn.proj.weight.data, layer_id + 1)
219
- rescale(layer.mlp.fc2.weight.data, layer_id + 1)
220
-
221
- def prepare_tokens(
222
- self,
223
- x_signal: torch.Tensor,
224
- channel_locations: torch.Tensor,
225
- mask: Optional[torch.Tensor] = None,
226
- ) -> Tuple[torch.Tensor, torch.Tensor]:
227
- num_channels = channel_locations.shape[1]
228
- num_patches_per_channel = x_signal.shape[-1] // self.patch_size
229
- x_patched = self.patch_embed(x_signal)
230
- freq_embed = self.freq_embed(x_signal)
231
- x_patched = x_patched + freq_embed
232
- x_masked = x_patched.clone() # (B, N, D), N = C * num_patches_per_channel
233
-
234
- if mask is not None:
235
- mask_tokens = self.mask_token.repeat(
236
- x_masked.shape[0], x_masked.shape[1], 1
237
- ) # (B, N, D) N = C * num_patches_per_channel
238
- mask = rearrange(
239
- mask, "B C (S P) -> B (C S) P", P=self.patch_size
240
- ) # (B, C, T) -> (B, N, P)
241
- mask = (
242
- (mask.sum(dim=-1) > 0).unsqueeze(-1).float()
243
- ) # (B, N, 1), since a patch is either fully masked or not
244
- x_masked = torch.where(mask.bool(), mask_tokens, x_masked)
245
-
246
- channel_min = torch.min(channel_locations, dim=1, keepdim=True)[0]
247
- channel_max = torch.max(channel_locations, dim=1, keepdim=True)[0]
248
- channel_locations = (channel_locations - channel_min) / (
249
- channel_max - channel_min + 1e-8
250
- )
251
-
252
- if mask is not None:
253
- channel_locations = (
254
- channel_locations + torch.randn_like(channel_locations) * 0.02
255
- )
256
-
257
- channel_locations = nerf_positional_encoding(
258
- channel_locations, self.patch_embed_size
259
- )
260
- channel_locations_emb = self.channel_location_embedder(channel_locations)
261
-
262
- x_tokenized = rearrange(x_masked, "B (C t) D -> (B t) C D", C=num_channels)
263
- channel_locations_emb = channel_locations_emb.repeat(
264
- num_patches_per_channel, 1, 1
265
- )
266
- x_tokenized = x_tokenized + channel_locations_emb
267
-
268
- return x_tokenized, channel_locations_emb
269
-
270
- def forward(
271
- self,
272
- X: torch.Tensor,
273
- mask: Optional[torch.Tensor] = None,
274
- channel_locations: Optional[torch.Tensor] = None,
275
- channel_names: Optional[torch.Tensor] = None,
276
- ) -> torch.Tensor:
277
- """Forward pass."""
278
- x_signal = X
279
- B, C, _ = x_signal.shape
280
-
281
- if channel_locations is None:
282
- channel_locations = self.get_default_channel_locations(
283
- batch_size=B,
284
- num_channels=C,
285
- device=x_signal.device,
286
- dtype=x_signal.dtype,
287
- )
288
-
289
- x, channel_locations_emb = self.prepare_tokens(
290
- x_signal, channel_locations, mask=mask
291
- )
292
- x, _ = self.cross_attn(x)
293
- x = rearrange(x, "(B t) Q D -> B t (Q D)", B=B)
294
- num_patches = x.shape[1]
295
-
296
- for blk in self.blocks:
297
- x = blk(x)
298
- x_latent = self.norm(x)
299
-
300
- if self.num_classes > 0:
301
- return self.final_layer(x_latent)
302
-
303
- if channel_names is None:
304
- raise ValueError("channel_names must be provided for reconstruction tasks.")
305
- channel_emb = self.channel_emb(channel_names)
306
- channel_emb = channel_emb.repeat(num_patches, 1, 1)
307
- decoder_queries = channel_locations_emb + channel_emb
308
- return self.decoder_head(x_latent, decoder_queries)
309
-
310
- def get_default_channel_locations(
311
- self,
312
- batch_size: int,
313
- num_channels: int,
314
- device: torch.device,
315
- dtype: torch.dtype,
316
- ) -> torch.Tensor:
317
- if num_channels not in self._channel_location_cache:
318
- template = self.build_channel_location_template(num_channels)
319
- self._channel_location_cache[num_channels] = template
320
- template = self._channel_location_cache[num_channels].to(
321
- device=device, dtype=dtype
322
- )
323
- return template.unsqueeze(0).repeat(batch_size, 1, 1)
324
-
325
- def build_channel_location_template(self, num_channels: int) -> torch.Tensor:
326
- """Build channel location template for the model.
327
-
328
- Attempts to extract channel locations from chs_info. Falls back to a default
329
- linear spacing along the x-axis if real locations are unavailable.
330
-
331
- Parameters
332
- ----------
333
- num_channels : int
334
- Number of channels to generate locations for.
335
-
336
- Returns
337
- -------
338
- torch.Tensor
339
- Tensor of shape (num_channels, 3) with channel locations in 3D space.
340
- """
341
- # Try to extract channel locations from chs_info using the unified utility
342
- channel_info = getattr(self, "_chs_info", None)
343
- if channel_info is not None:
344
- locs = extract_channel_locations_from_chs_info(
345
- channel_info, num_channels=num_channels
346
- )
347
- if locs is not None and len(locs) == num_channels:
348
- return torch.from_numpy(locs).float()
349
-
350
- # Fallback: generate default linear spacing along x-axis
351
- positions = torch.linspace(-1.0, 1.0, steps=num_channels, dtype=torch.float32)
352
- zeros = torch.zeros_like(positions)
353
- locs_tensor = torch.stack([positions, zeros, zeros], dim=-1)
354
- return locs_tensor
355
-
356
- def _get_default_channel_locations(
357
- self,
358
- batch_size: int,
359
- num_channels: int,
360
- device: torch.device,
361
- dtype: torch.dtype,
362
- ) -> torch.Tensor:
363
- return self.get_default_channel_locations(
364
- batch_size=batch_size,
365
- num_channels=num_channels,
366
- device=device,
367
- dtype=dtype,
368
- )
369
-
370
- def _build_channel_location_template(self, num_channels: int) -> torch.Tensor:
371
- return self.build_channel_location_template(num_channels)
372
-
373
-
374
- def trunc_normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> None:
375
- nn.init.trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
376
-
377
-
378
- def nerf_positional_encoding(coords: torch.Tensor, embed_size: int) -> torch.Tensor:
379
- """
380
- coords: (N, C, 3)
381
- Returns: (N, C, embed_size)
382
- """
383
- N, C, dim = coords.shape
384
- device = coords.device
385
- freqs = embed_size // (2 * dim)
386
- leftover = embed_size - freqs * 2 * dim
387
- freq_bands = 2.0 ** torch.arange(freqs, device=device).float()
388
- scaled_coords = coords.unsqueeze(-1) * freq_bands.view(
389
- 1, 1, 1, -1
390
- ) # (N, C, dim, freqs)
391
- sin_enc = torch.sin(scaled_coords) # (N, C, dim, freqs)
392
- cos_enc = torch.cos(scaled_coords) # (N, C, dim, freqs)
393
- encoded = (
394
- torch.stack([sin_enc, cos_enc], dim=-1)
395
- .permute(0, 1, 3, 2, 4)
396
- .reshape(N, C, freqs * dim * 2)
397
- )
398
- if leftover > 0:
399
- pad = torch.zeros(N, C, leftover, device=device, dtype=coords.dtype)
400
- encoded = torch.cat([encoded, pad], dim=-1)
401
- return encoded
402
-
403
-
404
- class _ChannelEmbeddings(nn.Module):
405
- """
406
- This class creates embeddings for each EEG channel based on a predefined
407
- mapping of channel names to indices.
408
-
409
- The number of unique channels is determined by the union of channels
410
- from SEED Pretraining, TUEG, and Siena datasets.
411
-
412
- Parameters
413
- ----------
414
- embed_dim : int
415
- Dimension of the channel embeddings.
416
- number_channels : int
417
- Number of unique EEG channels. Default is 90.
418
-
419
- """
420
-
421
- def __init__(self, embed_dim: int, number_channels=90) -> None:
422
- super().__init__()
423
- self.embeddings = nn.Embedding(number_channels, embed_dim)
424
-
425
- def forward(self, indices: torch.Tensor) -> torch.Tensor:
426
- return self.embeddings(indices)
427
-
428
- def initialize_weights(self) -> None:
429
- torch.init.normal_(self.embeddings.weight, std=2.0)
430
-
431
-
432
- class _FrequencyFeatureEmbedder(nn.Module):
433
- """
434
- This class takes data that is of the form (B, C, T) and patches it
435
- along the time dimension (T) into patches of size P (patch_size).
436
- The output is of the form (B, C, S, P) where S = T // P.
437
- """
438
-
439
- def __init__(self, patch_size: int, embed_dim: int) -> None:
440
- super().__init__()
441
- self.patch_size = patch_size
442
- self.embed_dim = embed_dim
443
- in_features = 2 * (patch_size // 2 + 1)
444
- self.frequency_to_embed = _Mlp(
445
- in_features=in_features,
446
- hidden_features=int(4 * in_features),
447
- out_features=embed_dim,
448
- )
449
-
450
- def forward(self, x: torch.Tensor) -> torch.Tensor:
451
- B, C, T = x.size()
452
- S = T // self.patch_size
453
- # There is a chance that the input tensor is not divisible by the patch size
454
- # In this case we need to pad the tensor with zeros
455
- if T % self.patch_size != 0:
456
- # Pad last dimension with zeros to make it divisible by patch size
457
- pad_size = self.patch_size - (T % self.patch_size)
458
- x = F.pad(x, (0, pad_size))
459
- T = x.size(-1)
460
- S = T // self.patch_size
461
- x = x.view(B, C, S, self.patch_size)
462
-
463
- freq_representation = fft.rfft(
464
- x, dim=-1
465
- ) # (B, C, num_patches, patch_size // 2 + 1)
466
- magnitude = torch.abs(freq_representation)
467
- phase = torch.angle(freq_representation)
468
-
469
- # Concatenate magnitude and phase along the frequency axis (last dimension)
470
- freq_features = torch.cat((magnitude, phase), dim=-1)
471
- # Map frequency features to embedding dimension
472
- embedded = self.frequency_to_embed(
473
- freq_features
474
- ) # (B, C, num_patches, embed_dim)
475
- embedded = rearrange(embedded, "B C t D -> B (C t) D")
476
- return embedded
477
-
478
-
479
- class _RotarySelfAttentionBlock(nn.Module):
480
- def __init__(
481
- self,
482
- dim: int,
483
- num_heads: int = 8,
484
- qkv_bias: bool = True,
485
- qk_scale: Optional[float] = None,
486
- attn_drop: float = 0.0,
487
- proj_drop: float = 0.0,
488
- ) -> None:
489
- super().__init__()
490
- self.dim = dim
491
- self.num_heads = num_heads
492
- head_dim = dim // num_heads
493
- self.rotary_emb = RotaryEmbedding(dim=head_dim, learned_freq=False)
494
-
495
- self.scale = qk_scale or head_dim**-0.5
496
-
497
- self.qkv_proj = nn.Linear(dim, dim * 3, bias=qkv_bias)
498
- self.attn_drop = attn_drop
499
- self.attn_drop_fn = nn.Dropout(attn_drop)
500
- self.proj = nn.Linear(dim, dim)
501
- self.proj_drop = nn.Dropout(proj_drop)
502
-
503
- def forward(self, x: torch.Tensor) -> torch.Tensor:
504
- B, N, C = x.shape
505
- qkv = (
506
- self.qkv_proj(x)
507
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
508
- .permute(2, 0, 3, 1, 4)
509
- ) # (K, B, H, N, D)
510
- q, k, v = qkv[0], qkv[1], qkv[2]
511
- q = self.rotary_emb.rotate_queries_or_keys(q)
512
- k = self.rotary_emb.rotate_queries_or_keys(k)
513
- # Calculate attention scores
514
- attn_weights = (q @ k.transpose(-2, -1)) * self.scale # (B, H, N, N)
515
-
516
- # Apply softmax to get attention probabilities
517
- attn_weights = torch.softmax(attn_weights, dim=-1)
518
-
519
- # Apply dropout
520
- attn_weights = self.attn_drop_fn(attn_weights)
521
-
522
- # Apply attention weights to values
523
- attn = attn_weights @ v # (B, H, N, D)
524
- attn = rearrange(attn, "B H N D -> B N (H D)")
525
- return self.proj_drop(self.proj(attn))
526
-
527
-
528
- class _FeedForwardBlock(nn.Module):
529
- def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
530
- super().__init__()
531
- self.fc1 = nn.Linear(dim, hidden_dim)
532
- self.activation = nn.GELU()
533
- self.dropout1 = nn.Dropout(dropout)
534
- self.dropout2 = nn.Dropout(dropout)
535
- self.fc2 = nn.Linear(hidden_dim, dim)
536
- self.norm = nn.LayerNorm(hidden_dim)
537
-
538
- def forward(self, x: torch.Tensor) -> torch.Tensor:
539
- x = self.fc1(x)
540
- x = self.activation(x)
541
- x = self.dropout1(x)
542
- x = self.norm(x)
543
- x = self.fc2(x)
544
- x = self.dropout2(x)
545
- return x
546
-
547
-
548
- class _RotaryTransformerBlock(nn.Module):
549
- def __init__(
550
- self,
551
- dim: int,
552
- num_heads: int,
553
- mlp_ratio: float = 4.0,
554
- qkv_bias: bool = False,
555
- qk_scale: Optional[float] = None,
556
- drop: float = 0.0,
557
- attn_drop: float = 0.0,
558
- drop_path: float = 0.0,
559
- norm_layer: Type[nn.Module] = nn.LayerNorm,
560
- ) -> None:
561
- super().__init__()
562
- self.norm1 = norm_layer(dim)
563
- self.attn = _RotarySelfAttentionBlock(
564
- dim=dim,
565
- num_heads=num_heads,
566
- qkv_bias=qkv_bias,
567
- qk_scale=qk_scale,
568
- attn_drop=attn_drop,
569
- proj_drop=drop,
570
- )
571
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
572
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
573
- self.norm2 = norm_layer(dim)
574
- self.mlp = _FeedForwardBlock(
575
- dim=dim, hidden_dim=int(dim * mlp_ratio), dropout=drop
576
- )
577
-
578
- def forward(self, x: torch.Tensor) -> torch.Tensor:
579
- x = x + self.drop_path1(self.attn(self.norm1(x)))
580
- x = x + self.drop_path2(self.mlp(self.norm2(x)))
581
- return x
582
-
583
-
584
- class _PatchReconstructionHeadWithQueries(nn.Module):
585
- def __init__(
586
- self,
587
- input_dim: int = 8,
588
- embed_dim: int = 768,
589
- num_heads: int = 8,
590
- num_queries: int = 4,
591
- ) -> None:
592
- super().__init__()
593
- self.input_dim = input_dim
594
- self.embed_dim = embed_dim
595
- self.reconstruction_shape = self.input_dim
596
- self.num_queries = num_queries
597
- # Projection from embed space to pixel space, according to type of input
598
- self.decoder_pred = nn.TransformerDecoder(
599
- nn.TransformerDecoderLayer(
600
- embed_dim,
601
- num_heads,
602
- dropout=0.0,
603
- batch_first=True,
604
- activation="gelu",
605
- dim_feedforward=int(embed_dim * 4),
606
- norm_first=True,
607
- ),
608
- num_layers=1,
609
- )
610
- self.norm = nn.LayerNorm(embed_dim)
611
- self.decoder_linear = _Mlp(
612
- embed_dim, int(embed_dim * 4), input_dim, act_layer=nn.GELU, drop=0.0
613
- ) # nn.Linear(embed_dim, input_dim, bias=True)
614
-
615
- def forward(self, enc: torch.Tensor, decoder_queries: torch.Tensor) -> torch.Tensor:
616
- """
617
- enc: [B, num_patches, embed_dim], embed_dim = Q*D
618
- decoder_queries: [B*num_patches, num_channels, embed_dim]
619
- """
620
-
621
- B, num_patches, embed_dim = enc.shape
622
- enc = rearrange(enc, "B t (Q D) -> (B t) Q D", Q=self.num_queries)
623
- out = self.decoder_pred(decoder_queries, enc) # (B*t, C, D)
624
- out = self.norm(out)
625
- out = self.decoder_linear(out) # (B*t, C, patch_size)
626
- out = rearrange(out, "(B t) C P -> B C (t P)", B=B)
627
- return out
628
-
629
-
630
- class _ClassificationHeadWithQueries(nn.Module):
631
- def __init__(
632
- self,
633
- input_dim: int = 8,
634
- embed_dim: int = 768,
635
- num_queries: int = 8,
636
- num_heads: int = 8,
637
- num_classes: int = 2,
638
- drop_decoder: float = 0.15,
639
- drop_ffn: float = 0.15,
640
- ) -> None:
641
- super().__init__()
642
- self.input_dim = input_dim
643
- self.embed_dim = int(embed_dim * num_queries)
644
- self.reconstruction_shape = self.input_dim
645
- self.decoder_attn = nn.MultiheadAttention(
646
- self.embed_dim, num_heads, batch_first=True, dropout=drop_decoder
647
- )
648
- self.decoder_ffn = _Mlp(
649
- in_features=self.embed_dim,
650
- hidden_features=int(self.embed_dim * 4),
651
- out_features=num_classes,
652
- act_layer=nn.GELU,
653
- drop=drop_ffn,
654
- )
655
-
656
- self.learned_agg = nn.Parameter(
657
- torch.randn(1, 1, self.embed_dim), requires_grad=True
658
- )
659
-
660
- def forward(self, x: torch.Tensor) -> torch.Tensor:
661
- """
662
- Output shape:
663
- [B, num_tokens, in_chans, input_dim]
664
- Args:
665
- x: [B, num_tokens+1, embed_dim]
666
- channel_embeddings: [B, in_chans, embed_dim]
667
- """
668
- B, num_patches, embed_dim = x.shape
669
- decoder_queries = self.learned_agg.repeat(x.shape[0], 1, 1)
670
-
671
- x = self.decoder_attn(query=decoder_queries, key=x, value=x)[0]
672
- x = x[:, 0, :]
673
- x = self.decoder_ffn(x)
674
- return x
675
-
676
-
677
- class _CrossAttentionBlock(nn.Module):
678
- def __init__(
679
- self,
680
- num_queries: int,
681
- input_embed_dim: int,
682
- output_embed_dim: int,
683
- num_heads: int,
684
- dropout_p: float = 0.1,
685
- ff_dim: int = 2048,
686
- ) -> None:
687
- super().__init__()
688
- self.num_queries = num_queries
689
- self.dropout_p = dropout_p
690
- self.query_embed = nn.Parameter(
691
- torch.randn(1, num_queries, input_embed_dim), requires_grad=True
692
- ) # Learnable queries
693
- self.cross_attention = nn.MultiheadAttention(
694
- embed_dim=input_embed_dim,
695
- num_heads=num_heads,
696
- dropout=dropout_p,
697
- batch_first=True,
698
- )
699
- self.temperature = nn.Parameter(torch.tensor(1.0), requires_grad=False)
700
-
701
- self.ffn = _Mlp(
702
- input_embed_dim,
703
- ff_dim,
704
- output_embed_dim,
705
- act_layer=nn.GELU,
706
- drop=dropout_p,
707
- )
708
- self.keys_norm = nn.LayerNorm(input_embed_dim)
709
- self.values_norm = nn.LayerNorm(input_embed_dim)
710
- self.queries_norm = nn.LayerNorm(input_embed_dim)
711
- self.query_self_attn = nn.TransformerEncoder(
712
- nn.TransformerEncoderLayer(
713
- input_embed_dim,
714
- nhead=num_heads,
715
- activation="gelu",
716
- dim_feedforward=ff_dim,
717
- batch_first=True,
718
- norm_first=True,
719
- ),
720
- num_layers=3,
721
- )
722
-
723
- def initialize_weights(self) -> None:
724
- torch.nn.init.orthogonal_(self.query_embed, gain=1.0)
725
- self.apply(self._init_weights)
726
-
727
- def _init_weights(self, m: nn.Module) -> None:
728
- if isinstance(m, nn.Linear):
729
- # we use xavier_uniform following official JAX ViT:
730
- torch.nn.init.xavier_normal_(m.weight)
731
- if isinstance(m, nn.Linear) and m.bias is not None:
732
- nn.init.constant_(m.bias, 0)
733
- elif isinstance(m, nn.LayerNorm):
734
- nn.init.constant_(m.bias, 0)
735
- nn.init.constant_(m.weight, 1.0)
736
-
737
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
738
- # x is the input with shape (batch_size*num_patches, num_channels, embed_dim)
739
- batch_size, _, _ = x.size()
740
- queries = self.query_embed.repeat(batch_size, 1, 1)
741
- queries = self.queries_norm(queries)
742
- keys = self.keys_norm(x)
743
- values = self.values_norm(x)
744
-
745
- attention_out, attention_scores = self.cross_attention(
746
- query=queries, key=keys, value=values
747
- ) # Shape: (batch_size*num_patches, num_queries, embed_dim)
748
- attention_out = self.ffn(attention_out) + attention_out
749
-
750
- attention_out = self.query_self_attn(attention_out)
751
- return (
752
- attention_out,
753
- attention_scores,
754
- ) # Shape: (batch_size*num_patches, num_queries, embed_dim)
755
-
756
-
757
- class _PatchEmbedNetwork(nn.Module):
758
- def __init__(self, embed_dim: int = 64, patch_size: int = 40) -> None:
759
- super().__init__()
760
- self.patch_size = patch_size
761
- self.embed_dim = embed_dim
762
- self.in_channels = 1
763
- self.out_channels = int(embed_dim // 4)
764
- self.groups = 4
765
- self.kernel_size = int(patch_size // 2)
766
- self.proj_in = nn.Sequential(
767
- nn.Conv2d(
768
- in_channels=self.in_channels,
769
- out_channels=self.out_channels,
770
- kernel_size=(1, self.kernel_size - 1),
771
- stride=(1, self.kernel_size // 2),
772
- padding=(0, self.kernel_size // 2 - 1),
773
- ),
774
- nn.GroupNorm(self.groups, self.out_channels),
775
- nn.GELU(),
776
- nn.Conv2d(
777
- in_channels=self.out_channels,
778
- out_channels=self.out_channels,
779
- kernel_size=(1, 3),
780
- stride=(1, 1),
781
- padding=(0, 1),
782
- ),
783
- nn.GroupNorm(self.groups, self.out_channels),
784
- nn.GELU(),
785
- nn.Conv2d(
786
- in_channels=self.out_channels,
787
- out_channels=self.out_channels,
788
- kernel_size=(1, 3),
789
- stride=(1, 1),
790
- padding=(0, 1),
791
- ),
792
- nn.GroupNorm(self.groups, self.out_channels),
793
- nn.GELU(),
794
- )
795
-
796
- def forward(self, x: torch.Tensor) -> torch.Tensor:
797
- """
798
- x: (B, C, T)
799
- output: (B, C*S, D) where S = T//patch_size, D = embed_dim
800
- """
801
- x = rearrange(x, "B C (S P) -> B (C S) P", P=self.patch_size)
802
- x = x.unsqueeze(1)
803
- x = self.proj_in(x)
804
- x = rearrange(x, "B E CS D -> B CS (D E)")
805
- return x
806
-
807
-
808
- class _Mlp(nn.Module):
809
- """MLP as used in Vision Transformer, MLP-Mixer and related networks.
810
-
811
- Code copied from timm.models.mlp.Mlp
812
- """
813
-
814
- def __init__(
815
- self,
816
- in_features,
817
- hidden_features=None,
818
- out_features=None,
819
- act_layer=nn.GELU,
820
- drop=0.0,
821
- ):
822
- super().__init__()
823
- out_features = out_features or in_features
824
- hidden_features = hidden_features or in_features
825
- self.fc1 = nn.Linear(in_features, hidden_features)
826
- self.act = act_layer()
827
- self.fc2 = nn.Linear(hidden_features, out_features)
828
- self.drop = nn.Dropout(drop)
829
-
830
- def forward(self, x):
831
- x = self.fc1(x)
832
- x = self.act(x)
833
- x = self.drop(x)
834
- x = self.fc2(x)
835
- x = self.drop(x)
836
- return x