broccoli-ml 5.1.2__py3-none-any.whl → 5.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +7 -9
- broccoli/vit.py +6 -6
- {broccoli_ml-5.1.2.dist-info → broccoli_ml-5.1.4.dist-info}/METADATA +1 -1
- {broccoli_ml-5.1.2.dist-info → broccoli_ml-5.1.4.dist-info}/RECORD +6 -6
- {broccoli_ml-5.1.2.dist-info → broccoli_ml-5.1.4.dist-info}/LICENSE +0 -0
- {broccoli_ml-5.1.2.dist-info → broccoli_ml-5.1.4.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -284,7 +284,7 @@ class MHAttention(nn.Module):
|
|
|
284
284
|
|
|
285
285
|
return self.out_proj(output_without_heads)
|
|
286
286
|
|
|
287
|
-
def
|
|
287
|
+
def attention_logits(self, q, k, v):
|
|
288
288
|
|
|
289
289
|
q, k, v = self.project_qkv(q, k, v)
|
|
290
290
|
|
|
@@ -301,8 +301,6 @@ class MHAttention(nn.Module):
|
|
|
301
301
|
if self.causal:
|
|
302
302
|
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
303
303
|
|
|
304
|
-
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
305
|
-
|
|
306
304
|
return qk_scores # (batch, head, seq_len, seq_len)
|
|
307
305
|
|
|
308
306
|
def reset_parameters(self):
|
|
@@ -529,15 +527,15 @@ class TransformerBlock(nn.Module):
|
|
|
529
527
|
|
|
530
528
|
return x
|
|
531
529
|
|
|
532
|
-
def
|
|
530
|
+
def attention_logits(self, x):
|
|
533
531
|
"""
|
|
534
532
|
Give back the attention scores used in this layer.
|
|
535
533
|
"""
|
|
536
534
|
if self.pre_norm:
|
|
537
535
|
x = self.layer_norm_1(x)
|
|
538
|
-
return self.attn.
|
|
536
|
+
return self.attn.attention_logits(x, x, x)
|
|
539
537
|
else:
|
|
540
|
-
return self.attn.
|
|
538
|
+
return self.attn.attention_logits(x, x, x)
|
|
541
539
|
|
|
542
540
|
def reset_parameters(self):
|
|
543
541
|
self.layer_norm_1.reset_parameters()
|
|
@@ -697,7 +695,7 @@ class TransformerEncoder(nn.Module):
|
|
|
697
695
|
else:
|
|
698
696
|
return x
|
|
699
697
|
|
|
700
|
-
def
|
|
698
|
+
def attention_logits(self, x):
|
|
701
699
|
|
|
702
700
|
x = self.preprocess(x)
|
|
703
701
|
|
|
@@ -705,8 +703,8 @@ class TransformerEncoder(nn.Module):
|
|
|
705
703
|
|
|
706
704
|
for block in self.blocks:
|
|
707
705
|
# Get attention scores with shape (batch, 1, head, seq_len, seq_len)
|
|
708
|
-
|
|
709
|
-
layer_scores.append(
|
|
706
|
+
layer_attention_logits = block.attention_logits(x).unsqueeze(1)
|
|
707
|
+
layer_scores.append(layer_attention_logits)
|
|
710
708
|
x = block(x)
|
|
711
709
|
|
|
712
710
|
return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
|
broccoli/vit.py
CHANGED
|
@@ -400,9 +400,9 @@ class ViTEncoder(nn.Module):
|
|
|
400
400
|
def forward(self, x):
|
|
401
401
|
return self.encoder(x)
|
|
402
402
|
|
|
403
|
-
def
|
|
403
|
+
def attention_logits(self, x):
|
|
404
404
|
x = self.encoder[:-1](x)
|
|
405
|
-
return self.encoder[-1].
|
|
405
|
+
return self.encoder[-1].attention_logits(x)
|
|
406
406
|
|
|
407
407
|
def reset_parameters(self):
|
|
408
408
|
for module in self.encoder:
|
|
@@ -546,16 +546,16 @@ class ViT(nn.Module):
|
|
|
546
546
|
def forward(self, x):
|
|
547
547
|
return self.pool(self.encoder(x))
|
|
548
548
|
|
|
549
|
-
def
|
|
550
|
-
return self.encoder.
|
|
549
|
+
def attention_logits(self, x):
|
|
550
|
+
return self.encoder.attention_logits(x)
|
|
551
551
|
|
|
552
552
|
def head_to_bos_token_attention(self, x):
|
|
553
|
-
all_attention = self.
|
|
553
|
+
all_attention = self.attention_logits(x)
|
|
554
554
|
batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
|
|
555
555
|
sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
|
|
556
556
|
n_bos_tokens = self.encoder.encoder[-1]._bos_tokens
|
|
557
557
|
just_bos = sequence_averages[:, :, :n_bos_tokens]
|
|
558
|
-
return F.softmax(just_bos, dim
|
|
558
|
+
return F.softmax(just_bos, dim=1) # (layer, head, bos_token)
|
|
559
559
|
|
|
560
560
|
def reset_parameters(self):
|
|
561
561
|
self.encoder.reset_parameters()
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=hLk4p-aS2xJiKHellkFKow5bgkAinoMlS8xTT5aQnII,23054
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-5.1.
|
|
11
|
-
broccoli_ml-5.1.
|
|
12
|
-
broccoli_ml-5.1.
|
|
13
|
-
broccoli_ml-5.1.
|
|
9
|
+
broccoli/vit.py,sha256=Dy1egtEPft0l3812BENooc7SNogEibRPamXzy4jaJxA,20428
|
|
10
|
+
broccoli_ml-5.1.4.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-5.1.4.dist-info/METADATA,sha256=s59fAtdmIjKhdP29kTPEgMp5buRGJzj5WgaBxsW83yg,1368
|
|
12
|
+
broccoli_ml-5.1.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-5.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|