broccoli-ml 5.2.0__tar.gz → 6.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 5.2.0
3
+ Version: 6.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -76,7 +76,7 @@ class MHAttention(nn.Module):
76
76
  causal=False,
77
77
  seq_len=None,
78
78
  linear_module: nn.Module = nn.Linear,
79
- bos_tokens=0,
79
+ utility_tokens=0,
80
80
  rotary_embedding=None,
81
81
  source_size=None,
82
82
  scaling="d",
@@ -129,7 +129,7 @@ class MHAttention(nn.Module):
129
129
  )
130
130
  self.rotary_embedding = rotary_embedding
131
131
  self.source_size = source_size
132
- self.bos_tokens = bos_tokens
132
+ self.utility_tokens = utility_tokens
133
133
 
134
134
  self.reset_parameters()
135
135
 
@@ -156,7 +156,7 @@ class MHAttention(nn.Module):
156
156
  self, q: torch.Tensor, k: torch.Tensor
157
157
  ) -> Tuple[torch.Tensor, torch.Tensor]:
158
158
  """
159
- Apply Axial RoPE to all tokens except BOS tokens
159
+ Apply Axial RoPE to all tokens except utility tokens
160
160
  """
161
161
 
162
162
  if len(self.source_size) == 1:
@@ -180,8 +180,8 @@ class MHAttention(nn.Module):
180
180
  "`source_size` must be a tuple of 1, 2 or 3 integers"
181
181
  )
182
182
 
183
- q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
184
- k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
183
+ q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
184
+ k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
185
185
 
186
186
  q_img = rearrange(
187
187
  q_img,
@@ -208,9 +208,9 @@ class MHAttention(nn.Module):
208
208
  f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
209
209
  )
210
210
 
211
- # Re-combine the BOS tokens and the RoPE-enhanced image tokens
212
- q = torch.cat([q_bos, q_img], dim=1)
213
- k = torch.cat([k_bos, k_img], dim=1)
211
+ # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
212
+ q = torch.cat([q_util, q_img], dim=1)
213
+ k = torch.cat([k_util, k_img], dim=1)
214
214
 
215
215
  return q, k
216
216
 
@@ -414,7 +414,7 @@ class TransformerBlock(nn.Module):
414
414
  n_heads,
415
415
  relative_position_embedding=False,
416
416
  source_size=None,
417
- bos_tokens=0,
417
+ utility_tokens=0,
418
418
  mlp_ratio=4,
419
419
  activation: nn.Module = nn.ReLU,
420
420
  activation_kwargs: Optional[dict] = None,
@@ -472,7 +472,7 @@ class TransformerBlock(nn.Module):
472
472
  linear_module=linear_module,
473
473
  rotary_embedding=self.rotary_embedding,
474
474
  source_size=source_size,
475
- bos_tokens=bos_tokens,
475
+ utility_tokens=utility_tokens,
476
476
  scaling=msa_scaling,
477
477
  )
478
478
 
@@ -571,8 +571,8 @@ class TransformerEncoder(nn.Module):
571
571
  stochastic_depth=0.0,
572
572
  causal=False,
573
573
  linear_module=nn.Linear,
574
- bos_tokens=0,
575
- return_bos_tokens=False,
574
+ utility_tokens=0,
575
+ return_utility_tokens=False,
576
576
  pre_norm=True,
577
577
  post_norm=False,
578
578
  normformer=False,
@@ -596,16 +596,18 @@ class TransformerEncoder(nn.Module):
596
596
  super().__init__()
597
597
  self.seq_len = seq_len
598
598
  self.n_heads = n_heads
599
- self._bos_tokens = bos_tokens
600
- self.return_bos_tokens = return_bos_tokens
601
-
602
- # Initialise BOS tokens with normal init, like usual Pytorch embeddings
603
- if self._bos_tokens:
604
- self._bos_embedding = nn.Parameter(torch.empty(self._bos_tokens, d_model))
605
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
606
- self.full_sequence_length = self.seq_len + self._bos_tokens
599
+ self._utility_tokens = utility_tokens
600
+ self.return_utility_tokens = return_utility_tokens
601
+
602
+ # Initialise utility tokens with normal init, like usual Pytorch embeddings
603
+ if self._utility_tokens:
604
+ self._utility_token_embedding = nn.Parameter(
605
+ torch.empty(self._utility_tokens, d_model)
606
+ )
607
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
608
+ self.full_sequence_length = self.seq_len + self._utility_tokens
607
609
  else:
608
- self._bos_embedding = None
610
+ self._utility_token_embedding = None
609
611
  self.full_sequence_length = self.seq_len
610
612
 
611
613
  self.d_model = d_model
@@ -639,7 +641,7 @@ class TransformerEncoder(nn.Module):
639
641
  n_heads,
640
642
  relative_position_embedding=relative_position_embedding,
641
643
  source_size=source_size,
642
- bos_tokens=bos_tokens,
644
+ utility_tokens=utility_tokens,
643
645
  mlp_ratio=mlp_ratio,
644
646
  activation=activation,
645
647
  activation_kwargs=activation_kwargs,
@@ -667,8 +669,10 @@ class TransformerEncoder(nn.Module):
667
669
  return ",".join([str(block._kv_distance) for block in self.blocks])
668
670
 
669
671
  def preprocess(self, x):
670
- if self._bos_tokens:
671
- x = torch.cat([self._bos_embedding.expand(x.size(0), -1, -1), x], dim=1)
672
+ if self._utility_tokens:
673
+ x = torch.cat(
674
+ [self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
675
+ )
672
676
  else:
673
677
  x = x
674
678
 
@@ -690,8 +694,8 @@ class TransformerEncoder(nn.Module):
690
694
  for block in self.blocks:
691
695
  x = block(x)
692
696
 
693
- if self._bos_tokens and not self.return_bos_tokens:
694
- return x[:, self._bos_tokens :, :]
697
+ if self._utility_tokens and not self.return_utility_tokens:
698
+ return x[:, self._utility_tokens :, :]
695
699
  else:
696
700
  return x
697
701
 
@@ -710,8 +714,8 @@ class TransformerEncoder(nn.Module):
710
714
  return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
711
715
 
712
716
  def reset_parameters(self):
713
- if self._bos_embedding is not None:
714
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
717
+ if self._utility_token_embedding is not None:
718
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
715
719
 
716
720
  if self.absolute_position_embedding is not None:
717
721
  self.absolute_position_embedding.reset_parameters()
@@ -11,7 +11,6 @@ from einops.layers.torch import Rearrange
11
11
 
12
12
  import torch
13
13
  import torch.nn as nn
14
- import torch.nn.functional as F
15
14
 
16
15
 
17
16
  class GetCLSToken(nn.Module):
@@ -39,6 +38,9 @@ class SequencePool(nn.Module):
39
38
  weights = self.attention(x)
40
39
  return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
41
40
 
41
+ def attention_scores(self, x):
42
+ return self.attention(x)
43
+
42
44
  def reset_parameters(self):
43
45
  # Iterate over modules in the sequential block
44
46
  for module in self.attention:
@@ -169,8 +171,8 @@ class ViTEncoder(nn.Module):
169
171
  transformer_layers=7,
170
172
  transformer_heads=4,
171
173
  transformer_mlp_ratio=2,
172
- transformer_bos_tokens=0,
173
- transformer_return_bos_tokens=False,
174
+ transformer_utility_tokens=0,
175
+ transformer_return_utility_tokens=False,
174
176
  transformer_activation: nn.Module = SquaredReLU,
175
177
  transformer_activation_kwargs: Optional[dict] = None,
176
178
  transformer_ff_linear_module_up=None,
@@ -336,8 +338,8 @@ class ViTEncoder(nn.Module):
336
338
  stochastic_depth=transformer_stochastic_depth,
337
339
  causal=False,
338
340
  linear_module=linear_module,
339
- bos_tokens=transformer_bos_tokens,
340
- return_bos_tokens=transformer_return_bos_tokens,
341
+ utility_tokens=transformer_utility_tokens,
342
+ return_utility_tokens=transformer_return_utility_tokens,
341
343
  pre_norm=transformer_pre_norm,
342
344
  normformer=transformer_normformer,
343
345
  post_norm=transformer_post_norm,
@@ -449,8 +451,8 @@ class ViT(nn.Module):
449
451
  transformer_layers=7,
450
452
  transformer_heads=4,
451
453
  transformer_mlp_ratio=2,
452
- transformer_bos_tokens=0,
453
- transformer_return_bos_tokens=False,
454
+ transformer_utility_tokens=0,
455
+ transformer_return_utility_tokens=False,
454
456
  transformer_activation: nn.Module = SquaredReLU,
455
457
  transformer_activation_kwargs: Optional[dict] = None,
456
458
  transformer_ff_linear_module_up=None,
@@ -516,8 +518,8 @@ class ViT(nn.Module):
516
518
  transformer_layers=transformer_layers,
517
519
  transformer_heads=transformer_heads,
518
520
  transformer_mlp_ratio=transformer_mlp_ratio,
519
- transformer_bos_tokens=transformer_bos_tokens,
520
- transformer_return_bos_tokens=transformer_return_bos_tokens,
521
+ transformer_utility_tokens=transformer_utility_tokens,
522
+ transformer_return_utility_tokens=transformer_return_utility_tokens,
521
523
  transformer_activation=transformer_activation,
522
524
  transformer_activation_kwargs=transformer_activation_kwargs,
523
525
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
@@ -549,12 +551,23 @@ class ViT(nn.Module):
549
551
  def attention_logits(self, x):
550
552
  return self.encoder.attention_logits(x)
551
553
 
552
- def head_to_bos_token_attention_logits(self, x):
554
+ def pool_attention(self, x):
555
+ if hasattr(self.pool.summarize, "attention"):
556
+ return self.pool.summarize.attention(self.encoder(x))
557
+ else:
558
+ raise NotImplementedError(
559
+ "`pool_attention` is currently only implemented where"
560
+ " head class is SequencePoolClassificationHead"
561
+ )
562
+
563
+ def head_to_utility_token_attention_logits(self, x):
553
564
  all_attention = self.attention_logits(x)
554
565
  batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
555
566
  sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
556
- n_bos_tokens = self.encoder.encoder[-1]._bos_tokens
557
- return sequence_averages[:, :, :n_bos_tokens] # (layer, head, bos_token)
567
+ n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
568
+ return sequence_averages[
569
+ :, :, :n_utility_tokens
570
+ ] # (layer, head, utility_tokens)
558
571
 
559
572
  def reset_parameters(self):
560
573
  self.encoder.reset_parameters()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "5.2.0"
3
+ version = "6.0.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes