mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (308) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +3 -1
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +50 -9
  9. mindspore/_extends/parse/compile_config.py +41 -0
  10. mindspore/_extends/parse/parser.py +9 -7
  11. mindspore/_extends/parse/standard_method.py +52 -14
  12. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  13. mindspore/amp.py +24 -10
  14. mindspore/atlprov.dll +0 -0
  15. mindspore/avcodec-59.dll +0 -0
  16. mindspore/avdevice-59.dll +0 -0
  17. mindspore/avfilter-8.dll +0 -0
  18. mindspore/avformat-59.dll +0 -0
  19. mindspore/avutil-57.dll +0 -0
  20. mindspore/c1.dll +0 -0
  21. mindspore/c1xx.dll +0 -0
  22. mindspore/c2.dll +0 -0
  23. mindspore/common/__init__.py +6 -4
  24. mindspore/common/_pijit_context.py +190 -0
  25. mindspore/common/_register_for_tensor.py +2 -1
  26. mindspore/common/_tensor_overload.py +139 -0
  27. mindspore/common/api.py +102 -87
  28. mindspore/common/dump.py +5 -6
  29. mindspore/common/generator.py +1 -7
  30. mindspore/common/hook_handle.py +14 -26
  31. mindspore/common/mindir_util.py +2 -2
  32. mindspore/common/parameter.py +46 -13
  33. mindspore/common/recompute.py +39 -9
  34. mindspore/common/sparse_tensor.py +7 -3
  35. mindspore/common/tensor.py +209 -29
  36. mindspore/communication/__init__.py +1 -1
  37. mindspore/communication/_comm_helper.py +38 -3
  38. mindspore/communication/comm_func.py +310 -55
  39. mindspore/communication/management.py +14 -14
  40. mindspore/context.py +123 -22
  41. mindspore/dataset/__init__.py +1 -1
  42. mindspore/dataset/audio/__init__.py +1 -1
  43. mindspore/dataset/core/config.py +7 -0
  44. mindspore/dataset/core/validator_helpers.py +7 -0
  45. mindspore/dataset/engine/cache_client.py +1 -1
  46. mindspore/dataset/engine/datasets.py +72 -44
  47. mindspore/dataset/engine/datasets_audio.py +7 -7
  48. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  49. mindspore/dataset/engine/datasets_text.py +20 -20
  50. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  51. mindspore/dataset/engine/datasets_vision.py +33 -33
  52. mindspore/dataset/engine/iterators.py +29 -0
  53. mindspore/dataset/engine/obs/util.py +7 -0
  54. mindspore/dataset/engine/queue.py +114 -60
  55. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  56. mindspore/dataset/engine/validators.py +34 -14
  57. mindspore/dataset/text/__init__.py +1 -4
  58. mindspore/dataset/transforms/__init__.py +0 -3
  59. mindspore/dataset/utils/line_reader.py +2 -0
  60. mindspore/dataset/vision/__init__.py +1 -4
  61. mindspore/dataset/vision/utils.py +1 -1
  62. mindspore/dataset/vision/validators.py +2 -1
  63. mindspore/dnnl.dll +0 -0
  64. mindspore/dpcmi.dll +0 -0
  65. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  66. mindspore/experimental/es/embedding_service.py +883 -0
  67. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  68. mindspore/experimental/llm_boost/__init__.py +21 -0
  69. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  70. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  71. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  72. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  73. mindspore/experimental/llm_boost/register.py +129 -0
  74. mindspore/experimental/llm_boost/utils.py +31 -0
  75. mindspore/experimental/optim/adamw.py +85 -0
  76. mindspore/experimental/optim/optimizer.py +3 -0
  77. mindspore/hal/__init__.py +3 -3
  78. mindspore/hal/contiguous_tensors_handle.py +175 -0
  79. mindspore/hal/stream.py +18 -0
  80. mindspore/include/api/model_group.h +13 -1
  81. mindspore/include/api/types.h +10 -10
  82. mindspore/include/dataset/config.h +2 -2
  83. mindspore/include/dataset/constants.h +2 -2
  84. mindspore/include/dataset/execute.h +2 -2
  85. mindspore/include/dataset/vision.h +4 -0
  86. mindspore/jpeg62.dll +0 -0
  87. mindspore/log.py +1 -1
  88. mindspore/mindrecord/filewriter.py +68 -51
  89. mindspore/mindspore_backend.dll +0 -0
  90. mindspore/mindspore_common.dll +0 -0
  91. mindspore/mindspore_core.dll +0 -0
  92. mindspore/mindspore_glog.dll +0 -0
  93. mindspore/mindspore_np_dtype.dll +0 -0
  94. mindspore/mindspore_ops.dll +0 -0
  95. mindspore/mint/__init__.py +495 -46
  96. mindspore/mint/distributed/__init__.py +31 -0
  97. mindspore/mint/distributed/distributed.py +254 -0
  98. mindspore/mint/nn/__init__.py +266 -21
  99. mindspore/mint/nn/functional.py +125 -19
  100. mindspore/mint/nn/layer/__init__.py +39 -0
  101. mindspore/mint/nn/layer/activation.py +133 -0
  102. mindspore/mint/nn/layer/normalization.py +477 -0
  103. mindspore/mint/nn/layer/pooling.py +110 -0
  104. mindspore/mint/optim/adamw.py +28 -7
  105. mindspore/mint/special/__init__.py +63 -0
  106. mindspore/msobj140.dll +0 -0
  107. mindspore/mspdb140.dll +0 -0
  108. mindspore/mspdbcore.dll +0 -0
  109. mindspore/mspdbst.dll +0 -0
  110. mindspore/mspft140.dll +0 -0
  111. mindspore/msvcdis140.dll +0 -0
  112. mindspore/msvcp140_1.dll +0 -0
  113. mindspore/msvcp140_2.dll +0 -0
  114. mindspore/msvcp140_atomic_wait.dll +0 -0
  115. mindspore/msvcp140_codecvt_ids.dll +0 -0
  116. mindspore/multiprocessing/__init__.py +2 -1
  117. mindspore/nn/__init__.py +0 -1
  118. mindspore/nn/cell.py +275 -93
  119. mindspore/nn/layer/activation.py +211 -44
  120. mindspore/nn/layer/basic.py +113 -3
  121. mindspore/nn/layer/embedding.py +120 -2
  122. mindspore/nn/layer/normalization.py +101 -5
  123. mindspore/nn/layer/padding.py +34 -48
  124. mindspore/nn/layer/pooling.py +161 -7
  125. mindspore/nn/layer/transformer.py +3 -3
  126. mindspore/nn/loss/__init__.py +2 -2
  127. mindspore/nn/loss/loss.py +84 -6
  128. mindspore/nn/optim/__init__.py +2 -1
  129. mindspore/nn/optim/adadelta.py +1 -1
  130. mindspore/nn/optim/adam.py +1 -1
  131. mindspore/nn/optim/lamb.py +1 -1
  132. mindspore/nn/optim/tft_wrapper.py +127 -0
  133. mindspore/nn/wrap/cell_wrapper.py +12 -23
  134. mindspore/nn/wrap/grad_reducer.py +5 -5
  135. mindspore/nn/wrap/loss_scale.py +17 -3
  136. mindspore/numpy/__init__.py +1 -1
  137. mindspore/numpy/array_creations.py +65 -68
  138. mindspore/numpy/array_ops.py +64 -60
  139. mindspore/numpy/fft.py +610 -75
  140. mindspore/numpy/logic_ops.py +11 -10
  141. mindspore/numpy/math_ops.py +85 -84
  142. mindspore/numpy/utils_const.py +4 -4
  143. mindspore/opencv_core452.dll +0 -0
  144. mindspore/opencv_imgcodecs452.dll +0 -0
  145. mindspore/opencv_imgproc452.dll +0 -0
  146. mindspore/ops/__init__.py +6 -4
  147. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  148. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  149. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  150. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  151. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  152. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  153. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  154. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  155. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  156. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  157. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  158. mindspore/ops/composite/base.py +85 -48
  159. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  160. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  161. mindspore/ops/function/__init__.py +22 -0
  162. mindspore/ops/function/array_func.py +490 -153
  163. mindspore/ops/function/debug_func.py +113 -1
  164. mindspore/ops/function/fft_func.py +15 -2
  165. mindspore/ops/function/grad/grad_func.py +3 -2
  166. mindspore/ops/function/math_func.py +558 -207
  167. mindspore/ops/function/nn_func.py +817 -383
  168. mindspore/ops/function/other_func.py +3 -2
  169. mindspore/ops/function/random_func.py +184 -8
  170. mindspore/ops/function/reshard_func.py +13 -11
  171. mindspore/ops/function/sparse_unary_func.py +1 -1
  172. mindspore/ops/function/vmap_func.py +3 -2
  173. mindspore/ops/functional.py +24 -14
  174. mindspore/ops/op_info_register.py +3 -3
  175. mindspore/ops/operations/__init__.py +6 -1
  176. mindspore/ops/operations/_grad_ops.py +2 -76
  177. mindspore/ops/operations/_infer_ops.py +1 -1
  178. mindspore/ops/operations/_inner_ops.py +71 -94
  179. mindspore/ops/operations/array_ops.py +12 -146
  180. mindspore/ops/operations/comm_ops.py +42 -53
  181. mindspore/ops/operations/custom_ops.py +83 -19
  182. mindspore/ops/operations/debug_ops.py +42 -10
  183. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  184. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  185. mindspore/ops/operations/math_ops.py +12 -223
  186. mindspore/ops/operations/nn_ops.py +20 -114
  187. mindspore/ops/operations/other_ops.py +7 -4
  188. mindspore/ops/operations/random_ops.py +46 -1
  189. mindspore/ops/primitive.py +18 -6
  190. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  191. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  192. mindspore/ops_generate/gen_constants.py +36 -0
  193. mindspore/ops_generate/gen_ops.py +67 -52
  194. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  195. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  196. mindspore/ops_generate/op_proto.py +10 -3
  197. mindspore/ops_generate/pyboost_utils.py +14 -1
  198. mindspore/ops_generate/template.py +43 -21
  199. mindspore/parallel/__init__.py +3 -1
  200. mindspore/parallel/_auto_parallel_context.py +28 -8
  201. mindspore/parallel/_cell_wrapper.py +83 -0
  202. mindspore/parallel/_parallel_serialization.py +47 -19
  203. mindspore/parallel/_tensor.py +81 -11
  204. mindspore/parallel/_utils.py +13 -1
  205. mindspore/parallel/algo_parameter_config.py +5 -5
  206. mindspore/parallel/checkpoint_transform.py +46 -39
  207. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  208. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  209. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  210. mindspore/parallel/parameter_broadcast.py +3 -4
  211. mindspore/parallel/shard.py +162 -31
  212. mindspore/parallel/transform_safetensors.py +993 -0
  213. mindspore/pgodb140.dll +0 -0
  214. mindspore/pgort140.dll +0 -0
  215. mindspore/profiler/__init__.py +2 -1
  216. mindspore/profiler/common/constant.py +29 -0
  217. mindspore/profiler/common/registry.py +47 -0
  218. mindspore/profiler/common/util.py +28 -0
  219. mindspore/profiler/dynamic_profiler.py +694 -0
  220. mindspore/profiler/envprofiling.py +17 -19
  221. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  222. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  223. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  230. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  231. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  232. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  233. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  234. mindspore/profiler/parser/framework_parser.py +1 -391
  235. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  236. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  237. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  238. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  239. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  240. mindspore/profiler/parser/profiler_info.py +78 -6
  241. mindspore/profiler/profiler.py +153 -0
  242. mindspore/profiler/profiling.py +280 -412
  243. mindspore/rewrite/__init__.py +1 -2
  244. mindspore/rewrite/common/namespace.py +4 -4
  245. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  246. mindspore/run_check/_check_version.py +36 -103
  247. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  248. mindspore/swresample-4.dll +0 -0
  249. mindspore/swscale-6.dll +0 -0
  250. mindspore/tbbmalloc.dll +0 -0
  251. mindspore/tinyxml2.dll +0 -0
  252. mindspore/train/__init__.py +4 -3
  253. mindspore/train/_utils.py +28 -2
  254. mindspore/train/amp.py +171 -53
  255. mindspore/train/callback/__init__.py +2 -2
  256. mindspore/train/callback/_callback.py +4 -4
  257. mindspore/train/callback/_checkpoint.py +85 -22
  258. mindspore/train/callback/_cluster_monitor.py +1 -1
  259. mindspore/train/callback/_flops_collector.py +1 -0
  260. mindspore/train/callback/_loss_monitor.py +3 -3
  261. mindspore/train/callback/_on_request_exit.py +134 -31
  262. mindspore/train/callback/_summary_collector.py +5 -5
  263. mindspore/train/callback/_tft_register.py +352 -0
  264. mindspore/train/dataset_helper.py +7 -3
  265. mindspore/train/metrics/metric.py +3 -3
  266. mindspore/train/metrics/roc.py +4 -4
  267. mindspore/train/mind_ir_pb2.py +44 -39
  268. mindspore/train/model.py +134 -58
  269. mindspore/train/serialization.py +336 -112
  270. mindspore/turbojpeg.dll +0 -0
  271. mindspore/utils/__init__.py +21 -0
  272. mindspore/utils/utils.py +60 -0
  273. mindspore/vcmeta.dll +0 -0
  274. mindspore/vcruntime140.dll +0 -0
  275. mindspore/vcruntime140_1.dll +0 -0
  276. mindspore/version.py +1 -1
  277. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  278. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
  279. mindspore/include/c_api/ms/abstract.h +0 -67
  280. mindspore/include/c_api/ms/attribute.h +0 -197
  281. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  282. mindspore/include/c_api/ms/base/macros.h +0 -32
  283. mindspore/include/c_api/ms/base/status.h +0 -33
  284. mindspore/include/c_api/ms/base/types.h +0 -283
  285. mindspore/include/c_api/ms/context.h +0 -102
  286. mindspore/include/c_api/ms/graph.h +0 -160
  287. mindspore/include/c_api/ms/node.h +0 -606
  288. mindspore/include/c_api/ms/tensor.h +0 -161
  289. mindspore/include/c_api/ms/value.h +0 -84
  290. mindspore/mindspore_shared_lib.dll +0 -0
  291. mindspore/nn/extend/basic.py +0 -140
  292. mindspore/nn/extend/embedding.py +0 -143
  293. mindspore/nn/extend/layer/normalization.py +0 -109
  294. mindspore/nn/extend/pooling.py +0 -117
  295. mindspore/nn/layer/embedding_service.py +0 -531
  296. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  297. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  298. mindspore/ops/extend/__init__.py +0 -53
  299. mindspore/ops/extend/array_func.py +0 -218
  300. mindspore/ops/extend/math_func.py +0 -76
  301. mindspore/ops/extend/nn_func.py +0 -308
  302. mindspore/ops/silent_check.py +0 -162
  303. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  304. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  305. mindspore/train/callback/_mindio_ttp.py +0 -443
  306. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  307. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  308. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,12 @@ import binascii
21
21
  import copy
22
22
  import json
23
23
  import os
24
+ import re
24
25
  import shutil
25
26
  import stat
26
27
  import threading
27
28
  from threading import Thread, RLock
29
+ from multiprocessing import Process
28
30
  from collections import defaultdict, OrderedDict
29
31
  from io import BytesIO
30
32
 
@@ -58,21 +60,25 @@ from mindspore.common.file_system import FileSystem, _register_basic_file_system
58
60
  from mindspore.communication.management import get_rank, get_group_size
59
61
  from mindspore.experimental import MapParameter
60
62
  from mindspore.ops import Cast
61
- from mindspore.parallel._cell_wrapper import get_allgather_cell
63
+ from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
62
64
  from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
63
65
  from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
64
- from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
66
+ from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
67
+ _get_device_num, _is_parallel_mode
68
+ from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
65
69
  from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
66
- _restore_group_info_list
70
+ _restore_group_info_list, _get_param_list_when_first_dim_sharded
67
71
  from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
68
72
  _store_warm_up_ptr_by_tensor_list, _cache_enable
69
73
  from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
70
- from mindspore.train._utils import read_proto
74
+ from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
75
+ _extract_pipeline_stage_num
76
+ from mindspore.train._utils import read_proto, get_parameter_redundancy
71
77
  from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
72
78
  split_mindir, split_dynamic_mindir
73
79
  from mindspore.common.generator import Generator
74
- from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
75
- from mindspore.parallel.parameter_broadcast import parameter_broadcast
80
+ from safetensors.numpy import save_file
81
+ from safetensors import safe_open
76
82
  from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
77
83
 
78
84
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
@@ -116,6 +122,68 @@ def init_ckpt_file_system(fs: FileSystem):
116
122
  init_ckpt_file_system(_ckpt_fs)
117
123
 
118
124
 
125
+ def _get_cur_rank_dp(parameter_layout_dict):
126
+ """ Get dp and tp from layout dict. """
127
+ pp_num = _get_auto_parallel_context("pipeline_stages")
128
+ dev_num = _get_device_num()
129
+ global_rank = get_rank()
130
+ pipe_size = dev_num // pp_num
131
+ initial_rank = (global_rank // pipe_size) * pipe_size
132
+ parameter_redundancy_dict = get_parameter_redundancy(
133
+ parameter_layout_dict, initial_rank)
134
+ value_len = sys.maxsize
135
+ min_value = ()
136
+ for key, value in parameter_redundancy_dict.items():
137
+ if "accu_grads" in key or "inputs" in key:
138
+ continue
139
+ for item in value:
140
+ if len(item) < value_len and global_rank in item:
141
+ value_len = len(item)
142
+ min_value = item
143
+ return min_value
144
+
145
+
146
+ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
147
+ """
148
+ Find available checkpoint file path from all backup checkpoint files of current rank.
149
+ It suppose that checkpoint path contains substring 'rank_{rank_id}' which is used to
150
+ distinguish between different path.If cur_ckpt_path doesn't have 'rank_{rank_id}' substring, will return
151
+ cur_ckpt_path itself when cur_ckpt_path is exist, otherwise return None.
152
+
153
+ Note:
154
+ This API must be called after the communication is initialized because the cluster information
155
+ needs to be obtained internally.
156
+
157
+ Args:
158
+ cur_ckpt_path (str): the checkpoint file path which cur rank needs.
159
+ cur_strategy_path (str): strategy file path for current rank.
160
+
161
+ Returns:
162
+ - new_ckpt_file (String), if found available checkpoint file , return it.
163
+ - None, if not found available checkpoint, return None.
164
+
165
+ Examples:
166
+ >>> import mindspore as ms
167
+ >>> from mindspore.communication import init
168
+ >>> from mindspore import get_ckpt_path_with_strategy
169
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
170
+ >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
171
+ >>> init()
172
+ >>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
173
+ >>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
174
+ >>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
175
+ >>> print(ckpt_file_new)
176
+ """
177
+ dp = _get_cur_rank_dp(cur_strategy_path)
178
+ pattern = r'rank_\d+'
179
+ for i in dp:
180
+ new_ckpt_path = re.sub(pattern, f"rank_{str(i)}", cur_ckpt_path)
181
+ if not os.path.isfile(new_ckpt_path):
182
+ continue
183
+ return new_ckpt_path
184
+ return None
185
+
186
+
119
187
  class ParamDictFuture:
120
188
  def __init__(self, executor, param_dict_future):
121
189
  self.executor = executor
@@ -252,57 +320,72 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
252
320
  logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
253
321
 
254
322
 
255
- def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False):
323
+ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False,
324
+ format="ckpt"):
256
325
  """Execute the process of saving checkpoint into file."""
257
326
  try:
258
327
  with _ckpt_mutex:
328
+ file_name_list = list(os.path.splitext(ckpt_file_name))
329
+ file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
330
+ tmp_name = ''.join(file_name_list)
259
331
  if os.path.exists(ckpt_file_name):
260
332
  os.chmod(ckpt_file_name, stat.S_IWUSR)
261
333
  os.remove(ckpt_file_name)
262
- with _ckpt_fs.create(ckpt_file_name, *_ckpt_fs.create_args) as f:
263
- plain_data = None
264
- if enc_key is not None:
265
- plain_data = BytesIO()
266
-
267
- crc_num = 0
268
- for name, value in data_list.items():
269
- if name == "random_op":
270
- _write_random_seed(name, value, f)
271
- continue
272
- if value[0] == "mapparameter":
273
- _write_mapparameter(name, value, f, map_param_inc)
274
- continue
275
- if value[0] == "offload_parameter":
276
- new_value = value[1:]
277
- new_value[2] = value[3]
278
- _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
279
- _offload_if_config(value[3])
280
- continue
281
- if value[1] == "str":
282
- crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
283
- continue
284
- if isinstance(value[2], np.ndarray):
285
- crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
286
- continue
287
- if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
288
- _write_hugeparameter(name, value, f)
289
- continue
290
-
291
- crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
292
-
293
- if enc_key is not None:
294
- plain_data.seek(0)
295
- max_block_size = ENCRYPT_BLOCK_SIZE * 1024
296
- block_data = plain_data.read(max_block_size)
297
- while block_data:
298
- f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
334
+ if os.path.exists(tmp_name):
335
+ os.chmod(tmp_name, stat.S_IWUSR)
336
+ os.remove(tmp_name)
337
+ if format == "ckpt":
338
+ with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
339
+ plain_data = None
340
+ if enc_key is not None:
341
+ plain_data = BytesIO()
342
+
343
+ crc_num = 0
344
+ for name, value in data_list.items():
345
+ if name == "random_op":
346
+ _write_random_seed(name, value, f)
347
+ continue
348
+ if value[0] == "mapparameter":
349
+ _write_mapparameter(name, value, f, map_param_inc)
350
+ continue
351
+ if value[0] == "offload_parameter":
352
+ new_value = value[1:]
353
+ new_value[2] = value[3]
354
+ _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
355
+ _offload_if_config(value[3])
356
+ continue
357
+ if value[1] == "str":
358
+ crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
359
+ continue
360
+ if isinstance(value[2], np.ndarray):
361
+ crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
362
+ continue
363
+ if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
364
+ _write_hugeparameter(name, value, f)
365
+ continue
366
+
367
+ crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
368
+
369
+ if enc_key is not None:
370
+ plain_data.seek(0)
371
+ max_block_size = ENCRYPT_BLOCK_SIZE * 1024
299
372
  block_data = plain_data.read(max_block_size)
300
-
301
- if crc_check:
302
- f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
303
-
373
+ while block_data:
374
+ f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
375
+ block_data = plain_data.read(max_block_size)
376
+ if crc_check:
377
+ f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
378
+ elif format == "safetensors":
379
+ save_dict = {}
380
+ for name, value in data_list.items():
381
+ save_dict[name] = value[2].asnumpy()
382
+ save_file(save_dict, tmp_name)
383
+ if not os.path.exists(tmp_name):
384
+ logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
385
+ f"simultaneously modified a file.")
386
+ else:
387
+ os.rename(tmp_name, ckpt_file_name)
304
388
  os.chmod(ckpt_file_name, stat.S_IRUSR)
305
-
306
389
  except BaseException as e:
307
390
  logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
308
391
  "or the disk space is insufficient and so on.", ckpt_file_name)
@@ -415,8 +498,11 @@ def _write_hugeparameter(name, value, f):
415
498
  offset += numpy_data.shape[0]
416
499
 
417
500
 
418
- def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
501
+ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
419
502
  """Check save_obj and ckpt_file_name for save_checkpoint."""
503
+ if format not in ["safetensors", "ckpt"]:
504
+ raise ValueError(f"For 'save_checkpoint', the format must be "
505
+ f"'safetensors' or 'ckpt', but got {format}.")
420
506
  if not isinstance(save_obj, (nn.Cell, list, dict)):
421
507
  raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
422
508
  "but got {}.".format(type(save_obj)))
@@ -424,18 +510,26 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
424
510
  raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
425
511
  "'ckpt_file_name' must be "
426
512
  "string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
427
- ckpt_file_name = os.path.abspath(ckpt_file_name)
513
+ ckpt_file_name = os.path.realpath(ckpt_file_name)
428
514
  if os.path.isdir(ckpt_file_name):
429
515
  raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, "
430
516
  "it must be a file name.".format(ckpt_file_name))
431
- if not ckpt_file_name.endswith('.ckpt'):
432
- ckpt_file_name += ".ckpt"
517
+ if not ckpt_file_name.endswith(format):
518
+ ckpt_file_name += f".{format}"
433
519
  return ckpt_file_name
434
520
 
435
521
 
522
+ def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, map_param_inc=False,
523
+ global_step_num=None):
524
+ param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
525
+ or map_param_inc or global_step_num is not None)
526
+ if format == "safetensors" and param_not_default:
527
+ raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
528
+
529
+
436
530
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
437
531
  async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
438
- crc_check=False, **kwargs):
532
+ crc_check=False, format="ckpt", **kwargs):
439
533
  r"""
440
534
  Save checkpoint to a specified file.
441
535
 
@@ -465,6 +559,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
465
559
  be saved. Default: ``None`` .
466
560
  crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
467
561
  result to the file. Default: ``False`` .
562
+ format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
468
563
  kwargs (dict): Configuration options dictionary.
469
564
 
470
565
  Raises:
@@ -498,7 +593,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
498
593
  - `Saving and Loading the Model - Saving and Loading the Model Weight
499
594
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
500
595
  """
501
- ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
596
+ ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
502
597
  integrated_save = Validator.check_bool(integrated_save)
503
598
  async_save = Validator.check_bool(async_save)
504
599
  append_dict = _check_append_dict(append_dict)
@@ -508,10 +603,19 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
508
603
  map_param_inc = kwargs.get('incremental', False)
509
604
  logger.info("Execute the process of saving checkpoint files.")
510
605
  global_step_num = kwargs.get('global_step_num', None)
606
+ _check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, map_param_inc, global_step_num)
511
607
 
512
- save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
608
+ if append_dict and "__exception_save__" in append_dict:
609
+ s1 = mindspore.hal.Stream()
610
+ with mindspore.hal.StreamCtx(s1):
611
+ save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
612
+ s1.synchronize()
613
+ else:
614
+ save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
513
615
 
514
616
  if append_dict:
617
+ if "__exception_save__" in append_dict:
618
+ del append_dict["__exception_save__"]
515
619
  append_info_list = []
516
620
  for k_name, value in append_dict.items():
517
621
  if isinstance(value, Generator):
@@ -527,12 +631,17 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
527
631
  for param in save_obj:
528
632
  if param["name"] == "random_op":
529
633
  if os.getenv("AITURBO") == "1":
530
- data_list_np["random_op"] = param["data"]
634
+ data_list_np["random_op"] = []
635
+ data_list_np["random_op"].append(param["data"])
636
+ if crc_check:
637
+ bytes_value = bytes(data_list_np[key][0])
638
+ data_list_np[key].append(binascii.crc32(bytes_value))
531
639
  else:
532
640
  data_list["random_op"] = param["data"]
533
641
  continue
534
642
  key = param["name"]
535
643
  data_list[key] = []
644
+ data_list_np[key] = []
536
645
  if isinstance(param["data"], MapParameter):
537
646
  data_list[param["name"]].append("mapparameter")
538
647
  data_list[param["name"]].append(param["data"])
@@ -546,7 +655,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
546
655
 
547
656
  if isinstance(param["data"], str):
548
657
  if os.getenv("AITURBO") == "1":
549
- data_list_np[key] = np.array(param["data"])
658
+ data_list_np[key].append(np.array(param["data"]))
659
+ if crc_check:
660
+ bytes_value = data_list_np[key][0].tobytes()
661
+ data_list_np[key].append(binascii.crc32(bytes_value))
550
662
  else:
551
663
  data_list[key].append([0])
552
664
  data_list[key].append('str')
@@ -556,7 +668,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
556
668
  if isinstance(param["data"], Parameter):
557
669
  param["data"].init_data()
558
670
  if os.getenv("AITURBO") == "1":
559
- data_list_np[key] = param["data"].asnumpy()
671
+ data_list_np[key].append(param["data"].asnumpy())
672
+ if crc_check:
673
+ bytes_value = data_list_np[key][0].tobytes()
674
+ data_list_np[key].append(binascii.crc32(bytes_value))
560
675
  else:
561
676
  dims = []
562
677
  for dim in param['data'].shape:
@@ -568,16 +683,17 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
568
683
  data_list[key].append(data)
569
684
 
570
685
  if os.getenv("AITURBO") == "1":
571
- import aiturbo
686
+ from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
572
687
  ckpt_name = os.path.basename(ckpt_file_name)
573
- aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np)
688
+ aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
574
689
  elif async_save:
575
690
  data_copy = copy.deepcopy(data_list)
576
- thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check),
691
+ thr = Thread(target=_exec_save,
692
+ args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
577
693
  name="asyn_save_ckpt")
578
694
  thr.start()
579
695
  else:
580
- _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check)
696
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
581
697
 
582
698
  logger.info("Saving checkpoint process is finished.")
583
699
 
@@ -692,11 +808,14 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
692
808
  param_data.append(value.key)
693
809
  else:
694
810
  param_data = value.data
811
+ if append_dict and "__exception_save__" in append_dict:
812
+ param_data = Tensor(Tensor_.move_to(value, "CPU", False))
695
813
 
696
814
  # in automatic model parallel scenario, some parameters were split to all the devices,
697
815
  # which should be combined before saving
698
816
  if key in parameter_layout_dict:
699
- param_data = Tensor(value.data)
817
+ if not append_dict or "__exception_save__" not in append_dict:
818
+ param_data = Tensor(value.data)
700
819
  param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
701
820
  integrated_save)
702
821
 
@@ -812,7 +931,7 @@ def load(file_name, **kwargs):
812
931
  if not os.path.exists(file_name):
813
932
  raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, "
814
933
  "please check whether the 'file_name' is correct.")
815
- file_name = os.path.abspath(file_name)
934
+ file_name = os.path.realpath(file_name)
816
935
 
817
936
  # set customized functions for dynamic obfuscation
818
937
  obfuscated = _check_load_obfuscate(**kwargs)
@@ -875,7 +994,7 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=T
875
994
  if not os.path.exists(file_name):
876
995
  raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, "
877
996
  "please check whether the 'file_name' is correct.")
878
- file_name = os.path.abspath(file_name)
997
+ file_name = os.path.realpath(file_name)
879
998
 
880
999
  logger.info("Execute the process of export and split mindir.")
881
1000
  dynamic = True
@@ -1074,9 +1193,14 @@ def obfuscate_model(obf_config, **kwargs):
1074
1193
 
1075
1194
 
1076
1195
  def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1077
- dec_mode, crc_check):
1196
+ dec_mode, crc_check, format):
1078
1197
  """load parameter into parameter_dict"""
1079
- ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
1198
+ ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
1199
+ if format == "safetensors":
1200
+ with safe_open(ckpt_file_name, framework='np') as f:
1201
+ for k in f.keys():
1202
+ parameter_dict[k] = Parameter(f.get_tensor(k))
1203
+ return
1080
1204
  checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
1081
1205
  try:
1082
1206
  param_data_list = []
@@ -1138,7 +1262,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1138
1262
 
1139
1263
  def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
1140
1264
  dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None,
1141
- crc_check=False):
1265
+ crc_check=False, remove_redundancy=False, format="ckpt"):
1142
1266
  """
1143
1267
  Load checkpoint info from a specified file.
1144
1268
 
@@ -1148,6 +1272,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1148
1272
  - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1149
1273
  `choice_func` is recommended instead.
1150
1274
  And using either of those two args will override `choice_func` at the same time.
1275
+ - When loading a checkpoint that has removed redundancy, the network should be compiled.
1151
1276
 
1152
1277
  Args:
1153
1278
  ckpt_file_name (str): Checkpoint file name.
@@ -1170,6 +1295,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1170
1295
  that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1171
1296
  matches the custom condition will be removed. Default: ``None`` .
1172
1297
  crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
1298
+ remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1299
+ Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1300
+ redundant-free loading is not enabled.
1301
+ format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
1173
1302
 
1174
1303
  Returns:
1175
1304
  Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
@@ -1219,24 +1348,35 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1219
1348
  dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1220
1349
  dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1221
1350
  crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
1351
+ remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1352
+ _check_format_and_other_params(format, dec_key, dec_mode, crc_check)
1222
1353
  logger.info("Execute the process of loading checkpoint files.")
1223
1354
 
1224
1355
  parameter_dict = {}
1225
1356
 
1226
1357
  if os.getenv("AITURBO") == "1":
1227
1358
  rank_id = get_rank()
1228
- import aiturbo
1359
+ from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
1229
1360
  ckpt_path = os.path.dirname(ckpt_file_name)
1230
1361
  ckpt_name = os.path.basename(ckpt_file_name)
1231
- np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id)
1362
+ np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id, crc_check)
1232
1363
  for key, value in np_dict.items():
1364
+ if crc_check and len(value) != 2:
1365
+ raise ValueError(f"When loading a checkpoint from AITurbo, if CRC check is enabled, "
1366
+ f"the length of the value must be 2, but got {len(value)}.")
1233
1367
  if isinstance(value, str):
1234
- parameter_dict[key] = value
1368
+ if crc_check and value[1] != binascii.crc32(np.array(value[0]).tobytes()):
1369
+ raise ValueError(f"When loading a checkpoint from AITurbo, the value of the string has not "
1370
+ f"passed the CRC check and has been corrupted.")
1371
+ parameter_dict[key] = value[0]
1235
1372
  else:
1236
- parameter_dict[key] = Parameter(Tensor(value), name=key)
1373
+ if crc_check and value[1] != binascii.crc32(value[0].tobytes()):
1374
+ raise ValueError(f"When loading a checkpoint from AITurbo, the value of the parameter has not "
1375
+ f"passed the CRC check and has been corrupted.")
1376
+ parameter_dict[key] = Parameter(Tensor(value[0]), name=key)
1237
1377
  else:
1238
1378
  _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1239
- dec_mode, crc_check)
1379
+ dec_mode, crc_check, format)
1240
1380
 
1241
1381
  if not parameter_dict:
1242
1382
  raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
@@ -1245,7 +1385,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1245
1385
  if _warm_up_host_cache_enabled(parameter_dict):
1246
1386
  (is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
1247
1387
  if net is not None:
1248
- load_param_into_net(net, parameter_dict, strict_load)
1388
+ load_param_into_net(net, parameter_dict, strict_load, remove_redundancy)
1249
1389
  if _warm_up_host_cache_enabled(parameter_dict):
1250
1390
  _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1251
1391
 
@@ -1362,17 +1502,20 @@ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
1362
1502
  parameter_dict[element.tag] = map_array
1363
1503
 
1364
1504
 
1365
- def _check_ckpt_file_name(ckpt_file_name):
1505
+ def _check_ckpt_file_name(ckpt_file_name, format):
1366
1506
  """Check function load_checkpoint's ckpt_file_name."""
1367
1507
  if not isinstance(ckpt_file_name, str):
1368
1508
  raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
1369
1509
  "but got {}.".format(type(ckpt_file_name)))
1370
1510
 
1371
- if ckpt_file_name[-5:] != ".ckpt":
1372
- raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please "
1511
+ if format not in ['ckpt', 'safetensors']:
1512
+ raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt' or '.safetensors', please "
1373
1513
  "input the correct 'ckpt_file_name'.")
1514
+ if not ckpt_file_name.endswith(format):
1515
+ raise ValueError(f"For 'load_checkpoint', the checkpoint file format must same with 'format', but got "
1516
+ f"file_name:'{ckpt_file_name}', format:'{format}'")
1374
1517
 
1375
- ckpt_file_name = os.path.abspath(ckpt_file_name)
1518
+ ckpt_file_name = os.path.realpath(ckpt_file_name)
1376
1519
  if not os.path.exists(ckpt_file_name):
1377
1520
  raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
1378
1521
  "whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
@@ -1414,7 +1557,7 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1414
1557
  pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
1415
1558
  if pb_content is None:
1416
1559
  raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
1417
- if crc_check and pb_content[-17:-10] == b"crc_num":
1560
+ if crc_check and pb_content[-17:-10] != b"crc_num":
1418
1561
  logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
1419
1562
  if pb_content[-17:-10] == b"crc_num":
1420
1563
  crc_num_bytes = pb_content[-10:]
@@ -1484,10 +1627,13 @@ def _check_load_param_into_net(net, parameter_dict):
1484
1627
  parameter_dict.pop("random_op")
1485
1628
 
1486
1629
 
1487
- def load_param_into_net(net, parameter_dict, strict_load=False):
1630
+ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False):
1488
1631
  """
1489
1632
  Load parameters into network, return parameter list that are not loaded in the network.
1490
1633
 
1634
+ Note:
1635
+ - When loading a parameter dict that has removed redundancy, the network should be compiled.
1636
+
1491
1637
  Args:
1492
1638
  net (Cell): The network where the parameters will be loaded.
1493
1639
  parameter_dict (dict): The dictionary generated by load checkpoint file,
@@ -1496,6 +1642,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1496
1642
  into net when parameter name's suffix in checkpoint file is the same as the
1497
1643
  parameter in the network. When the types are inconsistent perform type conversion
1498
1644
  on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1645
+ remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1646
+ Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1647
+ redundant-free loading is not enabled.
1499
1648
 
1500
1649
  Returns:
1501
1650
  - param_not_load (List), the parameter name in model which are not loaded into the network.
@@ -1529,10 +1678,11 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1529
1678
  raise TypeError(msg)
1530
1679
 
1531
1680
  strict_load = Validator.check_bool(strict_load)
1681
+ remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1532
1682
  logger.info("Execute the process of loading parameters into net.")
1533
1683
  for _, param in net.parameters_and_names():
1534
1684
  param.from_ckpt = True
1535
- if not _is_in_auto_parallel_mode():
1685
+ if not (_is_in_auto_parallel_mode() or _is_parallel_mode()):
1536
1686
  net.init_parameters_data()
1537
1687
  else:
1538
1688
  _init_parameter_data_in_parallel_mode(net, parameter_dict)
@@ -1560,16 +1710,26 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1560
1710
  logger.warning("For 'load_param_into_net', "
1561
1711
  "{} parameters in the 'net' are not loaded, because they are not in the "
1562
1712
  "'parameter_dict', please check whether the network structure is consistent "
1563
- "when training and loading checkpoint.".format(len(param_not_load)))
1713
+ "when training and loading checkpoint. Another possibility is that "
1714
+ "the redundant loading is not enabled, but the loaded checkpoint is saved with "
1715
+ "redundancy removed. ".format(len(param_not_load)))
1564
1716
  logger.warning("{} are not loaded.".format(param_not_load))
1565
- if os.getenv("AITURBO") == "1" and net.parameter_layout_dict is not None:
1717
+ if remove_redundancy:
1718
+ parallel_mode = context.get_auto_parallel_context("parallel_mode")
1719
+ if parallel_mode == "stand_alone":
1720
+ raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
1721
+ f"in parallel scenarios, but got {parallel_mode}.")
1722
+ if not net.compile_cache and not net.parameter_layout_dict:
1723
+ raise ValueError("When loading a parameter dict that has removed redundancy, "
1724
+ "the network should be compiled.")
1566
1725
  param_layout = net.parameter_layout_dict
1567
- param_redundancy = get_parameter_redundancy(param_layout)
1568
- remove_param_redundancy_dict = remove_param_redundancy(param_redundancy)
1569
- target_parameter_name_set = set(parameter_dict.keys())
1570
- for rank_id, param_name_set in remove_param_redundancy_dict:
1571
- if param_name_set == target_parameter_name_set:
1572
- parameter_broadcast(net, param_layout, rank_id)
1726
+ rank_id = get_rank()
1727
+ device_num = _get_device_num()
1728
+ stage_num = _get_auto_parallel_context("pipeline_stages")
1729
+ chunk_size = device_num // stage_num
1730
+ initial_rank = (rank_id // chunk_size) * chunk_size
1731
+ _single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
1732
+
1573
1733
  return param_not_load, ckpt_not_load
1574
1734
 
1575
1735
 
@@ -1675,7 +1835,7 @@ def _save_graph(network, file_name):
1675
1835
  """
1676
1836
  logger.info("Execute the process of saving graph.")
1677
1837
 
1678
- file_name = os.path.abspath(file_name)
1838
+ file_name = os.path.realpath(file_name)
1679
1839
  graph_pb = network.get_func_graph_proto()
1680
1840
  if graph_pb:
1681
1841
  with open(file_name, "wb") as f:
@@ -1790,7 +1950,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1790
1950
  - AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
1791
1951
  - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
1792
1952
  - MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format
1793
- for MindSpore models.
1953
+ for MindSpore models. MINDIR does not support operators which have dictionary attribute.
1794
1954
 
1795
1955
  kwargs (dict): Configuration options dictionary.
1796
1956
 
@@ -1889,7 +2049,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1889
2049
  + str(columns))
1890
2050
  inputs = tuple(inputs_col)
1891
2051
 
1892
- file_name = os.path.abspath(file_name)
2052
+ file_name = os.path.realpath(file_name)
1893
2053
  if 'enc_key' in kwargs.keys():
1894
2054
  kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
1895
2055
  _export(net, file_name, file_format, *inputs, **kwargs)
@@ -1982,8 +2142,8 @@ def _save_air(net, file_name, *inputs, **kwargs):
1982
2142
  if os.path.exists(file_name):
1983
2143
  os.chmod(file_name, stat.S_IWUSR)
1984
2144
  if "/" in file_name:
1985
- real_path = os.path.abspath(file_name[:file_name.rfind("/")])
1986
- os.makedirs(real_path, exist_ok=True)
2145
+ real_path = os.path.realpath(file_name[:file_name.rfind("/")])
2146
+ os.makedirs(real_path, mode=0o700, exist_ok=True)
1987
2147
  if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
1988
2148
  _executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode'))
1989
2149
  else:
@@ -2093,12 +2253,12 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
2093
2253
  file_prefix = file_name.split("/")[-1]
2094
2254
  if file_prefix.endswith(".mindir"):
2095
2255
  file_prefix = file_prefix[:-7]
2096
- current_path = os.path.abspath(file_name)
2256
+ current_path = os.path.realpath(file_name)
2097
2257
  dirname = os.path.dirname(current_path)
2098
2258
  data_path = os.path.join(dirname, file_prefix + "_variables")
2099
2259
  if os.path.exists(data_path):
2100
2260
  shutil.rmtree(data_path)
2101
- os.makedirs(data_path, exist_ok=True)
2261
+ os.makedirs(data_path, mode=0o700, exist_ok=True)
2102
2262
  os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
2103
2263
  index = 0
2104
2264
  external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
@@ -2267,9 +2427,9 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
2267
2427
  "the data of parameter cannot be exported.".format(map_param_proto.name))
2268
2428
  if not file_name.endswith('.mindir'):
2269
2429
  file_name += ".mindir"
2270
- current_path = os.path.abspath(file_name)
2430
+ current_path = os.path.realpath(file_name)
2271
2431
  dirname = os.path.dirname(current_path)
2272
- os.makedirs(dirname, exist_ok=True)
2432
+ os.makedirs(dirname, mode=0o700, exist_ok=True)
2273
2433
  if os.path.exists(file_name):
2274
2434
  os.chmod(file_name, stat.S_IWUSR)
2275
2435
  with open(file_name, 'wb') as f:
@@ -2398,7 +2558,7 @@ def parse_print(print_file_name):
2398
2558
  [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
2399
2559
  [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
2400
2560
  """
2401
- print_file_path = os.path.abspath(print_file_name)
2561
+ print_file_path = os.path.realpath(print_file_name)
2402
2562
 
2403
2563
  if os.path.getsize(print_file_path) == 0:
2404
2564
  raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
@@ -2687,14 +2847,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
2687
2847
  return merged_parameter
2688
2848
 
2689
2849
 
2690
- def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None,
2691
- train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
2850
+ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
2851
+ train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
2852
+ format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
2692
2853
  """
2693
2854
  Load checkpoint into net for distributed predication. Used in the case of distributed inference.
2694
2855
 
2695
2856
  Args:
2696
2857
  network (Cell): Network for distributed predication.
2697
- checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
2858
+ checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
2698
2859
  predict_strategy (dict): Strategy of predication process. It means that using one device to predict
2699
2860
  when setting predict_strategy as None. Default: ``None`` .
2700
2861
  train_strategy_filename (str): The filename of training strategy protocol buffer file.
@@ -2711,6 +2872,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2711
2872
  dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
2712
2873
  mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
2713
2874
  Default: ``'AES-GCM'`` .
2875
+ format (str): Input weight format to be loaded into the network.
2876
+ It can be set to either "ckpt" or "safetensors". Default: "ckpt".
2877
+ unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
2878
+ Default: ``None`` .
2879
+ dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
2880
+ rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
2881
+ globally by initializing the network; In save mode, save the file according to the input
2882
+ sequence number. If it is not input, save the entire file.
2714
2883
 
2715
2884
  Raises:
2716
2885
  TypeError: The type of inputs do not match the requirements.
@@ -2725,14 +2894,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2725
2894
 
2726
2895
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
2727
2896
  Please see the `rank table startup
2728
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
2897
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
2729
2898
  for more details.
2730
2899
 
2731
2900
  For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
2732
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
2901
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
2733
2902
 
2734
2903
  For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
2735
- Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
2904
+ Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
2736
2905
 
2737
2906
  >>> import os
2738
2907
  >>> import numpy as np
@@ -2814,6 +2983,54 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2814
2983
  ...
2815
2984
  [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
2816
2985
  """
2986
+ if format not in ['safetensors', 'ckpt']:
2987
+ raise ValueError(
2988
+ f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
2989
+
2990
+ if format == 'safetensors':
2991
+ if unified_safetensors_dir is None:
2992
+ raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
2993
+ f"when format is 'safetensors'.")
2994
+ unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
2995
+ for param in unsupport_param:
2996
+ if param is not None:
2997
+ raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
2998
+ f"when format is 'safetensors'.")
2999
+ if strict_load or dec_mode != 'AES-GCM':
3000
+ raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
3001
+ f"when format is 'safetensors'.")
3002
+ if network is not None:
3003
+ rank_id = get_rank()
3004
+ _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
3005
+ else:
3006
+ if dst_safetensors_dir is None:
3007
+ raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
3008
+ f"when network is None.")
3009
+ if rank_id is not None:
3010
+ _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3011
+ rank_id)
3012
+ else:
3013
+ dst_strategy_dict = _build_searched_strategy(predict_strategy)
3014
+ dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
3015
+ dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
3016
+ dst_device_num = dst_stage_device_num * dst_stage_num
3017
+ processes = []
3018
+ activate_processes = 0
3019
+ for rank in range(0, dst_device_num):
3020
+ p = Process(target=_load_parallel_checkpoint, args=(
3021
+ unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
3022
+ p.start()
3023
+ processes.append(p)
3024
+ activate_processes += 1
3025
+ max_processes = 64
3026
+ if activate_processes >= max_processes:
3027
+ p = processes.pop(0)
3028
+ p.join()
3029
+ activate_processes -= 1
3030
+ for p in processes:
3031
+ p.join()
3032
+ return
3033
+
2817
3034
  network = Validator.check_isinstance("network", network, nn.Cell)
2818
3035
  _check_checkpoint_file(checkpoint_filenames)
2819
3036
  _check_predict_strategy(predict_strategy)
@@ -2858,17 +3075,24 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2858
3075
  param_rank = rank_list.get(param.name)[0]
2859
3076
  skip_merge_split = rank_list.get(param.name)[1]
2860
3077
  shard_stride = train_strategy.get(param.name)[4]
3078
+ tensor_map = train_strategy.get(param.name)[1]
3079
+ first_dim_shard_idx = tensor_map[0] if tensor_map else -1
3080
+ device_arrangement = train_strategy.get(param.name)[0]
3081
+ first_dim_shard_size = 1
3082
+ if first_dim_shard_idx >= 0:
3083
+ first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
2861
3084
  if train_strategy.get(param.name)[5]:
2862
- shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
3085
+ shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
2863
3086
  else:
2864
3087
  shard_size = 0
2865
3088
  for rank in param_rank:
2866
3089
  param_total_list = list(range(0, ckpt_file_len))
3090
+ if first_dim_shard_size != 1:
3091
+ param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
2867
3092
  if shard_size > 0:
2868
- shard_total_list = []
2869
- for i in range(0, ckpt_file_len, shard_size):
2870
- shard_total_list.append(param_total_list[i:i + shard_size])
2871
- param_total_list = shard_total_list[rank // shard_size]
3093
+ rank_index = param_total_list.index(rank)
3094
+ start = rank_index // shard_size * shard_size
3095
+ param_total_list = param_total_list[start:start + shard_size]
2872
3096
  if shard_stride > 0:
2873
3097
  param_stride = []
2874
3098
  # merge pre parameter
@@ -3040,7 +3264,7 @@ def _get_mindir_inputs(file_name):
3040
3264
  >>> input_tensor = get_mindir_inputs("lenet.mindir")
3041
3265
  """
3042
3266
  Validator.check_file_name_by_regular(file_name)
3043
- file_name = os.path.abspath(file_name)
3267
+ file_name = os.path.realpath(file_name)
3044
3268
  model = read_proto(file_name)
3045
3269
  input_tensor = []
3046
3270