broccoli-ml 10.2.0__py3-none-any.whl → 11.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +65 -27
- {broccoli_ml-10.2.0.dist-info → broccoli_ml-11.0.0.dist-info}/METADATA +1 -1
- {broccoli_ml-10.2.0.dist-info → broccoli_ml-11.0.0.dist-info}/RECORD +5 -5
- {broccoli_ml-10.2.0.dist-info → broccoli_ml-11.0.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-10.2.0.dist-info → broccoli_ml-11.0.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -24,13 +24,10 @@ except ImportError:
|
|
|
24
24
|
class LayerScale(nn.Module):
|
|
25
25
|
def __init__(self, dim, decay=False, init_values=1e-4):
|
|
26
26
|
super().__init__()
|
|
27
|
+
self.dim = dim
|
|
27
28
|
self.decay = decay
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
self.nondecay_scale = None
|
|
31
|
-
else:
|
|
32
|
-
self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
|
|
33
|
-
self.scale = None
|
|
29
|
+
self.init_values = init_values
|
|
30
|
+
self.reset_parameters()
|
|
34
31
|
|
|
35
32
|
def forward(self, x):
|
|
36
33
|
if self.decay:
|
|
@@ -38,6 +35,14 @@ class LayerScale(nn.Module):
|
|
|
38
35
|
else:
|
|
39
36
|
return x * self.nondecay_scale
|
|
40
37
|
|
|
38
|
+
def reset_parameters(self):
|
|
39
|
+
if self.decay:
|
|
40
|
+
self.scale = nn.Parameter(self.init_values * torch.ones(self.dim))
|
|
41
|
+
self.nondecay_scale = None
|
|
42
|
+
else:
|
|
43
|
+
self.nondecay_scale = nn.Parameter(self.init_values * torch.ones(self.dim))
|
|
44
|
+
self.scale = None
|
|
45
|
+
|
|
41
46
|
|
|
42
47
|
def drop_path(
|
|
43
48
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
@@ -540,10 +545,18 @@ class TransformerBlock(nn.Module):
|
|
|
540
545
|
|
|
541
546
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
542
547
|
|
|
543
|
-
self.
|
|
544
|
-
|
|
545
|
-
|
|
548
|
+
if self.pre_norm:
|
|
549
|
+
self.pre_attention_norm = nn.LayerNorm(d_model)
|
|
550
|
+
self.pre_mlp_norm = nn.LayerNorm(d_model)
|
|
551
|
+
|
|
552
|
+
if normformer:
|
|
553
|
+
self.normformer_norm = nn.LayerNorm(d_model)
|
|
546
554
|
|
|
555
|
+
if self.post_norm:
|
|
556
|
+
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
557
|
+
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
558
|
+
|
|
559
|
+
self.layerscale = layerscale
|
|
547
560
|
if layerscale:
|
|
548
561
|
self.layerscale1 = LayerScale(d_model)
|
|
549
562
|
self.layerscale2 = LayerScale(d_model)
|
|
@@ -613,20 +626,31 @@ class TransformerBlock(nn.Module):
|
|
|
613
626
|
def forward(self, x):
|
|
614
627
|
|
|
615
628
|
if self.pre_norm:
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
629
|
+
process_x = self.pre_attention_norm(x)
|
|
630
|
+
else:
|
|
631
|
+
process_x = x
|
|
632
|
+
|
|
633
|
+
processed = self.drop_path(
|
|
634
|
+
self.layerscale1(self.attn(process_x, process_x, process_x))
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if self.normformer:
|
|
638
|
+
processed = self.normformer_norm(processed)
|
|
639
|
+
|
|
640
|
+
x = x + processed
|
|
641
|
+
|
|
642
|
+
if self.post_norm:
|
|
643
|
+
x = self.post_attention_norm(x)
|
|
644
|
+
|
|
645
|
+
if self.pre_norm:
|
|
646
|
+
process_x = self.pre_mlp_norm(x)
|
|
647
|
+
else:
|
|
648
|
+
process_x = x
|
|
649
|
+
|
|
650
|
+
x = x + self.drop_path(self.layerscale2(self.ff(process_x)))
|
|
651
|
+
|
|
652
|
+
if self.post_norm:
|
|
653
|
+
x = self.post_mlp_norm(x)
|
|
630
654
|
|
|
631
655
|
return x
|
|
632
656
|
|
|
@@ -634,20 +658,34 @@ class TransformerBlock(nn.Module):
|
|
|
634
658
|
"""
|
|
635
659
|
Give back the attention scores used in this layer.
|
|
636
660
|
"""
|
|
661
|
+
# Fix: Use the correct attribute name 'pre_attention_norm'
|
|
637
662
|
if self.pre_norm:
|
|
638
|
-
|
|
663
|
+
# We must normalize the input before measuring attention logits
|
|
664
|
+
# to match what the model actually sees during forward()
|
|
665
|
+
x = self.pre_attention_norm(x)
|
|
639
666
|
return self.attn.attention_logits(x, x, x)
|
|
640
667
|
else:
|
|
641
668
|
return self.attn.attention_logits(x, x, x)
|
|
642
669
|
|
|
643
670
|
def reset_parameters(self):
|
|
644
|
-
self.
|
|
645
|
-
|
|
646
|
-
|
|
671
|
+
if self.pre_norm:
|
|
672
|
+
self.pre_attention_norm.reset_parameters()
|
|
673
|
+
self.pre_mlp_norm.reset_parameters()
|
|
674
|
+
|
|
675
|
+
if self.post_norm:
|
|
676
|
+
self.post_attention_norm.reset_parameters()
|
|
677
|
+
self.post_mlp_norm.reset_parameters()
|
|
678
|
+
|
|
679
|
+
if self.normformer:
|
|
680
|
+
self.normformer_norm.reset_parameters()
|
|
647
681
|
|
|
648
682
|
self.attn.reset_parameters()
|
|
649
683
|
self.ff.reset_parameters()
|
|
650
684
|
|
|
685
|
+
if self.layerscale:
|
|
686
|
+
self.layerscale1.reset_parameters()
|
|
687
|
+
self.layerscale2.reset_parameters()
|
|
688
|
+
|
|
651
689
|
|
|
652
690
|
class TransformerEncoder(nn.Module):
|
|
653
691
|
"""
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=urW9KGxK47Mns7ZowPKTgdEyp4Yd21uw3PNCpKx04cI,29223
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
9
|
broccoli/vit.py,sha256=EGbQb-atuzG3JAx7kdTaJEbWvQR-4XgyYvwjKkN5C38,22612
|
|
10
|
-
broccoli_ml-
|
|
11
|
-
broccoli_ml-
|
|
12
|
-
broccoli_ml-
|
|
13
|
-
broccoli_ml-
|
|
10
|
+
broccoli_ml-11.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-11.0.0.dist-info/METADATA,sha256=7ixQROHXr3LyCMhT5hkKmvip8C8Bytmtz-IoQCbC_NQ,1369
|
|
12
|
+
broccoli_ml-11.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-11.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|