broccoli-ml 5.1.2__py3-none-any.whl → 5.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -284,7 +284,7 @@ class MHAttention(nn.Module):
284
284
 
285
285
  return self.out_proj(output_without_heads)
286
286
 
287
- def attention_scores(self, q, k, v):
287
+ def attention_logits(self, q, k, v):
288
288
 
289
289
  q, k, v = self.project_qkv(q, k, v)
290
290
 
@@ -301,8 +301,6 @@ class MHAttention(nn.Module):
301
301
  if self.causal:
302
302
  qk_scores.masked_fill_(self.mask, float("-inf"))
303
303
 
304
- qk_scores = F.softmax(qk_scores, dim=-1)
305
-
306
304
  return qk_scores # (batch, head, seq_len, seq_len)
307
305
 
308
306
  def reset_parameters(self):
@@ -529,15 +527,15 @@ class TransformerBlock(nn.Module):
529
527
 
530
528
  return x
531
529
 
532
- def attention_scores(self, x):
530
+ def attention_logits(self, x):
533
531
  """
534
532
  Give back the attention scores used in this layer.
535
533
  """
536
534
  if self.pre_norm:
537
535
  x = self.layer_norm_1(x)
538
- return self.attn.attention_scores(x, x, x)
536
+ return self.attn.attention_logits(x, x, x)
539
537
  else:
540
- return self.attn.attention_scores(x, x, x)
538
+ return self.attn.attention_logits(x, x, x)
541
539
 
542
540
  def reset_parameters(self):
543
541
  self.layer_norm_1.reset_parameters()
@@ -697,7 +695,7 @@ class TransformerEncoder(nn.Module):
697
695
  else:
698
696
  return x
699
697
 
700
- def attention_scores(self, x):
698
+ def attention_logits(self, x):
701
699
 
702
700
  x = self.preprocess(x)
703
701
 
@@ -705,8 +703,8 @@ class TransformerEncoder(nn.Module):
705
703
 
706
704
  for block in self.blocks:
707
705
  # Get attention scores with shape (batch, 1, head, seq_len, seq_len)
708
- layer_attention_scores = block.attention_scores(x).unsqueeze(1)
709
- layer_scores.append(layer_attention_scores)
706
+ layer_attention_logits = block.attention_logits(x).unsqueeze(1)
707
+ layer_scores.append(layer_attention_logits)
710
708
  x = block(x)
711
709
 
712
710
  return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
broccoli/vit.py CHANGED
@@ -400,9 +400,9 @@ class ViTEncoder(nn.Module):
400
400
  def forward(self, x):
401
401
  return self.encoder(x)
402
402
 
403
- def attention_scores(self, x):
403
+ def attention_logits(self, x):
404
404
  x = self.encoder[:-1](x)
405
- return self.encoder[-1].attention_scores(x)
405
+ return self.encoder[-1].attention_logits(x)
406
406
 
407
407
  def reset_parameters(self):
408
408
  for module in self.encoder:
@@ -546,16 +546,16 @@ class ViT(nn.Module):
546
546
  def forward(self, x):
547
547
  return self.pool(self.encoder(x))
548
548
 
549
- def attention_scores(self, x):
550
- return self.encoder.attention_scores(x)
549
+ def attention_logits(self, x):
550
+ return self.encoder.attention_logits(x)
551
551
 
552
552
  def head_to_bos_token_attention(self, x):
553
- all_attention = self.attention_scores(x)
553
+ all_attention = self.attention_logits(x)
554
554
  batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
555
555
  sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
556
556
  n_bos_tokens = self.encoder.encoder[-1]._bos_tokens
557
557
  just_bos = sequence_averages[:, :, :n_bos_tokens]
558
- return F.softmax(just_bos, dim=-1) # (layer, head, bos_token)
558
+ return F.softmax(just_bos, dim=1) # (layer, head, bos_token)
559
559
 
560
560
  def reset_parameters(self):
561
561
  self.encoder.reset_parameters()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 5.1.2
3
+ Version: 5.1.4
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
4
  broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=Gn8fhwgSq-izRc2nrsZ4JC5NsTIzQWFrol8pJ2pfzL4,23104
7
+ broccoli/transformer.py,sha256=hLk4p-aS2xJiKHellkFKow5bgkAinoMlS8xTT5aQnII,23054
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=R73GTxx41FjVwAu4KDlYBAhytFo_9xVbbRbtBjAIW0s,20429
10
- broccoli_ml-5.1.2.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-5.1.2.dist-info/METADATA,sha256=iYvlWeqDeDaGDO_up_xC0NBLo3AhJOi6C_5BqWEgNYI,1368
12
- broccoli_ml-5.1.2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-5.1.2.dist-info/RECORD,,
9
+ broccoli/vit.py,sha256=Dy1egtEPft0l3812BENooc7SNogEibRPamXzy4jaJxA,20428
10
+ broccoli_ml-5.1.4.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-5.1.4.dist-info/METADATA,sha256=s59fAtdmIjKhdP29kTPEgMp5buRGJzj5WgaBxsW83yg,1368
12
+ broccoli_ml-5.1.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-5.1.4.dist-info/RECORD,,