mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -21,10 +21,12 @@ import binascii
|
|
|
21
21
|
import copy
|
|
22
22
|
import json
|
|
23
23
|
import os
|
|
24
|
+
import re
|
|
24
25
|
import shutil
|
|
25
26
|
import stat
|
|
26
27
|
import threading
|
|
27
28
|
from threading import Thread, RLock
|
|
29
|
+
from multiprocessing import Process
|
|
28
30
|
from collections import defaultdict, OrderedDict
|
|
29
31
|
from io import BytesIO
|
|
30
32
|
|
|
@@ -58,21 +60,25 @@ from mindspore.common.file_system import FileSystem, _register_basic_file_system
|
|
|
58
60
|
from mindspore.communication.management import get_rank, get_group_size
|
|
59
61
|
from mindspore.experimental import MapParameter
|
|
60
62
|
from mindspore.ops import Cast
|
|
61
|
-
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
|
63
|
+
from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
|
|
62
64
|
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
|
63
65
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
|
64
|
-
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
|
|
66
|
+
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
|
|
67
|
+
_get_device_num
|
|
68
|
+
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
|
|
65
69
|
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
|
66
|
-
_restore_group_info_list
|
|
70
|
+
_restore_group_info_list, _get_param_list_when_first_dim_sharded
|
|
67
71
|
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
68
72
|
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
69
73
|
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
|
70
|
-
from mindspore.
|
|
74
|
+
from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
|
|
75
|
+
_extract_pipeline_stage_num
|
|
76
|
+
from mindspore.train._utils import read_proto, get_parameter_redundancy
|
|
71
77
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
|
72
78
|
split_mindir, split_dynamic_mindir
|
|
73
79
|
from mindspore.common.generator import Generator
|
|
74
|
-
from
|
|
75
|
-
from
|
|
80
|
+
from safetensors.numpy import save_file
|
|
81
|
+
from safetensors import safe_open
|
|
76
82
|
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
77
83
|
|
|
78
84
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
@@ -116,6 +122,68 @@ def init_ckpt_file_system(fs: FileSystem):
|
|
|
116
122
|
init_ckpt_file_system(_ckpt_fs)
|
|
117
123
|
|
|
118
124
|
|
|
125
|
+
def _get_cur_rank_dp(parameter_layout_dict):
|
|
126
|
+
""" Get dp and tp from layout dict. """
|
|
127
|
+
pp_num = _get_auto_parallel_context("pipeline_stages")
|
|
128
|
+
dev_num = _get_device_num()
|
|
129
|
+
global_rank = get_rank()
|
|
130
|
+
pipe_size = dev_num // pp_num
|
|
131
|
+
initial_rank = (global_rank // pipe_size) * pipe_size
|
|
132
|
+
parameter_redundancy_dict = get_parameter_redundancy(
|
|
133
|
+
parameter_layout_dict, initial_rank)
|
|
134
|
+
value_len = sys.maxsize
|
|
135
|
+
min_value = ()
|
|
136
|
+
for key, value in parameter_redundancy_dict.items():
|
|
137
|
+
if "accu_grads" in key or "inputs" in key:
|
|
138
|
+
continue
|
|
139
|
+
for item in value:
|
|
140
|
+
if len(item) < value_len and global_rank in item:
|
|
141
|
+
value_len = len(item)
|
|
142
|
+
min_value = item
|
|
143
|
+
return min_value
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
|
|
147
|
+
"""
|
|
148
|
+
Find available checkpoint file path from all backup checkpoint files of current rank.
|
|
149
|
+
It suppose that checkpoint path contains substring 'rank_{rank_id}' which is used to
|
|
150
|
+
distinguish between different path.If cur_ckpt_path doesn't have 'rank_{rank_id}' substring, will return
|
|
151
|
+
cur_ckpt_path itself when cur_ckpt_path is exist, otherwise return None.
|
|
152
|
+
|
|
153
|
+
Note:
|
|
154
|
+
This API must be called after the communication is initialized because the cluster information
|
|
155
|
+
needs to be obtained internally.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
cur_ckpt_path (str): the checkpoint file path which cur rank needs.
|
|
159
|
+
cur_strategy_path (str): strategy file path for current rank.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
- new_ckpt_file (String), if found available checkpoint file , return it.
|
|
163
|
+
- None, if not found available checkpoint, return None.
|
|
164
|
+
|
|
165
|
+
Examples:
|
|
166
|
+
>>> import mindspore as ms
|
|
167
|
+
>>> from mindspore.communication import init
|
|
168
|
+
>>> from mindspore import get_ckpt_path_with_strategy
|
|
169
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
170
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
|
171
|
+
>>> init()
|
|
172
|
+
>>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
|
|
173
|
+
>>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
|
|
174
|
+
>>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
|
|
175
|
+
>>> print(ckpt_file_new)
|
|
176
|
+
"""
|
|
177
|
+
dp = _get_cur_rank_dp(cur_strategy_path)
|
|
178
|
+
pattern = r'rank_\d+'
|
|
179
|
+
for i in dp:
|
|
180
|
+
new_ckpt_path = re.sub(pattern, f"rank_{str(i)}", cur_ckpt_path)
|
|
181
|
+
if not os.path.isfile(new_ckpt_path):
|
|
182
|
+
continue
|
|
183
|
+
return new_ckpt_path
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
|
|
119
187
|
class ParamDictFuture:
|
|
120
188
|
def __init__(self, executor, param_dict_future):
|
|
121
189
|
self.executor = executor
|
|
@@ -252,57 +320,72 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|
|
252
320
|
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
|
|
253
321
|
|
|
254
322
|
|
|
255
|
-
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False
|
|
323
|
+
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False,
|
|
324
|
+
format="ckpt"):
|
|
256
325
|
"""Execute the process of saving checkpoint into file."""
|
|
257
326
|
try:
|
|
258
327
|
with _ckpt_mutex:
|
|
328
|
+
file_name_list = list(os.path.splitext(ckpt_file_name))
|
|
329
|
+
file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
|
|
330
|
+
tmp_name = ''.join(file_name_list)
|
|
259
331
|
if os.path.exists(ckpt_file_name):
|
|
260
332
|
os.chmod(ckpt_file_name, stat.S_IWUSR)
|
|
261
333
|
os.remove(ckpt_file_name)
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
334
|
+
if os.path.exists(tmp_name):
|
|
335
|
+
os.chmod(tmp_name, stat.S_IWUSR)
|
|
336
|
+
os.remove(tmp_name)
|
|
337
|
+
if format == "ckpt":
|
|
338
|
+
with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
|
|
339
|
+
plain_data = None
|
|
340
|
+
if enc_key is not None:
|
|
341
|
+
plain_data = BytesIO()
|
|
342
|
+
|
|
343
|
+
crc_num = 0
|
|
344
|
+
for name, value in data_list.items():
|
|
345
|
+
if name == "random_op":
|
|
346
|
+
_write_random_seed(name, value, f)
|
|
347
|
+
continue
|
|
348
|
+
if value[0] == "mapparameter":
|
|
349
|
+
_write_mapparameter(name, value, f, map_param_inc)
|
|
350
|
+
continue
|
|
351
|
+
if value[0] == "offload_parameter":
|
|
352
|
+
new_value = value[1:]
|
|
353
|
+
new_value[2] = value[3]
|
|
354
|
+
_write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
|
|
355
|
+
_offload_if_config(value[3])
|
|
356
|
+
continue
|
|
357
|
+
if value[1] == "str":
|
|
358
|
+
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
359
|
+
continue
|
|
360
|
+
if isinstance(value[2], np.ndarray):
|
|
361
|
+
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
362
|
+
continue
|
|
363
|
+
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
364
|
+
_write_hugeparameter(name, value, f)
|
|
365
|
+
continue
|
|
366
|
+
|
|
367
|
+
crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
368
|
+
|
|
369
|
+
if enc_key is not None:
|
|
370
|
+
plain_data.seek(0)
|
|
371
|
+
max_block_size = ENCRYPT_BLOCK_SIZE * 1024
|
|
299
372
|
block_data = plain_data.read(max_block_size)
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
373
|
+
while block_data:
|
|
374
|
+
f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
|
|
375
|
+
block_data = plain_data.read(max_block_size)
|
|
376
|
+
if crc_check:
|
|
377
|
+
f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
|
|
378
|
+
elif format == "safetensors":
|
|
379
|
+
save_dict = {}
|
|
380
|
+
for name, value in data_list.items():
|
|
381
|
+
save_dict[name] = value[2].asnumpy()
|
|
382
|
+
save_file(save_dict, tmp_name)
|
|
383
|
+
if not os.path.exists(tmp_name):
|
|
384
|
+
logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
|
|
385
|
+
f"simultaneously modified a file.")
|
|
386
|
+
else:
|
|
387
|
+
os.rename(tmp_name, ckpt_file_name)
|
|
304
388
|
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
305
|
-
|
|
306
389
|
except BaseException as e:
|
|
307
390
|
logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
|
|
308
391
|
"or the disk space is insufficient and so on.", ckpt_file_name)
|
|
@@ -415,8 +498,11 @@ def _write_hugeparameter(name, value, f):
|
|
|
415
498
|
offset += numpy_data.shape[0]
|
|
416
499
|
|
|
417
500
|
|
|
418
|
-
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
501
|
+
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
|
|
419
502
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
503
|
+
if format not in ["safetensors", "ckpt"]:
|
|
504
|
+
raise ValueError(f"For 'save_checkpoint', the format must be "
|
|
505
|
+
f"'safetensors' or 'ckpt', but got {format}.")
|
|
420
506
|
if not isinstance(save_obj, (nn.Cell, list, dict)):
|
|
421
507
|
raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
|
|
422
508
|
"but got {}.".format(type(save_obj)))
|
|
@@ -424,18 +510,26 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
424
510
|
raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
|
|
425
511
|
"'ckpt_file_name' must be "
|
|
426
512
|
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
|
|
427
|
-
ckpt_file_name = os.path.
|
|
513
|
+
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
|
428
514
|
if os.path.isdir(ckpt_file_name):
|
|
429
515
|
raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, "
|
|
430
516
|
"it must be a file name.".format(ckpt_file_name))
|
|
431
|
-
if not ckpt_file_name.endswith(
|
|
432
|
-
ckpt_file_name += ".
|
|
517
|
+
if not ckpt_file_name.endswith(format):
|
|
518
|
+
ckpt_file_name += f".{format}"
|
|
433
519
|
return ckpt_file_name
|
|
434
520
|
|
|
435
521
|
|
|
522
|
+
def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, map_param_inc=False,
|
|
523
|
+
global_step_num=None):
|
|
524
|
+
param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
|
|
525
|
+
or map_param_inc or global_step_num is not None)
|
|
526
|
+
if format == "safetensors" and param_not_default:
|
|
527
|
+
raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
|
|
528
|
+
|
|
529
|
+
|
|
436
530
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
437
531
|
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
|
|
438
|
-
crc_check=False, **kwargs):
|
|
532
|
+
crc_check=False, format="ckpt", **kwargs):
|
|
439
533
|
r"""
|
|
440
534
|
Save checkpoint to a specified file.
|
|
441
535
|
|
|
@@ -465,6 +559,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
465
559
|
be saved. Default: ``None`` .
|
|
466
560
|
crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
|
|
467
561
|
result to the file. Default: ``False`` .
|
|
562
|
+
format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
468
563
|
kwargs (dict): Configuration options dictionary.
|
|
469
564
|
|
|
470
565
|
Raises:
|
|
@@ -498,7 +593,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
498
593
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
499
594
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
500
595
|
"""
|
|
501
|
-
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
|
|
596
|
+
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
|
|
502
597
|
integrated_save = Validator.check_bool(integrated_save)
|
|
503
598
|
async_save = Validator.check_bool(async_save)
|
|
504
599
|
append_dict = _check_append_dict(append_dict)
|
|
@@ -508,10 +603,19 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
508
603
|
map_param_inc = kwargs.get('incremental', False)
|
|
509
604
|
logger.info("Execute the process of saving checkpoint files.")
|
|
510
605
|
global_step_num = kwargs.get('global_step_num', None)
|
|
606
|
+
_check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, map_param_inc, global_step_num)
|
|
511
607
|
|
|
512
|
-
|
|
608
|
+
if append_dict and "__exception_save__" in append_dict:
|
|
609
|
+
s1 = mindspore.hal.Stream()
|
|
610
|
+
with mindspore.hal.StreamCtx(s1):
|
|
611
|
+
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
612
|
+
s1.synchronize()
|
|
613
|
+
else:
|
|
614
|
+
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
513
615
|
|
|
514
616
|
if append_dict:
|
|
617
|
+
if "__exception_save__" in append_dict:
|
|
618
|
+
del append_dict["__exception_save__"]
|
|
515
619
|
append_info_list = []
|
|
516
620
|
for k_name, value in append_dict.items():
|
|
517
621
|
if isinstance(value, Generator):
|
|
@@ -527,12 +631,17 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
527
631
|
for param in save_obj:
|
|
528
632
|
if param["name"] == "random_op":
|
|
529
633
|
if os.getenv("AITURBO") == "1":
|
|
530
|
-
data_list_np["random_op"] =
|
|
634
|
+
data_list_np["random_op"] = []
|
|
635
|
+
data_list_np["random_op"].append(param["data"])
|
|
636
|
+
if crc_check:
|
|
637
|
+
bytes_value = bytes(data_list_np[key][0])
|
|
638
|
+
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
531
639
|
else:
|
|
532
640
|
data_list["random_op"] = param["data"]
|
|
533
641
|
continue
|
|
534
642
|
key = param["name"]
|
|
535
643
|
data_list[key] = []
|
|
644
|
+
data_list_np[key] = []
|
|
536
645
|
if isinstance(param["data"], MapParameter):
|
|
537
646
|
data_list[param["name"]].append("mapparameter")
|
|
538
647
|
data_list[param["name"]].append(param["data"])
|
|
@@ -546,7 +655,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
546
655
|
|
|
547
656
|
if isinstance(param["data"], str):
|
|
548
657
|
if os.getenv("AITURBO") == "1":
|
|
549
|
-
data_list_np[key]
|
|
658
|
+
data_list_np[key].append(np.array(param["data"]))
|
|
659
|
+
if crc_check:
|
|
660
|
+
bytes_value = data_list_np[key][0].tobytes()
|
|
661
|
+
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
550
662
|
else:
|
|
551
663
|
data_list[key].append([0])
|
|
552
664
|
data_list[key].append('str')
|
|
@@ -556,7 +668,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
556
668
|
if isinstance(param["data"], Parameter):
|
|
557
669
|
param["data"].init_data()
|
|
558
670
|
if os.getenv("AITURBO") == "1":
|
|
559
|
-
data_list_np[key]
|
|
671
|
+
data_list_np[key].append(param["data"].asnumpy())
|
|
672
|
+
if crc_check:
|
|
673
|
+
bytes_value = data_list_np[key][0].tobytes()
|
|
674
|
+
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
560
675
|
else:
|
|
561
676
|
dims = []
|
|
562
677
|
for dim in param['data'].shape:
|
|
@@ -568,16 +683,17 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
568
683
|
data_list[key].append(data)
|
|
569
684
|
|
|
570
685
|
if os.getenv("AITURBO") == "1":
|
|
571
|
-
import aiturbo
|
|
686
|
+
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
572
687
|
ckpt_name = os.path.basename(ckpt_file_name)
|
|
573
|
-
aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np)
|
|
688
|
+
aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
|
|
574
689
|
elif async_save:
|
|
575
690
|
data_copy = copy.deepcopy(data_list)
|
|
576
|
-
thr = Thread(target=_exec_save,
|
|
691
|
+
thr = Thread(target=_exec_save,
|
|
692
|
+
args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
|
|
577
693
|
name="asyn_save_ckpt")
|
|
578
694
|
thr.start()
|
|
579
695
|
else:
|
|
580
|
-
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check)
|
|
696
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
|
|
581
697
|
|
|
582
698
|
logger.info("Saving checkpoint process is finished.")
|
|
583
699
|
|
|
@@ -692,11 +808,14 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
692
808
|
param_data.append(value.key)
|
|
693
809
|
else:
|
|
694
810
|
param_data = value.data
|
|
811
|
+
if append_dict and "__exception_save__" in append_dict:
|
|
812
|
+
param_data = Tensor(Tensor_.move_to(value, "CPU", False))
|
|
695
813
|
|
|
696
814
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
697
815
|
# which should be combined before saving
|
|
698
816
|
if key in parameter_layout_dict:
|
|
699
|
-
|
|
817
|
+
if not append_dict or "__exception_save__" not in append_dict:
|
|
818
|
+
param_data = Tensor(value.data)
|
|
700
819
|
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
701
820
|
integrated_save)
|
|
702
821
|
|
|
@@ -812,7 +931,7 @@ def load(file_name, **kwargs):
|
|
|
812
931
|
if not os.path.exists(file_name):
|
|
813
932
|
raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, "
|
|
814
933
|
"please check whether the 'file_name' is correct.")
|
|
815
|
-
file_name = os.path.
|
|
934
|
+
file_name = os.path.realpath(file_name)
|
|
816
935
|
|
|
817
936
|
# set customized functions for dynamic obfuscation
|
|
818
937
|
obfuscated = _check_load_obfuscate(**kwargs)
|
|
@@ -875,7 +994,7 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=T
|
|
|
875
994
|
if not os.path.exists(file_name):
|
|
876
995
|
raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, "
|
|
877
996
|
"please check whether the 'file_name' is correct.")
|
|
878
|
-
file_name = os.path.
|
|
997
|
+
file_name = os.path.realpath(file_name)
|
|
879
998
|
|
|
880
999
|
logger.info("Execute the process of export and split mindir.")
|
|
881
1000
|
dynamic = True
|
|
@@ -1074,9 +1193,14 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
1074
1193
|
|
|
1075
1194
|
|
|
1076
1195
|
def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1077
|
-
dec_mode, crc_check):
|
|
1196
|
+
dec_mode, crc_check, format):
|
|
1078
1197
|
"""load parameter into parameter_dict"""
|
|
1079
|
-
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
1198
|
+
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
|
|
1199
|
+
if format == "safetensors":
|
|
1200
|
+
with safe_open(ckpt_file_name, framework='np') as f:
|
|
1201
|
+
for k in f.keys():
|
|
1202
|
+
parameter_dict[k] = Parameter(f.get_tensor(k))
|
|
1203
|
+
return
|
|
1080
1204
|
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
|
|
1081
1205
|
try:
|
|
1082
1206
|
param_data_list = []
|
|
@@ -1138,7 +1262,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1138
1262
|
|
|
1139
1263
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
|
1140
1264
|
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None,
|
|
1141
|
-
crc_check=False):
|
|
1265
|
+
crc_check=False, remove_redundancy=False, format="ckpt"):
|
|
1142
1266
|
"""
|
|
1143
1267
|
Load checkpoint info from a specified file.
|
|
1144
1268
|
|
|
@@ -1148,6 +1272,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1148
1272
|
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
1149
1273
|
`choice_func` is recommended instead.
|
|
1150
1274
|
And using either of those two args will override `choice_func` at the same time.
|
|
1275
|
+
- When loading a checkpoint that has removed redundancy, the network should be compiled.
|
|
1151
1276
|
|
|
1152
1277
|
Args:
|
|
1153
1278
|
ckpt_file_name (str): Checkpoint file name.
|
|
@@ -1170,6 +1295,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1170
1295
|
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1171
1296
|
matches the custom condition will be removed. Default: ``None`` .
|
|
1172
1297
|
crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
|
|
1298
|
+
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1299
|
+
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1300
|
+
redundant-free loading is not enabled.
|
|
1301
|
+
format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
1173
1302
|
|
|
1174
1303
|
Returns:
|
|
1175
1304
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
@@ -1219,24 +1348,35 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1219
1348
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
1220
1349
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
1221
1350
|
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
1351
|
+
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1352
|
+
_check_format_and_other_params(format, dec_key, dec_mode, crc_check)
|
|
1222
1353
|
logger.info("Execute the process of loading checkpoint files.")
|
|
1223
1354
|
|
|
1224
1355
|
parameter_dict = {}
|
|
1225
1356
|
|
|
1226
1357
|
if os.getenv("AITURBO") == "1":
|
|
1227
1358
|
rank_id = get_rank()
|
|
1228
|
-
import aiturbo
|
|
1359
|
+
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
1229
1360
|
ckpt_path = os.path.dirname(ckpt_file_name)
|
|
1230
1361
|
ckpt_name = os.path.basename(ckpt_file_name)
|
|
1231
|
-
np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id)
|
|
1362
|
+
np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id, crc_check)
|
|
1232
1363
|
for key, value in np_dict.items():
|
|
1364
|
+
if crc_check and len(value) != 2:
|
|
1365
|
+
raise ValueError(f"When loading a checkpoint from AITurbo, if CRC check is enabled, "
|
|
1366
|
+
f"the length of the value must be 2, but got {len(value)}.")
|
|
1233
1367
|
if isinstance(value, str):
|
|
1234
|
-
|
|
1368
|
+
if crc_check and value[1] != binascii.crc32(np.array(value[0]).tobytes()):
|
|
1369
|
+
raise ValueError(f"When loading a checkpoint from AITurbo, the value of the string has not "
|
|
1370
|
+
f"passed the CRC check and has been corrupted.")
|
|
1371
|
+
parameter_dict[key] = value[0]
|
|
1235
1372
|
else:
|
|
1236
|
-
|
|
1373
|
+
if crc_check and value[1] != binascii.crc32(value[0].tobytes()):
|
|
1374
|
+
raise ValueError(f"When loading a checkpoint from AITurbo, the value of the parameter has not "
|
|
1375
|
+
f"passed the CRC check and has been corrupted.")
|
|
1376
|
+
parameter_dict[key] = Parameter(Tensor(value[0]), name=key)
|
|
1237
1377
|
else:
|
|
1238
1378
|
_load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1239
|
-
dec_mode, crc_check)
|
|
1379
|
+
dec_mode, crc_check, format)
|
|
1240
1380
|
|
|
1241
1381
|
if not parameter_dict:
|
|
1242
1382
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
@@ -1245,7 +1385,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1245
1385
|
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1246
1386
|
(is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
|
|
1247
1387
|
if net is not None:
|
|
1248
|
-
load_param_into_net(net, parameter_dict, strict_load)
|
|
1388
|
+
load_param_into_net(net, parameter_dict, strict_load, remove_redundancy)
|
|
1249
1389
|
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1250
1390
|
_warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
|
|
1251
1391
|
|
|
@@ -1362,17 +1502,20 @@ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
|
|
|
1362
1502
|
parameter_dict[element.tag] = map_array
|
|
1363
1503
|
|
|
1364
1504
|
|
|
1365
|
-
def _check_ckpt_file_name(ckpt_file_name):
|
|
1505
|
+
def _check_ckpt_file_name(ckpt_file_name, format):
|
|
1366
1506
|
"""Check function load_checkpoint's ckpt_file_name."""
|
|
1367
1507
|
if not isinstance(ckpt_file_name, str):
|
|
1368
1508
|
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
|
|
1369
1509
|
"but got {}.".format(type(ckpt_file_name)))
|
|
1370
1510
|
|
|
1371
|
-
if
|
|
1372
|
-
raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please "
|
|
1511
|
+
if format not in ['ckpt', 'safetensors']:
|
|
1512
|
+
raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt' or '.safetensors', please "
|
|
1373
1513
|
"input the correct 'ckpt_file_name'.")
|
|
1514
|
+
if not ckpt_file_name.endswith(format):
|
|
1515
|
+
raise ValueError(f"For 'load_checkpoint', the checkpoint file format must same with 'format', but got "
|
|
1516
|
+
f"file_name:'{ckpt_file_name}', format:'{format}'")
|
|
1374
1517
|
|
|
1375
|
-
ckpt_file_name = os.path.
|
|
1518
|
+
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
|
1376
1519
|
if not os.path.exists(ckpt_file_name):
|
|
1377
1520
|
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
|
|
1378
1521
|
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
|
|
@@ -1414,7 +1557,7 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
|
|
|
1414
1557
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
|
1415
1558
|
if pb_content is None:
|
|
1416
1559
|
raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
|
|
1417
|
-
if crc_check and pb_content[-17:-10]
|
|
1560
|
+
if crc_check and pb_content[-17:-10] != b"crc_num":
|
|
1418
1561
|
logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
|
|
1419
1562
|
if pb_content[-17:-10] == b"crc_num":
|
|
1420
1563
|
crc_num_bytes = pb_content[-10:]
|
|
@@ -1426,6 +1569,9 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
|
|
|
1426
1569
|
raise ValueError("For 'load_checkpoint', the crc check is failed, "
|
|
1427
1570
|
"please check whether the ckpt file is damaged.")
|
|
1428
1571
|
checkpoint_list.ParseFromString(pb_content)
|
|
1572
|
+
except google.protobuf.message.DecodeError as e:
|
|
1573
|
+
raise ValueError(f"Failed to read the checkpoint file {ckpt_file_name}. "
|
|
1574
|
+
f"The file may be corrupted, and the content cannot be parsed.") from e
|
|
1429
1575
|
except BaseException as e:
|
|
1430
1576
|
if _is_cipher_file(ckpt_file_name):
|
|
1431
1577
|
err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \
|
|
@@ -1455,19 +1601,6 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
|
|
1455
1601
|
return whether_load
|
|
1456
1602
|
|
|
1457
1603
|
|
|
1458
|
-
def _init_parameter_data_in_parallel_mode(net, parameter_dict):
|
|
1459
|
-
"""In parallel mode, only init the paraemters in ckpt."""
|
|
1460
|
-
is_train_phase = net.phase.startswith('train')
|
|
1461
|
-
for _, param in net.parameters_and_names():
|
|
1462
|
-
if param.name in parameter_dict and param.from_ckpt and not is_train_phase:
|
|
1463
|
-
param.shape = tuple(parameter_dict[param.name].shape)
|
|
1464
|
-
continue
|
|
1465
|
-
if param.name in parameter_dict and param.has_init:
|
|
1466
|
-
logger.warning("{} is not init while load ckpt.".format(param.name))
|
|
1467
|
-
new_tensor = param.init_data()
|
|
1468
|
-
param._update_tensor_data(new_tensor)
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
1604
|
def _check_load_param_into_net(net, parameter_dict):
|
|
1472
1605
|
"""check load_param_into_net"""
|
|
1473
1606
|
if not isinstance(net, nn.Cell):
|
|
@@ -1484,10 +1617,13 @@ def _check_load_param_into_net(net, parameter_dict):
|
|
|
1484
1617
|
parameter_dict.pop("random_op")
|
|
1485
1618
|
|
|
1486
1619
|
|
|
1487
|
-
def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
1620
|
+
def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False):
|
|
1488
1621
|
"""
|
|
1489
1622
|
Load parameters into network, return parameter list that are not loaded in the network.
|
|
1490
1623
|
|
|
1624
|
+
Note:
|
|
1625
|
+
- When loading a parameter dict that has removed redundancy, the network should be compiled.
|
|
1626
|
+
|
|
1491
1627
|
Args:
|
|
1492
1628
|
net (Cell): The network where the parameters will be loaded.
|
|
1493
1629
|
parameter_dict (dict): The dictionary generated by load checkpoint file,
|
|
@@ -1496,6 +1632,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1496
1632
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
1497
1633
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
1498
1634
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
1635
|
+
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1636
|
+
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1637
|
+
redundant-free loading is not enabled.
|
|
1499
1638
|
|
|
1500
1639
|
Returns:
|
|
1501
1640
|
- param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
@@ -1529,13 +1668,10 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1529
1668
|
raise TypeError(msg)
|
|
1530
1669
|
|
|
1531
1670
|
strict_load = Validator.check_bool(strict_load)
|
|
1671
|
+
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1532
1672
|
logger.info("Execute the process of loading parameters into net.")
|
|
1533
1673
|
for _, param in net.parameters_and_names():
|
|
1534
1674
|
param.from_ckpt = True
|
|
1535
|
-
if not _is_in_auto_parallel_mode():
|
|
1536
|
-
net.init_parameters_data()
|
|
1537
|
-
else:
|
|
1538
|
-
_init_parameter_data_in_parallel_mode(net, parameter_dict)
|
|
1539
1675
|
param_not_load = []
|
|
1540
1676
|
ckpt_not_load = list(parameter_dict.keys())
|
|
1541
1677
|
for _, param in net.parameters_and_names():
|
|
@@ -1548,6 +1684,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1548
1684
|
continue
|
|
1549
1685
|
new_param = parameter_dict[param.name]
|
|
1550
1686
|
_update_param(param, new_param, strict_load)
|
|
1687
|
+
if hasattr(param, "init_param") and not param.init_param:
|
|
1688
|
+
param.init_param = True
|
|
1551
1689
|
ckpt_not_load.remove(param.name)
|
|
1552
1690
|
else:
|
|
1553
1691
|
param_not_load.append(param.name)
|
|
@@ -1560,16 +1698,26 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1560
1698
|
logger.warning("For 'load_param_into_net', "
|
|
1561
1699
|
"{} parameters in the 'net' are not loaded, because they are not in the "
|
|
1562
1700
|
"'parameter_dict', please check whether the network structure is consistent "
|
|
1563
|
-
"when training and loading checkpoint."
|
|
1701
|
+
"when training and loading checkpoint. Another possibility is that "
|
|
1702
|
+
"the redundant loading is not enabled, but the loaded checkpoint is saved with "
|
|
1703
|
+
"redundancy removed. ".format(len(param_not_load)))
|
|
1564
1704
|
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1565
|
-
if
|
|
1705
|
+
if remove_redundancy:
|
|
1706
|
+
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
1707
|
+
if parallel_mode == "stand_alone":
|
|
1708
|
+
raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
|
|
1709
|
+
f"in parallel scenarios, but got {parallel_mode}.")
|
|
1710
|
+
if not net.compile_cache and not net.parameter_layout_dict:
|
|
1711
|
+
raise ValueError("When loading a parameter dict that has removed redundancy, "
|
|
1712
|
+
"the network should be compiled.")
|
|
1566
1713
|
param_layout = net.parameter_layout_dict
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1714
|
+
rank_id = get_rank()
|
|
1715
|
+
device_num = _get_device_num()
|
|
1716
|
+
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
1717
|
+
chunk_size = device_num // stage_num
|
|
1718
|
+
initial_rank = (rank_id // chunk_size) * chunk_size
|
|
1719
|
+
_single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
|
|
1720
|
+
|
|
1573
1721
|
return param_not_load, ckpt_not_load
|
|
1574
1722
|
|
|
1575
1723
|
|
|
@@ -1662,6 +1810,8 @@ def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_loa
|
|
|
1662
1810
|
if param.name in param_not_load and new_param_name in parameter_dict:
|
|
1663
1811
|
new_param = parameter_dict[new_param_name]
|
|
1664
1812
|
_update_param(param, new_param, strict_load)
|
|
1813
|
+
if hasattr(param, "init_param") and not param.init_param:
|
|
1814
|
+
param.init_param = True
|
|
1665
1815
|
param_not_load.remove(param.name)
|
|
1666
1816
|
|
|
1667
1817
|
|
|
@@ -1675,7 +1825,7 @@ def _save_graph(network, file_name):
|
|
|
1675
1825
|
"""
|
|
1676
1826
|
logger.info("Execute the process of saving graph.")
|
|
1677
1827
|
|
|
1678
|
-
file_name = os.path.
|
|
1828
|
+
file_name = os.path.realpath(file_name)
|
|
1679
1829
|
graph_pb = network.get_func_graph_proto()
|
|
1680
1830
|
if graph_pb:
|
|
1681
1831
|
with open(file_name, "wb") as f:
|
|
@@ -1790,7 +1940,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1790
1940
|
- AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
|
|
1791
1941
|
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
|
|
1792
1942
|
- MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format
|
|
1793
|
-
for MindSpore models.
|
|
1943
|
+
for MindSpore models. MINDIR does not support operators which have dictionary attribute.
|
|
1794
1944
|
|
|
1795
1945
|
kwargs (dict): Configuration options dictionary.
|
|
1796
1946
|
|
|
@@ -1889,7 +2039,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1889
2039
|
+ str(columns))
|
|
1890
2040
|
inputs = tuple(inputs_col)
|
|
1891
2041
|
|
|
1892
|
-
file_name = os.path.
|
|
2042
|
+
file_name = os.path.realpath(file_name)
|
|
1893
2043
|
if 'enc_key' in kwargs.keys():
|
|
1894
2044
|
kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
|
|
1895
2045
|
_export(net, file_name, file_format, *inputs, **kwargs)
|
|
@@ -1982,8 +2132,8 @@ def _save_air(net, file_name, *inputs, **kwargs):
|
|
|
1982
2132
|
if os.path.exists(file_name):
|
|
1983
2133
|
os.chmod(file_name, stat.S_IWUSR)
|
|
1984
2134
|
if "/" in file_name:
|
|
1985
|
-
real_path = os.path.
|
|
1986
|
-
os.makedirs(real_path, exist_ok=True)
|
|
2135
|
+
real_path = os.path.realpath(file_name[:file_name.rfind("/")])
|
|
2136
|
+
os.makedirs(real_path, mode=0o700, exist_ok=True)
|
|
1987
2137
|
if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
|
|
1988
2138
|
_executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode'))
|
|
1989
2139
|
else:
|
|
@@ -2093,12 +2243,12 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
2093
2243
|
file_prefix = file_name.split("/")[-1]
|
|
2094
2244
|
if file_prefix.endswith(".mindir"):
|
|
2095
2245
|
file_prefix = file_prefix[:-7]
|
|
2096
|
-
current_path = os.path.
|
|
2246
|
+
current_path = os.path.realpath(file_name)
|
|
2097
2247
|
dirname = os.path.dirname(current_path)
|
|
2098
2248
|
data_path = os.path.join(dirname, file_prefix + "_variables")
|
|
2099
2249
|
if os.path.exists(data_path):
|
|
2100
2250
|
shutil.rmtree(data_path)
|
|
2101
|
-
os.makedirs(data_path, exist_ok=True)
|
|
2251
|
+
os.makedirs(data_path, mode=0o700, exist_ok=True)
|
|
2102
2252
|
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
|
|
2103
2253
|
index = 0
|
|
2104
2254
|
external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
|
|
@@ -2267,9 +2417,9 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
2267
2417
|
"the data of parameter cannot be exported.".format(map_param_proto.name))
|
|
2268
2418
|
if not file_name.endswith('.mindir'):
|
|
2269
2419
|
file_name += ".mindir"
|
|
2270
|
-
current_path = os.path.
|
|
2420
|
+
current_path = os.path.realpath(file_name)
|
|
2271
2421
|
dirname = os.path.dirname(current_path)
|
|
2272
|
-
os.makedirs(dirname, exist_ok=True)
|
|
2422
|
+
os.makedirs(dirname, mode=0o700, exist_ok=True)
|
|
2273
2423
|
if os.path.exists(file_name):
|
|
2274
2424
|
os.chmod(file_name, stat.S_IWUSR)
|
|
2275
2425
|
with open(file_name, 'wb') as f:
|
|
@@ -2398,7 +2548,7 @@ def parse_print(print_file_name):
|
|
|
2398
2548
|
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
|
2399
2549
|
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
|
2400
2550
|
"""
|
|
2401
|
-
print_file_path = os.path.
|
|
2551
|
+
print_file_path = os.path.realpath(print_file_name)
|
|
2402
2552
|
|
|
2403
2553
|
if os.path.getsize(print_file_path) == 0:
|
|
2404
2554
|
raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
|
|
@@ -2687,14 +2837,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
2687
2837
|
return merged_parameter
|
|
2688
2838
|
|
|
2689
2839
|
|
|
2690
|
-
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None,
|
|
2691
|
-
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'
|
|
2840
|
+
def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
|
|
2841
|
+
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
|
|
2842
|
+
format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
|
|
2692
2843
|
"""
|
|
2693
2844
|
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
2694
2845
|
|
|
2695
2846
|
Args:
|
|
2696
2847
|
network (Cell): Network for distributed predication.
|
|
2697
|
-
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
|
|
2848
|
+
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
|
|
2698
2849
|
predict_strategy (dict): Strategy of predication process. It means that using one device to predict
|
|
2699
2850
|
when setting predict_strategy as None. Default: ``None`` .
|
|
2700
2851
|
train_strategy_filename (str): The filename of training strategy protocol buffer file.
|
|
@@ -2711,6 +2862,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2711
2862
|
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
|
|
2712
2863
|
mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
|
|
2713
2864
|
Default: ``'AES-GCM'`` .
|
|
2865
|
+
format (str): Input weight format to be loaded into the network.
|
|
2866
|
+
It can be set to either "ckpt" or "safetensors". Default: "ckpt".
|
|
2867
|
+
unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
|
|
2868
|
+
Default: ``None`` .
|
|
2869
|
+
dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
|
|
2870
|
+
rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
|
|
2871
|
+
globally by initializing the network; In save mode, save the file according to the input
|
|
2872
|
+
sequence number. If it is not input, save the entire file.
|
|
2714
2873
|
|
|
2715
2874
|
Raises:
|
|
2716
2875
|
TypeError: The type of inputs do not match the requirements.
|
|
@@ -2725,14 +2884,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2725
2884
|
|
|
2726
2885
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2727
2886
|
Please see the `rank table startup
|
|
2728
|
-
<https://www.mindspore.cn/
|
|
2887
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
|
|
2729
2888
|
for more details.
|
|
2730
2889
|
|
|
2731
2890
|
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
2732
|
-
<https://www.mindspore.cn/
|
|
2891
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
|
|
2733
2892
|
|
|
2734
2893
|
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2735
|
-
Startup <https://www.mindspore.cn/
|
|
2894
|
+
Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
|
|
2736
2895
|
|
|
2737
2896
|
>>> import os
|
|
2738
2897
|
>>> import numpy as np
|
|
@@ -2814,6 +2973,54 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2814
2973
|
...
|
|
2815
2974
|
[ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
|
|
2816
2975
|
"""
|
|
2976
|
+
if format not in ['safetensors', 'ckpt']:
|
|
2977
|
+
raise ValueError(
|
|
2978
|
+
f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
|
|
2979
|
+
|
|
2980
|
+
if format == 'safetensors':
|
|
2981
|
+
if unified_safetensors_dir is None:
|
|
2982
|
+
raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
|
|
2983
|
+
f"when format is 'safetensors'.")
|
|
2984
|
+
unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
|
|
2985
|
+
for param in unsupport_param:
|
|
2986
|
+
if param is not None:
|
|
2987
|
+
raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
|
|
2988
|
+
f"when format is 'safetensors'.")
|
|
2989
|
+
if strict_load or dec_mode != 'AES-GCM':
|
|
2990
|
+
raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
|
|
2991
|
+
f"when format is 'safetensors'.")
|
|
2992
|
+
if network is not None:
|
|
2993
|
+
rank_id = get_rank()
|
|
2994
|
+
_load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
|
|
2995
|
+
else:
|
|
2996
|
+
if dst_safetensors_dir is None:
|
|
2997
|
+
raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
|
|
2998
|
+
f"when network is None.")
|
|
2999
|
+
if rank_id is not None:
|
|
3000
|
+
_load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
|
|
3001
|
+
rank_id)
|
|
3002
|
+
else:
|
|
3003
|
+
dst_strategy_dict = _build_searched_strategy(predict_strategy)
|
|
3004
|
+
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
3005
|
+
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
3006
|
+
dst_device_num = dst_stage_device_num * dst_stage_num
|
|
3007
|
+
processes = []
|
|
3008
|
+
activate_processes = 0
|
|
3009
|
+
for rank in range(0, dst_device_num):
|
|
3010
|
+
p = Process(target=_load_parallel_checkpoint, args=(
|
|
3011
|
+
unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
|
|
3012
|
+
p.start()
|
|
3013
|
+
processes.append(p)
|
|
3014
|
+
activate_processes += 1
|
|
3015
|
+
max_processes = 64
|
|
3016
|
+
if activate_processes >= max_processes:
|
|
3017
|
+
p = processes.pop(0)
|
|
3018
|
+
p.join()
|
|
3019
|
+
activate_processes -= 1
|
|
3020
|
+
for p in processes:
|
|
3021
|
+
p.join()
|
|
3022
|
+
return
|
|
3023
|
+
|
|
2817
3024
|
network = Validator.check_isinstance("network", network, nn.Cell)
|
|
2818
3025
|
_check_checkpoint_file(checkpoint_filenames)
|
|
2819
3026
|
_check_predict_strategy(predict_strategy)
|
|
@@ -2858,17 +3065,24 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2858
3065
|
param_rank = rank_list.get(param.name)[0]
|
|
2859
3066
|
skip_merge_split = rank_list.get(param.name)[1]
|
|
2860
3067
|
shard_stride = train_strategy.get(param.name)[4]
|
|
3068
|
+
tensor_map = train_strategy.get(param.name)[1]
|
|
3069
|
+
first_dim_shard_idx = tensor_map[0] if tensor_map else -1
|
|
3070
|
+
device_arrangement = train_strategy.get(param.name)[0]
|
|
3071
|
+
first_dim_shard_size = 1
|
|
3072
|
+
if first_dim_shard_idx >= 0:
|
|
3073
|
+
first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
|
|
2861
3074
|
if train_strategy.get(param.name)[5]:
|
|
2862
|
-
shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
|
|
3075
|
+
shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
|
|
2863
3076
|
else:
|
|
2864
3077
|
shard_size = 0
|
|
2865
3078
|
for rank in param_rank:
|
|
2866
3079
|
param_total_list = list(range(0, ckpt_file_len))
|
|
3080
|
+
if first_dim_shard_size != 1:
|
|
3081
|
+
param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
|
|
2867
3082
|
if shard_size > 0:
|
|
2868
|
-
|
|
2869
|
-
|
|
2870
|
-
|
|
2871
|
-
param_total_list = shard_total_list[rank // shard_size]
|
|
3083
|
+
rank_index = param_total_list.index(rank)
|
|
3084
|
+
start = rank_index // shard_size * shard_size
|
|
3085
|
+
param_total_list = param_total_list[start:start + shard_size]
|
|
2872
3086
|
if shard_stride > 0:
|
|
2873
3087
|
param_stride = []
|
|
2874
3088
|
# merge pre parameter
|
|
@@ -3040,7 +3254,7 @@ def _get_mindir_inputs(file_name):
|
|
|
3040
3254
|
>>> input_tensor = get_mindir_inputs("lenet.mindir")
|
|
3041
3255
|
"""
|
|
3042
3256
|
Validator.check_file_name_by_regular(file_name)
|
|
3043
|
-
file_name = os.path.
|
|
3257
|
+
file_name = os.path.realpath(file_name)
|
|
3044
3258
|
model = read_proto(file_name)
|
|
3045
3259
|
input_tensor = []
|
|
3046
3260
|
|