broccoli-ml 10.1.1__py3-none-any.whl → 11.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +76 -28
- {broccoli_ml-10.1.1.dist-info → broccoli_ml-11.0.0.dist-info}/METADATA +1 -1
- {broccoli_ml-10.1.1.dist-info → broccoli_ml-11.0.0.dist-info}/RECORD +5 -5
- {broccoli_ml-10.1.1.dist-info → broccoli_ml-11.0.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-10.1.1.dist-info → broccoli_ml-11.0.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -22,12 +22,26 @@ except ImportError:
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class LayerScale(nn.Module):
|
|
25
|
-
def __init__(self, dim, init_values=1e-4):
|
|
25
|
+
def __init__(self, dim, decay=False, init_values=1e-4):
|
|
26
26
|
super().__init__()
|
|
27
|
-
self.
|
|
27
|
+
self.dim = dim
|
|
28
|
+
self.decay = decay
|
|
29
|
+
self.init_values = init_values
|
|
30
|
+
self.reset_parameters()
|
|
28
31
|
|
|
29
32
|
def forward(self, x):
|
|
30
|
-
|
|
33
|
+
if self.decay:
|
|
34
|
+
return x * self.scale
|
|
35
|
+
else:
|
|
36
|
+
return x * self.nondecay_scale
|
|
37
|
+
|
|
38
|
+
def reset_parameters(self):
|
|
39
|
+
if self.decay:
|
|
40
|
+
self.scale = nn.Parameter(self.init_values * torch.ones(self.dim))
|
|
41
|
+
self.nondecay_scale = None
|
|
42
|
+
else:
|
|
43
|
+
self.nondecay_scale = nn.Parameter(self.init_values * torch.ones(self.dim))
|
|
44
|
+
self.scale = None
|
|
31
45
|
|
|
32
46
|
|
|
33
47
|
def drop_path(
|
|
@@ -531,10 +545,18 @@ class TransformerBlock(nn.Module):
|
|
|
531
545
|
|
|
532
546
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
533
547
|
|
|
534
|
-
self.
|
|
535
|
-
|
|
536
|
-
|
|
548
|
+
if self.pre_norm:
|
|
549
|
+
self.pre_attention_norm = nn.LayerNorm(d_model)
|
|
550
|
+
self.pre_mlp_norm = nn.LayerNorm(d_model)
|
|
551
|
+
|
|
552
|
+
if normformer:
|
|
553
|
+
self.normformer_norm = nn.LayerNorm(d_model)
|
|
537
554
|
|
|
555
|
+
if self.post_norm:
|
|
556
|
+
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
557
|
+
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
558
|
+
|
|
559
|
+
self.layerscale = layerscale
|
|
538
560
|
if layerscale:
|
|
539
561
|
self.layerscale1 = LayerScale(d_model)
|
|
540
562
|
self.layerscale2 = LayerScale(d_model)
|
|
@@ -604,20 +626,31 @@ class TransformerBlock(nn.Module):
|
|
|
604
626
|
def forward(self, x):
|
|
605
627
|
|
|
606
628
|
if self.pre_norm:
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
629
|
+
process_x = self.pre_attention_norm(x)
|
|
630
|
+
else:
|
|
631
|
+
process_x = x
|
|
632
|
+
|
|
633
|
+
processed = self.drop_path(
|
|
634
|
+
self.layerscale1(self.attn(process_x, process_x, process_x))
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if self.normformer:
|
|
638
|
+
processed = self.normformer_norm(processed)
|
|
639
|
+
|
|
640
|
+
x = x + processed
|
|
641
|
+
|
|
642
|
+
if self.post_norm:
|
|
643
|
+
x = self.post_attention_norm(x)
|
|
644
|
+
|
|
645
|
+
if self.pre_norm:
|
|
646
|
+
process_x = self.pre_mlp_norm(x)
|
|
647
|
+
else:
|
|
648
|
+
process_x = x
|
|
649
|
+
|
|
650
|
+
x = x + self.drop_path(self.layerscale2(self.ff(process_x)))
|
|
651
|
+
|
|
652
|
+
if self.post_norm:
|
|
653
|
+
x = self.post_mlp_norm(x)
|
|
621
654
|
|
|
622
655
|
return x
|
|
623
656
|
|
|
@@ -625,20 +658,34 @@ class TransformerBlock(nn.Module):
|
|
|
625
658
|
"""
|
|
626
659
|
Give back the attention scores used in this layer.
|
|
627
660
|
"""
|
|
661
|
+
# Fix: Use the correct attribute name 'pre_attention_norm'
|
|
628
662
|
if self.pre_norm:
|
|
629
|
-
|
|
663
|
+
# We must normalize the input before measuring attention logits
|
|
664
|
+
# to match what the model actually sees during forward()
|
|
665
|
+
x = self.pre_attention_norm(x)
|
|
630
666
|
return self.attn.attention_logits(x, x, x)
|
|
631
667
|
else:
|
|
632
668
|
return self.attn.attention_logits(x, x, x)
|
|
633
669
|
|
|
634
670
|
def reset_parameters(self):
|
|
635
|
-
self.
|
|
636
|
-
|
|
637
|
-
|
|
671
|
+
if self.pre_norm:
|
|
672
|
+
self.pre_attention_norm.reset_parameters()
|
|
673
|
+
self.pre_mlp_norm.reset_parameters()
|
|
674
|
+
|
|
675
|
+
if self.post_norm:
|
|
676
|
+
self.post_attention_norm.reset_parameters()
|
|
677
|
+
self.post_mlp_norm.reset_parameters()
|
|
678
|
+
|
|
679
|
+
if self.normformer:
|
|
680
|
+
self.normformer_norm.reset_parameters()
|
|
638
681
|
|
|
639
682
|
self.attn.reset_parameters()
|
|
640
683
|
self.ff.reset_parameters()
|
|
641
684
|
|
|
685
|
+
if self.layerscale:
|
|
686
|
+
self.layerscale1.reset_parameters()
|
|
687
|
+
self.layerscale2.reset_parameters()
|
|
688
|
+
|
|
642
689
|
|
|
643
690
|
class TransformerEncoder(nn.Module):
|
|
644
691
|
"""
|
|
@@ -711,9 +758,10 @@ class TransformerEncoder(nn.Module):
|
|
|
711
758
|
self.return_utility_tokens = return_utility_tokens
|
|
712
759
|
|
|
713
760
|
if layerscale:
|
|
714
|
-
|
|
761
|
+
rope_and_ape = absolute_position_embedding and relative_position_embedding
|
|
762
|
+
self.position_layerscale = LayerScale(d_model, decay=rope_and_ape)
|
|
715
763
|
else:
|
|
716
|
-
self.
|
|
764
|
+
self.position_layerscale = None
|
|
717
765
|
|
|
718
766
|
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
719
767
|
if self._utility_tokens:
|
|
@@ -807,8 +855,8 @@ class TransformerEncoder(nn.Module):
|
|
|
807
855
|
0
|
|
808
856
|
) # to shape (1, seq_len) to broadcast over batch
|
|
809
857
|
)
|
|
810
|
-
if self.
|
|
811
|
-
position_embedding = self.
|
|
858
|
+
if self.position_layerscale is not None:
|
|
859
|
+
position_embedding = self.position_layerscale(position_embedding)
|
|
812
860
|
x += position_embedding
|
|
813
861
|
|
|
814
862
|
return x
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=urW9KGxK47Mns7ZowPKTgdEyp4Yd21uw3PNCpKx04cI,29223
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
9
|
broccoli/vit.py,sha256=EGbQb-atuzG3JAx7kdTaJEbWvQR-4XgyYvwjKkN5C38,22612
|
|
10
|
-
broccoli_ml-
|
|
11
|
-
broccoli_ml-
|
|
12
|
-
broccoli_ml-
|
|
13
|
-
broccoli_ml-
|
|
10
|
+
broccoli_ml-11.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-11.0.0.dist-info/METADATA,sha256=7ixQROHXr3LyCMhT5hkKmvip8C8Bytmtz-IoQCbC_NQ,1369
|
|
12
|
+
broccoli_ml-11.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-11.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|