broccoli-ml 5.2.0__tar.gz → 6.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 5.2.0
3
+ Version: 6.0.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -13,6 +13,7 @@ from .rope import RotaryEmbedding, apply_rotary_emb
13
13
  try:
14
14
  from flash_attn import flash_attn_func
15
15
 
16
+ print("Using flash-attn.")
16
17
  FLASH_ATTN = True
17
18
  except ImportError:
18
19
  pass
@@ -76,7 +77,7 @@ class MHAttention(nn.Module):
76
77
  causal=False,
77
78
  seq_len=None,
78
79
  linear_module: nn.Module = nn.Linear,
79
- bos_tokens=0,
80
+ utility_tokens=0,
80
81
  rotary_embedding=None,
81
82
  source_size=None,
82
83
  scaling="d",
@@ -129,7 +130,7 @@ class MHAttention(nn.Module):
129
130
  )
130
131
  self.rotary_embedding = rotary_embedding
131
132
  self.source_size = source_size
132
- self.bos_tokens = bos_tokens
133
+ self.utility_tokens = utility_tokens
133
134
 
134
135
  self.reset_parameters()
135
136
 
@@ -156,7 +157,7 @@ class MHAttention(nn.Module):
156
157
  self, q: torch.Tensor, k: torch.Tensor
157
158
  ) -> Tuple[torch.Tensor, torch.Tensor]:
158
159
  """
159
- Apply Axial RoPE to all tokens except BOS tokens
160
+ Apply Axial RoPE to all tokens except utility tokens
160
161
  """
161
162
 
162
163
  if len(self.source_size) == 1:
@@ -180,8 +181,8 @@ class MHAttention(nn.Module):
180
181
  "`source_size` must be a tuple of 1, 2 or 3 integers"
181
182
  )
182
183
 
183
- q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
184
- k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
184
+ q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
185
+ k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
185
186
 
186
187
  q_img = rearrange(
187
188
  q_img,
@@ -208,9 +209,9 @@ class MHAttention(nn.Module):
208
209
  f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
209
210
  )
210
211
 
211
- # Re-combine the BOS tokens and the RoPE-enhanced image tokens
212
- q = torch.cat([q_bos, q_img], dim=1)
213
- k = torch.cat([k_bos, k_img], dim=1)
212
+ # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
213
+ q = torch.cat([q_util, q_img], dim=1)
214
+ k = torch.cat([k_util, k_img], dim=1)
214
215
 
215
216
  return q, k
216
217
 
@@ -414,7 +415,7 @@ class TransformerBlock(nn.Module):
414
415
  n_heads,
415
416
  relative_position_embedding=False,
416
417
  source_size=None,
417
- bos_tokens=0,
418
+ utility_tokens=0,
418
419
  mlp_ratio=4,
419
420
  activation: nn.Module = nn.ReLU,
420
421
  activation_kwargs: Optional[dict] = None,
@@ -472,7 +473,7 @@ class TransformerBlock(nn.Module):
472
473
  linear_module=linear_module,
473
474
  rotary_embedding=self.rotary_embedding,
474
475
  source_size=source_size,
475
- bos_tokens=bos_tokens,
476
+ utility_tokens=utility_tokens,
476
477
  scaling=msa_scaling,
477
478
  )
478
479
 
@@ -571,8 +572,8 @@ class TransformerEncoder(nn.Module):
571
572
  stochastic_depth=0.0,
572
573
  causal=False,
573
574
  linear_module=nn.Linear,
574
- bos_tokens=0,
575
- return_bos_tokens=False,
575
+ utility_tokens=0,
576
+ return_utility_tokens=False,
576
577
  pre_norm=True,
577
578
  post_norm=False,
578
579
  normformer=False,
@@ -596,16 +597,18 @@ class TransformerEncoder(nn.Module):
596
597
  super().__init__()
597
598
  self.seq_len = seq_len
598
599
  self.n_heads = n_heads
599
- self._bos_tokens = bos_tokens
600
- self.return_bos_tokens = return_bos_tokens
601
-
602
- # Initialise BOS tokens with normal init, like usual Pytorch embeddings
603
- if self._bos_tokens:
604
- self._bos_embedding = nn.Parameter(torch.empty(self._bos_tokens, d_model))
605
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
606
- self.full_sequence_length = self.seq_len + self._bos_tokens
600
+ self._utility_tokens = utility_tokens
601
+ self.return_utility_tokens = return_utility_tokens
602
+
603
+ # Initialise utility tokens with normal init, like usual Pytorch embeddings
604
+ if self._utility_tokens:
605
+ self._utility_token_embedding = nn.Parameter(
606
+ torch.empty(self._utility_tokens, d_model)
607
+ )
608
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
609
+ self.full_sequence_length = self.seq_len + self._utility_tokens
607
610
  else:
608
- self._bos_embedding = None
611
+ self._utility_token_embedding = None
609
612
  self.full_sequence_length = self.seq_len
610
613
 
611
614
  self.d_model = d_model
@@ -639,7 +642,7 @@ class TransformerEncoder(nn.Module):
639
642
  n_heads,
640
643
  relative_position_embedding=relative_position_embedding,
641
644
  source_size=source_size,
642
- bos_tokens=bos_tokens,
645
+ utility_tokens=utility_tokens,
643
646
  mlp_ratio=mlp_ratio,
644
647
  activation=activation,
645
648
  activation_kwargs=activation_kwargs,
@@ -667,8 +670,10 @@ class TransformerEncoder(nn.Module):
667
670
  return ",".join([str(block._kv_distance) for block in self.blocks])
668
671
 
669
672
  def preprocess(self, x):
670
- if self._bos_tokens:
671
- x = torch.cat([self._bos_embedding.expand(x.size(0), -1, -1), x], dim=1)
673
+ if self._utility_tokens:
674
+ x = torch.cat(
675
+ [self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
676
+ )
672
677
  else:
673
678
  x = x
674
679
 
@@ -690,8 +695,8 @@ class TransformerEncoder(nn.Module):
690
695
  for block in self.blocks:
691
696
  x = block(x)
692
697
 
693
- if self._bos_tokens and not self.return_bos_tokens:
694
- return x[:, self._bos_tokens :, :]
698
+ if self._utility_tokens and not self.return_utility_tokens:
699
+ return x[:, self._utility_tokens :, :]
695
700
  else:
696
701
  return x
697
702
 
@@ -710,8 +715,8 @@ class TransformerEncoder(nn.Module):
710
715
  return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
711
716
 
712
717
  def reset_parameters(self):
713
- if self._bos_embedding is not None:
714
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
718
+ if self._utility_token_embedding is not None:
719
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
715
720
 
716
721
  if self.absolute_position_embedding is not None:
717
722
  self.absolute_position_embedding.reset_parameters()
@@ -11,7 +11,6 @@ from einops.layers.torch import Rearrange
11
11
 
12
12
  import torch
13
13
  import torch.nn as nn
14
- import torch.nn.functional as F
15
14
 
16
15
 
17
16
  class GetCLSToken(nn.Module):
@@ -39,6 +38,9 @@ class SequencePool(nn.Module):
39
38
  weights = self.attention(x)
40
39
  return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
41
40
 
41
+ def attention_scores(self, x):
42
+ return self.attention(x)
43
+
42
44
  def reset_parameters(self):
43
45
  # Iterate over modules in the sequential block
44
46
  for module in self.attention:
@@ -169,8 +171,8 @@ class ViTEncoder(nn.Module):
169
171
  transformer_layers=7,
170
172
  transformer_heads=4,
171
173
  transformer_mlp_ratio=2,
172
- transformer_bos_tokens=0,
173
- transformer_return_bos_tokens=False,
174
+ transformer_utility_tokens=0,
175
+ transformer_return_utility_tokens=False,
174
176
  transformer_activation: nn.Module = SquaredReLU,
175
177
  transformer_activation_kwargs: Optional[dict] = None,
176
178
  transformer_ff_linear_module_up=None,
@@ -336,8 +338,8 @@ class ViTEncoder(nn.Module):
336
338
  stochastic_depth=transformer_stochastic_depth,
337
339
  causal=False,
338
340
  linear_module=linear_module,
339
- bos_tokens=transformer_bos_tokens,
340
- return_bos_tokens=transformer_return_bos_tokens,
341
+ utility_tokens=transformer_utility_tokens,
342
+ return_utility_tokens=transformer_return_utility_tokens,
341
343
  pre_norm=transformer_pre_norm,
342
344
  normformer=transformer_normformer,
343
345
  post_norm=transformer_post_norm,
@@ -449,8 +451,8 @@ class ViT(nn.Module):
449
451
  transformer_layers=7,
450
452
  transformer_heads=4,
451
453
  transformer_mlp_ratio=2,
452
- transformer_bos_tokens=0,
453
- transformer_return_bos_tokens=False,
454
+ transformer_utility_tokens=0,
455
+ transformer_return_utility_tokens=False,
454
456
  transformer_activation: nn.Module = SquaredReLU,
455
457
  transformer_activation_kwargs: Optional[dict] = None,
456
458
  transformer_ff_linear_module_up=None,
@@ -516,8 +518,8 @@ class ViT(nn.Module):
516
518
  transformer_layers=transformer_layers,
517
519
  transformer_heads=transformer_heads,
518
520
  transformer_mlp_ratio=transformer_mlp_ratio,
519
- transformer_bos_tokens=transformer_bos_tokens,
520
- transformer_return_bos_tokens=transformer_return_bos_tokens,
521
+ transformer_utility_tokens=transformer_utility_tokens,
522
+ transformer_return_utility_tokens=transformer_return_utility_tokens,
521
523
  transformer_activation=transformer_activation,
522
524
  transformer_activation_kwargs=transformer_activation_kwargs,
523
525
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
@@ -549,12 +551,23 @@ class ViT(nn.Module):
549
551
  def attention_logits(self, x):
550
552
  return self.encoder.attention_logits(x)
551
553
 
552
- def head_to_bos_token_attention_logits(self, x):
554
+ def pool_attention(self, x):
555
+ if hasattr(self.pool.summarize, "attention"):
556
+ return self.pool.summarize.attention(self.encoder(x))
557
+ else:
558
+ raise NotImplementedError(
559
+ "`pool_attention` is currently only implemented where"
560
+ " head class is SequencePoolClassificationHead"
561
+ )
562
+
563
+ def head_to_utility_token_attention_logits(self, x):
553
564
  all_attention = self.attention_logits(x)
554
565
  batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
555
566
  sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
556
- n_bos_tokens = self.encoder.encoder[-1]._bos_tokens
557
- return sequence_averages[:, :, :n_bos_tokens] # (layer, head, bos_token)
567
+ n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
568
+ return sequence_averages[
569
+ :, :, :n_utility_tokens
570
+ ] # (layer, head, utility_tokens)
558
571
 
559
572
  def reset_parameters(self):
560
573
  self.encoder.reset_parameters()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "5.2.0"
3
+ version = "6.0.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes