paddleformers 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- paddleformers/__init__.py +60 -0
- paddleformers/data/__init__.py +21 -0
- paddleformers/data/blendable_dataset.py +184 -0
- paddleformers/data/causal_dataset.py +711 -0
- paddleformers/data/collate.py +321 -0
- paddleformers/data/data_collator.py +782 -0
- paddleformers/data/dist_dataloader.py +215 -0
- paddleformers/data/indexed_dataset.py +972 -0
- paddleformers/data/iterator.py +15 -0
- paddleformers/data/sampler.py +416 -0
- paddleformers/data/tokenizer.py +33 -0
- paddleformers/data/vocab.py +579 -0
- paddleformers/datasets/__init__.py +18 -0
- paddleformers/datasets/dataset.py +773 -0
- paddleformers/datasets/embedding_dataset.py +256 -0
- paddleformers/datasets/rlhf_datasets/__init__.py +16 -0
- paddleformers/datasets/rlhf_datasets/protocol.py +607 -0
- paddleformers/datasets/rlhf_datasets/rl_dataset.py +168 -0
- paddleformers/datasets/zero_padding_dataset.py +234 -0
- paddleformers/generation/__init__.py +34 -0
- paddleformers/generation/configuration_utils.py +597 -0
- paddleformers/generation/logits_process.py +646 -0
- paddleformers/generation/stopping_criteria.py +91 -0
- paddleformers/generation/streamers.py +216 -0
- paddleformers/generation/utils.py +1809 -0
- paddleformers/mergekit/__init__.py +19 -0
- paddleformers/mergekit/merge_config.py +186 -0
- paddleformers/mergekit/merge_method.py +207 -0
- paddleformers/mergekit/merge_model.py +803 -0
- paddleformers/mergekit/merge_utils.py +73 -0
- paddleformers/mergekit/sparsify_method.py +118 -0
- paddleformers/ops/__init__.py +15 -0
- paddleformers/ops/topo.py +84 -0
- paddleformers/peft/__init__.py +19 -0
- paddleformers/peft/lokr/__init__.py +19 -0
- paddleformers/peft/lokr/lokr_config.py +141 -0
- paddleformers/peft/lokr/lokr_layers.py +240 -0
- paddleformers/peft/lokr/lokr_model.py +296 -0
- paddleformers/peft/lora/__init__.py +18 -0
- paddleformers/peft/lora/auto_lora_model.py +799 -0
- paddleformers/peft/lora/lora_config.py +197 -0
- paddleformers/peft/lora/lora_layers.py +942 -0
- paddleformers/peft/lora/lora_model.py +839 -0
- paddleformers/peft/lora/lora_quant_layers.py +272 -0
- paddleformers/peft/lora/lora_quantization_layers.py +284 -0
- paddleformers/peft/lora/lora_quick_layers.py +223 -0
- paddleformers/peft/lora/loraga_utils.py +395 -0
- paddleformers/peft/lora/utils.py +21 -0
- paddleformers/peft/prefix/__init__.py +23 -0
- paddleformers/peft/prefix/prefix_config.py +102 -0
- paddleformers/peft/prefix/prefix_model.py +549 -0
- paddleformers/peft/prefix/utils.py +52 -0
- paddleformers/peft/reft/__init__.py +24 -0
- paddleformers/peft/reft/interventions.py +148 -0
- paddleformers/peft/reft/modeling_utils.py +175 -0
- paddleformers/peft/reft/predict.py +132 -0
- paddleformers/peft/reft/reft_config.py +85 -0
- paddleformers/peft/reft/reft_model.py +365 -0
- paddleformers/peft/vera/__init__.py +17 -0
- paddleformers/peft/vera/vera_config.py +131 -0
- paddleformers/peft/vera/vera_layers.py +149 -0
- paddleformers/peft/vera/vera_model.py +284 -0
- paddleformers/quantization/__init__.py +15 -0
- paddleformers/quantization/checkpoint_quantization_utils.py +364 -0
- paddleformers/quantization/hadamard_utils.py +69 -0
- paddleformers/quantization/qat_utils.py +433 -0
- paddleformers/quantization/qlora.py +115 -0
- paddleformers/quantization/quantization_config.py +246 -0
- paddleformers/quantization/quantization_linear.py +721 -0
- paddleformers/quantization/quantization_utils.py +278 -0
- paddleformers/quantization/unified_checkpoint_quantization.py +221 -0
- paddleformers/trainer/__init__.py +21 -0
- paddleformers/trainer/argparser.py +469 -0
- paddleformers/trainer/auto_trainer.py +1036 -0
- paddleformers/trainer/auto_training_args.py +147 -0
- paddleformers/trainer/integrations.py +521 -0
- paddleformers/trainer/plugins/__init__.py +13 -0
- paddleformers/trainer/plugins/npu_plugin.py +127 -0
- paddleformers/trainer/plugins/timer.py +176 -0
- paddleformers/trainer/trainer.py +3829 -0
- paddleformers/trainer/trainer_callback.py +610 -0
- paddleformers/trainer/trainer_utils.py +1254 -0
- paddleformers/trainer/training_args.py +2508 -0
- paddleformers/trainer/unified_checkpoint/__init__.py +15 -0
- paddleformers/trainer/unified_checkpoint/async_handler.py +265 -0
- paddleformers/trainer/unified_checkpoint/check_completion.py +248 -0
- paddleformers/trainer/unified_checkpoint/load_dynamic.py +491 -0
- paddleformers/trainer/unified_checkpoint/load_local.py +297 -0
- paddleformers/trainer/unified_checkpoint/load_save_single_card.py +234 -0
- paddleformers/trainer/unified_checkpoint/sharding_split_param_utils.py +471 -0
- paddleformers/trainer/unified_checkpoint/shared_memory_utils.py +149 -0
- paddleformers/trainer/unified_checkpoint/unified_checkpoint.py +789 -0
- paddleformers/trainer/unified_checkpoint/utils.py +804 -0
- paddleformers/trainer/utils/__init__.py +20 -0
- paddleformers/trainer/utils/async_save.py +126 -0
- paddleformers/trainer/utils/ckpt_converter.py +1221 -0
- paddleformers/trainer/utils/doc.py +54 -0
- paddleformers/trainer/utils/helper.py +338 -0
- paddleformers/trainer/utils/reshard/__init__.py +23 -0
- paddleformers/trainer/utils/reshard/common.py +587 -0
- paddleformers/trainer/utils/reshard/pp_reshard.py +365 -0
- paddleformers/trainer/utils/reshard/sharding_v1.py +42 -0
- paddleformers/trainer/utils/reshard/sharding_v2.py +231 -0
- paddleformers/trainer/utils/sharding_io.py +605 -0
- paddleformers/trainer/utils/zero_cost_checkpoint.py +984 -0
- paddleformers/transformers/__init__.py +72 -0
- paddleformers/transformers/activations.py +174 -0
- paddleformers/transformers/aistudio_utils.py +70 -0
- paddleformers/transformers/attention_utils.py +619 -0
- paddleformers/transformers/audio_utils.py +694 -0
- paddleformers/transformers/auto/__init__.py +14 -0
- paddleformers/transformers/auto/configuration.py +355 -0
- paddleformers/transformers/auto/factory.py +146 -0
- paddleformers/transformers/auto/image_processing.py +176 -0
- paddleformers/transformers/auto/modeling.py +984 -0
- paddleformers/transformers/auto/processing.py +184 -0
- paddleformers/transformers/auto/tokenizer.py +478 -0
- paddleformers/transformers/auto_utils.py +59 -0
- paddleformers/transformers/bert/__init__.py +13 -0
- paddleformers/transformers/bert/configuration.py +407 -0
- paddleformers/transformers/bert/modeling.py +1420 -0
- paddleformers/transformers/bert/modeling.pyi +347 -0
- paddleformers/transformers/bert/tokenizer.py +630 -0
- paddleformers/transformers/bert/tokenizer_fast.py +165 -0
- paddleformers/transformers/configuration_utils.py +1236 -0
- paddleformers/transformers/context_parallel_utils.py +64 -0
- paddleformers/transformers/contrastive_loss.py +152 -0
- paddleformers/transformers/conversion_utils.py +1661 -0
- paddleformers/transformers/convert_slow_tokenizer.py +691 -0
- paddleformers/transformers/deepseek_v2/__init__.py +19 -0
- paddleformers/transformers/deepseek_v2/configuration.py +237 -0
- paddleformers/transformers/deepseek_v2/fp8_linear.py +138 -0
- paddleformers/transformers/deepseek_v2/kernel.py +227 -0
- paddleformers/transformers/deepseek_v2/mfu_utils.py +206 -0
- paddleformers/transformers/deepseek_v2/modeling.py +2334 -0
- paddleformers/transformers/deepseek_v2/modeling_auto.py +1263 -0
- paddleformers/transformers/deepseek_v2/modeling_pp.py +501 -0
- paddleformers/transformers/deepseek_v2/tokenizer_fast.py +54 -0
- paddleformers/transformers/deepseek_v3/__init__.py +18 -0
- paddleformers/transformers/deepseek_v3/configuration.py +33 -0
- paddleformers/transformers/deepseek_v3/modeling.py +170 -0
- paddleformers/transformers/deepseek_v3/modeling_auto.py +209 -0
- paddleformers/transformers/deepseek_v3/modeling_pp.py +40 -0
- paddleformers/transformers/dpo_criterion.py +379 -0
- paddleformers/transformers/embedding_utils.py +51 -0
- paddleformers/transformers/export.py +68 -0
- paddleformers/transformers/feature_extraction_sequence_utils.py +365 -0
- paddleformers/transformers/feature_extraction_utils.py +377 -0
- paddleformers/transformers/fused_a2a.py +216 -0
- paddleformers/transformers/image_processing_utils.py +547 -0
- paddleformers/transformers/image_transforms.py +655 -0
- paddleformers/transformers/image_utils.py +621 -0
- paddleformers/transformers/kto_criterion.py +275 -0
- paddleformers/transformers/linear_utils.py +90 -0
- paddleformers/transformers/llama/__init__.py +21 -0
- paddleformers/transformers/llama/configuration.py +211 -0
- paddleformers/transformers/llama/fusion_ops.py +313 -0
- paddleformers/transformers/llama/modeling.py +2161 -0
- paddleformers/transformers/llama/modeling_auto.py +1360 -0
- paddleformers/transformers/llama/modeling_network.py +1224 -0
- paddleformers/transformers/llama/modeling_pp.py +468 -0
- paddleformers/transformers/llama/tokenizer.py +502 -0
- paddleformers/transformers/llama/tokenizer_fast.py +171 -0
- paddleformers/transformers/long_sequence_strategies/__init__.py +18 -0
- paddleformers/transformers/long_sequence_strategies/attention_strategies.py +51 -0
- paddleformers/transformers/long_sequence_strategies/embedding_strategies.py +223 -0
- paddleformers/transformers/long_sequence_strategies/long_sequence_strategies.py +68 -0
- paddleformers/transformers/mc2_parallel_linear.py +230 -0
- paddleformers/transformers/model_outputs.py +1568 -0
- paddleformers/transformers/model_utils.py +3315 -0
- paddleformers/transformers/moe_gate.py +588 -0
- paddleformers/transformers/moe_gate_auto.py +656 -0
- paddleformers/transformers/moe_layer.py +378 -0
- paddleformers/transformers/moe_layer_auto.py +329 -0
- paddleformers/transformers/moe_utils.py +101 -0
- paddleformers/transformers/ofa_utils.py +326 -0
- paddleformers/transformers/optimization.py +304 -0
- paddleformers/transformers/processing_utils.py +136 -0
- paddleformers/transformers/qwen/__init__.py +19 -0
- paddleformers/transformers/qwen/configuration.py +86 -0
- paddleformers/transformers/qwen/modeling.py +1293 -0
- paddleformers/transformers/qwen/modeling_auto.py +970 -0
- paddleformers/transformers/qwen/modeling_network.py +830 -0
- paddleformers/transformers/qwen/modeling_pp.py +220 -0
- paddleformers/transformers/qwen/tokenizer.py +270 -0
- paddleformers/transformers/qwen2/__init__.py +20 -0
- paddleformers/transformers/qwen2/configuration.py +164 -0
- paddleformers/transformers/qwen2/modeling.py +1973 -0
- paddleformers/transformers/qwen2/modeling_pp.py +363 -0
- paddleformers/transformers/qwen2/tokenizer.py +448 -0
- paddleformers/transformers/qwen2/tokenizer_fast.py +131 -0
- paddleformers/transformers/qwen2_moe/__init__.py +18 -0
- paddleformers/transformers/qwen2_moe/configuration.py +186 -0
- paddleformers/transformers/qwen2_moe/modeling.py +1735 -0
- paddleformers/transformers/qwen2_moe/modeling_pp.py +354 -0
- paddleformers/transformers/qwen3/__init__.py +18 -0
- paddleformers/transformers/qwen3/configuration.py +195 -0
- paddleformers/transformers/qwen3/modeling.py +1448 -0
- paddleformers/transformers/qwen3/modeling_pp.py +363 -0
- paddleformers/transformers/qwen3_moe/__init__.py +17 -0
- paddleformers/transformers/qwen3_moe/configuration.py +219 -0
- paddleformers/transformers/qwen3_moe/modeling.py +996 -0
- paddleformers/transformers/qwen3_moe/modeling_pp.py +254 -0
- paddleformers/transformers/refined_recompute.py +791 -0
- paddleformers/transformers/ring_flash_attention.py +353 -0
- paddleformers/transformers/segment_parallel_utils.py +137 -0
- paddleformers/transformers/sentencepiece_model_pb2.py +1534 -0
- paddleformers/transformers/sequence_parallel_utils.py +139 -0
- paddleformers/transformers/tensor_parallel_utils.py +583 -0
- paddleformers/transformers/token_dispatcher.py +284 -0
- paddleformers/transformers/tokenizer_utils.py +2199 -0
- paddleformers/transformers/tokenizer_utils_base.py +3655 -0
- paddleformers/transformers/tokenizer_utils_fast.py +882 -0
- paddleformers/transformers/transposed_linear.py +59 -0
- paddleformers/transformers/utils.py +1006 -0
- paddleformers/trl/__init__.py +28 -0
- paddleformers/trl/dpo_auto_trainer.py +1153 -0
- paddleformers/trl/dpo_trainer.py +571 -0
- paddleformers/trl/embedding_trainer.py +192 -0
- paddleformers/trl/extras/__init__.py +13 -0
- paddleformers/trl/extras/dataset_formatting.py +129 -0
- paddleformers/trl/kto_trainer.py +554 -0
- paddleformers/trl/llm_utils.py +944 -0
- paddleformers/trl/model_config.py +147 -0
- paddleformers/trl/quant_config.py +118 -0
- paddleformers/trl/sft_auto_trainer.py +911 -0
- paddleformers/trl/sft_config.py +102 -0
- paddleformers/trl/sft_trainer.py +421 -0
- paddleformers/trl/sftdata_config.py +68 -0
- paddleformers/trl/trl_data.py +268 -0
- paddleformers/trl/trl_utils.py +49 -0
- paddleformers/trl/utils.py +43 -0
- paddleformers/utils/__init__.py +50 -0
- paddleformers/utils/adamw_triton.py +194 -0
- paddleformers/utils/batch_sampler.py +182 -0
- paddleformers/utils/converter.py +18 -0
- paddleformers/utils/distributed.py +224 -0
- paddleformers/utils/doc_parser.py +432 -0
- paddleformers/utils/download/__init__.py +340 -0
- paddleformers/utils/download/aistudio_hub_download.py +728 -0
- paddleformers/utils/download/bos_download.py +287 -0
- paddleformers/utils/download/common.py +662 -0
- paddleformers/utils/downloader.py +592 -0
- paddleformers/utils/env.py +166 -0
- paddleformers/utils/fault_tolerance.py +38 -0
- paddleformers/utils/ie_utils.py +142 -0
- paddleformers/utils/image_utils.py +734 -0
- paddleformers/utils/import_utils.py +384 -0
- paddleformers/utils/infohub.py +53 -0
- paddleformers/utils/initializer.py +337 -0
- paddleformers/utils/log.py +204 -0
- paddleformers/utils/memory_utils.py +39 -0
- paddleformers/utils/nested.py +128 -0
- paddleformers/utils/optimizer.py +608 -0
- paddleformers/utils/paddle_patch.py +148 -0
- paddleformers/utils/pdc_sdk.py +629 -0
- paddleformers/utils/profiler.py +130 -0
- paddleformers/utils/safetensors.py +320 -0
- paddleformers/utils/serialization.py +309 -0
- paddleformers/utils/tools.py +247 -0
- paddleformers/version/__init__.py +25 -0
- paddleformers/version/git.py +48 -0
- paddleformers-0.1.dist-info/METADATA +48 -0
- paddleformers-0.1.dist-info/RECORD +268 -0
- paddleformers-0.1.dist-info/WHEEL +5 -0
- paddleformers-0.1.dist-info/entry_points.txt +2 -0
- paddleformers-0.1.dist-info/licenses/LICENSE +203 -0
- paddleformers-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import sys
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
|
|
19
|
+
PADDLEFORMERS_STABLE_VERSION = "PADDLEFORMERS_STABLE_VERSION"
|
|
20
|
+
|
|
21
|
+
# this version is used for develop and test.
|
|
22
|
+
# release version will be added fixed version by setup.py.
|
|
23
|
+
__version__ = "0.1.post"
|
|
24
|
+
if os.getenv(PADDLEFORMERS_STABLE_VERSION):
|
|
25
|
+
__version__ = __version__.replace(".post", "")
|
|
26
|
+
else:
|
|
27
|
+
formatted_date = datetime.now().date().strftime("%Y%m%d")
|
|
28
|
+
__version__ = __version__.replace(".post", ".post{}".format(formatted_date))
|
|
29
|
+
|
|
30
|
+
# the next line will be replaced by setup.py for release version.
|
|
31
|
+
|
|
32
|
+
__version__ = "0.1"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
if "datasets" in sys.modules.keys():
|
|
37
|
+
from paddleformers.utils.log import logger
|
|
38
|
+
|
|
39
|
+
logger.warning(
|
|
40
|
+
"Detected that datasets module was imported before paddleformers. "
|
|
41
|
+
"This may cause PaddleFormers datasets to be unavailable in intranet. "
|
|
42
|
+
"Please import paddleformers before datasets module to avoid download issues"
|
|
43
|
+
)
|
|
44
|
+
import paddle
|
|
45
|
+
|
|
46
|
+
from . import (
|
|
47
|
+
data,
|
|
48
|
+
datasets,
|
|
49
|
+
mergekit,
|
|
50
|
+
ops,
|
|
51
|
+
peft,
|
|
52
|
+
quantization,
|
|
53
|
+
trainer,
|
|
54
|
+
transformers,
|
|
55
|
+
trl,
|
|
56
|
+
utils,
|
|
57
|
+
version,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
paddle.disable_signal_handler()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .blendable_dataset import *
|
|
16
|
+
from .causal_dataset import *
|
|
17
|
+
from .collate import *
|
|
18
|
+
from .data_collator import *
|
|
19
|
+
from .dist_dataloader import *
|
|
20
|
+
from .sampler import *
|
|
21
|
+
from .vocab import *
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import hashlib
|
|
16
|
+
import importlib.metadata
|
|
17
|
+
import os
|
|
18
|
+
import time
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import paddle
|
|
22
|
+
|
|
23
|
+
local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def print_rank_0(*args, **kwargs):
|
|
27
|
+
if paddle.distributed.get_rank() == 0:
|
|
28
|
+
print(*args, **kwargs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BlendableDataset(paddle.io.Dataset):
|
|
32
|
+
def __init__(self, datasets, weights, size, share_folder, *, data_cache_path=None):
|
|
33
|
+
|
|
34
|
+
self.datasets = datasets
|
|
35
|
+
num_datasets = len(datasets)
|
|
36
|
+
assert num_datasets == len(weights)
|
|
37
|
+
|
|
38
|
+
self.size = size
|
|
39
|
+
|
|
40
|
+
# Normalize weights.
|
|
41
|
+
weights = np.array(weights, dtype=np.float64)
|
|
42
|
+
sum_weights = np.sum(weights)
|
|
43
|
+
assert sum_weights > 0.0
|
|
44
|
+
weights /= sum_weights
|
|
45
|
+
|
|
46
|
+
# Build indices.
|
|
47
|
+
def _build_indices():
|
|
48
|
+
start_time = time.time()
|
|
49
|
+
|
|
50
|
+
fast_dataindex_version = importlib.metadata.version("fast_dataindex")
|
|
51
|
+
if fast_dataindex_version > "0.1.1":
|
|
52
|
+
assert (
|
|
53
|
+
num_datasets < 32767
|
|
54
|
+
), f"Detect num_datasets({num_datasets})>=32767. Currently, num_datasets should be less than 32767."
|
|
55
|
+
dataset_index = np.zeros(self.size, dtype=np.int16)
|
|
56
|
+
else:
|
|
57
|
+
assert (
|
|
58
|
+
num_datasets < 255
|
|
59
|
+
), f"Detect num_datasets:({num_datasets})>=255. When 'fast_dataindex<=0.1.1', num_datasets should be less than 255. To support num_datasets greater than 255, please upgrade `fast_dataindex>=0.1.2`."
|
|
60
|
+
dataset_index = np.zeros(self.size, dtype=np.uint8)
|
|
61
|
+
dataset_sample_index = np.zeros(self.size, dtype=np.int64)
|
|
62
|
+
|
|
63
|
+
from fast_dataindex import helpers
|
|
64
|
+
|
|
65
|
+
helpers.build_blending_indices(
|
|
66
|
+
dataset_index,
|
|
67
|
+
dataset_sample_index,
|
|
68
|
+
weights,
|
|
69
|
+
num_datasets,
|
|
70
|
+
self.size,
|
|
71
|
+
local_rank == 0,
|
|
72
|
+
# paddle.distributed.get_rank() == 0,
|
|
73
|
+
)
|
|
74
|
+
print_rank_0(
|
|
75
|
+
"> elapsed time for building blendable dataset indices: "
|
|
76
|
+
"{:.2f} (sec)".format(time.time() - start_time)
|
|
77
|
+
)
|
|
78
|
+
return dataset_index, dataset_sample_index
|
|
79
|
+
|
|
80
|
+
desc = "Blendable dataset\n\n"
|
|
81
|
+
desc += "Datasets:\n"
|
|
82
|
+
for dataset in datasets:
|
|
83
|
+
desc += dataset.desc + "\n\n"
|
|
84
|
+
desc += f"Weights: {weights}\n"
|
|
85
|
+
desc += f"Size: {size}\n"
|
|
86
|
+
self.desc = desc
|
|
87
|
+
|
|
88
|
+
if data_cache_path:
|
|
89
|
+
desc_hash = hashlib.md5(desc.encode("utf-8")).hexdigest()
|
|
90
|
+
desc_path = os.path.join(data_cache_path, desc_hash + ".dsc")
|
|
91
|
+
index_path = os.path.join(data_cache_path, desc_hash + "_index.npy")
|
|
92
|
+
sample_index_path = os.path.join(data_cache_path, desc_hash + "_sample_index.npy")
|
|
93
|
+
cache_hit = os.path.isfile(index_path) and os.path.isfile(sample_index_path)
|
|
94
|
+
# cache_success = True
|
|
95
|
+
# if paddle.distributed.get_rank() == 0 and not cache_hit:
|
|
96
|
+
check_rank_flag = not cache_hit and local_rank == 0
|
|
97
|
+
if share_folder:
|
|
98
|
+
check_rank_flag = not cache_hit and paddle.distributed.get_rank() == 0
|
|
99
|
+
|
|
100
|
+
print(
|
|
101
|
+
f"searching for blendable dataset, cache_hit={cache_hit}, share_folder {share_folder}, check_rank_flag {check_rank_flag}",
|
|
102
|
+
flush=True,
|
|
103
|
+
)
|
|
104
|
+
if check_rank_flag:
|
|
105
|
+
print(
|
|
106
|
+
" > WARNING: could not find index map files for blendable"
|
|
107
|
+
" dataset, building indices on rank 0 ...",
|
|
108
|
+
flush=True,
|
|
109
|
+
)
|
|
110
|
+
dataset_index, dataset_sample_index = _build_indices()
|
|
111
|
+
try:
|
|
112
|
+
os.makedirs(os.path.dirname(index_path), exist_ok=True)
|
|
113
|
+
with open(desc_path, "wt") as fd:
|
|
114
|
+
fd.write(desc)
|
|
115
|
+
np.save(index_path, dataset_index, allow_pickle=True)
|
|
116
|
+
np.save(sample_index_path, dataset_sample_index, allow_pickle=True)
|
|
117
|
+
except OSError:
|
|
118
|
+
print(f"There was an error trying to create the data cache directory ({data_cache_path})")
|
|
119
|
+
print("or a file in it. This is set with the --data-cache-path argument. Please")
|
|
120
|
+
print("ensure you have write access to this directory or specify one that you do have")
|
|
121
|
+
print("write access to.")
|
|
122
|
+
# cache_success = False
|
|
123
|
+
|
|
124
|
+
# hcg = paddle.distributed.fleet.get_hybrid_communicate_group()
|
|
125
|
+
|
|
126
|
+
# counts = paddle.to_tensor([cache_success], dtype="int64")
|
|
127
|
+
# paddle.distributed.all_reduce(counts, group=hcg.get_data_parallel_group())
|
|
128
|
+
# paddle.distributed.all_reduce(counts, group=hcg.get_pipeline_model_parallel_group())
|
|
129
|
+
# if counts[0].item() != (
|
|
130
|
+
# paddle.distributed.get_world_size()
|
|
131
|
+
# // paddle.distributed.get_world_size(group=hcg.get_tensor_model_parallel_group())
|
|
132
|
+
# ):
|
|
133
|
+
# print_rank_0("Data index creation unsuccessful, exiting.")
|
|
134
|
+
# exit()
|
|
135
|
+
|
|
136
|
+
else:
|
|
137
|
+
while True:
|
|
138
|
+
if (not os.path.isfile(index_path)) or (not os.path.isfile(sample_index_path)):
|
|
139
|
+
print("building indices on rank 0 ...", flush=True)
|
|
140
|
+
time.sleep(3)
|
|
141
|
+
else:
|
|
142
|
+
try:
|
|
143
|
+
np.load(index_path, allow_pickle=True, mmap_mode="r")
|
|
144
|
+
print("build success", flush=True)
|
|
145
|
+
break
|
|
146
|
+
except Exception:
|
|
147
|
+
print("%s file is still writing or damaged, please wait for a moment." % index_path)
|
|
148
|
+
time.sleep(3)
|
|
149
|
+
|
|
150
|
+
# paddle.distributed.barrier()
|
|
151
|
+
# Load on all ranks.
|
|
152
|
+
print_rank_0(f"> loading blendable dataset index: {index_path}")
|
|
153
|
+
self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode="r")
|
|
154
|
+
assert self.dataset_index.size == self.size
|
|
155
|
+
|
|
156
|
+
print_rank_0(f"> loading blendable dataset sample index: {sample_index_path}")
|
|
157
|
+
self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode="r")
|
|
158
|
+
assert self.dataset_sample_index.size == self.size
|
|
159
|
+
else:
|
|
160
|
+
print_rank_0(
|
|
161
|
+
"building indices for the blendable dataset, Since --data_cache is not specified, the index file will not be stored.",
|
|
162
|
+
flush=True,
|
|
163
|
+
)
|
|
164
|
+
self.dataset_index, self.dataset_sample_index = _build_indices()
|
|
165
|
+
|
|
166
|
+
# Check size
|
|
167
|
+
_ = self.__getitem__(self.size - 1)
|
|
168
|
+
try:
|
|
169
|
+
_ = self.__getitem__(self.size)
|
|
170
|
+
raise RuntimeError("BlendedDataset size is improperly bounded")
|
|
171
|
+
except IndexError:
|
|
172
|
+
pass
|
|
173
|
+
print_rank_0("> size of blendable dataset: " "{} samples".format(self.size))
|
|
174
|
+
|
|
175
|
+
def __len__(self):
|
|
176
|
+
return self.size
|
|
177
|
+
|
|
178
|
+
def __getitem__(self, idx):
|
|
179
|
+
dataset_idx = self.dataset_index[idx]
|
|
180
|
+
sample_idx = self.dataset_sample_index[idx]
|
|
181
|
+
return {
|
|
182
|
+
"dataset_idx": dataset_idx,
|
|
183
|
+
**self.datasets[dataset_idx][sample_idx],
|
|
184
|
+
}
|