broccoli-ml 9.6.0__tar.gz → 13.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/PKG-INFO +1 -1
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/broccoli/transformer.py +152 -49
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/broccoli/vit.py +15 -12
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/pyproject.toml +1 -1
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/LICENSE +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/README.md +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/broccoli/activation.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/broccoli/linear.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/broccoli/rope.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-13.0.0}/broccoli/utils.py +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import math
|
|
2
3
|
from typing import Optional, Tuple
|
|
3
4
|
|
|
@@ -20,6 +21,12 @@ except ImportError:
|
|
|
20
21
|
FLASH_ATTN = False
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
def scale_parameters(torch_module: nn.Module, factor: float):
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
for param in torch_module.parameters():
|
|
27
|
+
param.mul_(factor)
|
|
28
|
+
|
|
29
|
+
|
|
23
30
|
def drop_path(
|
|
24
31
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
25
32
|
):
|
|
@@ -78,9 +85,11 @@ class MHAttention(nn.Module):
|
|
|
78
85
|
seq_len=None,
|
|
79
86
|
linear_module: nn.Module = nn.Linear,
|
|
80
87
|
utility_tokens=0,
|
|
88
|
+
talking_heads=False,
|
|
81
89
|
rotary_embedding=None,
|
|
82
90
|
source_size=None,
|
|
83
91
|
scaling="d",
|
|
92
|
+
beta=1.0,
|
|
84
93
|
):
|
|
85
94
|
"""
|
|
86
95
|
Args:
|
|
@@ -96,10 +105,20 @@ class MHAttention(nn.Module):
|
|
|
96
105
|
if causal:
|
|
97
106
|
assert seq_len is not None
|
|
98
107
|
|
|
108
|
+
self.talking_heads = talking_heads
|
|
109
|
+
|
|
110
|
+
if self.talking_heads:
|
|
111
|
+
self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
112
|
+
self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
113
|
+
else:
|
|
114
|
+
self.head_projection = None
|
|
115
|
+
self.sample_projection = None
|
|
116
|
+
|
|
99
117
|
self.embed_dim = embed_dim
|
|
100
118
|
self.n_heads = n_heads
|
|
101
119
|
assert embed_dim % n_heads == 0
|
|
102
120
|
self.scaling = scaling
|
|
121
|
+
self.beta = beta
|
|
103
122
|
|
|
104
123
|
self.head_dim = self.embed_dim // self.n_heads
|
|
105
124
|
|
|
@@ -181,17 +200,26 @@ class MHAttention(nn.Module):
|
|
|
181
200
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
182
201
|
)
|
|
183
202
|
|
|
184
|
-
|
|
185
|
-
|
|
203
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
204
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
205
|
+
|
|
206
|
+
q_util, q_img = (
|
|
207
|
+
q[:, : self.utility_tokens, :, :],
|
|
208
|
+
q[:, self.utility_tokens :, :, :],
|
|
209
|
+
)
|
|
210
|
+
k_util, k_img = (
|
|
211
|
+
k[:, : self.utility_tokens, :, :],
|
|
212
|
+
k[:, self.utility_tokens :, :, :],
|
|
213
|
+
)
|
|
186
214
|
|
|
187
215
|
q_img = rearrange(
|
|
188
216
|
q_img,
|
|
189
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
217
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
190
218
|
**spatial_dimension_values,
|
|
191
219
|
)
|
|
192
220
|
k_img = rearrange(
|
|
193
221
|
k_img,
|
|
194
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
222
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
195
223
|
**spatial_dimension_values,
|
|
196
224
|
)
|
|
197
225
|
|
|
@@ -202,17 +230,20 @@ class MHAttention(nn.Module):
|
|
|
202
230
|
|
|
203
231
|
q_img = rearrange(
|
|
204
232
|
q_img,
|
|
205
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
233
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
206
234
|
)
|
|
207
235
|
k_img = rearrange(
|
|
208
236
|
k_img,
|
|
209
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
237
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
210
238
|
)
|
|
211
239
|
|
|
212
240
|
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
213
241
|
q = torch.cat([q_util, q_img], dim=1)
|
|
214
242
|
k = torch.cat([k_util, k_img], dim=1)
|
|
215
243
|
|
|
244
|
+
q = rearrange(q, "b t h d -> b t (h d)")
|
|
245
|
+
k = rearrange(k, "b t h d -> b t (h d)")
|
|
246
|
+
|
|
216
247
|
return q, k
|
|
217
248
|
|
|
218
249
|
def project_qkv(
|
|
@@ -243,7 +274,7 @@ class MHAttention(nn.Module):
|
|
|
243
274
|
|
|
244
275
|
q, k, v = self.project_qkv(q, k, v)
|
|
245
276
|
|
|
246
|
-
if FLASH_ATTN:
|
|
277
|
+
if FLASH_ATTN and not self.talking_heads:
|
|
247
278
|
# Divide Q/K/V into heads
|
|
248
279
|
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
249
280
|
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
@@ -271,12 +302,22 @@ class MHAttention(nn.Module):
|
|
|
271
302
|
|
|
272
303
|
qk_scores *= self.scaling_factor
|
|
273
304
|
|
|
305
|
+
if self.talking_heads:
|
|
306
|
+
qk_scores = torch.einsum(
|
|
307
|
+
"b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
|
|
308
|
+
)
|
|
309
|
+
|
|
274
310
|
# Apply mask if causal (must come before softmax)
|
|
275
311
|
if self.causal:
|
|
276
312
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
277
313
|
|
|
278
314
|
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
279
315
|
|
|
316
|
+
if self.talking_heads:
|
|
317
|
+
qk_scores = torch.einsum(
|
|
318
|
+
"b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
|
|
319
|
+
)
|
|
320
|
+
|
|
280
321
|
qk_scores = self.dropout(qk_scores)
|
|
281
322
|
|
|
282
323
|
output_with_heads = qk_scores @ v
|
|
@@ -309,7 +350,14 @@ class MHAttention(nn.Module):
|
|
|
309
350
|
self.q_proj.reset_parameters()
|
|
310
351
|
self.k_proj.reset_parameters()
|
|
311
352
|
self.v_proj.reset_parameters()
|
|
353
|
+
scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
|
|
312
354
|
self.out_proj.reset_parameters()
|
|
355
|
+
scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
|
|
356
|
+
|
|
357
|
+
if self.talking_heads:
|
|
358
|
+
# Initialize close to identity
|
|
359
|
+
nn.init.eye_(self.head_projection.weight)
|
|
360
|
+
nn.init.eye_(self.sample_projection.weight)
|
|
313
361
|
|
|
314
362
|
|
|
315
363
|
class FeedforwardBlock(nn.Module):
|
|
@@ -329,17 +377,14 @@ class FeedforwardBlock(nn.Module):
|
|
|
329
377
|
outer_dropout=None,
|
|
330
378
|
linear_module_up=nn.Linear,
|
|
331
379
|
linear_module_down=nn.Linear,
|
|
332
|
-
pre_norm=True,
|
|
333
380
|
normformer=False,
|
|
334
|
-
post_norm=True,
|
|
335
|
-
residual_path=True,
|
|
336
381
|
checkpoint=True,
|
|
382
|
+
beta=1.0,
|
|
337
383
|
):
|
|
338
384
|
super().__init__()
|
|
339
385
|
|
|
340
386
|
self.checkpoint = checkpoint
|
|
341
|
-
self.
|
|
342
|
-
self.post_norm = post_norm
|
|
387
|
+
self.beta = beta
|
|
343
388
|
self.xglu = activation.__name__.endswith("GLU")
|
|
344
389
|
|
|
345
390
|
if self.residual_path and (output_features < input_features):
|
|
@@ -365,19 +410,26 @@ class FeedforwardBlock(nn.Module):
|
|
|
365
410
|
)
|
|
366
411
|
|
|
367
412
|
self.max_features = (
|
|
368
|
-
2 * ratio * output_features
|
|
413
|
+
2 * int(ratio * output_features)
|
|
414
|
+
if self.xglu
|
|
415
|
+
else int(ratio * output_features)
|
|
369
416
|
)
|
|
370
417
|
|
|
371
418
|
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
372
|
-
self.linear_out = linear_module_down(
|
|
419
|
+
self.linear_out = linear_module_down(
|
|
420
|
+
int(ratio * output_features), output_features
|
|
421
|
+
)
|
|
373
422
|
|
|
374
423
|
self.process = nn.Sequential(
|
|
375
424
|
*[
|
|
376
|
-
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
377
425
|
self.linear_in,
|
|
378
426
|
self.activation,
|
|
379
427
|
self.inner_dropout,
|
|
380
|
-
|
|
428
|
+
(
|
|
429
|
+
nn.LayerNorm(int(ratio * output_features))
|
|
430
|
+
if normformer
|
|
431
|
+
else nn.Identity()
|
|
432
|
+
),
|
|
381
433
|
self.linear_out,
|
|
382
434
|
self.outer_dropout,
|
|
383
435
|
]
|
|
@@ -419,12 +471,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
419
471
|
else:
|
|
420
472
|
processed = self.process(x)
|
|
421
473
|
|
|
422
|
-
|
|
423
|
-
return self.layernorm(x + processed)
|
|
424
|
-
elif self.residual_path:
|
|
425
|
-
return x + processed
|
|
426
|
-
else:
|
|
427
|
-
return processed
|
|
474
|
+
return processed
|
|
428
475
|
|
|
429
476
|
def reset_parameters(self):
|
|
430
477
|
if self.post_norm:
|
|
@@ -435,8 +482,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
435
482
|
if hasattr(module, "reset_parameters"):
|
|
436
483
|
module.reset_parameters()
|
|
437
484
|
|
|
485
|
+
scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
|
|
486
|
+
scale_parameters(self.linear_out, self.beta)
|
|
487
|
+
|
|
438
488
|
|
|
439
|
-
class
|
|
489
|
+
class EncoderBlock(nn.Module):
|
|
440
490
|
"""
|
|
441
491
|
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
|
442
492
|
which is also what is seen in e.g.
|
|
@@ -453,6 +503,7 @@ class TransformerBlock(nn.Module):
|
|
|
453
503
|
relative_position_embedding=False,
|
|
454
504
|
source_size=None,
|
|
455
505
|
utility_tokens=0,
|
|
506
|
+
talking_heads=False,
|
|
456
507
|
mlp_ratio=4,
|
|
457
508
|
activation: nn.Module = nn.ReLU,
|
|
458
509
|
activation_kwargs: Optional[dict] = None,
|
|
@@ -470,6 +521,8 @@ class TransformerBlock(nn.Module):
|
|
|
470
521
|
post_norm=False,
|
|
471
522
|
normformer=False,
|
|
472
523
|
checkpoint_ff=True,
|
|
524
|
+
alpha=1.0,
|
|
525
|
+
beta=1.0,
|
|
473
526
|
):
|
|
474
527
|
"""
|
|
475
528
|
Args:
|
|
@@ -481,15 +534,29 @@ class TransformerBlock(nn.Module):
|
|
|
481
534
|
|
|
482
535
|
super().__init__()
|
|
483
536
|
|
|
537
|
+
if pre_norm and post_norm:
|
|
538
|
+
raise ValueError("A transformer cannot be both prenorm and postnorm.")
|
|
539
|
+
|
|
484
540
|
self.pre_norm = pre_norm
|
|
485
541
|
self.post_norm = post_norm
|
|
486
542
|
self.normformer = normformer
|
|
487
543
|
|
|
544
|
+
self.alpha = alpha
|
|
545
|
+
self.beta = beta
|
|
546
|
+
|
|
488
547
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
489
548
|
|
|
490
|
-
self.
|
|
491
|
-
|
|
492
|
-
|
|
549
|
+
if self.pre_norm:
|
|
550
|
+
self.pre_attention_norm = nn.LayerNorm(d_model)
|
|
551
|
+
self.pre_mlp_norm = nn.LayerNorm(d_model)
|
|
552
|
+
|
|
553
|
+
if normformer:
|
|
554
|
+
self.normformer_norm = nn.LayerNorm(d_model)
|
|
555
|
+
|
|
556
|
+
if self.post_norm:
|
|
557
|
+
self.input_norm = nn.LayerNorm(d_model)
|
|
558
|
+
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
559
|
+
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
493
560
|
|
|
494
561
|
if relative_position_embedding:
|
|
495
562
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
@@ -513,7 +580,9 @@ class TransformerBlock(nn.Module):
|
|
|
513
580
|
rotary_embedding=self.rotary_embedding,
|
|
514
581
|
source_size=source_size,
|
|
515
582
|
utility_tokens=utility_tokens,
|
|
583
|
+
talking_heads=talking_heads,
|
|
516
584
|
scaling=msa_scaling,
|
|
585
|
+
beta=beta,
|
|
517
586
|
)
|
|
518
587
|
|
|
519
588
|
# Submodule for the feedforward process
|
|
@@ -536,11 +605,9 @@ class TransformerBlock(nn.Module):
|
|
|
536
605
|
if ff_linear_module_down is not None
|
|
537
606
|
else linear_module
|
|
538
607
|
),
|
|
539
|
-
pre_norm=False, # Handled outside the block
|
|
540
608
|
normformer=normformer,
|
|
541
|
-
post_norm=False, # Handled outside the block
|
|
542
|
-
residual_path=False, # Handled outside the block
|
|
543
609
|
checkpoint=checkpoint_ff,
|
|
610
|
+
beta=beta,
|
|
544
611
|
)
|
|
545
612
|
|
|
546
613
|
self.reset_parameters()
|
|
@@ -550,22 +617,34 @@ class TransformerBlock(nn.Module):
|
|
|
550
617
|
return self.attn._kv_distance
|
|
551
618
|
|
|
552
619
|
def forward(self, x):
|
|
620
|
+
if self.post_norm:
|
|
621
|
+
x = self.input_norm(x)
|
|
553
622
|
|
|
554
623
|
if self.pre_norm:
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
x =
|
|
568
|
-
|
|
624
|
+
process_x = self.pre_attention_norm(x)
|
|
625
|
+
else:
|
|
626
|
+
process_x = x
|
|
627
|
+
|
|
628
|
+
processed = self.drop_path(self.attn(process_x, process_x, process_x))
|
|
629
|
+
|
|
630
|
+
if self.normformer:
|
|
631
|
+
processed = self.normformer_norm(processed)
|
|
632
|
+
|
|
633
|
+
x = self.alpha * x + processed
|
|
634
|
+
|
|
635
|
+
if self.post_norm:
|
|
636
|
+
x = self.post_attention_norm(x)
|
|
637
|
+
elif self.pre_norm:
|
|
638
|
+
process_x = self.pre_mlp_norm(x)
|
|
639
|
+
else:
|
|
640
|
+
process_x = x
|
|
641
|
+
|
|
642
|
+
processed = self.drop_path(self.ff(process_x))
|
|
643
|
+
|
|
644
|
+
x = self.alpha * x + processed
|
|
645
|
+
|
|
646
|
+
if self.post_norm:
|
|
647
|
+
x = self.post_mlp_norm(x)
|
|
569
648
|
|
|
570
649
|
return x
|
|
571
650
|
|
|
@@ -573,16 +652,26 @@ class TransformerBlock(nn.Module):
|
|
|
573
652
|
"""
|
|
574
653
|
Give back the attention scores used in this layer.
|
|
575
654
|
"""
|
|
655
|
+
# Fix: Use the correct attribute name 'pre_attention_norm'
|
|
576
656
|
if self.pre_norm:
|
|
577
|
-
|
|
657
|
+
# We must normalize the input before measuring attention logits
|
|
658
|
+
# to match what the model actually sees during forward()
|
|
659
|
+
x = self.pre_attention_norm(x)
|
|
578
660
|
return self.attn.attention_logits(x, x, x)
|
|
579
661
|
else:
|
|
580
662
|
return self.attn.attention_logits(x, x, x)
|
|
581
663
|
|
|
582
664
|
def reset_parameters(self):
|
|
583
|
-
self.
|
|
584
|
-
|
|
585
|
-
|
|
665
|
+
if self.pre_norm:
|
|
666
|
+
self.pre_attention_norm.reset_parameters()
|
|
667
|
+
self.pre_mlp_norm.reset_parameters()
|
|
668
|
+
|
|
669
|
+
if self.post_norm:
|
|
670
|
+
self.post_attention_norm.reset_parameters()
|
|
671
|
+
self.post_mlp_norm.reset_parameters()
|
|
672
|
+
|
|
673
|
+
if self.normformer:
|
|
674
|
+
self.normformer_norm.reset_parameters()
|
|
586
675
|
|
|
587
676
|
self.attn.reset_parameters()
|
|
588
677
|
self.ff.reset_parameters()
|
|
@@ -616,12 +705,15 @@ class TransformerEncoder(nn.Module):
|
|
|
616
705
|
causal=False,
|
|
617
706
|
linear_module=nn.Linear,
|
|
618
707
|
utility_tokens=0,
|
|
708
|
+
talking_heads=False,
|
|
619
709
|
return_utility_tokens=False,
|
|
620
710
|
pre_norm=True,
|
|
621
711
|
post_norm=False,
|
|
622
712
|
normformer=False,
|
|
623
713
|
msa_scaling="d",
|
|
624
714
|
checkpoint_ff=True,
|
|
715
|
+
alpha=1.0,
|
|
716
|
+
beta=1.0,
|
|
625
717
|
):
|
|
626
718
|
"""
|
|
627
719
|
Args:
|
|
@@ -644,6 +736,13 @@ class TransformerEncoder(nn.Module):
|
|
|
644
736
|
)
|
|
645
737
|
|
|
646
738
|
super().__init__()
|
|
739
|
+
|
|
740
|
+
if FLASH_ATTN and talking_heads:
|
|
741
|
+
warnings.warn(
|
|
742
|
+
"Using talking heads currently prevents using flash attention.",
|
|
743
|
+
stacklevel=2,
|
|
744
|
+
)
|
|
745
|
+
|
|
647
746
|
self.seq_len = seq_len
|
|
648
747
|
self.n_heads = n_heads
|
|
649
748
|
self._utility_tokens = utility_tokens
|
|
@@ -688,13 +787,14 @@ class TransformerEncoder(nn.Module):
|
|
|
688
787
|
|
|
689
788
|
self.blocks = nn.ModuleList(
|
|
690
789
|
[
|
|
691
|
-
|
|
790
|
+
EncoderBlock(
|
|
692
791
|
self.full_sequence_length,
|
|
693
792
|
d_model,
|
|
694
793
|
n_heads,
|
|
695
794
|
relative_position_embedding=relative_position_embedding,
|
|
696
795
|
source_size=source_size,
|
|
697
796
|
utility_tokens=utility_tokens,
|
|
797
|
+
talking_heads=talking_heads,
|
|
698
798
|
mlp_ratio=mlp_ratio,
|
|
699
799
|
activation=activation,
|
|
700
800
|
activation_kwargs=activation_kwargs,
|
|
@@ -712,6 +812,8 @@ class TransformerEncoder(nn.Module):
|
|
|
712
812
|
post_norm=post_norm,
|
|
713
813
|
normformer=normformer,
|
|
714
814
|
checkpoint_ff=checkpoint_ff,
|
|
815
|
+
alpha=alpha,
|
|
816
|
+
beta=beta,
|
|
715
817
|
)
|
|
716
818
|
for i in range(n_layers)
|
|
717
819
|
]
|
|
@@ -732,13 +834,14 @@ class TransformerEncoder(nn.Module):
|
|
|
732
834
|
x = x
|
|
733
835
|
|
|
734
836
|
if self.absolute_position_embedding is not None:
|
|
735
|
-
|
|
837
|
+
position_embedding = self.absolute_position_embedding(
|
|
736
838
|
torch.arange(
|
|
737
839
|
0, self.full_sequence_length, dtype=torch.long, device=x.device
|
|
738
840
|
).unsqueeze(
|
|
739
841
|
0
|
|
740
842
|
) # to shape (1, seq_len) to broadcast over batch
|
|
741
843
|
)
|
|
844
|
+
x += position_embedding
|
|
742
845
|
|
|
743
846
|
return x
|
|
744
847
|
|
|
@@ -158,7 +158,6 @@ class ViTEncoder(nn.Module):
|
|
|
158
158
|
pooling_kernel_stride=2,
|
|
159
159
|
pooling_padding=1,
|
|
160
160
|
transformer_feedforward_first=True,
|
|
161
|
-
transformer_initial_ff_residual_path=True,
|
|
162
161
|
transformer_initial_ff_linear_module_up=None,
|
|
163
162
|
transformer_initial_ff_linear_module_down=None,
|
|
164
163
|
transformer_initial_ff_dropout=None,
|
|
@@ -174,6 +173,7 @@ class ViTEncoder(nn.Module):
|
|
|
174
173
|
transformer_heads=4,
|
|
175
174
|
transformer_mlp_ratio=2,
|
|
176
175
|
transformer_utility_tokens=0,
|
|
176
|
+
transformer_talking_heads=False,
|
|
177
177
|
transformer_return_utility_tokens=False,
|
|
178
178
|
transformer_activation: nn.Module = SquaredReLU,
|
|
179
179
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -345,11 +345,14 @@ class ViTEncoder(nn.Module):
|
|
|
345
345
|
causal=False,
|
|
346
346
|
linear_module=linear_module,
|
|
347
347
|
utility_tokens=transformer_utility_tokens,
|
|
348
|
+
talking_heads=transformer_talking_heads,
|
|
348
349
|
return_utility_tokens=transformer_return_utility_tokens,
|
|
349
350
|
pre_norm=transformer_pre_norm,
|
|
350
351
|
normformer=transformer_normformer,
|
|
351
352
|
post_norm=transformer_post_norm,
|
|
352
353
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
354
|
+
alpha=self.alpha,
|
|
355
|
+
beta=self.beta,
|
|
353
356
|
)
|
|
354
357
|
else:
|
|
355
358
|
self.transformer = nn.Identity()
|
|
@@ -391,16 +394,14 @@ class ViTEncoder(nn.Module):
|
|
|
391
394
|
or transformer_ff_linear_module_down
|
|
392
395
|
or linear_module
|
|
393
396
|
),
|
|
394
|
-
pre_norm=transformer_pre_norm,
|
|
395
397
|
normformer=transformer_normformer,
|
|
396
|
-
post_norm=transformer_post_norm,
|
|
397
|
-
residual_path=transformer_initial_ff_residual_path,
|
|
398
398
|
checkpoint=transformer_checkpoint_ff,
|
|
399
|
+
beta=self.beta,
|
|
399
400
|
)
|
|
400
401
|
else:
|
|
401
402
|
self.initial_ff = nn.Identity()
|
|
402
403
|
|
|
403
|
-
self.
|
|
404
|
+
self.preprocess = nn.Sequential(
|
|
404
405
|
*[
|
|
405
406
|
batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
|
|
406
407
|
self.cnn,
|
|
@@ -410,19 +411,21 @@ class ViTEncoder(nn.Module):
|
|
|
410
411
|
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
|
411
412
|
),
|
|
412
413
|
self.pooling_channels_padding,
|
|
413
|
-
|
|
414
|
-
self.transformer,
|
|
414
|
+
nn.LayerNorm(),
|
|
415
415
|
]
|
|
416
416
|
)
|
|
417
417
|
|
|
418
418
|
self.reset_parameters()
|
|
419
419
|
|
|
420
420
|
def forward(self, x):
|
|
421
|
-
|
|
421
|
+
x = self.preprocess(x)
|
|
422
|
+
x = x + self.initial_ff(x)
|
|
423
|
+
return self.transformer(x)
|
|
422
424
|
|
|
423
425
|
def attention_logits(self, x):
|
|
424
|
-
x = self.
|
|
425
|
-
|
|
426
|
+
x = self.preprocess(x)
|
|
427
|
+
x = x + self.initial_ff(x)
|
|
428
|
+
return self.transformer.attention_logits(x)
|
|
426
429
|
|
|
427
430
|
def reset_parameters(self):
|
|
428
431
|
for module in self.encoder:
|
|
@@ -456,7 +459,6 @@ class ViT(nn.Module):
|
|
|
456
459
|
pooling_kernel_stride=2,
|
|
457
460
|
pooling_padding=1,
|
|
458
461
|
transformer_feedforward_first=True,
|
|
459
|
-
transformer_initial_ff_residual_path=True,
|
|
460
462
|
transformer_initial_ff_linear_module_up=None,
|
|
461
463
|
transformer_initial_ff_linear_module_down=None,
|
|
462
464
|
transformer_initial_ff_dropout=None,
|
|
@@ -472,6 +474,7 @@ class ViT(nn.Module):
|
|
|
472
474
|
transformer_heads=4,
|
|
473
475
|
transformer_mlp_ratio=2,
|
|
474
476
|
transformer_utility_tokens=0,
|
|
477
|
+
transformer_talking_heads=False,
|
|
475
478
|
transformer_return_utility_tokens=False,
|
|
476
479
|
transformer_activation: nn.Module = SquaredReLU,
|
|
477
480
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -527,7 +530,6 @@ class ViT(nn.Module):
|
|
|
527
530
|
pooling_kernel_stride=pooling_kernel_stride,
|
|
528
531
|
pooling_padding=pooling_padding,
|
|
529
532
|
transformer_feedforward_first=transformer_feedforward_first,
|
|
530
|
-
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
531
533
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
532
534
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
533
535
|
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
@@ -543,6 +545,7 @@ class ViT(nn.Module):
|
|
|
543
545
|
transformer_heads=transformer_heads,
|
|
544
546
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
545
547
|
transformer_utility_tokens=transformer_utility_tokens,
|
|
548
|
+
transformer_talking_heads=transformer_talking_heads,
|
|
546
549
|
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
547
550
|
transformer_activation=transformer_activation,
|
|
548
551
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|