broccoli-ml 6.0.0__py3-none-any.whl → 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +12 -2
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-7.0.0.dist-info}/METADATA +1 -1
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-7.0.0.dist-info}/RECORD +5 -5
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-7.0.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-7.0.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -13,6 +13,7 @@ from .rope import RotaryEmbedding, apply_rotary_emb
|
|
|
13
13
|
try:
|
|
14
14
|
from flash_attn import flash_attn_func
|
|
15
15
|
|
|
16
|
+
print("Using flash-attn.")
|
|
16
17
|
FLASH_ATTN = True
|
|
17
18
|
except ImportError:
|
|
18
19
|
pass
|
|
@@ -590,7 +591,13 @@ class TransformerEncoder(nn.Module):
|
|
|
590
591
|
if relative_position_embedding and (source_size is None):
|
|
591
592
|
raise ValueError(
|
|
592
593
|
"`source_size` for TransformerEncoder cannot be None if"
|
|
593
|
-
" `
|
|
594
|
+
" `relative_position_embedding` is True"
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
if absolute_position_embedding and (seq_len is None):
|
|
598
|
+
raise ValueError(
|
|
599
|
+
"`seq_len` for TransformerEncoder cannot be None if"
|
|
600
|
+
" `absolute_position_embedding` is True"
|
|
594
601
|
)
|
|
595
602
|
|
|
596
603
|
super().__init__()
|
|
@@ -605,9 +612,12 @@ class TransformerEncoder(nn.Module):
|
|
|
605
612
|
torch.empty(self._utility_tokens, d_model)
|
|
606
613
|
)
|
|
607
614
|
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
608
|
-
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
609
615
|
else:
|
|
610
616
|
self._utility_token_embedding = None
|
|
617
|
+
|
|
618
|
+
if self._utility_tokens and (self.seq_len is not None):
|
|
619
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
620
|
+
else:
|
|
611
621
|
self.full_sequence_length = self.seq_len
|
|
612
622
|
|
|
613
623
|
self.d_model = d_model
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256
|
|
7
|
+
broccoli/transformer.py,sha256=-NRO8mzvOkQPCuTrF6OTDQ8sUIQ3_j4HxP-NfTUxu10,23636
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
9
|
broccoli/vit.py,sha256=9oyh76ulmX5lDPMCDicQhhqm8RYCvJIgAJkDbYRVdi4,20873
|
|
10
|
-
broccoli_ml-
|
|
11
|
-
broccoli_ml-
|
|
12
|
-
broccoli_ml-
|
|
13
|
-
broccoli_ml-
|
|
10
|
+
broccoli_ml-7.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-7.0.0.dist-info/METADATA,sha256=OdAph_5ItouyYHeZwVgbrTiyexljFb9uKNLNWlnR4wM,1368
|
|
12
|
+
broccoli_ml-7.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-7.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|