broccoli-ml 5.1.4__tar.gz → 6.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/PKG-INFO +1 -1
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/broccoli/transformer.py +32 -28
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/broccoli/vit.py +25 -13
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/pyproject.toml +1 -1
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/LICENSE +0 -0
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/README.md +0 -0
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/broccoli/activation.py +0 -0
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/broccoli/linear.py +0 -0
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/broccoli/rope.py +0 -0
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-5.1.4 → broccoli_ml-6.0.0}/broccoli/utils.py +0 -0
|
@@ -76,7 +76,7 @@ class MHAttention(nn.Module):
|
|
|
76
76
|
causal=False,
|
|
77
77
|
seq_len=None,
|
|
78
78
|
linear_module: nn.Module = nn.Linear,
|
|
79
|
-
|
|
79
|
+
utility_tokens=0,
|
|
80
80
|
rotary_embedding=None,
|
|
81
81
|
source_size=None,
|
|
82
82
|
scaling="d",
|
|
@@ -129,7 +129,7 @@ class MHAttention(nn.Module):
|
|
|
129
129
|
)
|
|
130
130
|
self.rotary_embedding = rotary_embedding
|
|
131
131
|
self.source_size = source_size
|
|
132
|
-
self.
|
|
132
|
+
self.utility_tokens = utility_tokens
|
|
133
133
|
|
|
134
134
|
self.reset_parameters()
|
|
135
135
|
|
|
@@ -156,7 +156,7 @@ class MHAttention(nn.Module):
|
|
|
156
156
|
self, q: torch.Tensor, k: torch.Tensor
|
|
157
157
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
158
158
|
"""
|
|
159
|
-
Apply Axial RoPE to all tokens except
|
|
159
|
+
Apply Axial RoPE to all tokens except utility tokens
|
|
160
160
|
"""
|
|
161
161
|
|
|
162
162
|
if len(self.source_size) == 1:
|
|
@@ -180,8 +180,8 @@ class MHAttention(nn.Module):
|
|
|
180
180
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
181
181
|
)
|
|
182
182
|
|
|
183
|
-
|
|
184
|
-
|
|
183
|
+
q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
|
|
184
|
+
k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
|
|
185
185
|
|
|
186
186
|
q_img = rearrange(
|
|
187
187
|
q_img,
|
|
@@ -208,9 +208,9 @@ class MHAttention(nn.Module):
|
|
|
208
208
|
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
209
209
|
)
|
|
210
210
|
|
|
211
|
-
# Re-combine the
|
|
212
|
-
q = torch.cat([
|
|
213
|
-
k = torch.cat([
|
|
211
|
+
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
212
|
+
q = torch.cat([q_util, q_img], dim=1)
|
|
213
|
+
k = torch.cat([k_util, k_img], dim=1)
|
|
214
214
|
|
|
215
215
|
return q, k
|
|
216
216
|
|
|
@@ -414,7 +414,7 @@ class TransformerBlock(nn.Module):
|
|
|
414
414
|
n_heads,
|
|
415
415
|
relative_position_embedding=False,
|
|
416
416
|
source_size=None,
|
|
417
|
-
|
|
417
|
+
utility_tokens=0,
|
|
418
418
|
mlp_ratio=4,
|
|
419
419
|
activation: nn.Module = nn.ReLU,
|
|
420
420
|
activation_kwargs: Optional[dict] = None,
|
|
@@ -472,7 +472,7 @@ class TransformerBlock(nn.Module):
|
|
|
472
472
|
linear_module=linear_module,
|
|
473
473
|
rotary_embedding=self.rotary_embedding,
|
|
474
474
|
source_size=source_size,
|
|
475
|
-
|
|
475
|
+
utility_tokens=utility_tokens,
|
|
476
476
|
scaling=msa_scaling,
|
|
477
477
|
)
|
|
478
478
|
|
|
@@ -571,8 +571,8 @@ class TransformerEncoder(nn.Module):
|
|
|
571
571
|
stochastic_depth=0.0,
|
|
572
572
|
causal=False,
|
|
573
573
|
linear_module=nn.Linear,
|
|
574
|
-
|
|
575
|
-
|
|
574
|
+
utility_tokens=0,
|
|
575
|
+
return_utility_tokens=False,
|
|
576
576
|
pre_norm=True,
|
|
577
577
|
post_norm=False,
|
|
578
578
|
normformer=False,
|
|
@@ -596,16 +596,18 @@ class TransformerEncoder(nn.Module):
|
|
|
596
596
|
super().__init__()
|
|
597
597
|
self.seq_len = seq_len
|
|
598
598
|
self.n_heads = n_heads
|
|
599
|
-
self.
|
|
600
|
-
self.
|
|
601
|
-
|
|
602
|
-
# Initialise
|
|
603
|
-
if self.
|
|
604
|
-
self.
|
|
605
|
-
|
|
606
|
-
|
|
599
|
+
self._utility_tokens = utility_tokens
|
|
600
|
+
self.return_utility_tokens = return_utility_tokens
|
|
601
|
+
|
|
602
|
+
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
603
|
+
if self._utility_tokens:
|
|
604
|
+
self._utility_token_embedding = nn.Parameter(
|
|
605
|
+
torch.empty(self._utility_tokens, d_model)
|
|
606
|
+
)
|
|
607
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
608
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
607
609
|
else:
|
|
608
|
-
self.
|
|
610
|
+
self._utility_token_embedding = None
|
|
609
611
|
self.full_sequence_length = self.seq_len
|
|
610
612
|
|
|
611
613
|
self.d_model = d_model
|
|
@@ -639,7 +641,7 @@ class TransformerEncoder(nn.Module):
|
|
|
639
641
|
n_heads,
|
|
640
642
|
relative_position_embedding=relative_position_embedding,
|
|
641
643
|
source_size=source_size,
|
|
642
|
-
|
|
644
|
+
utility_tokens=utility_tokens,
|
|
643
645
|
mlp_ratio=mlp_ratio,
|
|
644
646
|
activation=activation,
|
|
645
647
|
activation_kwargs=activation_kwargs,
|
|
@@ -667,8 +669,10 @@ class TransformerEncoder(nn.Module):
|
|
|
667
669
|
return ",".join([str(block._kv_distance) for block in self.blocks])
|
|
668
670
|
|
|
669
671
|
def preprocess(self, x):
|
|
670
|
-
if self.
|
|
671
|
-
x = torch.cat(
|
|
672
|
+
if self._utility_tokens:
|
|
673
|
+
x = torch.cat(
|
|
674
|
+
[self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
|
|
675
|
+
)
|
|
672
676
|
else:
|
|
673
677
|
x = x
|
|
674
678
|
|
|
@@ -690,8 +694,8 @@ class TransformerEncoder(nn.Module):
|
|
|
690
694
|
for block in self.blocks:
|
|
691
695
|
x = block(x)
|
|
692
696
|
|
|
693
|
-
if self.
|
|
694
|
-
return x[:, self.
|
|
697
|
+
if self._utility_tokens and not self.return_utility_tokens:
|
|
698
|
+
return x[:, self._utility_tokens :, :]
|
|
695
699
|
else:
|
|
696
700
|
return x
|
|
697
701
|
|
|
@@ -710,8 +714,8 @@ class TransformerEncoder(nn.Module):
|
|
|
710
714
|
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
|
711
715
|
|
|
712
716
|
def reset_parameters(self):
|
|
713
|
-
if self.
|
|
714
|
-
nn.init.normal_(self.
|
|
717
|
+
if self._utility_token_embedding is not None:
|
|
718
|
+
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
715
719
|
|
|
716
720
|
if self.absolute_position_embedding is not None:
|
|
717
721
|
self.absolute_position_embedding.reset_parameters()
|
|
@@ -11,7 +11,6 @@ from einops.layers.torch import Rearrange
|
|
|
11
11
|
|
|
12
12
|
import torch
|
|
13
13
|
import torch.nn as nn
|
|
14
|
-
import torch.nn.functional as F
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class GetCLSToken(nn.Module):
|
|
@@ -39,6 +38,9 @@ class SequencePool(nn.Module):
|
|
|
39
38
|
weights = self.attention(x)
|
|
40
39
|
return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
|
|
41
40
|
|
|
41
|
+
def attention_scores(self, x):
|
|
42
|
+
return self.attention(x)
|
|
43
|
+
|
|
42
44
|
def reset_parameters(self):
|
|
43
45
|
# Iterate over modules in the sequential block
|
|
44
46
|
for module in self.attention:
|
|
@@ -169,8 +171,8 @@ class ViTEncoder(nn.Module):
|
|
|
169
171
|
transformer_layers=7,
|
|
170
172
|
transformer_heads=4,
|
|
171
173
|
transformer_mlp_ratio=2,
|
|
172
|
-
|
|
173
|
-
|
|
174
|
+
transformer_utility_tokens=0,
|
|
175
|
+
transformer_return_utility_tokens=False,
|
|
174
176
|
transformer_activation: nn.Module = SquaredReLU,
|
|
175
177
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
176
178
|
transformer_ff_linear_module_up=None,
|
|
@@ -336,8 +338,8 @@ class ViTEncoder(nn.Module):
|
|
|
336
338
|
stochastic_depth=transformer_stochastic_depth,
|
|
337
339
|
causal=False,
|
|
338
340
|
linear_module=linear_module,
|
|
339
|
-
|
|
340
|
-
|
|
341
|
+
utility_tokens=transformer_utility_tokens,
|
|
342
|
+
return_utility_tokens=transformer_return_utility_tokens,
|
|
341
343
|
pre_norm=transformer_pre_norm,
|
|
342
344
|
normformer=transformer_normformer,
|
|
343
345
|
post_norm=transformer_post_norm,
|
|
@@ -449,8 +451,8 @@ class ViT(nn.Module):
|
|
|
449
451
|
transformer_layers=7,
|
|
450
452
|
transformer_heads=4,
|
|
451
453
|
transformer_mlp_ratio=2,
|
|
452
|
-
|
|
453
|
-
|
|
454
|
+
transformer_utility_tokens=0,
|
|
455
|
+
transformer_return_utility_tokens=False,
|
|
454
456
|
transformer_activation: nn.Module = SquaredReLU,
|
|
455
457
|
transformer_activation_kwargs: Optional[dict] = None,
|
|
456
458
|
transformer_ff_linear_module_up=None,
|
|
@@ -516,8 +518,8 @@ class ViT(nn.Module):
|
|
|
516
518
|
transformer_layers=transformer_layers,
|
|
517
519
|
transformer_heads=transformer_heads,
|
|
518
520
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
519
|
-
|
|
520
|
-
|
|
521
|
+
transformer_utility_tokens=transformer_utility_tokens,
|
|
522
|
+
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
521
523
|
transformer_activation=transformer_activation,
|
|
522
524
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
523
525
|
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
|
@@ -549,13 +551,23 @@ class ViT(nn.Module):
|
|
|
549
551
|
def attention_logits(self, x):
|
|
550
552
|
return self.encoder.attention_logits(x)
|
|
551
553
|
|
|
552
|
-
def
|
|
554
|
+
def pool_attention(self, x):
|
|
555
|
+
if hasattr(self.pool.summarize, "attention"):
|
|
556
|
+
return self.pool.summarize.attention(self.encoder(x))
|
|
557
|
+
else:
|
|
558
|
+
raise NotImplementedError(
|
|
559
|
+
"`pool_attention` is currently only implemented where"
|
|
560
|
+
" head class is SequencePoolClassificationHead"
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
def head_to_utility_token_attention_logits(self, x):
|
|
553
564
|
all_attention = self.attention_logits(x)
|
|
554
565
|
batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
|
|
555
566
|
sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
567
|
+
n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
|
|
568
|
+
return sequence_averages[
|
|
569
|
+
:, :, :n_utility_tokens
|
|
570
|
+
] # (layer, head, utility_tokens)
|
|
559
571
|
|
|
560
572
|
def reset_parameters(self):
|
|
561
573
|
self.encoder.reset_parameters()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|