broccoli-ml 11.0.0__py3-none-any.whl → 12.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -21,29 +21,6 @@ except ImportError:
21
21
  FLASH_ATTN = False
22
22
 
23
23
 
24
- class LayerScale(nn.Module):
25
- def __init__(self, dim, decay=False, init_values=1e-4):
26
- super().__init__()
27
- self.dim = dim
28
- self.decay = decay
29
- self.init_values = init_values
30
- self.reset_parameters()
31
-
32
- def forward(self, x):
33
- if self.decay:
34
- return x * self.scale
35
- else:
36
- return x * self.nondecay_scale
37
-
38
- def reset_parameters(self):
39
- if self.decay:
40
- self.scale = nn.Parameter(self.init_values * torch.ones(self.dim))
41
- self.nondecay_scale = None
42
- else:
43
- self.nondecay_scale = nn.Parameter(self.init_values * torch.ones(self.dim))
44
- self.scale = None
45
-
46
-
47
24
  def drop_path(
48
25
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
49
26
  ):
@@ -527,7 +504,6 @@ class TransformerBlock(nn.Module):
527
504
  post_norm=False,
528
505
  normformer=False,
529
506
  checkpoint_ff=True,
530
- layerscale=True,
531
507
  ):
532
508
  """
533
509
  Args:
@@ -556,14 +532,6 @@ class TransformerBlock(nn.Module):
556
532
  self.post_attention_norm = nn.LayerNorm(d_model)
557
533
  self.post_mlp_norm = nn.LayerNorm(d_model)
558
534
 
559
- self.layerscale = layerscale
560
- if layerscale:
561
- self.layerscale1 = LayerScale(d_model)
562
- self.layerscale2 = LayerScale(d_model)
563
- else:
564
- self.layerscale1 = nn.Identity()
565
- self.layerscale2 = nn.Identity()
566
-
567
535
  if relative_position_embedding:
568
536
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
569
537
  if d_model < 16:
@@ -624,35 +592,52 @@ class TransformerBlock(nn.Module):
624
592
  return self.attn._kv_distance
625
593
 
626
594
  def forward(self, x):
627
-
628
595
  if self.pre_norm:
629
- process_x = self.pre_attention_norm(x)
630
- else:
631
- process_x = x
596
+ x = self.layer_norm_1(x)
597
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
598
+ x = self.layer_norm_2(x)
599
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
600
+ if self.post_norm: # i.e. in addition! Pre and post.
601
+ x = self.layer_norm_3(x)
602
+ elif self.post_norm: # i.e. only, not prenorm, just post
603
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
604
+ x = self.layer_norm_1(x)
605
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
606
+ x = self.layer_norm_2(x)
607
+ else: # Not pre or post norm. Stand well back.
608
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
609
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
632
610
 
633
- processed = self.drop_path(
634
- self.layerscale1(self.attn(process_x, process_x, process_x))
635
- )
611
+ # if self.pre_norm:
612
+ # process_x = self.pre_attention_norm(x)
613
+ # else:
614
+ # process_x = x
636
615
 
637
- if self.normformer:
638
- processed = self.normformer_norm(processed)
616
+ # processed = self.drop_path(self.attn(process_x, process_x, process_x))
639
617
 
640
- x = x + processed
618
+ # if self.normformer:
619
+ # processed = self.normformer_norm(processed)
641
620
 
642
- if self.post_norm:
643
- x = self.post_attention_norm(x)
621
+ # if self.residual_path:
622
+ # x = x + processed
644
623
 
645
- if self.pre_norm:
646
- process_x = self.pre_mlp_norm(x)
647
- else:
648
- process_x = x
624
+ # if self.post_norm:
625
+ # x = self.post_attention_norm(x)
649
626
 
650
- x = x + self.drop_path(self.layerscale2(self.ff(process_x)))
627
+ # if self.pre_norm:
628
+ # process_x = self.pre_mlp_norm(x)
629
+ # else:
630
+ # process_x = x
651
631
 
652
- if self.post_norm:
653
- x = self.post_mlp_norm(x)
632
+ # processed = self.drop_path(self.ff(process_x))
654
633
 
655
- return x
634
+ # if self.residual_path:
635
+ # x = x + processed
636
+
637
+ # if self.post_norm:
638
+ # x = self.post_mlp_norm(x)
639
+
640
+ # return x
656
641
 
657
642
  def attention_logits(self, x):
658
643
  """
@@ -682,10 +667,6 @@ class TransformerBlock(nn.Module):
682
667
  self.attn.reset_parameters()
683
668
  self.ff.reset_parameters()
684
669
 
685
- if self.layerscale:
686
- self.layerscale1.reset_parameters()
687
- self.layerscale2.reset_parameters()
688
-
689
670
 
690
671
  class TransformerEncoder(nn.Module):
691
672
  """
@@ -722,7 +703,6 @@ class TransformerEncoder(nn.Module):
722
703
  normformer=False,
723
704
  msa_scaling="d",
724
705
  checkpoint_ff=True,
725
- layerscale=True,
726
706
  ):
727
707
  """
728
708
  Args:
@@ -757,12 +737,6 @@ class TransformerEncoder(nn.Module):
757
737
  self._utility_tokens = utility_tokens
758
738
  self.return_utility_tokens = return_utility_tokens
759
739
 
760
- if layerscale:
761
- rope_and_ape = absolute_position_embedding and relative_position_embedding
762
- self.position_layerscale = LayerScale(d_model, decay=rope_and_ape)
763
- else:
764
- self.position_layerscale = None
765
-
766
740
  # Initialise utility tokens with normal init, like usual Pytorch embeddings
767
741
  if self._utility_tokens:
768
742
  self._utility_token_embedding = nn.Parameter(
@@ -827,7 +801,6 @@ class TransformerEncoder(nn.Module):
827
801
  post_norm=post_norm,
828
802
  normformer=normformer,
829
803
  checkpoint_ff=checkpoint_ff,
830
- layerscale=layerscale,
831
804
  )
832
805
  for i in range(n_layers)
833
806
  ]
@@ -855,8 +828,6 @@ class TransformerEncoder(nn.Module):
855
828
  0
856
829
  ) # to shape (1, seq_len) to broadcast over batch
857
830
  )
858
- if self.position_layerscale is not None:
859
- position_embedding = self.position_layerscale(position_embedding)
860
831
  x += position_embedding
861
832
 
862
833
  return x
broccoli/vit.py CHANGED
@@ -187,7 +187,6 @@ class ViTEncoder(nn.Module):
187
187
  transformer_msa_dropout=0.1,
188
188
  transformer_stochastic_depth=0.1,
189
189
  transformer_checkpoint_ff=True,
190
- transformer_layerscale=True,
191
190
  linear_module=nn.Linear,
192
191
  ):
193
192
  super().__init__()
@@ -353,7 +352,6 @@ class ViTEncoder(nn.Module):
353
352
  normformer=transformer_normformer,
354
353
  post_norm=transformer_post_norm,
355
354
  checkpoint_ff=transformer_checkpoint_ff,
356
- layerscale=transformer_layerscale,
357
355
  )
358
356
  else:
359
357
  self.transformer = nn.Identity()
@@ -489,7 +487,6 @@ class ViT(nn.Module):
489
487
  transformer_msa_dropout=0.1,
490
488
  transformer_stochastic_depth=0.1,
491
489
  transformer_checkpoint_ff=True,
492
- transformer_layerscale=True,
493
490
  head=SequencePoolClassificationHead,
494
491
  batch_norm_logits=True,
495
492
  logit_projection_layer=nn.Linear,
@@ -562,7 +559,6 @@ class ViT(nn.Module):
562
559
  transformer_msa_dropout=transformer_msa_dropout,
563
560
  transformer_stochastic_depth=transformer_stochastic_depth,
564
561
  transformer_checkpoint_ff=transformer_checkpoint_ff,
565
- transformer_layerscale=transformer_layerscale,
566
562
  linear_module=linear_module,
567
563
  )
568
564
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 11.0.0
3
+ Version: 12.1.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
4
  broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=urW9KGxK47Mns7ZowPKTgdEyp4Yd21uw3PNCpKx04cI,29223
7
+ broccoli/transformer.py,sha256=cPXHGlfkvJAXgr_ONyESQ0RAEzn7yqZEoyZ3l4N3EX8,28571
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=EGbQb-atuzG3JAx7kdTaJEbWvQR-4XgyYvwjKkN5C38,22612
10
- broccoli_ml-11.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-11.0.0.dist-info/METADATA,sha256=7ixQROHXr3LyCMhT5hkKmvip8C8Bytmtz-IoQCbC_NQ,1369
12
- broccoli_ml-11.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-11.0.0.dist-info/RECORD,,
9
+ broccoli/vit.py,sha256=DvVpayMIcUhH7Xg6CiYyeedUuuMHrjsGxEdXfnTGa_Q,22428
10
+ broccoli_ml-12.1.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-12.1.0.dist-info/METADATA,sha256=2j4nt5aHVcPiW_YF2Sx4Yu_l3MKe2fa0S9v3aQ1PZE0,1369
12
+ broccoli_ml-12.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-12.1.0.dist-info/RECORD,,