mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1146 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Transform distributed safetensors"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
import time
|
|
20
|
+
import glob
|
|
21
|
+
import re
|
|
22
|
+
import math
|
|
23
|
+
import json
|
|
24
|
+
from collections import defaultdict
|
|
25
|
+
|
|
26
|
+
import multiprocessing as mp
|
|
27
|
+
import numpy as np
|
|
28
|
+
import mindspore as ms
|
|
29
|
+
from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
|
|
30
|
+
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
|
|
31
|
+
_insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src
|
|
32
|
+
from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
|
|
33
|
+
_get_needed_rank_transform_operator_map_by_layouts, \
|
|
34
|
+
_generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
|
|
35
|
+
_extract_layout_item, _load_tensor_shape, _apply_operator
|
|
36
|
+
from mindspore.parallel._parallel_serialization import _build_searched_strategy, _load_protobuf_strategy, \
|
|
37
|
+
_convert_to_list
|
|
38
|
+
|
|
39
|
+
from safetensors.numpy import save_file, load_file
|
|
40
|
+
from safetensors import safe_open
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _load_and_transform(path, name_map, load_func, transform_func):
|
|
44
|
+
if load_func is not None:
|
|
45
|
+
param_dict = load_func(path)
|
|
46
|
+
else:
|
|
47
|
+
param_dict = path
|
|
48
|
+
transform_dict = {}
|
|
49
|
+
for k, v in param_dict.items():
|
|
50
|
+
new_name = name_map.get(k, k) if name_map is not None else k
|
|
51
|
+
transform_dict[new_name] = transform_func(v, new_name)
|
|
52
|
+
return transform_dict
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _transform_tensor_to_numpy(path, name_map=None):
|
|
56
|
+
return _load_and_transform(path, name_map, ms.load_checkpoint, lambda v, new_name: v.asnumpy())
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _transform_numpy_to_tensor(path, name_map=None):
|
|
60
|
+
return _load_and_transform(path, name_map, load_file, lambda v, new_name: ms.Parameter(v, name=new_name))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _process_file(file_info):
|
|
64
|
+
cur_ckpt_path, name_map, save_path, file = file_info
|
|
65
|
+
param_dict_numpy = _transform_tensor_to_numpy(cur_ckpt_path, name_map)
|
|
66
|
+
safetensors_filename = file.replace(".ckpt", ".safetensors")
|
|
67
|
+
dst_file = os.path.join(save_path, safetensors_filename)
|
|
68
|
+
save_file(param_dict_numpy, dst_file)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _process_file_safetensors(file_info):
|
|
72
|
+
cur_safe_path, name_map, save_path, file = file_info
|
|
73
|
+
param_dict_tensor = _transform_numpy_to_tensor(cur_safe_path, name_map)
|
|
74
|
+
ckpt_filename = file.replace(".safetensors", ".ckpt")
|
|
75
|
+
dst_file = os.path.join(save_path, ckpt_filename)
|
|
76
|
+
ms.save_checkpoint(param_dict_tensor, dst_file)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _gather_tasks(file_path, save_path, file_name_regex, name_map):
|
|
80
|
+
"""gather transform rank together"""
|
|
81
|
+
tasks = []
|
|
82
|
+
for root, dirs, _ in os.walk(file_path):
|
|
83
|
+
if root != file_path:
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
87
|
+
if not rank_dirs:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
|
|
90
|
+
|
|
91
|
+
for rank_dir in rank_dirs:
|
|
92
|
+
rank_dir_path = os.path.join(root, rank_dir)
|
|
93
|
+
dst_root = os.path.join(save_path,
|
|
94
|
+
os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
|
|
95
|
+
os.makedirs(dst_root, exist_ok=True)
|
|
96
|
+
tasks.extend(
|
|
97
|
+
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
98
|
+
for file in os.listdir(rank_dir_path)
|
|
99
|
+
if file.endswith(".ckpt") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
100
|
+
)
|
|
101
|
+
return tasks
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _progress_bar(iterable, total=None):
|
|
105
|
+
"""
|
|
106
|
+
Decorate an iterable object, returning an iterator which acts exactly
|
|
107
|
+
like the original iterable, but prints a dynamically updating
|
|
108
|
+
progressbar every time a value is requested.
|
|
109
|
+
"""
|
|
110
|
+
if total is None:
|
|
111
|
+
total = len(iterable)
|
|
112
|
+
|
|
113
|
+
start_time = time.time()
|
|
114
|
+
|
|
115
|
+
def print_progress_bar(iteration):
|
|
116
|
+
percent = f"{100 * (iteration / float(total)):.1f}"
|
|
117
|
+
bar_length = 40
|
|
118
|
+
filled_length = int(bar_length * iteration // total)
|
|
119
|
+
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
|
120
|
+
|
|
121
|
+
elapsed_time = time.time() - start_time
|
|
122
|
+
estimated_total_time = elapsed_time / iteration * total
|
|
123
|
+
remaining_time = estimated_total_time - elapsed_time
|
|
124
|
+
|
|
125
|
+
elapsed_time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
|
|
126
|
+
remaining_time_str = time.strftime("%H:%M:%S", time.gmtime(remaining_time))
|
|
127
|
+
|
|
128
|
+
print(f'\r{percent}%|{bar}|[{elapsed_time_str}<{remaining_time_str}]', end='')
|
|
129
|
+
if iteration == total:
|
|
130
|
+
print()
|
|
131
|
+
|
|
132
|
+
for i, item in enumerate(iterable, start=1):
|
|
133
|
+
yield item
|
|
134
|
+
print_progress_bar(i)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
138
|
+
"""
|
|
139
|
+
Converts MindSpore checkpoint files into safetensors format and saves them to `save_path`.
|
|
140
|
+
Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
|
|
141
|
+
used for securely storing Tensors with fast speed (zero copy).
|
|
142
|
+
|
|
143
|
+
Note:
|
|
144
|
+
The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
|
|
145
|
+
too large, otherwise it may cause freezing.
|
|
146
|
+
The safetensors format does not support the enc verification function. If ckpt is enabled to save enc
|
|
147
|
+
verification, an error will be generated when performing the conversion.
|
|
148
|
+
The safetensors format currently does not support crc verification function. If ckpt contains crc verification
|
|
149
|
+
information, the crc verification information will be lost after conversion to safetensors.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
|
|
153
|
+
save_path (str, optional): Directory path where safetensors files will be saved. Defaults: ``None``.
|
|
154
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
|
|
155
|
+
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
156
|
+
Defaults: ``None``.
|
|
157
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
|
|
158
|
+
Raises:
|
|
159
|
+
ValueError: If the input path is invalid or the save_path is not a directory,
|
|
160
|
+
or the file_path does not end with '.ckpt'.
|
|
161
|
+
|
|
162
|
+
Supported Platforms:
|
|
163
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
164
|
+
|
|
165
|
+
Examples:
|
|
166
|
+
>>> import mindspore as ms
|
|
167
|
+
>>> ms.ckpt_to_safetensors("./ckpt_save_path")
|
|
168
|
+
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt")
|
|
169
|
+
>>> ms.ckpt_to_safetensors(file_path="./ckpt_save_path/rank0/checkpoint_0.ckpt", save_path="./new_path/")
|
|
170
|
+
>>> namemap = {"lin.weight":"new_name"}
|
|
171
|
+
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
|
|
172
|
+
"""
|
|
173
|
+
is_dir = os.path.isdir(file_path)
|
|
174
|
+
is_file = os.path.isfile(file_path)
|
|
175
|
+
if not is_dir and not is_file:
|
|
176
|
+
raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
|
|
177
|
+
if save_path and os.path.splitext(save_path)[1]:
|
|
178
|
+
raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
|
|
179
|
+
if name_map is not None and not isinstance(name_map, dict):
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
182
|
+
|
|
183
|
+
if is_dir:
|
|
184
|
+
tasks = _gather_tasks(file_path, save_path, file_name_regex, name_map)
|
|
185
|
+
with mp.Pool(processes=processes_num) as pool:
|
|
186
|
+
list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
|
|
187
|
+
elif is_file:
|
|
188
|
+
if not file_path.endswith(".ckpt"):
|
|
189
|
+
raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
|
|
190
|
+
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
191
|
+
raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
|
|
192
|
+
if save_path and not os.path.exists(save_path):
|
|
193
|
+
os.makedirs(save_path, exist_ok=True)
|
|
194
|
+
|
|
195
|
+
param_dict_numpy = _transform_tensor_to_numpy(file_path, name_map)
|
|
196
|
+
safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
|
|
197
|
+
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
|
|
198
|
+
save_file(param_dict_numpy, dst_file)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
|
|
202
|
+
"""gather transform rank together"""
|
|
203
|
+
tasks = []
|
|
204
|
+
for root, dirs, _ in os.walk(file_path):
|
|
205
|
+
if root != file_path:
|
|
206
|
+
continue
|
|
207
|
+
|
|
208
|
+
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
209
|
+
if not rank_dirs:
|
|
210
|
+
raise ValueError(
|
|
211
|
+
f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}")
|
|
212
|
+
|
|
213
|
+
for rank_dir in rank_dirs:
|
|
214
|
+
rank_dir_path = os.path.join(root, rank_dir)
|
|
215
|
+
dst_root = os.path.join(save_path,
|
|
216
|
+
os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
|
|
217
|
+
os.makedirs(dst_root, exist_ok=True)
|
|
218
|
+
tasks.extend(
|
|
219
|
+
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
220
|
+
for file in os.listdir(rank_dir_path)
|
|
221
|
+
if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
222
|
+
)
|
|
223
|
+
return tasks
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
227
|
+
"""
|
|
228
|
+
Converts safetensors files into MindSpore checkpoint format and saves them to `save_path`.
|
|
229
|
+
Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
|
|
230
|
+
used for securely storing Tensors with fast speed (zero copy).
|
|
231
|
+
|
|
232
|
+
Note:
|
|
233
|
+
The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
|
|
234
|
+
too large, otherwise it may cause freezing.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
|
|
238
|
+
save_path (str, optional): Directory path where checkpoint files will be saved. Defaults: ``None``.
|
|
239
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
|
|
240
|
+
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
241
|
+
Defaults: ``None``.
|
|
242
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
|
|
243
|
+
|
|
244
|
+
Raises:
|
|
245
|
+
ValueError: If the input path is invalid, the save_path is not a directory,
|
|
246
|
+
or the file_path does not end with '.safetensors'.
|
|
247
|
+
|
|
248
|
+
Supported Platforms:
|
|
249
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
250
|
+
|
|
251
|
+
Examples:
|
|
252
|
+
>>> import mindspore as ms
|
|
253
|
+
>>> ms.safetensors_to_ckpt("./safetensors_save_path")
|
|
254
|
+
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors")
|
|
255
|
+
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/")
|
|
256
|
+
>>> namemap = {"lin.weight":"new_name"}
|
|
257
|
+
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
|
|
258
|
+
"""
|
|
259
|
+
is_dir = os.path.isdir(file_path)
|
|
260
|
+
is_file = os.path.isfile(file_path)
|
|
261
|
+
if not is_dir and not is_file:
|
|
262
|
+
raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
|
|
263
|
+
if save_path and os.path.splitext(save_path)[1]:
|
|
264
|
+
raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
|
|
265
|
+
if name_map is not None and not isinstance(name_map, dict):
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
268
|
+
|
|
269
|
+
if is_dir:
|
|
270
|
+
tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
|
|
271
|
+
with mp.Pool(processes=processes_num) as pool:
|
|
272
|
+
list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
|
|
273
|
+
elif is_file:
|
|
274
|
+
if not file_path.endswith(".safetensors"):
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
|
|
277
|
+
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
278
|
+
raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
|
|
279
|
+
if save_path and not os.path.exists(save_path):
|
|
280
|
+
os.makedirs(save_path, exist_ok=True)
|
|
281
|
+
|
|
282
|
+
param_dict_tensor = _transform_numpy_to_tensor(file_path, name_map)
|
|
283
|
+
ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
|
|
284
|
+
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
|
|
285
|
+
ms.save_checkpoint(param_dict_tensor, dst_file)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _check_transform_safetensors(src_safetensors_dir, ckpt_prefix, src_strategy_file, dst_strategy_file):
|
|
289
|
+
"""check _transform_safetensors input"""
|
|
290
|
+
if not isinstance(ckpt_prefix, str):
|
|
291
|
+
raise TypeError("The ckpt_prefix should be a str.")
|
|
292
|
+
if src_strategy_file and os.path.dirname(src_strategy_file) and not os.path.exists(
|
|
293
|
+
os.path.dirname(src_strategy_file)):
|
|
294
|
+
raise ValueError("The director of src_strategy_file: {} is not exists.".
|
|
295
|
+
format(os.path.dirname(src_strategy_file)))
|
|
296
|
+
if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists(
|
|
297
|
+
os.path.dirname(dst_strategy_file)):
|
|
298
|
+
raise ValueError("The director of dst_strategy_file: {} is not exists.".
|
|
299
|
+
format(os.path.dirname(dst_strategy_file)))
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def _check_output_format(output_format):
|
|
303
|
+
if output_format not in ["safetensors", "ckpt"]:
|
|
304
|
+
raise ValueError(f"For 'transform_safetensors', the output_format must be "
|
|
305
|
+
f"'safetensors' or 'ckpt', but got {output_format}.")
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _split_protobuf_strategy(merged_strategy_file):
|
|
309
|
+
"""split src_strategy_file by pp"""
|
|
310
|
+
dst_parallel_strategy_map = _load_protobuf_strategy(merged_strategy_file)
|
|
311
|
+
if not dst_parallel_strategy_map.parallel_strategy_item or not dst_parallel_strategy_map.parallel_layout_item:
|
|
312
|
+
raise ValueError(f"The merged strategy file {merged_strategy_file} is empty")
|
|
313
|
+
|
|
314
|
+
src_dict = {}
|
|
315
|
+
for layout_item in dst_parallel_strategy_map.parallel_layout_item:
|
|
316
|
+
stage, _ = layout_item.param_name.split('-', 1)
|
|
317
|
+
stage = int(stage)
|
|
318
|
+
if stage not in src_dict:
|
|
319
|
+
src_dict[stage] = {}
|
|
320
|
+
parameter_name = layout_item.param_name
|
|
321
|
+
layout = layout_item.parallel_layouts
|
|
322
|
+
src_dict[stage][parameter_name] = layout
|
|
323
|
+
return src_dict
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _transform_safetensors(src_safetensors_dir, dst_safetensors_dir, ckpt_prefix, src_strategy_file=None,
|
|
327
|
+
dst_strategy_file=None, process_num=1, output_format="safetensors"):
|
|
328
|
+
"""Transform distributed safetensors from source sharding strategy to destination sharding strategy for a rank."""
|
|
329
|
+
_check_transform_safetensors(src_safetensors_dir, ckpt_prefix, src_strategy_file, dst_strategy_file)
|
|
330
|
+
_check_output_format(output_format)
|
|
331
|
+
_make_dir(dst_safetensors_dir, "path")
|
|
332
|
+
all_safetensor_files_map = _collect_safetensor_files(src_safetensors_dir)
|
|
333
|
+
|
|
334
|
+
dst_strategy_dict = _build_searched_strategy(dst_strategy_file)
|
|
335
|
+
pipeline_stage_num = _extract_pipeline_stage_num(src_strategy_file)
|
|
336
|
+
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file)
|
|
337
|
+
|
|
338
|
+
if pipeline_stage_num > 1 and dst_stage_num == 1:
|
|
339
|
+
stage_dict = _split_protobuf_strategy(src_strategy_file)
|
|
340
|
+
|
|
341
|
+
processes = []
|
|
342
|
+
manager = mp.Manager()
|
|
343
|
+
_transform_param_list = manager.list()
|
|
344
|
+
for _, src_strategy_dict in stage_dict.items():
|
|
345
|
+
p = mp.Process(target=_transform_stage_safetensors,
|
|
346
|
+
args=(src_strategy_dict, dst_strategy_dict, ckpt_prefix,
|
|
347
|
+
dst_safetensors_dir, output_format, all_safetensor_files_map, process_num,
|
|
348
|
+
_transform_param_list))
|
|
349
|
+
p.start()
|
|
350
|
+
processes.append(p)
|
|
351
|
+
for p in processes:
|
|
352
|
+
p.join()
|
|
353
|
+
|
|
354
|
+
_save_final_safetensors(_transform_param_list, output_format)
|
|
355
|
+
else:
|
|
356
|
+
src_strategy_dict = _build_searched_strategy(src_strategy_file)
|
|
357
|
+
_transform_stage_safetensors(src_strategy_dict, dst_strategy_dict, ckpt_prefix,
|
|
358
|
+
dst_safetensors_dir, output_format, all_safetensor_files_map, process_num,
|
|
359
|
+
_transform_param_list=None)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _transform_stage_safetensors(src_strategy_dict, dst_strategy_dict, ckpt_prefix,
|
|
363
|
+
dst_safetensors_dir, output_format, all_safetensor_files_map, process_num,
|
|
364
|
+
_transform_param_list):
|
|
365
|
+
"""Transform distributed safetensors by stage"""
|
|
366
|
+
src_stage_device_num = _get_device_num_from_strategy(src_strategy_dict)
|
|
367
|
+
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
368
|
+
|
|
369
|
+
origin_src_strategy_list = _extract_layout_map(src_strategy_dict)
|
|
370
|
+
origin_dst_strategy_list = _extract_layout_map(dst_strategy_dict)
|
|
371
|
+
|
|
372
|
+
needed_rank_list_map = _find_needed_ranks(src_strategy_dict, dst_strategy_dict)
|
|
373
|
+
for needed_rank_list, rank in needed_rank_list_map.items():
|
|
374
|
+
for needed_rank in needed_rank_list.split("-"):
|
|
375
|
+
if int(needed_rank) not in all_safetensor_files_map:
|
|
376
|
+
raise ValueError("The safetensor file of rank{} is needed for converting rank{}'s safetensor, "
|
|
377
|
+
"but it is missing.".format(needed_rank, rank))
|
|
378
|
+
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
379
|
+
if not (len(needed_rank_list_map) == 1 and dst_stage_num > 1) and process_num > len(needed_rank_list_map):
|
|
380
|
+
ms.log.warning("The value of process_num cannot be greater than that of needed_rank_list_map.")
|
|
381
|
+
process_num = len(needed_rank_list_map)
|
|
382
|
+
_transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
|
|
383
|
+
dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
|
|
384
|
+
origin_src_strategy_list, origin_dst_strategy_list, ckpt_prefix,
|
|
385
|
+
dst_safetensors_dir, process_num, output_format,
|
|
386
|
+
_transform_param_list)
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num):
|
|
390
|
+
"""
|
|
391
|
+
Distributes files across multiple processes based on file size to balance the processing load.
|
|
392
|
+
"""
|
|
393
|
+
if process_num == 1:
|
|
394
|
+
return [needed_rank_list_map]
|
|
395
|
+
# Calculate the size of each file.
|
|
396
|
+
# if src==1, dst pp>1, split for pp number.
|
|
397
|
+
if len(needed_rank_list_map) == 1:
|
|
398
|
+
src_rank = next(iter(needed_rank_list_map.keys()))
|
|
399
|
+
dst_list = next(iter(needed_rank_list_map.values()))
|
|
400
|
+
size = len(dst_list) // process_num
|
|
401
|
+
split_list = [dst_list[i:i + size] for i in range(0, len(dst_list), size)]
|
|
402
|
+
part_list_dict = [dict() for _ in range(process_num)]
|
|
403
|
+
for index in range(process_num):
|
|
404
|
+
part_list_dict[index][src_rank] = split_list[index]
|
|
405
|
+
return part_list_dict
|
|
406
|
+
|
|
407
|
+
rank_size = dict()
|
|
408
|
+
for rank_id, file_name in all_safetensor_files_map.items():
|
|
409
|
+
tmp_size = os.path.getsize(file_name) / 1024 / 1024
|
|
410
|
+
rank_size[rank_id] = tmp_size
|
|
411
|
+
# Obtain the rank and size required by all parts.
|
|
412
|
+
part_total = []
|
|
413
|
+
for index, (k, v) in enumerate(needed_rank_list_map.items()):
|
|
414
|
+
tmp_part = []
|
|
415
|
+
key_ele = k.split("-")
|
|
416
|
+
tmp_size = 0
|
|
417
|
+
for ele in key_ele:
|
|
418
|
+
tmp_size += rank_size[int(ele)]
|
|
419
|
+
tmp_part.append(index)
|
|
420
|
+
tmp_part.append(tmp_size)
|
|
421
|
+
part_total.append(tmp_part)
|
|
422
|
+
# Sort each part by size.
|
|
423
|
+
part_total = sorted(part_total, key=lambda x: x[1], reverse=True)
|
|
424
|
+
part_list = [[] for _ in range(process_num)]
|
|
425
|
+
part_size = [[] for _ in range(process_num)]
|
|
426
|
+
for [index, size] in part_total:
|
|
427
|
+
min_sum = float('inf')
|
|
428
|
+
min_idx = -1
|
|
429
|
+
for ele in range(process_num):
|
|
430
|
+
if sum(part_size[ele]) < min_sum:
|
|
431
|
+
min_sum = sum(part_size[ele])
|
|
432
|
+
min_idx = ele
|
|
433
|
+
part_list[min_idx].append(index)
|
|
434
|
+
part_size[min_idx].append(size)
|
|
435
|
+
|
|
436
|
+
part_list_dict = [dict() for _ in range(process_num)]
|
|
437
|
+
for index, (k, v) in enumerate(needed_rank_list_map.items()):
|
|
438
|
+
for idd, ele in enumerate(part_list):
|
|
439
|
+
if index in ele:
|
|
440
|
+
part_list_dict[idd][k] = v
|
|
441
|
+
break
|
|
442
|
+
return part_list_dict
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
|
|
446
|
+
dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
|
|
447
|
+
origin_src_strategy_list, origin_dst_strategy_list, ckpt_prefix,
|
|
448
|
+
dst_safetensors_dir, process_num, output_format,
|
|
449
|
+
_transform_param_list):
|
|
450
|
+
"""
|
|
451
|
+
Transforms safetensors files to a specified format using parallel processing.
|
|
452
|
+
"""
|
|
453
|
+
# cal param name for every pipeline, save in pipe_param_list.
|
|
454
|
+
pipe_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
455
|
+
pipe_param_list = [None for _ in range(max(pipe_num, process_num))]
|
|
456
|
+
if len(needed_rank_list_map) == 1 and pipe_num > 1:
|
|
457
|
+
process_num = pipe_num
|
|
458
|
+
pipe_param_list = [[] for _ in range(pipe_num)]
|
|
459
|
+
layout_map = _convert_to_list(dst_strategy_dict)
|
|
460
|
+
|
|
461
|
+
for name, layout in layout_map.items():
|
|
462
|
+
pipe_param_list[layout[6][0]].append(name)
|
|
463
|
+
|
|
464
|
+
part_list_dict = _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num)
|
|
465
|
+
processes = []
|
|
466
|
+
for i in range(process_num):
|
|
467
|
+
p = mp.Process(target=_transform_safetensors_single, args=(
|
|
468
|
+
part_list_dict[i], all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
|
|
469
|
+
src_strategy_dict, dst_strategy_dict, origin_src_strategy_list, origin_dst_strategy_list,
|
|
470
|
+
ckpt_prefix, dst_safetensors_dir, output_format, _transform_param_list, pipe_param_list[i]))
|
|
471
|
+
p.start()
|
|
472
|
+
processes.append(p)
|
|
473
|
+
for p in processes:
|
|
474
|
+
p.join()
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
|
|
478
|
+
"""Obtain the specified redundant group."""
|
|
479
|
+
redundancy_tuple = redundancy_dict.get(param_name)
|
|
480
|
+
for rank_list in redundancy_tuple:
|
|
481
|
+
for rank in rank_list:
|
|
482
|
+
if rank_num % device_num == rank % device_num:
|
|
483
|
+
return set(rank_list)
|
|
484
|
+
return set()
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict, redundancy_dict,
|
|
488
|
+
needed_rank, device_num):
|
|
489
|
+
"""Find the rank_id under redundant groups."""
|
|
490
|
+
for param_name in pipe_param_list:
|
|
491
|
+
rank_num = int(needed_rank)
|
|
492
|
+
redundancy_ranks = _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num)
|
|
493
|
+
open_file_id = None
|
|
494
|
+
if single_param_dict.get(param_name) is None:
|
|
495
|
+
continue
|
|
496
|
+
for real_rank in single_param_dict[param_name]:
|
|
497
|
+
for redundancy_rank in redundancy_ranks:
|
|
498
|
+
if real_rank % device_num == redundancy_rank % device_num:
|
|
499
|
+
open_file_id = real_rank
|
|
500
|
+
break
|
|
501
|
+
if open_file_id is not None:
|
|
502
|
+
output = file_dict[open_file_id].get_tensor(param_name)
|
|
503
|
+
saftensor_dict[param_name] = output
|
|
504
|
+
else:
|
|
505
|
+
raise ValueError(f"For _transform_safetensors_single, {param_name} should be in "
|
|
506
|
+
f"{redundancy_ranks}, but in {single_param_dict[param_name]}.")
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
|
|
510
|
+
dst_stage_device_num,
|
|
511
|
+
src_strategy_dict, dst_strategy_dict, origin_src_strategy_list,
|
|
512
|
+
origin_dst_strategy_list,
|
|
513
|
+
ckpt_prefix, dst_safetensors_dir, output_format,
|
|
514
|
+
_transform_param_list, pipe_param_list=None, file_index=None, unified_flag=False,
|
|
515
|
+
src_strategy_file=None):
|
|
516
|
+
"""
|
|
517
|
+
Transforms safetensors files to a specified format without using parallel processing.
|
|
518
|
+
"""
|
|
519
|
+
if src_strategy_file is not None:
|
|
520
|
+
from mindspore.train._utils import get_parameter_redundancy
|
|
521
|
+
redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file)
|
|
522
|
+
redundancy_dict = {}
|
|
523
|
+
device_num = 0
|
|
524
|
+
for param_name, redundancy in redundancy_dict_tmp.items():
|
|
525
|
+
if device_num == 0:
|
|
526
|
+
device_num = max(max(redundancy)) + 1
|
|
527
|
+
origin_param_name = param_name
|
|
528
|
+
pipeline_stage = 0
|
|
529
|
+
if "-" in param_name:
|
|
530
|
+
pipeline_stage, origin_param_name = param_name.split("-")
|
|
531
|
+
pipeline_stage = int(pipeline_stage)
|
|
532
|
+
redundancy_new = tuple(
|
|
533
|
+
(tuple(x + pipeline_stage * device_num for x in subtuple)) for subtuple in redundancy)
|
|
534
|
+
redundancy_dict[origin_param_name] = redundancy_new
|
|
535
|
+
file_dict = {}
|
|
536
|
+
single_param_dict = {}
|
|
537
|
+
for file_id, _ in all_safetensor_files_map.items():
|
|
538
|
+
f = safe_open(all_safetensor_files_map.get(file_id), framework="np")
|
|
539
|
+
file_dict[file_id] = f
|
|
540
|
+
for param_name in f.keys():
|
|
541
|
+
if param_name not in single_param_dict.keys():
|
|
542
|
+
single_param_dict[param_name] = {file_id}
|
|
543
|
+
else:
|
|
544
|
+
single_param_dict[param_name].add(file_id)
|
|
545
|
+
src_strategy_list_keys = _convert_to_list(src_strategy_dict).keys() if src_strategy_dict else []
|
|
546
|
+
dst_strategy_list_keys = _convert_to_list(dst_strategy_dict).keys() if dst_strategy_dict else []
|
|
547
|
+
for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
|
|
548
|
+
param_total_dict = defaultdict(dict)
|
|
549
|
+
param_attr_dict = defaultdict(dict)
|
|
550
|
+
needed_rank_list = needed_rank_list_key.split("-")
|
|
551
|
+
for needed_rank in needed_rank_list:
|
|
552
|
+
if pipe_param_list:
|
|
553
|
+
saftensor_dict = dict()
|
|
554
|
+
if src_strategy_file is not None:
|
|
555
|
+
_find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict,
|
|
556
|
+
redundancy_dict, needed_rank, device_num)
|
|
557
|
+
else:
|
|
558
|
+
with safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
|
|
559
|
+
if not unified_flag:
|
|
560
|
+
all_param_name_set = set(f.keys())
|
|
561
|
+
src_param_name_set = set(src_strategy_list_keys)
|
|
562
|
+
dst_param_name_set = set(dst_strategy_list_keys)
|
|
563
|
+
hyper_param_set = all_param_name_set - (src_param_name_set & dst_param_name_set)
|
|
564
|
+
pipe_param_list.extend(list(hyper_param_set))
|
|
565
|
+
for param_name in pipe_param_list:
|
|
566
|
+
if param_name not in f.keys():
|
|
567
|
+
# param not in ckpt file, check reason
|
|
568
|
+
continue
|
|
569
|
+
output = f.get_tensor(param_name)
|
|
570
|
+
saftensor_dict[param_name] = output
|
|
571
|
+
else:
|
|
572
|
+
saftensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
|
|
573
|
+
for param_name, param in saftensor_dict.items():
|
|
574
|
+
src_rank = int(needed_rank) % src_stage_device_num
|
|
575
|
+
param_total_dict[param_name][src_rank] = param
|
|
576
|
+
param_attr_dict[param_name][src_rank] = (True, False)
|
|
577
|
+
|
|
578
|
+
for transform_rank in transform_rank_list:
|
|
579
|
+
param_total_dict_keys = list(param_total_dict.keys())
|
|
580
|
+
src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(transform_rank, src_strategy_dict,
|
|
581
|
+
dst_strategy_dict)
|
|
582
|
+
# cut the parameter not in the pipeline stage.
|
|
583
|
+
for param in list(param_total_dict.keys()):
|
|
584
|
+
if _parameter_not_in_local_stage(param, origin_src_strategy_list, src_strategy_list) \
|
|
585
|
+
and _parameter_not_in_local_stage(param, origin_dst_strategy_list, dst_strategy_list):
|
|
586
|
+
param_total_dict_keys.remove(param)
|
|
587
|
+
|
|
588
|
+
local_rank_id = transform_rank % dst_stage_device_num
|
|
589
|
+
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
|
|
590
|
+
param_attr_dict, src_strategy_list, dst_strategy_list,
|
|
591
|
+
param_total_dict_keys, src_strategy_file)
|
|
592
|
+
if file_index is not None:
|
|
593
|
+
save_safetensor_file = f"part{file_index}.{output_format}"
|
|
594
|
+
save_safetensor_file_dir = dst_safetensors_dir
|
|
595
|
+
else:
|
|
596
|
+
save_safetensor_file = f"{ckpt_prefix}{transform_rank}.{output_format}"
|
|
597
|
+
save_safetensor_file_dir = os.path.join(dst_safetensors_dir, "rank_{}".format(transform_rank))
|
|
598
|
+
|
|
599
|
+
if not os.path.exists(save_safetensor_file_dir):
|
|
600
|
+
_make_dir(save_safetensor_file_dir, "path")
|
|
601
|
+
save_file_name = os.path.join(save_safetensor_file_dir, save_safetensor_file)
|
|
602
|
+
if _transform_param_list is not None:
|
|
603
|
+
_transform_param_list.append({save_file_name: transform_param_dict})
|
|
604
|
+
else:
|
|
605
|
+
if output_format == "safetensors":
|
|
606
|
+
save_file(transform_param_dict, save_file_name)
|
|
607
|
+
else:
|
|
608
|
+
transform_param_dict = _load_and_transform(transform_param_dict, None, None,
|
|
609
|
+
transform_func=lambda v, name: ms.Parameter(v,
|
|
610
|
+
name=name))
|
|
611
|
+
ms.save_checkpoint(transform_param_dict, save_file_name)
|
|
612
|
+
del param_total_dict_keys
|
|
613
|
+
del param_total_dict
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def _save_final_safetensors(_transform_param_list, output_format):
|
|
617
|
+
"""save file with list"""
|
|
618
|
+
new_transform_dict = {}
|
|
619
|
+
for transform_dict in _transform_param_list:
|
|
620
|
+
for save_file_name, transform_param_dict in transform_dict.items():
|
|
621
|
+
if save_file_name not in new_transform_dict:
|
|
622
|
+
new_transform_dict[save_file_name] = transform_param_dict
|
|
623
|
+
else:
|
|
624
|
+
new_transform_dict[save_file_name].update(transform_param_dict)
|
|
625
|
+
for save_file_name, transform_param_dict in new_transform_dict.items():
|
|
626
|
+
if output_format == "safetensors":
|
|
627
|
+
save_file(transform_param_dict, save_file_name)
|
|
628
|
+
else:
|
|
629
|
+
transform_param_dict = _load_and_transform(transform_param_dict, None, None,
|
|
630
|
+
transform_func=lambda v, name: ms.Parameter(v, name=name))
|
|
631
|
+
ms.save_checkpoint(transform_param_dict, save_file_name)
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckpt_prefix,
|
|
635
|
+
src_strategy_file,
|
|
636
|
+
dst_strategy_file=None):
|
|
637
|
+
"""Transform safetensor for stage in src_strategy_file"""
|
|
638
|
+
param_total_dict = defaultdict(dict)
|
|
639
|
+
param_attr_dict = defaultdict(dict)
|
|
640
|
+
param_type_dict = defaultdict(dict)
|
|
641
|
+
src_strategy_list, dst_strategy_list, stage_id = _extract_src_dst_layout_map_by_src(src_strategy_file, \
|
|
642
|
+
dst_strategy_file)
|
|
643
|
+
src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
|
|
644
|
+
is not None else 1
|
|
645
|
+
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
|
|
646
|
+
is not None else 1
|
|
647
|
+
origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
|
|
648
|
+
origin_src_strategy_list = _extract_layout_map(src_strategy_file)
|
|
649
|
+
safetensor_files_map = {}
|
|
650
|
+
src_rank_id_start = stage_id * src_stage_device_num
|
|
651
|
+
for local_rank in range(src_stage_device_num):
|
|
652
|
+
rank_id = src_rank_id_start + local_rank
|
|
653
|
+
safetensor_file_name = os.path.join(src_safetensors_dir, "rank_{}".format(rank_id), "*.safetensors")
|
|
654
|
+
rank_ckpts = glob.glob(safetensor_file_name)
|
|
655
|
+
rank_ckpts.sort()
|
|
656
|
+
for safetensor_file in rank_ckpts:
|
|
657
|
+
if not os.path.isfile(safetensor_file):
|
|
658
|
+
continue
|
|
659
|
+
safetensor_files_map[rank_id] = safetensor_file
|
|
660
|
+
for rank, local_file in safetensor_files_map.items():
|
|
661
|
+
if not os.path.exists(local_file):
|
|
662
|
+
raise ValueError("safetensor file {} in rank {} not exits: ".format(local_file, rank))
|
|
663
|
+
for rank, file_name in safetensor_files_map.items():
|
|
664
|
+
saftensor_dict = load_file(file_name)
|
|
665
|
+
for param_name, param in saftensor_dict.items():
|
|
666
|
+
# cut the parameter not in the pipeline stage.
|
|
667
|
+
if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
|
|
668
|
+
and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
|
|
669
|
+
continue
|
|
670
|
+
src_rank = rank % src_stage_device_num
|
|
671
|
+
param_type_dict[param_name][src_rank] = str(param.data.dtype)
|
|
672
|
+
param_total_dict[param_name][src_rank] = param
|
|
673
|
+
param_attr_dict[param_name][src_rank] = (True, False)
|
|
674
|
+
for local_rank_id in range(dst_stage_device_num):
|
|
675
|
+
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
|
|
676
|
+
param_attr_dict, src_strategy_list, dst_strategy_list,
|
|
677
|
+
param_type_dict)
|
|
678
|
+
save_safetensor_file = "{}{}_part{}.safetensors".format(ckpt_prefix, local_rank_id, stage_id)
|
|
679
|
+
save_safetensor_file_dir = os.path.join(dst_safetensors_dir, "rank_{}".format(local_rank_id))
|
|
680
|
+
if not os.path.exists(save_safetensor_file_dir):
|
|
681
|
+
_make_dir(save_safetensor_file_dir, "path")
|
|
682
|
+
save_safetensor_file_name = os.path.join(save_safetensor_file_dir, save_safetensor_file)
|
|
683
|
+
save_file(transform_param_dict, save_safetensor_file_name)
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor_file_name,
|
|
687
|
+
src_strategy_file=None, dst_strategy_file=None):
|
|
688
|
+
"""
|
|
689
|
+
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank.
|
|
690
|
+
"""
|
|
691
|
+
if not isinstance(safetensor_files_map, dict):
|
|
692
|
+
raise TypeError("The safetensor_files_map should be a dict.")
|
|
693
|
+
if not isinstance(rank_id, int):
|
|
694
|
+
raise TypeError("The rank_id should be a int.")
|
|
695
|
+
if not isinstance(save_safetensor_file_name, str):
|
|
696
|
+
raise TypeError("The save_safetensor_file_name should be a str.")
|
|
697
|
+
if not save_safetensor_file_name.endswith(".safetensors"):
|
|
698
|
+
raise ValueError(
|
|
699
|
+
"The save_safetensor_file_name {} should end with .safetensors".format(save_safetensor_file_name))
|
|
700
|
+
if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists(
|
|
701
|
+
os.path.dirname(dst_strategy_file)):
|
|
702
|
+
raise ValueError("The director of dst_strategy_file: {} is not exists.".
|
|
703
|
+
format(os.path.dirname(dst_strategy_file)))
|
|
704
|
+
for rank, local_file in safetensor_files_map.items():
|
|
705
|
+
if not os.path.exists(local_file):
|
|
706
|
+
raise ValueError("safetensor file {} in rank {} not exits: ".format(local_file, rank))
|
|
707
|
+
param_total_dict = defaultdict(dict)
|
|
708
|
+
param_attr_dict = defaultdict(dict)
|
|
709
|
+
param_type_dict = defaultdict(dict)
|
|
710
|
+
src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file)
|
|
711
|
+
# src rank => local rank inside pipeline stage
|
|
712
|
+
src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
|
|
713
|
+
is not None else 1
|
|
714
|
+
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
|
|
715
|
+
is not None else 1
|
|
716
|
+
origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
|
|
717
|
+
origin_src_strategy_list = _extract_layout_map(src_strategy_file)
|
|
718
|
+
for rank, file_name in safetensor_files_map.items():
|
|
719
|
+
saftensor_dict = load_file(file_name)
|
|
720
|
+
for param_name, param in saftensor_dict.items():
|
|
721
|
+
# cut the parameter not in the pipeline stage.
|
|
722
|
+
if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
|
|
723
|
+
and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
|
|
724
|
+
continue
|
|
725
|
+
src_rank = rank % src_stage_device_num
|
|
726
|
+
param_type_dict[param_name][src_rank] = str(param.data.dtype)
|
|
727
|
+
# if param.data.dtype == mstype.bfloat16:
|
|
728
|
+
# param.set_dtype(mstype.float32)
|
|
729
|
+
param_total_dict[param_name][src_rank] = param
|
|
730
|
+
param_attr_dict[param_name][src_rank] = (True, False)
|
|
731
|
+
local_rank_id = rank_id % dst_stage_device_num
|
|
732
|
+
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
|
|
733
|
+
param_attr_dict, src_strategy_list, dst_strategy_list,
|
|
734
|
+
param_type_dict)
|
|
735
|
+
save_file(transform_param_dict, save_safetensor_file_name)
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_suffix=None):
|
|
739
|
+
"""
|
|
740
|
+
Collects all safetensors files from the specified directory and its subdirectories.
|
|
741
|
+
"""
|
|
742
|
+
if os.path.isfile(src_safetensors_dir) and format == 'safetensors' and src_safetensors_dir.endswith('safetensors'):
|
|
743
|
+
return {0: src_safetensors_dir}
|
|
744
|
+
safetensors_rank_dir_list = os.path.join(src_safetensors_dir, "rank_[0-9]*")
|
|
745
|
+
all_safetensor_files_map = {}
|
|
746
|
+
for safetensor_dir in glob.glob(safetensors_rank_dir_list):
|
|
747
|
+
if not os.path.isdir(safetensor_dir):
|
|
748
|
+
ms.log.warning("{} is not a directory.".format(safetensor_dir))
|
|
749
|
+
continue
|
|
750
|
+
rank_id_str = safetensor_dir.split('rank_')[-1]
|
|
751
|
+
if not rank_id_str.isdigit():
|
|
752
|
+
ms.log.warning("{} is not a expected directory, the directory should end with rank_0/rank_1.....".
|
|
753
|
+
format(safetensor_dir))
|
|
754
|
+
continue
|
|
755
|
+
rank_id = int(rank_id_str)
|
|
756
|
+
if file_suffix is None:
|
|
757
|
+
safetensor_file_name = os.path.join(safetensor_dir, f"*.{format}")
|
|
758
|
+
else:
|
|
759
|
+
safetensor_file_name = os.path.join(safetensor_dir, f"*{file_suffix}.{format}")
|
|
760
|
+
rank_ckpts = glob.glob(safetensor_file_name)
|
|
761
|
+
rank_ckpts.sort()
|
|
762
|
+
for safetensor_file in rank_ckpts:
|
|
763
|
+
if not os.path.isfile(safetensor_file):
|
|
764
|
+
ms.log.warning("{} is not a safetensor file.".format(safetensor_file))
|
|
765
|
+
continue
|
|
766
|
+
all_safetensor_files_map[rank_id] = safetensor_file
|
|
767
|
+
return all_safetensor_files_map
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def _find_needed_ranks(src_strategy_dict, dst_strategy_dict):
|
|
771
|
+
"""
|
|
772
|
+
Identifies the ranks needed for transformation based on source and destination strategies.
|
|
773
|
+
"""
|
|
774
|
+
needed_rank_list_map = defaultdict(list)
|
|
775
|
+
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
776
|
+
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
777
|
+
dst_device_num = dst_stage_device_num * dst_stage_num
|
|
778
|
+
for rank in _progress_bar(range(dst_device_num)):
|
|
779
|
+
needed_rank_list = ms.rank_list_for_transform(rank, src_strategy_dict, dst_strategy_dict)
|
|
780
|
+
needed_rank_list_key = "-".join([str(r) for r in needed_rank_list])
|
|
781
|
+
needed_rank_list_map[needed_rank_list_key].append(rank)
|
|
782
|
+
return needed_rank_list_map
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def load_file_by_param_name(filename, parme_name_list):
|
|
786
|
+
result = {}
|
|
787
|
+
with safe_open(filename, framework="np") as f:
|
|
788
|
+
for k in parme_name_list:
|
|
789
|
+
result[k] = f.get_tensor(k)
|
|
790
|
+
return result
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
|
|
794
|
+
dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None):
|
|
795
|
+
"""
|
|
796
|
+
Transform model parallel dimension for distributed safetensor files.
|
|
797
|
+
"""
|
|
798
|
+
transform_param_dict = {}
|
|
799
|
+
device_num = -1
|
|
800
|
+
param_total_dict_keys = list(param_total_dict.keys()) if param_total_dict_keys is None else param_total_dict_keys
|
|
801
|
+
for param_name in param_total_dict_keys:
|
|
802
|
+
tensor_shape = list(param_total_dict[param_name].values())[0].shape
|
|
803
|
+
from_dev_matrix = [1]
|
|
804
|
+
from_tensor_map = [-1] * len(tensor_shape)
|
|
805
|
+
from_opt_shard_step = 0
|
|
806
|
+
from_opt_shard_size = 0
|
|
807
|
+
if src_strategy_list is not None:
|
|
808
|
+
if param_name not in src_strategy_list:
|
|
809
|
+
continue
|
|
810
|
+
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
|
|
811
|
+
src_strategy_list.get(param_name))
|
|
812
|
+
to_dev_matrix_origin = [1]
|
|
813
|
+
to_tensor_map_origin = [-1] * len(tensor_shape)
|
|
814
|
+
to_opt_shard_step = 0
|
|
815
|
+
to_opt_shard_size = 0
|
|
816
|
+
if dst_strategy_list is not None:
|
|
817
|
+
if param_name not in dst_strategy_list:
|
|
818
|
+
continue
|
|
819
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
|
|
820
|
+
dst_strategy_list.get(param_name))
|
|
821
|
+
# Add optimizer sharding dim for tensor layout
|
|
822
|
+
device_num = np.prod(from_dev_matrix)
|
|
823
|
+
if device_num < 1:
|
|
824
|
+
raise ValueError("None of the parameters in safetensor file are in either src strategy or "
|
|
825
|
+
"dst strategy. Please check correctness of strategy files. "
|
|
826
|
+
"Param name is: {}, rank_id is {}.".format(param_name, rank_id))
|
|
827
|
+
param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
|
|
828
|
+
origin_tensor_shape = ()
|
|
829
|
+
for i, item in enumerate(tensor_shape):
|
|
830
|
+
if i == 0 and from_opt_shard_size > 0:
|
|
831
|
+
origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
|
|
832
|
+
continue
|
|
833
|
+
origin_tensor_shape += (item * param_strategy[i],)
|
|
834
|
+
|
|
835
|
+
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
836
|
+
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
|
|
837
|
+
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
838
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
|
|
839
|
+
# Convert tensor layout to same device num
|
|
840
|
+
from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix,
|
|
841
|
+
from_tensor_map, to_full_tensor_shape,
|
|
842
|
+
to_dev_matrix, to_tensor_map)
|
|
843
|
+
|
|
844
|
+
# when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
|
|
845
|
+
device_list = list(range(0, np.prod(from_tensor_layout[0])))
|
|
846
|
+
if rank_id % device_num not in param_attr_dict[param_name] and src_strategy_file is None:
|
|
847
|
+
raise ValueError("The safetensor of rank {} is missing.".format(rank_id % device_num))
|
|
848
|
+
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
|
|
849
|
+
device_list, rank_id)
|
|
850
|
+
|
|
851
|
+
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
852
|
+
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
853
|
+
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
854
|
+
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
|
|
855
|
+
param_total_dict_copy = param_total_dict[param_name].copy()
|
|
856
|
+
_apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
|
|
857
|
+
|
|
858
|
+
transform_param_dict[param_name] = param_total_dict_copy[rank_id % device_num]
|
|
859
|
+
|
|
860
|
+
# Handle those parameter like learning_rate, global_step which not in strategy_file.
|
|
861
|
+
for param_name in param_total_dict_keys:
|
|
862
|
+
if param_name not in transform_param_dict:
|
|
863
|
+
transform_para = param_total_dict[param_name][rank_id % device_num]
|
|
864
|
+
transform_param_dict[param_name] = transform_para
|
|
865
|
+
return transform_param_dict
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None):
|
|
869
|
+
"""
|
|
870
|
+
Merge multiple safetensor files into a unified safetensor file.
|
|
871
|
+
|
|
872
|
+
Args:
|
|
873
|
+
src_dir (str): Source weight saving directory.
|
|
874
|
+
src_strategy_file (str): Source weight segmentation strategy file.
|
|
875
|
+
dst_dir (str): Target save directory.
|
|
876
|
+
merge_with_redundancy (bool, optional): Whether the merged source weight files are de-duplicated and
|
|
877
|
+
saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
|
|
878
|
+
file_suffix (str, optional): Specify the filename suffix for merging safetensors files. Default: ``None``,
|
|
879
|
+
meaning all safetensors files in the source weight directory will be merged.
|
|
880
|
+
|
|
881
|
+
Raises:
|
|
882
|
+
ValueError: If the safetensors file of rank is missing.
|
|
883
|
+
|
|
884
|
+
Supported Platforms:
|
|
885
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
886
|
+
|
|
887
|
+
Examples:
|
|
888
|
+
>>> import mindspore as ms
|
|
889
|
+
>>> src_dir = "/usr/safetensors/llama31B/4p_safetensors/"
|
|
890
|
+
>>> src_strategy_file = "/usr/safetensors/llama31B/strategy_4p.ckpt"
|
|
891
|
+
>>> dst_dir = "/usr/safetensors/llama31B/merge_llama31B_4p/"
|
|
892
|
+
>>> ms.unified_safetensors(src_dir, src_strategy_file, dst_dir)
|
|
893
|
+
"""
|
|
894
|
+
_check_transform_safetensors(src_dir, "", src_strategy_file, None)
|
|
895
|
+
_make_dir(dst_dir, "path")
|
|
896
|
+
if os.path.isfile(src_dir):
|
|
897
|
+
raise ValueError("For 'unified_safetensors', the 'src_dir' can not be a file.")
|
|
898
|
+
all_safetensor_files_map = _collect_safetensor_files(src_dir, format="safetensors", file_suffix=file_suffix)
|
|
899
|
+
all_ckpt_files_map = _collect_safetensor_files(src_dir, format="ckpt", file_suffix=file_suffix)
|
|
900
|
+
if all_safetensor_files_map and all_ckpt_files_map:
|
|
901
|
+
raise ValueError("For 'unified_safetensors', the 'src_dir' cannot contain "
|
|
902
|
+
"both ckpt file and safetensors file simultaneously")
|
|
903
|
+
src_strategy_dict = _build_searched_strategy(src_strategy_file)
|
|
904
|
+
src_stage_device_num = _get_device_num_from_strategy(src_strategy_dict)
|
|
905
|
+
dst_stage_device_num = 1
|
|
906
|
+
origin_src_strategy_list = _extract_layout_map(src_strategy_dict)
|
|
907
|
+
origin_dst_strategy_list = None
|
|
908
|
+
|
|
909
|
+
needed_rank_list_map = _find_needed_ranks(src_strategy_dict, dst_strategy_dict=None)
|
|
910
|
+
for needed_rank_list, rank in needed_rank_list_map.items():
|
|
911
|
+
for needed_rank in needed_rank_list.split("-"):
|
|
912
|
+
if int(needed_rank) not in all_safetensor_files_map:
|
|
913
|
+
raise ValueError("The safetensor file of rank{} is needed for converting rank{}'s safetensor, "
|
|
914
|
+
"but it is missing.".format(needed_rank, rank))
|
|
915
|
+
layout_map = _convert_to_list(src_strategy_dict)
|
|
916
|
+
|
|
917
|
+
total_size = 0
|
|
918
|
+
actual_params = set()
|
|
919
|
+
for _, file_name in all_safetensor_files_map.items():
|
|
920
|
+
total_size += os.path.getsize(file_name) / 1024 / 1024 / 1024
|
|
921
|
+
with safe_open(file_name, framework="np") as f:
|
|
922
|
+
actual_params.update(f.keys())
|
|
923
|
+
split_num = math.ceil(total_size / 3)
|
|
924
|
+
params_to_store = actual_params & set(layout_map.keys())
|
|
925
|
+
|
|
926
|
+
name_list = []
|
|
927
|
+
for name in list(params_to_store):
|
|
928
|
+
if name.startswith("accu_grads"):
|
|
929
|
+
continue
|
|
930
|
+
name_list.append(name)
|
|
931
|
+
split_list = _split_list(name_list, split_num)
|
|
932
|
+
|
|
933
|
+
with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
|
|
934
|
+
all_key = f.keys()
|
|
935
|
+
hyper_parameter = set(all_key) - set(name_list)
|
|
936
|
+
if hyper_parameter:
|
|
937
|
+
hyper_dict = {}
|
|
938
|
+
for key in hyper_parameter:
|
|
939
|
+
hyper_dict[key] = f.get_tensor(key)
|
|
940
|
+
save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
|
|
941
|
+
|
|
942
|
+
# save parameter map json
|
|
943
|
+
param_name_dict = dict()
|
|
944
|
+
for index, part_list in enumerate(split_list):
|
|
945
|
+
for name in part_list:
|
|
946
|
+
param_name_dict[name] = f"part{index}.safetensors"
|
|
947
|
+
json_str = json.dumps(param_name_dict, indent=4)
|
|
948
|
+
map_file = os.path.join(dst_dir, "param_name_map.json")
|
|
949
|
+
with open(map_file, 'w') as f:
|
|
950
|
+
f.write(json_str)
|
|
951
|
+
|
|
952
|
+
max_process = min(split_num, 100)
|
|
953
|
+
res = [i for i in range(split_num)]
|
|
954
|
+
res = _split_list(res, max_process)
|
|
955
|
+
processes = []
|
|
956
|
+
src_strategy_name = None
|
|
957
|
+
if not merge_with_redundancy:
|
|
958
|
+
src_strategy_name = src_strategy_file
|
|
959
|
+
for i in range(max_process):
|
|
960
|
+
p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
|
|
961
|
+
needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
|
|
962
|
+
src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
|
|
963
|
+
"", dst_dir, "safetensors", None, split_list, res[i], True, src_strategy_name))
|
|
964
|
+
p.start()
|
|
965
|
+
processes.append(p)
|
|
966
|
+
for p in processes:
|
|
967
|
+
p.join()
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor_files_map,
|
|
971
|
+
src_stage_device_num,
|
|
972
|
+
dst_stage_device_num,
|
|
973
|
+
src_strategy_dict, dst_strategy_dict, origin_src_strategy_list,
|
|
974
|
+
origin_dst_strategy_list,
|
|
975
|
+
ckpt_prefix, dst_safetensors_dir, output_format,
|
|
976
|
+
_transform_param_list, pipe_param_list=None, file_index=None,
|
|
977
|
+
unified_flag=False, src_strategy_file=None):
|
|
978
|
+
for i in file_index:
|
|
979
|
+
_transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
|
|
980
|
+
dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
|
|
981
|
+
origin_src_strategy_list,
|
|
982
|
+
origin_dst_strategy_list, ckpt_prefix, dst_safetensors_dir, output_format,
|
|
983
|
+
_transform_param_list, pipe_param_list[i], i, unified_flag, src_strategy_file)
|
|
984
|
+
|
|
985
|
+
|
|
986
|
+
def _split_list(split_list, split_num):
|
|
987
|
+
split_array = np.array_split(split_list, split_num)
|
|
988
|
+
return [array.tolist() for array in split_array]
|
|
989
|
+
|
|
990
|
+
|
|
991
|
+
def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num):
|
|
992
|
+
"""apply safetensors object operators"""
|
|
993
|
+
if not transform_operator_stack:
|
|
994
|
+
return sf_obj[:]
|
|
995
|
+
level = transform_operator_stack[-1][1]
|
|
996
|
+
level_operators = []
|
|
997
|
+
while True:
|
|
998
|
+
if not transform_operator_stack or (level != transform_operator_stack[-1][1]):
|
|
999
|
+
tmp_tensor_dict = {}
|
|
1000
|
+
if not level_operators:
|
|
1001
|
+
continue
|
|
1002
|
+
op_name = level_operators[0][2][0]
|
|
1003
|
+
for operator_pair in level_operators:
|
|
1004
|
+
rank_id = operator_pair[0]
|
|
1005
|
+
cur_level = operator_pair[1]
|
|
1006
|
+
operator = operator_pair[2]
|
|
1007
|
+
if operator[0] != op_name:
|
|
1008
|
+
raise ValueError("The operator in the same level should be equal in the transform tensor operator "
|
|
1009
|
+
"list, but the find {} and {} in level {}".format(op_name, operator[0], cur_level))
|
|
1010
|
+
if operator[0] != "AllConcat":
|
|
1011
|
+
sf_obj = _apply_operator(operator[0])(sf_obj, operator)
|
|
1012
|
+
continue
|
|
1013
|
+
for rank in operator[1][:-1]:
|
|
1014
|
+
if rank % device_num not in sf_obj:
|
|
1015
|
+
raise ValueError("The checkpoint file of rank {} is missing.".format(rank % device_num))
|
|
1016
|
+
allgather_list = [sf_obj for _ in operator[1][:-1]]
|
|
1017
|
+
tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
|
|
1018
|
+
if op_name == "AllConcat":
|
|
1019
|
+
for rank, value in tmp_tensor_dict.items():
|
|
1020
|
+
sf_obj = value
|
|
1021
|
+
level_operators.clear()
|
|
1022
|
+
if not transform_operator_stack:
|
|
1023
|
+
break
|
|
1024
|
+
operator_pair = transform_operator_stack.pop()
|
|
1025
|
+
level = operator_pair[1]
|
|
1026
|
+
level_operators.append(operator_pair)
|
|
1027
|
+
return sf_obj
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
def _load_parallel_checkpoint(total_safetensors_dir, dst_strategy_file, net=None, dst_safetensors_dir=None,
|
|
1031
|
+
rank_id=None):
|
|
1032
|
+
"""load parallel safetensors by merged file."""
|
|
1033
|
+
file_list = os.listdir(total_safetensors_dir)
|
|
1034
|
+
json_files = [file for file in file_list if file.endswith('.json')]
|
|
1035
|
+
if len(json_files) != 1:
|
|
1036
|
+
raise ValueError(f"For 'load_parallel_checkpoint', the number of json files in 'total_safetensors_dir' "
|
|
1037
|
+
f"must be 1, but got {len(json_files)}.")
|
|
1038
|
+
param_name_json = os.path.join(total_safetensors_dir, json_files[0])
|
|
1039
|
+
with open(param_name_json, 'r') as f:
|
|
1040
|
+
param_name_map = json.load(f)
|
|
1041
|
+
if dst_strategy_file is not None:
|
|
1042
|
+
_, dst_strategy_list = _extract_src_dst_layout_map(rank_id, None, dst_strategy_file)
|
|
1043
|
+
param_list = dst_strategy_list.keys()
|
|
1044
|
+
else:
|
|
1045
|
+
dst_strategy_list = None
|
|
1046
|
+
param_list = param_name_map.keys()
|
|
1047
|
+
|
|
1048
|
+
total_param = dict()
|
|
1049
|
+
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
|
|
1050
|
+
is not None else 1
|
|
1051
|
+
local_rank_id = rank_id % dst_stage_device_num
|
|
1052
|
+
for param_name in param_list:
|
|
1053
|
+
if param_name not in param_name_map:
|
|
1054
|
+
continue
|
|
1055
|
+
file_name = os.path.join(total_safetensors_dir, param_name_map[param_name])
|
|
1056
|
+
with safe_open(file_name, framework="np") as f:
|
|
1057
|
+
if param_name not in f.keys():
|
|
1058
|
+
continue
|
|
1059
|
+
sf_obj = f.get_slice(param_name)
|
|
1060
|
+
|
|
1061
|
+
tensor_shape = sf_obj.get_shape()
|
|
1062
|
+
from_dev_matrix = [1]
|
|
1063
|
+
from_tensor_map = [-1] * len(tensor_shape)
|
|
1064
|
+
from_opt_shard_step = 0
|
|
1065
|
+
from_opt_shard_size = 0
|
|
1066
|
+
if dst_strategy_list is not None:
|
|
1067
|
+
if param_name not in dst_strategy_list:
|
|
1068
|
+
continue
|
|
1069
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
|
|
1070
|
+
dst_strategy_list.get(param_name))
|
|
1071
|
+
|
|
1072
|
+
device_num = np.prod(from_dev_matrix)
|
|
1073
|
+
param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
|
|
1074
|
+
origin_tensor_shape = ()
|
|
1075
|
+
for i, item in enumerate(tensor_shape):
|
|
1076
|
+
if i == 0 and from_opt_shard_size > 0:
|
|
1077
|
+
origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
|
|
1078
|
+
continue
|
|
1079
|
+
origin_tensor_shape += (item * param_strategy[i],)
|
|
1080
|
+
|
|
1081
|
+
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
1082
|
+
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
|
|
1083
|
+
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
1084
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
|
|
1085
|
+
# Convert tensor layout to same device num
|
|
1086
|
+
from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape,
|
|
1087
|
+
from_dev_matrix,
|
|
1088
|
+
from_tensor_map,
|
|
1089
|
+
to_full_tensor_shape,
|
|
1090
|
+
to_dev_matrix, to_tensor_map)
|
|
1091
|
+
|
|
1092
|
+
# when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
|
|
1093
|
+
device_list = list(range(0, np.prod(from_tensor_layout[0])))
|
|
1094
|
+
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
|
|
1095
|
+
device_list, local_rank_id)
|
|
1096
|
+
|
|
1097
|
+
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
1098
|
+
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
1099
|
+
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
1100
|
+
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
|
|
1101
|
+
|
|
1102
|
+
slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
|
|
1103
|
+
else:
|
|
1104
|
+
slice_param = sf_obj[:]
|
|
1105
|
+
|
|
1106
|
+
total_param[param_name] = ms.Parameter(slice_param)
|
|
1107
|
+
|
|
1108
|
+
if 'hyper_param.safetensors' in file_list:
|
|
1109
|
+
hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
|
|
1110
|
+
with safe_open(hyper_parameter_file_name, framework="np") as f:
|
|
1111
|
+
for key in f.keys():
|
|
1112
|
+
total_param[key] = ms.Parameter(f.get_tensor(key))
|
|
1113
|
+
if net is not None:
|
|
1114
|
+
param_not_load, ckpt_not_load = ms.load_param_into_net(net, total_param)
|
|
1115
|
+
return param_not_load, ckpt_not_load
|
|
1116
|
+
_make_dir(os.path.join(dst_safetensors_dir, f"rank_{rank_id}"), "path")
|
|
1117
|
+
ms.save_checkpoint(total_param, os.path.join(dst_safetensors_dir, f"rank_{rank_id}", f"net.safetensors"),
|
|
1118
|
+
format='safetensors')
|
|
1119
|
+
return None
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
def _get_slice(rank_id, sf_obj, param_name, dst_strategy_list):
|
|
1123
|
+
"""get slice op"""
|
|
1124
|
+
tensor_shape = sf_obj.get_shape()
|
|
1125
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
|
|
1126
|
+
dst_strategy_list.get(param_name))
|
|
1127
|
+
# Add optimizer sharding dim for tensor layout
|
|
1128
|
+
to_dev_matrix, to_tensor_map, _ = _construct_tensor_layout_for_opt_shard(
|
|
1129
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, tensor_shape)
|
|
1130
|
+
slice_op = _load_tensor_shape(to_dev_matrix, to_tensor_map, full_shape=tensor_shape, rank_id=rank_id)
|
|
1131
|
+
shape = None
|
|
1132
|
+
if to_opt_shard_size > 0:
|
|
1133
|
+
to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
|
|
1134
|
+
to_slice_tensor_shape = ()
|
|
1135
|
+
for i, item in enumerate(tensor_shape):
|
|
1136
|
+
if i == 0 and to_opt_shard_size > 0:
|
|
1137
|
+
to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
|
|
1138
|
+
continue
|
|
1139
|
+
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
|
|
1140
|
+
shape = list(to_slice_tensor_shape)
|
|
1141
|
+
|
|
1142
|
+
return slice_op, shape
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
__all__ = ["_transform_safetensors", "transform_safetensors_by_stage",
|
|
1146
|
+
"transform_safetensors_by_rank", "ckpt_to_safetensors", "safetensors_to_ckpt", "unified_safetensors"]
|