mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (493) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +25 -194
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +109 -75
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +2014 -3386
  46. mindspore/common/api.py +386 -355
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/generator.py +3 -0
  52. mindspore/common/hook_handle.py +5 -3
  53. mindspore/common/initializer.py +10 -6
  54. mindspore/common/jit_begin_end.py +94 -0
  55. mindspore/common/jit_config.py +6 -1
  56. mindspore/common/jit_context.py +76 -0
  57. mindspore/common/jit_trace.py +378 -0
  58. mindspore/common/lazy_inline.py +2 -2
  59. mindspore/common/mutable.py +5 -4
  60. mindspore/common/parameter.py +106 -39
  61. mindspore/common/seed.py +2 -2
  62. mindspore/common/sparse_tensor.py +23 -17
  63. mindspore/common/tensor.py +332 -714
  64. mindspore/communication/__init__.py +7 -5
  65. mindspore/communication/_comm_helper.py +47 -2
  66. mindspore/communication/comm_func.py +70 -53
  67. mindspore/communication/management.py +83 -17
  68. mindspore/context.py +228 -571
  69. mindspore/dataset/__init__.py +44 -20
  70. mindspore/dataset/audio/__init__.py +2 -8
  71. mindspore/dataset/audio/transforms.py +3 -17
  72. mindspore/dataset/core/config.py +3 -3
  73. mindspore/dataset/engine/cache_client.py +1 -1
  74. mindspore/dataset/engine/datasets.py +102 -120
  75. mindspore/dataset/engine/datasets_audio.py +22 -22
  76. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  77. mindspore/dataset/engine/datasets_text.py +78 -85
  78. mindspore/dataset/engine/datasets_user_defined.py +109 -77
  79. mindspore/dataset/engine/datasets_vision.py +111 -108
  80. mindspore/dataset/engine/iterators.py +5 -3
  81. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  82. mindspore/dataset/engine/samplers.py +279 -57
  83. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  84. mindspore/dataset/engine/validators.py +10 -0
  85. mindspore/dataset/text/__init__.py +7 -6
  86. mindspore/dataset/text/transforms.py +6 -5
  87. mindspore/dataset/text/utils.py +3 -3
  88. mindspore/dataset/transforms/__init__.py +0 -9
  89. mindspore/dataset/transforms/transforms.py +3 -3
  90. mindspore/dataset/utils/browse_dataset.py +1 -1
  91. mindspore/dataset/vision/__init__.py +2 -9
  92. mindspore/dataset/vision/transforms.py +202 -158
  93. mindspore/dataset/vision/utils.py +7 -5
  94. mindspore/device_context/ascend/op_debug.py +60 -1
  95. mindspore/device_context/ascend/op_tuning.py +0 -4
  96. mindspore/device_manager.py +39 -3
  97. mindspore/dnnl.dll +0 -0
  98. mindspore/dpcmi.dll +0 -0
  99. mindspore/experimental/es/embedding_service.py +35 -27
  100. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
  101. mindspore/experimental/map_parameter.py +4 -4
  102. mindspore/experimental/optim/adadelta.py +22 -26
  103. mindspore/experimental/optim/adagrad.py +4 -4
  104. mindspore/experimental/optim/adam.py +4 -0
  105. mindspore/experimental/optim/adamax.py +4 -4
  106. mindspore/experimental/optim/adamw.py +4 -0
  107. mindspore/experimental/optim/asgd.py +1 -1
  108. mindspore/experimental/optim/lr_scheduler.py +40 -22
  109. mindspore/experimental/optim/radam.py +5 -5
  110. mindspore/experimental/optim/rprop.py +1 -1
  111. mindspore/experimental/optim/sgd.py +1 -1
  112. mindspore/hal/contiguous_tensors_handle.py +6 -10
  113. mindspore/hal/device.py +55 -81
  114. mindspore/hal/event.py +38 -55
  115. mindspore/hal/memory.py +115 -147
  116. mindspore/hal/stream.py +81 -125
  117. mindspore/include/dataset/constants.h +7 -4
  118. mindspore/include/dataset/execute.h +2 -2
  119. mindspore/jpeg62.dll +0 -0
  120. mindspore/log.py +40 -2
  121. mindspore/mindrecord/__init__.py +20 -7
  122. mindspore/mindspore_backend_common.dll +0 -0
  123. mindspore/mindspore_backend_manager.dll +0 -0
  124. mindspore/mindspore_common.dll +0 -0
  125. mindspore/mindspore_core.dll +0 -0
  126. mindspore/mindspore_dump.dll +0 -0
  127. mindspore/mindspore_frontend.dll +0 -0
  128. mindspore/mindspore_glog.dll +0 -0
  129. mindspore/mindspore_memory_pool.dll +0 -0
  130. mindspore/mindspore_ms_backend.dll +0 -0
  131. mindspore/mindspore_ops.dll +0 -0
  132. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  133. mindspore/mindspore_ops_kernel_common.dll +0 -0
  134. mindspore/mindspore_profiler.dll +0 -0
  135. mindspore/mindspore_pyboost.dll +0 -0
  136. mindspore/mindspore_pynative.dll +0 -0
  137. mindspore/mindspore_res_manager.dll +0 -0
  138. mindspore/mindspore_runtime_pipeline.dll +0 -0
  139. mindspore/mint/__init__.py +133 -702
  140. mindspore/mint/distributed/__init__.py +5 -1
  141. mindspore/mint/distributed/distributed.py +198 -113
  142. mindspore/mint/linalg/__init__.py +2 -0
  143. mindspore/mint/nn/__init__.py +280 -18
  144. mindspore/mint/nn/functional.py +282 -64
  145. mindspore/mint/nn/layer/__init__.py +4 -0
  146. mindspore/mint/nn/layer/_functions.py +7 -3
  147. mindspore/mint/nn/layer/activation.py +120 -13
  148. mindspore/mint/nn/layer/conv.py +234 -28
  149. mindspore/mint/nn/layer/normalization.py +15 -16
  150. mindspore/mint/nn/layer/padding.py +1 -1
  151. mindspore/mint/nn/layer/pooling.py +66 -1
  152. mindspore/mint/optim/__init__.py +2 -1
  153. mindspore/mint/optim/sgd.py +171 -0
  154. mindspore/msobj140.dll +0 -0
  155. mindspore/mspdb140.dll +0 -0
  156. mindspore/mspdbcore.dll +0 -0
  157. mindspore/mspdbst.dll +0 -0
  158. mindspore/mspft140.dll +0 -0
  159. mindspore/msvcdis140.dll +0 -0
  160. mindspore/msvcp140_1.dll +0 -0
  161. mindspore/msvcp140_2.dll +0 -0
  162. mindspore/msvcp140_atomic_wait.dll +0 -0
  163. mindspore/msvcp140_codecvt_ids.dll +0 -0
  164. mindspore/nn/__init__.py +4 -1
  165. mindspore/nn/cell.py +1253 -179
  166. mindspore/nn/layer/activation.py +23 -21
  167. mindspore/nn/layer/basic.py +22 -16
  168. mindspore/nn/layer/container.py +1 -1
  169. mindspore/nn/layer/conv.py +53 -42
  170. mindspore/nn/layer/embedding.py +9 -8
  171. mindspore/nn/layer/normalization.py +48 -42
  172. mindspore/nn/layer/pooling.py +75 -31
  173. mindspore/nn/layer/transformer.py +11 -10
  174. mindspore/nn/learning_rate_schedule.py +4 -2
  175. mindspore/nn/loss/loss.py +27 -19
  176. mindspore/nn/optim/ada_grad.py +6 -5
  177. mindspore/nn/optim/adadelta.py +9 -7
  178. mindspore/nn/optim/adafactor.py +1 -1
  179. mindspore/nn/optim/adam.py +18 -14
  180. mindspore/nn/optim/adamax.py +8 -7
  181. mindspore/nn/optim/adasum.py +5 -5
  182. mindspore/nn/optim/asgd.py +3 -1
  183. mindspore/nn/optim/ftrl.py +11 -9
  184. mindspore/nn/optim/lamb.py +1 -1
  185. mindspore/nn/optim/lazyadam.py +12 -10
  186. mindspore/nn/optim/momentum.py +7 -6
  187. mindspore/nn/optim/optimizer.py +2 -2
  188. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  189. mindspore/nn/optim/rmsprop.py +13 -12
  190. mindspore/nn/optim/rprop.py +9 -7
  191. mindspore/nn/optim/sgd.py +9 -6
  192. mindspore/nn/optim/tft_wrapper.py +5 -2
  193. mindspore/nn/probability/bijector/bijector.py +17 -11
  194. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  195. mindspore/nn/probability/bijector/invert.py +2 -2
  196. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  197. mindspore/nn/probability/bijector/softplus.py +3 -2
  198. mindspore/nn/probability/distribution/beta.py +3 -3
  199. mindspore/nn/probability/distribution/categorical.py +1 -1
  200. mindspore/nn/probability/distribution/cauchy.py +4 -2
  201. mindspore/nn/probability/distribution/exponential.py +6 -7
  202. mindspore/nn/probability/distribution/gamma.py +2 -2
  203. mindspore/nn/probability/distribution/gumbel.py +2 -2
  204. mindspore/nn/probability/distribution/half_normal.py +5 -3
  205. mindspore/nn/probability/distribution/logistic.py +5 -3
  206. mindspore/nn/probability/distribution/poisson.py +1 -1
  207. mindspore/nn/probability/distribution/uniform.py +5 -3
  208. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  209. mindspore/nn/reinforcement/tensor_array.py +1 -1
  210. mindspore/nn/wrap/__init__.py +6 -6
  211. mindspore/nn/wrap/cell_wrapper.py +178 -117
  212. mindspore/nn/wrap/grad_reducer.py +45 -36
  213. mindspore/nn/wrap/loss_scale.py +3 -3
  214. mindspore/numpy/array_creations.py +3 -3
  215. mindspore/numpy/array_ops.py +1 -1
  216. mindspore/numpy/utils.py +1 -2
  217. mindspore/numpy/utils_const.py +1 -2
  218. mindspore/opencv_core452.dll +0 -0
  219. mindspore/opencv_imgcodecs452.dll +0 -0
  220. mindspore/opencv_imgproc452.dll +0 -0
  221. mindspore/ops/__init__.py +3 -2
  222. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  223. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  224. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  225. mindspore/ops/_register_for_op.py +0 -11
  226. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  227. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  228. mindspore/ops/_vmap/vmap_array_ops.py +32 -6
  229. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  230. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  231. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  232. mindspore/ops/auto_generate/__init__.py +4 -3
  233. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
  234. mindspore/ops/auto_generate/gen_extend_func.py +286 -208
  235. mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
  236. mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
  237. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  238. mindspore/ops/composite/__init__.py +2 -1
  239. mindspore/ops/composite/base.py +19 -24
  240. mindspore/ops/composite/math_ops.py +6 -16
  241. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  242. mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
  243. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  244. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  248. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  249. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  250. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  251. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  252. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  254. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  255. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  256. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  257. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  259. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  260. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  263. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  264. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  268. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  271. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  272. mindspore/ops/function/__init__.py +28 -2
  273. mindspore/ops/function/_add_attr_func.py +58 -0
  274. mindspore/ops/function/array_func.py +1631 -2347
  275. mindspore/ops/function/clip_func.py +38 -45
  276. mindspore/ops/function/debug_func.py +36 -44
  277. mindspore/ops/function/grad/__init__.py +1 -0
  278. mindspore/ops/function/grad/grad_func.py +104 -71
  279. mindspore/ops/function/image_func.py +1 -1
  280. mindspore/ops/function/linalg_func.py +46 -78
  281. mindspore/ops/function/math_func.py +3024 -3855
  282. mindspore/ops/function/nn_func.py +678 -274
  283. mindspore/ops/function/other_func.py +159 -1
  284. mindspore/ops/function/parameter_func.py +17 -30
  285. mindspore/ops/function/random_func.py +216 -361
  286. mindspore/ops/function/reshard_func.py +4 -70
  287. mindspore/ops/function/sparse_func.py +3 -3
  288. mindspore/ops/function/sparse_unary_func.py +5 -5
  289. mindspore/ops/function/spectral_func.py +25 -58
  290. mindspore/ops/function/vmap_func.py +26 -18
  291. mindspore/ops/functional.py +8 -5
  292. mindspore/ops/functional_overload.py +655 -4
  293. mindspore/ops/op_info_register.py +32 -244
  294. mindspore/ops/operations/__init__.py +21 -14
  295. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  296. mindspore/ops/operations/_grad_ops.py +1 -10
  297. mindspore/ops/operations/_inner_ops.py +5 -76
  298. mindspore/ops/operations/_ms_kernel.py +4 -10
  299. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  300. mindspore/ops/operations/_scalar_ops.py +3 -2
  301. mindspore/ops/operations/_sequence_ops.py +1 -1
  302. mindspore/ops/operations/_tensor_array.py +1 -1
  303. mindspore/ops/operations/array_ops.py +39 -24
  304. mindspore/ops/operations/comm_ops.py +150 -107
  305. mindspore/ops/operations/custom_ops.py +287 -32
  306. mindspore/ops/operations/debug_ops.py +119 -16
  307. mindspore/ops/operations/inner_ops.py +1 -1
  308. mindspore/ops/operations/linalg_ops.py +1 -58
  309. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  310. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  311. mindspore/ops/operations/math_ops.py +21 -18
  312. mindspore/ops/operations/nn_ops.py +67 -224
  313. mindspore/ops/operations/other_ops.py +62 -9
  314. mindspore/ops/operations/random_ops.py +13 -7
  315. mindspore/ops/operations/reshard_ops.py +1 -1
  316. mindspore/ops/operations/sparse_ops.py +2 -2
  317. mindspore/ops/primitive.py +43 -32
  318. mindspore/ops/tensor_method.py +243 -17
  319. mindspore/ops_generate/__init__.py +0 -5
  320. mindspore/ops_generate/aclnn/__init__.py +0 -0
  321. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  322. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  323. mindspore/ops_generate/api/__init__.py +0 -0
  324. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  325. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  326. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  327. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  328. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  329. mindspore/ops_generate/api/gen_api.py +103 -0
  330. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  331. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  332. mindspore/ops_generate/common/__init__.py +0 -0
  333. mindspore/ops_generate/common/gen_constants.py +91 -0
  334. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  335. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  336. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  337. mindspore/ops_generate/gen_ops.py +23 -325
  338. mindspore/ops_generate/op_def/__init__.py +0 -0
  339. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  340. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  341. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
  342. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  343. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  344. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  345. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  346. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  347. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  348. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  349. mindspore/ops_generate/pyboost/__init__.py +0 -0
  350. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  351. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  352. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  353. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  354. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  355. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  356. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  357. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  358. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  359. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  360. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  361. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  362. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  363. mindspore/ops_generate/resources/__init__.py +0 -0
  364. mindspore/ops_generate/resources/resource_list.py +30 -0
  365. mindspore/ops_generate/resources/resource_loader.py +36 -0
  366. mindspore/ops_generate/resources/resource_manager.py +64 -0
  367. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  368. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  369. mindspore/parallel/__init__.py +6 -2
  370. mindspore/parallel/_auto_parallel_context.py +140 -12
  371. mindspore/parallel/_cell_wrapper.py +132 -15
  372. mindspore/parallel/_parallel_serialization.py +95 -4
  373. mindspore/parallel/_ps_context.py +1 -1
  374. mindspore/parallel/_recovery_context.py +7 -2
  375. mindspore/parallel/_tensor.py +142 -18
  376. mindspore/parallel/_utils.py +198 -25
  377. mindspore/parallel/algo_parameter_config.py +3 -3
  378. mindspore/parallel/auto_parallel.py +732 -0
  379. mindspore/parallel/checkpoint_convert.py +159 -0
  380. mindspore/parallel/checkpoint_transform.py +658 -37
  381. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  382. mindspore/parallel/cluster/run.py +1 -1
  383. mindspore/parallel/function/__init__.py +24 -0
  384. mindspore/parallel/function/reshard_func.py +258 -0
  385. mindspore/parallel/nn/__init__.py +25 -0
  386. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  387. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  388. mindspore/parallel/parameter_broadcast.py +24 -13
  389. mindspore/parallel/shard.py +137 -62
  390. mindspore/parallel/transform_safetensors.py +288 -95
  391. mindspore/pgodb140.dll +0 -0
  392. mindspore/pgort140.dll +0 -0
  393. mindspore/profiler/__init__.py +9 -5
  394. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  395. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  397. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
  398. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  399. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  400. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  401. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  402. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  403. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  404. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  405. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  406. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  407. mindspore/profiler/common/constant.py +12 -0
  408. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  409. mindspore/profiler/common/path_manager.py +24 -0
  410. mindspore/profiler/common/profiler_context.py +26 -2
  411. mindspore/profiler/common/profiler_meta_data.py +74 -0
  412. mindspore/profiler/common/profiler_parameters.py +59 -18
  413. mindspore/profiler/common/profiler_path_manager.py +66 -7
  414. mindspore/profiler/dynamic_profiler.py +112 -79
  415. mindspore/profiler/envprofiler.py +26 -1
  416. mindspore/profiler/experimental_config.py +197 -0
  417. mindspore/profiler/mstx.py +57 -14
  418. mindspore/profiler/platform/npu_profiler.py +33 -7
  419. mindspore/profiler/profiler.py +541 -45
  420. mindspore/profiler/profiler_action_controller.py +1 -1
  421. mindspore/profiler/profiler_interface.py +4 -0
  422. mindspore/profiler/schedule.py +57 -22
  423. mindspore/rewrite/api/node.py +15 -13
  424. mindspore/rewrite/api/symbol_tree.py +1 -1
  425. mindspore/run_check/_check_version.py +25 -14
  426. mindspore/run_check/run_check.py +1 -1
  427. mindspore/runtime/__init__.py +2 -2
  428. mindspore/runtime/executor.py +40 -11
  429. mindspore/runtime/memory.py +37 -13
  430. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  431. mindspore/swresample-4.dll +0 -0
  432. mindspore/swscale-6.dll +0 -0
  433. mindspore/tbbmalloc.dll +0 -0
  434. mindspore/tinyxml2.dll +0 -0
  435. mindspore/train/__init__.py +8 -8
  436. mindspore/train/_utils.py +43 -9
  437. mindspore/train/amp.py +1 -1
  438. mindspore/train/callback/__init__.py +2 -2
  439. mindspore/train/callback/_callback.py +2 -16
  440. mindspore/train/callback/_checkpoint.py +24 -40
  441. mindspore/train/callback/_cluster_monitor.py +14 -18
  442. mindspore/train/callback/_flops_collector.py +2 -3
  443. mindspore/train/callback/_history.py +7 -4
  444. mindspore/train/callback/_lambda_callback.py +2 -2
  445. mindspore/train/callback/_landscape.py +0 -3
  446. mindspore/train/callback/_loss_monitor.py +2 -1
  447. mindspore/train/callback/_on_request_exit.py +6 -5
  448. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  449. mindspore/train/callback/_summary_collector.py +8 -13
  450. mindspore/train/callback/_time_monitor.py +2 -1
  451. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
  452. mindspore/train/data_sink.py +25 -2
  453. mindspore/train/dataset_helper.py +4 -5
  454. mindspore/train/loss_scale_manager.py +8 -7
  455. mindspore/train/metrics/accuracy.py +3 -3
  456. mindspore/train/metrics/confusion_matrix.py +9 -9
  457. mindspore/train/metrics/error.py +3 -3
  458. mindspore/train/metrics/hausdorff_distance.py +4 -4
  459. mindspore/train/metrics/mean_surface_distance.py +3 -3
  460. mindspore/train/metrics/metric.py +0 -12
  461. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  462. mindspore/train/metrics/precision.py +8 -6
  463. mindspore/train/metrics/recall.py +9 -9
  464. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  465. mindspore/train/mind_ir_pb2.py +19 -12
  466. mindspore/train/model.py +262 -127
  467. mindspore/train/serialization.py +246 -988
  468. mindspore/train/summary/_summary_adapter.py +2 -2
  469. mindspore/train/summary/summary_record.py +1 -1
  470. mindspore/turbojpeg.dll +0 -0
  471. mindspore/utils/__init__.py +3 -2
  472. mindspore/utils/dryrun.py +4 -2
  473. mindspore/utils/hooks.py +81 -0
  474. mindspore/utils/runtime_execution_order_check.py +2 -0
  475. mindspore/utils/utils.py +138 -4
  476. mindspore/vcmeta.dll +0 -0
  477. mindspore/vcruntime140.dll +0 -0
  478. mindspore/vcruntime140_1.dll +0 -0
  479. mindspore/version.py +1 -1
  480. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
  481. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
  482. mindspore/_install_custom.py +0 -43
  483. mindspore/common/_register_for_adapter.py +0 -74
  484. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  485. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  486. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  487. mindspore/ops_generate/gen_constants.py +0 -190
  488. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  489. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  490. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  492. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
  493. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -18,34 +18,46 @@ from __future__ import absolute_import
18
18
  import os
19
19
  import glob
20
20
  import copy
21
+ from multiprocessing import Pool
21
22
  from collections import defaultdict
22
23
  import numpy as np
23
24
  import mindspore as ms
25
+ from mindspore import log as logger
26
+ from mindspore import _checkparam as Validator
24
27
  from mindspore.common import dtype as mstype
25
- from mindspore.parallel._utils import _is_in_auto_parallel_mode, _get_pipeline_stages
28
+ from mindspore.common.parameter import Parameter
29
+ from mindspore.common.tensor import Tensor
30
+ from mindspore.communication.management import get_rank, get_group_size
31
+ from mindspore.parallel._tensor import _load_tensor, _reshape_param_data, _reshape_param_data_with_weight, \
32
+ _get_tensor_slice_index, _get_tensor_strategy
33
+ from mindspore.parallel._utils import _is_in_auto_parallel_mode, _get_pipeline_stages, _infer_rank_list, \
34
+ _remove_repeated_slices, _get_auto_parallel_net
26
35
  from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
27
- _transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
36
+ _transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, _build_searched_strategy, \
28
37
  _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
29
- _merge_protobuf_strategy, _merge_json_strategy, _extract_src_dst_layout_map_by_src
30
- from mindspore.parallel.transform_safetensors import _transform_safetensors, _collect_safetensor_files
38
+ _merge_protobuf_strategy, _merge_json_strategy, _extract_src_dst_layout_map_by_src, _convert_to_list, \
39
+ _check_checkpoint_file, _check_predict_strategy, _gather_tasks_load_dis, _get_param_list_when_first_dim_sharded, \
40
+ _convert_to_layout, _restore_group_info_list
31
41
  from mindspore._c_expression import AutoParallelContext
42
+ from mindspore.parallel.transform_safetensors import _transform_safetensors, _collect_safetensor_files, \
43
+ _load_parallel_checkpoint
32
44
 
33
45
  __all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
34
- "transform_checkpoints", "sync_pipeline_shared_parameters", "load_segmented_checkpoints"]
46
+ "transform_checkpoints", "sync_pipeline_shared_parameters", "load_segmented_checkpoints",
47
+ "load_distributed_checkpoint", "merge_sliced_parameter", "restore_group_info_list",
48
+ "build_searched_strategy"]
35
49
 
36
50
 
37
51
  def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
38
52
  """
39
- Merge parallel strategy between all pipeline stages in pipeline parallel mode.
40
- For more details about converting distributed Checkpoint, please refer to
41
- `Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
53
+ Aggregate the sharding strategy files of all pipeline parallel subgraphs to the destination file.
42
54
 
43
55
  Note:
44
56
  Strategy file of each pipeline stage should be included in src_strategy_dirs.
45
57
 
46
58
  Args:
47
59
  src_strategy_dirs (str): The directory of strategy files including all pipeline stage which is saved by
48
- 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
60
+ :func:`mindspore.parallel.auto_parallel.AutoParallel.save_param_strategy_file`.
49
61
  dst_strategy_file (str): The file merged strategy to save.
50
62
 
51
63
  Raises:
@@ -54,7 +66,7 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
54
66
  Examples:
55
67
  >>> import mindspore as ms
56
68
  >>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt
57
- >>> ms.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt")
69
+ >>> ms.parallel.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt")
58
70
 
59
71
  """
60
72
  dst_strategy_dir, _ = os.path.split(dst_strategy_file)
@@ -73,11 +85,211 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
73
85
  _merge_json_strategy(src_strategy_files_json, dst_strategy_file)
74
86
 
75
87
 
88
+ def merge_sliced_parameter(sliced_parameters, strategy=None):
89
+ """
90
+ Merge parameter slices into one parameter. Used in the case of distributed inference.
91
+
92
+ Args:
93
+ sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
94
+ strategy (Optional[dict], optional): Parameter slice strategy, whose key is parameter name and
95
+ value is slice strategy of this parameter. If strategy is None, just merge
96
+ parameter slices in 0 axis order. Default: ``None``.
97
+
98
+ Returns:
99
+ Parameter, the merged parameter which has the whole data.
100
+
101
+ Raises:
102
+ ValueError: Failed to merge.
103
+ TypeError: The sliced_parameters is incorrect or strategy is not dict.
104
+ KeyError: The parameter name is not in keys of strategy.
105
+
106
+ Examples:
107
+ >>> import numpy as np
108
+ >>> import mindspore as ms
109
+ >>> from mindspore import Tensor, Parameter
110
+ >>>
111
+ >>> sliced_parameters = [
112
+ ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
113
+ ... "network.embedding_table"),
114
+ ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
115
+ ... "network.embedding_table"),
116
+ ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
117
+ ... "network.embedding_table"),
118
+ ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
119
+ ... "network.embedding_table")]
120
+ >>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters)
121
+ >>> print(merged_parameter)
122
+ Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
123
+ """
124
+ if not isinstance(sliced_parameters, list):
125
+ raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
126
+ f"but got {type(sliced_parameters)}.")
127
+
128
+ if not sliced_parameters:
129
+ raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
130
+
131
+ if strategy and not isinstance(strategy, dict):
132
+ raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
133
+ f"but got {type(strategy)}.")
134
+
135
+ try:
136
+ parameter_name = sliced_parameters[0].name
137
+ parameter_shape = sliced_parameters[0].data.shape
138
+ parameter_shape_length = len(parameter_shape)
139
+ except BaseException as e:
140
+ raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be "
141
+ f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e
142
+
143
+ is_even = True
144
+ for index, parameter in enumerate(sliced_parameters):
145
+ if not isinstance(parameter, Parameter):
146
+ raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', "
147
+ f"but got {type(parameter)} at index {index}.")
148
+
149
+ if parameter.name != parameter_name \
150
+ or len(parameter.data.shape) != parameter_shape_length \
151
+ or parameter.data.shape[1:] != parameter_shape[1:]:
152
+ raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'"
153
+ f" have the same name, dimension length and shape except 0 axis. The name, dimension "
154
+ f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, "
155
+ f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: "
156
+ f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} "
157
+ f"at index {index}.")
158
+
159
+ if parameter.data.shape != parameter_shape:
160
+ is_even = False
161
+
162
+ layerwise_parallel = sliced_parameters[0].layerwise_parallel
163
+ requires_grad = sliced_parameters[0].requires_grad
164
+ sliced_data = []
165
+ for parameter in sliced_parameters:
166
+ if parameter.data.dtype == mstype.bfloat16:
167
+ from mindspore.ops import Cast
168
+ cpu_cast = Cast().set_device("CPU")
169
+ sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
170
+ else:
171
+ sliced_data.append(parameter.data.asnumpy())
172
+
173
+ if not strategy:
174
+ merged_tensor = Tensor(np.concatenate(sliced_data))
175
+ merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
176
+
177
+ else:
178
+ if parameter_name not in strategy.keys():
179
+ raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in "
180
+ f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.")
181
+ merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
182
+ merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
183
+
184
+ return merged_parameter
185
+
186
+
187
+ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
188
+ """Merge sliced parameter and split it according to the predict strategy."""
189
+ merged_param = merge_sliced_parameter(sliced_params, train_strategy)
190
+ if not predict_strategy:
191
+ return merged_param
192
+ param_name = merged_param.name
193
+ tensor_layout = predict_strategy[param_name]
194
+ rank = get_rank()
195
+ split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
196
+ requires_grad = merged_param.requires_grad
197
+ layerwise_parallel = merged_param.layerwise_parallel
198
+ if merged_param.data.dtype == mstype.bfloat16:
199
+ split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
200
+ else:
201
+ split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
202
+ return split_param
203
+
204
+
205
+ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
206
+ """
207
+ Merge data slices to one tensor with whole data when strategy is not None.
208
+
209
+ Args:
210
+ sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
211
+ parameter_name (str): Name of parameter.
212
+ strategy (dict): Parameter slice strategy.
213
+ is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
214
+
215
+ Returns:
216
+ Tensor, the merged Tensor which has the whole data.
217
+
218
+ Raises:
219
+ ValueError: Failed to merge.
220
+ """
221
+ layout = strategy.get(parameter_name)
222
+ try:
223
+ dev_mat = list(layout.dev_matrix[0].dim)
224
+ tensor_map = list(layout.tensor_map[0].dim)
225
+ param_split_shape = list(layout.param_split_shape[0].dim)
226
+ field_size = int(layout.field)
227
+ except BaseException as e:
228
+ raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'"
229
+ f", please make sure that 'strategy' is correct.") from e
230
+
231
+ device_count = 1
232
+ for dim in dev_mat:
233
+ device_count *= dim
234
+
235
+ if len(sliced_data) != device_count:
236
+ raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to "
237
+ f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but "
238
+ f"device_count is {device_count}.")
239
+
240
+ if not param_split_shape:
241
+ if not is_even:
242
+ raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' "
243
+ "should be the same when slice manner is even.")
244
+
245
+ all_gather_tensor = Tensor(np.concatenate(sliced_data))
246
+
247
+ if field_size > 0:
248
+ merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
249
+ else:
250
+ merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
251
+
252
+ else:
253
+ tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
254
+
255
+ slice_count = 1
256
+ for dim in tensor_strategy:
257
+ slice_count *= dim
258
+
259
+ if len(param_split_shape) != slice_count:
260
+ raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be "
261
+ f"{slice_count}, but got {len(param_split_shape)}.")
262
+
263
+ tensor_slices_new = list(range(slice_count))
264
+ tensor_slices = sliced_data
265
+ for i in range(device_count):
266
+ slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
267
+ if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
268
+ raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be "
269
+ f"{param_split_shape[slice_index]} in 0 axis, but got "
270
+ f"{tensor_slices[i].shape[0]}.")
271
+ tensor_slices_new[slice_index] = np.array(tensor_slices[i])
272
+
273
+ dim_len = len(tensor_strategy)
274
+ for i in range(dim_len):
275
+ ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
276
+ tensor_slices_new_inner = []
277
+ for j in range(ele_count):
278
+ new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
279
+ for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
280
+ (j + 1) * tensor_strategy[dim_len - 1 - i]):
281
+ new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
282
+ tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
283
+ tensor_slices_new = tensor_slices_new_inner
284
+ merged_tensor = Tensor(tensor_slices_new[0])
285
+
286
+ return merged_tensor
287
+
288
+
76
289
  def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None):
77
290
  """
78
291
  List of original distributed checkpoint rank index for obtaining the target checkpoint of a rank_id during the
79
- distributed checkpoint conversion. For more details about converting distributed Checkpoint, please refer to
80
- `Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
292
+ distributed checkpoint conversion.
81
293
 
82
294
  Args:
83
295
  rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
@@ -102,7 +314,7 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
102
314
  Examples:
103
315
  >>> import mindspore as ms
104
316
  >>> rank_id = 0
105
- >>> rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
317
+ >>> rank_list = ms.parallel.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
106
318
  >>> checkpoint_files_map = {}
107
319
  >>> for rank in rank_list:
108
320
  ... checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
@@ -141,8 +353,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
141
353
  src_strategy_file=None, dst_strategy_file=None):
142
354
  """
143
355
  Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank
144
- for a network. For more details about converting distributed Checkpoint, please refer to
145
- `Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
356
+ for a network.
146
357
 
147
358
  Args:
148
359
  rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
@@ -150,11 +361,11 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
150
361
  the checkpoint file name.
151
362
  save_checkpoint_file_name (str): The file name to save the converted checkpoint.
152
363
  src_strategy_file (str): Name of source sharding strategy file which saved by
153
- 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
364
+ `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
154
365
  when the `src_strategy_file` is None, it means that the source sharding strategy is
155
366
  without any sharing for each parameter. Default: ``None``.
156
367
  dst_strategy_file (str): Name of destination sharding strategy file which saved by
157
- 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
368
+ `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
158
369
  when the `dst_strategy_file` is ``None``,
159
370
  it means that the destination sharding strategy
160
371
  is without any sharing for each parameter. Default: ``None``.
@@ -362,8 +573,6 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
362
573
  dst_strategy_file=None, process_num=1, output_format="ckpt"):
363
574
  """
364
575
  Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank.
365
- For more details about converting distributed Checkpoint, please refer to
366
- `Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
367
576
 
368
577
  Note:
369
578
  The `src_checkpoints_dir` directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the
@@ -387,7 +596,7 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
387
596
  is without any sharing for each parameter. Default:None.
388
597
  process_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
389
598
  output_format (str, optional): Control the format of the output checkpoint after conversion.
390
- It can be set to either "ckpt" or "safetensors". Default: "ckpt".
599
+ It can be set to either ``"ckpt"`` or ``"safetensors"``. Default: ``"ckpt"``.
391
600
 
392
601
  Raises:
393
602
  ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
@@ -473,18 +682,21 @@ def _sync_params(name, param, layout):
473
682
  shape=param.shape,
474
683
  dtype=param.dtype)(param))
475
684
 
476
-
685
+ # pylint: disable=W0212
477
686
  def sync_pipeline_shared_parameters(net):
478
- """synchronize pipeline parallel stage shared parameters.
479
- Parameters may be shared between different stages. For example, `embedding table` is
687
+ """Synchronization of shared weights between stages for pipeline parallel inference scenarios.
688
+ For example, `embedding table` is
480
689
  shared by `WordEmbedding` layer and `LMHead` layer, which are usually split into different stages. It is necessary
481
690
  to perform synchronization after `embedding table` changes.
482
691
 
483
692
  Note:
484
- The network should be compiled before synchronize pipeline parallel stage shared parameters.
693
+ The network should be compiled before shared parameters are synchronized in the pipeline parallel stage.
485
694
 
486
695
  Args:
487
- net (nn.Cell): the inference network.
696
+ net (Cell): the inference network.
697
+
698
+ Raises:
699
+ TypeError: `net` is not in Cell type.
488
700
 
489
701
  Supported Platforms:
490
702
  ``Ascend``
@@ -494,12 +706,13 @@ def sync_pipeline_shared_parameters(net):
494
706
  Before running the following examples, you need to configure the communication environment variables.
495
707
 
496
708
  For the Ascend device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
497
- Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
709
+ Startup <https://www.mindspore.cn/tutorials/en/master/parallel/dynamic_cluster.html>`_ .
498
710
 
499
711
  >>> import numpy as np
500
712
  >>> import mindspore as ms
501
713
  >>> import mindspore.communication.management as D
502
714
  >>> from mindspore import lazy_inline, context, nn, ops, Parameter, Tensor
715
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
503
716
  >>> context.set_context(mode=context.GRAPH_MODE)
504
717
  >>> class Embedding(nn.Cell):
505
718
  ... def __init__(self, shape):
@@ -547,14 +760,16 @@ def sync_pipeline_shared_parameters(net):
547
760
  ... ret = self.concat(ret)
548
761
  ... return ret
549
762
  >>> D.init()
550
- >>> context.set_auto_parallel_context(parallel_mode='semi_auto_parallel', full_batch=True, pipeline_stages=2)
551
763
  >>> net = Network()
552
764
  >>> net = PipelineCellInference(net, 2)
553
765
  >>> net.set_train(False)
554
766
  >>> x = Tensor(np.ones((2, 4)), ms.float32)
555
767
  >>> net.compile(x)
556
- >>> ms.sync_pipeline_shared_parameters(net)
557
- >>> print(net.network.word_embedding.w.asnumpy())
768
+ >>> pp_net = AutoParallel(net, parallel_mode="semi_auto")
769
+ >>> pp_net.full_batch = True
770
+ >>> pp_net.pipeline(stages=2, scheduler="1f1b")
771
+ >>> ms.parallel.sync_pipeline_shared_parameters(pp_net)
772
+ >>> print(pp_net.network.network.word_embedding.w.asnumpy())
558
773
  [[1. 1. 1. 1.]
559
774
  [1. 1. 1. 1.]
560
775
  [1. 1. 1. 1.]
@@ -567,18 +782,25 @@ def sync_pipeline_shared_parameters(net):
567
782
  "but got {}.".format(type(net)))
568
783
  raise TypeError(msg)
569
784
 
570
- if _get_pipeline_stages() < 2:
785
+ parallel_net = _get_auto_parallel_net(net)
786
+ pipeline_stages = 1
787
+ if type(parallel_net).__name__ != 'AutoParallel':
788
+ pipeline_stages = _get_pipeline_stages()
789
+ else:
790
+ pipeline_stages = parallel_net._pipeline_stages
791
+ if pipeline_stages < 2:
571
792
  return
572
793
 
573
794
  layout_dict = net.parameter_layout_dict
574
- if _is_in_auto_parallel_mode() and not layout_dict:
795
+ if (_is_in_auto_parallel_mode() or (type(parallel_net).__name__ == 'AutoParallel')) and not layout_dict:
575
796
  from mindspore.common.api import _get_parameter_layout
576
797
  layout_dict = _get_parameter_layout()
577
798
 
578
799
  # switch to standalone mode
579
- parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
580
- full_batch = ms.context.get_auto_parallel_context("full_batch")
581
- ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False)
800
+ if type(parallel_net).__name__ != 'AutoParallel':
801
+ parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
802
+ full_batch = ms.context.get_auto_parallel_context("full_batch")
803
+ ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False)
582
804
 
583
805
  # synchronize shared parameter
584
806
  for name, param in net.parameters_and_names():
@@ -586,7 +808,8 @@ def sync_pipeline_shared_parameters(net):
586
808
  _sync_params(name, param, layout_dict[name])
587
809
 
588
810
  # restore parallel context
589
- ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)
811
+ if type(parallel_net).__name__ != 'AutoParallel':
812
+ ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)
590
813
 
591
814
 
592
815
  def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filter_prefix=None,
@@ -636,6 +859,9 @@ def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filte
636
859
  ValueError: Checkpoint file's format is incorrect.
637
860
  ValueError: Parameter's dict is None after load checkpoint file.
638
861
  TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
862
+
863
+ Supported Platforms:
864
+ ``Ascend``
639
865
  """
640
866
  if not isinstance(ckpt_file_dir, str):
641
867
  raise TypeError("The ckpt_file_dir should be a str.")
@@ -656,8 +882,10 @@ def set_op_strategy_config(mode="SAVE", path=""):
656
882
  Set strategy json configuration when using sharding propagation.
657
883
 
658
884
  .. warning::
659
- This is an experimental interface, may be changed or canceled in the future;
660
- This interface currently doesn't support saving or loading strategies using layout.
885
+ - This is an experimental interface, may be changed or canceled in the future, please use the api
886
+ :func:`mindspore.parallel.auto_parallel.AutoParallel.load_operator_strategy_file` or
887
+ :func:`mindspore.parallel.auto_parallel.AutoParallel.save_operator_strategy_file` instead;
888
+ - This interface currently doesn't support saving or loading strategies using layout.
661
889
 
662
890
  Note:
663
891
  - It only works when `parallel_mode=ParallelMode.AUTO_PARALLEL` and `search_mode='sharding_propagation'`.
@@ -692,3 +920,396 @@ def set_op_strategy_config(mode="SAVE", path=""):
692
920
  AutoParallelContext.get_instance().set_ops_strategy_json_config(mode, path, "all")
693
921
  else:
694
922
  raise KeyError("Type must be 'SAVE' or 'LOAD'")
923
+
924
+
925
+ def build_searched_strategy(strategy_filename):
926
+ """
927
+ Extract the sharding strategy for each parameter in the network from the strategy file
928
+ for distributed inference scenarios.
929
+
930
+ Args:
931
+ strategy_filename (str): Name of strategy file.
932
+
933
+ Returns:
934
+ Dict, whose key is parameter name and value is slice strategy of this parameter.
935
+
936
+ Raises:
937
+ ValueError: Strategy file is incorrect.
938
+ TypeError: `strategy_filename` is not a string.
939
+
940
+ Supported Platforms:
941
+ ``Ascend``
942
+
943
+ Examples:
944
+ >>> from mindspore.parallel import build_searched_strategy
945
+ >>> strategy = build_searched_strategy("./strategy_train.ckpt")
946
+ """
947
+ return _build_searched_strategy(strategy_filename)
948
+
949
+
950
+ # disable pylint too broad Exception
951
+ # pylint: disable=W0212
952
+ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
953
+ train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
954
+ format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
955
+ output_format='safetensors', name_map=None, max_process_num=64,
956
+ return_param_dict=False):
957
+ """
958
+ Load checkpoint into net for distributed predication. Used in the case of distributed inference.
959
+
960
+ Note:
961
+ `output_format` will only take effect when `format` is set to `safetensors` and `network` is set to `None`.
962
+
963
+ Args:
964
+ network (Cell): Network for distributed predication, When the format is `safetensors`, the network parameter
965
+ can be left blank or passed as None, and the interface will execute save mode.
966
+ checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
967
+ predict_strategy (Union[dict, str]): Strategy of predication process. It means that using one device to predict
968
+ when setting predict_strategy as None. Default: ``None`` .
969
+ train_strategy_filename (str): The filename of training strategy protocol buffer file.
970
+ When train_strategy_filename is None, the training strategy file will be
971
+ obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
972
+ Therefore, the training strategy file needs to be specified
973
+ in at least one of them. Default: ``None`` .
974
+ strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
975
+ into net when parameter name's suffix in checkpoint file is the same as the
976
+ parameter in the network. When the types are inconsistent, perform type conversion
977
+ on the parameters of the same type, such as float32 to float16. Default: ``False`` .
978
+ dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
979
+ is not required. Default: ``None`` .
980
+ dec_mode (str): Specifies the decryption
981
+ mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
982
+ This parameter is valid only when dec_key is not set to ``None`` .
983
+ Default: ``'AES-GCM'`` .
984
+ format (str): Input weight format to be loaded into the network.
985
+ It can be set to either "ckpt" or "safetensors". Default: "ckpt".
986
+ unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
987
+ Default: ``None`` .
988
+ dst_safetensors_dir (str): In the save mode scenario, the save directory for weights.
989
+ rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
990
+ globally by initializing the network; In save mode, save the file according to the input
991
+ sequence number. If it is not input, save the entire file.
992
+ output_format (str, optional): Control the format of the output checkpoint after conversion.
993
+ It can be set to either "ckpt" or "safetensors". Default: "safetensors".
994
+ name_map (dict): The weight mapping dictionary will modify the weight names according to the mapping
995
+ dictionary before loading or saving the segmented weights into the network. Default: None.
996
+ max_process_num (int): Maximum number of processes. Default: 64.
997
+ return_param_dict (bool): Whether to return the param_dict. Default: ``False``.
998
+
999
+ Raises:
1000
+ TypeError: The type of inputs do not match the requirements.
1001
+ ValueError: Failed to load checkpoint into net.
1002
+
1003
+ Supported Platforms:
1004
+ ``Ascend``
1005
+
1006
+ Examples:
1007
+ .. note::
1008
+ Before running the following examples, you need to configure the communication environment variables.
1009
+
1010
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
1011
+ Please see the `rank table startup
1012
+ <https://www.mindspore.cn/tutorials/en/master/parallel/rank_table.html>`_
1013
+ for more details.
1014
+
1015
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
1016
+ Startup <https://www.mindspore.cn/tutorials/en/master/parallel/dynamic_cluster.html>`_ .
1017
+
1018
+ >>> import os
1019
+ >>> import numpy as np
1020
+ >>> import mindspore as ms
1021
+ >>> import mindspore.dataset as ds
1022
+ >>> from mindspore import nn, ops, train
1023
+ >>> from mindspore.communication import init
1024
+ >>> from mindspore.parallel import load_distributed_checkpoint
1025
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
1026
+ >>> from mindspore.nn.utils import no_init_parameters
1027
+ >>> from mindspore.common.initializer import initializer, One
1028
+ >>> from mindspore.communication.management import get_group_size
1029
+ >>>
1030
+ >>> step_per_epoch = 4
1031
+ >>> device_num = get_group_size()
1032
+ >>>
1033
+ >>> # Define the network structure.
1034
+ >>> class Net(nn.Cell):
1035
+ ... def __init__(self, matmul_size, strategy=None):
1036
+ ... super().__init__()
1037
+ ... self.matmul_weight = ms.Parameter(initializer(One(), matmul_size, ms.float32))
1038
+ ... self.matmul = ops.MatMul()
1039
+ ... self.neg = ops.Neg()
1040
+ ... if strategy is not None:
1041
+ ... self.matmul.shard(strategy)
1042
+ ...
1043
+ ... def construct(self, inputs):
1044
+ ... x = self.matmul(inputs, self.matmul_weight)
1045
+ ... x = self.neg(x)
1046
+ ... return x
1047
+ >>>
1048
+ >>> # Create dataset.
1049
+ >>> def get_dataset(*inputs):
1050
+ ... def generate():
1051
+ ... for _ in range(step_per_epoch):
1052
+ ... yield inputs
1053
+ ... return generate
1054
+ >>>
1055
+ >>> # Train network and save distributed checkpoint.
1056
+ >>> def train_net():
1057
+ ... ms.set_context(mode=ms.GRAPH_MODE)
1058
+ ... init()
1059
+ ... np.random.seed(1)
1060
+ ... input_data = np.random.rand(16, 96).astype(np.float32)
1061
+ ... label_data = np.random.rand(16, 16).astype(np.float32)
1062
+ ... fake_dataset = get_dataset(input_data, label_data)
1063
+ ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
1064
+ ...
1065
+ ... # Set parallel strategy.
1066
+ ... strategy = ((1, 4), (4, 1))
1067
+ ... with no_init_parameters():
1068
+ ... network = Net(matmul_size=(96, 16), strategy=strategy)
1069
+ ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
1070
+ ...
1071
+ ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
1072
+ ... network = AutoParallel(network, parallel_mode="semi_auto")
1073
+ ... network.save_param_strategy_file(file_path="./train_strategy.ckpt")
1074
+ ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
1075
+ ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=True)
1076
+ ... global_rank_id = int(os.getenv("RANK_ID"))
1077
+ ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
1078
+ ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
1079
+ ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
1080
+ >>>
1081
+ >>> # Load distributed checkpoint and test.
1082
+ >>> def load_model():
1083
+ ... ms.set_context(mode=ms.GRAPH_MODE)
1084
+ ... init()
1085
+ ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
1086
+ ... with no_init_parameters():
1087
+ ... network = Net(matmul_size=(96, 16))
1088
+ ... network = AutoParallel(network, parallel_mode="semi_auto")
1089
+ ... network.dataset_strategy(config="full_batch")
1090
+ ... train_strategy_file = "./train_strategy.ckpt"
1091
+ ... network.save_param_strategy_file(file_path=train_strategy_file)
1092
+ ... model = ms.Model(network)
1093
+ ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
1094
+ ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
1095
+ ... load_distributed_checkpoint(network, ckpt_file_list, predict_layout, None)
1096
+ ... predict_result = model.predict(predict_data)
1097
+ ... print(predict_result)
1098
+ >>>
1099
+ >>> train_net()
1100
+ >>> load_model()
1101
+ [[-9.62929535e+00, -9.76258755e+00, -9.70192051e+00 ... -9.67151260e+00, -9.71998310e+00, -9.64571190e+00],
1102
+ [-4.63218540e-01, -4.07317460e-01, -3.78161550e-01 ... -3.95918339e-01, -2.87363172e-01, -3.48693460e-01],
1103
+ ...
1104
+ [-4.28075647e+00, -4.36630344e+00, -4.25664043e+00 ... -4.32012939e+00, -4.30337954e+00, -4.27571440e+00]]
1105
+ """
1106
+ if format not in ['safetensors', 'ckpt'] or output_format not in ['safetensors', 'ckpt']:
1107
+ raise ValueError(
1108
+ f"For 'load_distributed_checkpoint', 'format' and 'output_format' "
1109
+ f"must be 'ckpt' or 'safetensors', but got {format}.")
1110
+
1111
+ if format == 'safetensors':
1112
+ if unified_safetensors_dir is None:
1113
+ raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
1114
+ f"when format is 'safetensors'.")
1115
+ unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
1116
+ for param in unsupport_param:
1117
+ if param is not None:
1118
+ raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
1119
+ f"when format is 'safetensors'.")
1120
+ if strict_load or dec_mode != 'AES-GCM':
1121
+ raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
1122
+ f"when format is 'safetensors'.")
1123
+ if network is not None:
1124
+ try:
1125
+ rank_id = get_rank()
1126
+ except RuntimeError:
1127
+ rank_id = 0
1128
+ logger.warning(f"Get rank failed, default loading weight for rank 0.")
1129
+ param_dict = _load_parallel_checkpoint(
1130
+ (unified_safetensors_dir, predict_strategy, network, None, rank_id, output_format, name_map,
1131
+ return_param_dict))
1132
+ return param_dict
1133
+ if dst_safetensors_dir is None:
1134
+ raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
1135
+ f"when network is None.")
1136
+ if rank_id is not None:
1137
+ _load_parallel_checkpoint(
1138
+ (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
1139
+ rank_id, output_format, name_map, return_param_dict))
1140
+ else:
1141
+ dst_strategy_dict = _build_searched_strategy(predict_strategy)
1142
+ dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
1143
+ dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
1144
+ dst_device_num = dst_stage_device_num * dst_stage_num
1145
+ tasks = _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
1146
+ dst_device_num, output_format, name_map, return_param_dict)
1147
+ with Pool(processes=max_process_num) as pool:
1148
+ list(pool.imap(_load_parallel_checkpoint, tasks))
1149
+ return True
1150
+
1151
+ network = Validator.check_isinstance("network", network, ms.nn.Cell)
1152
+ _check_checkpoint_file(checkpoint_filenames)
1153
+ _check_predict_strategy(predict_strategy)
1154
+
1155
+ dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1156
+ dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1157
+
1158
+ if train_strategy_filename is None:
1159
+ parallel_net = _get_auto_parallel_net(network)
1160
+ if parallel_net.__class__.__name__ == "AutoParallel":
1161
+ train_strategy_filename = parallel_net._save_strategy_file_path
1162
+ else:
1163
+ train_strategy_filename = ms.context.get_auto_parallel_context("strategy_ckpt_load_file")
1164
+
1165
+ _train_strategy = build_searched_strategy(train_strategy_filename)
1166
+ train_strategy = _convert_to_list(_train_strategy)
1167
+
1168
+ train_dev_count = 1
1169
+ ckpt_file_len = len(checkpoint_filenames)
1170
+ for dim in train_strategy[list(train_strategy.keys())[0]][0]:
1171
+ train_dev_count *= dim
1172
+ if train_dev_count != ckpt_file_len:
1173
+ raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
1174
+ f"equal to the device count of training process. "
1175
+ f"But got the length of 'checkpoint_filenames'"
1176
+ f" is {ckpt_file_len} and the device count is {train_dev_count}.")
1177
+ rank_list = _infer_rank_list(train_strategy, predict_strategy)
1178
+
1179
+ param_total_dict = defaultdict(dict)
1180
+ for file_index, file_name in enumerate(checkpoint_filenames):
1181
+ ckpt_dict = ms.load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
1182
+ for param_name, param in ckpt_dict.items():
1183
+ param_total_dict[param_name][file_index] = param
1184
+
1185
+ param_dict = {}
1186
+ param_not_in_strategy = []
1187
+ param_not_in_ckpt = []
1188
+ for _, param in network.parameters_and_names():
1189
+ sliced_params = []
1190
+ if param.name not in rank_list.keys():
1191
+ param_not_in_strategy.append(param.name)
1192
+ continue
1193
+ if param.name not in param_total_dict:
1194
+ param_not_in_ckpt.append(param.name)
1195
+ continue
1196
+
1197
+ param_rank = rank_list.get(param.name)[0]
1198
+ skip_merge_split = rank_list.get(param.name)[1]
1199
+ shard_stride = train_strategy.get(param.name)[4]
1200
+ tensor_map = train_strategy.get(param.name)[1]
1201
+ first_dim_shard_idx = tensor_map[0] if tensor_map else -1
1202
+ device_arrangement = train_strategy.get(param.name)[0]
1203
+ first_dim_shard_size = 1
1204
+ if first_dim_shard_idx >= 0:
1205
+ first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
1206
+ if train_strategy.get(param.name)[5]:
1207
+ repeat_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
1208
+ else:
1209
+ repeat_size = 0
1210
+ for rank in param_rank:
1211
+ param_total_list = list(range(0, ckpt_file_len))
1212
+ if first_dim_shard_size != 1:
1213
+ param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
1214
+ if repeat_size > 0:
1215
+ shard_size = shard_stride * train_strategy.get(param.name)[5]
1216
+ rank_index = param_total_list.index(rank)
1217
+ start = rank_index // shard_size * shard_size
1218
+ param_total_list = param_total_list[start:start + shard_size]
1219
+ if shard_stride > 0:
1220
+ param_stride = []
1221
+ # merge pre parameter
1222
+ param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
1223
+ param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
1224
+ param_index = list(set(param_index))
1225
+ param_index.sort()
1226
+ for rank_num in param_index:
1227
+ if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
1228
+ from mindspore.ops import Cast
1229
+ cpu_cast = Cast().set_device("CPU")
1230
+ param_stride.append(
1231
+ cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
1232
+ else:
1233
+ param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
1234
+
1235
+ sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
1236
+ else:
1237
+ sliced_param = param_total_dict[param.name][rank]
1238
+
1239
+ sliced_params.append(sliced_param)
1240
+ if skip_merge_split:
1241
+ split_param = sliced_params[0]
1242
+ else:
1243
+ param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
1244
+ _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
1245
+ split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
1246
+ opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
1247
+ if opt_shard_group:
1248
+ if split_param.data.dtype == mstype.bfloat16:
1249
+ from mindspore.ops import Cast
1250
+ cpu_cast = Cast().set_device("CPU")
1251
+ data = cpu_cast(split_param.data, mstype.float32).asnumpy()
1252
+ else:
1253
+ data = split_param.data.asnumpy()
1254
+ rank = get_rank(opt_shard_group)
1255
+ size = get_group_size(opt_shard_group)
1256
+ try:
1257
+ data_slice = np.split(data, size)[rank]
1258
+ except BaseException as e:
1259
+ logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
1260
+ " and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
1261
+ raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
1262
+ f" in load distributed checkpoint for {param.name}. Data shape is "
1263
+ f"{split_param.data.shape} and group is {opt_shard_group}.") from e
1264
+ split_param = Parameter(Tensor(data_slice), param.name,
1265
+ split_param.requires_grad, split_param.layerwise_parallel)
1266
+ param_dict[param.name] = split_param
1267
+
1268
+ if param_not_in_strategy:
1269
+ logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, "
1270
+ "you can check whether 'predict_strategy' or 'train_strategy_filename' is correct."
1271
+ .format(param_not_in_strategy))
1272
+ if param_not_in_ckpt:
1273
+ logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in "
1274
+ "the checkpoint file, please check whether 'checkpoint_filenames' is correct."
1275
+ .format(param_not_in_ckpt))
1276
+
1277
+ ms.load_param_into_net(network, param_dict, strict_load=strict_load)
1278
+ return True
1279
+
1280
+
1281
+ def restore_group_info_list(group_info_file_name):
1282
+ """
1283
+ Extract rank list information from communication domain files. To save the group info file,
1284
+ please export GROUP_INFO_FIL
1285
+ environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
1286
+
1287
+ Args:
1288
+ group_info_file_name (str): Name of group information file.
1289
+
1290
+ Returns:
1291
+ List, the rank list.
1292
+
1293
+ Raises:
1294
+ ValueError: group information file is incorrect.
1295
+ TypeError: `group_info_file_name` is not str.
1296
+
1297
+ Supported Platforms:
1298
+ ``Ascend``
1299
+
1300
+ Examples:
1301
+ >>> import mindspore as ms
1302
+ >>> from mindspore.parallel import restore_group_info_list
1303
+ >>> ms.restore_list = restore_group_info_list("./group_info.pb")
1304
+ """
1305
+ if not isinstance(group_info_file_name, str):
1306
+ raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
1307
+ f"but got {type(group_info_file_name)}.")
1308
+
1309
+ if not os.path.isfile(group_info_file_name):
1310
+ raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
1311
+
1312
+ if os.path.getsize(group_info_file_name) == 0:
1313
+ raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
1314
+
1315
+ return _restore_group_info_list(group_info_file_name)