broccoli-ml 9.6.0__tar.gz → 9.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/PKG-INFO +1 -1
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/broccoli/transformer.py +37 -1
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/broccoli/vit.py +4 -0
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/pyproject.toml +1 -1
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/LICENSE +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/README.md +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/broccoli/activation.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/broccoli/linear.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/broccoli/rope.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-9.7.0}/broccoli/utils.py +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import math
|
|
2
3
|
from typing import Optional, Tuple
|
|
3
4
|
|
|
@@ -78,6 +79,7 @@ class MHAttention(nn.Module):
|
|
|
78
79
|
seq_len=None,
|
|
79
80
|
linear_module: nn.Module = nn.Linear,
|
|
80
81
|
utility_tokens=0,
|
|
82
|
+
talking_heads=False,
|
|
81
83
|
rotary_embedding=None,
|
|
82
84
|
source_size=None,
|
|
83
85
|
scaling="d",
|
|
@@ -96,6 +98,15 @@ class MHAttention(nn.Module):
|
|
|
96
98
|
if causal:
|
|
97
99
|
assert seq_len is not None
|
|
98
100
|
|
|
101
|
+
self.talking_heads = talking_heads
|
|
102
|
+
|
|
103
|
+
if self.talking_heads:
|
|
104
|
+
self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
105
|
+
self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
106
|
+
else:
|
|
107
|
+
self.head_projection = None
|
|
108
|
+
self.sample_projection = None
|
|
109
|
+
|
|
99
110
|
self.embed_dim = embed_dim
|
|
100
111
|
self.n_heads = n_heads
|
|
101
112
|
assert embed_dim % n_heads == 0
|
|
@@ -243,7 +254,7 @@ class MHAttention(nn.Module):
|
|
|
243
254
|
|
|
244
255
|
q, k, v = self.project_qkv(q, k, v)
|
|
245
256
|
|
|
246
|
-
if FLASH_ATTN:
|
|
257
|
+
if FLASH_ATTN and not self.talking_heads:
|
|
247
258
|
# Divide Q/K/V into heads
|
|
248
259
|
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
249
260
|
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
@@ -271,12 +282,22 @@ class MHAttention(nn.Module):
|
|
|
271
282
|
|
|
272
283
|
qk_scores *= self.scaling_factor
|
|
273
284
|
|
|
285
|
+
if self.talking_heads:
|
|
286
|
+
qk_scores = torch.einsum(
|
|
287
|
+
"b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
|
|
288
|
+
)
|
|
289
|
+
|
|
274
290
|
# Apply mask if causal (must come before softmax)
|
|
275
291
|
if self.causal:
|
|
276
292
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
277
293
|
|
|
278
294
|
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
279
295
|
|
|
296
|
+
if self.talking_heads:
|
|
297
|
+
qk_scores = torch.einsum(
|
|
298
|
+
"b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
|
|
299
|
+
)
|
|
300
|
+
|
|
280
301
|
qk_scores = self.dropout(qk_scores)
|
|
281
302
|
|
|
282
303
|
output_with_heads = qk_scores @ v
|
|
@@ -310,6 +331,10 @@ class MHAttention(nn.Module):
|
|
|
310
331
|
self.k_proj.reset_parameters()
|
|
311
332
|
self.v_proj.reset_parameters()
|
|
312
333
|
self.out_proj.reset_parameters()
|
|
334
|
+
if self.talking_heads:
|
|
335
|
+
# Initialize close to identity
|
|
336
|
+
nn.init.eye_(self.head_projection.weight)
|
|
337
|
+
nn.init.eye_(self.sample_projection.weight)
|
|
313
338
|
|
|
314
339
|
|
|
315
340
|
class FeedforwardBlock(nn.Module):
|
|
@@ -453,6 +478,7 @@ class TransformerBlock(nn.Module):
|
|
|
453
478
|
relative_position_embedding=False,
|
|
454
479
|
source_size=None,
|
|
455
480
|
utility_tokens=0,
|
|
481
|
+
talking_heads=False,
|
|
456
482
|
mlp_ratio=4,
|
|
457
483
|
activation: nn.Module = nn.ReLU,
|
|
458
484
|
activation_kwargs: Optional[dict] = None,
|
|
@@ -513,6 +539,7 @@ class TransformerBlock(nn.Module):
|
|
|
513
539
|
rotary_embedding=self.rotary_embedding,
|
|
514
540
|
source_size=source_size,
|
|
515
541
|
utility_tokens=utility_tokens,
|
|
542
|
+
talking_heads=talking_heads,
|
|
516
543
|
scaling=msa_scaling,
|
|
517
544
|
)
|
|
518
545
|
|
|
@@ -616,6 +643,7 @@ class TransformerEncoder(nn.Module):
|
|
|
616
643
|
causal=False,
|
|
617
644
|
linear_module=nn.Linear,
|
|
618
645
|
utility_tokens=0,
|
|
646
|
+
talking_heads=False,
|
|
619
647
|
return_utility_tokens=False,
|
|
620
648
|
pre_norm=True,
|
|
621
649
|
post_norm=False,
|
|
@@ -644,6 +672,13 @@ class TransformerEncoder(nn.Module):
|
|
|
644
672
|
)
|
|
645
673
|
|
|
646
674
|
super().__init__()
|
|
675
|
+
|
|
676
|
+
if FLASH_ATTN and talking_heads:
|
|
677
|
+
warnings.warn(
|
|
678
|
+
"Using talking heads currently prevents using flash attention.",
|
|
679
|
+
stacklevel=2,
|
|
680
|
+
)
|
|
681
|
+
|
|
647
682
|
self.seq_len = seq_len
|
|
648
683
|
self.n_heads = n_heads
|
|
649
684
|
self._utility_tokens = utility_tokens
|
|
@@ -695,6 +730,7 @@ class TransformerEncoder(nn.Module):
|
|
|
695
730
|
relative_position_embedding=relative_position_embedding,
|
|
696
731
|
source_size=source_size,
|
|
697
732
|
utility_tokens=utility_tokens,
|
|
733
|
+
talking_heads=talking_heads,
|
|
698
734
|
mlp_ratio=mlp_ratio,
|
|
699
735
|
activation=activation,
|
|
700
736
|
activation_kwargs=activation_kwargs,
|
|
@@ -174,6 +174,7 @@ class ViTEncoder(nn.Module):
|
|
|
174
174
|
transformer_heads=4,
|
|
175
175
|
transformer_mlp_ratio=2,
|
|
176
176
|
transformer_utility_tokens=0,
|
|
177
|
+
transformer_talking_heads=False,
|
|
177
178
|
transformer_return_utility_tokens=False,
|
|
178
179
|
transformer_activation: nn.Module = SquaredReLU,
|
|
179
180
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -345,6 +346,7 @@ class ViTEncoder(nn.Module):
|
|
|
345
346
|
causal=False,
|
|
346
347
|
linear_module=linear_module,
|
|
347
348
|
utility_tokens=transformer_utility_tokens,
|
|
349
|
+
talking_heads=transformer_talking_heads,
|
|
348
350
|
return_utility_tokens=transformer_return_utility_tokens,
|
|
349
351
|
pre_norm=transformer_pre_norm,
|
|
350
352
|
normformer=transformer_normformer,
|
|
@@ -472,6 +474,7 @@ class ViT(nn.Module):
|
|
|
472
474
|
transformer_heads=4,
|
|
473
475
|
transformer_mlp_ratio=2,
|
|
474
476
|
transformer_utility_tokens=0,
|
|
477
|
+
transformer_talking_heads=False,
|
|
475
478
|
transformer_return_utility_tokens=False,
|
|
476
479
|
transformer_activation: nn.Module = SquaredReLU,
|
|
477
480
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -543,6 +546,7 @@ class ViT(nn.Module):
|
|
|
543
546
|
transformer_heads=transformer_heads,
|
|
544
547
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
545
548
|
transformer_utility_tokens=transformer_utility_tokens,
|
|
549
|
+
transformer_talking_heads=transformer_talking_heads,
|
|
546
550
|
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
547
551
|
transformer_activation=transformer_activation,
|
|
548
552
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|