broccoli-ml 5.1.0__py3-none-any.whl → 5.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +4 -2
- broccoli/vit.py +1 -1
- {broccoli_ml-5.1.0.dist-info → broccoli_ml-5.1.2.dist-info}/METADATA +1 -1
- {broccoli_ml-5.1.0.dist-info → broccoli_ml-5.1.2.dist-info}/RECORD +6 -6
- {broccoli_ml-5.1.0.dist-info → broccoli_ml-5.1.2.dist-info}/LICENSE +0 -0
- {broccoli_ml-5.1.0.dist-info → broccoli_ml-5.1.2.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -535,9 +535,9 @@ class TransformerBlock(nn.Module):
|
|
|
535
535
|
"""
|
|
536
536
|
if self.pre_norm:
|
|
537
537
|
x = self.layer_norm_1(x)
|
|
538
|
-
return self.attn(x, x, x)
|
|
538
|
+
return self.attn.attention_scores(x, x, x)
|
|
539
539
|
else:
|
|
540
|
-
return self.attn(x, x, x)
|
|
540
|
+
return self.attn.attention_scores(x, x, x)
|
|
541
541
|
|
|
542
542
|
def reset_parameters(self):
|
|
543
543
|
self.layer_norm_1.reset_parameters()
|
|
@@ -683,6 +683,8 @@ class TransformerEncoder(nn.Module):
|
|
|
683
683
|
) # to shape (1, seq_len) to broadcast over batch
|
|
684
684
|
)
|
|
685
685
|
|
|
686
|
+
return x
|
|
687
|
+
|
|
686
688
|
def forward(self, x):
|
|
687
689
|
|
|
688
690
|
x = self.preprocess(x)
|
broccoli/vit.py
CHANGED
|
@@ -553,7 +553,7 @@ class ViT(nn.Module):
|
|
|
553
553
|
all_attention = self.attention_scores(x)
|
|
554
554
|
batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
|
|
555
555
|
sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
|
|
556
|
-
n_bos_tokens = self.encoder.encoder._bos_tokens
|
|
556
|
+
n_bos_tokens = self.encoder.encoder[-1]._bos_tokens
|
|
557
557
|
just_bos = sequence_averages[:, :, :n_bos_tokens]
|
|
558
558
|
return F.softmax(just_bos, dim=-1) # (layer, head, bos_token)
|
|
559
559
|
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=Gn8fhwgSq-izRc2nrsZ4JC5NsTIzQWFrol8pJ2pfzL4,23104
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-5.1.
|
|
11
|
-
broccoli_ml-5.1.
|
|
12
|
-
broccoli_ml-5.1.
|
|
13
|
-
broccoli_ml-5.1.
|
|
9
|
+
broccoli/vit.py,sha256=R73GTxx41FjVwAu4KDlYBAhytFo_9xVbbRbtBjAIW0s,20429
|
|
10
|
+
broccoli_ml-5.1.2.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-5.1.2.dist-info/METADATA,sha256=iYvlWeqDeDaGDO_up_xC0NBLo3AhJOi6C_5BqWEgNYI,1368
|
|
12
|
+
broccoli_ml-5.1.2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-5.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|