broccoli-ml 9.6.0__tar.gz → 9.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.6.0
3
+ Version: 9.7.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import math
2
3
  from typing import Optional, Tuple
3
4
 
@@ -78,6 +79,7 @@ class MHAttention(nn.Module):
78
79
  seq_len=None,
79
80
  linear_module: nn.Module = nn.Linear,
80
81
  utility_tokens=0,
82
+ talking_heads=False,
81
83
  rotary_embedding=None,
82
84
  source_size=None,
83
85
  scaling="d",
@@ -96,6 +98,15 @@ class MHAttention(nn.Module):
96
98
  if causal:
97
99
  assert seq_len is not None
98
100
 
101
+ self.talking_heads = talking_heads
102
+
103
+ if self.talking_heads:
104
+ self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
105
+ self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
106
+ else:
107
+ self.head_projection = None
108
+ self.sample_projection = None
109
+
99
110
  self.embed_dim = embed_dim
100
111
  self.n_heads = n_heads
101
112
  assert embed_dim % n_heads == 0
@@ -243,7 +254,7 @@ class MHAttention(nn.Module):
243
254
 
244
255
  q, k, v = self.project_qkv(q, k, v)
245
256
 
246
- if FLASH_ATTN:
257
+ if FLASH_ATTN and not self.talking_heads:
247
258
  # Divide Q/K/V into heads
248
259
  q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
249
260
  k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
@@ -271,12 +282,22 @@ class MHAttention(nn.Module):
271
282
 
272
283
  qk_scores *= self.scaling_factor
273
284
 
285
+ if self.talking_heads:
286
+ qk_scores = torch.einsum(
287
+ "b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
288
+ )
289
+
274
290
  # Apply mask if causal (must come before softmax)
275
291
  if self.causal:
276
292
  qk_scores.masked_fill_(self.mask, float("-inf"))
277
293
 
278
294
  qk_scores = F.softmax(qk_scores, dim=-1)
279
295
 
296
+ if self.talking_heads:
297
+ qk_scores = torch.einsum(
298
+ "b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
299
+ )
300
+
280
301
  qk_scores = self.dropout(qk_scores)
281
302
 
282
303
  output_with_heads = qk_scores @ v
@@ -310,6 +331,10 @@ class MHAttention(nn.Module):
310
331
  self.k_proj.reset_parameters()
311
332
  self.v_proj.reset_parameters()
312
333
  self.out_proj.reset_parameters()
334
+ if self.talking_heads:
335
+ # Initialize close to identity
336
+ nn.init.eye_(self.head_projection.weight)
337
+ nn.init.eye_(self.sample_projection.weight)
313
338
 
314
339
 
315
340
  class FeedforwardBlock(nn.Module):
@@ -453,6 +478,7 @@ class TransformerBlock(nn.Module):
453
478
  relative_position_embedding=False,
454
479
  source_size=None,
455
480
  utility_tokens=0,
481
+ talking_heads=False,
456
482
  mlp_ratio=4,
457
483
  activation: nn.Module = nn.ReLU,
458
484
  activation_kwargs: Optional[dict] = None,
@@ -513,6 +539,7 @@ class TransformerBlock(nn.Module):
513
539
  rotary_embedding=self.rotary_embedding,
514
540
  source_size=source_size,
515
541
  utility_tokens=utility_tokens,
542
+ talking_heads=talking_heads,
516
543
  scaling=msa_scaling,
517
544
  )
518
545
 
@@ -616,6 +643,7 @@ class TransformerEncoder(nn.Module):
616
643
  causal=False,
617
644
  linear_module=nn.Linear,
618
645
  utility_tokens=0,
646
+ talking_heads=False,
619
647
  return_utility_tokens=False,
620
648
  pre_norm=True,
621
649
  post_norm=False,
@@ -644,6 +672,13 @@ class TransformerEncoder(nn.Module):
644
672
  )
645
673
 
646
674
  super().__init__()
675
+
676
+ if FLASH_ATTN and talking_heads:
677
+ warnings.warn(
678
+ "Using talking heads currently prevents using flash attention.",
679
+ stacklevel=2,
680
+ )
681
+
647
682
  self.seq_len = seq_len
648
683
  self.n_heads = n_heads
649
684
  self._utility_tokens = utility_tokens
@@ -695,6 +730,7 @@ class TransformerEncoder(nn.Module):
695
730
  relative_position_embedding=relative_position_embedding,
696
731
  source_size=source_size,
697
732
  utility_tokens=utility_tokens,
733
+ talking_heads=talking_heads,
698
734
  mlp_ratio=mlp_ratio,
699
735
  activation=activation,
700
736
  activation_kwargs=activation_kwargs,
@@ -174,6 +174,7 @@ class ViTEncoder(nn.Module):
174
174
  transformer_heads=4,
175
175
  transformer_mlp_ratio=2,
176
176
  transformer_utility_tokens=0,
177
+ transformer_talking_heads=False,
177
178
  transformer_return_utility_tokens=False,
178
179
  transformer_activation: nn.Module = SquaredReLU,
179
180
  transformer_activation_kwargs: Optional[dict] = None,
@@ -345,6 +346,7 @@ class ViTEncoder(nn.Module):
345
346
  causal=False,
346
347
  linear_module=linear_module,
347
348
  utility_tokens=transformer_utility_tokens,
349
+ talking_heads=transformer_talking_heads,
348
350
  return_utility_tokens=transformer_return_utility_tokens,
349
351
  pre_norm=transformer_pre_norm,
350
352
  normformer=transformer_normformer,
@@ -472,6 +474,7 @@ class ViT(nn.Module):
472
474
  transformer_heads=4,
473
475
  transformer_mlp_ratio=2,
474
476
  transformer_utility_tokens=0,
477
+ transformer_talking_heads=False,
475
478
  transformer_return_utility_tokens=False,
476
479
  transformer_activation: nn.Module = SquaredReLU,
477
480
  transformer_activation_kwargs: Optional[dict] = None,
@@ -543,6 +546,7 @@ class ViT(nn.Module):
543
546
  transformer_heads=transformer_heads,
544
547
  transformer_mlp_ratio=transformer_mlp_ratio,
545
548
  transformer_utility_tokens=transformer_utility_tokens,
549
+ transformer_talking_heads=transformer_talking_heads,
546
550
  transformer_return_utility_tokens=transformer_return_utility_tokens,
547
551
  transformer_activation=transformer_activation,
548
552
  transformer_activation_kwargs=transformer_activation_kwargs,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.6.0"
3
+ version = "9.7.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes