mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -21,10 +21,12 @@ import binascii
|
|
|
21
21
|
import copy
|
|
22
22
|
import json
|
|
23
23
|
import os
|
|
24
|
+
import re
|
|
24
25
|
import shutil
|
|
25
26
|
import stat
|
|
26
27
|
import threading
|
|
27
28
|
from threading import Thread, RLock
|
|
29
|
+
from multiprocessing import Process
|
|
28
30
|
from collections import defaultdict, OrderedDict
|
|
29
31
|
from io import BytesIO
|
|
30
32
|
|
|
@@ -58,21 +60,25 @@ from mindspore.common.file_system import FileSystem, _register_basic_file_system
|
|
|
58
60
|
from mindspore.communication.management import get_rank, get_group_size
|
|
59
61
|
from mindspore.experimental import MapParameter
|
|
60
62
|
from mindspore.ops import Cast
|
|
61
|
-
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
|
63
|
+
from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
|
|
62
64
|
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
|
63
65
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
|
64
|
-
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
|
|
66
|
+
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
|
|
67
|
+
_get_device_num, _is_parallel_mode
|
|
68
|
+
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
|
|
65
69
|
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
|
66
|
-
_restore_group_info_list
|
|
70
|
+
_restore_group_info_list, _get_param_list_when_first_dim_sharded
|
|
67
71
|
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
68
72
|
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
69
73
|
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
|
70
|
-
from mindspore.
|
|
74
|
+
from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
|
|
75
|
+
_extract_pipeline_stage_num
|
|
76
|
+
from mindspore.train._utils import read_proto, get_parameter_redundancy
|
|
71
77
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
|
72
78
|
split_mindir, split_dynamic_mindir
|
|
73
79
|
from mindspore.common.generator import Generator
|
|
74
|
-
from
|
|
75
|
-
from
|
|
80
|
+
from safetensors.numpy import save_file
|
|
81
|
+
from safetensors import safe_open
|
|
76
82
|
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
77
83
|
|
|
78
84
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
@@ -116,6 +122,68 @@ def init_ckpt_file_system(fs: FileSystem):
|
|
|
116
122
|
init_ckpt_file_system(_ckpt_fs)
|
|
117
123
|
|
|
118
124
|
|
|
125
|
+
def _get_cur_rank_dp(parameter_layout_dict):
|
|
126
|
+
""" Get dp and tp from layout dict. """
|
|
127
|
+
pp_num = _get_auto_parallel_context("pipeline_stages")
|
|
128
|
+
dev_num = _get_device_num()
|
|
129
|
+
global_rank = get_rank()
|
|
130
|
+
pipe_size = dev_num // pp_num
|
|
131
|
+
initial_rank = (global_rank // pipe_size) * pipe_size
|
|
132
|
+
parameter_redundancy_dict = get_parameter_redundancy(
|
|
133
|
+
parameter_layout_dict, initial_rank)
|
|
134
|
+
value_len = sys.maxsize
|
|
135
|
+
min_value = ()
|
|
136
|
+
for key, value in parameter_redundancy_dict.items():
|
|
137
|
+
if "accu_grads" in key or "inputs" in key:
|
|
138
|
+
continue
|
|
139
|
+
for item in value:
|
|
140
|
+
if len(item) < value_len and global_rank in item:
|
|
141
|
+
value_len = len(item)
|
|
142
|
+
min_value = item
|
|
143
|
+
return min_value
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
|
|
147
|
+
"""
|
|
148
|
+
Find available checkpoint file path from all backup checkpoint files of current rank.
|
|
149
|
+
It suppose that checkpoint path contains substring 'rank_{rank_id}' which is used to
|
|
150
|
+
distinguish between different path.If cur_ckpt_path doesn't have 'rank_{rank_id}' substring, will return
|
|
151
|
+
cur_ckpt_path itself when cur_ckpt_path is exist, otherwise return None.
|
|
152
|
+
|
|
153
|
+
Note:
|
|
154
|
+
This API must be called after the communication is initialized because the cluster information
|
|
155
|
+
needs to be obtained internally.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
cur_ckpt_path (str): the checkpoint file path which cur rank needs.
|
|
159
|
+
cur_strategy_path (str): strategy file path for current rank.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
- new_ckpt_file (String), if found available checkpoint file , return it.
|
|
163
|
+
- None, if not found available checkpoint, return None.
|
|
164
|
+
|
|
165
|
+
Examples:
|
|
166
|
+
>>> import mindspore as ms
|
|
167
|
+
>>> from mindspore.communication import init
|
|
168
|
+
>>> from mindspore import get_ckpt_path_with_strategy
|
|
169
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
170
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
|
171
|
+
>>> init()
|
|
172
|
+
>>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
|
|
173
|
+
>>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
|
|
174
|
+
>>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
|
|
175
|
+
>>> print(ckpt_file_new)
|
|
176
|
+
"""
|
|
177
|
+
dp = _get_cur_rank_dp(cur_strategy_path)
|
|
178
|
+
pattern = r'rank_\d+'
|
|
179
|
+
for i in dp:
|
|
180
|
+
new_ckpt_path = re.sub(pattern, f"rank_{str(i)}", cur_ckpt_path)
|
|
181
|
+
if not os.path.isfile(new_ckpt_path):
|
|
182
|
+
continue
|
|
183
|
+
return new_ckpt_path
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
|
|
119
187
|
class ParamDictFuture:
|
|
120
188
|
def __init__(self, executor, param_dict_future):
|
|
121
189
|
self.executor = executor
|
|
@@ -252,57 +320,72 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|
|
252
320
|
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
|
|
253
321
|
|
|
254
322
|
|
|
255
|
-
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False
|
|
323
|
+
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False,
|
|
324
|
+
format="ckpt"):
|
|
256
325
|
"""Execute the process of saving checkpoint into file."""
|
|
257
326
|
try:
|
|
258
327
|
with _ckpt_mutex:
|
|
328
|
+
file_name_list = list(os.path.splitext(ckpt_file_name))
|
|
329
|
+
file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
|
|
330
|
+
tmp_name = ''.join(file_name_list)
|
|
259
331
|
if os.path.exists(ckpt_file_name):
|
|
260
332
|
os.chmod(ckpt_file_name, stat.S_IWUSR)
|
|
261
333
|
os.remove(ckpt_file_name)
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
334
|
+
if os.path.exists(tmp_name):
|
|
335
|
+
os.chmod(tmp_name, stat.S_IWUSR)
|
|
336
|
+
os.remove(tmp_name)
|
|
337
|
+
if format == "ckpt":
|
|
338
|
+
with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
|
|
339
|
+
plain_data = None
|
|
340
|
+
if enc_key is not None:
|
|
341
|
+
plain_data = BytesIO()
|
|
342
|
+
|
|
343
|
+
crc_num = 0
|
|
344
|
+
for name, value in data_list.items():
|
|
345
|
+
if name == "random_op":
|
|
346
|
+
_write_random_seed(name, value, f)
|
|
347
|
+
continue
|
|
348
|
+
if value[0] == "mapparameter":
|
|
349
|
+
_write_mapparameter(name, value, f, map_param_inc)
|
|
350
|
+
continue
|
|
351
|
+
if value[0] == "offload_parameter":
|
|
352
|
+
new_value = value[1:]
|
|
353
|
+
new_value[2] = value[3]
|
|
354
|
+
_write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
|
|
355
|
+
_offload_if_config(value[3])
|
|
356
|
+
continue
|
|
357
|
+
if value[1] == "str":
|
|
358
|
+
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
359
|
+
continue
|
|
360
|
+
if isinstance(value[2], np.ndarray):
|
|
361
|
+
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
362
|
+
continue
|
|
363
|
+
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
364
|
+
_write_hugeparameter(name, value, f)
|
|
365
|
+
continue
|
|
366
|
+
|
|
367
|
+
crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
368
|
+
|
|
369
|
+
if enc_key is not None:
|
|
370
|
+
plain_data.seek(0)
|
|
371
|
+
max_block_size = ENCRYPT_BLOCK_SIZE * 1024
|
|
299
372
|
block_data = plain_data.read(max_block_size)
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
373
|
+
while block_data:
|
|
374
|
+
f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
|
|
375
|
+
block_data = plain_data.read(max_block_size)
|
|
376
|
+
if crc_check:
|
|
377
|
+
f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
|
|
378
|
+
elif format == "safetensors":
|
|
379
|
+
save_dict = {}
|
|
380
|
+
for name, value in data_list.items():
|
|
381
|
+
save_dict[name] = value[2].asnumpy()
|
|
382
|
+
save_file(save_dict, tmp_name)
|
|
383
|
+
if not os.path.exists(tmp_name):
|
|
384
|
+
logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
|
|
385
|
+
f"simultaneously modified a file.")
|
|
386
|
+
else:
|
|
387
|
+
os.rename(tmp_name, ckpt_file_name)
|
|
304
388
|
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
305
|
-
|
|
306
389
|
except BaseException as e:
|
|
307
390
|
logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
|
|
308
391
|
"or the disk space is insufficient and so on.", ckpt_file_name)
|
|
@@ -415,8 +498,11 @@ def _write_hugeparameter(name, value, f):
|
|
|
415
498
|
offset += numpy_data.shape[0]
|
|
416
499
|
|
|
417
500
|
|
|
418
|
-
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
501
|
+
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
|
|
419
502
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
503
|
+
if format not in ["safetensors", "ckpt"]:
|
|
504
|
+
raise ValueError(f"For 'save_checkpoint', the format must be "
|
|
505
|
+
f"'safetensors' or 'ckpt', but got {format}.")
|
|
420
506
|
if not isinstance(save_obj, (nn.Cell, list, dict)):
|
|
421
507
|
raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
|
|
422
508
|
"but got {}.".format(type(save_obj)))
|
|
@@ -424,18 +510,26 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
424
510
|
raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
|
|
425
511
|
"'ckpt_file_name' must be "
|
|
426
512
|
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
|
|
427
|
-
ckpt_file_name = os.path.
|
|
513
|
+
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
|
428
514
|
if os.path.isdir(ckpt_file_name):
|
|
429
515
|
raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, "
|
|
430
516
|
"it must be a file name.".format(ckpt_file_name))
|
|
431
|
-
if not ckpt_file_name.endswith(
|
|
432
|
-
ckpt_file_name += ".
|
|
517
|
+
if not ckpt_file_name.endswith(format):
|
|
518
|
+
ckpt_file_name += f".{format}"
|
|
433
519
|
return ckpt_file_name
|
|
434
520
|
|
|
435
521
|
|
|
522
|
+
def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, map_param_inc=False,
|
|
523
|
+
global_step_num=None):
|
|
524
|
+
param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
|
|
525
|
+
or map_param_inc or global_step_num is not None)
|
|
526
|
+
if format == "safetensors" and param_not_default:
|
|
527
|
+
raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
|
|
528
|
+
|
|
529
|
+
|
|
436
530
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
437
531
|
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
|
|
438
|
-
crc_check=False, **kwargs):
|
|
532
|
+
crc_check=False, format="ckpt", **kwargs):
|
|
439
533
|
r"""
|
|
440
534
|
Save checkpoint to a specified file.
|
|
441
535
|
|
|
@@ -465,6 +559,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
465
559
|
be saved. Default: ``None`` .
|
|
466
560
|
crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
|
|
467
561
|
result to the file. Default: ``False`` .
|
|
562
|
+
format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
468
563
|
kwargs (dict): Configuration options dictionary.
|
|
469
564
|
|
|
470
565
|
Raises:
|
|
@@ -498,7 +593,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
498
593
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
499
594
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
500
595
|
"""
|
|
501
|
-
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
|
|
596
|
+
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
|
|
502
597
|
integrated_save = Validator.check_bool(integrated_save)
|
|
503
598
|
async_save = Validator.check_bool(async_save)
|
|
504
599
|
append_dict = _check_append_dict(append_dict)
|
|
@@ -508,10 +603,19 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
508
603
|
map_param_inc = kwargs.get('incremental', False)
|
|
509
604
|
logger.info("Execute the process of saving checkpoint files.")
|
|
510
605
|
global_step_num = kwargs.get('global_step_num', None)
|
|
606
|
+
_check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, map_param_inc, global_step_num)
|
|
511
607
|
|
|
512
|
-
|
|
608
|
+
if append_dict and "__exception_save__" in append_dict:
|
|
609
|
+
s1 = mindspore.hal.Stream()
|
|
610
|
+
with mindspore.hal.StreamCtx(s1):
|
|
611
|
+
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
612
|
+
s1.synchronize()
|
|
613
|
+
else:
|
|
614
|
+
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
513
615
|
|
|
514
616
|
if append_dict:
|
|
617
|
+
if "__exception_save__" in append_dict:
|
|
618
|
+
del append_dict["__exception_save__"]
|
|
515
619
|
append_info_list = []
|
|
516
620
|
for k_name, value in append_dict.items():
|
|
517
621
|
if isinstance(value, Generator):
|
|
@@ -527,12 +631,17 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
527
631
|
for param in save_obj:
|
|
528
632
|
if param["name"] == "random_op":
|
|
529
633
|
if os.getenv("AITURBO") == "1":
|
|
530
|
-
data_list_np["random_op"] =
|
|
634
|
+
data_list_np["random_op"] = []
|
|
635
|
+
data_list_np["random_op"].append(param["data"])
|
|
636
|
+
if crc_check:
|
|
637
|
+
bytes_value = bytes(data_list_np[key][0])
|
|
638
|
+
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
531
639
|
else:
|
|
532
640
|
data_list["random_op"] = param["data"]
|
|
533
641
|
continue
|
|
534
642
|
key = param["name"]
|
|
535
643
|
data_list[key] = []
|
|
644
|
+
data_list_np[key] = []
|
|
536
645
|
if isinstance(param["data"], MapParameter):
|
|
537
646
|
data_list[param["name"]].append("mapparameter")
|
|
538
647
|
data_list[param["name"]].append(param["data"])
|
|
@@ -546,7 +655,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
546
655
|
|
|
547
656
|
if isinstance(param["data"], str):
|
|
548
657
|
if os.getenv("AITURBO") == "1":
|
|
549
|
-
data_list_np[key]
|
|
658
|
+
data_list_np[key].append(np.array(param["data"]))
|
|
659
|
+
if crc_check:
|
|
660
|
+
bytes_value = data_list_np[key][0].tobytes()
|
|
661
|
+
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
550
662
|
else:
|
|
551
663
|
data_list[key].append([0])
|
|
552
664
|
data_list[key].append('str')
|
|
@@ -556,7 +668,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
556
668
|
if isinstance(param["data"], Parameter):
|
|
557
669
|
param["data"].init_data()
|
|
558
670
|
if os.getenv("AITURBO") == "1":
|
|
559
|
-
data_list_np[key]
|
|
671
|
+
data_list_np[key].append(param["data"].asnumpy())
|
|
672
|
+
if crc_check:
|
|
673
|
+
bytes_value = data_list_np[key][0].tobytes()
|
|
674
|
+
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
560
675
|
else:
|
|
561
676
|
dims = []
|
|
562
677
|
for dim in param['data'].shape:
|
|
@@ -568,16 +683,17 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
568
683
|
data_list[key].append(data)
|
|
569
684
|
|
|
570
685
|
if os.getenv("AITURBO") == "1":
|
|
571
|
-
import aiturbo
|
|
686
|
+
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
572
687
|
ckpt_name = os.path.basename(ckpt_file_name)
|
|
573
|
-
aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np)
|
|
688
|
+
aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
|
|
574
689
|
elif async_save:
|
|
575
690
|
data_copy = copy.deepcopy(data_list)
|
|
576
|
-
thr = Thread(target=_exec_save,
|
|
691
|
+
thr = Thread(target=_exec_save,
|
|
692
|
+
args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
|
|
577
693
|
name="asyn_save_ckpt")
|
|
578
694
|
thr.start()
|
|
579
695
|
else:
|
|
580
|
-
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check)
|
|
696
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
|
|
581
697
|
|
|
582
698
|
logger.info("Saving checkpoint process is finished.")
|
|
583
699
|
|
|
@@ -692,11 +808,14 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
692
808
|
param_data.append(value.key)
|
|
693
809
|
else:
|
|
694
810
|
param_data = value.data
|
|
811
|
+
if append_dict and "__exception_save__" in append_dict:
|
|
812
|
+
param_data = Tensor(Tensor_.move_to(value, "CPU", False))
|
|
695
813
|
|
|
696
814
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
697
815
|
# which should be combined before saving
|
|
698
816
|
if key in parameter_layout_dict:
|
|
699
|
-
|
|
817
|
+
if not append_dict or "__exception_save__" not in append_dict:
|
|
818
|
+
param_data = Tensor(value.data)
|
|
700
819
|
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
701
820
|
integrated_save)
|
|
702
821
|
|
|
@@ -812,7 +931,7 @@ def load(file_name, **kwargs):
|
|
|
812
931
|
if not os.path.exists(file_name):
|
|
813
932
|
raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, "
|
|
814
933
|
"please check whether the 'file_name' is correct.")
|
|
815
|
-
file_name = os.path.
|
|
934
|
+
file_name = os.path.realpath(file_name)
|
|
816
935
|
|
|
817
936
|
# set customized functions for dynamic obfuscation
|
|
818
937
|
obfuscated = _check_load_obfuscate(**kwargs)
|
|
@@ -875,7 +994,7 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=T
|
|
|
875
994
|
if not os.path.exists(file_name):
|
|
876
995
|
raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, "
|
|
877
996
|
"please check whether the 'file_name' is correct.")
|
|
878
|
-
file_name = os.path.
|
|
997
|
+
file_name = os.path.realpath(file_name)
|
|
879
998
|
|
|
880
999
|
logger.info("Execute the process of export and split mindir.")
|
|
881
1000
|
dynamic = True
|
|
@@ -1074,9 +1193,14 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
1074
1193
|
|
|
1075
1194
|
|
|
1076
1195
|
def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1077
|
-
dec_mode, crc_check):
|
|
1196
|
+
dec_mode, crc_check, format):
|
|
1078
1197
|
"""load parameter into parameter_dict"""
|
|
1079
|
-
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
1198
|
+
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
|
|
1199
|
+
if format == "safetensors":
|
|
1200
|
+
with safe_open(ckpt_file_name, framework='np') as f:
|
|
1201
|
+
for k in f.keys():
|
|
1202
|
+
parameter_dict[k] = Parameter(f.get_tensor(k))
|
|
1203
|
+
return
|
|
1080
1204
|
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
|
|
1081
1205
|
try:
|
|
1082
1206
|
param_data_list = []
|
|
@@ -1138,7 +1262,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1138
1262
|
|
|
1139
1263
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
|
1140
1264
|
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None,
|
|
1141
|
-
crc_check=False):
|
|
1265
|
+
crc_check=False, remove_redundancy=False, format="ckpt"):
|
|
1142
1266
|
"""
|
|
1143
1267
|
Load checkpoint info from a specified file.
|
|
1144
1268
|
|
|
@@ -1148,6 +1272,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1148
1272
|
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
1149
1273
|
`choice_func` is recommended instead.
|
|
1150
1274
|
And using either of those two args will override `choice_func` at the same time.
|
|
1275
|
+
- When loading a checkpoint that has removed redundancy, the network should be compiled.
|
|
1151
1276
|
|
|
1152
1277
|
Args:
|
|
1153
1278
|
ckpt_file_name (str): Checkpoint file name.
|
|
@@ -1170,6 +1295,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1170
1295
|
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1171
1296
|
matches the custom condition will be removed. Default: ``None`` .
|
|
1172
1297
|
crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
|
|
1298
|
+
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1299
|
+
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1300
|
+
redundant-free loading is not enabled.
|
|
1301
|
+
format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
1173
1302
|
|
|
1174
1303
|
Returns:
|
|
1175
1304
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
@@ -1219,24 +1348,35 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1219
1348
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
1220
1349
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
1221
1350
|
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
1351
|
+
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1352
|
+
_check_format_and_other_params(format, dec_key, dec_mode, crc_check)
|
|
1222
1353
|
logger.info("Execute the process of loading checkpoint files.")
|
|
1223
1354
|
|
|
1224
1355
|
parameter_dict = {}
|
|
1225
1356
|
|
|
1226
1357
|
if os.getenv("AITURBO") == "1":
|
|
1227
1358
|
rank_id = get_rank()
|
|
1228
|
-
import aiturbo
|
|
1359
|
+
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
1229
1360
|
ckpt_path = os.path.dirname(ckpt_file_name)
|
|
1230
1361
|
ckpt_name = os.path.basename(ckpt_file_name)
|
|
1231
|
-
np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id)
|
|
1362
|
+
np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id, crc_check)
|
|
1232
1363
|
for key, value in np_dict.items():
|
|
1364
|
+
if crc_check and len(value) != 2:
|
|
1365
|
+
raise ValueError(f"When loading a checkpoint from AITurbo, if CRC check is enabled, "
|
|
1366
|
+
f"the length of the value must be 2, but got {len(value)}.")
|
|
1233
1367
|
if isinstance(value, str):
|
|
1234
|
-
|
|
1368
|
+
if crc_check and value[1] != binascii.crc32(np.array(value[0]).tobytes()):
|
|
1369
|
+
raise ValueError(f"When loading a checkpoint from AITurbo, the value of the string has not "
|
|
1370
|
+
f"passed the CRC check and has been corrupted.")
|
|
1371
|
+
parameter_dict[key] = value[0]
|
|
1235
1372
|
else:
|
|
1236
|
-
|
|
1373
|
+
if crc_check and value[1] != binascii.crc32(value[0].tobytes()):
|
|
1374
|
+
raise ValueError(f"When loading a checkpoint from AITurbo, the value of the parameter has not "
|
|
1375
|
+
f"passed the CRC check and has been corrupted.")
|
|
1376
|
+
parameter_dict[key] = Parameter(Tensor(value[0]), name=key)
|
|
1237
1377
|
else:
|
|
1238
1378
|
_load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1239
|
-
dec_mode, crc_check)
|
|
1379
|
+
dec_mode, crc_check, format)
|
|
1240
1380
|
|
|
1241
1381
|
if not parameter_dict:
|
|
1242
1382
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
@@ -1245,7 +1385,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1245
1385
|
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1246
1386
|
(is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
|
|
1247
1387
|
if net is not None:
|
|
1248
|
-
load_param_into_net(net, parameter_dict, strict_load)
|
|
1388
|
+
load_param_into_net(net, parameter_dict, strict_load, remove_redundancy)
|
|
1249
1389
|
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1250
1390
|
_warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
|
|
1251
1391
|
|
|
@@ -1362,17 +1502,20 @@ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
|
|
|
1362
1502
|
parameter_dict[element.tag] = map_array
|
|
1363
1503
|
|
|
1364
1504
|
|
|
1365
|
-
def _check_ckpt_file_name(ckpt_file_name):
|
|
1505
|
+
def _check_ckpt_file_name(ckpt_file_name, format):
|
|
1366
1506
|
"""Check function load_checkpoint's ckpt_file_name."""
|
|
1367
1507
|
if not isinstance(ckpt_file_name, str):
|
|
1368
1508
|
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
|
|
1369
1509
|
"but got {}.".format(type(ckpt_file_name)))
|
|
1370
1510
|
|
|
1371
|
-
if
|
|
1372
|
-
raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please "
|
|
1511
|
+
if format not in ['ckpt', 'safetensors']:
|
|
1512
|
+
raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt' or '.safetensors', please "
|
|
1373
1513
|
"input the correct 'ckpt_file_name'.")
|
|
1514
|
+
if not ckpt_file_name.endswith(format):
|
|
1515
|
+
raise ValueError(f"For 'load_checkpoint', the checkpoint file format must same with 'format', but got "
|
|
1516
|
+
f"file_name:'{ckpt_file_name}', format:'{format}'")
|
|
1374
1517
|
|
|
1375
|
-
ckpt_file_name = os.path.
|
|
1518
|
+
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
|
1376
1519
|
if not os.path.exists(ckpt_file_name):
|
|
1377
1520
|
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
|
|
1378
1521
|
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
|
|
@@ -1414,7 +1557,7 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
|
|
|
1414
1557
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
|
1415
1558
|
if pb_content is None:
|
|
1416
1559
|
raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
|
|
1417
|
-
if crc_check and pb_content[-17:-10]
|
|
1560
|
+
if crc_check and pb_content[-17:-10] != b"crc_num":
|
|
1418
1561
|
logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
|
|
1419
1562
|
if pb_content[-17:-10] == b"crc_num":
|
|
1420
1563
|
crc_num_bytes = pb_content[-10:]
|
|
@@ -1484,10 +1627,13 @@ def _check_load_param_into_net(net, parameter_dict):
|
|
|
1484
1627
|
parameter_dict.pop("random_op")
|
|
1485
1628
|
|
|
1486
1629
|
|
|
1487
|
-
def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
1630
|
+
def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False):
|
|
1488
1631
|
"""
|
|
1489
1632
|
Load parameters into network, return parameter list that are not loaded in the network.
|
|
1490
1633
|
|
|
1634
|
+
Note:
|
|
1635
|
+
- When loading a parameter dict that has removed redundancy, the network should be compiled.
|
|
1636
|
+
|
|
1491
1637
|
Args:
|
|
1492
1638
|
net (Cell): The network where the parameters will be loaded.
|
|
1493
1639
|
parameter_dict (dict): The dictionary generated by load checkpoint file,
|
|
@@ -1496,6 +1642,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1496
1642
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
1497
1643
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
1498
1644
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
1645
|
+
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1646
|
+
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1647
|
+
redundant-free loading is not enabled.
|
|
1499
1648
|
|
|
1500
1649
|
Returns:
|
|
1501
1650
|
- param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
@@ -1529,10 +1678,11 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1529
1678
|
raise TypeError(msg)
|
|
1530
1679
|
|
|
1531
1680
|
strict_load = Validator.check_bool(strict_load)
|
|
1681
|
+
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1532
1682
|
logger.info("Execute the process of loading parameters into net.")
|
|
1533
1683
|
for _, param in net.parameters_and_names():
|
|
1534
1684
|
param.from_ckpt = True
|
|
1535
|
-
if not _is_in_auto_parallel_mode():
|
|
1685
|
+
if not (_is_in_auto_parallel_mode() or _is_parallel_mode()):
|
|
1536
1686
|
net.init_parameters_data()
|
|
1537
1687
|
else:
|
|
1538
1688
|
_init_parameter_data_in_parallel_mode(net, parameter_dict)
|
|
@@ -1560,16 +1710,26 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1560
1710
|
logger.warning("For 'load_param_into_net', "
|
|
1561
1711
|
"{} parameters in the 'net' are not loaded, because they are not in the "
|
|
1562
1712
|
"'parameter_dict', please check whether the network structure is consistent "
|
|
1563
|
-
"when training and loading checkpoint."
|
|
1713
|
+
"when training and loading checkpoint. Another possibility is that "
|
|
1714
|
+
"the redundant loading is not enabled, but the loaded checkpoint is saved with "
|
|
1715
|
+
"redundancy removed. ".format(len(param_not_load)))
|
|
1564
1716
|
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1565
|
-
if
|
|
1717
|
+
if remove_redundancy:
|
|
1718
|
+
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
1719
|
+
if parallel_mode == "stand_alone":
|
|
1720
|
+
raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
|
|
1721
|
+
f"in parallel scenarios, but got {parallel_mode}.")
|
|
1722
|
+
if not net.compile_cache and not net.parameter_layout_dict:
|
|
1723
|
+
raise ValueError("When loading a parameter dict that has removed redundancy, "
|
|
1724
|
+
"the network should be compiled.")
|
|
1566
1725
|
param_layout = net.parameter_layout_dict
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1726
|
+
rank_id = get_rank()
|
|
1727
|
+
device_num = _get_device_num()
|
|
1728
|
+
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
1729
|
+
chunk_size = device_num // stage_num
|
|
1730
|
+
initial_rank = (rank_id // chunk_size) * chunk_size
|
|
1731
|
+
_single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
|
|
1732
|
+
|
|
1573
1733
|
return param_not_load, ckpt_not_load
|
|
1574
1734
|
|
|
1575
1735
|
|
|
@@ -1675,7 +1835,7 @@ def _save_graph(network, file_name):
|
|
|
1675
1835
|
"""
|
|
1676
1836
|
logger.info("Execute the process of saving graph.")
|
|
1677
1837
|
|
|
1678
|
-
file_name = os.path.
|
|
1838
|
+
file_name = os.path.realpath(file_name)
|
|
1679
1839
|
graph_pb = network.get_func_graph_proto()
|
|
1680
1840
|
if graph_pb:
|
|
1681
1841
|
with open(file_name, "wb") as f:
|
|
@@ -1790,7 +1950,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1790
1950
|
- AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
|
|
1791
1951
|
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
|
|
1792
1952
|
- MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format
|
|
1793
|
-
for MindSpore models.
|
|
1953
|
+
for MindSpore models. MINDIR does not support operators which have dictionary attribute.
|
|
1794
1954
|
|
|
1795
1955
|
kwargs (dict): Configuration options dictionary.
|
|
1796
1956
|
|
|
@@ -1889,7 +2049,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1889
2049
|
+ str(columns))
|
|
1890
2050
|
inputs = tuple(inputs_col)
|
|
1891
2051
|
|
|
1892
|
-
file_name = os.path.
|
|
2052
|
+
file_name = os.path.realpath(file_name)
|
|
1893
2053
|
if 'enc_key' in kwargs.keys():
|
|
1894
2054
|
kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
|
|
1895
2055
|
_export(net, file_name, file_format, *inputs, **kwargs)
|
|
@@ -1982,8 +2142,8 @@ def _save_air(net, file_name, *inputs, **kwargs):
|
|
|
1982
2142
|
if os.path.exists(file_name):
|
|
1983
2143
|
os.chmod(file_name, stat.S_IWUSR)
|
|
1984
2144
|
if "/" in file_name:
|
|
1985
|
-
real_path = os.path.
|
|
1986
|
-
os.makedirs(real_path, exist_ok=True)
|
|
2145
|
+
real_path = os.path.realpath(file_name[:file_name.rfind("/")])
|
|
2146
|
+
os.makedirs(real_path, mode=0o700, exist_ok=True)
|
|
1987
2147
|
if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
|
|
1988
2148
|
_executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode'))
|
|
1989
2149
|
else:
|
|
@@ -2093,12 +2253,12 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
2093
2253
|
file_prefix = file_name.split("/")[-1]
|
|
2094
2254
|
if file_prefix.endswith(".mindir"):
|
|
2095
2255
|
file_prefix = file_prefix[:-7]
|
|
2096
|
-
current_path = os.path.
|
|
2256
|
+
current_path = os.path.realpath(file_name)
|
|
2097
2257
|
dirname = os.path.dirname(current_path)
|
|
2098
2258
|
data_path = os.path.join(dirname, file_prefix + "_variables")
|
|
2099
2259
|
if os.path.exists(data_path):
|
|
2100
2260
|
shutil.rmtree(data_path)
|
|
2101
|
-
os.makedirs(data_path, exist_ok=True)
|
|
2261
|
+
os.makedirs(data_path, mode=0o700, exist_ok=True)
|
|
2102
2262
|
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
|
|
2103
2263
|
index = 0
|
|
2104
2264
|
external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
|
|
@@ -2267,9 +2427,9 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
2267
2427
|
"the data of parameter cannot be exported.".format(map_param_proto.name))
|
|
2268
2428
|
if not file_name.endswith('.mindir'):
|
|
2269
2429
|
file_name += ".mindir"
|
|
2270
|
-
current_path = os.path.
|
|
2430
|
+
current_path = os.path.realpath(file_name)
|
|
2271
2431
|
dirname = os.path.dirname(current_path)
|
|
2272
|
-
os.makedirs(dirname, exist_ok=True)
|
|
2432
|
+
os.makedirs(dirname, mode=0o700, exist_ok=True)
|
|
2273
2433
|
if os.path.exists(file_name):
|
|
2274
2434
|
os.chmod(file_name, stat.S_IWUSR)
|
|
2275
2435
|
with open(file_name, 'wb') as f:
|
|
@@ -2398,7 +2558,7 @@ def parse_print(print_file_name):
|
|
|
2398
2558
|
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
|
2399
2559
|
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
|
2400
2560
|
"""
|
|
2401
|
-
print_file_path = os.path.
|
|
2561
|
+
print_file_path = os.path.realpath(print_file_name)
|
|
2402
2562
|
|
|
2403
2563
|
if os.path.getsize(print_file_path) == 0:
|
|
2404
2564
|
raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
|
|
@@ -2687,14 +2847,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
2687
2847
|
return merged_parameter
|
|
2688
2848
|
|
|
2689
2849
|
|
|
2690
|
-
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None,
|
|
2691
|
-
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'
|
|
2850
|
+
def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
|
|
2851
|
+
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
|
|
2852
|
+
format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
|
|
2692
2853
|
"""
|
|
2693
2854
|
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
2694
2855
|
|
|
2695
2856
|
Args:
|
|
2696
2857
|
network (Cell): Network for distributed predication.
|
|
2697
|
-
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
|
|
2858
|
+
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
|
|
2698
2859
|
predict_strategy (dict): Strategy of predication process. It means that using one device to predict
|
|
2699
2860
|
when setting predict_strategy as None. Default: ``None`` .
|
|
2700
2861
|
train_strategy_filename (str): The filename of training strategy protocol buffer file.
|
|
@@ -2711,6 +2872,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2711
2872
|
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
|
|
2712
2873
|
mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
|
|
2713
2874
|
Default: ``'AES-GCM'`` .
|
|
2875
|
+
format (str): Input weight format to be loaded into the network.
|
|
2876
|
+
It can be set to either "ckpt" or "safetensors". Default: "ckpt".
|
|
2877
|
+
unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
|
|
2878
|
+
Default: ``None`` .
|
|
2879
|
+
dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
|
|
2880
|
+
rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
|
|
2881
|
+
globally by initializing the network; In save mode, save the file according to the input
|
|
2882
|
+
sequence number. If it is not input, save the entire file.
|
|
2714
2883
|
|
|
2715
2884
|
Raises:
|
|
2716
2885
|
TypeError: The type of inputs do not match the requirements.
|
|
@@ -2725,14 +2894,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2725
2894
|
|
|
2726
2895
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2727
2896
|
Please see the `rank table startup
|
|
2728
|
-
<https://www.mindspore.cn/
|
|
2897
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
|
|
2729
2898
|
for more details.
|
|
2730
2899
|
|
|
2731
2900
|
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
2732
|
-
<https://www.mindspore.cn/
|
|
2901
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
|
|
2733
2902
|
|
|
2734
2903
|
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2735
|
-
Startup <https://www.mindspore.cn/
|
|
2904
|
+
Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
|
|
2736
2905
|
|
|
2737
2906
|
>>> import os
|
|
2738
2907
|
>>> import numpy as np
|
|
@@ -2814,6 +2983,54 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2814
2983
|
...
|
|
2815
2984
|
[ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
|
|
2816
2985
|
"""
|
|
2986
|
+
if format not in ['safetensors', 'ckpt']:
|
|
2987
|
+
raise ValueError(
|
|
2988
|
+
f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
|
|
2989
|
+
|
|
2990
|
+
if format == 'safetensors':
|
|
2991
|
+
if unified_safetensors_dir is None:
|
|
2992
|
+
raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
|
|
2993
|
+
f"when format is 'safetensors'.")
|
|
2994
|
+
unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
|
|
2995
|
+
for param in unsupport_param:
|
|
2996
|
+
if param is not None:
|
|
2997
|
+
raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
|
|
2998
|
+
f"when format is 'safetensors'.")
|
|
2999
|
+
if strict_load or dec_mode != 'AES-GCM':
|
|
3000
|
+
raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
|
|
3001
|
+
f"when format is 'safetensors'.")
|
|
3002
|
+
if network is not None:
|
|
3003
|
+
rank_id = get_rank()
|
|
3004
|
+
_load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
|
|
3005
|
+
else:
|
|
3006
|
+
if dst_safetensors_dir is None:
|
|
3007
|
+
raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
|
|
3008
|
+
f"when network is None.")
|
|
3009
|
+
if rank_id is not None:
|
|
3010
|
+
_load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
|
|
3011
|
+
rank_id)
|
|
3012
|
+
else:
|
|
3013
|
+
dst_strategy_dict = _build_searched_strategy(predict_strategy)
|
|
3014
|
+
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
3015
|
+
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
3016
|
+
dst_device_num = dst_stage_device_num * dst_stage_num
|
|
3017
|
+
processes = []
|
|
3018
|
+
activate_processes = 0
|
|
3019
|
+
for rank in range(0, dst_device_num):
|
|
3020
|
+
p = Process(target=_load_parallel_checkpoint, args=(
|
|
3021
|
+
unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
|
|
3022
|
+
p.start()
|
|
3023
|
+
processes.append(p)
|
|
3024
|
+
activate_processes += 1
|
|
3025
|
+
max_processes = 64
|
|
3026
|
+
if activate_processes >= max_processes:
|
|
3027
|
+
p = processes.pop(0)
|
|
3028
|
+
p.join()
|
|
3029
|
+
activate_processes -= 1
|
|
3030
|
+
for p in processes:
|
|
3031
|
+
p.join()
|
|
3032
|
+
return
|
|
3033
|
+
|
|
2817
3034
|
network = Validator.check_isinstance("network", network, nn.Cell)
|
|
2818
3035
|
_check_checkpoint_file(checkpoint_filenames)
|
|
2819
3036
|
_check_predict_strategy(predict_strategy)
|
|
@@ -2858,17 +3075,24 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2858
3075
|
param_rank = rank_list.get(param.name)[0]
|
|
2859
3076
|
skip_merge_split = rank_list.get(param.name)[1]
|
|
2860
3077
|
shard_stride = train_strategy.get(param.name)[4]
|
|
3078
|
+
tensor_map = train_strategy.get(param.name)[1]
|
|
3079
|
+
first_dim_shard_idx = tensor_map[0] if tensor_map else -1
|
|
3080
|
+
device_arrangement = train_strategy.get(param.name)[0]
|
|
3081
|
+
first_dim_shard_size = 1
|
|
3082
|
+
if first_dim_shard_idx >= 0:
|
|
3083
|
+
first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
|
|
2861
3084
|
if train_strategy.get(param.name)[5]:
|
|
2862
|
-
shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
|
|
3085
|
+
shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
|
|
2863
3086
|
else:
|
|
2864
3087
|
shard_size = 0
|
|
2865
3088
|
for rank in param_rank:
|
|
2866
3089
|
param_total_list = list(range(0, ckpt_file_len))
|
|
3090
|
+
if first_dim_shard_size != 1:
|
|
3091
|
+
param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
|
|
2867
3092
|
if shard_size > 0:
|
|
2868
|
-
|
|
2869
|
-
|
|
2870
|
-
|
|
2871
|
-
param_total_list = shard_total_list[rank // shard_size]
|
|
3093
|
+
rank_index = param_total_list.index(rank)
|
|
3094
|
+
start = rank_index // shard_size * shard_size
|
|
3095
|
+
param_total_list = param_total_list[start:start + shard_size]
|
|
2872
3096
|
if shard_stride > 0:
|
|
2873
3097
|
param_stride = []
|
|
2874
3098
|
# merge pre parameter
|
|
@@ -3040,7 +3264,7 @@ def _get_mindir_inputs(file_name):
|
|
|
3040
3264
|
>>> input_tensor = get_mindir_inputs("lenet.mindir")
|
|
3041
3265
|
"""
|
|
3042
3266
|
Validator.check_file_name_by_regular(file_name)
|
|
3043
|
-
file_name = os.path.
|
|
3267
|
+
file_name = os.path.realpath(file_name)
|
|
3044
3268
|
model = read_proto(file_name)
|
|
3045
3269
|
input_tensor = []
|
|
3046
3270
|
|