broccoli-ml 10.2.0__tar.gz → 11.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 10.2.0
3
+ Version: 11.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -24,13 +24,10 @@ except ImportError:
24
24
  class LayerScale(nn.Module):
25
25
  def __init__(self, dim, decay=False, init_values=1e-4):
26
26
  super().__init__()
27
+ self.dim = dim
27
28
  self.decay = decay
28
- if decay:
29
- self.scale = nn.Parameter(init_values * torch.ones(dim))
30
- self.nondecay_scale = None
31
- else:
32
- self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
33
- self.scale = None
29
+ self.init_values = init_values
30
+ self.reset_parameters()
34
31
 
35
32
  def forward(self, x):
36
33
  if self.decay:
@@ -38,6 +35,14 @@ class LayerScale(nn.Module):
38
35
  else:
39
36
  return x * self.nondecay_scale
40
37
 
38
+ def reset_parameters(self):
39
+ if self.decay:
40
+ self.scale = nn.Parameter(self.init_values * torch.ones(self.dim))
41
+ self.nondecay_scale = None
42
+ else:
43
+ self.nondecay_scale = nn.Parameter(self.init_values * torch.ones(self.dim))
44
+ self.scale = None
45
+
41
46
 
42
47
  def drop_path(
43
48
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
@@ -540,10 +545,18 @@ class TransformerBlock(nn.Module):
540
545
 
541
546
  self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
542
547
 
543
- self.layer_norm_1 = nn.LayerNorm(d_model)
544
- self.layer_norm_2 = nn.LayerNorm(d_model)
545
- self.layer_norm_3 = nn.LayerNorm(d_model)
548
+ if self.pre_norm:
549
+ self.pre_attention_norm = nn.LayerNorm(d_model)
550
+ self.pre_mlp_norm = nn.LayerNorm(d_model)
551
+
552
+ if normformer:
553
+ self.normformer_norm = nn.LayerNorm(d_model)
546
554
 
555
+ if self.post_norm:
556
+ self.post_attention_norm = nn.LayerNorm(d_model)
557
+ self.post_mlp_norm = nn.LayerNorm(d_model)
558
+
559
+ self.layerscale = layerscale
547
560
  if layerscale:
548
561
  self.layerscale1 = LayerScale(d_model)
549
562
  self.layerscale2 = LayerScale(d_model)
@@ -613,20 +626,31 @@ class TransformerBlock(nn.Module):
613
626
  def forward(self, x):
614
627
 
615
628
  if self.pre_norm:
616
- x = self.layer_norm_1(x)
617
- x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
618
- x = self.layer_norm_2(x)
619
- x = x + self.drop_path(self.layerscale2(self.ff(x)))
620
- if self.post_norm: # i.e. in addition! Pre and post.
621
- x = self.layer_norm_3(x)
622
- elif self.post_norm: # i.e. only, not prenorm, just post
623
- x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
624
- x = self.layer_norm_1(x)
625
- x = x + self.drop_path(self.layerscale2(self.ff(x)))
626
- x = self.layer_norm_2(x)
627
- else: # Not pre or post norm. Stand well back.
628
- x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
629
- x = x + self.drop_path(self.layerscale2(self.ff(x)))
629
+ process_x = self.pre_attention_norm(x)
630
+ else:
631
+ process_x = x
632
+
633
+ processed = self.drop_path(
634
+ self.layerscale1(self.attn(process_x, process_x, process_x))
635
+ )
636
+
637
+ if self.normformer:
638
+ processed = self.normformer_norm(processed)
639
+
640
+ x = x + processed
641
+
642
+ if self.post_norm:
643
+ x = self.post_attention_norm(x)
644
+
645
+ if self.pre_norm:
646
+ process_x = self.pre_mlp_norm(x)
647
+ else:
648
+ process_x = x
649
+
650
+ x = x + self.drop_path(self.layerscale2(self.ff(process_x)))
651
+
652
+ if self.post_norm:
653
+ x = self.post_mlp_norm(x)
630
654
 
631
655
  return x
632
656
 
@@ -634,20 +658,34 @@ class TransformerBlock(nn.Module):
634
658
  """
635
659
  Give back the attention scores used in this layer.
636
660
  """
661
+ # Fix: Use the correct attribute name 'pre_attention_norm'
637
662
  if self.pre_norm:
638
- x = self.layer_norm_1(x)
663
+ # We must normalize the input before measuring attention logits
664
+ # to match what the model actually sees during forward()
665
+ x = self.pre_attention_norm(x)
639
666
  return self.attn.attention_logits(x, x, x)
640
667
  else:
641
668
  return self.attn.attention_logits(x, x, x)
642
669
 
643
670
  def reset_parameters(self):
644
- self.layer_norm_1.reset_parameters()
645
- self.layer_norm_2.reset_parameters()
646
- self.layer_norm_3.reset_parameters()
671
+ if self.pre_norm:
672
+ self.pre_attention_norm.reset_parameters()
673
+ self.pre_mlp_norm.reset_parameters()
674
+
675
+ if self.post_norm:
676
+ self.post_attention_norm.reset_parameters()
677
+ self.post_mlp_norm.reset_parameters()
678
+
679
+ if self.normformer:
680
+ self.normformer_norm.reset_parameters()
647
681
 
648
682
  self.attn.reset_parameters()
649
683
  self.ff.reset_parameters()
650
684
 
685
+ if self.layerscale:
686
+ self.layerscale1.reset_parameters()
687
+ self.layerscale2.reset_parameters()
688
+
651
689
 
652
690
  class TransformerEncoder(nn.Module):
653
691
  """
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "10.2.0"
3
+ version = "11.0.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes