broccoli-ml 5.1.0__tar.gz → 7.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/PKG-INFO +1 -1
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/broccoli/transformer.py +52 -38
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/broccoli/vit.py +30 -18
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/pyproject.toml +1 -1
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/LICENSE +0 -0
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/README.md +0 -0
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/broccoli/activation.py +0 -0
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/broccoli/linear.py +0 -0
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/broccoli/rope.py +0 -0
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-5.1.0 → broccoli_ml-7.0.0}/broccoli/utils.py +0 -0
|
@@ -13,6 +13,7 @@ from .rope import RotaryEmbedding, apply_rotary_emb
|
|
|
13
13
|
try:
|
|
14
14
|
from flash_attn import flash_attn_func
|
|
15
15
|
|
|
16
|
+
print("Using flash-attn.")
|
|
16
17
|
FLASH_ATTN = True
|
|
17
18
|
except ImportError:
|
|
18
19
|
pass
|
|
@@ -76,7 +77,7 @@ class MHAttention(nn.Module):
|
|
|
76
77
|
causal=False,
|
|
77
78
|
seq_len=None,
|
|
78
79
|
linear_module: nn.Module = nn.Linear,
|
|
79
|
-
|
|
80
|
+
utility_tokens=0,
|
|
80
81
|
rotary_embedding=None,
|
|
81
82
|
source_size=None,
|
|
82
83
|
scaling="d",
|
|
@@ -129,7 +130,7 @@ class MHAttention(nn.Module):
|
|
|
129
130
|
)
|
|
130
131
|
self.rotary_embedding = rotary_embedding
|
|
131
132
|
self.source_size = source_size
|
|
132
|
-
self.
|
|
133
|
+
self.utility_tokens = utility_tokens
|
|
133
134
|
|
|
134
135
|
self.reset_parameters()
|
|
135
136
|
|
|
@@ -156,7 +157,7 @@ class MHAttention(nn.Module):
|
|
|
156
157
|
self, q: torch.Tensor, k: torch.Tensor
|
|
157
158
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
158
159
|
"""
|
|
159
|
-
Apply Axial RoPE to all tokens except
|
|
160
|
+
Apply Axial RoPE to all tokens except utility tokens
|
|
160
161
|
"""
|
|
161
162
|
|
|
162
163
|
if len(self.source_size) == 1:
|
|
@@ -180,8 +181,8 @@ class MHAttention(nn.Module):
|
|
|
180
181
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
181
182
|
)
|
|
182
183
|
|
|
183
|
-
|
|
184
|
-
|
|
184
|
+
q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
|
|
185
|
+
k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
|
|
185
186
|
|
|
186
187
|
q_img = rearrange(
|
|
187
188
|
q_img,
|
|
@@ -208,9 +209,9 @@ class MHAttention(nn.Module):
|
|
|
208
209
|
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
209
210
|
)
|
|
210
211
|
|
|
211
|
-
# Re-combine the
|
|
212
|
-
q = torch.cat([
|
|
213
|
-
k = torch.cat([
|
|
212
|
+
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
213
|
+
q = torch.cat([q_util, q_img], dim=1)
|
|
214
|
+
k = torch.cat([k_util, k_img], dim=1)
|
|
214
215
|
|
|
215
216
|
return q, k
|
|
216
217
|
|
|
@@ -284,7 +285,7 @@ class MHAttention(nn.Module):
|
|
|
284
285
|
|
|
285
286
|
return self.out_proj(output_without_heads)
|
|
286
287
|
|
|
287
|
-
def
|
|
288
|
+
def attention_logits(self, q, k, v):
|
|
288
289
|
|
|
289
290
|
q, k, v = self.project_qkv(q, k, v)
|
|
290
291
|
|
|
@@ -301,8 +302,6 @@ class MHAttention(nn.Module):
|
|
|
301
302
|
if self.causal:
|
|
302
303
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
303
304
|
|
|
304
|
-
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
305
|
-
|
|
306
305
|
return qk_scores # (batch, head, seq_len, seq_len)
|
|
307
306
|
|
|
308
307
|
def reset_parameters(self):
|
|
@@ -416,7 +415,7 @@ class TransformerBlock(nn.Module):
|
|
|
416
415
|
n_heads,
|
|
417
416
|
relative_position_embedding=False,
|
|
418
417
|
source_size=None,
|
|
419
|
-
|
|
418
|
+
utility_tokens=0,
|
|
420
419
|
mlp_ratio=4,
|
|
421
420
|
activation: nn.Module = nn.ReLU,
|
|
422
421
|
activation_kwargs: Optional[dict] = None,
|
|
@@ -474,7 +473,7 @@ class TransformerBlock(nn.Module):
|
|
|
474
473
|
linear_module=linear_module,
|
|
475
474
|
rotary_embedding=self.rotary_embedding,
|
|
476
475
|
source_size=source_size,
|
|
477
|
-
|
|
476
|
+
utility_tokens=utility_tokens,
|
|
478
477
|
scaling=msa_scaling,
|
|
479
478
|
)
|
|
480
479
|
|
|
@@ -529,15 +528,15 @@ class TransformerBlock(nn.Module):
|
|
|
529
528
|
|
|
530
529
|
return x
|
|
531
530
|
|
|
532
|
-
def
|
|
531
|
+
def attention_logits(self, x):
|
|
533
532
|
"""
|
|
534
533
|
Give back the attention scores used in this layer.
|
|
535
534
|
"""
|
|
536
535
|
if self.pre_norm:
|
|
537
536
|
x = self.layer_norm_1(x)
|
|
538
|
-
return self.attn(x, x, x)
|
|
537
|
+
return self.attn.attention_logits(x, x, x)
|
|
539
538
|
else:
|
|
540
|
-
return self.attn(x, x, x)
|
|
539
|
+
return self.attn.attention_logits(x, x, x)
|
|
541
540
|
|
|
542
541
|
def reset_parameters(self):
|
|
543
542
|
self.layer_norm_1.reset_parameters()
|
|
@@ -573,8 +572,8 @@ class TransformerEncoder(nn.Module):
|
|
|
573
572
|
stochastic_depth=0.0,
|
|
574
573
|
causal=False,
|
|
575
574
|
linear_module=nn.Linear,
|
|
576
|
-
|
|
577
|
-
|
|
575
|
+
utility_tokens=0,
|
|
576
|
+
return_utility_tokens=False,
|
|
578
577
|
pre_norm=True,
|
|
579
578
|
post_norm=False,
|
|
580
579
|
normformer=False,
|
|
@@ -592,22 +591,33 @@ class TransformerEncoder(nn.Module):
|
|
|
592
591
|
if relative_position_embedding and (source_size is None):
|
|
593
592
|
raise ValueError(
|
|
594
593
|
"`source_size` for TransformerEncoder cannot be None if"
|
|
595
|
-
" `
|
|
594
|
+
" `relative_position_embedding` is True"
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
if absolute_position_embedding and (seq_len is None):
|
|
598
|
+
raise ValueError(
|
|
599
|
+
"`seq_len` for TransformerEncoder cannot be None if"
|
|
600
|
+
" `absolute_position_embedding` is True"
|
|
596
601
|
)
|
|
597
602
|
|
|
598
603
|
super().__init__()
|
|
599
604
|
self.seq_len = seq_len
|
|
600
605
|
self.n_heads = n_heads
|
|
601
|
-
self.
|
|
602
|
-
self.
|
|
603
|
-
|
|
604
|
-
# Initialise
|
|
605
|
-
if self.
|
|
606
|
-
self.
|
|
607
|
-
|
|
608
|
-
|
|
606
|
+
self._utility_tokens = utility_tokens
|
|
607
|
+
self.return_utility_tokens = return_utility_tokens
|
|
608
|
+
|
|
609
|
+
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
610
|
+
if self._utility_tokens:
|
|
611
|
+
self._utility_token_embedding = nn.Parameter(
|
|
612
|
+
torch.empty(self._utility_tokens, d_model)
|
|
613
|
+
)
|
|
614
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
615
|
+
else:
|
|
616
|
+
self._utility_token_embedding = None
|
|
617
|
+
|
|
618
|
+
if self._utility_tokens and (self.seq_len is not None):
|
|
619
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
609
620
|
else:
|
|
610
|
-
self._bos_embedding = None
|
|
611
621
|
self.full_sequence_length = self.seq_len
|
|
612
622
|
|
|
613
623
|
self.d_model = d_model
|
|
@@ -641,7 +651,7 @@ class TransformerEncoder(nn.Module):
|
|
|
641
651
|
n_heads,
|
|
642
652
|
relative_position_embedding=relative_position_embedding,
|
|
643
653
|
source_size=source_size,
|
|
644
|
-
|
|
654
|
+
utility_tokens=utility_tokens,
|
|
645
655
|
mlp_ratio=mlp_ratio,
|
|
646
656
|
activation=activation,
|
|
647
657
|
activation_kwargs=activation_kwargs,
|
|
@@ -669,8 +679,10 @@ class TransformerEncoder(nn.Module):
|
|
|
669
679
|
return ",".join([str(block._kv_distance) for block in self.blocks])
|
|
670
680
|
|
|
671
681
|
def preprocess(self, x):
|
|
672
|
-
if self.
|
|
673
|
-
x = torch.cat(
|
|
682
|
+
if self._utility_tokens:
|
|
683
|
+
x = torch.cat(
|
|
684
|
+
[self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
|
|
685
|
+
)
|
|
674
686
|
else:
|
|
675
687
|
x = x
|
|
676
688
|
|
|
@@ -683,6 +695,8 @@ class TransformerEncoder(nn.Module):
|
|
|
683
695
|
) # to shape (1, seq_len) to broadcast over batch
|
|
684
696
|
)
|
|
685
697
|
|
|
698
|
+
return x
|
|
699
|
+
|
|
686
700
|
def forward(self, x):
|
|
687
701
|
|
|
688
702
|
x = self.preprocess(x)
|
|
@@ -690,12 +704,12 @@ class TransformerEncoder(nn.Module):
|
|
|
690
704
|
for block in self.blocks:
|
|
691
705
|
x = block(x)
|
|
692
706
|
|
|
693
|
-
if self.
|
|
694
|
-
return x[:, self.
|
|
707
|
+
if self._utility_tokens and not self.return_utility_tokens:
|
|
708
|
+
return x[:, self._utility_tokens :, :]
|
|
695
709
|
else:
|
|
696
710
|
return x
|
|
697
711
|
|
|
698
|
-
def
|
|
712
|
+
def attention_logits(self, x):
|
|
699
713
|
|
|
700
714
|
x = self.preprocess(x)
|
|
701
715
|
|
|
@@ -703,15 +717,15 @@ class TransformerEncoder(nn.Module):
|
|
|
703
717
|
|
|
704
718
|
for block in self.blocks:
|
|
705
719
|
# Get attention scores with shape (batch, 1, head, seq_len, seq_len)
|
|
706
|
-
|
|
707
|
-
layer_scores.append(
|
|
720
|
+
layer_attention_logits = block.attention_logits(x).unsqueeze(1)
|
|
721
|
+
layer_scores.append(layer_attention_logits)
|
|
708
722
|
x = block(x)
|
|
709
723
|
|
|
710
724
|
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
|
711
725
|
|
|
712
726
|
def reset_parameters(self):
|
|
713
|
-
if self.
|
|
714
|
-
nn.init.normal_(self.
|
|
727
|
+
if self._utility_token_embedding is not None:
|
|
728
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
715
729
|
|
|
716
730
|
if self.absolute_position_embedding is not None:
|
|
717
731
|
self.absolute_position_embedding.reset_parameters()
|
|
@@ -11,7 +11,6 @@ from einops.layers.torch import Rearrange
|
|
|
11
11
|
|
|
12
12
|
import torch
|
|
13
13
|
import torch.nn as nn
|
|
14
|
-
import torch.nn.functional as F
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class GetCLSToken(nn.Module):
|
|
@@ -39,6 +38,9 @@ class SequencePool(nn.Module):
|
|
|
39
38
|
weights = self.attention(x)
|
|
40
39
|
return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
|
|
41
40
|
|
|
41
|
+
def attention_scores(self, x):
|
|
42
|
+
return self.attention(x)
|
|
43
|
+
|
|
42
44
|
def reset_parameters(self):
|
|
43
45
|
# Iterate over modules in the sequential block
|
|
44
46
|
for module in self.attention:
|
|
@@ -169,8 +171,8 @@ class ViTEncoder(nn.Module):
|
|
|
169
171
|
transformer_layers=7,
|
|
170
172
|
transformer_heads=4,
|
|
171
173
|
transformer_mlp_ratio=2,
|
|
172
|
-
|
|
173
|
-
|
|
174
|
+
transformer_utility_tokens=0,
|
|
175
|
+
transformer_return_utility_tokens=False,
|
|
174
176
|
transformer_activation: nn.Module = SquaredReLU,
|
|
175
177
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
176
178
|
transformer_ff_linear_module_up=None,
|
|
@@ -336,8 +338,8 @@ class ViTEncoder(nn.Module):
|
|
|
336
338
|
stochastic_depth=transformer_stochastic_depth,
|
|
337
339
|
causal=False,
|
|
338
340
|
linear_module=linear_module,
|
|
339
|
-
|
|
340
|
-
|
|
341
|
+
utility_tokens=transformer_utility_tokens,
|
|
342
|
+
return_utility_tokens=transformer_return_utility_tokens,
|
|
341
343
|
pre_norm=transformer_pre_norm,
|
|
342
344
|
normformer=transformer_normformer,
|
|
343
345
|
post_norm=transformer_post_norm,
|
|
@@ -400,9 +402,9 @@ class ViTEncoder(nn.Module):
|
|
|
400
402
|
def forward(self, x):
|
|
401
403
|
return self.encoder(x)
|
|
402
404
|
|
|
403
|
-
def
|
|
405
|
+
def attention_logits(self, x):
|
|
404
406
|
x = self.encoder[:-1](x)
|
|
405
|
-
return self.encoder[-1].
|
|
407
|
+
return self.encoder[-1].attention_logits(x)
|
|
406
408
|
|
|
407
409
|
def reset_parameters(self):
|
|
408
410
|
for module in self.encoder:
|
|
@@ -449,8 +451,8 @@ class ViT(nn.Module):
|
|
|
449
451
|
transformer_layers=7,
|
|
450
452
|
transformer_heads=4,
|
|
451
453
|
transformer_mlp_ratio=2,
|
|
452
|
-
|
|
453
|
-
|
|
454
|
+
transformer_utility_tokens=0,
|
|
455
|
+
transformer_return_utility_tokens=False,
|
|
454
456
|
transformer_activation: nn.Module = SquaredReLU,
|
|
455
457
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
456
458
|
transformer_ff_linear_module_up=None,
|
|
@@ -516,8 +518,8 @@ class ViT(nn.Module):
|
|
|
516
518
|
transformer_layers=transformer_layers,
|
|
517
519
|
transformer_heads=transformer_heads,
|
|
518
520
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
519
|
-
|
|
520
|
-
|
|
521
|
+
transformer_utility_tokens=transformer_utility_tokens,
|
|
522
|
+
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
521
523
|
transformer_activation=transformer_activation,
|
|
522
524
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
523
525
|
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
|
@@ -546,16 +548,26 @@ class ViT(nn.Module):
|
|
|
546
548
|
def forward(self, x):
|
|
547
549
|
return self.pool(self.encoder(x))
|
|
548
550
|
|
|
549
|
-
def
|
|
550
|
-
return self.encoder.
|
|
551
|
+
def attention_logits(self, x):
|
|
552
|
+
return self.encoder.attention_logits(x)
|
|
553
|
+
|
|
554
|
+
def pool_attention(self, x):
|
|
555
|
+
if hasattr(self.pool.summarize, "attention"):
|
|
556
|
+
return self.pool.summarize.attention(self.encoder(x))
|
|
557
|
+
else:
|
|
558
|
+
raise NotImplementedError(
|
|
559
|
+
"`pool_attention` is currently only implemented where"
|
|
560
|
+
" head class is SequencePoolClassificationHead"
|
|
561
|
+
)
|
|
551
562
|
|
|
552
|
-
def
|
|
553
|
-
all_attention = self.
|
|
563
|
+
def head_to_utility_token_attention_logits(self, x):
|
|
564
|
+
all_attention = self.attention_logits(x)
|
|
554
565
|
batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
|
|
555
566
|
sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
567
|
+
n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
|
|
568
|
+
return sequence_averages[
|
|
569
|
+
:, :, :n_utility_tokens
|
|
570
|
+
] # (layer, head, utility_tokens)
|
|
559
571
|
|
|
560
572
|
def reset_parameters(self):
|
|
561
573
|
self.encoder.reset_parameters()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|