broccoli-ml 13.0.6__tar.gz → 14.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 13.0.6
3
+ Version: 14.0.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -612,15 +612,18 @@ class EncoderBlock(nn.Module):
612
612
  else:
613
613
  process_x = x
614
614
 
615
- processed = self.drop_path(self.attn(process_x, process_x, process_x))
615
+ processed = self.attn(process_x, process_x, process_x)
616
616
 
617
617
  if self.normformer:
618
618
  processed = self.normformer_norm(processed)
619
619
 
620
+ processed = self.drop_path(processed)
621
+
620
622
  x = self.alpha * x + processed
621
623
 
622
624
  if self.post_norm:
623
625
  x = self.post_attention_norm(x)
626
+ process_x = x
624
627
  elif self.pre_norm:
625
628
  process_x = self.pre_mlp_norm(x)
626
629
  else:
@@ -638,15 +641,15 @@ class EncoderBlock(nn.Module):
638
641
  def attention_logits(self, x):
639
642
  """
640
643
  Give back the attention scores used in this layer.
644
+ Needs to match what the model actually sees during forward()
645
+ by applying the correct normalisations.
641
646
  """
642
- # Fix: Use the correct attribute name 'pre_attention_norm'
643
647
  if self.pre_norm:
644
- # We must normalize the input before measuring attention logits
645
- # to match what the model actually sees during forward()
646
648
  x = self.pre_attention_norm(x)
647
- return self.attn.attention_logits(x, x, x)
648
- else:
649
- return self.attn.attention_logits(x, x, x)
649
+ elif self.post_norm:
650
+ x = self.input_norm(x)
651
+
652
+ return self.attn.attention_logits(x, x, x)
650
653
 
651
654
  def reset_parameters(self):
652
655
  if self.pre_norm:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "13.0.6"
3
+ version = "14.0.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes