broccoli-ml 12.0.0__tar.gz → 12.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 12.0.0
3
+ Version: 12.2.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -592,33 +592,44 @@ class TransformerBlock(nn.Module):
592
592
  return self.attn._kv_distance
593
593
 
594
594
  def forward(self, x):
595
-
596
595
  if self.pre_norm:
597
- process_x = self.pre_attention_norm(x)
598
- else:
599
- process_x = x
596
+ x = self.layer_norm_1(x)
597
+ x = x + self.drop_path(self.attn(x, x, x))
598
+ x = self.layer_norm_2(x)
599
+ x = x + self.drop_path(self.ff(x))
600
+ if self.post_norm: # i.e. in addition! Pre and post.
601
+ x = self.layer_norm_3(x)
600
602
 
601
- processed = self.drop_path(self.attn(process_x, process_x, process_x))
603
+ # if self.pre_norm:
604
+ # process_x = self.pre_attention_norm(x)
605
+ # else:
606
+ # process_x = x
602
607
 
603
- if self.normformer:
604
- processed = self.normformer_norm(processed)
608
+ # processed = self.drop_path(self.attn(process_x, process_x, process_x))
605
609
 
606
- x = x + processed
610
+ # if self.normformer:
611
+ # processed = self.normformer_norm(processed)
607
612
 
608
- if self.post_norm:
609
- x = self.post_attention_norm(x)
613
+ # if self.residual_path:
614
+ # x = x + processed
610
615
 
611
- if self.pre_norm:
612
- process_x = self.pre_mlp_norm(x)
613
- else:
614
- process_x = x
616
+ # if self.post_norm:
617
+ # x = self.post_attention_norm(x)
615
618
 
616
- x = x + self.drop_path(self.ff(process_x))
619
+ # if self.pre_norm:
620
+ # process_x = self.pre_mlp_norm(x)
621
+ # else:
622
+ # process_x = x
617
623
 
618
- if self.post_norm:
619
- x = self.post_mlp_norm(x)
624
+ # processed = self.drop_path(self.ff(process_x))
620
625
 
621
- return x
626
+ # if self.residual_path:
627
+ # x = x + processed
628
+
629
+ # if self.post_norm:
630
+ # x = self.post_mlp_norm(x)
631
+
632
+ # return x
622
633
 
623
634
  def attention_logits(self, x):
624
635
  """
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "12.0.0"
3
+ version = "12.2.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes