mindspore 2.6.0rc1__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (458) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +2 -2
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +42 -11
  9. mindspore/_extends/builtin_operations.py +3 -3
  10. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  11. mindspore/_extends/optimize/cell_utils.py +96 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +3 -3
  15. mindspore/_extends/parse/compile_config.py +44 -22
  16. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
  17. mindspore/_extends/parse/parser.py +65 -84
  18. mindspore/_extends/parse/resources.py +39 -0
  19. mindspore/_extends/parse/standard_method.py +58 -14
  20. mindspore/_extends/parse/trope.py +8 -1
  21. mindspore/_extends/pijit/__init__.py +1 -2
  22. mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
  23. mindspore/amp.py +4 -22
  24. mindspore/atlprov.dll +0 -0
  25. mindspore/avcodec-59.dll +0 -0
  26. mindspore/avdevice-59.dll +0 -0
  27. mindspore/avfilter-8.dll +0 -0
  28. mindspore/avformat-59.dll +0 -0
  29. mindspore/avutil-57.dll +0 -0
  30. mindspore/boost/adasum.py +1 -1
  31. mindspore/boost/boost_cell_wrapper.py +4 -4
  32. mindspore/c1.dll +0 -0
  33. mindspore/c1xx.dll +0 -0
  34. mindspore/c2.dll +0 -0
  35. mindspore/common/__init__.py +43 -12
  36. mindspore/common/_grad_function.py +2 -1
  37. mindspore/common/_pijit_context.py +28 -7
  38. mindspore/common/_stub_tensor.py +1 -209
  39. mindspore/common/_tensor_cpp_method.py +1 -1
  40. mindspore/common/_tensor_docs.py +178 -53
  41. mindspore/common/_utils.py +9 -1
  42. mindspore/common/api.py +377 -203
  43. mindspore/common/dtype.py +108 -57
  44. mindspore/common/dump.py +11 -16
  45. mindspore/common/dynamic_shape/__init__.py +0 -0
  46. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
  47. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  48. mindspore/common/file_system.py +59 -9
  49. mindspore/common/generator.py +5 -3
  50. mindspore/common/hook_handle.py +33 -5
  51. mindspore/common/jit_config.py +1 -1
  52. mindspore/common/jit_trace.py +84 -105
  53. mindspore/common/np_dtype.py +3 -3
  54. mindspore/common/parameter.py +27 -29
  55. mindspore/common/recompute.py +5 -7
  56. mindspore/common/sparse_tensor.py +0 -3
  57. mindspore/common/symbol.py +0 -1
  58. mindspore/common/tensor.py +117 -131
  59. mindspore/communication/_comm_helper.py +46 -4
  60. mindspore/communication/management.py +79 -7
  61. mindspore/context.py +67 -55
  62. mindspore/dataset/__init__.py +1 -1
  63. mindspore/dataset/audio/transforms.py +1 -1
  64. mindspore/dataset/core/config.py +38 -4
  65. mindspore/dataset/engine/datasets.py +350 -322
  66. mindspore/dataset/engine/datasets_user_defined.py +70 -24
  67. mindspore/dataset/engine/iterators.py +2 -2
  68. mindspore/dataset/engine/obs/config_loader.py +2 -2
  69. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  70. mindspore/dataset/transforms/c_transforms.py +2 -2
  71. mindspore/dataset/transforms/py_transforms.py +7 -3
  72. mindspore/dataset/transforms/transforms.py +10 -6
  73. mindspore/dataset/vision/__init__.py +1 -1
  74. mindspore/dataset/vision/py_transforms.py +8 -8
  75. mindspore/dataset/vision/transforms.py +17 -5
  76. mindspore/dataset/vision/utils.py +632 -21
  77. mindspore/dataset/vision/validators.py +1 -0
  78. mindspore/device_context/ascend/device.py +1 -1
  79. mindspore/device_context/ascend/op_tuning.py +35 -1
  80. mindspore/device_context/gpu/__init__.py +2 -2
  81. mindspore/device_context/gpu/device.py +1 -1
  82. mindspore/device_context/gpu/op_precision.py +4 -2
  83. mindspore/device_context/gpu/op_tuning.py +6 -3
  84. mindspore/device_manager.py +16 -9
  85. mindspore/dnnl.dll +0 -0
  86. mindspore/dpcmi.dll +0 -0
  87. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -4
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/optim/adadelta.py +13 -20
  90. mindspore/experimental/optim/adagrad.py +15 -22
  91. mindspore/experimental/optim/adam.py +17 -24
  92. mindspore/experimental/optim/adamax.py +14 -22
  93. mindspore/experimental/optim/adamw.py +28 -34
  94. mindspore/experimental/optim/asgd.py +15 -25
  95. mindspore/experimental/optim/lr_scheduler.py +27 -45
  96. mindspore/experimental/optim/nadam.py +14 -24
  97. mindspore/experimental/optim/optimizer.py +13 -23
  98. mindspore/experimental/optim/radam.py +18 -24
  99. mindspore/experimental/optim/rmsprop.py +14 -25
  100. mindspore/experimental/optim/rprop.py +15 -26
  101. mindspore/experimental/optim/sgd.py +9 -19
  102. mindspore/hal/__init__.py +4 -4
  103. mindspore/hal/contiguous_tensors_handle.py +2 -2
  104. mindspore/hal/memory.py +27 -7
  105. mindspore/include/api/cell.h +65 -5
  106. mindspore/include/api/cfg.h +24 -7
  107. mindspore/include/api/context.h +1 -0
  108. mindspore/include/api/delegate.h +10 -2
  109. mindspore/include/api/dual_abi_helper.h +100 -19
  110. mindspore/include/api/graph.h +14 -1
  111. mindspore/include/api/kernel.h +16 -3
  112. mindspore/include/api/kernel_api.h +9 -1
  113. mindspore/include/api/metrics/accuracy.h +9 -0
  114. mindspore/include/api/model.h +8 -1
  115. mindspore/include/api/model_group.h +4 -0
  116. mindspore/include/api/model_parallel_runner.h +2 -0
  117. mindspore/include/api/status.h +48 -10
  118. mindspore/include/api/types.h +8 -3
  119. mindspore/include/c_api/model_c.h +0 -58
  120. mindspore/include/c_api/tensor_c.h +0 -26
  121. mindspore/include/dataset/constants.h +9 -0
  122. mindspore/include/dataset/vision_ascend.h +1 -1
  123. mindspore/jpeg62.dll +0 -0
  124. mindspore/mindrecord/tools/cifar10.py +61 -11
  125. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  126. mindspore/mindspore_backend_common.dll +0 -0
  127. mindspore/mindspore_backend_manager.dll +0 -0
  128. mindspore/mindspore_common.dll +0 -0
  129. mindspore/mindspore_core.dll +0 -0
  130. mindspore/mindspore_cpu_res_manager.dll +0 -0
  131. mindspore/mindspore_dump.dll +0 -0
  132. mindspore/mindspore_frontend.dll +0 -0
  133. mindspore/mindspore_glog.dll +0 -0
  134. mindspore/mindspore_memory_pool.dll +0 -0
  135. mindspore/mindspore_ms_backend.dll +0 -0
  136. mindspore/mindspore_ops.dll +0 -0
  137. mindspore/mindspore_ops_host.dll +0 -0
  138. mindspore/mindspore_ops_kernel_common.dll +0 -0
  139. mindspore/mindspore_profiler.dll +0 -0
  140. mindspore/mindspore_pyboost.dll +0 -0
  141. mindspore/mindspore_pynative.dll +0 -0
  142. mindspore/mindspore_res_manager.dll +0 -0
  143. mindspore/mindspore_runtime_pipeline.dll +0 -0
  144. mindspore/mint/__init__.py +6 -46
  145. mindspore/mint/distributed/__init__.py +5 -0
  146. mindspore/mint/distributed/distributed.py +429 -23
  147. mindspore/mint/nn/__init__.py +1 -1
  148. mindspore/mint/nn/functional.py +53 -6
  149. mindspore/mint/nn/layer/_functions.py +163 -294
  150. mindspore/mint/nn/layer/activation.py +8 -6
  151. mindspore/mint/nn/layer/conv.py +140 -104
  152. mindspore/mint/nn/layer/normalization.py +11 -25
  153. mindspore/mint/optim/adam.py +19 -18
  154. mindspore/mint/optim/adamw.py +14 -8
  155. mindspore/mint/optim/sgd.py +5 -5
  156. mindspore/msobj140.dll +0 -0
  157. mindspore/mspdb140.dll +0 -0
  158. mindspore/mspdbcore.dll +0 -0
  159. mindspore/mspdbst.dll +0 -0
  160. mindspore/mspft140.dll +0 -0
  161. mindspore/msvcdis140.dll +0 -0
  162. mindspore/msvcp140_1.dll +0 -0
  163. mindspore/msvcp140_2.dll +0 -0
  164. mindspore/msvcp140_atomic_wait.dll +0 -0
  165. mindspore/msvcp140_codecvt_ids.dll +0 -0
  166. mindspore/nn/cell.py +491 -623
  167. mindspore/nn/grad/cell_grad.py +11 -12
  168. mindspore/nn/layer/activation.py +36 -36
  169. mindspore/nn/layer/basic.py +74 -77
  170. mindspore/nn/layer/channel_shuffle.py +4 -4
  171. mindspore/nn/layer/combined.py +4 -2
  172. mindspore/nn/layer/conv.py +117 -110
  173. mindspore/nn/layer/dense.py +9 -7
  174. mindspore/nn/layer/embedding.py +50 -52
  175. mindspore/nn/layer/image.py +38 -40
  176. mindspore/nn/layer/math.py +111 -112
  177. mindspore/nn/layer/normalization.py +56 -44
  178. mindspore/nn/layer/pooling.py +58 -63
  179. mindspore/nn/layer/rnn_cells.py +33 -33
  180. mindspore/nn/layer/rnns.py +56 -56
  181. mindspore/nn/layer/thor_layer.py +74 -73
  182. mindspore/nn/layer/transformer.py +11 -1
  183. mindspore/nn/learning_rate_schedule.py +20 -20
  184. mindspore/nn/loss/loss.py +79 -81
  185. mindspore/nn/optim/adam.py +4 -6
  186. mindspore/nn/optim/adasum.py +2 -2
  187. mindspore/nn/optim/asgd.py +2 -0
  188. mindspore/nn/optim/lamb.py +1 -3
  189. mindspore/nn/optim/optimizer.py +1 -1
  190. mindspore/nn/optim/tft_wrapper.py +2 -3
  191. mindspore/nn/optim/thor.py +2 -2
  192. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  193. mindspore/nn/probability/distribution/exponential.py +2 -1
  194. mindspore/nn/probability/distribution/poisson.py +2 -1
  195. mindspore/nn/sparse/sparse.py +3 -3
  196. mindspore/nn/wrap/cell_wrapper.py +73 -42
  197. mindspore/nn/wrap/grad_reducer.py +37 -52
  198. mindspore/nn/wrap/loss_scale.py +72 -74
  199. mindspore/numpy/array_creations.py +7 -7
  200. mindspore/numpy/fft.py +1 -1
  201. mindspore/numpy/math_ops.py +5 -5
  202. mindspore/numpy/utils_const.py +1 -1
  203. mindspore/opencv_core452.dll +0 -0
  204. mindspore/opencv_imgcodecs452.dll +0 -0
  205. mindspore/opencv_imgproc452.dll +0 -0
  206. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  207. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  208. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  209. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  210. mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
  211. mindspore/ops/_vmap/vmap_array_ops.py +31 -13
  212. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  213. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +54 -13
  214. mindspore/ops/auto_generate/gen_extend_func.py +27 -145
  215. mindspore/ops/auto_generate/gen_ops_def.py +1027 -347
  216. mindspore/ops/auto_generate/gen_ops_prim.py +2341 -1117
  217. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  218. mindspore/ops/composite/__init__.py +10 -0
  219. mindspore/ops/composite/base.py +9 -5
  220. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  221. mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
  222. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  223. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  224. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  225. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  226. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  227. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  228. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  229. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  230. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  231. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  232. mindspore/ops/function/__init__.py +4 -1
  233. mindspore/ops/function/_add_attr_func.py +11 -6
  234. mindspore/ops/function/array_func.py +19 -102
  235. mindspore/ops/function/debug_func.py +8 -5
  236. mindspore/ops/function/grad/grad_func.py +5 -13
  237. mindspore/ops/function/math_func.py +77 -572
  238. mindspore/ops/function/nn_func.py +46 -94
  239. mindspore/ops/function/other_func.py +4 -1
  240. mindspore/ops/function/random_func.py +44 -5
  241. mindspore/ops/function/vmap_func.py +2 -1
  242. mindspore/ops/functional.py +4 -4
  243. mindspore/ops/functional_overload.py +594 -18
  244. mindspore/ops/op_info_register.py +21 -0
  245. mindspore/ops/operations/__init__.py +16 -11
  246. mindspore/ops/operations/_custom_ops_utils.py +689 -34
  247. mindspore/ops/operations/_inner_ops.py +14 -18
  248. mindspore/ops/operations/_sequence_ops.py +1 -1
  249. mindspore/ops/operations/array_ops.py +5 -51
  250. mindspore/ops/operations/comm_ops.py +186 -41
  251. mindspore/ops/operations/custom_ops.py +303 -177
  252. mindspore/ops/operations/debug_ops.py +59 -4
  253. mindspore/ops/operations/image_ops.py +13 -13
  254. mindspore/ops/operations/manually_defined/ops_def.py +27 -28
  255. mindspore/ops/operations/math_ops.py +8 -9
  256. mindspore/ops/operations/nn_ops.py +8 -40
  257. mindspore/ops/primitive.py +9 -20
  258. mindspore/ops/tensor_method.py +63 -15
  259. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  260. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  261. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  262. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  263. mindspore/ops_generate/common/base_generator.py +14 -0
  264. mindspore/ops_generate/common/gen_constants.py +8 -3
  265. mindspore/ops_generate/common/gen_utils.py +0 -19
  266. mindspore/ops_generate/common/op_proto.py +11 -4
  267. mindspore/ops_generate/common/template.py +88 -11
  268. mindspore/ops_generate/gen_ops.py +1 -1
  269. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  270. mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
  271. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  272. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  273. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  274. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  275. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  276. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
  277. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  278. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  279. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  280. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  281. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  282. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  283. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  284. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  285. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  286. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  287. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  288. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  289. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  290. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  291. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  292. mindspore/parallel/_auto_parallel_context.py +16 -23
  293. mindspore/parallel/_cell_wrapper.py +113 -45
  294. mindspore/parallel/_parallel_serialization.py +4 -3
  295. mindspore/parallel/_ps_context.py +4 -6
  296. mindspore/parallel/_tensor.py +167 -12
  297. mindspore/parallel/_transformer/moe.py +1 -1
  298. mindspore/parallel/_transformer/transformer.py +17 -12
  299. mindspore/parallel/_utils.py +5 -11
  300. mindspore/parallel/auto_parallel.py +35 -14
  301. mindspore/parallel/checkpoint_convert.py +3 -3
  302. mindspore/parallel/checkpoint_transform.py +13 -7
  303. mindspore/parallel/cluster/process_entity/_api.py +88 -49
  304. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  305. mindspore/parallel/cluster/run.py +48 -7
  306. mindspore/parallel/function/__init__.py +8 -1
  307. mindspore/parallel/function/reshard_func.py +12 -12
  308. mindspore/parallel/nn/__init__.py +15 -2
  309. mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
  310. mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
  311. mindspore/parallel/shard.py +10 -25
  312. mindspore/parallel/transform_safetensors.py +469 -174
  313. mindspore/pgodb140.dll +0 -0
  314. mindspore/pgort140.dll +0 -0
  315. mindspore/profiler/__init__.py +2 -1
  316. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  317. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  318. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
  319. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  320. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  321. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  322. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  323. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  324. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  325. mindspore/profiler/analysis/task_manager.py +1 -1
  326. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  327. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  328. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
  329. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
  330. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  331. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  332. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  333. mindspore/profiler/common/constant.py +16 -0
  334. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  335. mindspore/profiler/common/path_manager.py +9 -0
  336. mindspore/profiler/common/profiler_context.py +50 -29
  337. mindspore/profiler/common/profiler_info.py +0 -16
  338. mindspore/profiler/common/profiler_meta_data.py +1 -0
  339. mindspore/profiler/common/profiler_op_analyse.py +239 -0
  340. mindspore/profiler/common/profiler_output_path.py +23 -8
  341. mindspore/profiler/common/profiler_parameters.py +128 -35
  342. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  343. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  344. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  345. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  346. mindspore/profiler/dynamic_profiler.py +374 -338
  347. mindspore/profiler/envprofiler.py +42 -12
  348. mindspore/profiler/experimental_config.py +112 -7
  349. mindspore/profiler/mstx.py +33 -12
  350. mindspore/profiler/platform/__init__.py +2 -3
  351. mindspore/profiler/platform/cpu_profiler.py +10 -4
  352. mindspore/profiler/platform/npu_profiler.py +30 -20
  353. mindspore/profiler/profiler.py +218 -154
  354. mindspore/profiler/profiler_action_controller.py +65 -77
  355. mindspore/profiler/profiler_interface.py +2 -2
  356. mindspore/profiler/schedule.py +10 -4
  357. mindspore/rewrite/common/config.py +1 -0
  358. mindspore/rewrite/common/namer.py +1 -0
  359. mindspore/rewrite/common/namespace.py +1 -0
  360. mindspore/rewrite/node/node.py +31 -11
  361. mindspore/rewrite/parsers/assign_parser.py +1 -1
  362. mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
  363. mindspore/run_check/_check_version.py +7 -10
  364. mindspore/runtime/__init__.py +8 -6
  365. mindspore/runtime/event.py +10 -4
  366. mindspore/runtime/executor.py +87 -45
  367. mindspore/runtime/memory.py +31 -32
  368. mindspore/runtime/thread_bind_core.py +299 -165
  369. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  370. mindspore/swresample-4.dll +0 -0
  371. mindspore/swscale-6.dll +0 -0
  372. mindspore/tbbmalloc.dll +0 -0
  373. mindspore/tinyxml2.dll +0 -0
  374. mindspore/train/_utils.py +17 -7
  375. mindspore/train/amp.py +43 -23
  376. mindspore/train/callback/__init__.py +5 -5
  377. mindspore/train/callback/_callback.py +2 -1
  378. mindspore/train/callback/_checkpoint.py +4 -14
  379. mindspore/train/callback/_flops_collector.py +11 -7
  380. mindspore/train/callback/_landscape.py +0 -1
  381. mindspore/train/callback/_train_fault_tolerance.py +98 -21
  382. mindspore/train/data_sink.py +15 -6
  383. mindspore/train/dataset_helper.py +14 -5
  384. mindspore/train/model.py +133 -69
  385. mindspore/train/serialization.py +168 -126
  386. mindspore/train/summary/summary_record.py +13 -2
  387. mindspore/train/train_thor/model_thor.py +2 -2
  388. mindspore/turbojpeg.dll +0 -0
  389. mindspore/utils/__init__.py +3 -2
  390. mindspore/utils/dryrun.py +0 -6
  391. mindspore/utils/runtime_execution_order_check.py +163 -77
  392. mindspore/utils/sdc_detect.py +68 -0
  393. mindspore/utils/utils.py +14 -17
  394. mindspore/vcmeta.dll +0 -0
  395. mindspore/vcruntime140.dll +0 -0
  396. mindspore/vcruntime140_1.dll +0 -0
  397. mindspore/version.py +1 -1
  398. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
  399. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/RECORD +403 -442
  400. mindspore/_deprecated/jit.py +0 -198
  401. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  402. mindspore/communication/_hccl_management.py +0 -297
  403. mindspore/experimental/es/embedding_service.py +0 -891
  404. mindspore/experimental/es/embedding_service_layer.py +0 -581
  405. mindspore/profiler/common/validator/__init__.py +0 -14
  406. mindspore/profiler/common/validator/validate_path.py +0 -84
  407. mindspore/profiler/parser/__init__.py +0 -14
  408. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  409. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  410. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  411. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  412. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  413. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  414. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  415. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  416. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  417. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  418. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  419. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  420. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  421. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  422. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  423. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  424. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  425. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  426. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  427. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  428. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  429. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  430. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  431. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  432. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  433. mindspore/profiler/parser/container.py +0 -229
  434. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  435. mindspore/profiler/parser/flops_parser.py +0 -531
  436. mindspore/profiler/parser/framework_enum.py +0 -111
  437. mindspore/profiler/parser/framework_parser.py +0 -464
  438. mindspore/profiler/parser/framework_struct.py +0 -61
  439. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  440. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  441. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  442. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  443. mindspore/profiler/parser/hccl_parser.py +0 -573
  444. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  445. mindspore/profiler/parser/integrator.py +0 -526
  446. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  447. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  448. mindspore/profiler/parser/minddata_parser.py +0 -186
  449. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  450. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  451. mindspore/profiler/parser/optime_parser.py +0 -250
  452. mindspore/profiler/parser/profiler_info.py +0 -213
  453. mindspore/profiler/parser/step_trace_parser.py +0 -666
  454. mindspore/utils/hooks.py +0 -81
  455. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  456. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
  457. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
  458. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
@@ -31,15 +31,14 @@ from multiprocessing import active_children
31
31
  import multiprocessing as mp
32
32
  from collections import OrderedDict
33
33
  from io import BytesIO
34
+ from functools import partial
34
35
 
35
36
  import math
36
37
  import sys
37
38
  import time
38
- import google
39
39
  import numpy as np
40
-
41
- from safetensors.numpy import save_file, load_file
42
- from safetensors import safe_open
40
+ from safetensors.numpy import save_file
41
+ import google
43
42
 
44
43
  from mindspore.train.checkpoint_pb2 import Checkpoint
45
44
  from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
@@ -53,7 +52,6 @@ from mindspore.log import vlog_print
53
52
  from mindspore._checkparam import check_input_data, check_input_dataset
54
53
  from mindspore import _checkparam as Validator
55
54
  from mindspore.common import dtype as mstype
56
- from mindspore.common import np_dtype
57
55
  from mindspore.common.api import _cell_graph_executor as _executor
58
56
  from mindspore.common.api import _JitExecutor
59
57
  from mindspore.common.api import _get_parameter_layout
@@ -76,6 +74,7 @@ from mindspore.parallel.checkpoint_transform import restore_group_info_list as n
76
74
  from mindspore.parallel.checkpoint_transform import load_distributed_checkpoint as new_load_distributed_checkpoint
77
75
  from mindspore.parallel.checkpoint_transform import merge_sliced_parameter as new_merge_sliced_parameter
78
76
  from mindspore.parallel.checkpoint_transform import build_searched_strategy as new_build_searched_strategy
77
+ from mindspore.parallel.transform_safetensors import _fast_safe_open
79
78
  from mindspore.train._utils import read_proto, get_parameter_redundancy, _progress_bar, _load_and_transform
80
79
  from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, \
81
80
  split_mindir, split_dynamic_mindir
@@ -86,12 +85,9 @@ tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype
86
85
  "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
87
86
  "Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
88
87
 
89
- tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
90
- "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
91
- "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
92
-
93
- if hasattr(np_dtype, "bfloat16"):
94
- tensor_to_np_type["BFloat16"] = np_dtype.bfloat16
88
+ _tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
89
+ "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
90
+ "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
95
91
 
96
92
  np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
97
93
 
@@ -99,6 +95,8 @@ mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4:
99
95
  5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
100
96
  11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
101
97
 
98
+ safetensors_to_mstype = {'Int4': mstype.qint4x2}
99
+
102
100
  _ckpt_mutex = RLock()
103
101
 
104
102
  # unit is KB
@@ -112,6 +110,21 @@ INT_64_MAX = 9223372036854775807
112
110
  cpu_cast = Cast().set_device("CPU")
113
111
 
114
112
  _ckpt_fs = FileSystem()
113
+ _ckpt_fs_initialized = False
114
+
115
+
116
+ def tensor_to_np_type(tensor_type_str):
117
+ """tensor to numpy type"""
118
+ if tensor_type_str == "BFloat16":
119
+ from mindspore.common import np_dtype
120
+ if not np_dtype.np_dtype_valid(True):
121
+ raise TypeError(
122
+ "The Numpy bfloat16 data type is not supported now, please ensure that the current "
123
+ "Numpy version is not less than the version when the mindspore is compiled, "
124
+ "and the major versions are same."
125
+ )
126
+ return np_dtype.bfloat16
127
+ return _tensor_to_np_type.get(tensor_type_str)
115
128
 
116
129
 
117
130
  def init_ckpt_file_system(fs: FileSystem):
@@ -121,8 +134,12 @@ def init_ckpt_file_system(fs: FileSystem):
121
134
  _register_basic_file_system(fs)
122
135
 
123
136
 
124
- # Initialize checkpoint file system
125
- init_ckpt_file_system(_ckpt_fs)
137
+ def _ensure_ckpt_fs_initialized():
138
+ """Ensure checkpoint file system is initialized"""
139
+ global _ckpt_fs_initialized
140
+ if not _ckpt_fs_initialized:
141
+ init_ckpt_file_system(_ckpt_fs)
142
+ _ckpt_fs_initialized = True
126
143
 
127
144
 
128
145
  def _wait_async_process_save_ckpt():
@@ -272,10 +289,7 @@ def _update_param(param, new_param, strict_load):
272
289
 
273
290
  if param.data.dtype != new_param.data.dtype:
274
291
  if _type_convert(param, new_param, strict_load):
275
- if new_param.data.dtype == mstype.bfloat16:
276
- new_tensor = cpu_cast(new_param.data, param.data.dtype)
277
- else:
278
- new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
292
+ new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
279
293
  param.set_data(new_tensor, param.sliced)
280
294
  return
281
295
 
@@ -313,7 +327,7 @@ def _update_param(param, new_param, strict_load):
313
327
  def _type_convert(param, new_param, strict_load):
314
328
  """Whether to convert parameter's type during load checkpoint into network."""
315
329
  float_type = (mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16)
316
- int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64)
330
+ int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.qint4x2)
317
331
  if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or
318
332
  {param.data.dtype, new_param.data.dtype}.issubset(int_type)):
319
333
  logger.warning(f"The type of {new_param.name}:{new_param.data.dtype} in 'parameter_dict' is different from "
@@ -359,7 +373,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
359
373
 
360
374
 
361
375
  def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False,
362
- format="ckpt"):
376
+ format="ckpt", remove_redundancy=None):
363
377
  """Execute the process of saving checkpoint into file."""
364
378
  try:
365
379
  with _ckpt_mutex:
@@ -383,9 +397,6 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
383
397
 
384
398
  crc_num = 0
385
399
  for name, value in data_list.items():
386
- if name == "random_op":
387
- _write_random_seed(name, value, f)
388
- continue
389
400
  if value[0] == "mapparameter":
390
401
  _write_mapparameter(name, value, f, map_param_inc)
391
402
  continue
@@ -428,16 +439,19 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
428
439
  elif format == "safetensors":
429
440
  save_dict = {}
430
441
  crc_num = 0
442
+ meta_data = {"format": "ms"}
443
+ if remove_redundancy is not None and isinstance(remove_redundancy, bool):
444
+ meta_data["remove_redundancy"] = str(remove_redundancy)
431
445
  for name in sorted(data_list.keys()):
432
446
  value = data_list[name]
433
447
  if isinstance(value[2], np.ndarray):
448
+ if value[1] == str(mstype.qint4x2):
449
+ meta_data[name] = str(mstype.qint4x2)
434
450
  save_dict[name] = value[2]
435
451
  else:
436
- bytes_data = value[2].get_bytes()
437
- np_type = tensor_to_np_type.get(value[1])
438
- np_array = np.frombuffer(bytes_data, np_type)
439
- new_np_array = np_array.reshape(value[0])
440
- save_dict[name] = new_np_array
452
+ if value[2].dtype == mstype.qint4x2:
453
+ meta_data[name] = str(mstype.qint4x2)
454
+ save_dict[name] = value[2].asnumpy()
441
455
 
442
456
  if crc_check:
443
457
  crc_num = binascii.crc32(bytes(name, encoding='utf-8'), crc_num)
@@ -445,10 +459,12 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
445
459
  bytes(save_dict[name]), crc_num)
446
460
  safetensors_save_time_start = time.time()
447
461
  if crc_check:
448
- save_file(save_dict, tmp_name, metadata={
449
- "crc_num": str(crc_num)})
462
+ meta_data.update({"crc_num": str(crc_num)})
463
+ if save_dict:
464
+ save_file(save_dict, tmp_name, metadata=meta_data)
450
465
  else:
451
466
  save_file(save_dict, tmp_name)
467
+
452
468
  safetensors_save_time_end = time.time()
453
469
  cost_time = safetensors_save_time_end - safetensors_save_time_start
454
470
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors io cost time:{cost_time}.")
@@ -457,25 +473,13 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
457
473
  f"simultaneously modified a file.")
458
474
  elif _ckpt_fs.backend != "mindio":
459
475
  os.rename(tmp_name, ckpt_file_name)
460
- os.chmod(ckpt_file_name, stat.S_IRUSR)
476
+ os.chmod(ckpt_file_name, stat.S_IRUSR)
461
477
  except BaseException as e:
462
478
  logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
463
479
  "or the disk space is insufficient and so on.", ckpt_file_name)
464
480
  raise e
465
481
 
466
482
 
467
- def _write_random_seed(name, value, f):
468
- """Write random op into protobuf file."""
469
- checkpoint_list = Checkpoint()
470
- param_value = checkpoint_list.value.add()
471
- param_value.tag = name
472
- param_tensor = param_value.tensor
473
- param_tensor.dims.extend(0)
474
- param_tensor.tensor_type = "random_op"
475
- param_tensor.tensor_content = value
476
- f.write(checkpoint_list.SerializeToString())
477
-
478
-
479
483
  def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
480
484
  """Write parameter data into protobuf file."""
481
485
  data_size = value[2].nbytes / 1024
@@ -599,7 +603,7 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
599
603
  return ckpt_file_name
600
604
 
601
605
 
602
- def _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode):
606
+ def _check_load_checkpoint_unsupported_param(format, dec_key, dec_mode):
603
607
  """check load checkpoint unsupported param"""
604
608
  if format != "safetensors":
605
609
  return
@@ -614,7 +618,7 @@ def _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode):
614
618
  f"be set to default value '{default_value}', but got '{current_value}'.")
615
619
 
616
620
 
617
- def _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc=False, global_step_num=None):
621
+ def _check_save_checkpoint_unsupported_param(format, enc_key, enc_mode, map_param_inc=False, global_step_num=None):
618
622
  """check save checkpoint unsupported param"""
619
623
  if format != "safetensors":
620
624
  return
@@ -644,11 +648,11 @@ def _check_async_save(async_save):
644
648
 
645
649
 
646
650
  def _async_process_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False,
647
- crc_check=False, format="ckpt", cond=None):
651
+ crc_check=False, format="ckpt", cond=None, remove_redundancy=None):
648
652
  """Check whether the process is pulled up successfully, execute the process of saving checkpoint into file."""
649
653
  with cond:
650
654
  cond.notify()
651
- _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
655
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format, remove_redundancy)
652
656
 
653
657
 
654
658
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
@@ -729,6 +733,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
729
733
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
730
734
  """
731
735
  start_save_time = time.time()
736
+ _ensure_ckpt_fs_initialized()
732
737
  ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
733
738
  integrated_save = Validator.check_bool(integrated_save)
734
739
  async_save = _check_async_save(async_save)
@@ -739,7 +744,9 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
739
744
  map_param_inc = kwargs.get('incremental', False)
740
745
  logger.info("Execute the process of saving checkpoint files.")
741
746
  global_step_num = kwargs.get('global_step_num', None)
742
- _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc, global_step_num)
747
+ remove_redundancy = kwargs.get('remove_redundancy', None)
748
+ remove_redundancy = Validator.check_isinstance("remove_redundancy", remove_redundancy, (type(None), bool))
749
+ _check_save_checkpoint_unsupported_param(format, enc_key, enc_mode, map_param_inc, global_step_num)
743
750
 
744
751
  if append_dict and "__exception_save__" in append_dict:
745
752
  s1 = mindspore.hal.Stream()
@@ -768,16 +775,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
768
775
  data_list_np = OrderedDict()
769
776
  with _ckpt_mutex:
770
777
  for param in save_obj:
771
- if param["name"] == "random_op":
772
- if os.getenv("AITURBO") == "1":
773
- data_list_np["random_op"] = []
774
- data_list_np["random_op"].append(param["data"])
775
- if crc_check:
776
- bytes_value = bytes(data_list_np[key][0])
777
- data_list_np[key].append(binascii.crc32(bytes_value))
778
- else:
779
- data_list["random_op"] = param["data"]
780
- continue
781
778
  key = param["name"]
782
779
  data_list[key] = []
783
780
  data_list_np[key] = []
@@ -841,7 +838,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
841
838
  while process_flag:
842
839
  process = ctx.Process(target=_async_process_save,
843
840
  args=(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check,
844
- format, cond), daemon=True, name="asyn_save_ckpt")
841
+ format, cond, remove_redundancy), daemon=True, name="asyn_save_ckpt")
845
842
  process.start()
846
843
  with cond:
847
844
  wait_flag = cond.wait(timeout=5)
@@ -854,11 +851,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
854
851
  data_copy = copy.deepcopy(data_list)
855
852
  _wait_async_thread_save_ckpt()
856
853
  thr = Thread(target=_exec_save,
857
- args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
854
+ args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format,
855
+ remove_redundancy),
858
856
  name="asyn_save_ckpt")
859
857
  thr.start()
860
858
  else:
861
- _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
859
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format, remove_redundancy)
862
860
 
863
861
  mstx.range_end(range_id)
864
862
  logger.info("Saving checkpoint process is finished.")
@@ -926,10 +924,13 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
926
924
  """Convert a dict of Parameter to param_list."""
927
925
  param_list = []
928
926
  for (key, value) in save_obj.items():
929
- if isinstance(key, str) and (isinstance(value, (Parameter, str)) or _is_buffer_type(value)):
927
+ if isinstance(key, str):
930
928
  if choice_func is not None and not choice_func(key):
931
929
  continue
932
- each_param = {"name": key, "data": value}
930
+ if isinstance(value, np.ndarray):
931
+ each_param = {"name": key, "data": Parameter(Tensor.from_numpy(value))}
932
+ if isinstance(value, (Parameter, str)) or _is_buffer_type(value):
933
+ each_param = {"name": key, "data": value}
933
934
  param_list.append(each_param)
934
935
  else:
935
936
  raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
@@ -941,16 +942,12 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
941
942
  def _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode):
942
943
  """Convert cell.parameters_and_names to OrderedDict."""
943
944
  param_dict = OrderedDict()
945
+ is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
944
946
  for _, param in save_obj.parameters_and_names():
945
- if param.name.startswith("accu_grads") or param.name.endswith("expert_load"):
946
- continue
947
- not_sliced = not param.sliced
948
- is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
949
947
  # All parameters are initialized immediately under PyNative mode, skip this judgement.
950
- judgment = not_sliced or param.has_init
951
948
  if param.param_info.is_pipeline_shared_param:
952
949
  continue
953
- if is_graph_mode and is_parallel_mode and judgment:
950
+ if is_parallel_mode and is_graph_mode and (not param.sliced or param.has_init):
954
951
  continue
955
952
  if choice_func is not None and not choice_func(param.name):
956
953
  continue
@@ -974,12 +971,6 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
974
971
  if not is_parallel_mode:
975
972
  save_obj.init_parameters_data()
976
973
  param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode)
977
- if append_dict and "random_op" in append_dict:
978
- phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
979
- if phase in save_obj.compile_cache and _executor.has_compiled(phase):
980
- random_byte = _executor._graph_executor.get_random_status(phase)
981
- param_list.append({"name": "random_op", "data": random_byte})
982
- append_dict.pop("random_op")
983
974
  for (key, value) in param_dict.items():
984
975
  each_param = {"name": key}
985
976
  if isinstance(value, MapParameter):
@@ -1002,15 +993,14 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
1002
993
  param_data.append(str(param_tensor.dtype))
1003
994
  param_data.append(value.key)
1004
995
  else:
1005
- param_data = value.data
1006
996
  if append_dict and "__exception_save__" in append_dict:
1007
997
  param_data = Tensor(Tensor_.move_to(value, "CPU", False))
998
+ else:
999
+ param_data = Tensor(value.data)
1008
1000
 
1009
1001
  # in automatic model parallel scenario, some parameters were split to all the devices,
1010
1002
  # which should be combined before saving
1011
1003
  if key in parameter_layout_dict:
1012
- if not append_dict or "__exception_save__" not in append_dict:
1013
- param_data = Tensor(value.data)
1014
1004
  param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
1015
1005
  integrated_save)
1016
1006
 
@@ -1215,12 +1205,26 @@ def _check_param_type(param_config, key, target_type, requested):
1215
1205
  return None
1216
1206
 
1217
1207
 
1208
+ def _check_remove_redundancy(remove_redundancy, f):
1209
+ """Check whether remove_redundancy is consistent with the safetensors file."""
1210
+ if f.metadata() is not None and "remove_redundancy" in f.metadata().keys():
1211
+ if f.metadata()["remove_redundancy"] == "True" and not remove_redundancy:
1212
+ logger.warning("For 'load_checkpoint', the safetensors file is deduplicated, "
1213
+ "but remove_redundancy is set to False.")
1214
+ return True
1215
+ if f.metadata()["remove_redundancy"] == "False" and remove_redundancy:
1216
+ logger.warning("For 'load_checkpoint', the safetensors file is non-deduplicated, "
1217
+ "but remove_redundancy is set to True.")
1218
+ return False
1219
+ return remove_redundancy
1220
+
1221
+
1218
1222
  def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1219
- dec_mode, crc_check, format):
1223
+ dec_mode, crc_check, format, remove_redundancy):
1220
1224
  """load parameter into parameter_dict"""
1221
1225
  ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
1222
1226
  if format == "safetensors":
1223
- with safe_open(ckpt_file_name, framework='np') as f:
1227
+ with _fast_safe_open(ckpt_file_name, framework='np') as f:
1224
1228
  cal_crc_num = 0
1225
1229
  total_io_cost_time = 0
1226
1230
  for k in sorted(f.keys()):
@@ -1234,8 +1238,13 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1234
1238
  io_end_time = time.time()
1235
1239
  io_cost_time = io_end_time - io_start_time
1236
1240
  total_io_cost_time += io_cost_time
1237
- parameter_dict[k] = Parameter(Tensor.from_numpy(value))
1238
-
1241
+ if f.metadata() is not None and k in f.metadata().keys():
1242
+ sf_dtype = f.metadata()[k]
1243
+ ms_dtype = safetensors_to_mstype[sf_dtype]
1244
+ parameter_dict[k] = Parameter(Tensor(value, dtype=ms_dtype))
1245
+ else:
1246
+ parameter_dict[k] = Parameter(Tensor.from_numpy(value))
1247
+ remove_redundancy = _check_remove_redundancy(remove_redundancy, f)
1239
1248
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1240
1249
  f"Load safetensors io cost time:{total_io_cost_time}.")
1241
1250
  if crc_check:
@@ -1248,7 +1257,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1248
1257
  if cal_crc_num != crc_num:
1249
1258
  raise ValueError("For 'load_checkpoint', the crc check has failed. "
1250
1259
  "Please check whether the ckpt file is damaged.")
1251
- return
1260
+ return remove_redundancy
1252
1261
  checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
1253
1262
  try:
1254
1263
  param_data_list = []
@@ -1261,9 +1270,6 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1261
1270
  logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
1262
1271
  "please use `choice_func` instead.")
1263
1272
  for element_id, element in enumerate(checkpoint_list.value):
1264
- if element.tag == "random_op":
1265
- parameter_dict["random_op"] = element.tensor.tensor_content
1266
- continue
1267
1273
  if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
1268
1274
  continue
1269
1275
  if specify_prefix is None and filter_prefix is None and \
@@ -1278,11 +1284,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1278
1284
  continue
1279
1285
  data = element.tensor.tensor_content
1280
1286
  data_type = element.tensor.tensor_type
1281
- np_type = tensor_to_np_type.get(data_type)
1282
1287
  ms_type = tensor_to_ms_type[data_type]
1283
- if data_type == 'str':
1284
- str_length = int(len(data) / 4)
1285
- np_type = np_type + str(str_length)
1286
1288
  param_data_list.append(data)
1287
1289
  if (element_id == len(checkpoint_list.value) - 1) or \
1288
1290
  (element.tag != checkpoint_list.value[element_id + 1].tag):
@@ -1290,6 +1292,8 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1290
1292
  param_data_list.clear()
1291
1293
  dims = element.tensor.dims
1292
1294
  if data_type == 'str':
1295
+ str_length = int(len(data) / 4)
1296
+ np_type = "U" + str(str_length)
1293
1297
  str_value = np.frombuffer(new_data, np_type)
1294
1298
  parameter_dict[element.tag] = str(str_value[0])
1295
1299
  else:
@@ -1301,6 +1305,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1301
1305
  _offload_if_config(parameter)
1302
1306
 
1303
1307
  logger.info("Loading checkpoint files process is finished.")
1308
+ return remove_redundancy
1304
1309
 
1305
1310
  except BaseException as e:
1306
1311
  logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
@@ -1320,6 +1325,9 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1320
1325
  And using either of those two args will override `choice_func` at the same time.
1321
1326
  - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1322
1327
  - When loading a checkpoint that has removed redundancy, the network should be compiled.
1328
+ - When `net` is not None, it will verify whether the `remove_redundancy` parameter matches the
1329
+ deduplication flag in the loaded safetensors file. If they are different, load the file according to
1330
+ the deduplication flag in the file.
1323
1331
 
1324
1332
  Args:
1325
1333
  ckpt_file_name (str): Checkpoint file name.
@@ -1392,13 +1400,14 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1392
1400
  """
1393
1401
  start_load_time = time.time()
1394
1402
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin load checkpoint.")
1403
+ _ensure_ckpt_fs_initialized()
1395
1404
  specify_prefix = _check_prefix(specify_prefix)
1396
1405
  filter_prefix = _check_prefix(filter_prefix)
1397
1406
  dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1398
1407
  dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1399
1408
  crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
1400
1409
  remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1401
- _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode)
1410
+ _check_load_checkpoint_unsupported_param(format, dec_key, dec_mode)
1402
1411
  logger.info("Execute the process of loading checkpoint files.")
1403
1412
 
1404
1413
  parameter_dict = {}
@@ -1424,8 +1433,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1424
1433
  f"passed the CRC check and has been corrupted.")
1425
1434
  parameter_dict[key] = Parameter(Tensor(value[0]), name=key)
1426
1435
  else:
1427
- _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1428
- dec_mode, crc_check, format)
1436
+ remove_redundancy = _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix,
1437
+ choice_func, dec_key, dec_mode, crc_check, format, remove_redundancy)
1429
1438
 
1430
1439
  if not parameter_dict:
1431
1440
  raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
@@ -1672,9 +1681,22 @@ def _check_load_param_into_net(net, parameter_dict):
1672
1681
  msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
1673
1682
  "but got {}.".format(type(parameter_dict)))
1674
1683
  raise TypeError(msg)
1675
- if "random_op" in parameter_dict.keys():
1676
- net._add_attr("random_op_snapshot", parameter_dict["random_op"])
1677
- parameter_dict.pop("random_op")
1684
+ for key, value in parameter_dict.items():
1685
+ if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
1686
+ logger.critical("Load parameters into net failed.")
1687
+ msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a "
1688
+ "'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value)))
1689
+ raise TypeError(msg)
1690
+
1691
+
1692
+ def _check_remove_redundancy_net(net):
1693
+ """Check whether the network is compiled with the remove_redundancy feature."""
1694
+ if get_group_size() == 1:
1695
+ raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
1696
+ f"in parallel scenarios, but got stand_alone.")
1697
+ if not net.compile_cache and not net.parameter_layout_dict:
1698
+ raise ValueError("When loading a parameter dict that has removed redundancy, "
1699
+ "the network should be compiled.")
1678
1700
 
1679
1701
 
1680
1702
  def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False):
@@ -1721,18 +1743,14 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1721
1743
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1722
1744
  """
1723
1745
  _check_load_param_into_net(net, parameter_dict)
1724
- for key, value in parameter_dict.items():
1725
- if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
1726
- logger.critical("Load parameters into net failed.")
1727
- msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a "
1728
- "'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value)))
1729
- raise TypeError(msg)
1730
1746
 
1731
1747
  strict_load = Validator.check_bool(strict_load)
1732
1748
  remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1733
1749
  logger.info("Execute the process of loading parameters into net.")
1734
1750
  param_not_load = []
1751
+ param_loaded = set()
1735
1752
  ckpt_not_load = list(parameter_dict.keys())
1753
+ is_parallel_mode = _is_auto_parallel_mode(net)
1736
1754
  for _, param in net.parameters_and_names():
1737
1755
  if param.param_info.is_pipeline_shared_param:
1738
1756
  continue
@@ -1748,22 +1766,23 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1748
1766
  if hasattr(param, "init_param") and not param.init_param:
1749
1767
  param.init_param = True
1750
1768
  ckpt_not_load.remove(param.name)
1769
+ param_loaded.add(param.name)
1751
1770
  else:
1771
+ if param.name.startswith("accu_grads"):
1772
+ continue
1773
+ if param.param_info.is_pipeline_shared_param:
1774
+ continue
1775
+ if is_parallel_mode and not param.sliced:
1776
+ continue
1752
1777
  param_not_load.append(param.name)
1753
1778
 
1754
1779
  if param_not_load and not strict_load:
1755
1780
  _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
1756
1781
 
1757
1782
  if remove_redundancy:
1758
- if get_group_size() == 1:
1759
- raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
1760
- f"in parallel scenarios, but got stand_alone.")
1761
- if not net.compile_cache and not net.parameter_layout_dict:
1762
- raise ValueError("When loading a parameter dict that has removed redundancy, "
1763
- "the network should be compiled.")
1783
+ _check_remove_redundancy_net(net)
1764
1784
  param_layout = net.parameter_layout_dict
1765
- _single_parameter_broadcast(net, param_layout, param_not_load)
1766
- mindspore.hal.synchronize()
1785
+ _single_parameter_broadcast(net, param_layout, param_not_load, param_loaded)
1767
1786
 
1768
1787
  logger.info("Loading parameters into net is finished.")
1769
1788
  if param_not_load:
@@ -1878,9 +1897,10 @@ def _save_graph(network, file_name):
1878
1897
  file_name (str): Graph file name into which the graph will be saved.
1879
1898
  """
1880
1899
  logger.info("Execute the process of saving graph.")
1881
-
1882
1900
  file_name = os.path.realpath(file_name)
1883
1901
  graph_pb = network.get_func_graph_proto()
1902
+ if os.path.isfile(file_name) and graph_pb:
1903
+ os.remove(file_name)
1884
1904
  if graph_pb:
1885
1905
  with open(file_name, "wb") as f:
1886
1906
  os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
@@ -2193,6 +2213,11 @@ def _save_onnx(net, file_name, *inputs, **kwargs):
2193
2213
  file_name += ".onnx"
2194
2214
  if os.path.exists(file_name):
2195
2215
  os.chmod(file_name, stat.S_IWUSR)
2216
+ else:
2217
+ dir_path = os.path.dirname(file_name)
2218
+ if not os.path.exists(dir_path):
2219
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
2220
+ os.chmod(dir_path, 0o700)
2196
2221
  with open(file_name, 'wb') as f:
2197
2222
  f.write(onnx_stream)
2198
2223
  os.chmod(file_name, stat.S_IRUSR)
@@ -2242,7 +2267,7 @@ def _get_data_file(is_encrypt, kwargs, data_file_name):
2242
2267
  if is_encrypt():
2243
2268
  place_holder_data = _encrypt(place_holder_data, len(place_holder_data), kwargs["enc_key"],
2244
2269
  len(kwargs["enc_key"]), kwargs["enc_mode"])
2245
- parameter_size = (offset / 1024)
2270
+ parameter_size = offset / 1024
2246
2271
  try:
2247
2272
  f = open(data_file_name, "wb")
2248
2273
  f.write(place_holder_data)
@@ -2284,9 +2309,11 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
2284
2309
  external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
2285
2310
  data_file_name = os.path.join(dirname, external_local)
2286
2311
  f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
2312
+
2313
+ round = 0
2314
+ names = []
2315
+
2287
2316
  try:
2288
- round = 0
2289
- names = []
2290
2317
  for param_proto in model.graph.parameter:
2291
2318
  name = param_proto.name[param_proto.name.find(":") + 1:]
2292
2319
  names.append((name, param_proto))
@@ -2587,7 +2614,7 @@ def parse_print(print_file_name):
2587
2614
  dims = print_.tensor.dims
2588
2615
  data_type = print_.tensor.tensor_type
2589
2616
  data = print_.tensor.tensor_content
2590
- np_type = tensor_to_np_type.get(data_type)
2617
+ np_type = tensor_to_np_type(data_type)
2591
2618
  param_data = np.fromstring(data, np_type)
2592
2619
  ms_type = tensor_to_ms_type.get(data_type)
2593
2620
  if dims and dims != [0]:
@@ -2730,28 +2757,35 @@ def convert_model(mindir_file, convert_file, file_format):
2730
2757
  export(net, *net_input, file_name=convert_file, file_format=file_format)
2731
2758
 
2732
2759
 
2733
- def _transform_tensor_to_numpy(path, name_map=None):
2734
- return _load_and_transform(path, name_map, mindspore.load_checkpoint, lambda v, new_name: v.asnumpy())
2760
+ def _load_ckpt_to_new_name_map(path, name_map=None):
2761
+ return _load_and_transform(path, name_map, mindspore.load_checkpoint, None)
2735
2762
 
2736
2763
 
2737
- def _transform_numpy_to_tensor(path, name_map=None):
2738
- return _load_and_transform(path, name_map, load_file, lambda v, new_name: mindspore.Parameter(v, name=new_name))
2764
+ def _load_sf_to_new_name_map(path, name_map=None):
2765
+ load_func = partial(mindspore.load_checkpoint, format="safetensors")
2766
+ return _load_and_transform(path, name_map, load_func, None)
2739
2767
 
2740
2768
 
2741
2769
  def _process_file(file_info):
2742
2770
  cur_ckpt_path, name_map, save_path, file = file_info
2743
- param_dict_numpy = _transform_tensor_to_numpy(cur_ckpt_path, name_map)
2771
+ if name_map is not None:
2772
+ param_dict = _load_ckpt_to_new_name_map(cur_ckpt_path, name_map)
2773
+ else:
2774
+ param_dict = mindspore.load_checkpoint(cur_ckpt_path)
2744
2775
  safetensors_filename = file.replace(".ckpt", ".safetensors")
2745
2776
  dst_file = os.path.join(save_path, safetensors_filename)
2746
- save_file(param_dict_numpy, dst_file)
2777
+ mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
2747
2778
 
2748
2779
 
2749
2780
  def _process_file_safetensors(file_info):
2750
2781
  cur_safe_path, name_map, save_path, file = file_info
2751
- param_dict_tensor = _transform_numpy_to_tensor(cur_safe_path, name_map)
2782
+ if name_map is not None:
2783
+ param_dict = _load_sf_to_new_name_map(cur_safe_path, name_map)
2784
+ else:
2785
+ param_dict = mindspore.load_checkpoint(cur_safe_path, format="safetensors")
2752
2786
  ckpt_filename = file.replace(".safetensors", ".ckpt")
2753
2787
  dst_file = os.path.join(save_path, ckpt_filename)
2754
- mindspore.save_checkpoint(param_dict_tensor, dst_file)
2788
+ mindspore.save_checkpoint(param_dict, dst_file)
2755
2789
 
2756
2790
 
2757
2791
  def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
@@ -2862,10 +2896,14 @@ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_rege
2862
2896
  if save_path and not os.path.exists(save_path):
2863
2897
  os.makedirs(save_path, exist_ok=True)
2864
2898
 
2865
- param_dict_numpy = _transform_tensor_to_numpy(file_path, name_map)
2899
+ if name_map is not None:
2900
+ param_dict = _load_ckpt_to_new_name_map(file_path, name_map)
2901
+ else:
2902
+ param_dict = mindspore.load_checkpoint(file_path)
2903
+
2866
2904
  safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
2867
2905
  dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
2868
- save_file(param_dict_numpy, dst_file)
2906
+ mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
2869
2907
 
2870
2908
 
2871
2909
  def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
@@ -2924,10 +2962,14 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
2924
2962
  if save_path and not os.path.exists(save_path):
2925
2963
  os.makedirs(save_path, exist_ok=True)
2926
2964
 
2927
- param_dict_tensor = _transform_numpy_to_tensor(file_path, name_map)
2965
+ if name_map is not None:
2966
+ param_dict = _load_sf_to_new_name_map(file_path, name_map)
2967
+ else:
2968
+ param_dict = mindspore.load_checkpoint(file_path, format="safetensors")
2969
+
2928
2970
  ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
2929
2971
  dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
2930
- mindspore.save_checkpoint(param_dict_tensor, dst_file)
2972
+ mindspore.save_checkpoint(param_dict, dst_file)
2931
2973
 
2932
2974
 
2933
2975
  def restore_group_info_list(group_info_file_name):