broccoli-ml 12.0.0__tar.gz → 12.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/PKG-INFO +1 -1
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/broccoli/transformer.py +29 -18
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/pyproject.toml +1 -1
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/LICENSE +0 -0
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/README.md +0 -0
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/broccoli/activation.py +0 -0
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/broccoli/linear.py +0 -0
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/broccoli/rope.py +0 -0
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/broccoli/utils.py +0 -0
- {broccoli_ml-12.0.0 → broccoli_ml-12.2.0}/broccoli/vit.py +0 -0
|
@@ -592,33 +592,44 @@ class TransformerBlock(nn.Module):
|
|
|
592
592
|
return self.attn._kv_distance
|
|
593
593
|
|
|
594
594
|
def forward(self, x):
|
|
595
|
-
|
|
596
595
|
if self.pre_norm:
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
596
|
+
x = self.layer_norm_1(x)
|
|
597
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
|
598
|
+
x = self.layer_norm_2(x)
|
|
599
|
+
x = x + self.drop_path(self.ff(x))
|
|
600
|
+
if self.post_norm: # i.e. in addition! Pre and post.
|
|
601
|
+
x = self.layer_norm_3(x)
|
|
600
602
|
|
|
601
|
-
|
|
603
|
+
# if self.pre_norm:
|
|
604
|
+
# process_x = self.pre_attention_norm(x)
|
|
605
|
+
# else:
|
|
606
|
+
# process_x = x
|
|
602
607
|
|
|
603
|
-
|
|
604
|
-
processed = self.normformer_norm(processed)
|
|
608
|
+
# processed = self.drop_path(self.attn(process_x, process_x, process_x))
|
|
605
609
|
|
|
606
|
-
|
|
610
|
+
# if self.normformer:
|
|
611
|
+
# processed = self.normformer_norm(processed)
|
|
607
612
|
|
|
608
|
-
if self.
|
|
609
|
-
|
|
613
|
+
# if self.residual_path:
|
|
614
|
+
# x = x + processed
|
|
610
615
|
|
|
611
|
-
if self.
|
|
612
|
-
|
|
613
|
-
else:
|
|
614
|
-
process_x = x
|
|
616
|
+
# if self.post_norm:
|
|
617
|
+
# x = self.post_attention_norm(x)
|
|
615
618
|
|
|
616
|
-
|
|
619
|
+
# if self.pre_norm:
|
|
620
|
+
# process_x = self.pre_mlp_norm(x)
|
|
621
|
+
# else:
|
|
622
|
+
# process_x = x
|
|
617
623
|
|
|
618
|
-
|
|
619
|
-
x = self.post_mlp_norm(x)
|
|
624
|
+
# processed = self.drop_path(self.ff(process_x))
|
|
620
625
|
|
|
621
|
-
|
|
626
|
+
# if self.residual_path:
|
|
627
|
+
# x = x + processed
|
|
628
|
+
|
|
629
|
+
# if self.post_norm:
|
|
630
|
+
# x = self.post_mlp_norm(x)
|
|
631
|
+
|
|
632
|
+
# return x
|
|
622
633
|
|
|
623
634
|
def attention_logits(self, x):
|
|
624
635
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|