broccoli-ml 11.0.0__tar.gz → 12.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/PKG-INFO +1 -1
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/broccoli/transformer.py +37 -66
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/broccoli/vit.py +0 -4
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/pyproject.toml +1 -1
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/LICENSE +0 -0
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/README.md +0 -0
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/broccoli/activation.py +0 -0
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/broccoli/linear.py +0 -0
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/broccoli/rope.py +0 -0
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-11.0.0 → broccoli_ml-12.1.0}/broccoli/utils.py +0 -0
|
@@ -21,29 +21,6 @@ except ImportError:
|
|
|
21
21
|
FLASH_ATTN = False
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
class LayerScale(nn.Module):
|
|
25
|
-
def __init__(self, dim, decay=False, init_values=1e-4):
|
|
26
|
-
super().__init__()
|
|
27
|
-
self.dim = dim
|
|
28
|
-
self.decay = decay
|
|
29
|
-
self.init_values = init_values
|
|
30
|
-
self.reset_parameters()
|
|
31
|
-
|
|
32
|
-
def forward(self, x):
|
|
33
|
-
if self.decay:
|
|
34
|
-
return x * self.scale
|
|
35
|
-
else:
|
|
36
|
-
return x * self.nondecay_scale
|
|
37
|
-
|
|
38
|
-
def reset_parameters(self):
|
|
39
|
-
if self.decay:
|
|
40
|
-
self.scale = nn.Parameter(self.init_values * torch.ones(self.dim))
|
|
41
|
-
self.nondecay_scale = None
|
|
42
|
-
else:
|
|
43
|
-
self.nondecay_scale = nn.Parameter(self.init_values * torch.ones(self.dim))
|
|
44
|
-
self.scale = None
|
|
45
|
-
|
|
46
|
-
|
|
47
24
|
def drop_path(
|
|
48
25
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
49
26
|
):
|
|
@@ -527,7 +504,6 @@ class TransformerBlock(nn.Module):
|
|
|
527
504
|
post_norm=False,
|
|
528
505
|
normformer=False,
|
|
529
506
|
checkpoint_ff=True,
|
|
530
|
-
layerscale=True,
|
|
531
507
|
):
|
|
532
508
|
"""
|
|
533
509
|
Args:
|
|
@@ -556,14 +532,6 @@ class TransformerBlock(nn.Module):
|
|
|
556
532
|
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
557
533
|
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
558
534
|
|
|
559
|
-
self.layerscale = layerscale
|
|
560
|
-
if layerscale:
|
|
561
|
-
self.layerscale1 = LayerScale(d_model)
|
|
562
|
-
self.layerscale2 = LayerScale(d_model)
|
|
563
|
-
else:
|
|
564
|
-
self.layerscale1 = nn.Identity()
|
|
565
|
-
self.layerscale2 = nn.Identity()
|
|
566
|
-
|
|
567
535
|
if relative_position_embedding:
|
|
568
536
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
569
537
|
if d_model < 16:
|
|
@@ -624,35 +592,52 @@ class TransformerBlock(nn.Module):
|
|
|
624
592
|
return self.attn._kv_distance
|
|
625
593
|
|
|
626
594
|
def forward(self, x):
|
|
627
|
-
|
|
628
595
|
if self.pre_norm:
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
596
|
+
x = self.layer_norm_1(x)
|
|
597
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
598
|
+
x = self.layer_norm_2(x)
|
|
599
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
600
|
+
if self.post_norm: # i.e. in addition! Pre and post.
|
|
601
|
+
x = self.layer_norm_3(x)
|
|
602
|
+
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
603
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
604
|
+
x = self.layer_norm_1(x)
|
|
605
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
606
|
+
x = self.layer_norm_2(x)
|
|
607
|
+
else: # Not pre or post norm. Stand well back.
|
|
608
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
609
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
632
610
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
611
|
+
# if self.pre_norm:
|
|
612
|
+
# process_x = self.pre_attention_norm(x)
|
|
613
|
+
# else:
|
|
614
|
+
# process_x = x
|
|
636
615
|
|
|
637
|
-
|
|
638
|
-
processed = self.normformer_norm(processed)
|
|
616
|
+
# processed = self.drop_path(self.attn(process_x, process_x, process_x))
|
|
639
617
|
|
|
640
|
-
|
|
618
|
+
# if self.normformer:
|
|
619
|
+
# processed = self.normformer_norm(processed)
|
|
641
620
|
|
|
642
|
-
if self.
|
|
643
|
-
|
|
621
|
+
# if self.residual_path:
|
|
622
|
+
# x = x + processed
|
|
644
623
|
|
|
645
|
-
if self.
|
|
646
|
-
|
|
647
|
-
else:
|
|
648
|
-
process_x = x
|
|
624
|
+
# if self.post_norm:
|
|
625
|
+
# x = self.post_attention_norm(x)
|
|
649
626
|
|
|
650
|
-
|
|
627
|
+
# if self.pre_norm:
|
|
628
|
+
# process_x = self.pre_mlp_norm(x)
|
|
629
|
+
# else:
|
|
630
|
+
# process_x = x
|
|
651
631
|
|
|
652
|
-
|
|
653
|
-
x = self.post_mlp_norm(x)
|
|
632
|
+
# processed = self.drop_path(self.ff(process_x))
|
|
654
633
|
|
|
655
|
-
|
|
634
|
+
# if self.residual_path:
|
|
635
|
+
# x = x + processed
|
|
636
|
+
|
|
637
|
+
# if self.post_norm:
|
|
638
|
+
# x = self.post_mlp_norm(x)
|
|
639
|
+
|
|
640
|
+
# return x
|
|
656
641
|
|
|
657
642
|
def attention_logits(self, x):
|
|
658
643
|
"""
|
|
@@ -682,10 +667,6 @@ class TransformerBlock(nn.Module):
|
|
|
682
667
|
self.attn.reset_parameters()
|
|
683
668
|
self.ff.reset_parameters()
|
|
684
669
|
|
|
685
|
-
if self.layerscale:
|
|
686
|
-
self.layerscale1.reset_parameters()
|
|
687
|
-
self.layerscale2.reset_parameters()
|
|
688
|
-
|
|
689
670
|
|
|
690
671
|
class TransformerEncoder(nn.Module):
|
|
691
672
|
"""
|
|
@@ -722,7 +703,6 @@ class TransformerEncoder(nn.Module):
|
|
|
722
703
|
normformer=False,
|
|
723
704
|
msa_scaling="d",
|
|
724
705
|
checkpoint_ff=True,
|
|
725
|
-
layerscale=True,
|
|
726
706
|
):
|
|
727
707
|
"""
|
|
728
708
|
Args:
|
|
@@ -757,12 +737,6 @@ class TransformerEncoder(nn.Module):
|
|
|
757
737
|
self._utility_tokens = utility_tokens
|
|
758
738
|
self.return_utility_tokens = return_utility_tokens
|
|
759
739
|
|
|
760
|
-
if layerscale:
|
|
761
|
-
rope_and_ape = absolute_position_embedding and relative_position_embedding
|
|
762
|
-
self.position_layerscale = LayerScale(d_model, decay=rope_and_ape)
|
|
763
|
-
else:
|
|
764
|
-
self.position_layerscale = None
|
|
765
|
-
|
|
766
740
|
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
767
741
|
if self._utility_tokens:
|
|
768
742
|
self._utility_token_embedding = nn.Parameter(
|
|
@@ -827,7 +801,6 @@ class TransformerEncoder(nn.Module):
|
|
|
827
801
|
post_norm=post_norm,
|
|
828
802
|
normformer=normformer,
|
|
829
803
|
checkpoint_ff=checkpoint_ff,
|
|
830
|
-
layerscale=layerscale,
|
|
831
804
|
)
|
|
832
805
|
for i in range(n_layers)
|
|
833
806
|
]
|
|
@@ -855,8 +828,6 @@ class TransformerEncoder(nn.Module):
|
|
|
855
828
|
0
|
|
856
829
|
) # to shape (1, seq_len) to broadcast over batch
|
|
857
830
|
)
|
|
858
|
-
if self.position_layerscale is not None:
|
|
859
|
-
position_embedding = self.position_layerscale(position_embedding)
|
|
860
831
|
x += position_embedding
|
|
861
832
|
|
|
862
833
|
return x
|
|
@@ -187,7 +187,6 @@ class ViTEncoder(nn.Module):
|
|
|
187
187
|
transformer_msa_dropout=0.1,
|
|
188
188
|
transformer_stochastic_depth=0.1,
|
|
189
189
|
transformer_checkpoint_ff=True,
|
|
190
|
-
transformer_layerscale=True,
|
|
191
190
|
linear_module=nn.Linear,
|
|
192
191
|
):
|
|
193
192
|
super().__init__()
|
|
@@ -353,7 +352,6 @@ class ViTEncoder(nn.Module):
|
|
|
353
352
|
normformer=transformer_normformer,
|
|
354
353
|
post_norm=transformer_post_norm,
|
|
355
354
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
356
|
-
layerscale=transformer_layerscale,
|
|
357
355
|
)
|
|
358
356
|
else:
|
|
359
357
|
self.transformer = nn.Identity()
|
|
@@ -489,7 +487,6 @@ class ViT(nn.Module):
|
|
|
489
487
|
transformer_msa_dropout=0.1,
|
|
490
488
|
transformer_stochastic_depth=0.1,
|
|
491
489
|
transformer_checkpoint_ff=True,
|
|
492
|
-
transformer_layerscale=True,
|
|
493
490
|
head=SequencePoolClassificationHead,
|
|
494
491
|
batch_norm_logits=True,
|
|
495
492
|
logit_projection_layer=nn.Linear,
|
|
@@ -562,7 +559,6 @@ class ViT(nn.Module):
|
|
|
562
559
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
563
560
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
564
561
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
565
|
-
transformer_layerscale=transformer_layerscale,
|
|
566
562
|
linear_module=linear_module,
|
|
567
563
|
)
|
|
568
564
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|