broccoli-ml 5.1.0__py3-none-any.whl → 5.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -535,9 +535,9 @@ class TransformerBlock(nn.Module):
535
535
  """
536
536
  if self.pre_norm:
537
537
  x = self.layer_norm_1(x)
538
- return self.attn(x, x, x)
538
+ return self.attn.attention_scores(x, x, x)
539
539
  else:
540
- return self.attn(x, x, x)
540
+ return self.attn.attention_scores(x, x, x)
541
541
 
542
542
  def reset_parameters(self):
543
543
  self.layer_norm_1.reset_parameters()
@@ -683,6 +683,8 @@ class TransformerEncoder(nn.Module):
683
683
  ) # to shape (1, seq_len) to broadcast over batch
684
684
  )
685
685
 
686
+ return x
687
+
686
688
  def forward(self, x):
687
689
 
688
690
  x = self.preprocess(x)
broccoli/vit.py CHANGED
@@ -553,7 +553,7 @@ class ViT(nn.Module):
553
553
  all_attention = self.attention_scores(x)
554
554
  batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
555
555
  sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
556
- n_bos_tokens = self.encoder.encoder._bos_tokens
556
+ n_bos_tokens = self.encoder.encoder[-1]._bos_tokens
557
557
  just_bos = sequence_averages[:, :, :n_bos_tokens]
558
558
  return F.softmax(just_bos, dim=-1) # (layer, head, bos_token)
559
559
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 5.1.0
3
+ Version: 5.1.2
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
4
  broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=x3Mo6_1x6fGG6lPDPx9srxn6UdwKEpvjFAO8zoMwAMI,23052
7
+ broccoli/transformer.py,sha256=Gn8fhwgSq-izRc2nrsZ4JC5NsTIzQWFrol8pJ2pfzL4,23104
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=tUYQyoDsBc5ZR_M5_J0huj0T3OAy-vn1f19hCGVDCrM,20425
10
- broccoli_ml-5.1.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-5.1.0.dist-info/METADATA,sha256=3986lqn1iuWJ53O8ckM9LVU3tTjr32i19SeIXauWDXw,1368
12
- broccoli_ml-5.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-5.1.0.dist-info/RECORD,,
9
+ broccoli/vit.py,sha256=R73GTxx41FjVwAu4KDlYBAhytFo_9xVbbRbtBjAIW0s,20429
10
+ broccoli_ml-5.1.2.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-5.1.2.dist-info/METADATA,sha256=iYvlWeqDeDaGDO_up_xC0NBLo3AhJOi6C_5BqWEgNYI,1368
12
+ broccoli_ml-5.1.2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-5.1.2.dist-info/RECORD,,