mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +46 -197
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +217 -98
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +435 -371
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +2 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +951 -1992
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +314 -566
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/experimental/es/embedding_service.py +35 -27
  112. mindspore/experimental/llm_boost/__init__.py +1 -0
  113. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  114. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  115. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  116. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  117. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  118. mindspore/experimental/llm_boost/register.py +1 -0
  119. mindspore/experimental/map_parameter.py +4 -4
  120. mindspore/experimental/optim/adadelta.py +6 -6
  121. mindspore/experimental/optim/adagrad.py +4 -4
  122. mindspore/experimental/optim/adam.py +7 -0
  123. mindspore/experimental/optim/adamax.py +4 -4
  124. mindspore/experimental/optim/adamw.py +4 -0
  125. mindspore/experimental/optim/asgd.py +1 -1
  126. mindspore/experimental/optim/lr_scheduler.py +73 -46
  127. mindspore/experimental/optim/radam.py +34 -31
  128. mindspore/experimental/optim/rprop.py +1 -1
  129. mindspore/experimental/optim/sgd.py +1 -1
  130. mindspore/hal/contiguous_tensors_handle.py +6 -10
  131. mindspore/hal/device.py +55 -53
  132. mindspore/hal/event.py +52 -52
  133. mindspore/hal/memory.py +157 -117
  134. mindspore/hal/stream.py +150 -109
  135. mindspore/include/api/context.h +0 -1
  136. mindspore/include/dataset/constants.h +7 -4
  137. mindspore/include/dataset/execute.h +2 -2
  138. mindspore/jpeg62.dll +0 -0
  139. mindspore/log.py +50 -0
  140. mindspore/mindrecord/__init__.py +21 -8
  141. mindspore/mindrecord/config.py +17 -316
  142. mindspore/mindrecord/filereader.py +1 -9
  143. mindspore/mindrecord/filewriter.py +5 -15
  144. mindspore/mindrecord/mindpage.py +1 -9
  145. mindspore/mindspore_backend_common.dll +0 -0
  146. mindspore/mindspore_backend_manager.dll +0 -0
  147. mindspore/mindspore_common.dll +0 -0
  148. mindspore/mindspore_core.dll +0 -0
  149. mindspore/mindspore_dump.dll +0 -0
  150. mindspore/mindspore_frontend.dll +0 -0
  151. mindspore/mindspore_memory_pool.dll +0 -0
  152. mindspore/mindspore_ms_backend.dll +0 -0
  153. mindspore/mindspore_ops.dll +0 -0
  154. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  155. mindspore/mindspore_ops_kernel_common.dll +0 -0
  156. mindspore/mindspore_profiler.dll +0 -0
  157. mindspore/mindspore_pyboost.dll +0 -0
  158. mindspore/mindspore_pynative.dll +0 -0
  159. mindspore/mindspore_res_manager.dll +0 -0
  160. mindspore/mindspore_runtime_pipeline.dll +0 -0
  161. mindspore/mint/__init__.py +796 -759
  162. mindspore/mint/distributed/__init__.py +70 -4
  163. mindspore/mint/distributed/distributed.py +2679 -44
  164. mindspore/mint/linalg/__init__.py +8 -0
  165. mindspore/mint/nn/__init__.py +743 -22
  166. mindspore/mint/nn/functional.py +716 -23
  167. mindspore/mint/nn/layer/__init__.py +21 -4
  168. mindspore/mint/nn/layer/_functions.py +334 -0
  169. mindspore/mint/nn/layer/activation.py +276 -1
  170. mindspore/mint/nn/layer/basic.py +123 -0
  171. mindspore/mint/nn/layer/conv.py +921 -0
  172. mindspore/mint/nn/layer/normalization.py +223 -28
  173. mindspore/mint/nn/layer/padding.py +797 -0
  174. mindspore/mint/nn/layer/pooling.py +235 -0
  175. mindspore/mint/optim/__init__.py +3 -1
  176. mindspore/mint/optim/adam.py +223 -0
  177. mindspore/mint/optim/adamw.py +26 -19
  178. mindspore/mint/optim/sgd.py +171 -0
  179. mindspore/mint/special/__init__.py +2 -1
  180. mindspore/multiprocessing/__init__.py +5 -0
  181. mindspore/nn/__init__.py +4 -1
  182. mindspore/nn/cell.py +1370 -189
  183. mindspore/nn/dynamic_lr.py +2 -1
  184. mindspore/nn/layer/activation.py +29 -27
  185. mindspore/nn/layer/basic.py +51 -35
  186. mindspore/nn/layer/channel_shuffle.py +3 -3
  187. mindspore/nn/layer/container.py +1 -1
  188. mindspore/nn/layer/conv.py +22 -17
  189. mindspore/nn/layer/embedding.py +12 -11
  190. mindspore/nn/layer/normalization.py +56 -49
  191. mindspore/nn/layer/padding.py +4 -3
  192. mindspore/nn/layer/pooling.py +120 -42
  193. mindspore/nn/layer/rnn_cells.py +1 -1
  194. mindspore/nn/layer/rnns.py +2 -1
  195. mindspore/nn/layer/timedistributed.py +5 -5
  196. mindspore/nn/layer/transformer.py +59 -36
  197. mindspore/nn/learning_rate_schedule.py +8 -4
  198. mindspore/nn/loss/loss.py +58 -55
  199. mindspore/nn/optim/ada_grad.py +7 -5
  200. mindspore/nn/optim/adadelta.py +11 -9
  201. mindspore/nn/optim/adafactor.py +1 -1
  202. mindspore/nn/optim/adam.py +17 -13
  203. mindspore/nn/optim/adamax.py +8 -7
  204. mindspore/nn/optim/adasum.py +5 -5
  205. mindspore/nn/optim/asgd.py +1 -1
  206. mindspore/nn/optim/ftrl.py +11 -9
  207. mindspore/nn/optim/lamb.py +1 -1
  208. mindspore/nn/optim/lars.py +1 -4
  209. mindspore/nn/optim/lazyadam.py +12 -10
  210. mindspore/nn/optim/momentum.py +7 -6
  211. mindspore/nn/optim/optimizer.py +3 -3
  212. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  213. mindspore/nn/optim/rmsprop.py +13 -12
  214. mindspore/nn/optim/rprop.py +11 -9
  215. mindspore/nn/optim/sgd.py +9 -6
  216. mindspore/nn/optim/tft_wrapper.py +5 -2
  217. mindspore/nn/optim/thor.py +2 -1
  218. mindspore/nn/probability/bijector/bijector.py +17 -11
  219. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  220. mindspore/nn/probability/bijector/invert.py +2 -2
  221. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  222. mindspore/nn/probability/bijector/softplus.py +3 -2
  223. mindspore/nn/probability/distribution/beta.py +3 -3
  224. mindspore/nn/probability/distribution/categorical.py +1 -1
  225. mindspore/nn/probability/distribution/cauchy.py +4 -2
  226. mindspore/nn/probability/distribution/exponential.py +6 -7
  227. mindspore/nn/probability/distribution/gamma.py +2 -2
  228. mindspore/nn/probability/distribution/gumbel.py +2 -2
  229. mindspore/nn/probability/distribution/half_normal.py +5 -3
  230. mindspore/nn/probability/distribution/logistic.py +5 -3
  231. mindspore/nn/probability/distribution/poisson.py +1 -1
  232. mindspore/nn/probability/distribution/uniform.py +5 -3
  233. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  234. mindspore/nn/reinforcement/tensor_array.py +1 -1
  235. mindspore/nn/utils/init.py +13 -11
  236. mindspore/nn/wrap/__init__.py +6 -6
  237. mindspore/nn/wrap/cell_wrapper.py +181 -122
  238. mindspore/nn/wrap/grad_reducer.py +45 -36
  239. mindspore/nn/wrap/loss_scale.py +6 -7
  240. mindspore/numpy/array_creations.py +63 -65
  241. mindspore/numpy/array_ops.py +149 -144
  242. mindspore/numpy/logic_ops.py +41 -42
  243. mindspore/numpy/math_ops.py +365 -363
  244. mindspore/numpy/utils.py +17 -18
  245. mindspore/numpy/utils_const.py +5 -6
  246. mindspore/opencv_core452.dll +0 -0
  247. mindspore/opencv_imgcodecs452.dll +0 -0
  248. mindspore/opencv_imgproc452.dll +0 -0
  249. mindspore/ops/__init__.py +5 -3
  250. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  251. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  252. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  253. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  254. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  255. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  256. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  257. mindspore/ops/_register_for_op.py +0 -11
  258. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  259. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  260. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  261. mindspore/ops/_vmap/vmap_base.py +0 -2
  262. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  263. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  264. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  265. mindspore/ops/auto_generate/__init__.py +4 -3
  266. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  267. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  268. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  269. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  270. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  271. mindspore/ops/composite/__init__.py +2 -1
  272. mindspore/ops/composite/base.py +20 -25
  273. mindspore/ops/composite/math_ops.py +6 -16
  274. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  275. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  276. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  277. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  278. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  279. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  282. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  283. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  284. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  285. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  286. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  287. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  288. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  289. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  291. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  292. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  293. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  294. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  295. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  297. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  301. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  302. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  304. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  305. mindspore/ops/function/__init__.py +40 -2
  306. mindspore/ops/function/_add_attr_func.py +58 -0
  307. mindspore/ops/function/array_func.py +2089 -2403
  308. mindspore/ops/function/clip_func.py +80 -23
  309. mindspore/ops/function/debug_func.py +57 -57
  310. mindspore/ops/function/grad/__init__.py +1 -0
  311. mindspore/ops/function/grad/grad_func.py +104 -71
  312. mindspore/ops/function/image_func.py +2 -2
  313. mindspore/ops/function/linalg_func.py +47 -78
  314. mindspore/ops/function/math_func.py +4501 -3802
  315. mindspore/ops/function/nn_func.py +1726 -620
  316. mindspore/ops/function/other_func.py +159 -1
  317. mindspore/ops/function/parameter_func.py +18 -84
  318. mindspore/ops/function/random_func.py +440 -387
  319. mindspore/ops/function/reshard_func.py +4 -70
  320. mindspore/ops/function/sparse_func.py +3 -3
  321. mindspore/ops/function/sparse_unary_func.py +6 -6
  322. mindspore/ops/function/spectral_func.py +25 -58
  323. mindspore/ops/function/vmap_func.py +24 -17
  324. mindspore/ops/functional.py +22 -7
  325. mindspore/ops/functional_overload.py +1440 -0
  326. mindspore/ops/op_info_register.py +32 -244
  327. mindspore/ops/operations/__init__.py +13 -7
  328. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  329. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  330. mindspore/ops/operations/_grad_ops.py +2 -43
  331. mindspore/ops/operations/_infer_ops.py +2 -1
  332. mindspore/ops/operations/_inner_ops.py +43 -84
  333. mindspore/ops/operations/_ms_kernel.py +4 -10
  334. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  335. mindspore/ops/operations/_scalar_ops.py +3 -2
  336. mindspore/ops/operations/_sequence_ops.py +1 -1
  337. mindspore/ops/operations/_tensor_array.py +1 -1
  338. mindspore/ops/operations/array_ops.py +81 -324
  339. mindspore/ops/operations/comm_ops.py +154 -108
  340. mindspore/ops/operations/custom_ops.py +232 -78
  341. mindspore/ops/operations/debug_ops.py +153 -59
  342. mindspore/ops/operations/inner_ops.py +7 -5
  343. mindspore/ops/operations/linalg_ops.py +1 -57
  344. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  345. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  346. mindspore/ops/operations/math_ops.py +32 -234
  347. mindspore/ops/operations/nn_ops.py +210 -498
  348. mindspore/ops/operations/other_ops.py +62 -9
  349. mindspore/ops/operations/random_ops.py +13 -7
  350. mindspore/ops/operations/reshard_ops.py +1 -1
  351. mindspore/ops/operations/sparse_ops.py +2 -2
  352. mindspore/ops/primitive.py +66 -53
  353. mindspore/ops/tensor_method.py +1888 -0
  354. mindspore/ops_generate/__init__.py +0 -5
  355. mindspore/ops_generate/aclnn/__init__.py +0 -0
  356. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  357. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  358. mindspore/ops_generate/api/__init__.py +0 -0
  359. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  360. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  361. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  362. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  363. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  364. mindspore/ops_generate/api/gen_api.py +103 -0
  365. mindspore/ops_generate/api/op_api_proto.py +235 -0
  366. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  367. mindspore/ops_generate/common/__init__.py +0 -0
  368. mindspore/ops_generate/common/base_generator.py +11 -0
  369. mindspore/ops_generate/common/gen_constants.py +91 -0
  370. mindspore/ops_generate/common/gen_utils.py +348 -0
  371. mindspore/ops_generate/common/op_proto.py +473 -0
  372. mindspore/ops_generate/common/template.py +523 -0
  373. mindspore/ops_generate/gen_ops.py +22 -1069
  374. mindspore/ops_generate/op_def/__init__.py +0 -0
  375. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  376. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  377. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  378. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  379. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  380. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  381. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  382. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  383. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  384. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  385. mindspore/ops_generate/pyboost/__init__.py +0 -0
  386. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  387. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  388. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  389. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  390. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  391. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  393. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  394. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  395. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  396. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  397. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  398. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  399. mindspore/ops_generate/resources/__init__.py +0 -0
  400. mindspore/ops_generate/resources/resource_list.py +30 -0
  401. mindspore/ops_generate/resources/resource_loader.py +36 -0
  402. mindspore/ops_generate/resources/resource_manager.py +64 -0
  403. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  404. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  405. mindspore/parallel/__init__.py +7 -3
  406. mindspore/parallel/_auto_parallel_context.py +152 -34
  407. mindspore/parallel/_cell_wrapper.py +130 -15
  408. mindspore/parallel/_parallel_serialization.py +107 -5
  409. mindspore/parallel/_ps_context.py +1 -1
  410. mindspore/parallel/_recovery_context.py +7 -2
  411. mindspore/parallel/_tensor.py +142 -18
  412. mindspore/parallel/_utils.py +199 -23
  413. mindspore/parallel/algo_parameter_config.py +4 -4
  414. mindspore/parallel/auto_parallel.py +732 -0
  415. mindspore/parallel/checkpoint_convert.py +159 -0
  416. mindspore/parallel/checkpoint_transform.py +698 -35
  417. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  418. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  419. mindspore/parallel/cluster/run.py +21 -4
  420. mindspore/parallel/function/__init__.py +24 -0
  421. mindspore/parallel/function/reshard_func.py +259 -0
  422. mindspore/parallel/nn/__init__.py +25 -0
  423. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  424. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  425. mindspore/parallel/parameter_broadcast.py +25 -14
  426. mindspore/parallel/shard.py +137 -58
  427. mindspore/parallel/transform_safetensors.py +363 -305
  428. mindspore/profiler/__init__.py +22 -5
  429. mindspore/profiler/analysis/__init__.py +0 -0
  430. mindspore/profiler/analysis/parser/__init__.py +0 -0
  431. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  432. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  433. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  434. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  435. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  436. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  437. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  440. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  441. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  446. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  447. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  448. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  449. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  450. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  451. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  452. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  453. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  454. mindspore/profiler/analysis/task_manager.py +131 -0
  455. mindspore/profiler/analysis/time_converter.py +84 -0
  456. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  457. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  458. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  459. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  460. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  461. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  462. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  463. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  464. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  465. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  466. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  467. mindspore/profiler/analysis/work_flow.py +73 -0
  468. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  469. mindspore/profiler/common/command_executor.py +90 -0
  470. mindspore/profiler/common/constant.py +186 -3
  471. mindspore/profiler/common/file_manager.py +208 -0
  472. mindspore/profiler/common/log.py +130 -0
  473. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  474. mindspore/profiler/common/path_manager.py +395 -0
  475. mindspore/profiler/common/process_bar.py +168 -0
  476. mindspore/profiler/common/process_pool.py +9 -3
  477. mindspore/profiler/common/profiler_context.py +500 -0
  478. mindspore/profiler/common/profiler_info.py +304 -0
  479. mindspore/profiler/common/profiler_meta_data.py +74 -0
  480. mindspore/profiler/common/profiler_output_path.py +284 -0
  481. mindspore/profiler/common/profiler_parameters.py +251 -0
  482. mindspore/profiler/common/profiler_path_manager.py +179 -0
  483. mindspore/profiler/common/record_function.py +76 -0
  484. mindspore/profiler/common/tlv_decoder.py +76 -0
  485. mindspore/profiler/common/util.py +75 -2
  486. mindspore/profiler/dynamic_profiler.py +341 -75
  487. mindspore/profiler/envprofiler.py +163 -0
  488. mindspore/profiler/experimental_config.py +197 -0
  489. mindspore/profiler/mstx.py +242 -0
  490. mindspore/profiler/platform/__init__.py +21 -0
  491. mindspore/profiler/platform/base_profiler.py +40 -0
  492. mindspore/profiler/platform/cpu_profiler.py +124 -0
  493. mindspore/profiler/platform/gpu_profiler.py +74 -0
  494. mindspore/profiler/platform/npu_profiler.py +335 -0
  495. mindspore/profiler/profiler.py +1073 -90
  496. mindspore/profiler/profiler_action_controller.py +187 -0
  497. mindspore/profiler/profiler_interface.py +118 -0
  498. mindspore/profiler/schedule.py +243 -0
  499. mindspore/rewrite/api/node.py +15 -13
  500. mindspore/rewrite/api/symbol_tree.py +2 -3
  501. mindspore/run_check/_check_version.py +27 -20
  502. mindspore/run_check/run_check.py +1 -1
  503. mindspore/runtime/__init__.py +37 -0
  504. mindspore/runtime/device.py +27 -0
  505. mindspore/runtime/event.py +209 -0
  506. mindspore/runtime/executor.py +177 -0
  507. mindspore/runtime/memory.py +409 -0
  508. mindspore/runtime/stream.py +460 -0
  509. mindspore/runtime/thread_bind_core.py +401 -0
  510. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  511. mindspore/swresample-4.dll +0 -0
  512. mindspore/swscale-6.dll +0 -0
  513. mindspore/tinyxml2.dll +0 -0
  514. mindspore/train/__init__.py +8 -8
  515. mindspore/train/_utils.py +88 -25
  516. mindspore/train/amp.py +9 -5
  517. mindspore/train/callback/__init__.py +2 -2
  518. mindspore/train/callback/_callback.py +2 -16
  519. mindspore/train/callback/_checkpoint.py +53 -55
  520. mindspore/train/callback/_cluster_monitor.py +14 -18
  521. mindspore/train/callback/_early_stop.py +1 -1
  522. mindspore/train/callback/_flops_collector.py +103 -68
  523. mindspore/train/callback/_history.py +8 -5
  524. mindspore/train/callback/_lambda_callback.py +2 -2
  525. mindspore/train/callback/_landscape.py +0 -3
  526. mindspore/train/callback/_loss_monitor.py +2 -1
  527. mindspore/train/callback/_on_request_exit.py +6 -5
  528. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  529. mindspore/train/callback/_summary_collector.py +52 -19
  530. mindspore/train/callback/_time_monitor.py +2 -1
  531. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  532. mindspore/train/data_sink.py +25 -2
  533. mindspore/train/dataset_helper.py +15 -16
  534. mindspore/train/loss_scale_manager.py +8 -7
  535. mindspore/train/metrics/accuracy.py +3 -3
  536. mindspore/train/metrics/confusion_matrix.py +9 -9
  537. mindspore/train/metrics/error.py +3 -3
  538. mindspore/train/metrics/hausdorff_distance.py +4 -4
  539. mindspore/train/metrics/mean_surface_distance.py +3 -3
  540. mindspore/train/metrics/metric.py +0 -12
  541. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  542. mindspore/train/metrics/precision.py +11 -10
  543. mindspore/train/metrics/recall.py +9 -9
  544. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  545. mindspore/train/mind_ir_pb2.py +174 -46
  546. mindspore/train/model.py +184 -113
  547. mindspore/train/serialization.py +622 -978
  548. mindspore/train/summary/_summary_adapter.py +2 -2
  549. mindspore/train/summary/summary_record.py +2 -3
  550. mindspore/train/train_thor/model_thor.py +1 -1
  551. mindspore/turbojpeg.dll +0 -0
  552. mindspore/utils/__init__.py +6 -3
  553. mindspore/utils/dryrun.py +140 -0
  554. mindspore/utils/hooks.py +81 -0
  555. mindspore/utils/runtime_execution_order_check.py +550 -0
  556. mindspore/utils/utils.py +138 -4
  557. mindspore/version.py +1 -1
  558. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  559. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  561. mindspore/_install_custom.py +0 -43
  562. mindspore/common/_register_for_adapter.py +0 -74
  563. mindspore/common/_tensor_overload.py +0 -139
  564. mindspore/mindspore_np_dtype.dll +0 -0
  565. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  566. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  567. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  568. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  569. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  570. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  571. mindspore/ops_generate/gen_utils.py +0 -209
  572. mindspore/ops_generate/op_proto.py +0 -145
  573. mindspore/ops_generate/template.py +0 -261
  574. mindspore/profiler/envprofiling.py +0 -254
  575. mindspore/profiler/profiling.py +0 -1926
  576. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  577. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -24,10 +24,12 @@ import os
24
24
  import re
25
25
  import shutil
26
26
  import stat
27
+ import atexit
27
28
  import threading
28
29
  from threading import Thread, RLock
29
- from multiprocessing import Process
30
- from collections import defaultdict, OrderedDict
30
+ from multiprocessing import active_children
31
+ import multiprocessing as mp
32
+ from collections import OrderedDict
31
33
  from io import BytesIO
32
34
 
33
35
  import math
@@ -36,6 +38,9 @@ import time
36
38
  import google
37
39
  import numpy as np
38
40
 
41
+ from safetensors.numpy import save_file, load_file
42
+ from safetensors import safe_open
43
+
39
44
  from mindspore.train.checkpoint_pb2 import Checkpoint
40
45
  from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
41
46
  from mindspore.train.print_pb2 import Print
@@ -44,43 +49,37 @@ import mindspore
44
49
  import mindspore.nn as nn
45
50
  from mindspore import context
46
51
  from mindspore import log as logger
52
+ from mindspore.log import vlog_print
47
53
  from mindspore._checkparam import check_input_data, check_input_dataset
48
54
  from mindspore import _checkparam as Validator
49
55
  from mindspore.common import dtype as mstype
56
+ from mindspore.common import np_dtype
50
57
  from mindspore.common.api import _cell_graph_executor as _executor
51
- from mindspore.common.api import _MindsporeFunctionExecutor
58
+ from mindspore.common.api import _JitExecutor
52
59
  from mindspore.common.api import _get_parameter_layout
53
- from mindspore.common.api import _generate_branch_control_input
54
60
  from mindspore.common.initializer import initializer, One
55
61
  from mindspore.common.parameter import Parameter, _offload_if_config
56
62
  from mindspore.common.tensor import Tensor
57
- from mindspore._c_expression import Tensor as Tensor_
63
+ from mindspore._c_expression import TensorPy as Tensor_
58
64
  from mindspore.common._utils import is_shape_unknown
59
65
  from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
60
66
  from mindspore.communication.management import get_rank, get_group_size
61
67
  from mindspore.experimental import MapParameter
62
68
  from mindspore.ops import Cast
63
69
  from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
64
- from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
65
- from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
66
- from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
67
- _get_device_num
68
- from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
69
- from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
70
- _restore_group_info_list, _get_param_list_when_first_dim_sharded
70
+ from mindspore.parallel._tensor import _reshape_param_data
71
+ from mindspore.parallel._utils import _is_in_auto_parallel_mode
71
72
  from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
72
73
  _store_warm_up_ptr_by_tensor_list, _cache_enable
73
74
  from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
74
- from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
75
- _extract_pipeline_stage_num
76
- from mindspore.train._utils import read_proto, get_parameter_redundancy
77
- from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
75
+ from mindspore.parallel.checkpoint_transform import restore_group_info_list as new_restore_group_info_list
76
+ from mindspore.parallel.checkpoint_transform import load_distributed_checkpoint as new_load_distributed_checkpoint
77
+ from mindspore.parallel.checkpoint_transform import merge_sliced_parameter as new_merge_sliced_parameter
78
+ from mindspore.parallel.checkpoint_transform import build_searched_strategy as new_build_searched_strategy
79
+ from mindspore.train._utils import read_proto, get_parameter_redundancy, _progress_bar, _load_and_transform
80
+ from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, \
78
81
  split_mindir, split_dynamic_mindir
79
82
  from mindspore.common.generator import Generator
80
- from safetensors.numpy import save_file
81
- from safetensors import safe_open
82
- from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
83
-
84
83
 
85
84
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
86
85
  "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
@@ -91,6 +90,9 @@ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UIn
91
90
  "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
92
91
  "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
93
92
 
93
+ if hasattr(np_dtype, "bfloat16"):
94
+ tensor_to_np_type["BFloat16"] = np_dtype.bfloat16
95
+
94
96
  np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
95
97
 
96
98
  mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
@@ -123,24 +125,55 @@ def init_ckpt_file_system(fs: FileSystem):
123
125
  init_ckpt_file_system(_ckpt_fs)
124
126
 
125
127
 
128
+ def _wait_async_process_save_ckpt():
129
+ """Waiting for asynchronous saving process of ckpt to complete"""
130
+ for process in active_children():
131
+ if process.name == "asyn_save_ckpt":
132
+ process.join()
133
+
134
+
135
+ def _wait_async_thread_save_ckpt():
136
+ """Waiting for asynchronous saving thread of ckpt to complete"""
137
+ thread_list = threading.enumerate()
138
+ for thread in thread_list:
139
+ if thread.getName() == "asyn_save_ckpt":
140
+ thread.join()
141
+
142
+
143
+ def _async_save_close():
144
+ """Waiting for asynchronous saving of ckpt to complete"""
145
+ _wait_async_process_save_ckpt()
146
+ _wait_async_thread_save_ckpt()
147
+
148
+
149
+ # Registering atexit handles asynchronous save
150
+ atexit.register(_async_save_close)
151
+
152
+
126
153
  def _get_cur_rank_dp(parameter_layout_dict):
127
154
  """ Get dp and tp from layout dict. """
128
- pp_num = _get_auto_parallel_context("pipeline_stages")
129
- dev_num = _get_device_num()
130
155
  global_rank = get_rank()
131
- pipe_size = dev_num // pp_num
132
- initial_rank = (global_rank // pipe_size) * pipe_size
133
- parameter_redundancy_dict = get_parameter_redundancy(
134
- parameter_layout_dict, initial_rank)
156
+ parameter_redundancy_dict = get_parameter_redundancy(parameter_layout_dict)
135
157
  value_len = sys.maxsize
136
158
  min_value = ()
159
+ min_value_set = set()
137
160
  for key, value in parameter_redundancy_dict.items():
138
- if "accu_grads" in key or "inputs" in key:
161
+ if key.startswith("accu_grads") or key.startswith("inputs"):
139
162
  continue
140
163
  for item in value:
141
- if len(item) < value_len and global_rank in item:
164
+ if global_rank not in item:
165
+ continue
166
+ # if item is subset of min_value_set, update min_value_set and min_value
167
+ if len(item) < value_len:
168
+ if min_value_set and not set(item).issubset(min_value_set):
169
+ return (global_rank,)
142
170
  value_len = len(item)
171
+ min_value_set = set(item)
143
172
  min_value = item
173
+ # if value is not smaller than len of min_value len,
174
+ # check if min_value_set is subset of current item
175
+ elif not min_value_set.issubset(set(item)):
176
+ return (global_rank,)
144
177
  return min_value
145
178
 
146
179
 
@@ -160,7 +193,7 @@ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
160
193
  cur_strategy_path (str): strategy file path for current rank.
161
194
 
162
195
  Returns:
163
- - new_ckpt_file (String), if found available checkpoint file , return it.
196
+ - new_ckpt_file (str), if found available checkpoint file , return it.
164
197
  - None, if not found available checkpoint, return None.
165
198
 
166
199
  Examples:
@@ -175,6 +208,9 @@ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
175
208
  >>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
176
209
  >>> print(ckpt_file_new)
177
210
  """
211
+ cur_rank = get_rank()
212
+ if f"rank_{str(cur_rank)}" in cur_ckpt_path and os.path.isfile(cur_ckpt_path):
213
+ return cur_ckpt_path
178
214
  dp = _get_cur_rank_dp(cur_strategy_path)
179
215
  pattern = r'rank_\d+'
180
216
  for i in dp:
@@ -282,7 +318,8 @@ def _type_convert(param, new_param, strict_load):
282
318
  {param.data.dtype, new_param.data.dtype}.issubset(int_type)):
283
319
  logger.warning(f"The type of {new_param.name}:{new_param.data.dtype} in 'parameter_dict' is different from "
284
320
  f"the type of it in 'net':{param.data.dtype}, then the type convert from "
285
- f"{new_param.data.dtype} to {param.data.dtype} in the network.")
321
+ f"{new_param.data.dtype} to {param.data.dtype} in the network. May consume additional memory "
322
+ f"and time")
286
323
  return True
287
324
  return False
288
325
 
@@ -338,6 +375,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
338
375
  os.chmod(tmp_name, stat.S_IWUSR)
339
376
  os.remove(tmp_name)
340
377
  if format == "ckpt":
378
+ ckpt_total_io_time = 0
341
379
  with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
342
380
  plain_data = None
343
381
  if enc_key is not None:
@@ -354,20 +392,26 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
354
392
  if value[0] == "offload_parameter":
355
393
  new_value = value[1:]
356
394
  new_value[2] = value[3]
357
- _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
395
+ _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data, ckpt_total_io_time)
358
396
  _offload_if_config(value[3])
359
397
  continue
360
398
  if value[1] == "str":
361
- crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
399
+ crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
400
+ crc_num, crc_check,
401
+ ckpt_total_io_time)
362
402
  continue
363
403
  if isinstance(value[2], np.ndarray):
364
- crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
404
+ crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
405
+ crc_num, crc_check,
406
+ ckpt_total_io_time)
365
407
  continue
366
408
  if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
367
409
  _write_hugeparameter(name, value, f)
368
410
  continue
369
411
 
370
- crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
412
+ crc_num, ckpt_total_io_time = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
413
+ crc_num, crc_check,
414
+ ckpt_total_io_time)
371
415
 
372
416
  if enc_key is not None:
373
417
  plain_data.seek(0)
@@ -378,11 +422,36 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
378
422
  block_data = plain_data.read(max_block_size)
379
423
  if crc_check:
380
424
  f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
425
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
426
+ f"Save ckpt io cost time:{ckpt_total_io_time}.")
427
+
381
428
  elif format == "safetensors":
382
429
  save_dict = {}
383
- for name, value in data_list.items():
384
- save_dict[name] = value[2].asnumpy()
385
- save_file(save_dict, tmp_name)
430
+ crc_num = 0
431
+ for name in sorted(data_list.keys()):
432
+ value = data_list[name]
433
+ if isinstance(value[2], np.ndarray):
434
+ save_dict[name] = value[2]
435
+ else:
436
+ bytes_data = value[2].get_bytes()
437
+ np_type = tensor_to_np_type.get(value[1])
438
+ np_array = np.frombuffer(bytes_data, np_type)
439
+ new_np_array = np_array.reshape(value[0])
440
+ save_dict[name] = new_np_array
441
+
442
+ if crc_check:
443
+ crc_num = binascii.crc32(bytes(name, encoding='utf-8'), crc_num)
444
+ crc_num = binascii.crc32(
445
+ bytes(save_dict[name]), crc_num)
446
+ safetensors_save_time_start = time.time()
447
+ if crc_check:
448
+ save_file(save_dict, tmp_name, metadata={
449
+ "crc_num": str(crc_num)})
450
+ else:
451
+ save_file(save_dict, tmp_name)
452
+ safetensors_save_time_end = time.time()
453
+ cost_time = safetensors_save_time_end - safetensors_save_time_start
454
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors io cost time:{cost_time}.")
386
455
  if not os.path.exists(tmp_name):
387
456
  logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
388
457
  f"simultaneously modified a file.")
@@ -407,7 +476,7 @@ def _write_random_seed(name, value, f):
407
476
  f.write(checkpoint_list.SerializeToString())
408
477
 
409
478
 
410
- def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
479
+ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
411
480
  """Write parameter data into protobuf file."""
412
481
  data_size = value[2].nbytes / 1024
413
482
  if data_size > SLICE_SIZE:
@@ -429,14 +498,18 @@ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_ch
429
498
  output_data = checkpoint_list.SerializeToString()
430
499
  if crc_check:
431
500
  crc_num = binascii.crc32(output_data, crc_num)
501
+ io_start_time = time.time()
432
502
  f.write(output_data)
503
+ io_end_time = time.time()
504
+ io_cost_time = io_end_time - io_start_time
505
+ ckpt_total_io_time += io_cost_time
433
506
  else:
434
507
  plain_data.write(checkpoint_list.SerializeToString())
435
508
 
436
- return crc_num
509
+ return crc_num, ckpt_total_io_time
437
510
 
438
511
 
439
- def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
512
+ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
440
513
  """Write parameter bytes data into protobuf file."""
441
514
  bytes_value = value[2].get_bytes()
442
515
  chunk_size = 1024 * SLICE_SIZE
@@ -454,11 +527,15 @@ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0,
454
527
  output_data = checkpoint_list.SerializeToString()
455
528
  if crc_check:
456
529
  crc_num = binascii.crc32(output_data, crc_num)
530
+ io_start_time = time.time()
457
531
  f.write(output_data)
532
+ io_end_time = time.time()
533
+ io_cost_time = io_end_time - io_start_time
534
+ ckpt_total_io_time += io_cost_time
458
535
  else:
459
536
  plain_data.write(checkpoint_list.SerializeToString())
460
537
 
461
- return crc_num
538
+ return crc_num, ckpt_total_io_time
462
539
 
463
540
 
464
541
  def _write_mapparameter(name, value, f, map_param_inc=False):
@@ -522,12 +599,56 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
522
599
  return ckpt_file_name
523
600
 
524
601
 
525
- def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, map_param_inc=False,
526
- global_step_num=None):
527
- param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
528
- or map_param_inc or global_step_num is not None)
529
- if format == "safetensors" and param_not_default:
530
- raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
602
+ def _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode):
603
+ """check load checkpoint unsupported param"""
604
+ if format != "safetensors":
605
+ return
606
+ default_params = {
607
+ "dec_key": None,
608
+ "dec_mode": "AES-GCM",
609
+ }
610
+ for param_name, default_value in default_params.items():
611
+ current_value = locals()[param_name]
612
+ if current_value != default_value:
613
+ raise ValueError(f"For 'load_checkpoint', when format is 'safetensors', the parameter '{param_name}' must "
614
+ f"be set to default value '{default_value}', but got '{current_value}'.")
615
+
616
+
617
+ def _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc=False, global_step_num=None):
618
+ """check save checkpoint unsupported param"""
619
+ if format != "safetensors":
620
+ return
621
+ default_params = {
622
+ "enc_key": None,
623
+ "enc_mode": "AES-GCM",
624
+ "map_param_inc": False,
625
+ "global_step_num": None
626
+ }
627
+ for param_name, default_value in default_params.items():
628
+ current_value = locals()[param_name]
629
+ if current_value != default_value:
630
+ raise ValueError(f"For 'save_checkpoint', when format is 'safetensors', the parameter '{param_name}' must "
631
+ f"be set to default value '{default_value}', but got '{current_value}'.")
632
+
633
+
634
+ def _check_async_save(async_save):
635
+ """Check async_save for save_checkpoint."""
636
+ if not isinstance(async_save, (bool, str)):
637
+ raise TypeError("For 'save_checkpoint', the parameter 'async_save' must be bool or str, "
638
+ "but got {}.".format(type(async_save)))
639
+ if isinstance(async_save, str):
640
+ if async_save not in ("process", "thread"):
641
+ raise ValueError("For 'save_checkpoint', the argument 'async_save' can only be 'process' or 'thread',"
642
+ "but got {}.".format(async_save))
643
+ return async_save
644
+
645
+
646
+ def _async_process_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False,
647
+ crc_check=False, format="ckpt", cond=None):
648
+ """Check whether the process is pulled up successfully, execute the process of saving checkpoint into file."""
649
+ with cond:
650
+ cond.notify()
651
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
531
652
 
532
653
 
533
654
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
@@ -541,13 +662,19 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
541
662
 
542
663
  Args:
543
664
  save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
544
- list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
545
- elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
546
- `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
547
- it can be the returned value of `mindspore.load_checkpoint()`.
665
+ list, or dict.
666
+
667
+ - If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
668
+ elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
669
+ `param_name` must be string, and the type of `param_data` must be parameter or Tensor).
670
+ - If dict, it can be the returned value of :func:`mindspore.load_checkpoint`.
671
+
548
672
  ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
549
673
  integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
550
- async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
674
+ async_save (Union[bool, str], optional): Whether to use asynchronous saving of the checkpoint file or
675
+ safetensors file, if True, the asynchronous thread is used by default. If the type
676
+ is string, the method of asynchronous saving, it can be "process" or "thread".
677
+ Default: ``False`` .
551
678
  append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
552
679
  of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
553
680
  enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is ``None`` , the encryption
@@ -557,9 +684,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
557
684
  Default: ``"AES-GCM"`` .
558
685
  choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
559
686
  a parameter name in string type, and the returned value is a bool.
560
- If returns ``True`` , the Parameter that matching the custom condition will be saved.
561
- If returns ``False`` , the Parameter that not matching the custom condition will not
562
- be saved. Default: ``None`` .
687
+ Default: ``None`` .
688
+
689
+ - If returns ``True`` , the Parameter that matching the custom condition will be saved.
690
+ - If returns ``False`` , the Parameter that not matching the custom condition will not
691
+ be saved.
692
+
563
693
  crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
564
694
  result to the file. Default: ``False`` .
565
695
  format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
@@ -567,8 +697,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
567
697
 
568
698
  Raises:
569
699
  TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
570
- TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
700
+ TypeError: If the parameter `integrated_save` is not bool type.
571
701
  TypeError: If the parameter `ckpt_file_name` is not string type.
702
+ TypeError: If the parameter `async_save` is not bool or string type.
703
+ ValueError: If the parameter `async_save` is string type but not in ["process", "thread"].
572
704
 
573
705
  Examples:
574
706
  >>> import mindspore as ms
@@ -596,9 +728,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
596
728
  - `Saving and Loading the Model - Saving and Loading the Model Weight
597
729
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
598
730
  """
731
+ start_save_time = time.time()
599
732
  ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
600
733
  integrated_save = Validator.check_bool(integrated_save)
601
- async_save = Validator.check_bool(async_save)
734
+ async_save = _check_async_save(async_save)
602
735
  append_dict = _check_append_dict(append_dict)
603
736
  enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
604
737
  enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
@@ -606,12 +739,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
606
739
  map_param_inc = kwargs.get('incremental', False)
607
740
  logger.info("Execute the process of saving checkpoint files.")
608
741
  global_step_num = kwargs.get('global_step_num', None)
609
- _check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, map_param_inc, global_step_num)
742
+ _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc, global_step_num)
610
743
 
611
744
  if append_dict and "__exception_save__" in append_dict:
612
745
  s1 = mindspore.hal.Stream()
613
746
  with mindspore.hal.StreamCtx(s1):
614
747
  save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
748
+ for k_name, value in append_dict.items():
749
+ if isinstance(value, (Tensor, Parameter)):
750
+ append_dict[k_name] = Tensor(Tensor_.move_to(value, "CPU", False))
615
751
  s1.synchronize()
616
752
  else:
617
753
  save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
@@ -682,23 +818,74 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
682
818
  data_list[key].append(dims)
683
819
  tensor_type = str(param["data"].dtype)
684
820
  data_list[key].append(tensor_type)
685
- data = param["data"]
821
+ data = param["data"] if async_save is False else param["data"].asnumpy()
686
822
  data_list[key].append(data)
687
823
 
824
+ from mindspore.profiler import mstx
825
+ range_id = mstx.range_start('save_checkpoint', None)
688
826
  if os.getenv("AITURBO") == "1":
689
827
  from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
690
828
  ckpt_name = os.path.basename(ckpt_file_name)
691
829
  aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
692
830
  elif async_save:
693
- data_copy = copy.deepcopy(data_list)
694
- thr = Thread(target=_exec_save,
695
- args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
696
- name="asyn_save_ckpt")
697
- thr.start()
831
+ if async_save == "process":
832
+ if sys.platform.startswith("win"):
833
+ logger.warining("The Win platform currently does not support asynchronous process saving of ckpt, "
834
+ "so serial saving of ckpt is used now.")
835
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
836
+ else:
837
+ _wait_async_process_save_ckpt()
838
+ ctx = mp.get_context("fork")
839
+ cond = ctx.Condition()
840
+ process_flag = True
841
+ while process_flag:
842
+ process = ctx.Process(target=_async_process_save,
843
+ args=(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check,
844
+ format, cond), daemon=True, name="asyn_save_ckpt")
845
+ process.start()
846
+ with cond:
847
+ wait_flag = cond.wait(timeout=5)
848
+ if not wait_flag:
849
+ logger.warning("Async save process fails to create. will kill and recreate")
850
+ process.kill()
851
+ else:
852
+ process_flag = False
853
+ else:
854
+ data_copy = copy.deepcopy(data_list)
855
+ _wait_async_thread_save_ckpt()
856
+ thr = Thread(target=_exec_save,
857
+ args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
858
+ name="asyn_save_ckpt")
859
+ thr.start()
698
860
  else:
699
861
  _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
700
862
 
863
+ mstx.range_end(range_id)
701
864
  logger.info("Saving checkpoint process is finished.")
865
+ end_save_time = time.time()
866
+ save_checkpoint_cost_time = end_save_time - start_save_time
867
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save checkpoint cost time {save_checkpoint_cost_time}.")
868
+
869
+
870
+ def _handle_shared_param_for_pipeline_parallel(save_obj):
871
+ """ Remove shared param for save_obj """
872
+ filtered_save_obj = []
873
+ for param_dict in save_obj:
874
+ cur_param = param_dict['data']
875
+ if isinstance(cur_param, Parameter):
876
+ if not cur_param.param_info.is_pipeline_shared_param:
877
+ filtered_save_obj.append(param_dict)
878
+ else:
879
+ filtered_save_obj.append(param_dict)
880
+ return filtered_save_obj
881
+
882
+
883
+ def _is_auto_parallel_mode(save_obj):
884
+ """Check if in auto parallel mode by verifying parameter initialization."""
885
+ for _, param in save_obj.parameters_and_names():
886
+ if param.param_info.is_param_init:
887
+ return True
888
+ return False
702
889
 
703
890
 
704
891
  def _convert_list_to_param_list(save_obj, choice_func):
@@ -739,7 +926,7 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
739
926
  """Convert a dict of Parameter to param_list."""
740
927
  param_list = []
741
928
  for (key, value) in save_obj.items():
742
- if isinstance(key, str) and isinstance(value, (Parameter, str)):
929
+ if isinstance(key, str) and (isinstance(value, (Parameter, str)) or _is_buffer_type(value)):
743
930
  if choice_func is not None and not choice_func(key):
744
931
  continue
745
932
  each_param = {"name": key, "data": value}
@@ -751,15 +938,19 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
751
938
  return param_list
752
939
 
753
940
 
754
- def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
941
+ def _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode):
755
942
  """Convert cell.parameters_and_names to OrderedDict."""
756
943
  param_dict = OrderedDict()
757
944
  for _, param in save_obj.parameters_and_names():
945
+ if param.name.startswith("accu_grads") or param.name.endswith("expert_load"):
946
+ continue
758
947
  not_sliced = not param.sliced
759
948
  is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
760
949
  # All parameters are initialized immediately under PyNative mode, skip this judgement.
761
950
  judgment = not_sliced or param.has_init
762
- if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
951
+ if param.param_info.is_pipeline_shared_param:
952
+ continue
953
+ if is_graph_mode and is_parallel_mode and judgment:
763
954
  continue
764
955
  if choice_func is not None and not choice_func(param.name):
765
956
  continue
@@ -777,11 +968,12 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
777
968
  sync_pipeline_shared_parameters(save_obj)
778
969
  param_list = []
779
970
  parameter_layout_dict = save_obj.parameter_layout_dict
780
- if _is_in_auto_parallel_mode() and not parameter_layout_dict:
971
+ is_parallel_mode = _is_auto_parallel_mode(save_obj)
972
+ if is_parallel_mode and not parameter_layout_dict:
781
973
  parameter_layout_dict = _get_parameter_layout()
782
- if not _is_in_auto_parallel_mode():
974
+ if not is_parallel_mode:
783
975
  save_obj.init_parameters_data()
784
- param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
976
+ param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode)
785
977
  if append_dict and "random_op" in append_dict:
786
978
  phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
787
979
  if phase in save_obj.compile_cache and _executor.has_compiled(phase):
@@ -829,11 +1021,14 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
829
1021
 
830
1022
  def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
831
1023
  """Convert a save_obj to param_list."""
832
- if isinstance(save_obj, list):
833
- return _convert_list_to_param_list(save_obj, choice_func)
1024
+ if isinstance(save_obj, (list, dict)):
1025
+ if isinstance(save_obj, list):
1026
+ save_obj = _convert_list_to_param_list(save_obj, choice_func)
1027
+
1028
+ if isinstance(save_obj, dict):
1029
+ save_obj = _convert_dict_to_param_dict(save_obj, choice_func)
834
1030
 
835
- if isinstance(save_obj, dict):
836
- return _convert_dict_to_param_dict(save_obj, choice_func)
1031
+ return _handle_shared_param_for_pipeline_parallel(save_obj)
837
1032
 
838
1033
  return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
839
1034
 
@@ -864,11 +1059,8 @@ def _check_append_dict(append_dict):
864
1059
  return append_dict
865
1060
 
866
1061
 
867
- def _check_load_obfuscate(**kwargs):
868
- if 'obf_func' in kwargs.keys():
869
- customized_func = _check_customized_func(kwargs.get('obf_func'))
870
- clean_funcs()
871
- add_opaque_predicate(customized_func.__name__, customized_func)
1062
+ def _is_buffer_type(value):
1063
+ if isinstance(value, Tensor) and getattr(value, "_is_buffer", False):
872
1064
  return True
873
1065
  return False
874
1066
 
@@ -885,20 +1077,18 @@ def load(file_name, **kwargs):
885
1077
  kwargs (dict): Configuration options dictionary.
886
1078
 
887
1079
  - dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
888
- - dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
1080
+ - dec_mode (Union[str, function], optional):
1081
+ Specifies the decryption mode, to take effect when dec_key is set.
889
1082
 
890
1083
  - Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
891
1084
  - For details of using the customized decryption, please check the `tutorial
892
1085
  <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
893
1086
 
894
- - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
895
- `obfuscate_model()
896
- <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
897
-
898
1087
  Returns:
899
1088
  GraphCell, a compiled graph that can executed by `GraphCell`.
900
1089
 
901
1090
  Raises:
1091
+ NotImplementedError: Dynamic model structure obfuscation is no longer supported.
902
1092
  ValueError: MindIR file does not exist or `file_name` is not a string.
903
1093
  RuntimeError: Failed to parse MindIR file.
904
1094
 
@@ -925,6 +1115,8 @@ def load(file_name, **kwargs):
925
1115
  - `Saving and Loading the Model - Saving and Loading MindIR
926
1116
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
927
1117
  """
1118
+ if 'obf_func' in kwargs.keys():
1119
+ raise NotImplementedError("Dynamic model structure obfuscation is no longer supported.")
928
1120
  if not isinstance(file_name, str):
929
1121
  raise ValueError("For 'load', the argument 'file_name' must be string, but "
930
1122
  "got {}.".format(type(file_name)))
@@ -936,9 +1128,6 @@ def load(file_name, **kwargs):
936
1128
  "please check whether the 'file_name' is correct.")
937
1129
  file_name = os.path.realpath(file_name)
938
1130
 
939
- # set customized functions for dynamic obfuscation
940
- obfuscated = _check_load_obfuscate(**kwargs)
941
-
942
1131
  logger.info("Execute the process of loading mindir.")
943
1132
  if 'dec_key' in kwargs.keys():
944
1133
  dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
@@ -951,9 +1140,9 @@ def load(file_name, **kwargs):
951
1140
  else:
952
1141
  dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
953
1142
  graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
954
- decrypt=dec_func, obfuscated=obfuscated)
1143
+ decrypt=dec_func)
955
1144
  else:
956
- graph = load_mindir(file_name, obfuscated=obfuscated)
1145
+ graph = load_mindir(file_name)
957
1146
 
958
1147
  if graph is None:
959
1148
  if _is_cipher_file(file_name):
@@ -1020,189 +1209,45 @@ def _check_param_type(param_config, key, target_type, requested):
1020
1209
  if key in param_config:
1021
1210
  if not isinstance(param_config[key], target_type):
1022
1211
  raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
1023
- if key == 'obf_random_seed':
1024
- if param_config[key] > INT_64_MAX or param_config[key] <= 0:
1025
- raise ValueError(
1026
- "'obf_random_seed' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX,
1027
- param_config[key]))
1028
1212
  return param_config[key]
1029
1213
  if requested:
1030
1214
  raise ValueError("The parameter {} is requested, but not got.".format(key))
1031
- if key == "obf_random_seed":
1032
- return 0
1033
1215
  return None
1034
1216
 
1035
1217
 
1036
- def _check_customized_func(customized_func):
1037
- """ check customized function of dynamic obfuscation """
1038
- if not callable(customized_func):
1039
- raise TypeError(
1040
- "'customized_func' must be a function, but not got {}.".format(type(customized_func)))
1041
- # test customized_func
1042
- try:
1043
- func_result = customized_func(1.0, 1.0)
1044
- except Exception as ex:
1045
- raise TypeError("customized_func must be a function with two inputs, but got exception: {}".format(ex))
1046
- else:
1047
- if not isinstance(func_result, bool):
1048
- raise TypeError("Return value of customized_func must be boolean, but got: {}".format(type(func_result)))
1049
- return customized_func
1050
-
1051
-
1052
- def _check_obfuscate_params(obf_config):
1053
- """Check obfuscation parameters, including obf_random_seed, obf_ratio, customized_func"""
1054
- if 'obf_random_seed' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
1055
- raise ValueError(
1056
- "At least one of 'obf_random_seed' or 'customized_func' must be set in obf_config, but got None of them.")
1057
- obfuscate_type = _check_param_type(obf_config, "type", str, False)
1058
- if obfuscate_type not in (None, "dynamic"):
1059
- raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
1060
- if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str):
1061
- if obf_config['obf_ratio'] not in ["small", "medium", "large"]:
1062
- raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format(
1063
- obf_config['obf_ratio']))
1064
- ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6}
1065
- obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio'])
1066
- obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True)
1067
- if (obf_ratio <= 0) or (obf_ratio > 1):
1068
- raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
1069
- customized_funcs = []
1070
- if 'customized_func' in obf_config.keys():
1071
- device_target = context.get_context('device_target')
1072
- if device_target in ["GPU", "Ascend"]:
1073
- raise ValueError(
1074
- "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
1075
- customized_funcs.append(_check_customized_func(obf_config['customized_func']))
1076
- obf_random_seed = _check_param_type(obf_config, "obf_random_seed", int, False)
1077
- return obf_ratio, customized_funcs, obf_random_seed
1078
-
1079
-
1080
- def obfuscate_model(obf_config, **kwargs):
1081
- """
1082
- Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its
1083
- predict correctness. The obfuscated model can prevent attackers from stealing the model.
1084
-
1085
- Args:
1086
- obf_config (dict): obfuscation config.
1087
-
1088
- - type (str): The type of obfuscation, only 'dynamic' is supported until now.
1089
- - original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original
1090
- model is encrypted, then enc_key and enc_mode should be provided.
1091
- - save_model_path (str): The path to save the obfuscated model.
1092
- - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
1093
- is the same as using :func:`mindspore.export`.
1094
- - obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
1095
- should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
1096
- correspond to 0.1, 0.3, and 0.6 respectively.
1097
- - customized_func (function): A python function used for customized function mode, which used for control
1098
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1099
- Reference to 'my_func()' in
1100
- `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
1101
- This function needs to ensure that its result is constant for any input. Users can refer to opaque
1102
- predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
1103
- when loading obfuscated model.
1104
- - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1105
- structure of obfuscated models corresponding to different random seeds is different. If
1106
- `obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
1107
- interface when loading
1108
- obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1109
- be set, and the latter mode would be applied if both of them are set.
1110
-
1111
- kwargs (dict): Configuration options dictionary.
1112
-
1113
- - enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
1114
- - enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
1115
- Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
1116
-
1117
- Raises:
1118
- TypeError: If `obf_config` is not a dict.
1119
- ValueError: If `enc_key` is passed and `enc_mode` is not in ["AES-GCM", "AES-CBC", "SM4-CBC"].
1120
- ValueError: If `original_model_path` is not provided in `obf_config`.
1121
- ValueError: If the model saved in `original_model_path` has been obfuscated.
1122
- ValueError: If `save_model_path` is not provided in `obf_config`.
1123
- ValueError: If `obf_ratio` is not provided in `obf_config`.
1124
- ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
1125
- ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
1126
- ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
1127
-
1128
- Examples:
1129
- >>> import mindspore as ms
1130
- >>> import mindspore.nn as nn
1131
- >>> import numpy as np
1132
- >>> # Download ori_net.mindir
1133
- >>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
1134
- >>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
1135
- >>> obf_config = {'original_model_path': "./net.mindir",
1136
- ... 'save_model_path': "./obf_net",
1137
- ... 'model_inputs': [input1, ],
1138
- ... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
1139
- >>> ms.obfuscate_model(obf_config)
1140
- >>> obf_func = ms.load("obf_net.mindir")
1141
- >>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
1142
- >>> print(obf_net(input1).asnumpy())
1143
- """
1144
- if not isinstance(obf_config, dict):
1145
- raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config)))
1146
- file_path = _check_param_type(obf_config, "original_model_path", str, True)
1147
- if not file_path.endswith(".mindir"):
1148
- raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) should end with '.mindir', "
1149
- "please input the correct 'file_path'.")
1150
- if not os.path.exists(file_path):
1151
- raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) does not exist, "
1152
- "please check whether the 'file_path' is correct.")
1153
- saved_path = _check_param_type(obf_config, "save_model_path", str, True)
1154
- model_inputs = _check_param_type(obf_config, "model_inputs", list, True)
1155
- for item in model_inputs:
1156
- if not isinstance(item, Tensor):
1157
- raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item)))
1158
- if -1 in item.shape:
1159
- raise ValueError(
1160
- "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
1161
- obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(obf_config)
1162
- if customized_funcs and obf_random_seed > 0:
1163
- logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
1164
- " applied, remember to set 'obf_random_seed' when loading obfuscated model.")
1165
-
1166
- if obf_random_seed == 0: # apply customized_func mode
1167
- clean_funcs()
1168
- for func in customized_funcs:
1169
- add_opaque_predicate(func.__name__, func)
1170
- branch_control_input = 0
1171
- else: # apply password mode
1172
- branch_control_input = _generate_branch_control_input(obf_random_seed)
1173
-
1174
- if 'enc_key' in kwargs.keys():
1175
- enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
1176
- enc_mode = "AES-GCM"
1177
- if 'enc_mode' in kwargs.keys():
1178
- enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
1179
- if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
1180
- raise ValueError(
1181
- "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
1182
- "obfuscate_model(), but got {}.".format(enc_mode))
1183
- obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
1184
- branch_control_input=branch_control_input, dec_key=enc_key,
1185
- key_len=len(enc_key),
1186
- dec_mode=enc_mode)
1187
- else:
1188
- obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
1189
- branch_control_input=branch_control_input)
1190
-
1191
- obf_net = nn.GraphCell(obf_graph)
1192
- if obf_random_seed != 0:
1193
- append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
1194
- model_inputs += [append_y_tensor]
1195
- export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
1196
-
1197
-
1198
1218
  def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1199
1219
  dec_mode, crc_check, format):
1200
1220
  """load parameter into parameter_dict"""
1201
1221
  ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
1202
1222
  if format == "safetensors":
1203
1223
  with safe_open(ckpt_file_name, framework='np') as f:
1204
- for k in f.keys():
1205
- parameter_dict[k] = Parameter(f.get_tensor(k))
1224
+ cal_crc_num = 0
1225
+ total_io_cost_time = 0
1226
+ for k in sorted(f.keys()):
1227
+ if crc_check:
1228
+ cal_crc_num = binascii.crc32(bytes(k, encoding='utf-8'), cal_crc_num)
1229
+ cal_crc_num = binascii.crc32(bytes(f.get_tensor(k)), cal_crc_num)
1230
+ if choice_func is not None and not choice_func(k):
1231
+ continue
1232
+ io_start_time = time.time()
1233
+ value = f.get_tensor(k)
1234
+ io_end_time = time.time()
1235
+ io_cost_time = io_end_time - io_start_time
1236
+ total_io_cost_time += io_cost_time
1237
+ parameter_dict[k] = Parameter(Tensor.from_numpy(value))
1238
+
1239
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1240
+ f"Load safetensors io cost time:{total_io_cost_time}.")
1241
+ if crc_check:
1242
+ if f.metadata() is None or f.metadata().get("crc_num") is None:
1243
+ logger.warning(
1244
+ "For 'load_checkpoint', the safetensors file do not contain the crc code, "
1245
+ "please check the file.")
1246
+ else:
1247
+ crc_num = int(f.metadata()["crc_num"])
1248
+ if cal_crc_num != crc_num:
1249
+ raise ValueError("For 'load_checkpoint', the crc check has failed. "
1250
+ "Please check whether the ckpt file is damaged.")
1206
1251
  return
1207
1252
  checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
1208
1253
  try:
@@ -1270,38 +1315,37 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1270
1315
  Load checkpoint info from a specified file.
1271
1316
 
1272
1317
  Note:
1273
- - `specify_prefix` and `filter_prefix` do not affect each other.
1274
- - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1275
1318
  - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1276
- `choice_func` is recommended instead.
1319
+ `choice_func` is recommended instead. `specify_prefix` and `filter_prefix` do not affect each other.
1277
1320
  And using either of those two args will override `choice_func` at the same time.
1321
+ - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1278
1322
  - When loading a checkpoint that has removed redundancy, the network should be compiled.
1279
1323
 
1280
1324
  Args:
1281
1325
  ckpt_file_name (str): Checkpoint file name.
1282
- net (Cell): The network where the parameters will be loaded. Default: ``None`` .
1283
- strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
1284
- into net when parameter name's suffix in checkpoint file is the same as the
1326
+ net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
1327
+ strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
1328
+ parameter into net when parameter name's suffix in checkpoint file is the same as the
1285
1329
  parameter in the network. When the types are inconsistent perform type conversion
1286
1330
  on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1287
- filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
1288
- filter_prefix will not be loaded. Default: ``None`` .
1289
- dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
1290
- is not required. Default: ``None`` .
1291
- dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
1292
- mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
1331
+ filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`).
1332
+ Parameters starting with the filter_prefix will not be loaded. Default: ``None`` .
1333
+ dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
1334
+ the decryption is not required. Default: ``None`` .
1335
+ dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies the
1336
+ decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
1293
1337
  Default: ``"AES-GCM"`` .
1294
- specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
1295
- specify_prefix will be loaded. Default: ``None`` .
1296
- choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
1338
+ specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`).
1339
+ Parameters starting with the specify_prefix will be loaded. Default: ``None`` .
1340
+ choice_func (Union[None, function], optional) : Input value of the function is a Parameter name of type string,
1297
1341
  and the return value is a bool. If returns ``True`` , the Parameter
1298
1342
  that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1299
1343
  matches the custom condition will be removed. Default: ``None`` .
1300
- crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
1301
- remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1344
+ crc_check (bool, optional) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
1345
+ remove_redundancy (bool, optional): Whether to enable loading of checkpoint saved with redundancy removal.
1302
1346
  Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1303
1347
  redundant-free loading is not enabled.
1304
- format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
1348
+ format (str, optional): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
1305
1349
 
1306
1350
  Returns:
1307
1351
  Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
@@ -1346,13 +1390,15 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1346
1390
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1347
1391
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1348
1392
  """
1393
+ start_load_time = time.time()
1394
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin load checkpoint.")
1349
1395
  specify_prefix = _check_prefix(specify_prefix)
1350
1396
  filter_prefix = _check_prefix(filter_prefix)
1351
1397
  dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1352
1398
  dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1353
1399
  crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
1354
1400
  remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1355
- _check_format_and_other_params(format, dec_key, dec_mode, crc_check)
1401
+ _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode)
1356
1402
  logger.info("Execute the process of loading checkpoint files.")
1357
1403
 
1358
1404
  parameter_dict = {}
@@ -1392,6 +1438,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1392
1438
  if _warm_up_host_cache_enabled(parameter_dict):
1393
1439
  _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1394
1440
 
1441
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Load checkpoint is finished.")
1442
+ end_load_time = time.time()
1443
+ load_checkpoint_cost_time = end_load_time - start_load_time
1444
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load checkpoint cost time {load_checkpoint_cost_time}.")
1395
1445
  return parameter_dict
1396
1446
 
1397
1447
 
@@ -1411,7 +1461,7 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
1411
1461
  And using either of those two args will override `choice_func` at the same time.
1412
1462
 
1413
1463
  Args:
1414
- ckpt_file_name (str): Checkpoint file name.
1464
+ ckpt_file_name (str): Checkpoint file name. The file extension must be `ckpt` or `safetensors` .
1415
1465
  net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
1416
1466
  strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
1417
1467
  parameter into net when parameter name's suffix in checkpoint file is the
@@ -1448,7 +1498,8 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
1448
1498
  >>> from mindspore import context
1449
1499
  >>> from mindspore import load_checkpoint_async
1450
1500
  >>> from mindspore import load_param_into_net
1451
- >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
1501
+ >>> mindspore.set_device(device_target="Ascend")
1502
+ >>> context.set_context(mode=context.GRAPH_MODE)
1452
1503
  >>> # Create the dataset taking MNIST as an example. Refer to
1453
1504
  >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
1454
1505
  >>> dataset = create_dataset()
@@ -1468,10 +1519,11 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
1468
1519
  >>> model.train(2, dataset)
1469
1520
  >>> print("param dict len: ", len(param_dict), flush=True)
1470
1521
  """
1522
+ format = "safetensors" if ckpt_file_name.endswith(".safetensors") else "ckpt"
1471
1523
  from concurrent.futures import ThreadPoolExecutor
1472
1524
  executor = ThreadPoolExecutor(max_workers=2)
1473
1525
  param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
1474
- dec_key, dec_mode, specify_prefix, choice_func)
1526
+ dec_key, dec_mode, specify_prefix, choice_func, format=format)
1475
1527
  return ParamDictFuture(executor, param_dict_future)
1476
1528
 
1477
1529
 
@@ -1555,7 +1607,12 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1555
1607
  try:
1556
1608
  if dec_key is None:
1557
1609
  with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
1610
+ ckpt_load_time_start = time.time()
1558
1611
  pb_content = f.read()
1612
+ ckpt_load_time_end = time.time()
1613
+ cost_time = ckpt_load_time_end - ckpt_load_time_start
1614
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load ckpt io cost time:{cost_time}.")
1615
+
1559
1616
  else:
1560
1617
  pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
1561
1618
  if pb_content is None:
@@ -1625,17 +1682,18 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1625
1682
  Load parameters into network, return parameter list that are not loaded in the network.
1626
1683
 
1627
1684
  Note:
1628
- - When loading a parameter dict that has removed redundancy, the network should be compiled.
1685
+ When loading a parameter dict that has removed redundancy, the network should be compiled.
1629
1686
 
1630
1687
  Args:
1631
1688
  net (Cell): The network where the parameters will be loaded.
1632
1689
  parameter_dict (dict): The dictionary generated by load checkpoint file,
1633
1690
  it is a dictionary consisting of key: parameters's name, value: parameter.
1634
- strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
1691
+ strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` ,
1692
+ it will load parameter
1635
1693
  into net when parameter name's suffix in checkpoint file is the same as the
1636
1694
  parameter in the network. When the types are inconsistent perform type conversion
1637
1695
  on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1638
- remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1696
+ remove_redundancy (bool, optional): Whether to enable loading of checkpoint saved with redundancy removal.
1639
1697
  Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1640
1698
  redundant-free loading is not enabled.
1641
1699
 
@@ -1673,11 +1731,11 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1673
1731
  strict_load = Validator.check_bool(strict_load)
1674
1732
  remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1675
1733
  logger.info("Execute the process of loading parameters into net.")
1676
- for _, param in net.parameters_and_names():
1677
- param.from_ckpt = True
1678
1734
  param_not_load = []
1679
1735
  ckpt_not_load = list(parameter_dict.keys())
1680
1736
  for _, param in net.parameters_and_names():
1737
+ if param.param_info.is_pipeline_shared_param:
1738
+ continue
1681
1739
  if param.name in parameter_dict:
1682
1740
  if isinstance(param, MapParameter):
1683
1741
  param.import_data(parameter_dict[param.name])
@@ -1696,31 +1754,24 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1696
1754
  if param_not_load and not strict_load:
1697
1755
  _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
1698
1756
 
1699
- logger.info("Loading parameters into net is finished.")
1700
- if param_not_load:
1701
- logger.warning("For 'load_param_into_net', "
1702
- "{} parameters in the 'net' are not loaded, because they are not in the "
1703
- "'parameter_dict', please check whether the network structure is consistent "
1704
- "when training and loading checkpoint. Another possibility is that "
1705
- "the redundant loading is not enabled, but the loaded checkpoint is saved with "
1706
- "redundancy removed. ".format(len(param_not_load)))
1707
- logger.warning("{} are not loaded.".format(param_not_load))
1708
1757
  if remove_redundancy:
1709
- parallel_mode = context.get_auto_parallel_context("parallel_mode")
1710
- if parallel_mode == "stand_alone":
1758
+ if get_group_size() == 1:
1711
1759
  raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
1712
- f"in parallel scenarios, but got {parallel_mode}.")
1760
+ f"in parallel scenarios, but got stand_alone.")
1713
1761
  if not net.compile_cache and not net.parameter_layout_dict:
1714
1762
  raise ValueError("When loading a parameter dict that has removed redundancy, "
1715
1763
  "the network should be compiled.")
1716
1764
  param_layout = net.parameter_layout_dict
1717
- rank_id = get_rank()
1718
- device_num = _get_device_num()
1719
- stage_num = _get_auto_parallel_context("pipeline_stages")
1720
- chunk_size = device_num // stage_num
1721
- initial_rank = (rank_id // chunk_size) * chunk_size
1722
- _single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
1765
+ _single_parameter_broadcast(net, param_layout, param_not_load)
1766
+ mindspore.hal.synchronize()
1723
1767
 
1768
+ logger.info("Loading parameters into net is finished.")
1769
+ if param_not_load:
1770
+ logger.warning("For 'load_param_into_net', "
1771
+ "{} parameters in the 'net' are not loaded, because they are not in the "
1772
+ "'parameter_dict', please check whether the network structure is consistent "
1773
+ "when training and loading checkpoint.".format(len(param_not_load)))
1774
+ logger.warning("{} are not loaded.".format(param_not_load))
1724
1775
  return param_not_load, ckpt_not_load
1725
1776
 
1726
1777
 
@@ -1903,9 +1954,6 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
1903
1954
  elif opt_shard_group:
1904
1955
  allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1905
1956
  tuple(after_reshape_slice_shape))
1906
- elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
1907
- allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1908
- tuple(after_reshape_slice_shape))
1909
1957
  net.parallel_parameter_merge_net_dict[param_name] = allgather_net
1910
1958
  if allgather_net:
1911
1959
  param_data = allgather_net(param_data)
@@ -1959,27 +2007,6 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1959
2007
 
1960
2008
  - dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
1961
2009
  preprocessing of the dataset into MindIR.
1962
-
1963
- - obf_config (dict): obfuscation config.
1964
-
1965
- - type (str): The type of obfuscation, only 'dynamic' is supported until now.
1966
- - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
1967
- should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
1968
- correspond to 0.1, 0.3, and 0.6 respectively.
1969
- - customized_func (function): A python function used for customized function mode, which used for control
1970
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1971
- Reference to 'my_func()' in
1972
- `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
1973
- This function needs to ensure that its result is constant for any input. Users can refer to opaque
1974
- predicates. If customized_func is set, then it should be passed to `load()` interface when loading
1975
- obfuscated model.
1976
- - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1977
- structure of obfuscated models corresponding to different random seeds is different. If
1978
- `obf_random_seed` is set, then it should be passed
1979
- to :class:`mindspore.nn.GraphCell` interface when loading
1980
- obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1981
- be set, and the latter mode would be applied if both of them are set.
1982
-
1983
2010
  - incremental (bool): export MindIR incrementally.
1984
2011
 
1985
2012
  - custom_func (function): Functions for custom defined export policies. This function will be used to
@@ -2013,6 +2040,8 @@ def export(net, *inputs, file_name, file_format, **kwargs):
2013
2040
  - `Saving and Loading the Model - Saving and Loading MindIR
2014
2041
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
2015
2042
  """
2043
+ if 'obf_func' in kwargs.keys():
2044
+ raise NotImplementedError("Dynamic model structure obfuscation is no longer supported.")
2016
2045
  old_ms_jit_value = context.get_context("jit_syntax_level")
2017
2046
  context.set_context(jit_syntax_level=mindspore.STRICT)
2018
2047
 
@@ -2094,9 +2123,7 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
2094
2123
  It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
2095
2124
  """
2096
2125
  logger.info("exporting model file:%s format:%s.", file_name, file_format)
2097
- if "obf_config" in kwargs and file_format != "MINDIR":
2098
- raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
2099
- if "custom_func" in kwargs and file_format != "MINDIR":
2126
+ if "custom_func" in kwargs and file_format != "MINDIR" and kwargs["custom_func"] is not None:
2100
2127
  raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
2101
2128
  if file_format == 'AIR':
2102
2129
  _save_air(net, file_name, *inputs, **kwargs)
@@ -2309,14 +2336,13 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
2309
2336
  os.chmod(data_file_name, stat.S_IRUSR)
2310
2337
 
2311
2338
 
2312
- def _msfunc_info(net, *inputs):
2339
+ def _msfunc_info(net, jit_executor, *inputs):
2313
2340
  """Get mindir stream and parameter dict of ms_function"""
2314
2341
  # pylint: disable=protected-access
2315
2342
  net_dict = OrderedDict()
2316
- _ms_func_executor = _MindsporeFunctionExecutor(net, time.time() * 1e9)
2317
- graph_id = _ms_func_executor.compile(net.__name__, *inputs)
2318
- mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
2319
- params = _ms_func_executor._graph_executor.get_params(graph_id)
2343
+ graph_id = jit_executor.compile(net.__name__, *inputs)
2344
+ mindir_stream = jit_executor._get_func_graph_proto(net, graph_id, 'mind_ir')
2345
+ params = jit_executor._graph_executor.get_params(graph_id)
2320
2346
  for name, value in params.items():
2321
2347
  net_dict[name] = Parameter(value, name=name)
2322
2348
  return mindir_stream, net_dict
@@ -2328,53 +2354,21 @@ def _cell_info(net, incremental, *inputs):
2328
2354
  graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
2329
2355
  # pylint: disable=protected-access
2330
2356
  mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
2331
- # clean obfuscation config to prevent the next call
2332
- _executor.obfuscate_config = None
2333
-
2334
2357
  net_dict = net.parameters_dict()
2335
2358
  return mindir_stream, net_dict
2336
2359
 
2337
2360
 
2338
- def _set_obfuscate_config(**kwargs):
2339
- """Set obfuscation config for executor."""
2340
- logger.warning("Obfuscate model.")
2341
- if 'enc_mode' in kwargs.keys():
2342
- enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
2343
- if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
2344
- raise ValueError(
2345
- "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
2346
- "obfuscation, but got {}.".format(enc_mode))
2347
- obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(kwargs.get('obf_config'))
2348
- if customized_funcs and obf_random_seed > 0:
2349
- logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
2350
- " applied, remember to set 'obf_random_seed' when loading obfuscated model.")
2351
-
2352
- if obf_random_seed == 0: # apply customized_func mode
2353
- device_target = context.get_context('device_target')
2354
- if device_target in ["GPU", "Ascend"]:
2355
- raise ValueError(
2356
- "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
2357
- clean_funcs()
2358
- for func in customized_funcs:
2359
- add_opaque_predicate(func.__name__, func)
2360
- _executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_random_seed': obf_random_seed}
2361
-
2362
-
2363
2361
  def _save_mindir(net, file_name, *inputs, **kwargs):
2364
2362
  """Save MindIR format file."""
2365
- # set obfuscate configs
2366
- if 'obf_config' in kwargs.keys():
2367
- _set_obfuscate_config(**kwargs)
2368
- for item in inputs:
2369
- if -1 in item.shape:
2370
- raise ValueError(
2371
- "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
2363
+ executor = _executor
2364
+ if not isinstance(net, nn.Cell):
2365
+ executor = _JitExecutor(net, time.time() * 1e9)
2372
2366
 
2373
2367
  incremental = kwargs.get('incremental', False)
2374
2368
 
2375
2369
  model = mindir_model()
2376
2370
  if not isinstance(net, nn.Cell):
2377
- mindir_stream, net_dict = _msfunc_info(net, *inputs)
2371
+ mindir_stream, net_dict = _msfunc_info(net, executor, *inputs)
2378
2372
  else:
2379
2373
  mindir_stream, net_dict = _cell_info(net, incremental, *inputs)
2380
2374
  model.ParseFromString(mindir_stream)
@@ -2447,8 +2441,10 @@ def _save_together(net_dict, model):
2447
2441
  if name in net_dict.keys():
2448
2442
  data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
2449
2443
  else:
2450
- raise ValueError("The parameter '{}' is not belongs to any cell,"
2451
- "the data of parameter cannot be exported.".format(param_proto.name))
2444
+ raise ValueError("There's a mindspore.Parameter that wasn't created in nn.Cell, and mindspore.export() "
2445
+ f"does not support exporting such Parameters. The parameter name is: {name}.\n"
2446
+ "You can find the supported syntax range for mindspore.export() at the following link:\n"
2447
+ "https://www.mindspore.cn/tutorials/zh-CN/master/beginner/save_load.html")
2452
2448
  if data_total > TOTAL_SAVE:
2453
2449
  return False
2454
2450
  return True
@@ -2478,6 +2474,9 @@ def check_checkpoint(ckpt_file_name):
2478
2474
  """
2479
2475
  Check whether the checkpoint is valid.
2480
2476
 
2477
+ Note:
2478
+ The interface is deprecated from version 2.5 and will be removed in a future version.
2479
+
2481
2480
  Args:
2482
2481
  ckpt_file_name (str): Checkpoint file name.
2483
2482
 
@@ -2491,6 +2490,8 @@ def check_checkpoint(ckpt_file_name):
2491
2490
  >>> print(check_result)
2492
2491
  True
2493
2492
  """
2493
+ logger.warning("The interface 'mindspore.check_checkpoint' is deprecated from version 2.5 "
2494
+ "and will be removed in a future version.")
2494
2495
  if not ckpt_file_name.endswith('.ckpt'):
2495
2496
  return False
2496
2497
  checkpoint_list = Checkpoint()
@@ -2517,6 +2518,9 @@ def parse_print(print_file_name):
2517
2518
  """
2518
2519
  Parse data file generated by :class:`mindspore.ops.Print`.
2519
2520
 
2521
+ Note:
2522
+ The interface is deprecated from version 2.5 and will be removed in a future version.
2523
+
2520
2524
  Args:
2521
2525
  print_file_name (str): The file name needs to be parsed.
2522
2526
 
@@ -2551,6 +2555,8 @@ def parse_print(print_file_name):
2551
2555
  [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
2552
2556
  [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
2553
2557
  """
2558
+ logger.warning("The interface 'mindspore.parse_print' is deprecated from version 2.5 "
2559
+ "and will be removed in a future version.")
2554
2560
  print_file_path = os.path.realpath(print_file_name)
2555
2561
 
2556
2562
  if os.path.getsize(print_file_path) == 0:
@@ -2605,548 +2611,13 @@ def parse_print(print_file_name):
2605
2611
  return tensor_list
2606
2612
 
2607
2613
 
2608
- def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
2609
- """
2610
- Merge data slices to one tensor with whole data when strategy is not None.
2611
-
2612
- Args:
2613
- sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
2614
- parameter_name (str): Name of parameter.
2615
- strategy (dict): Parameter slice strategy.
2616
- is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
2617
-
2618
- Returns:
2619
- Tensor, the merged Tensor which has the whole data.
2620
-
2621
- Raises:
2622
- ValueError: Failed to merge.
2623
- """
2624
- layout = strategy.get(parameter_name)
2625
- try:
2626
- dev_mat = list(layout.dev_matrix[0].dim)
2627
- tensor_map = list(layout.tensor_map[0].dim)
2628
- param_split_shape = list(layout.param_split_shape[0].dim)
2629
- field_size = int(layout.field)
2630
- except BaseException as e:
2631
- raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'"
2632
- f", please make sure that 'strategy' is correct.") from e
2633
-
2634
- device_count = 1
2635
- for dim in dev_mat:
2636
- device_count *= dim
2637
-
2638
- if len(sliced_data) != device_count:
2639
- raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to "
2640
- f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but "
2641
- f"device_count is {device_count}.")
2642
-
2643
- if not param_split_shape:
2644
- if not is_even:
2645
- raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' "
2646
- "should be the same when slice manner is even.")
2647
-
2648
- all_gather_tensor = Tensor(np.concatenate(sliced_data))
2649
-
2650
- if field_size > 0:
2651
- merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
2652
- else:
2653
- merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
2654
-
2655
- else:
2656
- tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
2657
-
2658
- slice_count = 1
2659
- for dim in tensor_strategy:
2660
- slice_count *= dim
2661
-
2662
- if len(param_split_shape) != slice_count:
2663
- raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be "
2664
- f"{slice_count}, but got {len(param_split_shape)}.")
2665
-
2666
- tensor_slices_new = list(range(slice_count))
2667
- tensor_slices = sliced_data
2668
- for i in range(device_count):
2669
- slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
2670
- if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
2671
- raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be "
2672
- f"{param_split_shape[slice_index]} in 0 axis, but got "
2673
- f"{tensor_slices[i].shape[0]}.")
2674
- tensor_slices_new[slice_index] = np.array(tensor_slices[i])
2675
-
2676
- dim_len = len(tensor_strategy)
2677
- for i in range(dim_len):
2678
- ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
2679
- tensor_slices_new_inner = []
2680
- for j in range(ele_count):
2681
- new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
2682
- for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
2683
- (j + 1) * tensor_strategy[dim_len - 1 - i]):
2684
- new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
2685
- tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
2686
- tensor_slices_new = tensor_slices_new_inner
2687
- merged_tensor = Tensor(tensor_slices_new[0])
2688
-
2689
- return merged_tensor
2690
-
2691
-
2692
- def restore_group_info_list(group_info_file_name):
2693
- """
2694
- Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
2695
- who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
2696
- environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
2697
-
2698
- Args:
2699
- group_info_file_name (str): Name of group information file.
2700
-
2701
- Returns:
2702
- List, the rank list.
2703
-
2704
- Raises:
2705
- ValueError: group information file is incorrect.
2706
- TypeError: `group_info_file_name` is not str.
2707
-
2708
- Examples:
2709
- >>> import mindspore as ms
2710
- >>> ms.restore_list = restore_group_info_list("./group_info.pb")
2711
- """
2712
- if not isinstance(group_info_file_name, str):
2713
- raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
2714
- f"but got {type(group_info_file_name)}.")
2715
-
2716
- if not os.path.isfile(group_info_file_name):
2717
- raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
2718
-
2719
- if os.path.getsize(group_info_file_name) == 0:
2720
- raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
2721
-
2722
- return _restore_group_info_list(group_info_file_name)
2723
-
2724
-
2725
- def build_searched_strategy(strategy_filename):
2726
- """
2727
- Build strategy of every parameter in network. Used in the case of distributed inference.
2728
-
2729
- Args:
2730
- strategy_filename (str): Name of strategy file.
2731
-
2732
- Returns:
2733
- Dict, whose key is parameter name and value is slice strategy of this parameter.
2734
-
2735
- Raises:
2736
- ValueError: Strategy file is incorrect.
2737
- TypeError: `strategy_filename` is not a string.
2738
-
2739
- Examples:
2740
- >>> import mindspore as ms
2741
- >>> strategy = ms.build_searched_strategy("./strategy_train.ckpt")
2742
- """
2743
- return _build_searched_strategy(strategy_filename)
2744
-
2745
-
2746
- def merge_sliced_parameter(sliced_parameters, strategy=None):
2747
- """
2748
- Merge parameter slices into one parameter. Used in the case of distributed inference.
2749
-
2750
- Args:
2751
- sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
2752
- strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
2753
- value is slice strategy of this parameter. If strategy is None, just merge
2754
- parameter slices in 0 axis order. Default: ``None``.
2755
-
2756
- Returns:
2757
- Parameter, the merged parameter which has the whole data.
2758
-
2759
- Raises:
2760
- ValueError: Failed to merge.
2761
- TypeError: The sliced_parameters is incorrect or strategy is not dict.
2762
- KeyError: The parameter name is not in keys of strategy.
2763
-
2764
- Examples:
2765
- >>> import numpy as np
2766
- >>> import mindspore as ms
2767
- >>> from mindspore import Tensor, Parameter
2768
- >>>
2769
- >>> sliced_parameters = [
2770
- ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
2771
- ... "network.embedding_table"),
2772
- ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
2773
- ... "network.embedding_table"),
2774
- ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
2775
- ... "network.embedding_table"),
2776
- ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
2777
- ... "network.embedding_table")]
2778
- >>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters)
2779
- >>> print(merged_parameter)
2780
- Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
2781
- """
2782
- if not isinstance(sliced_parameters, list):
2783
- raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
2784
- f"but got {type(sliced_parameters)}.")
2785
-
2786
- if not sliced_parameters:
2787
- raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
2788
-
2789
- if strategy and not isinstance(strategy, dict):
2790
- raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
2791
- f"but got {type(strategy)}.")
2792
-
2793
- try:
2794
- parameter_name = sliced_parameters[0].name
2795
- parameter_shape = sliced_parameters[0].data.shape
2796
- parameter_shape_length = len(parameter_shape)
2797
- except BaseException as e:
2798
- raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be "
2799
- f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e
2800
-
2801
- is_even = True
2802
- for index, parameter in enumerate(sliced_parameters):
2803
- if not isinstance(parameter, Parameter):
2804
- raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', "
2805
- f"but got {type(parameter)} at index {index}.")
2806
-
2807
- if parameter.name != parameter_name \
2808
- or len(parameter.data.shape) != parameter_shape_length \
2809
- or parameter.data.shape[1:] != parameter_shape[1:]:
2810
- raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'"
2811
- f" have the same name, dimension length and shape except 0 axis. The name, dimension "
2812
- f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, "
2813
- f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: "
2814
- f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} "
2815
- f"at index {index}.")
2816
-
2817
- if parameter.data.shape != parameter_shape:
2818
- is_even = False
2819
-
2820
- layerwise_parallel = sliced_parameters[0].layerwise_parallel
2821
- requires_grad = sliced_parameters[0].requires_grad
2822
- sliced_data = []
2823
- for parameter in sliced_parameters:
2824
- if parameter.data.dtype == mstype.bfloat16:
2825
- sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
2826
- else:
2827
- sliced_data.append(parameter.data.asnumpy())
2828
-
2829
- if not strategy:
2830
- merged_tensor = Tensor(np.concatenate(sliced_data))
2831
- merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
2832
-
2833
- else:
2834
- if parameter_name not in strategy.keys():
2835
- raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in "
2836
- f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.")
2837
- merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
2838
- merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
2839
-
2840
- return merged_parameter
2841
-
2842
-
2843
- def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
2844
- train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
2845
- format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
2846
- """
2847
- Load checkpoint into net for distributed predication. Used in the case of distributed inference.
2848
-
2849
- Args:
2850
- network (Cell): Network for distributed predication.
2851
- checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
2852
- predict_strategy (dict): Strategy of predication process. It means that using one device to predict
2853
- when setting predict_strategy as None. Default: ``None`` .
2854
- train_strategy_filename (str): The filename of training strategy protocol buffer file.
2855
- When train_strategy_filename is None, the training strategy file will be
2856
- obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
2857
- Therefore, the training strategy file needs to be specified
2858
- in at least one of them. Default: ``None`` .
2859
- strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
2860
- into net when parameter name's suffix in checkpoint file is the same as the
2861
- parameter in the network. When the types are inconsistent, perform type conversion
2862
- on the parameters of the same type, such as float32 to float16. Default: ``False`` .
2863
- dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
2864
- is not required. Default: ``None`` .
2865
- dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
2866
- mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
2867
- Default: ``'AES-GCM'`` .
2868
- format (str): Input weight format to be loaded into the network.
2869
- It can be set to either "ckpt" or "safetensors". Default: "ckpt".
2870
- unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
2871
- Default: ``None`` .
2872
- dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
2873
- rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
2874
- globally by initializing the network; In save mode, save the file according to the input
2875
- sequence number. If it is not input, save the entire file.
2876
-
2877
- Raises:
2878
- TypeError: The type of inputs do not match the requirements.
2879
- ValueError: Failed to load checkpoint into net.
2880
-
2881
- Supported Platforms:
2882
- ``Ascend`` ``GPU``
2883
-
2884
- Examples:
2885
- .. note::
2886
- Before running the following examples, you need to configure the communication environment variables.
2887
-
2888
- For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
2889
- Please see the `rank table startup
2890
- <https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
2891
- for more details.
2892
-
2893
- For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
2894
- <https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
2895
-
2896
- For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
2897
- Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
2898
-
2899
- >>> import os
2900
- >>> import numpy as np
2901
- >>> import mindspore as ms
2902
- >>> import mindspore.dataset as ds
2903
- >>> from mindspore import nn, ops, train
2904
- >>> from mindspore.communication import init
2905
- >>>
2906
- >>> step_per_epoch = 4
2907
- >>> device_num = 8
2908
- >>>
2909
- >>> # Define the network structure.
2910
- >>> class Net(nn.Cell):
2911
- ... def __init__(self, matmul_size, strategy=None):
2912
- ... super().__init__()
2913
- ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
2914
- ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
2915
- ... self.matmul = ops.MatMul()
2916
- ... self.neg = ops.Neg()
2917
- ... if strategy is not None:
2918
- ... self.matmul.shard(strategy)
2919
- ...
2920
- ... def construct(self, inputs):
2921
- ... x = self.matmul(inputs, self.matmul_weight)
2922
- ... x = self.neg(x)
2923
- ... return x
2924
- >>>
2925
- >>> # Create dataset.
2926
- >>> def get_dataset(*inputs):
2927
- ... def generate():
2928
- ... for _ in range(step_per_epoch):
2929
- ... yield inputs
2930
- ... return generate
2931
- >>>
2932
- >>> # Train network and save distributed checkpoint.
2933
- >>> def train_net():
2934
- ... ms.set_context(mode=ms.GRAPH_MODE)
2935
- ... init()
2936
- ... np.random.seed(1)
2937
- ... input_data = np.random.rand(16, 96).astype(np.float32)
2938
- ... label_data = np.random.rand(16, 16).astype(np.float32)
2939
- ... fake_dataset = get_dataset(input_data, label_data)
2940
- ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
2941
- ...
2942
- ... # Set parallel strategy.
2943
- ... strategy = ((1, 4), (4, 1))
2944
- ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
2945
- ... strategy_ckpt_save_file="./train_strategy.ckpt")
2946
- ... network = Net(matmul_size=(96, 16), strategy=strategy)
2947
- ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
2948
- ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
2949
- ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
2950
- ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
2951
- ... global_rank_id = int(os.getenv("RANK_ID"))
2952
- ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
2953
- ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
2954
- ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
2955
- ... ms.reset_auto_parallel_context()
2956
- >>>
2957
- >>> # Load distributed checkpoint and test.
2958
- >>> def load_model():
2959
- ... ms.set_context(mode=ms.GRAPH_MODE)
2960
- ... init()
2961
- ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
2962
- ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
2963
- ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
2964
- ... network = Net(matmul_size=(96, 16))
2965
- ... model = ms.Model(network)
2966
- ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
2967
- ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
2968
- ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
2969
- ... predict_result = model.predict(predict_data)
2970
- ... print(predict_result)
2971
- >>>
2972
- >>> train_net()
2973
- >>> load_model()
2974
- [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
2975
- [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
2976
- ...
2977
- [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
2978
- """
2979
- if format not in ['safetensors', 'ckpt']:
2980
- raise ValueError(
2981
- f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
2982
-
2983
- if format == 'safetensors':
2984
- if unified_safetensors_dir is None:
2985
- raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
2986
- f"when format is 'safetensors'.")
2987
- unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
2988
- for param in unsupport_param:
2989
- if param is not None:
2990
- raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
2991
- f"when format is 'safetensors'.")
2992
- if strict_load or dec_mode != 'AES-GCM':
2993
- raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
2994
- f"when format is 'safetensors'.")
2995
- if network is not None:
2996
- rank_id = get_rank()
2997
- _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
2998
- else:
2999
- if dst_safetensors_dir is None:
3000
- raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
3001
- f"when network is None.")
3002
- if rank_id is not None:
3003
- _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3004
- rank_id)
3005
- else:
3006
- dst_strategy_dict = _build_searched_strategy(predict_strategy)
3007
- dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
3008
- dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
3009
- dst_device_num = dst_stage_device_num * dst_stage_num
3010
- processes = []
3011
- activate_processes = 0
3012
- for rank in range(0, dst_device_num):
3013
- p = Process(target=_load_parallel_checkpoint, args=(
3014
- unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
3015
- p.start()
3016
- processes.append(p)
3017
- activate_processes += 1
3018
- max_processes = 64
3019
- if activate_processes >= max_processes:
3020
- p = processes.pop(0)
3021
- p.join()
3022
- activate_processes -= 1
3023
- for p in processes:
3024
- p.join()
3025
- return
3026
-
3027
- network = Validator.check_isinstance("network", network, nn.Cell)
3028
- _check_checkpoint_file(checkpoint_filenames)
3029
- _check_predict_strategy(predict_strategy)
3030
-
3031
- dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
3032
- dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
3033
-
3034
- if train_strategy_filename is None:
3035
- train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
3036
- _train_strategy = build_searched_strategy(train_strategy_filename)
3037
- train_strategy = _convert_to_list(_train_strategy)
3038
-
3039
- train_dev_count = 1
3040
- ckpt_file_len = len(checkpoint_filenames)
3041
- for dim in train_strategy[list(train_strategy.keys())[0]][0]:
3042
- train_dev_count *= dim
3043
- if train_dev_count != ckpt_file_len:
3044
- raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
3045
- f"equal to the device count of training process. "
3046
- f"But got the length of 'checkpoint_filenames'"
3047
- f" is {ckpt_file_len} and the device count is {train_dev_count}.")
3048
- rank_list = _infer_rank_list(train_strategy, predict_strategy)
3049
-
3050
- param_total_dict = defaultdict(dict)
3051
- for file_index, file_name in enumerate(checkpoint_filenames):
3052
- ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
3053
- for param_name, param in ckpt_dict.items():
3054
- param_total_dict[param_name][file_index] = param
3055
-
3056
- param_dict = {}
3057
- param_not_in_strategy = []
3058
- param_not_in_ckpt = []
3059
- for _, param in network.parameters_and_names():
3060
- sliced_params = []
3061
- if param.name not in rank_list.keys():
3062
- param_not_in_strategy.append(param.name)
3063
- continue
3064
- if param.name not in param_total_dict:
3065
- param_not_in_ckpt.append(param.name)
3066
- continue
3067
-
3068
- param_rank = rank_list.get(param.name)[0]
3069
- skip_merge_split = rank_list.get(param.name)[1]
3070
- shard_stride = train_strategy.get(param.name)[4]
3071
- tensor_map = train_strategy.get(param.name)[1]
3072
- first_dim_shard_idx = tensor_map[0] if tensor_map else -1
3073
- device_arrangement = train_strategy.get(param.name)[0]
3074
- first_dim_shard_size = 1
3075
- if first_dim_shard_idx >= 0:
3076
- first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
3077
- if train_strategy.get(param.name)[5]:
3078
- shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
3079
- else:
3080
- shard_size = 0
3081
- for rank in param_rank:
3082
- param_total_list = list(range(0, ckpt_file_len))
3083
- if first_dim_shard_size != 1:
3084
- param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
3085
- if shard_size > 0:
3086
- rank_index = param_total_list.index(rank)
3087
- start = rank_index // shard_size * shard_size
3088
- param_total_list = param_total_list[start:start + shard_size]
3089
- if shard_stride > 0:
3090
- param_stride = []
3091
- # merge pre parameter
3092
- param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
3093
- param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
3094
- param_index = list(set(param_index))
3095
- param_index.sort()
3096
- for rank_num in param_index:
3097
- if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
3098
- param_stride.append(
3099
- cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
3100
- else:
3101
- param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
3102
-
3103
- sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
3104
- else:
3105
- sliced_param = param_total_dict[param.name][rank]
3106
-
3107
- sliced_params.append(sliced_param)
3108
- if skip_merge_split:
3109
- split_param = sliced_params[0]
3110
- else:
3111
- param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
3112
- _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
3113
- split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
3114
- opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
3115
- if opt_shard_group:
3116
- if split_param.data.dtype == mstype.bfloat16:
3117
- data = cpu_cast(split_param.data, mstype.float32).asnumpy()
3118
- else:
3119
- data = split_param.data.asnumpy()
3120
- rank = get_rank(opt_shard_group)
3121
- size = get_group_size(opt_shard_group)
3122
- try:
3123
- data_slice = np.split(data, size)[rank]
3124
- except BaseException as e:
3125
- logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
3126
- " and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
3127
- raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
3128
- f" in load distributed checkpoint for {param.name}. Data shape is "
3129
- f"{split_param.data.shape} and group is {opt_shard_group}.") from e
3130
- split_param = Parameter(Tensor(data_slice), param.name,
3131
- split_param.requires_grad, split_param.layerwise_parallel)
3132
- param_dict[param.name] = split_param
3133
-
3134
- if param_not_in_strategy:
3135
- logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, "
3136
- "you can check whether 'predict_strategy' or 'train_strategy_filename' is correct."
3137
- .format(param_not_in_strategy))
3138
- if param_not_in_ckpt:
3139
- logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in "
3140
- "the checkpoint file, please check whether 'checkpoint_filenames' is correct."
3141
- .format(param_not_in_ckpt))
3142
-
3143
- load_param_into_net(network, param_dict, strict_load=strict_load)
3144
-
3145
-
3146
2614
  def async_ckpt_thread_status():
3147
2615
  """
3148
2616
  Get the status of asynchronous save checkpoint thread.
3149
2617
 
2618
+ Note:
2619
+ The interface is deprecated from version 2.5 and will be removed in a future version.
2620
+
3150
2621
  When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed.
3151
2622
 
3152
2623
  Returns:
@@ -3158,73 +2629,12 @@ def async_ckpt_thread_status():
3158
2629
  >>> ms.async_ckpt_thread_status()
3159
2630
  False
3160
2631
  """
2632
+ logger.warning("The interface 'mindspore.async_ckpt_thread_status' is deprecated from version 2.5 "
2633
+ "and will be removed in a future version.")
3161
2634
  thr_list = threading.enumerate()
3162
2635
  return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
3163
2636
 
3164
2637
 
3165
- def _check_predict_strategy(predict_strategy):
3166
- """Check predict strategy."""
3167
-
3168
- def _check_int_list(arg):
3169
- if not isinstance(arg, list):
3170
- return False
3171
- for item in arg:
3172
- if not isinstance(item, int):
3173
- return False
3174
- return True
3175
-
3176
- if predict_strategy is None:
3177
- return
3178
-
3179
- flag = True
3180
- predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
3181
- for key in predict_strategy.keys():
3182
- if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
3183
- or len(predict_strategy[key]) < 4:
3184
- flag = False
3185
- dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
3186
- if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
3187
- not (_check_int_list(param_split_shape) or not param_split_shape) or \
3188
- not (isinstance(field_size, int) and field_size == 0):
3189
- flag = False
3190
-
3191
- if not flag:
3192
- raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
3193
- f"the key of it must be string, and the value of it must be list or tuple that "
3194
- f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
3195
- f"param_split_shape (list[int]) and field_size (int, which value is 0)."
3196
- f"Please check whether 'predict_strategy' is correct.")
3197
-
3198
-
3199
- def _check_checkpoint_file(checkpoint_filenames):
3200
- """Check checkpoint file name."""
3201
- for index, filename in enumerate(checkpoint_filenames):
3202
- if not isinstance(filename, str) or not os.path.exists(filename) \
3203
- or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
3204
- raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and "
3205
- f"make sure the {filename} at index {index} is a valid checkpoint file, it must "
3206
- f"be a string ending with '.ckpt', and the checkpoint file it represents must "
3207
- f"be exist and not empty.")
3208
-
3209
-
3210
- def _merge_and_split(sliced_params, train_strategy, predict_strategy):
3211
- """Merge sliced parameter and split it according to the predict strategy."""
3212
- merged_param = merge_sliced_parameter(sliced_params, train_strategy)
3213
- if predict_strategy is None:
3214
- return merged_param
3215
- param_name = merged_param.name
3216
- tensor_layout = predict_strategy[param_name]
3217
- rank = get_rank()
3218
- split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
3219
- requires_grad = merged_param.requires_grad
3220
- layerwise_parallel = merged_param.layerwise_parallel
3221
- if merged_param.data.dtype == mstype.bfloat16:
3222
- split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
3223
- else:
3224
- split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
3225
- return split_param
3226
-
3227
-
3228
2638
  def _calculation_net_size(net):
3229
2639
  """Calculate the size of parameters in the network."""
3230
2640
  data_total = 0
@@ -3288,8 +2698,8 @@ def convert_model(mindir_file, convert_file, file_format):
3288
2698
  """
3289
2699
  Convert mindir model to other format model. The current version only supports conversion to ONNX models.
3290
2700
 
3291
- .. warning::
3292
- This is an experimental API that is subject to change or deletion.
2701
+ Note:
2702
+ The interface is deprecated from version 2.5 and will be removed in a future version.
3293
2703
 
3294
2704
  Args:
3295
2705
  mindir_file (str): MindIR file name.
@@ -3305,6 +2715,8 @@ def convert_model(mindir_file, convert_file, file_format):
3305
2715
  >>> import mindspore as ms
3306
2716
  >>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
3307
2717
  """
2718
+ logger.warning("The interface 'mindspore.train.serialization.convert_model' is deprecated from version 2.5 "
2719
+ "and will be removed in a future version.")
3308
2720
  Validator.check_file_name_by_regular(mindir_file)
3309
2721
  Validator.check_file_name_by_regular(convert_file)
3310
2722
  if file_format != "ONNX":
@@ -3316,3 +2728,235 @@ def convert_model(mindir_file, convert_file, file_format):
3316
2728
  export(net, net_input, file_name=convert_file, file_format=file_format)
3317
2729
  else:
3318
2730
  export(net, *net_input, file_name=convert_file, file_format=file_format)
2731
+
2732
+
2733
+ def _transform_tensor_to_numpy(path, name_map=None):
2734
+ return _load_and_transform(path, name_map, mindspore.load_checkpoint, lambda v, new_name: v.asnumpy())
2735
+
2736
+
2737
+ def _transform_numpy_to_tensor(path, name_map=None):
2738
+ return _load_and_transform(path, name_map, load_file, lambda v, new_name: mindspore.Parameter(v, name=new_name))
2739
+
2740
+
2741
+ def _process_file(file_info):
2742
+ cur_ckpt_path, name_map, save_path, file = file_info
2743
+ param_dict_numpy = _transform_tensor_to_numpy(cur_ckpt_path, name_map)
2744
+ safetensors_filename = file.replace(".ckpt", ".safetensors")
2745
+ dst_file = os.path.join(save_path, safetensors_filename)
2746
+ save_file(param_dict_numpy, dst_file)
2747
+
2748
+
2749
+ def _process_file_safetensors(file_info):
2750
+ cur_safe_path, name_map, save_path, file = file_info
2751
+ param_dict_tensor = _transform_numpy_to_tensor(cur_safe_path, name_map)
2752
+ ckpt_filename = file.replace(".safetensors", ".ckpt")
2753
+ dst_file = os.path.join(save_path, ckpt_filename)
2754
+ mindspore.save_checkpoint(param_dict_tensor, dst_file)
2755
+
2756
+
2757
+ def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
2758
+ """gather transform rank together"""
2759
+ tasks = []
2760
+ for root, dirs, _ in os.walk(file_path):
2761
+ if root != file_path:
2762
+ continue
2763
+
2764
+ rank_dirs = [d for d in dirs if d.startswith('rank')]
2765
+ if not rank_dirs:
2766
+ raise ValueError(
2767
+ f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}")
2768
+
2769
+ for rank_dir in rank_dirs:
2770
+ rank_dir_path = os.path.join(root, rank_dir)
2771
+ dst_root = os.path.join(save_path,
2772
+ os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
2773
+ os.makedirs(dst_root, exist_ok=True)
2774
+ tasks.extend(
2775
+ (os.path.join(rank_dir_path, file), name_map, dst_root, file)
2776
+ for file in os.listdir(rank_dir_path)
2777
+ if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
2778
+ )
2779
+ return tasks
2780
+
2781
+
2782
+ def _gather_tasks_covert(file_path, save_path, file_name_regex, name_map):
2783
+ """gather transform rank together"""
2784
+ tasks = []
2785
+ for root, dirs, _ in os.walk(file_path):
2786
+ if root != file_path:
2787
+ continue
2788
+
2789
+ rank_dirs = [d for d in dirs if d.startswith('rank')]
2790
+ if not rank_dirs:
2791
+ raise ValueError(
2792
+ f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
2793
+
2794
+ for rank_dir in rank_dirs:
2795
+ rank_dir_path = os.path.join(root, rank_dir)
2796
+ dst_root = os.path.join(save_path,
2797
+ os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
2798
+ os.makedirs(dst_root, exist_ok=True)
2799
+ tasks.extend(
2800
+ (os.path.join(rank_dir_path, file), name_map, dst_root, file)
2801
+ for file in os.listdir(rank_dir_path)
2802
+ if file.endswith(".ckpt") and (file_name_regex is None or re.findall(file_name_regex, file))
2803
+ )
2804
+ return tasks
2805
+
2806
+
2807
+ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
2808
+ """
2809
+ Converts MindSpore checkpoint files into safetensors format and saves them to `save_path`.
2810
+ Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
2811
+ used for securely storing Tensors with fast speed (zero copy).
2812
+
2813
+ Note:
2814
+ The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
2815
+ too large, otherwise it may cause freezing.
2816
+ The safetensors format does not support the enc verification function. If ckpt is enabled to save enc
2817
+ verification, an error will be generated when performing the conversion.
2818
+ The safetensors format currently does not support crc verification function. If ckpt contains crc verification
2819
+ information, the crc verification information will be lost after conversion to safetensors.
2820
+
2821
+ Args:
2822
+ file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
2823
+ save_path (str, optional): Directory path where safetensors files will be saved. Defaults: ``None``.
2824
+ name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
2825
+ file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
2826
+ Defaults: ``None``.
2827
+ processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
2828
+ Raises:
2829
+ ValueError: If the input path is invalid or the save_path is not a directory,
2830
+ or the file_path does not end with '.ckpt'.
2831
+
2832
+ Supported Platforms:
2833
+ ``Ascend`` ``GPU`` ``CPU``
2834
+
2835
+ Examples:
2836
+ >>> import mindspore as ms
2837
+ >>> ms.ckpt_to_safetensors("./ckpt_save_path")
2838
+ >>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt")
2839
+ >>> ms.ckpt_to_safetensors(file_path="./ckpt_save_path/rank0/checkpoint_0.ckpt", save_path="./new_path/")
2840
+ >>> namemap = {"lin.weight":"new_name"}
2841
+ >>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
2842
+ """
2843
+ is_dir = os.path.isdir(file_path)
2844
+ is_file = os.path.isfile(file_path)
2845
+ if not is_dir and not is_file:
2846
+ raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
2847
+ if save_path and os.path.splitext(save_path)[1]:
2848
+ raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
2849
+ if name_map is not None and not isinstance(name_map, dict):
2850
+ raise ValueError(
2851
+ f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
2852
+
2853
+ if is_dir:
2854
+ tasks = _gather_tasks_covert(file_path, save_path, file_name_regex, name_map)
2855
+ with mp.Pool(processes=processes_num) as pool:
2856
+ list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
2857
+ elif is_file:
2858
+ if not file_path.endswith(".ckpt"):
2859
+ raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
2860
+ if file_name_regex is not None and not re.findall(file_name_regex, file_path):
2861
+ raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
2862
+ if save_path and not os.path.exists(save_path):
2863
+ os.makedirs(save_path, exist_ok=True)
2864
+
2865
+ param_dict_numpy = _transform_tensor_to_numpy(file_path, name_map)
2866
+ safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
2867
+ dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
2868
+ save_file(param_dict_numpy, dst_file)
2869
+
2870
+
2871
+ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
2872
+ """
2873
+ Converts safetensors files into MindSpore checkpoint format and saves them to `save_path`.
2874
+ Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
2875
+ used for securely storing Tensors with fast speed (zero copy).
2876
+
2877
+ Note:
2878
+ The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
2879
+ too large, otherwise it may cause freezing.
2880
+
2881
+ Args:
2882
+ file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
2883
+ save_path (str, optional): Directory path where checkpoint files will be saved. Defaults: ``None``.
2884
+ name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
2885
+ file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
2886
+ Defaults: ``None``.
2887
+ processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
2888
+
2889
+ Raises:
2890
+ ValueError: If the input path is invalid, the save_path is not a directory,
2891
+ or the file_path does not end with '.safetensors'.
2892
+
2893
+ Supported Platforms:
2894
+ ``Ascend`` ``GPU`` ``CPU``
2895
+
2896
+ Examples:
2897
+ >>> import mindspore as ms
2898
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path")
2899
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors")
2900
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/")
2901
+ >>> namemap = {"lin.weight":"new_name"}
2902
+ >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
2903
+ """
2904
+ is_dir = os.path.isdir(file_path)
2905
+ is_file = os.path.isfile(file_path)
2906
+ if not is_dir and not is_file:
2907
+ raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
2908
+ if save_path and os.path.splitext(save_path)[1]:
2909
+ raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
2910
+ if name_map is not None and not isinstance(name_map, dict):
2911
+ raise ValueError(
2912
+ f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
2913
+
2914
+ if is_dir:
2915
+ tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
2916
+ with mp.Pool(processes=processes_num) as pool:
2917
+ list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
2918
+ elif is_file:
2919
+ if not file_path.endswith(".safetensors"):
2920
+ raise ValueError(
2921
+ f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
2922
+ if file_name_regex is not None and not re.findall(file_name_regex, file_path):
2923
+ raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
2924
+ if save_path and not os.path.exists(save_path):
2925
+ os.makedirs(save_path, exist_ok=True)
2926
+
2927
+ param_dict_tensor = _transform_numpy_to_tensor(file_path, name_map)
2928
+ ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
2929
+ dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
2930
+ mindspore.save_checkpoint(param_dict_tensor, dst_file)
2931
+
2932
+
2933
+ def restore_group_info_list(group_info_file_name):
2934
+ """
2935
+ Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
2936
+ who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
2937
+ environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
2938
+ """
2939
+ return new_restore_group_info_list(group_info_file_name)
2940
+
2941
+
2942
+ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
2943
+ train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
2944
+ format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
2945
+ output_format='safetensors', name_map=None, max_process_num=64,
2946
+ return_param_dict=False):
2947
+ """ Load checkpoint into net for distributed predication. Used in the case of distributed inference. """
2948
+ new_load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy,
2949
+ train_strategy_filename, strict_load, dec_key, dec_mode,
2950
+ format, unified_safetensors_dir, dst_safetensors_dir, rank_id,
2951
+ output_format, name_map, max_process_num,
2952
+ return_param_dict)
2953
+
2954
+
2955
+ def merge_sliced_parameter(sliced_parameters, strategy=None):
2956
+ """ Merge parameter slices into one parameter. Used in the case of distributed inference. """
2957
+ return new_merge_sliced_parameter(sliced_parameters, strategy)
2958
+
2959
+
2960
+ def build_searched_strategy(strategy_filename):
2961
+ """ Build strategy of every parameter in network. Used in the case of distributed inference. """
2962
+ return new_build_searched_strategy(strategy_filename)