broccoli-ml 10.2.0__tar.gz → 12.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 10.2.0
3
+ Version: 12.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -21,24 +21,6 @@ except ImportError:
21
21
  FLASH_ATTN = False
22
22
 
23
23
 
24
- class LayerScale(nn.Module):
25
- def __init__(self, dim, decay=False, init_values=1e-4):
26
- super().__init__()
27
- self.decay = decay
28
- if decay:
29
- self.scale = nn.Parameter(init_values * torch.ones(dim))
30
- self.nondecay_scale = None
31
- else:
32
- self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
33
- self.scale = None
34
-
35
- def forward(self, x):
36
- if self.decay:
37
- return x * self.scale
38
- else:
39
- return x * self.nondecay_scale
40
-
41
-
42
24
  def drop_path(
43
25
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
44
26
  ):
@@ -522,7 +504,6 @@ class TransformerBlock(nn.Module):
522
504
  post_norm=False,
523
505
  normformer=False,
524
506
  checkpoint_ff=True,
525
- layerscale=True,
526
507
  ):
527
508
  """
528
509
  Args:
@@ -540,16 +521,16 @@ class TransformerBlock(nn.Module):
540
521
 
541
522
  self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
542
523
 
543
- self.layer_norm_1 = nn.LayerNorm(d_model)
544
- self.layer_norm_2 = nn.LayerNorm(d_model)
545
- self.layer_norm_3 = nn.LayerNorm(d_model)
524
+ if self.pre_norm:
525
+ self.pre_attention_norm = nn.LayerNorm(d_model)
526
+ self.pre_mlp_norm = nn.LayerNorm(d_model)
546
527
 
547
- if layerscale:
548
- self.layerscale1 = LayerScale(d_model)
549
- self.layerscale2 = LayerScale(d_model)
550
- else:
551
- self.layerscale1 = nn.Identity()
552
- self.layerscale2 = nn.Identity()
528
+ if normformer:
529
+ self.normformer_norm = nn.LayerNorm(d_model)
530
+
531
+ if self.post_norm:
532
+ self.post_attention_norm = nn.LayerNorm(d_model)
533
+ self.post_mlp_norm = nn.LayerNorm(d_model)
553
534
 
554
535
  if relative_position_embedding:
555
536
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -613,20 +594,29 @@ class TransformerBlock(nn.Module):
613
594
  def forward(self, x):
614
595
 
615
596
  if self.pre_norm:
616
- x = self.layer_norm_1(x)
617
- x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
618
- x = self.layer_norm_2(x)
619
- x = x + self.drop_path(self.layerscale2(self.ff(x)))
620
- if self.post_norm: # i.e. in addition! Pre and post.
621
- x = self.layer_norm_3(x)
622
- elif self.post_norm: # i.e. only, not prenorm, just post
623
- x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
624
- x = self.layer_norm_1(x)
625
- x = x + self.drop_path(self.layerscale2(self.ff(x)))
626
- x = self.layer_norm_2(x)
627
- else: # Not pre or post norm. Stand well back.
628
- x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
629
- x = x + self.drop_path(self.layerscale2(self.ff(x)))
597
+ process_x = self.pre_attention_norm(x)
598
+ else:
599
+ process_x = x
600
+
601
+ processed = self.drop_path(self.attn(process_x, process_x, process_x))
602
+
603
+ if self.normformer:
604
+ processed = self.normformer_norm(processed)
605
+
606
+ x = x + processed
607
+
608
+ if self.post_norm:
609
+ x = self.post_attention_norm(x)
610
+
611
+ if self.pre_norm:
612
+ process_x = self.pre_mlp_norm(x)
613
+ else:
614
+ process_x = x
615
+
616
+ x = x + self.drop_path(self.ff(process_x))
617
+
618
+ if self.post_norm:
619
+ x = self.post_mlp_norm(x)
630
620
 
631
621
  return x
632
622
 
@@ -634,16 +624,26 @@ class TransformerBlock(nn.Module):
634
624
  """
635
625
  Give back the attention scores used in this layer.
636
626
  """
627
+ # Fix: Use the correct attribute name 'pre_attention_norm'
637
628
  if self.pre_norm:
638
- x = self.layer_norm_1(x)
629
+ # We must normalize the input before measuring attention logits
630
+ # to match what the model actually sees during forward()
631
+ x = self.pre_attention_norm(x)
639
632
  return self.attn.attention_logits(x, x, x)
640
633
  else:
641
634
  return self.attn.attention_logits(x, x, x)
642
635
 
643
636
  def reset_parameters(self):
644
- self.layer_norm_1.reset_parameters()
645
- self.layer_norm_2.reset_parameters()
646
- self.layer_norm_3.reset_parameters()
637
+ if self.pre_norm:
638
+ self.pre_attention_norm.reset_parameters()
639
+ self.pre_mlp_norm.reset_parameters()
640
+
641
+ if self.post_norm:
642
+ self.post_attention_norm.reset_parameters()
643
+ self.post_mlp_norm.reset_parameters()
644
+
645
+ if self.normformer:
646
+ self.normformer_norm.reset_parameters()
647
647
 
648
648
  self.attn.reset_parameters()
649
649
  self.ff.reset_parameters()
@@ -684,7 +684,6 @@ class TransformerEncoder(nn.Module):
684
684
  normformer=False,
685
685
  msa_scaling="d",
686
686
  checkpoint_ff=True,
687
- layerscale=True,
688
687
  ):
689
688
  """
690
689
  Args:
@@ -719,12 +718,6 @@ class TransformerEncoder(nn.Module):
719
718
  self._utility_tokens = utility_tokens
720
719
  self.return_utility_tokens = return_utility_tokens
721
720
 
722
- if layerscale:
723
- rope_and_ape = absolute_position_embedding and relative_position_embedding
724
- self.position_layerscale = LayerScale(d_model, decay=rope_and_ape)
725
- else:
726
- self.position_layerscale = None
727
-
728
721
  # Initialise utility tokens with normal init, like usual Pytorch embeddings
729
722
  if self._utility_tokens:
730
723
  self._utility_token_embedding = nn.Parameter(
@@ -789,7 +782,6 @@ class TransformerEncoder(nn.Module):
789
782
  post_norm=post_norm,
790
783
  normformer=normformer,
791
784
  checkpoint_ff=checkpoint_ff,
792
- layerscale=layerscale,
793
785
  )
794
786
  for i in range(n_layers)
795
787
  ]
@@ -817,8 +809,6 @@ class TransformerEncoder(nn.Module):
817
809
  0
818
810
  ) # to shape (1, seq_len) to broadcast over batch
819
811
  )
820
- if self.position_layerscale is not None:
821
- position_embedding = self.position_layerscale(position_embedding)
822
812
  x += position_embedding
823
813
 
824
814
  return x
@@ -187,7 +187,6 @@ class ViTEncoder(nn.Module):
187
187
  transformer_msa_dropout=0.1,
188
188
  transformer_stochastic_depth=0.1,
189
189
  transformer_checkpoint_ff=True,
190
- transformer_layerscale=True,
191
190
  linear_module=nn.Linear,
192
191
  ):
193
192
  super().__init__()
@@ -353,7 +352,6 @@ class ViTEncoder(nn.Module):
353
352
  normformer=transformer_normformer,
354
353
  post_norm=transformer_post_norm,
355
354
  checkpoint_ff=transformer_checkpoint_ff,
356
- layerscale=transformer_layerscale,
357
355
  )
358
356
  else:
359
357
  self.transformer = nn.Identity()
@@ -489,7 +487,6 @@ class ViT(nn.Module):
489
487
  transformer_msa_dropout=0.1,
490
488
  transformer_stochastic_depth=0.1,
491
489
  transformer_checkpoint_ff=True,
492
- transformer_layerscale=True,
493
490
  head=SequencePoolClassificationHead,
494
491
  batch_norm_logits=True,
495
492
  logit_projection_layer=nn.Linear,
@@ -562,7 +559,6 @@ class ViT(nn.Module):
562
559
  transformer_msa_dropout=transformer_msa_dropout,
563
560
  transformer_stochastic_depth=transformer_stochastic_depth,
564
561
  transformer_checkpoint_ff=transformer_checkpoint_ff,
565
- transformer_layerscale=transformer_layerscale,
566
562
  linear_module=linear_module,
567
563
  )
568
564
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "10.2.0"
3
+ version = "12.0.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes