broccoli-ml 9.6.0__tar.gz → 10.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/PKG-INFO +1 -1
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/broccoli/transformer.py +82 -12
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/broccoli/vit.py +8 -0
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/pyproject.toml +1 -1
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/LICENSE +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/README.md +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/broccoli/activation.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/broccoli/linear.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/broccoli/rope.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.6.0 → broccoli_ml-10.1.0}/broccoli/utils.py +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import math
|
|
2
3
|
from typing import Optional, Tuple
|
|
3
4
|
|
|
@@ -20,6 +21,15 @@ except ImportError:
|
|
|
20
21
|
FLASH_ATTN = False
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
class LayerScale(nn.Module):
|
|
25
|
+
def __init__(self, dim, init_values=1e-4):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
return x * self.nondecay_scale
|
|
31
|
+
|
|
32
|
+
|
|
23
33
|
def drop_path(
|
|
24
34
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
25
35
|
):
|
|
@@ -78,6 +88,7 @@ class MHAttention(nn.Module):
|
|
|
78
88
|
seq_len=None,
|
|
79
89
|
linear_module: nn.Module = nn.Linear,
|
|
80
90
|
utility_tokens=0,
|
|
91
|
+
talking_heads=False,
|
|
81
92
|
rotary_embedding=None,
|
|
82
93
|
source_size=None,
|
|
83
94
|
scaling="d",
|
|
@@ -96,6 +107,15 @@ class MHAttention(nn.Module):
|
|
|
96
107
|
if causal:
|
|
97
108
|
assert seq_len is not None
|
|
98
109
|
|
|
110
|
+
self.talking_heads = talking_heads
|
|
111
|
+
|
|
112
|
+
if self.talking_heads:
|
|
113
|
+
self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
114
|
+
self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
|
|
115
|
+
else:
|
|
116
|
+
self.head_projection = None
|
|
117
|
+
self.sample_projection = None
|
|
118
|
+
|
|
99
119
|
self.embed_dim = embed_dim
|
|
100
120
|
self.n_heads = n_heads
|
|
101
121
|
assert embed_dim % n_heads == 0
|
|
@@ -243,7 +263,7 @@ class MHAttention(nn.Module):
|
|
|
243
263
|
|
|
244
264
|
q, k, v = self.project_qkv(q, k, v)
|
|
245
265
|
|
|
246
|
-
if FLASH_ATTN:
|
|
266
|
+
if FLASH_ATTN and not self.talking_heads:
|
|
247
267
|
# Divide Q/K/V into heads
|
|
248
268
|
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
249
269
|
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
@@ -271,12 +291,22 @@ class MHAttention(nn.Module):
|
|
|
271
291
|
|
|
272
292
|
qk_scores *= self.scaling_factor
|
|
273
293
|
|
|
294
|
+
if self.talking_heads:
|
|
295
|
+
qk_scores = torch.einsum(
|
|
296
|
+
"b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
|
|
297
|
+
)
|
|
298
|
+
|
|
274
299
|
# Apply mask if causal (must come before softmax)
|
|
275
300
|
if self.causal:
|
|
276
301
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
277
302
|
|
|
278
303
|
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
279
304
|
|
|
305
|
+
if self.talking_heads:
|
|
306
|
+
qk_scores = torch.einsum(
|
|
307
|
+
"b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
|
|
308
|
+
)
|
|
309
|
+
|
|
280
310
|
qk_scores = self.dropout(qk_scores)
|
|
281
311
|
|
|
282
312
|
output_with_heads = qk_scores @ v
|
|
@@ -310,6 +340,10 @@ class MHAttention(nn.Module):
|
|
|
310
340
|
self.k_proj.reset_parameters()
|
|
311
341
|
self.v_proj.reset_parameters()
|
|
312
342
|
self.out_proj.reset_parameters()
|
|
343
|
+
if self.talking_heads:
|
|
344
|
+
# Initialize close to identity
|
|
345
|
+
nn.init.eye_(self.head_projection.weight)
|
|
346
|
+
nn.init.eye_(self.sample_projection.weight)
|
|
313
347
|
|
|
314
348
|
|
|
315
349
|
class FeedforwardBlock(nn.Module):
|
|
@@ -365,11 +399,15 @@ class FeedforwardBlock(nn.Module):
|
|
|
365
399
|
)
|
|
366
400
|
|
|
367
401
|
self.max_features = (
|
|
368
|
-
2 * ratio * output_features
|
|
402
|
+
2 * int(ratio * output_features)
|
|
403
|
+
if self.xglu
|
|
404
|
+
else int(ratio * output_features)
|
|
369
405
|
)
|
|
370
406
|
|
|
371
407
|
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
372
|
-
self.linear_out = linear_module_down(
|
|
408
|
+
self.linear_out = linear_module_down(
|
|
409
|
+
int(ratio * output_features), output_features
|
|
410
|
+
)
|
|
373
411
|
|
|
374
412
|
self.process = nn.Sequential(
|
|
375
413
|
*[
|
|
@@ -377,7 +415,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
377
415
|
self.linear_in,
|
|
378
416
|
self.activation,
|
|
379
417
|
self.inner_dropout,
|
|
380
|
-
|
|
418
|
+
(
|
|
419
|
+
nn.LayerNorm(int(ratio * output_features))
|
|
420
|
+
if normformer
|
|
421
|
+
else nn.Identity()
|
|
422
|
+
),
|
|
381
423
|
self.linear_out,
|
|
382
424
|
self.outer_dropout,
|
|
383
425
|
]
|
|
@@ -453,6 +495,7 @@ class TransformerBlock(nn.Module):
|
|
|
453
495
|
relative_position_embedding=False,
|
|
454
496
|
source_size=None,
|
|
455
497
|
utility_tokens=0,
|
|
498
|
+
talking_heads=False,
|
|
456
499
|
mlp_ratio=4,
|
|
457
500
|
activation: nn.Module = nn.ReLU,
|
|
458
501
|
activation_kwargs: Optional[dict] = None,
|
|
@@ -470,6 +513,7 @@ class TransformerBlock(nn.Module):
|
|
|
470
513
|
post_norm=False,
|
|
471
514
|
normformer=False,
|
|
472
515
|
checkpoint_ff=True,
|
|
516
|
+
layerscale=True,
|
|
473
517
|
):
|
|
474
518
|
"""
|
|
475
519
|
Args:
|
|
@@ -491,6 +535,13 @@ class TransformerBlock(nn.Module):
|
|
|
491
535
|
self.layer_norm_2 = nn.LayerNorm(d_model)
|
|
492
536
|
self.layer_norm_3 = nn.LayerNorm(d_model)
|
|
493
537
|
|
|
538
|
+
if layerscale:
|
|
539
|
+
self.layerscale1 = LayerScale(d_model)
|
|
540
|
+
self.layerscale2 = LayerScale(d_model)
|
|
541
|
+
else:
|
|
542
|
+
self.layerscale1 = nn.Identity()
|
|
543
|
+
self.layerscale2 = nn.Identity()
|
|
544
|
+
|
|
494
545
|
if relative_position_embedding:
|
|
495
546
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
496
547
|
if d_model < 16:
|
|
@@ -513,6 +564,7 @@ class TransformerBlock(nn.Module):
|
|
|
513
564
|
rotary_embedding=self.rotary_embedding,
|
|
514
565
|
source_size=source_size,
|
|
515
566
|
utility_tokens=utility_tokens,
|
|
567
|
+
talking_heads=talking_heads,
|
|
516
568
|
scaling=msa_scaling,
|
|
517
569
|
)
|
|
518
570
|
|
|
@@ -553,19 +605,19 @@ class TransformerBlock(nn.Module):
|
|
|
553
605
|
|
|
554
606
|
if self.pre_norm:
|
|
555
607
|
x = self.layer_norm_1(x)
|
|
556
|
-
x = x + self.drop_path(self.attn(x, x, x))
|
|
608
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
557
609
|
x = self.layer_norm_2(x)
|
|
558
|
-
x = x + self.drop_path(self.ff(x))
|
|
610
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
559
611
|
if self.post_norm: # i.e. in addition! Pre and post.
|
|
560
612
|
x = self.layer_norm_3(x)
|
|
561
613
|
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
562
|
-
x = x + self.drop_path(self.attn(x, x, x))
|
|
614
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
563
615
|
x = self.layer_norm_1(x)
|
|
564
|
-
x = x + self.drop_path(self.ff(x))
|
|
616
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
565
617
|
x = self.layer_norm_2(x)
|
|
566
618
|
else: # Not pre or post norm. Stand well back.
|
|
567
|
-
x = x + self.drop_path(self.attn(x, x, x))
|
|
568
|
-
x = x + self.drop_path(self.ff(x))
|
|
619
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
620
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
569
621
|
|
|
570
622
|
return x
|
|
571
623
|
|
|
@@ -616,12 +668,14 @@ class TransformerEncoder(nn.Module):
|
|
|
616
668
|
causal=False,
|
|
617
669
|
linear_module=nn.Linear,
|
|
618
670
|
utility_tokens=0,
|
|
671
|
+
talking_heads=False,
|
|
619
672
|
return_utility_tokens=False,
|
|
620
673
|
pre_norm=True,
|
|
621
674
|
post_norm=False,
|
|
622
675
|
normformer=False,
|
|
623
676
|
msa_scaling="d",
|
|
624
677
|
checkpoint_ff=True,
|
|
678
|
+
layerscale=True,
|
|
625
679
|
):
|
|
626
680
|
"""
|
|
627
681
|
Args:
|
|
@@ -644,11 +698,23 @@ class TransformerEncoder(nn.Module):
|
|
|
644
698
|
)
|
|
645
699
|
|
|
646
700
|
super().__init__()
|
|
701
|
+
|
|
702
|
+
if FLASH_ATTN and talking_heads:
|
|
703
|
+
warnings.warn(
|
|
704
|
+
"Using talking heads currently prevents using flash attention.",
|
|
705
|
+
stacklevel=2,
|
|
706
|
+
)
|
|
707
|
+
|
|
647
708
|
self.seq_len = seq_len
|
|
648
709
|
self.n_heads = n_heads
|
|
649
710
|
self._utility_tokens = utility_tokens
|
|
650
711
|
self.return_utility_tokens = return_utility_tokens
|
|
651
712
|
|
|
713
|
+
if layerscale:
|
|
714
|
+
self.layerscale = LayerScale(d_model)
|
|
715
|
+
else:
|
|
716
|
+
self.layerscale = None
|
|
717
|
+
|
|
652
718
|
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
653
719
|
if self._utility_tokens:
|
|
654
720
|
self._utility_token_embedding = nn.Parameter(
|
|
@@ -695,6 +761,7 @@ class TransformerEncoder(nn.Module):
|
|
|
695
761
|
relative_position_embedding=relative_position_embedding,
|
|
696
762
|
source_size=source_size,
|
|
697
763
|
utility_tokens=utility_tokens,
|
|
764
|
+
talking_heads=talking_heads,
|
|
698
765
|
mlp_ratio=mlp_ratio,
|
|
699
766
|
activation=activation,
|
|
700
767
|
activation_kwargs=activation_kwargs,
|
|
@@ -712,6 +779,7 @@ class TransformerEncoder(nn.Module):
|
|
|
712
779
|
post_norm=post_norm,
|
|
713
780
|
normformer=normformer,
|
|
714
781
|
checkpoint_ff=checkpoint_ff,
|
|
782
|
+
layerscale=layerscale,
|
|
715
783
|
)
|
|
716
784
|
for i in range(n_layers)
|
|
717
785
|
]
|
|
@@ -732,15 +800,17 @@ class TransformerEncoder(nn.Module):
|
|
|
732
800
|
x = x
|
|
733
801
|
|
|
734
802
|
if self.absolute_position_embedding is not None:
|
|
735
|
-
|
|
803
|
+
position_embedding = self.absolute_position_embedding(
|
|
736
804
|
torch.arange(
|
|
737
805
|
0, self.full_sequence_length, dtype=torch.long, device=x.device
|
|
738
806
|
).unsqueeze(
|
|
739
807
|
0
|
|
740
808
|
) # to shape (1, seq_len) to broadcast over batch
|
|
741
809
|
)
|
|
810
|
+
if self.layerscale is not None:
|
|
811
|
+
position_embedding = self.layerscale(position_embedding)
|
|
742
812
|
|
|
743
|
-
return x
|
|
813
|
+
return x + position_embedding
|
|
744
814
|
|
|
745
815
|
def forward(self, x):
|
|
746
816
|
|
|
@@ -174,6 +174,7 @@ class ViTEncoder(nn.Module):
|
|
|
174
174
|
transformer_heads=4,
|
|
175
175
|
transformer_mlp_ratio=2,
|
|
176
176
|
transformer_utility_tokens=0,
|
|
177
|
+
transformer_talking_heads=False,
|
|
177
178
|
transformer_return_utility_tokens=False,
|
|
178
179
|
transformer_activation: nn.Module = SquaredReLU,
|
|
179
180
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -186,6 +187,7 @@ class ViTEncoder(nn.Module):
|
|
|
186
187
|
transformer_msa_dropout=0.1,
|
|
187
188
|
transformer_stochastic_depth=0.1,
|
|
188
189
|
transformer_checkpoint_ff=True,
|
|
190
|
+
transformer_layerscale=True,
|
|
189
191
|
linear_module=nn.Linear,
|
|
190
192
|
):
|
|
191
193
|
super().__init__()
|
|
@@ -345,11 +347,13 @@ class ViTEncoder(nn.Module):
|
|
|
345
347
|
causal=False,
|
|
346
348
|
linear_module=linear_module,
|
|
347
349
|
utility_tokens=transformer_utility_tokens,
|
|
350
|
+
talking_heads=transformer_talking_heads,
|
|
348
351
|
return_utility_tokens=transformer_return_utility_tokens,
|
|
349
352
|
pre_norm=transformer_pre_norm,
|
|
350
353
|
normformer=transformer_normformer,
|
|
351
354
|
post_norm=transformer_post_norm,
|
|
352
355
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
356
|
+
layerscale=transformer_layerscale,
|
|
353
357
|
)
|
|
354
358
|
else:
|
|
355
359
|
self.transformer = nn.Identity()
|
|
@@ -472,6 +476,7 @@ class ViT(nn.Module):
|
|
|
472
476
|
transformer_heads=4,
|
|
473
477
|
transformer_mlp_ratio=2,
|
|
474
478
|
transformer_utility_tokens=0,
|
|
479
|
+
transformer_talking_heads=False,
|
|
475
480
|
transformer_return_utility_tokens=False,
|
|
476
481
|
transformer_activation: nn.Module = SquaredReLU,
|
|
477
482
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
@@ -484,6 +489,7 @@ class ViT(nn.Module):
|
|
|
484
489
|
transformer_msa_dropout=0.1,
|
|
485
490
|
transformer_stochastic_depth=0.1,
|
|
486
491
|
transformer_checkpoint_ff=True,
|
|
492
|
+
transformer_layerscale=True,
|
|
487
493
|
head=SequencePoolClassificationHead,
|
|
488
494
|
batch_norm_logits=True,
|
|
489
495
|
logit_projection_layer=nn.Linear,
|
|
@@ -543,6 +549,7 @@ class ViT(nn.Module):
|
|
|
543
549
|
transformer_heads=transformer_heads,
|
|
544
550
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
545
551
|
transformer_utility_tokens=transformer_utility_tokens,
|
|
552
|
+
transformer_talking_heads=transformer_talking_heads,
|
|
546
553
|
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
547
554
|
transformer_activation=transformer_activation,
|
|
548
555
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
@@ -555,6 +562,7 @@ class ViT(nn.Module):
|
|
|
555
562
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
556
563
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
557
564
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
565
|
+
transformer_layerscale=transformer_layerscale,
|
|
558
566
|
linear_module=linear_module,
|
|
559
567
|
)
|
|
560
568
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|