paddleformers 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. paddleformers/__init__.py +60 -0
  2. paddleformers/data/__init__.py +21 -0
  3. paddleformers/data/blendable_dataset.py +184 -0
  4. paddleformers/data/causal_dataset.py +711 -0
  5. paddleformers/data/collate.py +321 -0
  6. paddleformers/data/data_collator.py +782 -0
  7. paddleformers/data/dist_dataloader.py +215 -0
  8. paddleformers/data/indexed_dataset.py +972 -0
  9. paddleformers/data/iterator.py +15 -0
  10. paddleformers/data/sampler.py +416 -0
  11. paddleformers/data/tokenizer.py +33 -0
  12. paddleformers/data/vocab.py +579 -0
  13. paddleformers/datasets/__init__.py +18 -0
  14. paddleformers/datasets/dataset.py +773 -0
  15. paddleformers/datasets/embedding_dataset.py +256 -0
  16. paddleformers/datasets/rlhf_datasets/__init__.py +16 -0
  17. paddleformers/datasets/rlhf_datasets/protocol.py +607 -0
  18. paddleformers/datasets/rlhf_datasets/rl_dataset.py +168 -0
  19. paddleformers/datasets/zero_padding_dataset.py +234 -0
  20. paddleformers/generation/__init__.py +34 -0
  21. paddleformers/generation/configuration_utils.py +597 -0
  22. paddleformers/generation/logits_process.py +646 -0
  23. paddleformers/generation/stopping_criteria.py +91 -0
  24. paddleformers/generation/streamers.py +216 -0
  25. paddleformers/generation/utils.py +1809 -0
  26. paddleformers/mergekit/__init__.py +19 -0
  27. paddleformers/mergekit/merge_config.py +186 -0
  28. paddleformers/mergekit/merge_method.py +207 -0
  29. paddleformers/mergekit/merge_model.py +803 -0
  30. paddleformers/mergekit/merge_utils.py +73 -0
  31. paddleformers/mergekit/sparsify_method.py +118 -0
  32. paddleformers/ops/__init__.py +15 -0
  33. paddleformers/ops/topo.py +84 -0
  34. paddleformers/peft/__init__.py +19 -0
  35. paddleformers/peft/lokr/__init__.py +19 -0
  36. paddleformers/peft/lokr/lokr_config.py +141 -0
  37. paddleformers/peft/lokr/lokr_layers.py +240 -0
  38. paddleformers/peft/lokr/lokr_model.py +296 -0
  39. paddleformers/peft/lora/__init__.py +18 -0
  40. paddleformers/peft/lora/auto_lora_model.py +799 -0
  41. paddleformers/peft/lora/lora_config.py +197 -0
  42. paddleformers/peft/lora/lora_layers.py +942 -0
  43. paddleformers/peft/lora/lora_model.py +839 -0
  44. paddleformers/peft/lora/lora_quant_layers.py +272 -0
  45. paddleformers/peft/lora/lora_quantization_layers.py +284 -0
  46. paddleformers/peft/lora/lora_quick_layers.py +223 -0
  47. paddleformers/peft/lora/loraga_utils.py +395 -0
  48. paddleformers/peft/lora/utils.py +21 -0
  49. paddleformers/peft/prefix/__init__.py +23 -0
  50. paddleformers/peft/prefix/prefix_config.py +102 -0
  51. paddleformers/peft/prefix/prefix_model.py +549 -0
  52. paddleformers/peft/prefix/utils.py +52 -0
  53. paddleformers/peft/reft/__init__.py +24 -0
  54. paddleformers/peft/reft/interventions.py +148 -0
  55. paddleformers/peft/reft/modeling_utils.py +175 -0
  56. paddleformers/peft/reft/predict.py +132 -0
  57. paddleformers/peft/reft/reft_config.py +85 -0
  58. paddleformers/peft/reft/reft_model.py +365 -0
  59. paddleformers/peft/vera/__init__.py +17 -0
  60. paddleformers/peft/vera/vera_config.py +131 -0
  61. paddleformers/peft/vera/vera_layers.py +149 -0
  62. paddleformers/peft/vera/vera_model.py +284 -0
  63. paddleformers/quantization/__init__.py +15 -0
  64. paddleformers/quantization/checkpoint_quantization_utils.py +364 -0
  65. paddleformers/quantization/hadamard_utils.py +69 -0
  66. paddleformers/quantization/qat_utils.py +433 -0
  67. paddleformers/quantization/qlora.py +115 -0
  68. paddleformers/quantization/quantization_config.py +246 -0
  69. paddleformers/quantization/quantization_linear.py +721 -0
  70. paddleformers/quantization/quantization_utils.py +278 -0
  71. paddleformers/quantization/unified_checkpoint_quantization.py +221 -0
  72. paddleformers/trainer/__init__.py +21 -0
  73. paddleformers/trainer/argparser.py +469 -0
  74. paddleformers/trainer/auto_trainer.py +1036 -0
  75. paddleformers/trainer/auto_training_args.py +147 -0
  76. paddleformers/trainer/integrations.py +521 -0
  77. paddleformers/trainer/plugins/__init__.py +13 -0
  78. paddleformers/trainer/plugins/npu_plugin.py +127 -0
  79. paddleformers/trainer/plugins/timer.py +176 -0
  80. paddleformers/trainer/trainer.py +3829 -0
  81. paddleformers/trainer/trainer_callback.py +610 -0
  82. paddleformers/trainer/trainer_utils.py +1254 -0
  83. paddleformers/trainer/training_args.py +2508 -0
  84. paddleformers/trainer/unified_checkpoint/__init__.py +15 -0
  85. paddleformers/trainer/unified_checkpoint/async_handler.py +265 -0
  86. paddleformers/trainer/unified_checkpoint/check_completion.py +248 -0
  87. paddleformers/trainer/unified_checkpoint/load_dynamic.py +491 -0
  88. paddleformers/trainer/unified_checkpoint/load_local.py +297 -0
  89. paddleformers/trainer/unified_checkpoint/load_save_single_card.py +234 -0
  90. paddleformers/trainer/unified_checkpoint/sharding_split_param_utils.py +471 -0
  91. paddleformers/trainer/unified_checkpoint/shared_memory_utils.py +149 -0
  92. paddleformers/trainer/unified_checkpoint/unified_checkpoint.py +789 -0
  93. paddleformers/trainer/unified_checkpoint/utils.py +804 -0
  94. paddleformers/trainer/utils/__init__.py +20 -0
  95. paddleformers/trainer/utils/async_save.py +126 -0
  96. paddleformers/trainer/utils/ckpt_converter.py +1221 -0
  97. paddleformers/trainer/utils/doc.py +54 -0
  98. paddleformers/trainer/utils/helper.py +338 -0
  99. paddleformers/trainer/utils/reshard/__init__.py +23 -0
  100. paddleformers/trainer/utils/reshard/common.py +587 -0
  101. paddleformers/trainer/utils/reshard/pp_reshard.py +365 -0
  102. paddleformers/trainer/utils/reshard/sharding_v1.py +42 -0
  103. paddleformers/trainer/utils/reshard/sharding_v2.py +231 -0
  104. paddleformers/trainer/utils/sharding_io.py +605 -0
  105. paddleformers/trainer/utils/zero_cost_checkpoint.py +984 -0
  106. paddleformers/transformers/__init__.py +72 -0
  107. paddleformers/transformers/activations.py +174 -0
  108. paddleformers/transformers/aistudio_utils.py +70 -0
  109. paddleformers/transformers/attention_utils.py +619 -0
  110. paddleformers/transformers/audio_utils.py +694 -0
  111. paddleformers/transformers/auto/__init__.py +14 -0
  112. paddleformers/transformers/auto/configuration.py +355 -0
  113. paddleformers/transformers/auto/factory.py +146 -0
  114. paddleformers/transformers/auto/image_processing.py +176 -0
  115. paddleformers/transformers/auto/modeling.py +984 -0
  116. paddleformers/transformers/auto/processing.py +184 -0
  117. paddleformers/transformers/auto/tokenizer.py +478 -0
  118. paddleformers/transformers/auto_utils.py +59 -0
  119. paddleformers/transformers/bert/__init__.py +13 -0
  120. paddleformers/transformers/bert/configuration.py +407 -0
  121. paddleformers/transformers/bert/modeling.py +1420 -0
  122. paddleformers/transformers/bert/modeling.pyi +347 -0
  123. paddleformers/transformers/bert/tokenizer.py +630 -0
  124. paddleformers/transformers/bert/tokenizer_fast.py +165 -0
  125. paddleformers/transformers/configuration_utils.py +1236 -0
  126. paddleformers/transformers/context_parallel_utils.py +64 -0
  127. paddleformers/transformers/contrastive_loss.py +152 -0
  128. paddleformers/transformers/conversion_utils.py +1661 -0
  129. paddleformers/transformers/convert_slow_tokenizer.py +691 -0
  130. paddleformers/transformers/deepseek_v2/__init__.py +19 -0
  131. paddleformers/transformers/deepseek_v2/configuration.py +237 -0
  132. paddleformers/transformers/deepseek_v2/fp8_linear.py +138 -0
  133. paddleformers/transformers/deepseek_v2/kernel.py +227 -0
  134. paddleformers/transformers/deepseek_v2/mfu_utils.py +206 -0
  135. paddleformers/transformers/deepseek_v2/modeling.py +2334 -0
  136. paddleformers/transformers/deepseek_v2/modeling_auto.py +1263 -0
  137. paddleformers/transformers/deepseek_v2/modeling_pp.py +501 -0
  138. paddleformers/transformers/deepseek_v2/tokenizer_fast.py +54 -0
  139. paddleformers/transformers/deepseek_v3/__init__.py +18 -0
  140. paddleformers/transformers/deepseek_v3/configuration.py +33 -0
  141. paddleformers/transformers/deepseek_v3/modeling.py +170 -0
  142. paddleformers/transformers/deepseek_v3/modeling_auto.py +209 -0
  143. paddleformers/transformers/deepseek_v3/modeling_pp.py +40 -0
  144. paddleformers/transformers/dpo_criterion.py +379 -0
  145. paddleformers/transformers/embedding_utils.py +51 -0
  146. paddleformers/transformers/export.py +68 -0
  147. paddleformers/transformers/feature_extraction_sequence_utils.py +365 -0
  148. paddleformers/transformers/feature_extraction_utils.py +377 -0
  149. paddleformers/transformers/fused_a2a.py +216 -0
  150. paddleformers/transformers/image_processing_utils.py +547 -0
  151. paddleformers/transformers/image_transforms.py +655 -0
  152. paddleformers/transformers/image_utils.py +621 -0
  153. paddleformers/transformers/kto_criterion.py +275 -0
  154. paddleformers/transformers/linear_utils.py +90 -0
  155. paddleformers/transformers/llama/__init__.py +21 -0
  156. paddleformers/transformers/llama/configuration.py +211 -0
  157. paddleformers/transformers/llama/fusion_ops.py +313 -0
  158. paddleformers/transformers/llama/modeling.py +2161 -0
  159. paddleformers/transformers/llama/modeling_auto.py +1360 -0
  160. paddleformers/transformers/llama/modeling_network.py +1224 -0
  161. paddleformers/transformers/llama/modeling_pp.py +468 -0
  162. paddleformers/transformers/llama/tokenizer.py +502 -0
  163. paddleformers/transformers/llama/tokenizer_fast.py +171 -0
  164. paddleformers/transformers/long_sequence_strategies/__init__.py +18 -0
  165. paddleformers/transformers/long_sequence_strategies/attention_strategies.py +51 -0
  166. paddleformers/transformers/long_sequence_strategies/embedding_strategies.py +223 -0
  167. paddleformers/transformers/long_sequence_strategies/long_sequence_strategies.py +68 -0
  168. paddleformers/transformers/mc2_parallel_linear.py +230 -0
  169. paddleformers/transformers/model_outputs.py +1568 -0
  170. paddleformers/transformers/model_utils.py +3315 -0
  171. paddleformers/transformers/moe_gate.py +588 -0
  172. paddleformers/transformers/moe_gate_auto.py +656 -0
  173. paddleformers/transformers/moe_layer.py +378 -0
  174. paddleformers/transformers/moe_layer_auto.py +329 -0
  175. paddleformers/transformers/moe_utils.py +101 -0
  176. paddleformers/transformers/ofa_utils.py +326 -0
  177. paddleformers/transformers/optimization.py +304 -0
  178. paddleformers/transformers/processing_utils.py +136 -0
  179. paddleformers/transformers/qwen/__init__.py +19 -0
  180. paddleformers/transformers/qwen/configuration.py +86 -0
  181. paddleformers/transformers/qwen/modeling.py +1293 -0
  182. paddleformers/transformers/qwen/modeling_auto.py +970 -0
  183. paddleformers/transformers/qwen/modeling_network.py +830 -0
  184. paddleformers/transformers/qwen/modeling_pp.py +220 -0
  185. paddleformers/transformers/qwen/tokenizer.py +270 -0
  186. paddleformers/transformers/qwen2/__init__.py +20 -0
  187. paddleformers/transformers/qwen2/configuration.py +164 -0
  188. paddleformers/transformers/qwen2/modeling.py +1973 -0
  189. paddleformers/transformers/qwen2/modeling_pp.py +363 -0
  190. paddleformers/transformers/qwen2/tokenizer.py +448 -0
  191. paddleformers/transformers/qwen2/tokenizer_fast.py +131 -0
  192. paddleformers/transformers/qwen2_moe/__init__.py +18 -0
  193. paddleformers/transformers/qwen2_moe/configuration.py +186 -0
  194. paddleformers/transformers/qwen2_moe/modeling.py +1735 -0
  195. paddleformers/transformers/qwen2_moe/modeling_pp.py +354 -0
  196. paddleformers/transformers/qwen3/__init__.py +18 -0
  197. paddleformers/transformers/qwen3/configuration.py +195 -0
  198. paddleformers/transformers/qwen3/modeling.py +1448 -0
  199. paddleformers/transformers/qwen3/modeling_pp.py +363 -0
  200. paddleformers/transformers/qwen3_moe/__init__.py +17 -0
  201. paddleformers/transformers/qwen3_moe/configuration.py +219 -0
  202. paddleformers/transformers/qwen3_moe/modeling.py +996 -0
  203. paddleformers/transformers/qwen3_moe/modeling_pp.py +254 -0
  204. paddleformers/transformers/refined_recompute.py +791 -0
  205. paddleformers/transformers/ring_flash_attention.py +353 -0
  206. paddleformers/transformers/segment_parallel_utils.py +137 -0
  207. paddleformers/transformers/sentencepiece_model_pb2.py +1534 -0
  208. paddleformers/transformers/sequence_parallel_utils.py +139 -0
  209. paddleformers/transformers/tensor_parallel_utils.py +583 -0
  210. paddleformers/transformers/token_dispatcher.py +284 -0
  211. paddleformers/transformers/tokenizer_utils.py +2199 -0
  212. paddleformers/transformers/tokenizer_utils_base.py +3655 -0
  213. paddleformers/transformers/tokenizer_utils_fast.py +882 -0
  214. paddleformers/transformers/transposed_linear.py +59 -0
  215. paddleformers/transformers/utils.py +1006 -0
  216. paddleformers/trl/__init__.py +28 -0
  217. paddleformers/trl/dpo_auto_trainer.py +1153 -0
  218. paddleformers/trl/dpo_trainer.py +571 -0
  219. paddleformers/trl/embedding_trainer.py +192 -0
  220. paddleformers/trl/extras/__init__.py +13 -0
  221. paddleformers/trl/extras/dataset_formatting.py +129 -0
  222. paddleformers/trl/kto_trainer.py +554 -0
  223. paddleformers/trl/llm_utils.py +944 -0
  224. paddleformers/trl/model_config.py +147 -0
  225. paddleformers/trl/quant_config.py +118 -0
  226. paddleformers/trl/sft_auto_trainer.py +911 -0
  227. paddleformers/trl/sft_config.py +102 -0
  228. paddleformers/trl/sft_trainer.py +421 -0
  229. paddleformers/trl/sftdata_config.py +68 -0
  230. paddleformers/trl/trl_data.py +268 -0
  231. paddleformers/trl/trl_utils.py +49 -0
  232. paddleformers/trl/utils.py +43 -0
  233. paddleformers/utils/__init__.py +50 -0
  234. paddleformers/utils/adamw_triton.py +194 -0
  235. paddleformers/utils/batch_sampler.py +182 -0
  236. paddleformers/utils/converter.py +18 -0
  237. paddleformers/utils/distributed.py +224 -0
  238. paddleformers/utils/doc_parser.py +432 -0
  239. paddleformers/utils/download/__init__.py +340 -0
  240. paddleformers/utils/download/aistudio_hub_download.py +728 -0
  241. paddleformers/utils/download/bos_download.py +287 -0
  242. paddleformers/utils/download/common.py +662 -0
  243. paddleformers/utils/downloader.py +592 -0
  244. paddleformers/utils/env.py +166 -0
  245. paddleformers/utils/fault_tolerance.py +38 -0
  246. paddleformers/utils/ie_utils.py +142 -0
  247. paddleformers/utils/image_utils.py +734 -0
  248. paddleformers/utils/import_utils.py +384 -0
  249. paddleformers/utils/infohub.py +53 -0
  250. paddleformers/utils/initializer.py +337 -0
  251. paddleformers/utils/log.py +204 -0
  252. paddleformers/utils/memory_utils.py +39 -0
  253. paddleformers/utils/nested.py +128 -0
  254. paddleformers/utils/optimizer.py +608 -0
  255. paddleformers/utils/paddle_patch.py +148 -0
  256. paddleformers/utils/pdc_sdk.py +629 -0
  257. paddleformers/utils/profiler.py +130 -0
  258. paddleformers/utils/safetensors.py +320 -0
  259. paddleformers/utils/serialization.py +309 -0
  260. paddleformers/utils/tools.py +247 -0
  261. paddleformers/version/__init__.py +25 -0
  262. paddleformers/version/git.py +48 -0
  263. paddleformers-0.1.dist-info/METADATA +48 -0
  264. paddleformers-0.1.dist-info/RECORD +268 -0
  265. paddleformers-0.1.dist-info/WHEEL +5 -0
  266. paddleformers-0.1.dist-info/entry_points.txt +2 -0
  267. paddleformers-0.1.dist-info/licenses/LICENSE +203 -0
  268. paddleformers-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,60 @@
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ from datetime import datetime
18
+
19
+ PADDLEFORMERS_STABLE_VERSION = "PADDLEFORMERS_STABLE_VERSION"
20
+
21
+ # this version is used for develop and test.
22
+ # release version will be added fixed version by setup.py.
23
+ __version__ = "0.1.post"
24
+ if os.getenv(PADDLEFORMERS_STABLE_VERSION):
25
+ __version__ = __version__.replace(".post", "")
26
+ else:
27
+ formatted_date = datetime.now().date().strftime("%Y%m%d")
28
+ __version__ = __version__.replace(".post", ".post{}".format(formatted_date))
29
+
30
+ # the next line will be replaced by setup.py for release version.
31
+
32
+ __version__ = "0.1"
33
+
34
+
35
+
36
+ if "datasets" in sys.modules.keys():
37
+ from paddleformers.utils.log import logger
38
+
39
+ logger.warning(
40
+ "Detected that datasets module was imported before paddleformers. "
41
+ "This may cause PaddleFormers datasets to be unavailable in intranet. "
42
+ "Please import paddleformers before datasets module to avoid download issues"
43
+ )
44
+ import paddle
45
+
46
+ from . import (
47
+ data,
48
+ datasets,
49
+ mergekit,
50
+ ops,
51
+ peft,
52
+ quantization,
53
+ trainer,
54
+ transformers,
55
+ trl,
56
+ utils,
57
+ version,
58
+ )
59
+
60
+ paddle.disable_signal_handler()
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .blendable_dataset import *
16
+ from .causal_dataset import *
17
+ from .collate import *
18
+ from .data_collator import *
19
+ from .dist_dataloader import *
20
+ from .sampler import *
21
+ from .vocab import *
@@ -0,0 +1,184 @@
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import hashlib
16
+ import importlib.metadata
17
+ import os
18
+ import time
19
+
20
+ import numpy as np
21
+ import paddle
22
+
23
+ local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
24
+
25
+
26
+ def print_rank_0(*args, **kwargs):
27
+ if paddle.distributed.get_rank() == 0:
28
+ print(*args, **kwargs)
29
+
30
+
31
+ class BlendableDataset(paddle.io.Dataset):
32
+ def __init__(self, datasets, weights, size, share_folder, *, data_cache_path=None):
33
+
34
+ self.datasets = datasets
35
+ num_datasets = len(datasets)
36
+ assert num_datasets == len(weights)
37
+
38
+ self.size = size
39
+
40
+ # Normalize weights.
41
+ weights = np.array(weights, dtype=np.float64)
42
+ sum_weights = np.sum(weights)
43
+ assert sum_weights > 0.0
44
+ weights /= sum_weights
45
+
46
+ # Build indices.
47
+ def _build_indices():
48
+ start_time = time.time()
49
+
50
+ fast_dataindex_version = importlib.metadata.version("fast_dataindex")
51
+ if fast_dataindex_version > "0.1.1":
52
+ assert (
53
+ num_datasets < 32767
54
+ ), f"Detect num_datasets({num_datasets})>=32767. Currently, num_datasets should be less than 32767."
55
+ dataset_index = np.zeros(self.size, dtype=np.int16)
56
+ else:
57
+ assert (
58
+ num_datasets < 255
59
+ ), f"Detect num_datasets:({num_datasets})>=255. When 'fast_dataindex<=0.1.1', num_datasets should be less than 255. To support num_datasets greater than 255, please upgrade `fast_dataindex>=0.1.2`."
60
+ dataset_index = np.zeros(self.size, dtype=np.uint8)
61
+ dataset_sample_index = np.zeros(self.size, dtype=np.int64)
62
+
63
+ from fast_dataindex import helpers
64
+
65
+ helpers.build_blending_indices(
66
+ dataset_index,
67
+ dataset_sample_index,
68
+ weights,
69
+ num_datasets,
70
+ self.size,
71
+ local_rank == 0,
72
+ # paddle.distributed.get_rank() == 0,
73
+ )
74
+ print_rank_0(
75
+ "> elapsed time for building blendable dataset indices: "
76
+ "{:.2f} (sec)".format(time.time() - start_time)
77
+ )
78
+ return dataset_index, dataset_sample_index
79
+
80
+ desc = "Blendable dataset\n\n"
81
+ desc += "Datasets:\n"
82
+ for dataset in datasets:
83
+ desc += dataset.desc + "\n\n"
84
+ desc += f"Weights: {weights}\n"
85
+ desc += f"Size: {size}\n"
86
+ self.desc = desc
87
+
88
+ if data_cache_path:
89
+ desc_hash = hashlib.md5(desc.encode("utf-8")).hexdigest()
90
+ desc_path = os.path.join(data_cache_path, desc_hash + ".dsc")
91
+ index_path = os.path.join(data_cache_path, desc_hash + "_index.npy")
92
+ sample_index_path = os.path.join(data_cache_path, desc_hash + "_sample_index.npy")
93
+ cache_hit = os.path.isfile(index_path) and os.path.isfile(sample_index_path)
94
+ # cache_success = True
95
+ # if paddle.distributed.get_rank() == 0 and not cache_hit:
96
+ check_rank_flag = not cache_hit and local_rank == 0
97
+ if share_folder:
98
+ check_rank_flag = not cache_hit and paddle.distributed.get_rank() == 0
99
+
100
+ print(
101
+ f"searching for blendable dataset, cache_hit={cache_hit}, share_folder {share_folder}, check_rank_flag {check_rank_flag}",
102
+ flush=True,
103
+ )
104
+ if check_rank_flag:
105
+ print(
106
+ " > WARNING: could not find index map files for blendable"
107
+ " dataset, building indices on rank 0 ...",
108
+ flush=True,
109
+ )
110
+ dataset_index, dataset_sample_index = _build_indices()
111
+ try:
112
+ os.makedirs(os.path.dirname(index_path), exist_ok=True)
113
+ with open(desc_path, "wt") as fd:
114
+ fd.write(desc)
115
+ np.save(index_path, dataset_index, allow_pickle=True)
116
+ np.save(sample_index_path, dataset_sample_index, allow_pickle=True)
117
+ except OSError:
118
+ print(f"There was an error trying to create the data cache directory ({data_cache_path})")
119
+ print("or a file in it. This is set with the --data-cache-path argument. Please")
120
+ print("ensure you have write access to this directory or specify one that you do have")
121
+ print("write access to.")
122
+ # cache_success = False
123
+
124
+ # hcg = paddle.distributed.fleet.get_hybrid_communicate_group()
125
+
126
+ # counts = paddle.to_tensor([cache_success], dtype="int64")
127
+ # paddle.distributed.all_reduce(counts, group=hcg.get_data_parallel_group())
128
+ # paddle.distributed.all_reduce(counts, group=hcg.get_pipeline_model_parallel_group())
129
+ # if counts[0].item() != (
130
+ # paddle.distributed.get_world_size()
131
+ # // paddle.distributed.get_world_size(group=hcg.get_tensor_model_parallel_group())
132
+ # ):
133
+ # print_rank_0("Data index creation unsuccessful, exiting.")
134
+ # exit()
135
+
136
+ else:
137
+ while True:
138
+ if (not os.path.isfile(index_path)) or (not os.path.isfile(sample_index_path)):
139
+ print("building indices on rank 0 ...", flush=True)
140
+ time.sleep(3)
141
+ else:
142
+ try:
143
+ np.load(index_path, allow_pickle=True, mmap_mode="r")
144
+ print("build success", flush=True)
145
+ break
146
+ except Exception:
147
+ print("%s file is still writing or damaged, please wait for a moment." % index_path)
148
+ time.sleep(3)
149
+
150
+ # paddle.distributed.barrier()
151
+ # Load on all ranks.
152
+ print_rank_0(f"> loading blendable dataset index: {index_path}")
153
+ self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode="r")
154
+ assert self.dataset_index.size == self.size
155
+
156
+ print_rank_0(f"> loading blendable dataset sample index: {sample_index_path}")
157
+ self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode="r")
158
+ assert self.dataset_sample_index.size == self.size
159
+ else:
160
+ print_rank_0(
161
+ "building indices for the blendable dataset, Since --data_cache is not specified, the index file will not be stored.",
162
+ flush=True,
163
+ )
164
+ self.dataset_index, self.dataset_sample_index = _build_indices()
165
+
166
+ # Check size
167
+ _ = self.__getitem__(self.size - 1)
168
+ try:
169
+ _ = self.__getitem__(self.size)
170
+ raise RuntimeError("BlendedDataset size is improperly bounded")
171
+ except IndexError:
172
+ pass
173
+ print_rank_0("> size of blendable dataset: " "{} samples".format(self.size))
174
+
175
+ def __len__(self):
176
+ return self.size
177
+
178
+ def __getitem__(self, idx):
179
+ dataset_idx = self.dataset_index[idx]
180
+ sample_idx = self.dataset_sample_index[idx]
181
+ return {
182
+ "dataset_idx": dataset_idx,
183
+ **self.datasets[dataset_idx][sample_idx],
184
+ }