mindspore 2.4.1__cp311-cp311-win_amd64.whl → 2.5.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (395) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -3
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +0 -5
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/compile_config.py +64 -0
  11. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
  13. mindspore/_extends/parse/parser.py +23 -5
  14. mindspore/_extends/parse/standard_method.py +123 -27
  15. mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
  16. mindspore/amp.py +7 -1
  17. mindspore/atlprov.dll +0 -0
  18. mindspore/avcodec-59.dll +0 -0
  19. mindspore/avdevice-59.dll +0 -0
  20. mindspore/avfilter-8.dll +0 -0
  21. mindspore/avformat-59.dll +0 -0
  22. mindspore/avutil-57.dll +0 -0
  23. mindspore/boost/boost_cell_wrapper.py +136 -41
  24. mindspore/c1.dll +0 -0
  25. mindspore/c1xx.dll +0 -0
  26. mindspore/c2.dll +0 -0
  27. mindspore/common/__init__.py +3 -1
  28. mindspore/common/_register_for_tensor.py +0 -1
  29. mindspore/common/_stub_tensor.py +25 -4
  30. mindspore/common/_tensor_cpp_method.py +17 -0
  31. mindspore/common/_tensor_docs.py +6132 -0
  32. mindspore/common/api.py +99 -25
  33. mindspore/common/dtype.py +34 -34
  34. mindspore/common/dump.py +2 -1
  35. mindspore/common/file_system.py +8 -1
  36. mindspore/common/generator.py +2 -0
  37. mindspore/common/hook_handle.py +3 -1
  38. mindspore/common/initializer.py +3 -4
  39. mindspore/common/lazy_inline.py +8 -2
  40. mindspore/common/mindir_util.py +10 -2
  41. mindspore/common/parameter.py +30 -27
  42. mindspore/common/tensor.py +713 -1337
  43. mindspore/communication/__init__.py +1 -1
  44. mindspore/communication/_comm_helper.py +10 -0
  45. mindspore/communication/comm_func.py +215 -173
  46. mindspore/communication/management.py +23 -20
  47. mindspore/context.py +292 -193
  48. mindspore/dataset/__init__.py +23 -19
  49. mindspore/dataset/callback/ds_callback.py +2 -1
  50. mindspore/dataset/core/config.py +84 -3
  51. mindspore/dataset/engine/cache_admin.py +3 -3
  52. mindspore/dataset/engine/cache_client.py +5 -4
  53. mindspore/dataset/engine/datasets.py +192 -149
  54. mindspore/dataset/engine/datasets_audio.py +14 -0
  55. mindspore/dataset/engine/datasets_standard_format.py +28 -11
  56. mindspore/dataset/engine/datasets_text.py +38 -1
  57. mindspore/dataset/engine/datasets_user_defined.py +125 -65
  58. mindspore/dataset/engine/datasets_vision.py +81 -8
  59. mindspore/dataset/engine/iterators.py +281 -63
  60. mindspore/dataset/engine/obs/util.py +8 -0
  61. mindspore/dataset/engine/queue.py +40 -0
  62. mindspore/dataset/engine/samplers.py +26 -2
  63. mindspore/dataset/engine/serializer_deserializer.py +1 -1
  64. mindspore/dataset/engine/validators.py +43 -11
  65. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  66. mindspore/dataset/transforms/transforms.py +29 -12
  67. mindspore/dataset/vision/validators.py +1 -2
  68. mindspore/device_context/__init__.py +21 -0
  69. mindspore/device_context/ascend/__init__.py +25 -0
  70. mindspore/device_context/ascend/device.py +72 -0
  71. mindspore/device_context/ascend/op_debug.py +94 -0
  72. mindspore/device_context/ascend/op_precision.py +193 -0
  73. mindspore/device_context/ascend/op_tuning.py +127 -0
  74. mindspore/device_context/cpu/__init__.py +25 -0
  75. mindspore/device_context/cpu/device.py +62 -0
  76. mindspore/device_context/cpu/op_tuning.py +43 -0
  77. mindspore/device_context/gpu/__init__.py +21 -0
  78. mindspore/device_context/gpu/device.py +70 -0
  79. mindspore/device_context/gpu/op_precision.py +67 -0
  80. mindspore/device_context/gpu/op_tuning.py +175 -0
  81. mindspore/device_manager.py +134 -0
  82. mindspore/dnnl.dll +0 -0
  83. mindspore/dpcmi.dll +0 -0
  84. mindspore/experimental/llm_boost/__init__.py +3 -2
  85. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  86. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  87. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  88. mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
  89. mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
  90. mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
  91. mindspore/experimental/llm_boost/register.py +1 -0
  92. mindspore/experimental/optim/adadelta.py +26 -22
  93. mindspore/experimental/optim/adam.py +3 -0
  94. mindspore/experimental/optim/lr_scheduler.py +33 -24
  95. mindspore/experimental/optim/radam.py +33 -30
  96. mindspore/hal/device.py +28 -0
  97. mindspore/hal/event.py +17 -0
  98. mindspore/hal/memory.py +94 -3
  99. mindspore/hal/stream.py +91 -6
  100. mindspore/include/api/context.h +1 -2
  101. mindspore/include/dataset/constants.h +2 -2
  102. mindspore/jpeg62.dll +0 -0
  103. mindspore/log.py +12 -0
  104. mindspore/mindrecord/__init__.py +1 -1
  105. mindspore/mindrecord/config.py +17 -316
  106. mindspore/mindrecord/filereader.py +1 -9
  107. mindspore/mindrecord/filewriter.py +5 -15
  108. mindspore/mindrecord/mindpage.py +1 -9
  109. mindspore/mindspore_backend.dll +0 -0
  110. mindspore/mindspore_common.dll +0 -0
  111. mindspore/mindspore_core.dll +0 -0
  112. mindspore/mindspore_glog.dll +0 -0
  113. mindspore/mindspore_ops.dll +0 -0
  114. mindspore/mint/__init__.py +824 -218
  115. mindspore/mint/distributed/__init__.py +66 -4
  116. mindspore/mint/distributed/distributed.py +2594 -44
  117. mindspore/mint/linalg/__init__.py +6 -0
  118. mindspore/mint/nn/__init__.py +473 -14
  119. mindspore/mint/nn/functional.py +486 -11
  120. mindspore/mint/nn/layer/__init__.py +17 -4
  121. mindspore/mint/nn/layer/_functions.py +330 -0
  122. mindspore/mint/nn/layer/activation.py +169 -1
  123. mindspore/mint/nn/layer/basic.py +123 -0
  124. mindspore/mint/nn/layer/conv.py +727 -0
  125. mindspore/mint/nn/layer/normalization.py +215 -19
  126. mindspore/mint/nn/layer/padding.py +797 -0
  127. mindspore/mint/nn/layer/pooling.py +170 -0
  128. mindspore/mint/optim/__init__.py +2 -1
  129. mindspore/mint/optim/adam.py +223 -0
  130. mindspore/mint/optim/adamw.py +26 -19
  131. mindspore/mint/special/__init__.py +2 -1
  132. mindspore/msobj140.dll +0 -0
  133. mindspore/mspdb140.dll +0 -0
  134. mindspore/mspdbcore.dll +0 -0
  135. mindspore/mspdbst.dll +0 -0
  136. mindspore/mspft140.dll +0 -0
  137. mindspore/msvcdis140.dll +0 -0
  138. mindspore/msvcp140_1.dll +0 -0
  139. mindspore/msvcp140_2.dll +0 -0
  140. mindspore/msvcp140_atomic_wait.dll +0 -0
  141. mindspore/msvcp140_codecvt_ids.dll +0 -0
  142. mindspore/multiprocessing/__init__.py +5 -0
  143. mindspore/nn/__init__.py +2 -0
  144. mindspore/nn/cell.py +142 -21
  145. mindspore/nn/dynamic_lr.py +2 -1
  146. mindspore/nn/layer/activation.py +6 -6
  147. mindspore/nn/layer/basic.py +35 -25
  148. mindspore/nn/layer/channel_shuffle.py +3 -3
  149. mindspore/nn/layer/conv.py +3 -0
  150. mindspore/nn/layer/embedding.py +3 -3
  151. mindspore/nn/layer/normalization.py +8 -7
  152. mindspore/nn/layer/padding.py +4 -3
  153. mindspore/nn/layer/pooling.py +55 -23
  154. mindspore/nn/layer/rnn_cells.py +1 -1
  155. mindspore/nn/layer/rnns.py +2 -1
  156. mindspore/nn/layer/timedistributed.py +5 -5
  157. mindspore/nn/layer/transformer.py +48 -26
  158. mindspore/nn/learning_rate_schedule.py +5 -3
  159. mindspore/nn/loss/loss.py +31 -36
  160. mindspore/nn/optim/ada_grad.py +1 -0
  161. mindspore/nn/optim/adadelta.py +2 -2
  162. mindspore/nn/optim/adam.py +1 -1
  163. mindspore/nn/optim/lars.py +1 -4
  164. mindspore/nn/optim/optimizer.py +1 -1
  165. mindspore/nn/optim/rprop.py +2 -2
  166. mindspore/nn/optim/thor.py +2 -1
  167. mindspore/nn/utils/__init__.py +22 -0
  168. mindspore/nn/utils/init.py +73 -0
  169. mindspore/nn/wrap/cell_wrapper.py +4 -6
  170. mindspore/nn/wrap/loss_scale.py +3 -4
  171. mindspore/numpy/array_creations.py +60 -62
  172. mindspore/numpy/array_ops.py +148 -143
  173. mindspore/numpy/logic_ops.py +41 -42
  174. mindspore/numpy/math_ops.py +361 -359
  175. mindspore/numpy/utils.py +16 -16
  176. mindspore/numpy/utils_const.py +4 -4
  177. mindspore/opencv_core452.dll +0 -0
  178. mindspore/opencv_imgcodecs452.dll +0 -0
  179. mindspore/opencv_imgproc452.dll +0 -0
  180. mindspore/ops/__init__.py +2 -1
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
  182. mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
  183. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  184. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  185. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  186. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  187. mindspore/ops/_vmap/vmap_array_ops.py +20 -19
  188. mindspore/ops/_vmap/vmap_base.py +0 -2
  189. mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
  190. mindspore/ops/_vmap/vmap_math_ops.py +11 -9
  191. mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
  192. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
  193. mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
  194. mindspore/ops/auto_generate/gen_extend_func.py +554 -60
  195. mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
  196. mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
  197. mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
  198. mindspore/ops/composite/base.py +1 -1
  199. mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
  200. mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
  201. mindspore/ops/function/__init__.py +12 -0
  202. mindspore/ops/function/array_func.py +561 -159
  203. mindspore/ops/function/clip_func.py +64 -0
  204. mindspore/ops/function/debug_func.py +28 -20
  205. mindspore/ops/function/image_func.py +1 -1
  206. mindspore/ops/function/linalg_func.py +5 -4
  207. mindspore/ops/function/math_func.py +1664 -294
  208. mindspore/ops/function/nn_func.py +988 -317
  209. mindspore/ops/function/parameter_func.py +3 -56
  210. mindspore/ops/function/random_func.py +243 -33
  211. mindspore/ops/function/sparse_unary_func.py +1 -1
  212. mindspore/ops/functional.py +18 -5
  213. mindspore/ops/functional_overload.py +897 -0
  214. mindspore/ops/operations/__init__.py +3 -2
  215. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  216. mindspore/ops/operations/_grad_ops.py +2 -34
  217. mindspore/ops/operations/_infer_ops.py +2 -1
  218. mindspore/ops/operations/_inner_ops.py +38 -8
  219. mindspore/ops/operations/array_ops.py +45 -303
  220. mindspore/ops/operations/comm_ops.py +23 -17
  221. mindspore/ops/operations/custom_ops.py +7 -49
  222. mindspore/ops/operations/debug_ops.py +42 -47
  223. mindspore/ops/operations/inner_ops.py +6 -4
  224. mindspore/ops/operations/linalg_ops.py +3 -2
  225. mindspore/ops/operations/manually_defined/ops_def.py +185 -104
  226. mindspore/ops/operations/math_ops.py +11 -216
  227. mindspore/ops/operations/nn_ops.py +153 -310
  228. mindspore/ops/primitive.py +23 -21
  229. mindspore/ops/tensor_method.py +1669 -0
  230. mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
  231. mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
  232. mindspore/ops_generate/arg_handler.py +0 -61
  233. mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
  234. mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
  235. mindspore/ops_generate/base_generator.py +11 -0
  236. mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
  237. mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
  238. mindspore/ops_generate/functional_overload_py_generator.py +110 -0
  239. mindspore/ops_generate/functions_cc_generator.py +233 -0
  240. mindspore/ops_generate/gen_aclnn_implement.py +110 -114
  241. mindspore/ops_generate/gen_constants.py +157 -3
  242. mindspore/ops_generate/gen_ops.py +245 -990
  243. mindspore/ops_generate/gen_pyboost_func.py +97 -998
  244. mindspore/ops_generate/gen_utils.py +119 -33
  245. mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
  246. mindspore/ops_generate/op_api_proto.py +206 -0
  247. mindspore/ops_generate/op_def_py_generator.py +131 -0
  248. mindspore/ops_generate/op_prim_py_generator.py +480 -0
  249. mindspore/ops_generate/op_proto.py +373 -108
  250. mindspore/ops_generate/op_template_parser.py +436 -0
  251. mindspore/ops_generate/ops_def_cc_generator.py +288 -0
  252. mindspore/ops_generate/ops_def_h_generator.py +74 -0
  253. mindspore/ops_generate/ops_name_h_generator.py +68 -0
  254. mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
  255. mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
  256. mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
  257. mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
  258. mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
  259. mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
  260. mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
  261. mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
  262. mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
  263. mindspore/ops_generate/pyboost_utils.py +92 -33
  264. mindspore/ops_generate/template.py +294 -44
  265. mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
  266. mindspore/parallel/__init__.py +3 -3
  267. mindspore/parallel/_auto_parallel_context.py +44 -34
  268. mindspore/parallel/_cell_wrapper.py +22 -3
  269. mindspore/parallel/_parallel_serialization.py +13 -2
  270. mindspore/parallel/_utils.py +4 -2
  271. mindspore/parallel/algo_parameter_config.py +1 -1
  272. mindspore/parallel/checkpoint_transform.py +44 -0
  273. mindspore/parallel/cluster/process_entity/_api.py +131 -37
  274. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  275. mindspore/parallel/cluster/run.py +20 -3
  276. mindspore/parallel/parameter_broadcast.py +1 -1
  277. mindspore/parallel/shard.py +3 -0
  278. mindspore/parallel/transform_safetensors.py +119 -253
  279. mindspore/pgodb140.dll +0 -0
  280. mindspore/pgort140.dll +0 -0
  281. mindspore/profiler/__init__.py +17 -4
  282. mindspore/profiler/analysis/__init__.py +0 -0
  283. mindspore/profiler/analysis/parser/__init__.py +0 -0
  284. mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
  285. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  286. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  287. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  288. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  289. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  290. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
  291. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  292. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
  293. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  294. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  295. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  296. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  297. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  298. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  299. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  300. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  301. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  302. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  303. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
  304. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  305. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  306. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  307. mindspore/profiler/analysis/task_manager.py +131 -0
  308. mindspore/profiler/analysis/time_converter.py +84 -0
  309. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  310. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
  311. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  312. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
  313. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
  314. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
  315. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
  316. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  317. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  318. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
  319. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  320. mindspore/profiler/analysis/work_flow.py +73 -0
  321. mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
  322. mindspore/profiler/common/command_executor.py +90 -0
  323. mindspore/profiler/common/constant.py +174 -3
  324. mindspore/profiler/common/file_manager.py +208 -0
  325. mindspore/profiler/common/log.py +130 -0
  326. mindspore/profiler/common/msprof_cmd_tool.py +202 -0
  327. mindspore/profiler/common/path_manager.py +371 -0
  328. mindspore/profiler/common/process_bar.py +168 -0
  329. mindspore/profiler/common/process_pool.py +9 -3
  330. mindspore/profiler/common/profiler_context.py +476 -0
  331. mindspore/profiler/common/profiler_info.py +304 -0
  332. mindspore/profiler/common/profiler_output_path.py +284 -0
  333. mindspore/profiler/common/profiler_parameters.py +210 -0
  334. mindspore/profiler/common/profiler_path_manager.py +120 -0
  335. mindspore/profiler/common/record_function.py +76 -0
  336. mindspore/profiler/common/tlv_decoder.py +76 -0
  337. mindspore/profiler/common/util.py +75 -2
  338. mindspore/profiler/dynamic_profiler.py +270 -37
  339. mindspore/profiler/envprofiler.py +138 -0
  340. mindspore/profiler/mstx.py +199 -0
  341. mindspore/profiler/platform/__init__.py +21 -0
  342. mindspore/profiler/platform/base_profiler.py +40 -0
  343. mindspore/profiler/platform/cpu_profiler.py +124 -0
  344. mindspore/profiler/platform/gpu_profiler.py +74 -0
  345. mindspore/profiler/platform/npu_profiler.py +309 -0
  346. mindspore/profiler/profiler.py +580 -93
  347. mindspore/profiler/profiler_action_controller.py +187 -0
  348. mindspore/profiler/profiler_interface.py +114 -0
  349. mindspore/profiler/schedule.py +208 -0
  350. mindspore/rewrite/api/symbol_tree.py +1 -2
  351. mindspore/run_check/_check_version.py +18 -13
  352. mindspore/runtime/__init__.py +37 -0
  353. mindspore/runtime/device.py +27 -0
  354. mindspore/runtime/event.py +209 -0
  355. mindspore/runtime/executor.py +148 -0
  356. mindspore/runtime/memory.py +392 -0
  357. mindspore/runtime/stream.py +460 -0
  358. mindspore/runtime/thread_bind_core.py +401 -0
  359. mindspore/swresample-4.dll +0 -0
  360. mindspore/swscale-6.dll +0 -0
  361. mindspore/tbbmalloc.dll +0 -0
  362. mindspore/tinyxml2.dll +0 -0
  363. mindspore/train/__init__.py +2 -2
  364. mindspore/train/_utils.py +53 -18
  365. mindspore/train/amp.py +8 -4
  366. mindspore/train/callback/_checkpoint.py +32 -18
  367. mindspore/train/callback/_early_stop.py +1 -1
  368. mindspore/train/callback/_flops_collector.py +105 -69
  369. mindspore/train/callback/_history.py +1 -1
  370. mindspore/train/callback/_summary_collector.py +44 -6
  371. mindspore/train/callback/_tft_register.py +37 -15
  372. mindspore/train/dataset_helper.py +11 -11
  373. mindspore/train/metrics/precision.py +4 -5
  374. mindspore/train/mind_ir_pb2.py +167 -46
  375. mindspore/train/model.py +13 -14
  376. mindspore/train/serialization.py +461 -72
  377. mindspore/train/summary/summary_record.py +1 -2
  378. mindspore/train/train_thor/model_thor.py +1 -1
  379. mindspore/turbojpeg.dll +0 -0
  380. mindspore/utils/__init__.py +4 -2
  381. mindspore/utils/dryrun.py +138 -0
  382. mindspore/utils/runtime_execution_order_check.py +550 -0
  383. mindspore/vcmeta.dll +0 -0
  384. mindspore/vcruntime140.dll +0 -0
  385. mindspore/vcruntime140_1.dll +0 -0
  386. mindspore/version.py +1 -1
  387. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
  388. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +391 -265
  389. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
  390. mindspore/common/_tensor_overload.py +0 -139
  391. mindspore/mindspore_np_dtype.dll +0 -0
  392. mindspore/profiler/envprofiling.py +0 -254
  393. mindspore/profiler/profiling.py +0 -1926
  394. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
  395. {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
@@ -24,9 +24,11 @@ import os
24
24
  import re
25
25
  import shutil
26
26
  import stat
27
+ import atexit
27
28
  import threading
28
29
  from threading import Thread, RLock
29
- from multiprocessing import Process
30
+ from multiprocessing import Pool, active_children
31
+ import multiprocessing as mp
30
32
  from collections import defaultdict, OrderedDict
31
33
  from io import BytesIO
32
34
 
@@ -36,6 +38,9 @@ import time
36
38
  import google
37
39
  import numpy as np
38
40
 
41
+ from safetensors.numpy import save_file, load_file
42
+ from safetensors import safe_open
43
+
39
44
  from mindspore.train.checkpoint_pb2 import Checkpoint
40
45
  from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
41
46
  from mindspore.train.print_pb2 import Print
@@ -44,6 +49,7 @@ import mindspore
44
49
  import mindspore.nn as nn
45
50
  from mindspore import context
46
51
  from mindspore import log as logger
52
+ from mindspore.log import vlog_print
47
53
  from mindspore._checkparam import check_input_data, check_input_dataset
48
54
  from mindspore import _checkparam as Validator
49
55
  from mindspore.common import dtype as mstype
@@ -73,12 +79,10 @@ from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_w
73
79
  from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
74
80
  from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
75
81
  _extract_pipeline_stage_num
76
- from mindspore.train._utils import read_proto, get_parameter_redundancy
82
+ from mindspore.train._utils import read_proto, get_parameter_redundancy, _progress_bar, _load_and_transform
77
83
  from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
78
84
  split_mindir, split_dynamic_mindir
79
85
  from mindspore.common.generator import Generator
80
- from safetensors.numpy import save_file
81
- from safetensors import safe_open
82
86
  from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
83
87
 
84
88
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
@@ -122,6 +126,31 @@ def init_ckpt_file_system(fs: FileSystem):
122
126
  init_ckpt_file_system(_ckpt_fs)
123
127
 
124
128
 
129
+ def _wait_async_process_save_ckpt():
130
+ """Waiting for asynchronous saving process of ckpt to complete"""
131
+ for process in active_children():
132
+ if process.name == "asyn_save_ckpt":
133
+ process.join()
134
+
135
+
136
+ def _wait_async_thread_save_ckpt():
137
+ """Waiting for asynchronous saving thread of ckpt to complete"""
138
+ thread_list = threading.enumerate()
139
+ for thread in thread_list:
140
+ if thread.getName() == "asyn_save_ckpt":
141
+ thread.join()
142
+
143
+
144
+ def _async_save_close():
145
+ """Waiting for asynchronous saving of ckpt to complete"""
146
+ _wait_async_process_save_ckpt()
147
+ _wait_async_thread_save_ckpt()
148
+
149
+
150
+ # Registering atexit handles asynchronous save
151
+ atexit.register(_async_save_close)
152
+
153
+
125
154
  def _get_cur_rank_dp(parameter_layout_dict):
126
155
  """ Get dp and tp from layout dict. """
127
156
  pp_num = _get_auto_parallel_context("pipeline_stages")
@@ -281,7 +310,8 @@ def _type_convert(param, new_param, strict_load):
281
310
  {param.data.dtype, new_param.data.dtype}.issubset(int_type)):
282
311
  logger.warning(f"The type of {new_param.name}:{new_param.data.dtype} in 'parameter_dict' is different from "
283
312
  f"the type of it in 'net':{param.data.dtype}, then the type convert from "
284
- f"{new_param.data.dtype} to {param.data.dtype} in the network.")
313
+ f"{new_param.data.dtype} to {param.data.dtype} in the network. May consume additional memory "
314
+ f"and time")
285
315
  return True
286
316
  return False
287
317
 
@@ -335,6 +365,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
335
365
  os.chmod(tmp_name, stat.S_IWUSR)
336
366
  os.remove(tmp_name)
337
367
  if format == "ckpt":
368
+ ckpt_save_time_start = time.time()
338
369
  with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
339
370
  plain_data = None
340
371
  if enc_key is not None:
@@ -375,11 +406,29 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
375
406
  block_data = plain_data.read(max_block_size)
376
407
  if crc_check:
377
408
  f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
409
+ ckpt_save_time_end = time.time()
410
+ cost_time = ckpt_save_time_end - ckpt_save_time_start
411
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save ckpt cost time:{cost_time}.")
378
412
  elif format == "safetensors":
379
413
  save_dict = {}
380
- for name, value in data_list.items():
414
+ crc_num = 0
415
+ for name in sorted(data_list.keys()):
416
+ value = data_list[name]
381
417
  save_dict[name] = value[2].asnumpy()
382
- save_file(save_dict, tmp_name)
418
+
419
+ if crc_check:
420
+ crc_num = binascii.crc32(bytes(name, encoding='utf-8'), crc_num)
421
+ crc_num = binascii.crc32(
422
+ bytes(save_dict[name]), crc_num)
423
+ safetensors_save_time_start = time.time()
424
+ if crc_check:
425
+ save_file(save_dict, tmp_name, metadata={
426
+ "crc_num": str(crc_num)})
427
+ else:
428
+ save_file(save_dict, tmp_name)
429
+ safetensors_save_time_end = time.time()
430
+ cost_time = safetensors_save_time_end - safetensors_save_time_start
431
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors cost time:{cost_time}.")
383
432
  if not os.path.exists(tmp_name):
384
433
  logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
385
434
  f"simultaneously modified a file.")
@@ -519,12 +568,58 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
519
568
  return ckpt_file_name
520
569
 
521
570
 
522
- def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, map_param_inc=False,
523
- global_step_num=None):
524
- param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
525
- or map_param_inc or global_step_num is not None)
526
- if format == "safetensors" and param_not_default:
527
- raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
571
+ def _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode):
572
+ """check load checkpoint unsupported param"""
573
+ if format != "safetensors":
574
+ return
575
+ default_params = {
576
+ "dec_key": None,
577
+ "dec_mode": "AES-GCM",
578
+ }
579
+ for param_name, default_value in default_params.items():
580
+ current_value = locals()[param_name]
581
+ if current_value != default_value:
582
+ raise ValueError(f"For 'load_checkpoint', when format is 'safetensors', the parameter '{param_name}' must "
583
+ f"be set to default value '{default_value}', but got '{current_value}'.")
584
+
585
+
586
+ def _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, async_save=False, map_param_inc=False,
587
+ global_step_num=None):
588
+ """check save checkpoint unsupported param"""
589
+ if format != "safetensors":
590
+ return
591
+ default_params = {
592
+ "enc_key": None,
593
+ "enc_mode": "AES-GCM",
594
+ "async_save": False,
595
+ "map_param_inc": False,
596
+ "global_step_num": None
597
+ }
598
+ for param_name, default_value in default_params.items():
599
+ current_value = locals()[param_name]
600
+ if current_value != default_value:
601
+ raise ValueError(f"For 'save_checkpoint', when format is 'safetensors', the parameter '{param_name}' must "
602
+ f"be set to default value '{default_value}', but got '{current_value}'.")
603
+
604
+
605
+ def _check_async_save(async_save):
606
+ """Check async_save for save_checkpoint."""
607
+ if not isinstance(async_save, (bool, str)):
608
+ raise TypeError("For 'save_checkpoint', the parameter 'async_save' must be bool or str, "
609
+ "but got {}.".format(type(async_save)))
610
+ if isinstance(async_save, str):
611
+ if async_save not in ("process", "thread"):
612
+ raise ValueError("For 'save_checkpoint', the argument 'async_save' can only be 'process' or 'thread',"
613
+ "but got {}.".format(async_save))
614
+ return async_save
615
+
616
+
617
+ def _async_process_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False,
618
+ crc_check=False, format="ckpt", cond=None):
619
+ """Check whether the process is pulled up successfully, execute the process of saving checkpoint into file."""
620
+ with cond:
621
+ cond.notify()
622
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
528
623
 
529
624
 
530
625
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
@@ -541,10 +636,13 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
541
636
  list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
542
637
  elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
543
638
  `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
544
- it can be the returned value of `mindspore.load_checkpoint()`.
639
+ it can be the returned value of :func:`mindspore.load_checkpoint`.
545
640
  ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
546
641
  integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
547
- async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
642
+ async_save (Union[bool, str]): Whether to use asynchronous saving of the checkpoint file, if True,
643
+ the asynchronous thread is used by default. If the type is string,
644
+ the method of asynchronous saving, it can be "process" or "thread".
645
+ Default: ``False`` .
548
646
  append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
549
647
  of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
550
648
  enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is ``None`` , the encryption
@@ -564,8 +662,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
564
662
 
565
663
  Raises:
566
664
  TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
567
- TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
665
+ TypeError: If the parameter `integrated_save` is not bool type.
568
666
  TypeError: If the parameter `ckpt_file_name` is not string type.
667
+ TypeError: If the parameter `async_save` is not bool or string type.
668
+ ValueError: If the parameter `async_save` is string type but not in ["process", "thread"].
569
669
 
570
670
  Examples:
571
671
  >>> import mindspore as ms
@@ -595,7 +695,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
595
695
  """
596
696
  ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
597
697
  integrated_save = Validator.check_bool(integrated_save)
598
- async_save = Validator.check_bool(async_save)
698
+ async_save = _check_async_save(async_save)
599
699
  append_dict = _check_append_dict(append_dict)
600
700
  enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
601
701
  enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
@@ -603,7 +703,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
603
703
  map_param_inc = kwargs.get('incremental', False)
604
704
  logger.info("Execute the process of saving checkpoint files.")
605
705
  global_step_num = kwargs.get('global_step_num', None)
606
- _check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, map_param_inc, global_step_num)
706
+ _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, async_save, map_param_inc, global_step_num)
607
707
 
608
708
  if append_dict and "__exception_save__" in append_dict:
609
709
  s1 = mindspore.hal.Stream()
@@ -679,7 +779,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
679
779
  data_list[key].append(dims)
680
780
  tensor_type = str(param["data"].dtype)
681
781
  data_list[key].append(tensor_type)
682
- data = param["data"]
782
+ data = param["data"] if async_save != "process" else param["data"].asnumpy()
683
783
  data_list[key].append(data)
684
784
 
685
785
  if os.getenv("AITURBO") == "1":
@@ -687,11 +787,35 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
687
787
  ckpt_name = os.path.basename(ckpt_file_name)
688
788
  aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
689
789
  elif async_save:
690
- data_copy = copy.deepcopy(data_list)
691
- thr = Thread(target=_exec_save,
692
- args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
693
- name="asyn_save_ckpt")
694
- thr.start()
790
+ if async_save == "process":
791
+ if sys.platform.startswith("win"):
792
+ logger.warining("The Win platform currently does not support asynchronous process saving of ckpt, "
793
+ "so serial saving of ckpt is used now.")
794
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
795
+ else:
796
+ _wait_async_process_save_ckpt()
797
+ ctx = mp.get_context("fork")
798
+ cond = ctx.Condition()
799
+ process_flag = True
800
+ while process_flag:
801
+ process = ctx.Process(target=_async_process_save,
802
+ args=(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check,
803
+ format, cond), daemon=True, name="asyn_save_ckpt")
804
+ process.start()
805
+ with cond:
806
+ wait_flag = cond.wait(timeout=5)
807
+ if not wait_flag:
808
+ logger.warning("Async save process fails to create. will kill and recreate")
809
+ process.kill()
810
+ else:
811
+ process_flag = False
812
+ else:
813
+ data_copy = copy.deepcopy(data_list)
814
+ _wait_async_thread_save_ckpt()
815
+ thr = Thread(target=_exec_save,
816
+ args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
817
+ name="asyn_save_ckpt")
818
+ thr.start()
695
819
  else:
696
820
  _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
697
821
 
@@ -1198,8 +1322,28 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1198
1322
  ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
1199
1323
  if format == "safetensors":
1200
1324
  with safe_open(ckpt_file_name, framework='np') as f:
1201
- for k in f.keys():
1202
- parameter_dict[k] = Parameter(f.get_tensor(k))
1325
+ cal_crc_num = 0
1326
+ sf_load_time_start = time.time()
1327
+ for k in sorted(f.keys()):
1328
+ if crc_check:
1329
+ cal_crc_num = binascii.crc32(bytes(k, encoding='utf-8'), cal_crc_num)
1330
+ cal_crc_num = binascii.crc32(bytes(f.get_tensor(k)), cal_crc_num)
1331
+ if choice_func is not None and not choice_func(k):
1332
+ continue
1333
+ parameter_dict[k] = Parameter(Tensor.from_numpy(f.get_tensor(k)))
1334
+ sf_load_time_end = time.time()
1335
+ cost_time = sf_load_time_end - sf_load_time_start
1336
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load safetensors cost time:{cost_time}.")
1337
+ if crc_check:
1338
+ if f.metadata() is None or f.metadata().get("crc_num") is None:
1339
+ logger.warning(
1340
+ "For 'load_checkpoint', the safetensors file do not contain the crc code, "
1341
+ "please check the file.")
1342
+ else:
1343
+ crc_num = int(f.metadata()["crc_num"])
1344
+ if cal_crc_num != crc_num:
1345
+ raise ValueError("For 'load_checkpoint', the crc check has failed. "
1346
+ "Please check whether the ckpt file is damaged.")
1203
1347
  return
1204
1348
  checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
1205
1349
  try:
@@ -1343,13 +1487,14 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1343
1487
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1344
1488
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1345
1489
  """
1490
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin load checkpoint.")
1346
1491
  specify_prefix = _check_prefix(specify_prefix)
1347
1492
  filter_prefix = _check_prefix(filter_prefix)
1348
1493
  dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1349
1494
  dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1350
1495
  crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
1351
1496
  remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1352
- _check_format_and_other_params(format, dec_key, dec_mode, crc_check)
1497
+ _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode)
1353
1498
  logger.info("Execute the process of loading checkpoint files.")
1354
1499
 
1355
1500
  parameter_dict = {}
@@ -1389,6 +1534,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1389
1534
  if _warm_up_host_cache_enabled(parameter_dict):
1390
1535
  _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1391
1536
 
1537
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Load checkpoint is finished.")
1392
1538
  return parameter_dict
1393
1539
 
1394
1540
 
@@ -1445,7 +1591,8 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
1445
1591
  >>> from mindspore import context
1446
1592
  >>> from mindspore import load_checkpoint_async
1447
1593
  >>> from mindspore import load_param_into_net
1448
- >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
1594
+ >>> mindspore.set_device(device_target="Ascend")
1595
+ >>> context.set_context(mode=context.GRAPH_MODE)
1449
1596
  >>> # Create the dataset taking MNIST as an example. Refer to
1450
1597
  >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
1451
1598
  >>> dataset = create_dataset()
@@ -1552,7 +1699,12 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1552
1699
  try:
1553
1700
  if dec_key is None:
1554
1701
  with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
1702
+ ckpt_load_time_start = time.time()
1555
1703
  pb_content = f.read()
1704
+ ckpt_load_time_end = time.time()
1705
+ cost_time = ckpt_load_time_end - ckpt_load_time_start
1706
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load ckpt cost time:{cost_time}.")
1707
+
1556
1708
  else:
1557
1709
  pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
1558
1710
  if pb_content is None:
@@ -1670,8 +1822,6 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1670
1822
  strict_load = Validator.check_bool(strict_load)
1671
1823
  remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1672
1824
  logger.info("Execute the process of loading parameters into net.")
1673
- for _, param in net.parameters_and_names():
1674
- param.from_ckpt = True
1675
1825
  param_not_load = []
1676
1826
  ckpt_not_load = list(parameter_dict.keys())
1677
1827
  for _, param in net.parameters_and_names():
@@ -2093,7 +2243,7 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
2093
2243
  logger.info("exporting model file:%s format:%s.", file_name, file_format)
2094
2244
  if "obf_config" in kwargs and file_format != "MINDIR":
2095
2245
  raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
2096
- if "custom_func" in kwargs and file_format != "MINDIR":
2246
+ if "custom_func" in kwargs and file_format != "MINDIR" and kwargs["custom_func"] is not None:
2097
2247
  raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
2098
2248
  if file_format == 'AIR':
2099
2249
  _save_air(net, file_name, *inputs, **kwargs)
@@ -2475,6 +2625,9 @@ def check_checkpoint(ckpt_file_name):
2475
2625
  """
2476
2626
  Check whether the checkpoint is valid.
2477
2627
 
2628
+ Note:
2629
+ The interface is deprecated from version 2.5 and will be removed in a future version.
2630
+
2478
2631
  Args:
2479
2632
  ckpt_file_name (str): Checkpoint file name.
2480
2633
 
@@ -2488,6 +2641,8 @@ def check_checkpoint(ckpt_file_name):
2488
2641
  >>> print(check_result)
2489
2642
  True
2490
2643
  """
2644
+ logger.warning("The interface 'mindspore.check_checkpoint' is deprecated from version 2.5 "
2645
+ "and will be removed in a future version.")
2491
2646
  if not ckpt_file_name.endswith('.ckpt'):
2492
2647
  return False
2493
2648
  checkpoint_list = Checkpoint()
@@ -2514,6 +2669,9 @@ def parse_print(print_file_name):
2514
2669
  """
2515
2670
  Parse data file generated by :class:`mindspore.ops.Print`.
2516
2671
 
2672
+ Note:
2673
+ The interface is deprecated from version 2.5 and will be removed in a future version.
2674
+
2517
2675
  Args:
2518
2676
  print_file_name (str): The file name needs to be parsed.
2519
2677
 
@@ -2548,6 +2706,8 @@ def parse_print(print_file_name):
2548
2706
  [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
2549
2707
  [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
2550
2708
  """
2709
+ logger.warning("The interface 'mindspore.parse_print' is deprecated from version 2.5 "
2710
+ "and will be removed in a future version.")
2551
2711
  print_file_path = os.path.realpath(print_file_name)
2552
2712
 
2553
2713
  if os.path.getsize(print_file_path) == 0:
@@ -2837,16 +2997,33 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
2837
2997
  return merged_parameter
2838
2998
 
2839
2999
 
3000
+ def _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, dst_device_num,
3001
+ output_format, name_map, return_param_dict):
3002
+ """gather transform tasks"""
3003
+ tasks = []
3004
+ for rank in range(0, dst_device_num):
3005
+ tasks.append(
3006
+ (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank, output_format, name_map,
3007
+ return_param_dict))
3008
+ return tasks
3009
+
3010
+
2840
3011
  def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
2841
3012
  train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
2842
- format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
3013
+ format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
3014
+ output_format='safetensors', name_map=None, max_process_num=64,
3015
+ return_param_dict=False):
2843
3016
  """
2844
3017
  Load checkpoint into net for distributed predication. Used in the case of distributed inference.
2845
3018
 
3019
+ Note:
3020
+ `output_format` will only take effect when `format` is set to `safetensors` and `network` is set to `None`.
3021
+
2846
3022
  Args:
2847
- network (Cell): Network for distributed predication.
3023
+ network (Cell): Network for distributed predication, When the format is `safetensors`, the network parameter
3024
+ can be left blank or passed as None, and the interface will execute save mode.
2848
3025
  checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
2849
- predict_strategy (dict): Strategy of predication process. It means that using one device to predict
3026
+ predict_strategy (Union[dict, str]): Strategy of predication process. It means that using one device to predict
2850
3027
  when setting predict_strategy as None. Default: ``None`` .
2851
3028
  train_strategy_filename (str): The filename of training strategy protocol buffer file.
2852
3029
  When train_strategy_filename is None, the training strategy file will be
@@ -2866,17 +3043,23 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
2866
3043
  It can be set to either "ckpt" or "safetensors". Default: "ckpt".
2867
3044
  unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
2868
3045
  Default: ``None`` .
2869
- dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
3046
+ dst_safetensors_dir (str): In the save mode scenario, the save directory for weights.
2870
3047
  rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
2871
3048
  globally by initializing the network; In save mode, save the file according to the input
2872
3049
  sequence number. If it is not input, save the entire file.
3050
+ output_format (str, optional): Control the format of the output checkpoint after conversion.
3051
+ It can be set to either "ckpt" or "safetensors". Default: "safetensors".
3052
+ name_map (dict): The weight mapping dictionary will modify the weight names according to the mapping
3053
+ dictionary before loading or saving the segmented weights into the network. Default: None.
3054
+ max_process_num (int): Maximum number of processes. Default: 64.
3055
+ return_param_dict (bool): Whether to return the param_dict. Default: ``False``.
2873
3056
 
2874
3057
  Raises:
2875
3058
  TypeError: The type of inputs do not match the requirements.
2876
3059
  ValueError: Failed to load checkpoint into net.
2877
3060
 
2878
3061
  Supported Platforms:
2879
- ``Ascend`` ``GPU``
3062
+ ``Ascend`` ``GPU`` ``CPU``
2880
3063
 
2881
3064
  Examples:
2882
3065
  .. note::
@@ -2973,9 +3156,10 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
2973
3156
  ...
2974
3157
  [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
2975
3158
  """
2976
- if format not in ['safetensors', 'ckpt']:
3159
+ if format not in ['safetensors', 'ckpt'] or output_format not in ['safetensors', 'ckpt']:
2977
3160
  raise ValueError(
2978
- f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
3161
+ f"For 'load_distributed_checkpoint', 'format' and 'output_format' "
3162
+ f"must be 'ckpt' or 'safetensors', but got {format}.")
2979
3163
 
2980
3164
  if format == 'safetensors':
2981
3165
  if unified_safetensors_dir is None:
@@ -2990,36 +3174,32 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
2990
3174
  raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
2991
3175
  f"when format is 'safetensors'.")
2992
3176
  if network is not None:
2993
- rank_id = get_rank()
2994
- _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
3177
+ try:
3178
+ rank_id = get_rank()
3179
+ except RuntimeError:
3180
+ rank_id = 0
3181
+ logger.warning(f"Get rank failed, default loading weight for rank 0.")
3182
+ param_dict = _load_parallel_checkpoint(
3183
+ (unified_safetensors_dir, predict_strategy, network, None, rank_id, output_format, name_map,
3184
+ return_param_dict))
3185
+ return param_dict
3186
+ if dst_safetensors_dir is None:
3187
+ raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
3188
+ f"when network is None.")
3189
+ if rank_id is not None:
3190
+ _load_parallel_checkpoint(
3191
+ (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3192
+ rank_id, output_format, name_map, return_param_dict))
2995
3193
  else:
2996
- if dst_safetensors_dir is None:
2997
- raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
2998
- f"when network is None.")
2999
- if rank_id is not None:
3000
- _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3001
- rank_id)
3002
- else:
3003
- dst_strategy_dict = _build_searched_strategy(predict_strategy)
3004
- dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
3005
- dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
3006
- dst_device_num = dst_stage_device_num * dst_stage_num
3007
- processes = []
3008
- activate_processes = 0
3009
- for rank in range(0, dst_device_num):
3010
- p = Process(target=_load_parallel_checkpoint, args=(
3011
- unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
3012
- p.start()
3013
- processes.append(p)
3014
- activate_processes += 1
3015
- max_processes = 64
3016
- if activate_processes >= max_processes:
3017
- p = processes.pop(0)
3018
- p.join()
3019
- activate_processes -= 1
3020
- for p in processes:
3021
- p.join()
3022
- return
3194
+ dst_strategy_dict = _build_searched_strategy(predict_strategy)
3195
+ dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
3196
+ dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
3197
+ dst_device_num = dst_stage_device_num * dst_stage_num
3198
+ tasks = _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3199
+ dst_device_num, output_format, name_map, return_param_dict)
3200
+ with Pool(processes=max_process_num) as pool:
3201
+ list(pool.imap(_load_parallel_checkpoint, tasks))
3202
+ return True
3023
3203
 
3024
3204
  network = Validator.check_isinstance("network", network, nn.Cell)
3025
3205
  _check_checkpoint_file(checkpoint_filenames)
@@ -3072,14 +3252,15 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
3072
3252
  if first_dim_shard_idx >= 0:
3073
3253
  first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
3074
3254
  if train_strategy.get(param.name)[5]:
3075
- shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
3255
+ repeat_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
3076
3256
  else:
3077
- shard_size = 0
3257
+ repeat_size = 0
3078
3258
  for rank in param_rank:
3079
3259
  param_total_list = list(range(0, ckpt_file_len))
3080
3260
  if first_dim_shard_size != 1:
3081
3261
  param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
3082
- if shard_size > 0:
3262
+ if repeat_size > 0:
3263
+ shard_size = shard_stride * train_strategy.get(param.name)[5]
3083
3264
  rank_index = param_total_list.index(rank)
3084
3265
  start = rank_index // shard_size * shard_size
3085
3266
  param_total_list = param_total_list[start:start + shard_size]
@@ -3138,12 +3319,16 @@ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_stra
3138
3319
  .format(param_not_in_ckpt))
3139
3320
 
3140
3321
  load_param_into_net(network, param_dict, strict_load=strict_load)
3322
+ return True
3141
3323
 
3142
3324
 
3143
3325
  def async_ckpt_thread_status():
3144
3326
  """
3145
3327
  Get the status of asynchronous save checkpoint thread.
3146
3328
 
3329
+ Note:
3330
+ The interface is deprecated from version 2.5 and will be removed in a future version.
3331
+
3147
3332
  When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed.
3148
3333
 
3149
3334
  Returns:
@@ -3155,6 +3340,8 @@ def async_ckpt_thread_status():
3155
3340
  >>> ms.async_ckpt_thread_status()
3156
3341
  False
3157
3342
  """
3343
+ logger.warning("The interface 'mindspore.async_ckpt_thread_status' is deprecated from version 2.5 "
3344
+ "and will be removed in a future version.")
3158
3345
  thr_list = threading.enumerate()
3159
3346
  return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
3160
3347
 
@@ -3285,8 +3472,8 @@ def convert_model(mindir_file, convert_file, file_format):
3285
3472
  """
3286
3473
  Convert mindir model to other format model. The current version only supports conversion to ONNX models.
3287
3474
 
3288
- .. warning::
3289
- This is an experimental API that is subject to change or deletion.
3475
+ Note:
3476
+ The interface is deprecated from version 2.5 and will be removed in a future version.
3290
3477
 
3291
3478
  Args:
3292
3479
  mindir_file (str): MindIR file name.
@@ -3302,6 +3489,8 @@ def convert_model(mindir_file, convert_file, file_format):
3302
3489
  >>> import mindspore as ms
3303
3490
  >>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
3304
3491
  """
3492
+ logger.warning("The interface 'mindspore.train.serialization.convert_model' is deprecated from version 2.5 "
3493
+ "and will be removed in a future version.")
3305
3494
  Validator.check_file_name_by_regular(mindir_file)
3306
3495
  Validator.check_file_name_by_regular(convert_file)
3307
3496
  if file_format != "ONNX":
@@ -3313,3 +3502,203 @@ def convert_model(mindir_file, convert_file, file_format):
3313
3502
  export(net, net_input, file_name=convert_file, file_format=file_format)
3314
3503
  else:
3315
3504
  export(net, *net_input, file_name=convert_file, file_format=file_format)
3505
+
3506
+
3507
+ def _transform_tensor_to_numpy(path, name_map=None):
3508
+ return _load_and_transform(path, name_map, mindspore.load_checkpoint, lambda v, new_name: v.asnumpy())
3509
+
3510
+
3511
+ def _transform_numpy_to_tensor(path, name_map=None):
3512
+ return _load_and_transform(path, name_map, load_file, lambda v, new_name: mindspore.Parameter(v, name=new_name))
3513
+
3514
+
3515
+ def _process_file(file_info):
3516
+ cur_ckpt_path, name_map, save_path, file = file_info
3517
+ param_dict_numpy = _transform_tensor_to_numpy(cur_ckpt_path, name_map)
3518
+ safetensors_filename = file.replace(".ckpt", ".safetensors")
3519
+ dst_file = os.path.join(save_path, safetensors_filename)
3520
+ save_file(param_dict_numpy, dst_file)
3521
+
3522
+
3523
+ def _process_file_safetensors(file_info):
3524
+ cur_safe_path, name_map, save_path, file = file_info
3525
+ param_dict_tensor = _transform_numpy_to_tensor(cur_safe_path, name_map)
3526
+ ckpt_filename = file.replace(".safetensors", ".ckpt")
3527
+ dst_file = os.path.join(save_path, ckpt_filename)
3528
+ mindspore.save_checkpoint(param_dict_tensor, dst_file)
3529
+
3530
+
3531
+ def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
3532
+ """gather transform rank together"""
3533
+ tasks = []
3534
+ for root, dirs, _ in os.walk(file_path):
3535
+ if root != file_path:
3536
+ continue
3537
+
3538
+ rank_dirs = [d for d in dirs if d.startswith('rank')]
3539
+ if not rank_dirs:
3540
+ raise ValueError(
3541
+ f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}")
3542
+
3543
+ for rank_dir in rank_dirs:
3544
+ rank_dir_path = os.path.join(root, rank_dir)
3545
+ dst_root = os.path.join(save_path,
3546
+ os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
3547
+ os.makedirs(dst_root, exist_ok=True)
3548
+ tasks.extend(
3549
+ (os.path.join(rank_dir_path, file), name_map, dst_root, file)
3550
+ for file in os.listdir(rank_dir_path)
3551
+ if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
3552
+ )
3553
+ return tasks
3554
+
3555
+
3556
+ def _gather_tasks_covert(file_path, save_path, file_name_regex, name_map):
3557
+ """gather transform rank together"""
3558
+ tasks = []
3559
+ for root, dirs, _ in os.walk(file_path):
3560
+ if root != file_path:
3561
+ continue
3562
+
3563
+ rank_dirs = [d for d in dirs if d.startswith('rank')]
3564
+ if not rank_dirs:
3565
+ raise ValueError(
3566
+ f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
3567
+
3568
+ for rank_dir in rank_dirs:
3569
+ rank_dir_path = os.path.join(root, rank_dir)
3570
+ dst_root = os.path.join(save_path,
3571
+ os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
3572
+ os.makedirs(dst_root, exist_ok=True)
3573
+ tasks.extend(
3574
+ (os.path.join(rank_dir_path, file), name_map, dst_root, file)
3575
+ for file in os.listdir(rank_dir_path)
3576
+ if file.endswith(".ckpt") and (file_name_regex is None or re.findall(file_name_regex, file))
3577
+ )
3578
+ return tasks
3579
+
3580
+
3581
+ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
3582
+ """
3583
+ Converts MindSpore checkpoint files into safetensors format and saves them to `save_path`.
3584
+ Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
3585
+ used for securely storing Tensors with fast speed (zero copy).
3586
+
3587
+ Note:
3588
+ The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
3589
+ too large, otherwise it may cause freezing.
3590
+ The safetensors format does not support the enc verification function. If ckpt is enabled to save enc
3591
+ verification, an error will be generated when performing the conversion.
3592
+ The safetensors format currently does not support crc verification function. If ckpt contains crc verification
3593
+ information, the crc verification information will be lost after conversion to safetensors.
3594
+
3595
+ Args:
3596
+ file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
3597
+ save_path (str, optional): Directory path where safetensors files will be saved. Defaults: ``None``.
3598
+ name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
3599
+ file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
3600
+ Defaults: ``None``.
3601
+ processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
3602
+ Raises:
3603
+ ValueError: If the input path is invalid or the save_path is not a directory,
3604
+ or the file_path does not end with '.ckpt'.
3605
+
3606
+ Supported Platforms:
3607
+ ``Ascend`` ``GPU`` ``CPU``
3608
+
3609
+ Examples:
3610
+ >>> import mindspore as ms
3611
+ >>> ms.ckpt_to_safetensors("./ckpt_save_path")
3612
+ >>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt")
3613
+ >>> ms.ckpt_to_safetensors(file_path="./ckpt_save_path/rank0/checkpoint_0.ckpt", save_path="./new_path/")
3614
+ >>> namemap = {"lin.weight":"new_name"}
3615
+ >>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
3616
+ """
3617
+ is_dir = os.path.isdir(file_path)
3618
+ is_file = os.path.isfile(file_path)
3619
+ if not is_dir and not is_file:
3620
+ raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
3621
+ if save_path and os.path.splitext(save_path)[1]:
3622
+ raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
3623
+ if name_map is not None and not isinstance(name_map, dict):
3624
+ raise ValueError(
3625
+ f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
3626
+
3627
+ if is_dir:
3628
+ tasks = _gather_tasks_covert(file_path, save_path, file_name_regex, name_map)
3629
+ with mp.Pool(processes=processes_num) as pool:
3630
+ list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
3631
+ elif is_file:
3632
+ if not file_path.endswith(".ckpt"):
3633
+ raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
3634
+ if file_name_regex is not None and not re.findall(file_name_regex, file_path):
3635
+ raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
3636
+ if save_path and not os.path.exists(save_path):
3637
+ os.makedirs(save_path, exist_ok=True)
3638
+
3639
+ param_dict_numpy = _transform_tensor_to_numpy(file_path, name_map)
3640
+ safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
3641
+ dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
3642
+ save_file(param_dict_numpy, dst_file)
3643
+
3644
+
3645
+ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
3646
+ """
3647
+ Converts safetensors files into MindSpore checkpoint format and saves them to `save_path`.
3648
+ Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
3649
+ used for securely storing Tensors with fast speed (zero copy).
3650
+
3651
+ Note:
3652
+ The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
3653
+ too large, otherwise it may cause freezing.
3654
+
3655
+ Args:
3656
+ file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
3657
+ save_path (str, optional): Directory path where checkpoint files will be saved. Defaults: ``None``.
3658
+ name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
3659
+ file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
3660
+ Defaults: ``None``.
3661
+ processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
3662
+
3663
+ Raises:
3664
+ ValueError: If the input path is invalid, the save_path is not a directory,
3665
+ or the file_path does not end with '.safetensors'.
3666
+
3667
+ Supported Platforms:
3668
+ ``Ascend`` ``GPU`` ``CPU``
3669
+
3670
+ Examples:
3671
+ >>> import mindspore as ms
3672
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path")
3673
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors")
3674
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/")
3675
+ >>> namemap = {"lin.weight":"new_name"}
3676
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
3677
+ """
3678
+ is_dir = os.path.isdir(file_path)
3679
+ is_file = os.path.isfile(file_path)
3680
+ if not is_dir and not is_file:
3681
+ raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
3682
+ if save_path and os.path.splitext(save_path)[1]:
3683
+ raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
3684
+ if name_map is not None and not isinstance(name_map, dict):
3685
+ raise ValueError(
3686
+ f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
3687
+
3688
+ if is_dir:
3689
+ tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
3690
+ with mp.Pool(processes=processes_num) as pool:
3691
+ list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
3692
+ elif is_file:
3693
+ if not file_path.endswith(".safetensors"):
3694
+ raise ValueError(
3695
+ f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
3696
+ if file_name_regex is not None and not re.findall(file_name_regex, file_path):
3697
+ raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
3698
+ if save_path and not os.path.exists(save_path):
3699
+ os.makedirs(save_path, exist_ok=True)
3700
+
3701
+ param_dict_tensor = _transform_numpy_to_tensor(file_path, name_map)
3702
+ ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
3703
+ dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
3704
+ mindspore.save_checkpoint(param_dict_tensor, dst_file)