broccoli-ml 6.0.1__py3-none-any.whl → 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +11 -2
- {broccoli_ml-6.0.1.dist-info → broccoli_ml-7.0.0.dist-info}/METADATA +1 -1
- {broccoli_ml-6.0.1.dist-info → broccoli_ml-7.0.0.dist-info}/RECORD +5 -5
- {broccoli_ml-6.0.1.dist-info → broccoli_ml-7.0.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-6.0.1.dist-info → broccoli_ml-7.0.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -591,7 +591,13 @@ class TransformerEncoder(nn.Module):
|
|
|
591
591
|
if relative_position_embedding and (source_size is None):
|
|
592
592
|
raise ValueError(
|
|
593
593
|
"`source_size` for TransformerEncoder cannot be None if"
|
|
594
|
-
" `
|
|
594
|
+
" `relative_position_embedding` is True"
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
if absolute_position_embedding and (seq_len is None):
|
|
598
|
+
raise ValueError(
|
|
599
|
+
"`seq_len` for TransformerEncoder cannot be None if"
|
|
600
|
+
" `absolute_position_embedding` is True"
|
|
595
601
|
)
|
|
596
602
|
|
|
597
603
|
super().__init__()
|
|
@@ -606,9 +612,12 @@ class TransformerEncoder(nn.Module):
|
|
|
606
612
|
torch.empty(self._utility_tokens, d_model)
|
|
607
613
|
)
|
|
608
614
|
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
609
|
-
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
610
615
|
else:
|
|
611
616
|
self._utility_token_embedding = None
|
|
617
|
+
|
|
618
|
+
if self._utility_tokens and (self.seq_len is not None):
|
|
619
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
620
|
+
else:
|
|
612
621
|
self.full_sequence_length = self.seq_len
|
|
613
622
|
|
|
614
623
|
self.d_model = d_model
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256
|
|
7
|
+
broccoli/transformer.py,sha256=-NRO8mzvOkQPCuTrF6OTDQ8sUIQ3_j4HxP-NfTUxu10,23636
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
9
|
broccoli/vit.py,sha256=9oyh76ulmX5lDPMCDicQhhqm8RYCvJIgAJkDbYRVdi4,20873
|
|
10
|
-
broccoli_ml-
|
|
11
|
-
broccoli_ml-
|
|
12
|
-
broccoli_ml-
|
|
13
|
-
broccoli_ml-
|
|
10
|
+
broccoli_ml-7.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-7.0.0.dist-info/METADATA,sha256=OdAph_5ItouyYHeZwVgbrTiyexljFb9uKNLNWlnR4wM,1368
|
|
12
|
+
broccoli_ml-7.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-7.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|