broccoli-ml 10.2.0__py3-none-any.whl → 12.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +46 -56
- broccoli/vit.py +0 -4
- {broccoli_ml-10.2.0.dist-info → broccoli_ml-12.0.0.dist-info}/METADATA +1 -1
- {broccoli_ml-10.2.0.dist-info → broccoli_ml-12.0.0.dist-info}/RECORD +6 -6
- {broccoli_ml-10.2.0.dist-info → broccoli_ml-12.0.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-10.2.0.dist-info → broccoli_ml-12.0.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -21,24 +21,6 @@ except ImportError:
|
|
|
21
21
|
FLASH_ATTN = False
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
class LayerScale(nn.Module):
|
|
25
|
-
def __init__(self, dim, decay=False, init_values=1e-4):
|
|
26
|
-
super().__init__()
|
|
27
|
-
self.decay = decay
|
|
28
|
-
if decay:
|
|
29
|
-
self.scale = nn.Parameter(init_values * torch.ones(dim))
|
|
30
|
-
self.nondecay_scale = None
|
|
31
|
-
else:
|
|
32
|
-
self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
|
|
33
|
-
self.scale = None
|
|
34
|
-
|
|
35
|
-
def forward(self, x):
|
|
36
|
-
if self.decay:
|
|
37
|
-
return x * self.scale
|
|
38
|
-
else:
|
|
39
|
-
return x * self.nondecay_scale
|
|
40
|
-
|
|
41
|
-
|
|
42
24
|
def drop_path(
|
|
43
25
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
44
26
|
):
|
|
@@ -522,7 +504,6 @@ class TransformerBlock(nn.Module):
|
|
|
522
504
|
post_norm=False,
|
|
523
505
|
normformer=False,
|
|
524
506
|
checkpoint_ff=True,
|
|
525
|
-
layerscale=True,
|
|
526
507
|
):
|
|
527
508
|
"""
|
|
528
509
|
Args:
|
|
@@ -540,16 +521,16 @@ class TransformerBlock(nn.Module):
|
|
|
540
521
|
|
|
541
522
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
542
523
|
|
|
543
|
-
self.
|
|
544
|
-
|
|
545
|
-
|
|
524
|
+
if self.pre_norm:
|
|
525
|
+
self.pre_attention_norm = nn.LayerNorm(d_model)
|
|
526
|
+
self.pre_mlp_norm = nn.LayerNorm(d_model)
|
|
546
527
|
|
|
547
|
-
if
|
|
548
|
-
self.
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
self.
|
|
552
|
-
self.
|
|
528
|
+
if normformer:
|
|
529
|
+
self.normformer_norm = nn.LayerNorm(d_model)
|
|
530
|
+
|
|
531
|
+
if self.post_norm:
|
|
532
|
+
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
533
|
+
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
553
534
|
|
|
554
535
|
if relative_position_embedding:
|
|
555
536
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
@@ -613,20 +594,29 @@ class TransformerBlock(nn.Module):
|
|
|
613
594
|
def forward(self, x):
|
|
614
595
|
|
|
615
596
|
if self.pre_norm:
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
x =
|
|
629
|
-
|
|
597
|
+
process_x = self.pre_attention_norm(x)
|
|
598
|
+
else:
|
|
599
|
+
process_x = x
|
|
600
|
+
|
|
601
|
+
processed = self.drop_path(self.attn(process_x, process_x, process_x))
|
|
602
|
+
|
|
603
|
+
if self.normformer:
|
|
604
|
+
processed = self.normformer_norm(processed)
|
|
605
|
+
|
|
606
|
+
x = x + processed
|
|
607
|
+
|
|
608
|
+
if self.post_norm:
|
|
609
|
+
x = self.post_attention_norm(x)
|
|
610
|
+
|
|
611
|
+
if self.pre_norm:
|
|
612
|
+
process_x = self.pre_mlp_norm(x)
|
|
613
|
+
else:
|
|
614
|
+
process_x = x
|
|
615
|
+
|
|
616
|
+
x = x + self.drop_path(self.ff(process_x))
|
|
617
|
+
|
|
618
|
+
if self.post_norm:
|
|
619
|
+
x = self.post_mlp_norm(x)
|
|
630
620
|
|
|
631
621
|
return x
|
|
632
622
|
|
|
@@ -634,16 +624,26 @@ class TransformerBlock(nn.Module):
|
|
|
634
624
|
"""
|
|
635
625
|
Give back the attention scores used in this layer.
|
|
636
626
|
"""
|
|
627
|
+
# Fix: Use the correct attribute name 'pre_attention_norm'
|
|
637
628
|
if self.pre_norm:
|
|
638
|
-
|
|
629
|
+
# We must normalize the input before measuring attention logits
|
|
630
|
+
# to match what the model actually sees during forward()
|
|
631
|
+
x = self.pre_attention_norm(x)
|
|
639
632
|
return self.attn.attention_logits(x, x, x)
|
|
640
633
|
else:
|
|
641
634
|
return self.attn.attention_logits(x, x, x)
|
|
642
635
|
|
|
643
636
|
def reset_parameters(self):
|
|
644
|
-
self.
|
|
645
|
-
|
|
646
|
-
|
|
637
|
+
if self.pre_norm:
|
|
638
|
+
self.pre_attention_norm.reset_parameters()
|
|
639
|
+
self.pre_mlp_norm.reset_parameters()
|
|
640
|
+
|
|
641
|
+
if self.post_norm:
|
|
642
|
+
self.post_attention_norm.reset_parameters()
|
|
643
|
+
self.post_mlp_norm.reset_parameters()
|
|
644
|
+
|
|
645
|
+
if self.normformer:
|
|
646
|
+
self.normformer_norm.reset_parameters()
|
|
647
647
|
|
|
648
648
|
self.attn.reset_parameters()
|
|
649
649
|
self.ff.reset_parameters()
|
|
@@ -684,7 +684,6 @@ class TransformerEncoder(nn.Module):
|
|
|
684
684
|
normformer=False,
|
|
685
685
|
msa_scaling="d",
|
|
686
686
|
checkpoint_ff=True,
|
|
687
|
-
layerscale=True,
|
|
688
687
|
):
|
|
689
688
|
"""
|
|
690
689
|
Args:
|
|
@@ -719,12 +718,6 @@ class TransformerEncoder(nn.Module):
|
|
|
719
718
|
self._utility_tokens = utility_tokens
|
|
720
719
|
self.return_utility_tokens = return_utility_tokens
|
|
721
720
|
|
|
722
|
-
if layerscale:
|
|
723
|
-
rope_and_ape = absolute_position_embedding and relative_position_embedding
|
|
724
|
-
self.position_layerscale = LayerScale(d_model, decay=rope_and_ape)
|
|
725
|
-
else:
|
|
726
|
-
self.position_layerscale = None
|
|
727
|
-
|
|
728
721
|
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
729
722
|
if self._utility_tokens:
|
|
730
723
|
self._utility_token_embedding = nn.Parameter(
|
|
@@ -789,7 +782,6 @@ class TransformerEncoder(nn.Module):
|
|
|
789
782
|
post_norm=post_norm,
|
|
790
783
|
normformer=normformer,
|
|
791
784
|
checkpoint_ff=checkpoint_ff,
|
|
792
|
-
layerscale=layerscale,
|
|
793
785
|
)
|
|
794
786
|
for i in range(n_layers)
|
|
795
787
|
]
|
|
@@ -817,8 +809,6 @@ class TransformerEncoder(nn.Module):
|
|
|
817
809
|
0
|
|
818
810
|
) # to shape (1, seq_len) to broadcast over batch
|
|
819
811
|
)
|
|
820
|
-
if self.position_layerscale is not None:
|
|
821
|
-
position_embedding = self.position_layerscale(position_embedding)
|
|
822
812
|
x += position_embedding
|
|
823
813
|
|
|
824
814
|
return x
|
broccoli/vit.py
CHANGED
|
@@ -187,7 +187,6 @@ class ViTEncoder(nn.Module):
|
|
|
187
187
|
transformer_msa_dropout=0.1,
|
|
188
188
|
transformer_stochastic_depth=0.1,
|
|
189
189
|
transformer_checkpoint_ff=True,
|
|
190
|
-
transformer_layerscale=True,
|
|
191
190
|
linear_module=nn.Linear,
|
|
192
191
|
):
|
|
193
192
|
super().__init__()
|
|
@@ -353,7 +352,6 @@ class ViTEncoder(nn.Module):
|
|
|
353
352
|
normformer=transformer_normformer,
|
|
354
353
|
post_norm=transformer_post_norm,
|
|
355
354
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
356
|
-
layerscale=transformer_layerscale,
|
|
357
355
|
)
|
|
358
356
|
else:
|
|
359
357
|
self.transformer = nn.Identity()
|
|
@@ -489,7 +487,6 @@ class ViT(nn.Module):
|
|
|
489
487
|
transformer_msa_dropout=0.1,
|
|
490
488
|
transformer_stochastic_depth=0.1,
|
|
491
489
|
transformer_checkpoint_ff=True,
|
|
492
|
-
transformer_layerscale=True,
|
|
493
490
|
head=SequencePoolClassificationHead,
|
|
494
491
|
batch_norm_logits=True,
|
|
495
492
|
logit_projection_layer=nn.Linear,
|
|
@@ -562,7 +559,6 @@ class ViT(nn.Module):
|
|
|
562
559
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
563
560
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
564
561
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
565
|
-
transformer_layerscale=transformer_layerscale,
|
|
566
562
|
linear_module=linear_module,
|
|
567
563
|
)
|
|
568
564
|
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=dKiqPZEBLK5e6hqEhz99Z68RCJZNAPkTwQQQfjFDhko,27611
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-
|
|
11
|
-
broccoli_ml-
|
|
12
|
-
broccoli_ml-
|
|
13
|
-
broccoli_ml-
|
|
9
|
+
broccoli/vit.py,sha256=DvVpayMIcUhH7Xg6CiYyeedUuuMHrjsGxEdXfnTGa_Q,22428
|
|
10
|
+
broccoli_ml-12.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-12.0.0.dist-info/METADATA,sha256=xt43usBeTgqLUqnHqIsPf8xWN3Bkd8ckRP1GE69YID0,1369
|
|
12
|
+
broccoli_ml-12.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-12.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|