mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.5.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (389) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -3
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +0 -5
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/compile_config.py +64 -0
  11. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
  13. mindspore/_extends/parse/parser.py +23 -5
  14. mindspore/_extends/parse/standard_method.py +123 -27
  15. mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
  16. mindspore/amp.py +7 -1
  17. mindspore/atlprov.dll +0 -0
  18. mindspore/avcodec-59.dll +0 -0
  19. mindspore/avdevice-59.dll +0 -0
  20. mindspore/avfilter-8.dll +0 -0
  21. mindspore/avformat-59.dll +0 -0
  22. mindspore/avutil-57.dll +0 -0
  23. mindspore/boost/boost_cell_wrapper.py +136 -41
  24. mindspore/c1.dll +0 -0
  25. mindspore/c1xx.dll +0 -0
  26. mindspore/c2.dll +0 -0
  27. mindspore/common/__init__.py +3 -1
  28. mindspore/common/_register_for_tensor.py +0 -1
  29. mindspore/common/_stub_tensor.py +25 -4
  30. mindspore/common/_tensor_cpp_method.py +17 -0
  31. mindspore/common/_tensor_docs.py +6132 -0
  32. mindspore/common/api.py +98 -21
  33. mindspore/common/dtype.py +34 -34
  34. mindspore/common/dump.py +2 -1
  35. mindspore/common/file_system.py +8 -3
  36. mindspore/common/generator.py +2 -0
  37. mindspore/common/hook_handle.py +3 -1
  38. mindspore/common/initializer.py +3 -4
  39. mindspore/common/lazy_inline.py +8 -2
  40. mindspore/common/mindir_util.py +10 -2
  41. mindspore/common/parameter.py +31 -15
  42. mindspore/common/tensor.py +713 -1337
  43. mindspore/communication/__init__.py +1 -1
  44. mindspore/communication/_comm_helper.py +5 -0
  45. mindspore/communication/comm_func.py +215 -173
  46. mindspore/communication/management.py +23 -20
  47. mindspore/context.py +285 -191
  48. mindspore/dataset/__init__.py +23 -19
  49. mindspore/dataset/callback/ds_callback.py +2 -1
  50. mindspore/dataset/core/config.py +84 -3
  51. mindspore/dataset/engine/cache_admin.py +3 -3
  52. mindspore/dataset/engine/cache_client.py +5 -4
  53. mindspore/dataset/engine/datasets.py +192 -149
  54. mindspore/dataset/engine/datasets_audio.py +14 -0
  55. mindspore/dataset/engine/datasets_standard_format.py +11 -11
  56. mindspore/dataset/engine/datasets_text.py +38 -1
  57. mindspore/dataset/engine/datasets_user_defined.py +100 -66
  58. mindspore/dataset/engine/datasets_vision.py +81 -8
  59. mindspore/dataset/engine/iterators.py +281 -63
  60. mindspore/dataset/engine/obs/util.py +8 -0
  61. mindspore/dataset/engine/queue.py +40 -0
  62. mindspore/dataset/engine/samplers.py +26 -2
  63. mindspore/dataset/engine/serializer_deserializer.py +1 -1
  64. mindspore/dataset/engine/validators.py +43 -11
  65. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  66. mindspore/dataset/transforms/transforms.py +29 -12
  67. mindspore/dataset/vision/validators.py +1 -2
  68. mindspore/device_context/__init__.py +21 -0
  69. mindspore/device_context/ascend/__init__.py +25 -0
  70. mindspore/device_context/ascend/device.py +72 -0
  71. mindspore/device_context/ascend/op_debug.py +94 -0
  72. mindspore/device_context/ascend/op_precision.py +193 -0
  73. mindspore/device_context/ascend/op_tuning.py +127 -0
  74. mindspore/device_context/cpu/__init__.py +25 -0
  75. mindspore/device_context/cpu/device.py +62 -0
  76. mindspore/device_context/cpu/op_tuning.py +43 -0
  77. mindspore/device_context/gpu/__init__.py +21 -0
  78. mindspore/device_context/gpu/device.py +70 -0
  79. mindspore/device_context/gpu/op_precision.py +67 -0
  80. mindspore/device_context/gpu/op_tuning.py +175 -0
  81. mindspore/device_manager.py +134 -0
  82. mindspore/dnnl.dll +0 -0
  83. mindspore/dpcmi.dll +0 -0
  84. mindspore/experimental/llm_boost/__init__.py +1 -0
  85. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  86. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  87. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  90. mindspore/experimental/llm_boost/register.py +1 -0
  91. mindspore/experimental/optim/adadelta.py +26 -22
  92. mindspore/experimental/optim/adam.py +3 -0
  93. mindspore/experimental/optim/lr_scheduler.py +33 -24
  94. mindspore/experimental/optim/radam.py +33 -30
  95. mindspore/hal/device.py +28 -0
  96. mindspore/hal/event.py +17 -0
  97. mindspore/hal/memory.py +94 -3
  98. mindspore/hal/stream.py +91 -6
  99. mindspore/include/api/context.h +0 -1
  100. mindspore/jpeg62.dll +0 -0
  101. mindspore/log.py +12 -0
  102. mindspore/mindrecord/__init__.py +1 -1
  103. mindspore/mindrecord/config.py +17 -316
  104. mindspore/mindrecord/filereader.py +1 -9
  105. mindspore/mindrecord/filewriter.py +5 -15
  106. mindspore/mindrecord/mindpage.py +1 -9
  107. mindspore/mindspore_backend.dll +0 -0
  108. mindspore/mindspore_common.dll +0 -0
  109. mindspore/mindspore_core.dll +0 -0
  110. mindspore/mindspore_glog.dll +0 -0
  111. mindspore/mindspore_ops.dll +0 -0
  112. mindspore/mint/__init__.py +824 -218
  113. mindspore/mint/distributed/__init__.py +66 -4
  114. mindspore/mint/distributed/distributed.py +2594 -44
  115. mindspore/mint/linalg/__init__.py +6 -0
  116. mindspore/mint/nn/__init__.py +473 -14
  117. mindspore/mint/nn/functional.py +486 -11
  118. mindspore/mint/nn/layer/__init__.py +17 -4
  119. mindspore/mint/nn/layer/_functions.py +330 -0
  120. mindspore/mint/nn/layer/activation.py +169 -1
  121. mindspore/mint/nn/layer/basic.py +123 -0
  122. mindspore/mint/nn/layer/conv.py +727 -0
  123. mindspore/mint/nn/layer/normalization.py +215 -19
  124. mindspore/mint/nn/layer/padding.py +797 -0
  125. mindspore/mint/nn/layer/pooling.py +170 -0
  126. mindspore/mint/optim/__init__.py +2 -1
  127. mindspore/mint/optim/adam.py +223 -0
  128. mindspore/mint/optim/adamw.py +26 -19
  129. mindspore/mint/special/__init__.py +2 -1
  130. mindspore/msobj140.dll +0 -0
  131. mindspore/mspdb140.dll +0 -0
  132. mindspore/mspdbcore.dll +0 -0
  133. mindspore/mspdbst.dll +0 -0
  134. mindspore/mspft140.dll +0 -0
  135. mindspore/msvcdis140.dll +0 -0
  136. mindspore/msvcp140_1.dll +0 -0
  137. mindspore/msvcp140_2.dll +0 -0
  138. mindspore/msvcp140_atomic_wait.dll +0 -0
  139. mindspore/msvcp140_codecvt_ids.dll +0 -0
  140. mindspore/multiprocessing/__init__.py +5 -0
  141. mindspore/nn/cell.py +126 -19
  142. mindspore/nn/dynamic_lr.py +2 -1
  143. mindspore/nn/layer/activation.py +6 -6
  144. mindspore/nn/layer/basic.py +35 -25
  145. mindspore/nn/layer/channel_shuffle.py +3 -3
  146. mindspore/nn/layer/embedding.py +3 -3
  147. mindspore/nn/layer/normalization.py +8 -7
  148. mindspore/nn/layer/padding.py +4 -3
  149. mindspore/nn/layer/pooling.py +47 -13
  150. mindspore/nn/layer/rnn_cells.py +1 -1
  151. mindspore/nn/layer/rnns.py +2 -1
  152. mindspore/nn/layer/timedistributed.py +5 -5
  153. mindspore/nn/layer/transformer.py +48 -26
  154. mindspore/nn/learning_rate_schedule.py +5 -3
  155. mindspore/nn/loss/loss.py +31 -36
  156. mindspore/nn/optim/ada_grad.py +1 -0
  157. mindspore/nn/optim/adadelta.py +2 -2
  158. mindspore/nn/optim/adam.py +1 -1
  159. mindspore/nn/optim/lars.py +1 -4
  160. mindspore/nn/optim/optimizer.py +1 -1
  161. mindspore/nn/optim/rprop.py +2 -2
  162. mindspore/nn/optim/thor.py +2 -1
  163. mindspore/nn/utils/init.py +13 -11
  164. mindspore/nn/wrap/cell_wrapper.py +4 -6
  165. mindspore/nn/wrap/loss_scale.py +3 -4
  166. mindspore/numpy/array_creations.py +60 -62
  167. mindspore/numpy/array_ops.py +148 -143
  168. mindspore/numpy/logic_ops.py +41 -42
  169. mindspore/numpy/math_ops.py +361 -359
  170. mindspore/numpy/utils.py +16 -16
  171. mindspore/numpy/utils_const.py +4 -4
  172. mindspore/opencv_core452.dll +0 -0
  173. mindspore/opencv_imgcodecs452.dll +0 -0
  174. mindspore/opencv_imgproc452.dll +0 -0
  175. mindspore/ops/__init__.py +2 -1
  176. mindspore/ops/_grad_experimental/grad_comm_ops.py +94 -13
  177. mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
  178. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  179. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  180. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  181. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  182. mindspore/ops/_vmap/vmap_array_ops.py +20 -19
  183. mindspore/ops/_vmap/vmap_base.py +0 -2
  184. mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
  185. mindspore/ops/_vmap/vmap_math_ops.py +11 -9
  186. mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
  187. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
  188. mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
  189. mindspore/ops/auto_generate/gen_extend_func.py +554 -60
  190. mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
  191. mindspore/ops/auto_generate/gen_ops_prim.py +8024 -3409
  192. mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
  193. mindspore/ops/composite/base.py +1 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
  195. mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
  196. mindspore/ops/function/__init__.py +12 -0
  197. mindspore/ops/function/array_func.py +561 -159
  198. mindspore/ops/function/clip_func.py +64 -0
  199. mindspore/ops/function/debug_func.py +28 -20
  200. mindspore/ops/function/image_func.py +1 -1
  201. mindspore/ops/function/linalg_func.py +5 -4
  202. mindspore/ops/function/math_func.py +1659 -290
  203. mindspore/ops/function/nn_func.py +988 -317
  204. mindspore/ops/function/parameter_func.py +3 -56
  205. mindspore/ops/function/random_func.py +243 -33
  206. mindspore/ops/function/sparse_unary_func.py +1 -1
  207. mindspore/ops/functional.py +18 -5
  208. mindspore/ops/functional_overload.py +897 -0
  209. mindspore/ops/operations/__init__.py +3 -2
  210. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  211. mindspore/ops/operations/_grad_ops.py +2 -34
  212. mindspore/ops/operations/_infer_ops.py +2 -1
  213. mindspore/ops/operations/_inner_ops.py +38 -8
  214. mindspore/ops/operations/array_ops.py +45 -303
  215. mindspore/ops/operations/comm_ops.py +19 -16
  216. mindspore/ops/operations/custom_ops.py +11 -55
  217. mindspore/ops/operations/debug_ops.py +42 -47
  218. mindspore/ops/operations/inner_ops.py +6 -4
  219. mindspore/ops/operations/linalg_ops.py +3 -2
  220. mindspore/ops/operations/manually_defined/ops_def.py +185 -104
  221. mindspore/ops/operations/math_ops.py +11 -216
  222. mindspore/ops/operations/nn_ops.py +146 -308
  223. mindspore/ops/primitive.py +23 -21
  224. mindspore/ops/tensor_method.py +1669 -0
  225. mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
  226. mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
  227. mindspore/ops_generate/arg_handler.py +0 -61
  228. mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
  229. mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
  230. mindspore/ops_generate/base_generator.py +11 -0
  231. mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
  232. mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
  233. mindspore/ops_generate/functional_overload_py_generator.py +110 -0
  234. mindspore/ops_generate/functions_cc_generator.py +233 -0
  235. mindspore/ops_generate/gen_aclnn_implement.py +110 -114
  236. mindspore/ops_generate/gen_constants.py +157 -3
  237. mindspore/ops_generate/gen_ops.py +245 -990
  238. mindspore/ops_generate/gen_pyboost_func.py +97 -998
  239. mindspore/ops_generate/gen_utils.py +119 -33
  240. mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
  241. mindspore/ops_generate/op_api_proto.py +206 -0
  242. mindspore/ops_generate/op_def_py_generator.py +131 -0
  243. mindspore/ops_generate/op_prim_py_generator.py +480 -0
  244. mindspore/ops_generate/op_proto.py +373 -108
  245. mindspore/ops_generate/op_template_parser.py +436 -0
  246. mindspore/ops_generate/ops_def_cc_generator.py +288 -0
  247. mindspore/ops_generate/ops_def_h_generator.py +74 -0
  248. mindspore/ops_generate/ops_name_h_generator.py +68 -0
  249. mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
  250. mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
  251. mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
  252. mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
  253. mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
  254. mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
  255. mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
  256. mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
  257. mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
  258. mindspore/ops_generate/pyboost_utils.py +92 -33
  259. mindspore/ops_generate/template.py +294 -44
  260. mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
  261. mindspore/parallel/__init__.py +3 -3
  262. mindspore/parallel/_auto_parallel_context.py +24 -33
  263. mindspore/parallel/_parallel_serialization.py +13 -2
  264. mindspore/parallel/_utils.py +4 -1
  265. mindspore/parallel/algo_parameter_config.py +1 -1
  266. mindspore/parallel/checkpoint_transform.py +44 -0
  267. mindspore/parallel/cluster/process_entity/_api.py +131 -37
  268. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  269. mindspore/parallel/cluster/run.py +20 -3
  270. mindspore/parallel/parameter_broadcast.py +1 -1
  271. mindspore/parallel/shard.py +3 -0
  272. mindspore/parallel/transform_safetensors.py +119 -253
  273. mindspore/pgodb140.dll +0 -0
  274. mindspore/pgort140.dll +0 -0
  275. mindspore/profiler/__init__.py +17 -4
  276. mindspore/profiler/analysis/__init__.py +0 -0
  277. mindspore/profiler/analysis/parser/__init__.py +0 -0
  278. mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
  279. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  280. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  281. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  282. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  283. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  284. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
  285. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  286. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
  287. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  288. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  289. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  290. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  291. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  292. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  293. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  294. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  295. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  296. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  297. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
  298. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  299. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  300. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  301. mindspore/profiler/analysis/task_manager.py +131 -0
  302. mindspore/profiler/analysis/time_converter.py +84 -0
  303. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  304. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
  305. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  306. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
  307. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
  308. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
  309. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
  310. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  311. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  312. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
  313. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  314. mindspore/profiler/analysis/work_flow.py +73 -0
  315. mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
  316. mindspore/profiler/common/command_executor.py +90 -0
  317. mindspore/profiler/common/constant.py +174 -3
  318. mindspore/profiler/common/file_manager.py +208 -0
  319. mindspore/profiler/common/log.py +130 -0
  320. mindspore/profiler/common/msprof_cmd_tool.py +202 -0
  321. mindspore/profiler/common/path_manager.py +371 -0
  322. mindspore/profiler/common/process_bar.py +168 -0
  323. mindspore/profiler/common/process_pool.py +9 -3
  324. mindspore/profiler/common/profiler_context.py +476 -0
  325. mindspore/profiler/common/profiler_info.py +304 -0
  326. mindspore/profiler/common/profiler_output_path.py +284 -0
  327. mindspore/profiler/common/profiler_parameters.py +210 -0
  328. mindspore/profiler/common/profiler_path_manager.py +120 -0
  329. mindspore/profiler/common/record_function.py +76 -0
  330. mindspore/profiler/common/tlv_decoder.py +76 -0
  331. mindspore/profiler/common/util.py +75 -2
  332. mindspore/profiler/dynamic_profiler.py +270 -37
  333. mindspore/profiler/envprofiler.py +138 -0
  334. mindspore/profiler/mstx.py +199 -0
  335. mindspore/profiler/platform/__init__.py +21 -0
  336. mindspore/profiler/platform/base_profiler.py +40 -0
  337. mindspore/profiler/platform/cpu_profiler.py +124 -0
  338. mindspore/profiler/platform/gpu_profiler.py +74 -0
  339. mindspore/profiler/platform/npu_profiler.py +309 -0
  340. mindspore/profiler/profiler.py +580 -93
  341. mindspore/profiler/profiler_action_controller.py +187 -0
  342. mindspore/profiler/profiler_interface.py +114 -0
  343. mindspore/profiler/schedule.py +208 -0
  344. mindspore/rewrite/api/symbol_tree.py +1 -2
  345. mindspore/run_check/_check_version.py +2 -6
  346. mindspore/runtime/__init__.py +37 -0
  347. mindspore/runtime/device.py +27 -0
  348. mindspore/runtime/event.py +209 -0
  349. mindspore/runtime/executor.py +148 -0
  350. mindspore/runtime/memory.py +392 -0
  351. mindspore/runtime/stream.py +460 -0
  352. mindspore/runtime/thread_bind_core.py +401 -0
  353. mindspore/swresample-4.dll +0 -0
  354. mindspore/swscale-6.dll +0 -0
  355. mindspore/tbbmalloc.dll +0 -0
  356. mindspore/tinyxml2.dll +0 -0
  357. mindspore/train/__init__.py +2 -2
  358. mindspore/train/_utils.py +53 -18
  359. mindspore/train/amp.py +8 -4
  360. mindspore/train/callback/_checkpoint.py +32 -18
  361. mindspore/train/callback/_early_stop.py +1 -1
  362. mindspore/train/callback/_flops_collector.py +105 -69
  363. mindspore/train/callback/_history.py +1 -1
  364. mindspore/train/callback/_summary_collector.py +44 -6
  365. mindspore/train/callback/_tft_register.py +31 -10
  366. mindspore/train/dataset_helper.py +11 -11
  367. mindspore/train/metrics/precision.py +4 -5
  368. mindspore/train/mind_ir_pb2.py +167 -46
  369. mindspore/train/model.py +13 -15
  370. mindspore/train/serialization.py +462 -76
  371. mindspore/train/summary/summary_record.py +1 -2
  372. mindspore/train/train_thor/model_thor.py +1 -1
  373. mindspore/turbojpeg.dll +0 -0
  374. mindspore/utils/__init__.py +4 -2
  375. mindspore/utils/dryrun.py +138 -0
  376. mindspore/utils/runtime_execution_order_check.py +550 -0
  377. mindspore/vcmeta.dll +0 -0
  378. mindspore/vcruntime140.dll +0 -0
  379. mindspore/vcruntime140_1.dll +0 -0
  380. mindspore/version.py +1 -1
  381. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/METADATA +2 -3
  382. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/RECORD +385 -261
  383. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
  384. mindspore/common/_tensor_overload.py +0 -139
  385. mindspore/mindspore_np_dtype.dll +0 -0
  386. mindspore/profiler/envprofiling.py +0 -254
  387. mindspore/profiler/profiling.py +0 -1926
  388. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
  389. {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
@@ -24,9 +24,11 @@ import os
24
24
  import re
25
25
  import shutil
26
26
  import stat
27
+ import atexit
27
28
  import threading
28
29
  from threading import Thread, RLock
29
- from multiprocessing import Process
30
+ from multiprocessing import Pool, active_children
31
+ import multiprocessing as mp
30
32
  from collections import defaultdict, OrderedDict
31
33
  from io import BytesIO
32
34
 
@@ -36,6 +38,9 @@ import time
36
38
  import google
37
39
  import numpy as np
38
40
 
41
+ from safetensors.numpy import save_file, load_file
42
+ from safetensors import safe_open
43
+
39
44
  from mindspore.train.checkpoint_pb2 import Checkpoint
40
45
  from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
41
46
  from mindspore.train.print_pb2 import Print
@@ -44,6 +49,7 @@ import mindspore
44
49
  import mindspore.nn as nn
45
50
  from mindspore import context
46
51
  from mindspore import log as logger
52
+ from mindspore.log import vlog_print
47
53
  from mindspore._checkparam import check_input_data, check_input_dataset
48
54
  from mindspore import _checkparam as Validator
49
55
  from mindspore.common import dtype as mstype
@@ -73,15 +79,12 @@ from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_w
73
79
  from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
74
80
  from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
75
81
  _extract_pipeline_stage_num
76
- from mindspore.train._utils import read_proto, get_parameter_redundancy
82
+ from mindspore.train._utils import read_proto, get_parameter_redundancy, _progress_bar, _load_and_transform
77
83
  from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
78
84
  split_mindir, split_dynamic_mindir
79
85
  from mindspore.common.generator import Generator
80
- from safetensors.numpy import save_file
81
- from safetensors import safe_open
82
86
  from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
83
87
 
84
-
85
88
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
86
89
  "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
87
90
  "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
@@ -123,6 +126,31 @@ def init_ckpt_file_system(fs: FileSystem):
123
126
  init_ckpt_file_system(_ckpt_fs)
124
127
 
125
128
 
129
+ def _wait_async_process_save_ckpt():
130
+ """Waiting for asynchronous saving process of ckpt to complete"""
131
+ for process in active_children():
132
+ if process.name == "asyn_save_ckpt":
133
+ process.join()
134
+
135
+
136
+ def _wait_async_thread_save_ckpt():
137
+ """Waiting for asynchronous saving thread of ckpt to complete"""
138
+ thread_list = threading.enumerate()
139
+ for thread in thread_list:
140
+ if thread.getName() == "asyn_save_ckpt":
141
+ thread.join()
142
+
143
+
144
+ def _async_save_close():
145
+ """Waiting for asynchronous saving of ckpt to complete"""
146
+ _wait_async_process_save_ckpt()
147
+ _wait_async_thread_save_ckpt()
148
+
149
+
150
+ # Registering atexit handles asynchronous save
151
+ atexit.register(_async_save_close)
152
+
153
+
126
154
  def _get_cur_rank_dp(parameter_layout_dict):
127
155
  """ Get dp and tp from layout dict. """
128
156
  pp_num = _get_auto_parallel_context("pipeline_stages")
@@ -282,7 +310,8 @@ def _type_convert(param, new_param, strict_load):
282
310
  {param.data.dtype, new_param.data.dtype}.issubset(int_type)):
283
311
  logger.warning(f"The type of {new_param.name}:{new_param.data.dtype} in 'parameter_dict' is different from "
284
312
  f"the type of it in 'net':{param.data.dtype}, then the type convert from "
285
- f"{new_param.data.dtype} to {param.data.dtype} in the network.")
313
+ f"{new_param.data.dtype} to {param.data.dtype} in the network. May consume additional memory "
314
+ f"and time")
286
315
  return True
287
316
  return False
288
317
 
@@ -329,8 +358,6 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
329
358
  file_name_list = list(os.path.splitext(ckpt_file_name))
330
359
  file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
331
360
  tmp_name = ''.join(file_name_list)
332
- if _ckpt_fs.backend == "mindio":
333
- tmp_name = ckpt_file_name
334
361
  if os.path.exists(ckpt_file_name):
335
362
  os.chmod(ckpt_file_name, stat.S_IWUSR)
336
363
  os.remove(ckpt_file_name)
@@ -338,6 +365,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
338
365
  os.chmod(tmp_name, stat.S_IWUSR)
339
366
  os.remove(tmp_name)
340
367
  if format == "ckpt":
368
+ ckpt_save_time_start = time.time()
341
369
  with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
342
370
  plain_data = None
343
371
  if enc_key is not None:
@@ -378,15 +406,33 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
378
406
  block_data = plain_data.read(max_block_size)
379
407
  if crc_check:
380
408
  f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
409
+ ckpt_save_time_end = time.time()
410
+ cost_time = ckpt_save_time_end - ckpt_save_time_start
411
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save ckpt cost time:{cost_time}.")
381
412
  elif format == "safetensors":
382
413
  save_dict = {}
383
- for name, value in data_list.items():
414
+ crc_num = 0
415
+ for name in sorted(data_list.keys()):
416
+ value = data_list[name]
384
417
  save_dict[name] = value[2].asnumpy()
385
- save_file(save_dict, tmp_name)
418
+
419
+ if crc_check:
420
+ crc_num = binascii.crc32(bytes(name, encoding='utf-8'), crc_num)
421
+ crc_num = binascii.crc32(
422
+ bytes(save_dict[name]), crc_num)
423
+ safetensors_save_time_start = time.time()
424
+ if crc_check:
425
+ save_file(save_dict, tmp_name, metadata={
426
+ "crc_num": str(crc_num)})
427
+ else:
428
+ save_file(save_dict, tmp_name)
429
+ safetensors_save_time_end = time.time()
430
+ cost_time = safetensors_save_time_end - safetensors_save_time_start
431
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors cost time:{cost_time}.")
386
432
  if not os.path.exists(tmp_name):
387
433
  logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
388
434
  f"simultaneously modified a file.")
389
- elif _ckpt_fs.backend != "mindio":
435
+ else:
390
436
  os.rename(tmp_name, ckpt_file_name)
391
437
  os.chmod(ckpt_file_name, stat.S_IRUSR)
392
438
  except BaseException as e:
@@ -522,12 +568,58 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
522
568
  return ckpt_file_name
523
569
 
524
570
 
525
- def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, map_param_inc=False,
526
- global_step_num=None):
527
- param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
528
- or map_param_inc or global_step_num is not None)
529
- if format == "safetensors" and param_not_default:
530
- raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
571
+ def _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode):
572
+ """check load checkpoint unsupported param"""
573
+ if format != "safetensors":
574
+ return
575
+ default_params = {
576
+ "dec_key": None,
577
+ "dec_mode": "AES-GCM",
578
+ }
579
+ for param_name, default_value in default_params.items():
580
+ current_value = locals()[param_name]
581
+ if current_value != default_value:
582
+ raise ValueError(f"For 'load_checkpoint', when format is 'safetensors', the parameter '{param_name}' must "
583
+ f"be set to default value '{default_value}', but got '{current_value}'.")
584
+
585
+
586
+ def _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, async_save=False, map_param_inc=False,
587
+ global_step_num=None):
588
+ """check save checkpoint unsupported param"""
589
+ if format != "safetensors":
590
+ return
591
+ default_params = {
592
+ "enc_key": None,
593
+ "enc_mode": "AES-GCM",
594
+ "async_save": False,
595
+ "map_param_inc": False,
596
+ "global_step_num": None
597
+ }
598
+ for param_name, default_value in default_params.items():
599
+ current_value = locals()[param_name]
600
+ if current_value != default_value:
601
+ raise ValueError(f"For 'save_checkpoint', when format is 'safetensors', the parameter '{param_name}' must "
602
+ f"be set to default value '{default_value}', but got '{current_value}'.")
603
+
604
+
605
+ def _check_async_save(async_save):
606
+ """Check async_save for save_checkpoint."""
607
+ if not isinstance(async_save, (bool, str)):
608
+ raise TypeError("For 'save_checkpoint', the parameter 'async_save' must be bool or str, "
609
+ "but got {}.".format(type(async_save)))
610
+ if isinstance(async_save, str):
611
+ if async_save not in ("process", "thread"):
612
+ raise ValueError("For 'save_checkpoint', the argument 'async_save' can only be 'process' or 'thread',"
613
+ "but got {}.".format(async_save))
614
+ return async_save
615
+
616
+
617
+ def _async_process_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False,
618
+ crc_check=False, format="ckpt", cond=None):
619
+ """Check whether the process is pulled up successfully, execute the process of saving checkpoint into file."""
620
+ with cond:
621
+ cond.notify()
622
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
531
623
 
532
624
 
533
625
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
@@ -544,10 +636,13 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
544
636
  list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
545
637
  elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
546
638
  `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
547
- it can be the returned value of `mindspore.load_checkpoint()`.
639
+ it can be the returned value of :func:`mindspore.load_checkpoint`.
548
640
  ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
549
641
  integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
550
- async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
642
+ async_save (Union[bool, str]): Whether to use asynchronous saving of the checkpoint file, if True,
643
+ the asynchronous thread is used by default. If the type is string,
644
+ the method of asynchronous saving, it can be "process" or "thread".
645
+ Default: ``False`` .
551
646
  append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
552
647
  of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
553
648
  enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is ``None`` , the encryption
@@ -567,8 +662,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
567
662
 
568
663
  Raises:
569
664
  TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
570
- TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
665
+ TypeError: If the parameter `integrated_save` is not bool type.
571
666
  TypeError: If the parameter `ckpt_file_name` is not string type.
667
+ TypeError: If the parameter `async_save` is not bool or string type.
668
+ ValueError: If the parameter `async_save` is string type but not in ["process", "thread"].
572
669
 
573
670
  Examples:
574
671
  >>> import mindspore as ms
@@ -598,7 +695,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
598
695
  """
599
696
  ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
600
697
  integrated_save = Validator.check_bool(integrated_save)
601
- async_save = Validator.check_bool(async_save)
698
+ async_save = _check_async_save(async_save)
602
699
  append_dict = _check_append_dict(append_dict)
603
700
  enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
604
701
  enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
@@ -606,7 +703,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
606
703
  map_param_inc = kwargs.get('incremental', False)
607
704
  logger.info("Execute the process of saving checkpoint files.")
608
705
  global_step_num = kwargs.get('global_step_num', None)
609
- _check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, map_param_inc, global_step_num)
706
+ _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, async_save, map_param_inc, global_step_num)
610
707
 
611
708
  if append_dict and "__exception_save__" in append_dict:
612
709
  s1 = mindspore.hal.Stream()
@@ -682,7 +779,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
682
779
  data_list[key].append(dims)
683
780
  tensor_type = str(param["data"].dtype)
684
781
  data_list[key].append(tensor_type)
685
- data = param["data"]
782
+ data = param["data"] if async_save != "process" else param["data"].asnumpy()
686
783
  data_list[key].append(data)
687
784
 
688
785
  if os.getenv("AITURBO") == "1":
@@ -690,11 +787,35 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
690
787
  ckpt_name = os.path.basename(ckpt_file_name)
691
788
  aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
692
789
  elif async_save:
693
- data_copy = copy.deepcopy(data_list)
694
- thr = Thread(target=_exec_save,
695
- args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
696
- name="asyn_save_ckpt")
697
- thr.start()
790
+ if async_save == "process":
791
+ if sys.platform.startswith("win"):
792
+ logger.warining("The Win platform currently does not support asynchronous process saving of ckpt, "
793
+ "so serial saving of ckpt is used now.")
794
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
795
+ else:
796
+ _wait_async_process_save_ckpt()
797
+ ctx = mp.get_context("fork")
798
+ cond = ctx.Condition()
799
+ process_flag = True
800
+ while process_flag:
801
+ process = ctx.Process(target=_async_process_save,
802
+ args=(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check,
803
+ format, cond), daemon=True, name="asyn_save_ckpt")
804
+ process.start()
805
+ with cond:
806
+ wait_flag = cond.wait(timeout=5)
807
+ if not wait_flag:
808
+ logger.warning("Async save process fails to create. will kill and recreate")
809
+ process.kill()
810
+ else:
811
+ process_flag = False
812
+ else:
813
+ data_copy = copy.deepcopy(data_list)
814
+ _wait_async_thread_save_ckpt()
815
+ thr = Thread(target=_exec_save,
816
+ args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
817
+ name="asyn_save_ckpt")
818
+ thr.start()
698
819
  else:
699
820
  _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
700
821
 
@@ -1201,8 +1322,28 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1201
1322
  ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
1202
1323
  if format == "safetensors":
1203
1324
  with safe_open(ckpt_file_name, framework='np') as f:
1204
- for k in f.keys():
1205
- parameter_dict[k] = Parameter(f.get_tensor(k))
1325
+ cal_crc_num = 0
1326
+ sf_load_time_start = time.time()
1327
+ for k in sorted(f.keys()):
1328
+ if crc_check:
1329
+ cal_crc_num = binascii.crc32(bytes(k, encoding='utf-8'), cal_crc_num)
1330
+ cal_crc_num = binascii.crc32(bytes(f.get_tensor(k)), cal_crc_num)
1331
+ if choice_func is not None and not choice_func(k):
1332
+ continue
1333
+ parameter_dict[k] = Parameter(Tensor.from_numpy(f.get_tensor(k)))
1334
+ sf_load_time_end = time.time()
1335
+ cost_time = sf_load_time_end - sf_load_time_start
1336
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load safetensors cost time:{cost_time}.")
1337
+ if crc_check:
1338
+ if f.metadata() is None or f.metadata().get("crc_num") is None:
1339
+ logger.warning(
1340
+ "For 'load_checkpoint', the safetensors file do not contain the crc code, "
1341
+ "please check the file.")
1342
+ else:
1343
+ crc_num = int(f.metadata()["crc_num"])
1344
+ if cal_crc_num != crc_num:
1345
+ raise ValueError("For 'load_checkpoint', the crc check has failed. "
1346
+ "Please check whether the ckpt file is damaged.")
1206
1347
  return
1207
1348
  checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
1208
1349
  try:
@@ -1346,13 +1487,14 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1346
1487
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1347
1488
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1348
1489
  """
1490
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin load checkpoint.")
1349
1491
  specify_prefix = _check_prefix(specify_prefix)
1350
1492
  filter_prefix = _check_prefix(filter_prefix)
1351
1493
  dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1352
1494
  dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1353
1495
  crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
1354
1496
  remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1355
- _check_format_and_other_params(format, dec_key, dec_mode, crc_check)
1497
+ _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode)
1356
1498
  logger.info("Execute the process of loading checkpoint files.")
1357
1499
 
1358
1500
  parameter_dict = {}
@@ -1392,6 +1534,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1392
1534
  if _warm_up_host_cache_enabled(parameter_dict):
1393
1535
  _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1394
1536
 
1537
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Load checkpoint is finished.")
1395
1538
  return parameter_dict
1396
1539
 
1397
1540
 
@@ -1448,7 +1591,8 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
1448
1591
  >>> from mindspore import context
1449
1592
  >>> from mindspore import load_checkpoint_async
1450
1593
  >>> from mindspore import load_param_into_net
1451
- >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
1594
+ >>> mindspore.set_device(device_target="Ascend")
1595
+ >>> context.set_context(mode=context.GRAPH_MODE)
1452
1596
  >>> # Create the dataset taking MNIST as an example. Refer to
1453
1597
  >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
1454
1598
  >>> dataset = create_dataset()
@@ -1555,7 +1699,12 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1555
1699
  try:
1556
1700
  if dec_key is None:
1557
1701
  with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
1702
+ ckpt_load_time_start = time.time()
1558
1703
  pb_content = f.read()
1704
+ ckpt_load_time_end = time.time()
1705
+ cost_time = ckpt_load_time_end - ckpt_load_time_start
1706
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load ckpt cost time:{cost_time}.")
1707
+
1559
1708
  else:
1560
1709
  pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
1561
1710
  if pb_content is None:
@@ -1673,8 +1822,6 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1673
1822
  strict_load = Validator.check_bool(strict_load)
1674
1823
  remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1675
1824
  logger.info("Execute the process of loading parameters into net.")
1676
- for _, param in net.parameters_and_names():
1677
- param.from_ckpt = True
1678
1825
  param_not_load = []
1679
1826
  ckpt_not_load = list(parameter_dict.keys())
1680
1827
  for _, param in net.parameters_and_names():
@@ -2096,7 +2243,7 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
2096
2243
  logger.info("exporting model file:%s format:%s.", file_name, file_format)
2097
2244
  if "obf_config" in kwargs and file_format != "MINDIR":
2098
2245
  raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
2099
- if "custom_func" in kwargs and file_format != "MINDIR":
2246
+ if "custom_func" in kwargs and file_format != "MINDIR" and kwargs["custom_func"] is not None:
2100
2247
  raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
2101
2248
  if file_format == 'AIR':
2102
2249
  _save_air(net, file_name, *inputs, **kwargs)
@@ -2478,6 +2625,9 @@ def check_checkpoint(ckpt_file_name):
2478
2625
  """
2479
2626
  Check whether the checkpoint is valid.
2480
2627
 
2628
+ Note:
2629
+ The interface is deprecated from version 2.5 and will be removed in a future version.
2630
+
2481
2631
  Args:
2482
2632
  ckpt_file_name (str): Checkpoint file name.
2483
2633
 
@@ -2491,6 +2641,8 @@ def check_checkpoint(ckpt_file_name):
2491
2641
  >>> print(check_result)
2492
2642
  True
2493
2643
  """
2644
+ logger.warning("The interface 'mindspore.check_checkpoint' is deprecated from version 2.5 "
2645
+ "and will be removed in a future version.")
2494
2646
  if not ckpt_file_name.endswith('.ckpt'):
2495
2647
  return False
2496
2648
  checkpoint_list = Checkpoint()
@@ -2517,6 +2669,9 @@ def parse_print(print_file_name):
2517
2669
  """
2518
2670
  Parse data file generated by :class:`mindspore.ops.Print`.
2519
2671
 
2672
+ Note:
2673
+ The interface is deprecated from version 2.5 and will be removed in a future version.
2674
+
2520
2675
  Args:
2521
2676
  print_file_name (str): The file name needs to be parsed.
2522
2677
 
@@ -2551,6 +2706,8 @@ def parse_print(print_file_name):
2551
2706
  [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
2552
2707
  [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
2553
2708
  """
2709
+ logger.warning("The interface 'mindspore.parse_print' is deprecated from version 2.5 "
2710
+ "and will be removed in a future version.")
2554
2711
  print_file_path = os.path.realpath(print_file_name)
2555
2712
 
2556
2713
  if os.path.getsize(print_file_path) == 0:
@@ -2840,16 +2997,33 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
2840
2997
  return merged_parameter
2841
2998
 
2842
2999
 
3000
+ def _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, dst_device_num,
3001
+ output_format, name_map, return_param_dict):
3002
+ """gather transform tasks"""
3003
+ tasks = []
3004
+ for rank in range(0, dst_device_num):
3005
+ tasks.append(
3006
+ (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank, output_format, name_map,
3007
+ return_param_dict))
3008
+ return tasks
3009
+
3010
+
2843
3011
  def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
2844
3012
  train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
2845
- format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
3013
+ format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
3014
+ output_format='safetensors', name_map=None, max_process_num=64,
3015
+ return_param_dict=False):
2846
3016
  """
2847
3017
  Load checkpoint into net for distributed predication. Used in the case of distributed inference.
2848
3018
 
3019
+ Note:
3020
+ `output_format` will only take effect when `format` is set to `safetensors` and `network` is set to `None`.
3021
+
2849
3022
  Args:
2850
- network (Cell): Network for distributed predication.
3023
+ network (Cell): Network for distributed predication, When the format is `safetensors`, the network parameter
3024
+ can be left blank or passed as None, and the interface will execute save mode.
2851
3025
  checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
2852
- predict_strategy (dict): Strategy of predication process. It means that using one device to predict
3026
+ predict_strategy (Union[dict, str]): Strategy of predication process. It means that using one device to predict
2853
3027
  when setting predict_strategy as None. Default: ``None`` .
2854
3028
  train_strategy_filename (str): The filename of training strategy protocol buffer file.
2855
3029
  When train_strategy_filename is None, the training strategy file will be
@@ -2869,17 +3043,23 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
2869
3043
  It can be set to either "ckpt" or "safetensors". Default: "ckpt".
2870
3044
  unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
2871
3045
  Default: ``None`` .
2872
- dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
3046
+ dst_safetensors_dir (str): In the save mode scenario, the save directory for weights.
2873
3047
  rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
2874
3048
  globally by initializing the network; In save mode, save the file according to the input
2875
3049
  sequence number. If it is not input, save the entire file.
3050
+ output_format (str, optional): Control the format of the output checkpoint after conversion.
3051
+ It can be set to either "ckpt" or "safetensors". Default: "safetensors".
3052
+ name_map (dict): The weight mapping dictionary will modify the weight names according to the mapping
3053
+ dictionary before loading or saving the segmented weights into the network. Default: None.
3054
+ max_process_num (int): Maximum number of processes. Default: 64.
3055
+ return_param_dict (bool): Whether to return the param_dict. Default: ``False``.
2876
3056
 
2877
3057
  Raises:
2878
3058
  TypeError: The type of inputs do not match the requirements.
2879
3059
  ValueError: Failed to load checkpoint into net.
2880
3060
 
2881
3061
  Supported Platforms:
2882
- ``Ascend`` ``GPU``
3062
+ ``Ascend`` ``GPU`` ``CPU``
2883
3063
 
2884
3064
  Examples:
2885
3065
  .. note::
@@ -2976,9 +3156,10 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
2976
3156
  ...
2977
3157
  [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
2978
3158
  """
2979
- if format not in ['safetensors', 'ckpt']:
3159
+ if format not in ['safetensors', 'ckpt'] or output_format not in ['safetensors', 'ckpt']:
2980
3160
  raise ValueError(
2981
- f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
3161
+ f"For 'load_distributed_checkpoint', 'format' and 'output_format' "
3162
+ f"must be 'ckpt' or 'safetensors', but got {format}.")
2982
3163
 
2983
3164
  if format == 'safetensors':
2984
3165
  if unified_safetensors_dir is None:
@@ -2993,36 +3174,32 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
2993
3174
  raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
2994
3175
  f"when format is 'safetensors'.")
2995
3176
  if network is not None:
2996
- rank_id = get_rank()
2997
- _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
3177
+ try:
3178
+ rank_id = get_rank()
3179
+ except RuntimeError:
3180
+ rank_id = 0
3181
+ logger.warning(f"Get rank failed, default loading weight for rank 0.")
3182
+ param_dict = _load_parallel_checkpoint(
3183
+ (unified_safetensors_dir, predict_strategy, network, None, rank_id, output_format, name_map,
3184
+ return_param_dict))
3185
+ return param_dict
3186
+ if dst_safetensors_dir is None:
3187
+ raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
3188
+ f"when network is None.")
3189
+ if rank_id is not None:
3190
+ _load_parallel_checkpoint(
3191
+ (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3192
+ rank_id, output_format, name_map, return_param_dict))
2998
3193
  else:
2999
- if dst_safetensors_dir is None:
3000
- raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
3001
- f"when network is None.")
3002
- if rank_id is not None:
3003
- _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3004
- rank_id)
3005
- else:
3006
- dst_strategy_dict = _build_searched_strategy(predict_strategy)
3007
- dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
3008
- dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
3009
- dst_device_num = dst_stage_device_num * dst_stage_num
3010
- processes = []
3011
- activate_processes = 0
3012
- for rank in range(0, dst_device_num):
3013
- p = Process(target=_load_parallel_checkpoint, args=(
3014
- unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
3015
- p.start()
3016
- processes.append(p)
3017
- activate_processes += 1
3018
- max_processes = 64
3019
- if activate_processes >= max_processes:
3020
- p = processes.pop(0)
3021
- p.join()
3022
- activate_processes -= 1
3023
- for p in processes:
3024
- p.join()
3025
- return
3194
+ dst_strategy_dict = _build_searched_strategy(predict_strategy)
3195
+ dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
3196
+ dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
3197
+ dst_device_num = dst_stage_device_num * dst_stage_num
3198
+ tasks = _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3199
+ dst_device_num, output_format, name_map, return_param_dict)
3200
+ with Pool(processes=max_process_num) as pool:
3201
+ list(pool.imap(_load_parallel_checkpoint, tasks))
3202
+ return True
3026
3203
 
3027
3204
  network = Validator.check_isinstance("network", network, nn.Cell)
3028
3205
  _check_checkpoint_file(checkpoint_filenames)
@@ -3075,14 +3252,15 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
3075
3252
  if first_dim_shard_idx >= 0:
3076
3253
  first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
3077
3254
  if train_strategy.get(param.name)[5]:
3078
- shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
3255
+ repeat_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
3079
3256
  else:
3080
- shard_size = 0
3257
+ repeat_size = 0
3081
3258
  for rank in param_rank:
3082
3259
  param_total_list = list(range(0, ckpt_file_len))
3083
3260
  if first_dim_shard_size != 1:
3084
3261
  param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
3085
- if shard_size > 0:
3262
+ if repeat_size > 0:
3263
+ shard_size = shard_stride * train_strategy.get(param.name)[5]
3086
3264
  rank_index = param_total_list.index(rank)
3087
3265
  start = rank_index // shard_size * shard_size
3088
3266
  param_total_list = param_total_list[start:start + shard_size]
@@ -3141,12 +3319,16 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
3141
3319
  .format(param_not_in_ckpt))
3142
3320
 
3143
3321
  load_param_into_net(network, param_dict, strict_load=strict_load)
3322
+ return True
3144
3323
 
3145
3324
 
3146
3325
  def async_ckpt_thread_status():
3147
3326
  """
3148
3327
  Get the status of asynchronous save checkpoint thread.
3149
3328
 
3329
+ Note:
3330
+ The interface is deprecated from version 2.5 and will be removed in a future version.
3331
+
3150
3332
  When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed.
3151
3333
 
3152
3334
  Returns:
@@ -3158,6 +3340,8 @@ def async_ckpt_thread_status():
3158
3340
  >>> ms.async_ckpt_thread_status()
3159
3341
  False
3160
3342
  """
3343
+ logger.warning("The interface 'mindspore.async_ckpt_thread_status' is deprecated from version 2.5 "
3344
+ "and will be removed in a future version.")
3161
3345
  thr_list = threading.enumerate()
3162
3346
  return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
3163
3347
 
@@ -3288,8 +3472,8 @@ def convert_model(mindir_file, convert_file, file_format):
3288
3472
  """
3289
3473
  Convert mindir model to other format model. The current version only supports conversion to ONNX models.
3290
3474
 
3291
- .. warning::
3292
- This is an experimental API that is subject to change or deletion.
3475
+ Note:
3476
+ The interface is deprecated from version 2.5 and will be removed in a future version.
3293
3477
 
3294
3478
  Args:
3295
3479
  mindir_file (str): MindIR file name.
@@ -3305,6 +3489,8 @@ def convert_model(mindir_file, convert_file, file_format):
3305
3489
  >>> import mindspore as ms
3306
3490
  >>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
3307
3491
  """
3492
+ logger.warning("The interface 'mindspore.train.serialization.convert_model' is deprecated from version 2.5 "
3493
+ "and will be removed in a future version.")
3308
3494
  Validator.check_file_name_by_regular(mindir_file)
3309
3495
  Validator.check_file_name_by_regular(convert_file)
3310
3496
  if file_format != "ONNX":
@@ -3316,3 +3502,203 @@ def convert_model(mindir_file, convert_file, file_format):
3316
3502
  export(net, net_input, file_name=convert_file, file_format=file_format)
3317
3503
  else:
3318
3504
  export(net, *net_input, file_name=convert_file, file_format=file_format)
3505
+
3506
+
3507
+ def _transform_tensor_to_numpy(path, name_map=None):
3508
+ return _load_and_transform(path, name_map, mindspore.load_checkpoint, lambda v, new_name: v.asnumpy())
3509
+
3510
+
3511
+ def _transform_numpy_to_tensor(path, name_map=None):
3512
+ return _load_and_transform(path, name_map, load_file, lambda v, new_name: mindspore.Parameter(v, name=new_name))
3513
+
3514
+
3515
+ def _process_file(file_info):
3516
+ cur_ckpt_path, name_map, save_path, file = file_info
3517
+ param_dict_numpy = _transform_tensor_to_numpy(cur_ckpt_path, name_map)
3518
+ safetensors_filename = file.replace(".ckpt", ".safetensors")
3519
+ dst_file = os.path.join(save_path, safetensors_filename)
3520
+ save_file(param_dict_numpy, dst_file)
3521
+
3522
+
3523
+ def _process_file_safetensors(file_info):
3524
+ cur_safe_path, name_map, save_path, file = file_info
3525
+ param_dict_tensor = _transform_numpy_to_tensor(cur_safe_path, name_map)
3526
+ ckpt_filename = file.replace(".safetensors", ".ckpt")
3527
+ dst_file = os.path.join(save_path, ckpt_filename)
3528
+ mindspore.save_checkpoint(param_dict_tensor, dst_file)
3529
+
3530
+
3531
+ def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
3532
+ """gather transform rank together"""
3533
+ tasks = []
3534
+ for root, dirs, _ in os.walk(file_path):
3535
+ if root != file_path:
3536
+ continue
3537
+
3538
+ rank_dirs = [d for d in dirs if d.startswith('rank')]
3539
+ if not rank_dirs:
3540
+ raise ValueError(
3541
+ f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}")
3542
+
3543
+ for rank_dir in rank_dirs:
3544
+ rank_dir_path = os.path.join(root, rank_dir)
3545
+ dst_root = os.path.join(save_path,
3546
+ os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
3547
+ os.makedirs(dst_root, exist_ok=True)
3548
+ tasks.extend(
3549
+ (os.path.join(rank_dir_path, file), name_map, dst_root, file)
3550
+ for file in os.listdir(rank_dir_path)
3551
+ if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
3552
+ )
3553
+ return tasks
3554
+
3555
+
3556
+ def _gather_tasks_covert(file_path, save_path, file_name_regex, name_map):
3557
+ """gather transform rank together"""
3558
+ tasks = []
3559
+ for root, dirs, _ in os.walk(file_path):
3560
+ if root != file_path:
3561
+ continue
3562
+
3563
+ rank_dirs = [d for d in dirs if d.startswith('rank')]
3564
+ if not rank_dirs:
3565
+ raise ValueError(
3566
+ f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
3567
+
3568
+ for rank_dir in rank_dirs:
3569
+ rank_dir_path = os.path.join(root, rank_dir)
3570
+ dst_root = os.path.join(save_path,
3571
+ os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
3572
+ os.makedirs(dst_root, exist_ok=True)
3573
+ tasks.extend(
3574
+ (os.path.join(rank_dir_path, file), name_map, dst_root, file)
3575
+ for file in os.listdir(rank_dir_path)
3576
+ if file.endswith(".ckpt") and (file_name_regex is None or re.findall(file_name_regex, file))
3577
+ )
3578
+ return tasks
3579
+
3580
+
3581
+ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
3582
+ """
3583
+ Converts MindSpore checkpoint files into safetensors format and saves them to `save_path`.
3584
+ Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
3585
+ used for securely storing Tensors with fast speed (zero copy).
3586
+
3587
+ Note:
3588
+ The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
3589
+ too large, otherwise it may cause freezing.
3590
+ The safetensors format does not support the enc verification function. If ckpt is enabled to save enc
3591
+ verification, an error will be generated when performing the conversion.
3592
+ The safetensors format currently does not support crc verification function. If ckpt contains crc verification
3593
+ information, the crc verification information will be lost after conversion to safetensors.
3594
+
3595
+ Args:
3596
+ file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
3597
+ save_path (str, optional): Directory path where safetensors files will be saved. Defaults: ``None``.
3598
+ name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
3599
+ file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
3600
+ Defaults: ``None``.
3601
+ processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
3602
+ Raises:
3603
+ ValueError: If the input path is invalid or the save_path is not a directory,
3604
+ or the file_path does not end with '.ckpt'.
3605
+
3606
+ Supported Platforms:
3607
+ ``Ascend`` ``GPU`` ``CPU``
3608
+
3609
+ Examples:
3610
+ >>> import mindspore as ms
3611
+ >>> ms.ckpt_to_safetensors("./ckpt_save_path")
3612
+ >>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt")
3613
+ >>> ms.ckpt_to_safetensors(file_path="./ckpt_save_path/rank0/checkpoint_0.ckpt", save_path="./new_path/")
3614
+ >>> namemap = {"lin.weight":"new_name"}
3615
+ >>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
3616
+ """
3617
+ is_dir = os.path.isdir(file_path)
3618
+ is_file = os.path.isfile(file_path)
3619
+ if not is_dir and not is_file:
3620
+ raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
3621
+ if save_path and os.path.splitext(save_path)[1]:
3622
+ raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
3623
+ if name_map is not None and not isinstance(name_map, dict):
3624
+ raise ValueError(
3625
+ f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
3626
+
3627
+ if is_dir:
3628
+ tasks = _gather_tasks_covert(file_path, save_path, file_name_regex, name_map)
3629
+ with mp.Pool(processes=processes_num) as pool:
3630
+ list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
3631
+ elif is_file:
3632
+ if not file_path.endswith(".ckpt"):
3633
+ raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
3634
+ if file_name_regex is not None and not re.findall(file_name_regex, file_path):
3635
+ raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
3636
+ if save_path and not os.path.exists(save_path):
3637
+ os.makedirs(save_path, exist_ok=True)
3638
+
3639
+ param_dict_numpy = _transform_tensor_to_numpy(file_path, name_map)
3640
+ safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
3641
+ dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
3642
+ save_file(param_dict_numpy, dst_file)
3643
+
3644
+
3645
+ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
3646
+ """
3647
+ Converts safetensors files into MindSpore checkpoint format and saves them to `save_path`.
3648
+ Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
3649
+ used for securely storing Tensors with fast speed (zero copy).
3650
+
3651
+ Note:
3652
+ The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
3653
+ too large, otherwise it may cause freezing.
3654
+
3655
+ Args:
3656
+ file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
3657
+ save_path (str, optional): Directory path where checkpoint files will be saved. Defaults: ``None``.
3658
+ name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
3659
+ file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
3660
+ Defaults: ``None``.
3661
+ processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
3662
+
3663
+ Raises:
3664
+ ValueError: If the input path is invalid, the save_path is not a directory,
3665
+ or the file_path does not end with '.safetensors'.
3666
+
3667
+ Supported Platforms:
3668
+ ``Ascend`` ``GPU`` ``CPU``
3669
+
3670
+ Examples:
3671
+ >>> import mindspore as ms
3672
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path")
3673
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors")
3674
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/")
3675
+ >>> namemap = {"lin.weight":"new_name"}
3676
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
3677
+ """
3678
+ is_dir = os.path.isdir(file_path)
3679
+ is_file = os.path.isfile(file_path)
3680
+ if not is_dir and not is_file:
3681
+ raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
3682
+ if save_path and os.path.splitext(save_path)[1]:
3683
+ raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
3684
+ if name_map is not None and not isinstance(name_map, dict):
3685
+ raise ValueError(
3686
+ f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
3687
+
3688
+ if is_dir:
3689
+ tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
3690
+ with mp.Pool(processes=processes_num) as pool:
3691
+ list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
3692
+ elif is_file:
3693
+ if not file_path.endswith(".safetensors"):
3694
+ raise ValueError(
3695
+ f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
3696
+ if file_name_regex is not None and not re.findall(file_name_regex, file_path):
3697
+ raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
3698
+ if save_path and not os.path.exists(save_path):
3699
+ os.makedirs(save_path, exist_ok=True)
3700
+
3701
+ param_dict_tensor = _transform_numpy_to_tensor(file_path, name_map)
3702
+ ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
3703
+ dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
3704
+ mindspore.save_checkpoint(param_dict_tensor, dst_file)