broccoli-ml 5.2.0__tar.gz → 6.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/PKG-INFO +1 -1
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/broccoli/transformer.py +33 -28
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/broccoli/vit.py +25 -12
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/pyproject.toml +1 -1
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/LICENSE +0 -0
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/README.md +0 -0
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/broccoli/__init__.py +0 -0
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/broccoli/activation.py +0 -0
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/broccoli/cnn.py +0 -0
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/broccoli/linear.py +0 -0
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/broccoli/rope.py +0 -0
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/broccoli/tensor.py +0 -0
- {broccoli_ml-5.2.0 → broccoli_ml-6.0.1}/broccoli/utils.py +0 -0
|
@@ -13,6 +13,7 @@ from .rope import RotaryEmbedding, apply_rotary_emb
|
|
|
13
13
|
try:
|
|
14
14
|
from flash_attn import flash_attn_func
|
|
15
15
|
|
|
16
|
+
print("Using flash-attn.")
|
|
16
17
|
FLASH_ATTN = True
|
|
17
18
|
except ImportError:
|
|
18
19
|
pass
|
|
@@ -76,7 +77,7 @@ class MHAttention(nn.Module):
|
|
|
76
77
|
causal=False,
|
|
77
78
|
seq_len=None,
|
|
78
79
|
linear_module: nn.Module = nn.Linear,
|
|
79
|
-
|
|
80
|
+
utility_tokens=0,
|
|
80
81
|
rotary_embedding=None,
|
|
81
82
|
source_size=None,
|
|
82
83
|
scaling="d",
|
|
@@ -129,7 +130,7 @@ class MHAttention(nn.Module):
|
|
|
129
130
|
)
|
|
130
131
|
self.rotary_embedding = rotary_embedding
|
|
131
132
|
self.source_size = source_size
|
|
132
|
-
self.
|
|
133
|
+
self.utility_tokens = utility_tokens
|
|
133
134
|
|
|
134
135
|
self.reset_parameters()
|
|
135
136
|
|
|
@@ -156,7 +157,7 @@ class MHAttention(nn.Module):
|
|
|
156
157
|
self, q: torch.Tensor, k: torch.Tensor
|
|
157
158
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
158
159
|
"""
|
|
159
|
-
Apply Axial RoPE to all tokens except
|
|
160
|
+
Apply Axial RoPE to all tokens except utility tokens
|
|
160
161
|
"""
|
|
161
162
|
|
|
162
163
|
if len(self.source_size) == 1:
|
|
@@ -180,8 +181,8 @@ class MHAttention(nn.Module):
|
|
|
180
181
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
181
182
|
)
|
|
182
183
|
|
|
183
|
-
|
|
184
|
-
|
|
184
|
+
q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
|
|
185
|
+
k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
|
|
185
186
|
|
|
186
187
|
q_img = rearrange(
|
|
187
188
|
q_img,
|
|
@@ -208,9 +209,9 @@ class MHAttention(nn.Module):
|
|
|
208
209
|
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
209
210
|
)
|
|
210
211
|
|
|
211
|
-
# Re-combine the
|
|
212
|
-
q = torch.cat([
|
|
213
|
-
k = torch.cat([
|
|
212
|
+
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
213
|
+
q = torch.cat([q_util, q_img], dim=1)
|
|
214
|
+
k = torch.cat([k_util, k_img], dim=1)
|
|
214
215
|
|
|
215
216
|
return q, k
|
|
216
217
|
|
|
@@ -414,7 +415,7 @@ class TransformerBlock(nn.Module):
|
|
|
414
415
|
n_heads,
|
|
415
416
|
relative_position_embedding=False,
|
|
416
417
|
source_size=None,
|
|
417
|
-
|
|
418
|
+
utility_tokens=0,
|
|
418
419
|
mlp_ratio=4,
|
|
419
420
|
activation: nn.Module = nn.ReLU,
|
|
420
421
|
activation_kwargs: Optional[dict] = None,
|
|
@@ -472,7 +473,7 @@ class TransformerBlock(nn.Module):
|
|
|
472
473
|
linear_module=linear_module,
|
|
473
474
|
rotary_embedding=self.rotary_embedding,
|
|
474
475
|
source_size=source_size,
|
|
475
|
-
|
|
476
|
+
utility_tokens=utility_tokens,
|
|
476
477
|
scaling=msa_scaling,
|
|
477
478
|
)
|
|
478
479
|
|
|
@@ -571,8 +572,8 @@ class TransformerEncoder(nn.Module):
|
|
|
571
572
|
stochastic_depth=0.0,
|
|
572
573
|
causal=False,
|
|
573
574
|
linear_module=nn.Linear,
|
|
574
|
-
|
|
575
|
-
|
|
575
|
+
utility_tokens=0,
|
|
576
|
+
return_utility_tokens=False,
|
|
576
577
|
pre_norm=True,
|
|
577
578
|
post_norm=False,
|
|
578
579
|
normformer=False,
|
|
@@ -596,16 +597,18 @@ class TransformerEncoder(nn.Module):
|
|
|
596
597
|
super().__init__()
|
|
597
598
|
self.seq_len = seq_len
|
|
598
599
|
self.n_heads = n_heads
|
|
599
|
-
self.
|
|
600
|
-
self.
|
|
601
|
-
|
|
602
|
-
# Initialise
|
|
603
|
-
if self.
|
|
604
|
-
self.
|
|
605
|
-
|
|
606
|
-
|
|
600
|
+
self._utility_tokens = utility_tokens
|
|
601
|
+
self.return_utility_tokens = return_utility_tokens
|
|
602
|
+
|
|
603
|
+
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
604
|
+
if self._utility_tokens:
|
|
605
|
+
self._utility_token_embedding = nn.Parameter(
|
|
606
|
+
torch.empty(self._utility_tokens, d_model)
|
|
607
|
+
)
|
|
608
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
609
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
607
610
|
else:
|
|
608
|
-
self.
|
|
611
|
+
self._utility_token_embedding = None
|
|
609
612
|
self.full_sequence_length = self.seq_len
|
|
610
613
|
|
|
611
614
|
self.d_model = d_model
|
|
@@ -639,7 +642,7 @@ class TransformerEncoder(nn.Module):
|
|
|
639
642
|
n_heads,
|
|
640
643
|
relative_position_embedding=relative_position_embedding,
|
|
641
644
|
source_size=source_size,
|
|
642
|
-
|
|
645
|
+
utility_tokens=utility_tokens,
|
|
643
646
|
mlp_ratio=mlp_ratio,
|
|
644
647
|
activation=activation,
|
|
645
648
|
activation_kwargs=activation_kwargs,
|
|
@@ -667,8 +670,10 @@ class TransformerEncoder(nn.Module):
|
|
|
667
670
|
return ",".join([str(block._kv_distance) for block in self.blocks])
|
|
668
671
|
|
|
669
672
|
def preprocess(self, x):
|
|
670
|
-
if self.
|
|
671
|
-
x = torch.cat(
|
|
673
|
+
if self._utility_tokens:
|
|
674
|
+
x = torch.cat(
|
|
675
|
+
[self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
|
|
676
|
+
)
|
|
672
677
|
else:
|
|
673
678
|
x = x
|
|
674
679
|
|
|
@@ -690,8 +695,8 @@ class TransformerEncoder(nn.Module):
|
|
|
690
695
|
for block in self.blocks:
|
|
691
696
|
x = block(x)
|
|
692
697
|
|
|
693
|
-
if self.
|
|
694
|
-
return x[:, self.
|
|
698
|
+
if self._utility_tokens and not self.return_utility_tokens:
|
|
699
|
+
return x[:, self._utility_tokens :, :]
|
|
695
700
|
else:
|
|
696
701
|
return x
|
|
697
702
|
|
|
@@ -710,8 +715,8 @@ class TransformerEncoder(nn.Module):
|
|
|
710
715
|
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
|
711
716
|
|
|
712
717
|
def reset_parameters(self):
|
|
713
|
-
if self.
|
|
714
|
-
nn.init.normal_(self.
|
|
718
|
+
if self._utility_token_embedding is not None:
|
|
719
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
715
720
|
|
|
716
721
|
if self.absolute_position_embedding is not None:
|
|
717
722
|
self.absolute_position_embedding.reset_parameters()
|
|
@@ -11,7 +11,6 @@ from einops.layers.torch import Rearrange
|
|
|
11
11
|
|
|
12
12
|
import torch
|
|
13
13
|
import torch.nn as nn
|
|
14
|
-
import torch.nn.functional as F
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class GetCLSToken(nn.Module):
|
|
@@ -39,6 +38,9 @@ class SequencePool(nn.Module):
|
|
|
39
38
|
weights = self.attention(x)
|
|
40
39
|
return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
|
|
41
40
|
|
|
41
|
+
def attention_scores(self, x):
|
|
42
|
+
return self.attention(x)
|
|
43
|
+
|
|
42
44
|
def reset_parameters(self):
|
|
43
45
|
# Iterate over modules in the sequential block
|
|
44
46
|
for module in self.attention:
|
|
@@ -169,8 +171,8 @@ class ViTEncoder(nn.Module):
|
|
|
169
171
|
transformer_layers=7,
|
|
170
172
|
transformer_heads=4,
|
|
171
173
|
transformer_mlp_ratio=2,
|
|
172
|
-
|
|
173
|
-
|
|
174
|
+
transformer_utility_tokens=0,
|
|
175
|
+
transformer_return_utility_tokens=False,
|
|
174
176
|
transformer_activation: nn.Module = SquaredReLU,
|
|
175
177
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
176
178
|
transformer_ff_linear_module_up=None,
|
|
@@ -336,8 +338,8 @@ class ViTEncoder(nn.Module):
|
|
|
336
338
|
stochastic_depth=transformer_stochastic_depth,
|
|
337
339
|
causal=False,
|
|
338
340
|
linear_module=linear_module,
|
|
339
|
-
|
|
340
|
-
|
|
341
|
+
utility_tokens=transformer_utility_tokens,
|
|
342
|
+
return_utility_tokens=transformer_return_utility_tokens,
|
|
341
343
|
pre_norm=transformer_pre_norm,
|
|
342
344
|
normformer=transformer_normformer,
|
|
343
345
|
post_norm=transformer_post_norm,
|
|
@@ -449,8 +451,8 @@ class ViT(nn.Module):
|
|
|
449
451
|
transformer_layers=7,
|
|
450
452
|
transformer_heads=4,
|
|
451
453
|
transformer_mlp_ratio=2,
|
|
452
|
-
|
|
453
|
-
|
|
454
|
+
transformer_utility_tokens=0,
|
|
455
|
+
transformer_return_utility_tokens=False,
|
|
454
456
|
transformer_activation: nn.Module = SquaredReLU,
|
|
455
457
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
456
458
|
transformer_ff_linear_module_up=None,
|
|
@@ -516,8 +518,8 @@ class ViT(nn.Module):
|
|
|
516
518
|
transformer_layers=transformer_layers,
|
|
517
519
|
transformer_heads=transformer_heads,
|
|
518
520
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
519
|
-
|
|
520
|
-
|
|
521
|
+
transformer_utility_tokens=transformer_utility_tokens,
|
|
522
|
+
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
521
523
|
transformer_activation=transformer_activation,
|
|
522
524
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
523
525
|
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
|
@@ -549,12 +551,23 @@ class ViT(nn.Module):
|
|
|
549
551
|
def attention_logits(self, x):
|
|
550
552
|
return self.encoder.attention_logits(x)
|
|
551
553
|
|
|
552
|
-
def
|
|
554
|
+
def pool_attention(self, x):
|
|
555
|
+
if hasattr(self.pool.summarize, "attention"):
|
|
556
|
+
return self.pool.summarize.attention(self.encoder(x))
|
|
557
|
+
else:
|
|
558
|
+
raise NotImplementedError(
|
|
559
|
+
"`pool_attention` is currently only implemented where"
|
|
560
|
+
" head class is SequencePoolClassificationHead"
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
def head_to_utility_token_attention_logits(self, x):
|
|
553
564
|
all_attention = self.attention_logits(x)
|
|
554
565
|
batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
|
|
555
566
|
sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
|
|
556
|
-
|
|
557
|
-
return sequence_averages[
|
|
567
|
+
n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
|
|
568
|
+
return sequence_averages[
|
|
569
|
+
:, :, :n_utility_tokens
|
|
570
|
+
] # (layer, head, utility_tokens)
|
|
558
571
|
|
|
559
572
|
def reset_parameters(self):
|
|
560
573
|
self.encoder.reset_parameters()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|