broccoli-ml 5.1.0__tar.gz → 7.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 5.1.0
3
+ Version: 7.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -13,6 +13,7 @@ from .rope import RotaryEmbedding, apply_rotary_emb
13
13
  try:
14
14
  from flash_attn import flash_attn_func
15
15
 
16
+ print("Using flash-attn.")
16
17
  FLASH_ATTN = True
17
18
  except ImportError:
18
19
  pass
@@ -76,7 +77,7 @@ class MHAttention(nn.Module):
76
77
  causal=False,
77
78
  seq_len=None,
78
79
  linear_module: nn.Module = nn.Linear,
79
- bos_tokens=0,
80
+ utility_tokens=0,
80
81
  rotary_embedding=None,
81
82
  source_size=None,
82
83
  scaling="d",
@@ -129,7 +130,7 @@ class MHAttention(nn.Module):
129
130
  )
130
131
  self.rotary_embedding = rotary_embedding
131
132
  self.source_size = source_size
132
- self.bos_tokens = bos_tokens
133
+ self.utility_tokens = utility_tokens
133
134
 
134
135
  self.reset_parameters()
135
136
 
@@ -156,7 +157,7 @@ class MHAttention(nn.Module):
156
157
  self, q: torch.Tensor, k: torch.Tensor
157
158
  ) -> Tuple[torch.Tensor, torch.Tensor]:
158
159
  """
159
- Apply Axial RoPE to all tokens except BOS tokens
160
+ Apply Axial RoPE to all tokens except utility tokens
160
161
  """
161
162
 
162
163
  if len(self.source_size) == 1:
@@ -180,8 +181,8 @@ class MHAttention(nn.Module):
180
181
  "`source_size` must be a tuple of 1, 2 or 3 integers"
181
182
  )
182
183
 
183
- q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
184
- k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
184
+ q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
185
+ k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
185
186
 
186
187
  q_img = rearrange(
187
188
  q_img,
@@ -208,9 +209,9 @@ class MHAttention(nn.Module):
208
209
  f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
209
210
  )
210
211
 
211
- # Re-combine the BOS tokens and the RoPE-enhanced image tokens
212
- q = torch.cat([q_bos, q_img], dim=1)
213
- k = torch.cat([k_bos, k_img], dim=1)
212
+ # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
213
+ q = torch.cat([q_util, q_img], dim=1)
214
+ k = torch.cat([k_util, k_img], dim=1)
214
215
 
215
216
  return q, k
216
217
 
@@ -284,7 +285,7 @@ class MHAttention(nn.Module):
284
285
 
285
286
  return self.out_proj(output_without_heads)
286
287
 
287
- def attention_scores(self, q, k, v):
288
+ def attention_logits(self, q, k, v):
288
289
 
289
290
  q, k, v = self.project_qkv(q, k, v)
290
291
 
@@ -301,8 +302,6 @@ class MHAttention(nn.Module):
301
302
  if self.causal:
302
303
  qk_scores.masked_fill_(self.mask, float("-inf"))
303
304
 
304
- qk_scores = F.softmax(qk_scores, dim=-1)
305
-
306
305
  return qk_scores # (batch, head, seq_len, seq_len)
307
306
 
308
307
  def reset_parameters(self):
@@ -416,7 +415,7 @@ class TransformerBlock(nn.Module):
416
415
  n_heads,
417
416
  relative_position_embedding=False,
418
417
  source_size=None,
419
- bos_tokens=0,
418
+ utility_tokens=0,
420
419
  mlp_ratio=4,
421
420
  activation: nn.Module = nn.ReLU,
422
421
  activation_kwargs: Optional[dict] = None,
@@ -474,7 +473,7 @@ class TransformerBlock(nn.Module):
474
473
  linear_module=linear_module,
475
474
  rotary_embedding=self.rotary_embedding,
476
475
  source_size=source_size,
477
- bos_tokens=bos_tokens,
476
+ utility_tokens=utility_tokens,
478
477
  scaling=msa_scaling,
479
478
  )
480
479
 
@@ -529,15 +528,15 @@ class TransformerBlock(nn.Module):
529
528
 
530
529
  return x
531
530
 
532
- def attention_scores(self, x):
531
+ def attention_logits(self, x):
533
532
  """
534
533
  Give back the attention scores used in this layer.
535
534
  """
536
535
  if self.pre_norm:
537
536
  x = self.layer_norm_1(x)
538
- return self.attn(x, x, x)
537
+ return self.attn.attention_logits(x, x, x)
539
538
  else:
540
- return self.attn(x, x, x)
539
+ return self.attn.attention_logits(x, x, x)
541
540
 
542
541
  def reset_parameters(self):
543
542
  self.layer_norm_1.reset_parameters()
@@ -573,8 +572,8 @@ class TransformerEncoder(nn.Module):
573
572
  stochastic_depth=0.0,
574
573
  causal=False,
575
574
  linear_module=nn.Linear,
576
- bos_tokens=0,
577
- return_bos_tokens=False,
575
+ utility_tokens=0,
576
+ return_utility_tokens=False,
578
577
  pre_norm=True,
579
578
  post_norm=False,
580
579
  normformer=False,
@@ -592,22 +591,33 @@ class TransformerEncoder(nn.Module):
592
591
  if relative_position_embedding and (source_size is None):
593
592
  raise ValueError(
594
593
  "`source_size` for TransformerEncoder cannot be None if"
595
- " `position_embedding_type` is relative"
594
+ " `relative_position_embedding` is True"
595
+ )
596
+
597
+ if absolute_position_embedding and (seq_len is None):
598
+ raise ValueError(
599
+ "`seq_len` for TransformerEncoder cannot be None if"
600
+ " `absolute_position_embedding` is True"
596
601
  )
597
602
 
598
603
  super().__init__()
599
604
  self.seq_len = seq_len
600
605
  self.n_heads = n_heads
601
- self._bos_tokens = bos_tokens
602
- self.return_bos_tokens = return_bos_tokens
603
-
604
- # Initialise BOS tokens with normal init, like usual Pytorch embeddings
605
- if self._bos_tokens:
606
- self._bos_embedding = nn.Parameter(torch.empty(self._bos_tokens, d_model))
607
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
608
- self.full_sequence_length = self.seq_len + self._bos_tokens
606
+ self._utility_tokens = utility_tokens
607
+ self.return_utility_tokens = return_utility_tokens
608
+
609
+ # Initialise utility tokens with normal init, like usual Pytorch embeddings
610
+ if self._utility_tokens:
611
+ self._utility_token_embedding = nn.Parameter(
612
+ torch.empty(self._utility_tokens, d_model)
613
+ )
614
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
615
+ else:
616
+ self._utility_token_embedding = None
617
+
618
+ if self._utility_tokens and (self.seq_len is not None):
619
+ self.full_sequence_length = self.seq_len + self._utility_tokens
609
620
  else:
610
- self._bos_embedding = None
611
621
  self.full_sequence_length = self.seq_len
612
622
 
613
623
  self.d_model = d_model
@@ -641,7 +651,7 @@ class TransformerEncoder(nn.Module):
641
651
  n_heads,
642
652
  relative_position_embedding=relative_position_embedding,
643
653
  source_size=source_size,
644
- bos_tokens=bos_tokens,
654
+ utility_tokens=utility_tokens,
645
655
  mlp_ratio=mlp_ratio,
646
656
  activation=activation,
647
657
  activation_kwargs=activation_kwargs,
@@ -669,8 +679,10 @@ class TransformerEncoder(nn.Module):
669
679
  return ",".join([str(block._kv_distance) for block in self.blocks])
670
680
 
671
681
  def preprocess(self, x):
672
- if self._bos_tokens:
673
- x = torch.cat([self._bos_embedding.expand(x.size(0), -1, -1), x], dim=1)
682
+ if self._utility_tokens:
683
+ x = torch.cat(
684
+ [self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
685
+ )
674
686
  else:
675
687
  x = x
676
688
 
@@ -683,6 +695,8 @@ class TransformerEncoder(nn.Module):
683
695
  ) # to shape (1, seq_len) to broadcast over batch
684
696
  )
685
697
 
698
+ return x
699
+
686
700
  def forward(self, x):
687
701
 
688
702
  x = self.preprocess(x)
@@ -690,12 +704,12 @@ class TransformerEncoder(nn.Module):
690
704
  for block in self.blocks:
691
705
  x = block(x)
692
706
 
693
- if self._bos_tokens and not self.return_bos_tokens:
694
- return x[:, self._bos_tokens :, :]
707
+ if self._utility_tokens and not self.return_utility_tokens:
708
+ return x[:, self._utility_tokens :, :]
695
709
  else:
696
710
  return x
697
711
 
698
- def attention_scores(self, x):
712
+ def attention_logits(self, x):
699
713
 
700
714
  x = self.preprocess(x)
701
715
 
@@ -703,15 +717,15 @@ class TransformerEncoder(nn.Module):
703
717
 
704
718
  for block in self.blocks:
705
719
  # Get attention scores with shape (batch, 1, head, seq_len, seq_len)
706
- layer_attention_scores = block.attention_scores(x).unsqueeze(1)
707
- layer_scores.append(layer_attention_scores)
720
+ layer_attention_logits = block.attention_logits(x).unsqueeze(1)
721
+ layer_scores.append(layer_attention_logits)
708
722
  x = block(x)
709
723
 
710
724
  return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
711
725
 
712
726
  def reset_parameters(self):
713
- if self._bos_embedding is not None:
714
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
727
+ if self._utility_token_embedding is not None:
728
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
715
729
 
716
730
  if self.absolute_position_embedding is not None:
717
731
  self.absolute_position_embedding.reset_parameters()
@@ -11,7 +11,6 @@ from einops.layers.torch import Rearrange
11
11
 
12
12
  import torch
13
13
  import torch.nn as nn
14
- import torch.nn.functional as F
15
14
 
16
15
 
17
16
  class GetCLSToken(nn.Module):
@@ -39,6 +38,9 @@ class SequencePool(nn.Module):
39
38
  weights = self.attention(x)
40
39
  return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
41
40
 
41
+ def attention_scores(self, x):
42
+ return self.attention(x)
43
+
42
44
  def reset_parameters(self):
43
45
  # Iterate over modules in the sequential block
44
46
  for module in self.attention:
@@ -169,8 +171,8 @@ class ViTEncoder(nn.Module):
169
171
  transformer_layers=7,
170
172
  transformer_heads=4,
171
173
  transformer_mlp_ratio=2,
172
- transformer_bos_tokens=0,
173
- transformer_return_bos_tokens=False,
174
+ transformer_utility_tokens=0,
175
+ transformer_return_utility_tokens=False,
174
176
  transformer_activation: nn.Module = SquaredReLU,
175
177
  transformer_activation_kwargs: Optional[dict] = None,
176
178
  transformer_ff_linear_module_up=None,
@@ -336,8 +338,8 @@ class ViTEncoder(nn.Module):
336
338
  stochastic_depth=transformer_stochastic_depth,
337
339
  causal=False,
338
340
  linear_module=linear_module,
339
- bos_tokens=transformer_bos_tokens,
340
- return_bos_tokens=transformer_return_bos_tokens,
341
+ utility_tokens=transformer_utility_tokens,
342
+ return_utility_tokens=transformer_return_utility_tokens,
341
343
  pre_norm=transformer_pre_norm,
342
344
  normformer=transformer_normformer,
343
345
  post_norm=transformer_post_norm,
@@ -400,9 +402,9 @@ class ViTEncoder(nn.Module):
400
402
  def forward(self, x):
401
403
  return self.encoder(x)
402
404
 
403
- def attention_scores(self, x):
405
+ def attention_logits(self, x):
404
406
  x = self.encoder[:-1](x)
405
- return self.encoder[-1].attention_scores(x)
407
+ return self.encoder[-1].attention_logits(x)
406
408
 
407
409
  def reset_parameters(self):
408
410
  for module in self.encoder:
@@ -449,8 +451,8 @@ class ViT(nn.Module):
449
451
  transformer_layers=7,
450
452
  transformer_heads=4,
451
453
  transformer_mlp_ratio=2,
452
- transformer_bos_tokens=0,
453
- transformer_return_bos_tokens=False,
454
+ transformer_utility_tokens=0,
455
+ transformer_return_utility_tokens=False,
454
456
  transformer_activation: nn.Module = SquaredReLU,
455
457
  transformer_activation_kwargs: Optional[dict] = None,
456
458
  transformer_ff_linear_module_up=None,
@@ -516,8 +518,8 @@ class ViT(nn.Module):
516
518
  transformer_layers=transformer_layers,
517
519
  transformer_heads=transformer_heads,
518
520
  transformer_mlp_ratio=transformer_mlp_ratio,
519
- transformer_bos_tokens=transformer_bos_tokens,
520
- transformer_return_bos_tokens=transformer_return_bos_tokens,
521
+ transformer_utility_tokens=transformer_utility_tokens,
522
+ transformer_return_utility_tokens=transformer_return_utility_tokens,
521
523
  transformer_activation=transformer_activation,
522
524
  transformer_activation_kwargs=transformer_activation_kwargs,
523
525
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
@@ -546,16 +548,26 @@ class ViT(nn.Module):
546
548
  def forward(self, x):
547
549
  return self.pool(self.encoder(x))
548
550
 
549
- def attention_scores(self, x):
550
- return self.encoder.attention_scores(x)
551
+ def attention_logits(self, x):
552
+ return self.encoder.attention_logits(x)
553
+
554
+ def pool_attention(self, x):
555
+ if hasattr(self.pool.summarize, "attention"):
556
+ return self.pool.summarize.attention(self.encoder(x))
557
+ else:
558
+ raise NotImplementedError(
559
+ "`pool_attention` is currently only implemented where"
560
+ " head class is SequencePoolClassificationHead"
561
+ )
551
562
 
552
- def head_to_bos_token_attention(self, x):
553
- all_attention = self.attention_scores(x)
563
+ def head_to_utility_token_attention_logits(self, x):
564
+ all_attention = self.attention_logits(x)
554
565
  batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
555
566
  sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
556
- n_bos_tokens = self.encoder.encoder._bos_tokens
557
- just_bos = sequence_averages[:, :, :n_bos_tokens]
558
- return F.softmax(just_bos, dim=-1) # (layer, head, bos_token)
567
+ n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
568
+ return sequence_averages[
569
+ :, :, :n_utility_tokens
570
+ ] # (layer, head, utility_tokens)
559
571
 
560
572
  def reset_parameters(self):
561
573
  self.encoder.reset_parameters()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "5.1.0"
3
+ version = "7.0.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes