mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +46 -197
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +217 -98
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +435 -371
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +2 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +951 -1992
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +314 -566
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/experimental/es/embedding_service.py +35 -27
  112. mindspore/experimental/llm_boost/__init__.py +1 -0
  113. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  114. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  115. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  116. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  117. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  118. mindspore/experimental/llm_boost/register.py +1 -0
  119. mindspore/experimental/map_parameter.py +4 -4
  120. mindspore/experimental/optim/adadelta.py +6 -6
  121. mindspore/experimental/optim/adagrad.py +4 -4
  122. mindspore/experimental/optim/adam.py +7 -0
  123. mindspore/experimental/optim/adamax.py +4 -4
  124. mindspore/experimental/optim/adamw.py +4 -0
  125. mindspore/experimental/optim/asgd.py +1 -1
  126. mindspore/experimental/optim/lr_scheduler.py +73 -46
  127. mindspore/experimental/optim/radam.py +34 -31
  128. mindspore/experimental/optim/rprop.py +1 -1
  129. mindspore/experimental/optim/sgd.py +1 -1
  130. mindspore/hal/contiguous_tensors_handle.py +6 -10
  131. mindspore/hal/device.py +55 -53
  132. mindspore/hal/event.py +52 -52
  133. mindspore/hal/memory.py +157 -117
  134. mindspore/hal/stream.py +150 -109
  135. mindspore/include/api/context.h +0 -1
  136. mindspore/include/dataset/constants.h +7 -4
  137. mindspore/include/dataset/execute.h +2 -2
  138. mindspore/jpeg62.dll +0 -0
  139. mindspore/log.py +50 -0
  140. mindspore/mindrecord/__init__.py +21 -8
  141. mindspore/mindrecord/config.py +17 -316
  142. mindspore/mindrecord/filereader.py +1 -9
  143. mindspore/mindrecord/filewriter.py +5 -15
  144. mindspore/mindrecord/mindpage.py +1 -9
  145. mindspore/mindspore_backend_common.dll +0 -0
  146. mindspore/mindspore_backend_manager.dll +0 -0
  147. mindspore/mindspore_common.dll +0 -0
  148. mindspore/mindspore_core.dll +0 -0
  149. mindspore/mindspore_dump.dll +0 -0
  150. mindspore/mindspore_frontend.dll +0 -0
  151. mindspore/mindspore_memory_pool.dll +0 -0
  152. mindspore/mindspore_ms_backend.dll +0 -0
  153. mindspore/mindspore_ops.dll +0 -0
  154. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  155. mindspore/mindspore_ops_kernel_common.dll +0 -0
  156. mindspore/mindspore_profiler.dll +0 -0
  157. mindspore/mindspore_pyboost.dll +0 -0
  158. mindspore/mindspore_pynative.dll +0 -0
  159. mindspore/mindspore_res_manager.dll +0 -0
  160. mindspore/mindspore_runtime_pipeline.dll +0 -0
  161. mindspore/mint/__init__.py +796 -759
  162. mindspore/mint/distributed/__init__.py +70 -4
  163. mindspore/mint/distributed/distributed.py +2679 -44
  164. mindspore/mint/linalg/__init__.py +8 -0
  165. mindspore/mint/nn/__init__.py +743 -22
  166. mindspore/mint/nn/functional.py +716 -23
  167. mindspore/mint/nn/layer/__init__.py +21 -4
  168. mindspore/mint/nn/layer/_functions.py +334 -0
  169. mindspore/mint/nn/layer/activation.py +276 -1
  170. mindspore/mint/nn/layer/basic.py +123 -0
  171. mindspore/mint/nn/layer/conv.py +921 -0
  172. mindspore/mint/nn/layer/normalization.py +223 -28
  173. mindspore/mint/nn/layer/padding.py +797 -0
  174. mindspore/mint/nn/layer/pooling.py +235 -0
  175. mindspore/mint/optim/__init__.py +3 -1
  176. mindspore/mint/optim/adam.py +223 -0
  177. mindspore/mint/optim/adamw.py +26 -19
  178. mindspore/mint/optim/sgd.py +171 -0
  179. mindspore/mint/special/__init__.py +2 -1
  180. mindspore/multiprocessing/__init__.py +5 -0
  181. mindspore/nn/__init__.py +4 -1
  182. mindspore/nn/cell.py +1370 -189
  183. mindspore/nn/dynamic_lr.py +2 -1
  184. mindspore/nn/layer/activation.py +29 -27
  185. mindspore/nn/layer/basic.py +51 -35
  186. mindspore/nn/layer/channel_shuffle.py +3 -3
  187. mindspore/nn/layer/container.py +1 -1
  188. mindspore/nn/layer/conv.py +22 -17
  189. mindspore/nn/layer/embedding.py +12 -11
  190. mindspore/nn/layer/normalization.py +56 -49
  191. mindspore/nn/layer/padding.py +4 -3
  192. mindspore/nn/layer/pooling.py +120 -42
  193. mindspore/nn/layer/rnn_cells.py +1 -1
  194. mindspore/nn/layer/rnns.py +2 -1
  195. mindspore/nn/layer/timedistributed.py +5 -5
  196. mindspore/nn/layer/transformer.py +59 -36
  197. mindspore/nn/learning_rate_schedule.py +8 -4
  198. mindspore/nn/loss/loss.py +58 -55
  199. mindspore/nn/optim/ada_grad.py +7 -5
  200. mindspore/nn/optim/adadelta.py +11 -9
  201. mindspore/nn/optim/adafactor.py +1 -1
  202. mindspore/nn/optim/adam.py +17 -13
  203. mindspore/nn/optim/adamax.py +8 -7
  204. mindspore/nn/optim/adasum.py +5 -5
  205. mindspore/nn/optim/asgd.py +1 -1
  206. mindspore/nn/optim/ftrl.py +11 -9
  207. mindspore/nn/optim/lamb.py +1 -1
  208. mindspore/nn/optim/lars.py +1 -4
  209. mindspore/nn/optim/lazyadam.py +12 -10
  210. mindspore/nn/optim/momentum.py +7 -6
  211. mindspore/nn/optim/optimizer.py +3 -3
  212. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  213. mindspore/nn/optim/rmsprop.py +13 -12
  214. mindspore/nn/optim/rprop.py +11 -9
  215. mindspore/nn/optim/sgd.py +9 -6
  216. mindspore/nn/optim/tft_wrapper.py +5 -2
  217. mindspore/nn/optim/thor.py +2 -1
  218. mindspore/nn/probability/bijector/bijector.py +17 -11
  219. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  220. mindspore/nn/probability/bijector/invert.py +2 -2
  221. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  222. mindspore/nn/probability/bijector/softplus.py +3 -2
  223. mindspore/nn/probability/distribution/beta.py +3 -3
  224. mindspore/nn/probability/distribution/categorical.py +1 -1
  225. mindspore/nn/probability/distribution/cauchy.py +4 -2
  226. mindspore/nn/probability/distribution/exponential.py +6 -7
  227. mindspore/nn/probability/distribution/gamma.py +2 -2
  228. mindspore/nn/probability/distribution/gumbel.py +2 -2
  229. mindspore/nn/probability/distribution/half_normal.py +5 -3
  230. mindspore/nn/probability/distribution/logistic.py +5 -3
  231. mindspore/nn/probability/distribution/poisson.py +1 -1
  232. mindspore/nn/probability/distribution/uniform.py +5 -3
  233. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  234. mindspore/nn/reinforcement/tensor_array.py +1 -1
  235. mindspore/nn/utils/init.py +13 -11
  236. mindspore/nn/wrap/__init__.py +6 -6
  237. mindspore/nn/wrap/cell_wrapper.py +181 -122
  238. mindspore/nn/wrap/grad_reducer.py +45 -36
  239. mindspore/nn/wrap/loss_scale.py +6 -7
  240. mindspore/numpy/array_creations.py +63 -65
  241. mindspore/numpy/array_ops.py +149 -144
  242. mindspore/numpy/logic_ops.py +41 -42
  243. mindspore/numpy/math_ops.py +365 -363
  244. mindspore/numpy/utils.py +17 -18
  245. mindspore/numpy/utils_const.py +5 -6
  246. mindspore/opencv_core452.dll +0 -0
  247. mindspore/opencv_imgcodecs452.dll +0 -0
  248. mindspore/opencv_imgproc452.dll +0 -0
  249. mindspore/ops/__init__.py +5 -3
  250. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  251. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  252. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  253. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  254. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  255. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  256. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  257. mindspore/ops/_register_for_op.py +0 -11
  258. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  259. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  260. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  261. mindspore/ops/_vmap/vmap_base.py +0 -2
  262. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  263. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  264. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  265. mindspore/ops/auto_generate/__init__.py +4 -3
  266. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  267. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  268. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  269. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  270. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  271. mindspore/ops/composite/__init__.py +2 -1
  272. mindspore/ops/composite/base.py +20 -25
  273. mindspore/ops/composite/math_ops.py +6 -16
  274. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  275. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  276. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  277. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  278. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  279. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  282. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  283. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  284. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  285. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  286. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  287. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  288. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  289. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  291. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  292. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  293. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  294. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  295. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  297. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  301. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  302. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  304. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  305. mindspore/ops/function/__init__.py +40 -2
  306. mindspore/ops/function/_add_attr_func.py +58 -0
  307. mindspore/ops/function/array_func.py +2089 -2403
  308. mindspore/ops/function/clip_func.py +80 -23
  309. mindspore/ops/function/debug_func.py +57 -57
  310. mindspore/ops/function/grad/__init__.py +1 -0
  311. mindspore/ops/function/grad/grad_func.py +104 -71
  312. mindspore/ops/function/image_func.py +2 -2
  313. mindspore/ops/function/linalg_func.py +47 -78
  314. mindspore/ops/function/math_func.py +4501 -3802
  315. mindspore/ops/function/nn_func.py +1726 -620
  316. mindspore/ops/function/other_func.py +159 -1
  317. mindspore/ops/function/parameter_func.py +18 -84
  318. mindspore/ops/function/random_func.py +440 -387
  319. mindspore/ops/function/reshard_func.py +4 -70
  320. mindspore/ops/function/sparse_func.py +3 -3
  321. mindspore/ops/function/sparse_unary_func.py +6 -6
  322. mindspore/ops/function/spectral_func.py +25 -58
  323. mindspore/ops/function/vmap_func.py +24 -17
  324. mindspore/ops/functional.py +22 -7
  325. mindspore/ops/functional_overload.py +1440 -0
  326. mindspore/ops/op_info_register.py +32 -244
  327. mindspore/ops/operations/__init__.py +13 -7
  328. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  329. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  330. mindspore/ops/operations/_grad_ops.py +2 -43
  331. mindspore/ops/operations/_infer_ops.py +2 -1
  332. mindspore/ops/operations/_inner_ops.py +43 -84
  333. mindspore/ops/operations/_ms_kernel.py +4 -10
  334. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  335. mindspore/ops/operations/_scalar_ops.py +3 -2
  336. mindspore/ops/operations/_sequence_ops.py +1 -1
  337. mindspore/ops/operations/_tensor_array.py +1 -1
  338. mindspore/ops/operations/array_ops.py +81 -324
  339. mindspore/ops/operations/comm_ops.py +154 -108
  340. mindspore/ops/operations/custom_ops.py +232 -78
  341. mindspore/ops/operations/debug_ops.py +153 -59
  342. mindspore/ops/operations/inner_ops.py +7 -5
  343. mindspore/ops/operations/linalg_ops.py +1 -57
  344. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  345. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  346. mindspore/ops/operations/math_ops.py +32 -234
  347. mindspore/ops/operations/nn_ops.py +210 -498
  348. mindspore/ops/operations/other_ops.py +62 -9
  349. mindspore/ops/operations/random_ops.py +13 -7
  350. mindspore/ops/operations/reshard_ops.py +1 -1
  351. mindspore/ops/operations/sparse_ops.py +2 -2
  352. mindspore/ops/primitive.py +66 -53
  353. mindspore/ops/tensor_method.py +1888 -0
  354. mindspore/ops_generate/__init__.py +0 -5
  355. mindspore/ops_generate/aclnn/__init__.py +0 -0
  356. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  357. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  358. mindspore/ops_generate/api/__init__.py +0 -0
  359. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  360. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  361. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  362. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  363. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  364. mindspore/ops_generate/api/gen_api.py +103 -0
  365. mindspore/ops_generate/api/op_api_proto.py +235 -0
  366. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  367. mindspore/ops_generate/common/__init__.py +0 -0
  368. mindspore/ops_generate/common/base_generator.py +11 -0
  369. mindspore/ops_generate/common/gen_constants.py +91 -0
  370. mindspore/ops_generate/common/gen_utils.py +348 -0
  371. mindspore/ops_generate/common/op_proto.py +473 -0
  372. mindspore/ops_generate/common/template.py +523 -0
  373. mindspore/ops_generate/gen_ops.py +22 -1069
  374. mindspore/ops_generate/op_def/__init__.py +0 -0
  375. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  376. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  377. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  378. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  379. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  380. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  381. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  382. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  383. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  384. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  385. mindspore/ops_generate/pyboost/__init__.py +0 -0
  386. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  387. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  388. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  389. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  390. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  391. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  393. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  394. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  395. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  396. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  397. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  398. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  399. mindspore/ops_generate/resources/__init__.py +0 -0
  400. mindspore/ops_generate/resources/resource_list.py +30 -0
  401. mindspore/ops_generate/resources/resource_loader.py +36 -0
  402. mindspore/ops_generate/resources/resource_manager.py +64 -0
  403. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  404. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  405. mindspore/parallel/__init__.py +7 -3
  406. mindspore/parallel/_auto_parallel_context.py +152 -34
  407. mindspore/parallel/_cell_wrapper.py +130 -15
  408. mindspore/parallel/_parallel_serialization.py +107 -5
  409. mindspore/parallel/_ps_context.py +1 -1
  410. mindspore/parallel/_recovery_context.py +7 -2
  411. mindspore/parallel/_tensor.py +142 -18
  412. mindspore/parallel/_utils.py +199 -23
  413. mindspore/parallel/algo_parameter_config.py +4 -4
  414. mindspore/parallel/auto_parallel.py +732 -0
  415. mindspore/parallel/checkpoint_convert.py +159 -0
  416. mindspore/parallel/checkpoint_transform.py +698 -35
  417. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  418. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  419. mindspore/parallel/cluster/run.py +21 -4
  420. mindspore/parallel/function/__init__.py +24 -0
  421. mindspore/parallel/function/reshard_func.py +259 -0
  422. mindspore/parallel/nn/__init__.py +25 -0
  423. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  424. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  425. mindspore/parallel/parameter_broadcast.py +25 -14
  426. mindspore/parallel/shard.py +137 -58
  427. mindspore/parallel/transform_safetensors.py +363 -305
  428. mindspore/profiler/__init__.py +22 -5
  429. mindspore/profiler/analysis/__init__.py +0 -0
  430. mindspore/profiler/analysis/parser/__init__.py +0 -0
  431. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  432. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  433. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  434. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  435. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  436. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  437. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  440. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  441. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  446. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  447. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  448. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  449. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  450. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  451. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  452. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  453. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  454. mindspore/profiler/analysis/task_manager.py +131 -0
  455. mindspore/profiler/analysis/time_converter.py +84 -0
  456. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  457. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  458. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  459. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  460. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  461. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  462. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  463. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  464. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  465. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  466. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  467. mindspore/profiler/analysis/work_flow.py +73 -0
  468. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  469. mindspore/profiler/common/command_executor.py +90 -0
  470. mindspore/profiler/common/constant.py +186 -3
  471. mindspore/profiler/common/file_manager.py +208 -0
  472. mindspore/profiler/common/log.py +130 -0
  473. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  474. mindspore/profiler/common/path_manager.py +395 -0
  475. mindspore/profiler/common/process_bar.py +168 -0
  476. mindspore/profiler/common/process_pool.py +9 -3
  477. mindspore/profiler/common/profiler_context.py +500 -0
  478. mindspore/profiler/common/profiler_info.py +304 -0
  479. mindspore/profiler/common/profiler_meta_data.py +74 -0
  480. mindspore/profiler/common/profiler_output_path.py +284 -0
  481. mindspore/profiler/common/profiler_parameters.py +251 -0
  482. mindspore/profiler/common/profiler_path_manager.py +179 -0
  483. mindspore/profiler/common/record_function.py +76 -0
  484. mindspore/profiler/common/tlv_decoder.py +76 -0
  485. mindspore/profiler/common/util.py +75 -2
  486. mindspore/profiler/dynamic_profiler.py +341 -75
  487. mindspore/profiler/envprofiler.py +163 -0
  488. mindspore/profiler/experimental_config.py +197 -0
  489. mindspore/profiler/mstx.py +242 -0
  490. mindspore/profiler/platform/__init__.py +21 -0
  491. mindspore/profiler/platform/base_profiler.py +40 -0
  492. mindspore/profiler/platform/cpu_profiler.py +124 -0
  493. mindspore/profiler/platform/gpu_profiler.py +74 -0
  494. mindspore/profiler/platform/npu_profiler.py +335 -0
  495. mindspore/profiler/profiler.py +1073 -90
  496. mindspore/profiler/profiler_action_controller.py +187 -0
  497. mindspore/profiler/profiler_interface.py +118 -0
  498. mindspore/profiler/schedule.py +243 -0
  499. mindspore/rewrite/api/node.py +15 -13
  500. mindspore/rewrite/api/symbol_tree.py +2 -3
  501. mindspore/run_check/_check_version.py +27 -20
  502. mindspore/run_check/run_check.py +1 -1
  503. mindspore/runtime/__init__.py +37 -0
  504. mindspore/runtime/device.py +27 -0
  505. mindspore/runtime/event.py +209 -0
  506. mindspore/runtime/executor.py +177 -0
  507. mindspore/runtime/memory.py +409 -0
  508. mindspore/runtime/stream.py +460 -0
  509. mindspore/runtime/thread_bind_core.py +401 -0
  510. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  511. mindspore/swresample-4.dll +0 -0
  512. mindspore/swscale-6.dll +0 -0
  513. mindspore/tinyxml2.dll +0 -0
  514. mindspore/train/__init__.py +8 -8
  515. mindspore/train/_utils.py +88 -25
  516. mindspore/train/amp.py +9 -5
  517. mindspore/train/callback/__init__.py +2 -2
  518. mindspore/train/callback/_callback.py +2 -16
  519. mindspore/train/callback/_checkpoint.py +53 -55
  520. mindspore/train/callback/_cluster_monitor.py +14 -18
  521. mindspore/train/callback/_early_stop.py +1 -1
  522. mindspore/train/callback/_flops_collector.py +103 -68
  523. mindspore/train/callback/_history.py +8 -5
  524. mindspore/train/callback/_lambda_callback.py +2 -2
  525. mindspore/train/callback/_landscape.py +0 -3
  526. mindspore/train/callback/_loss_monitor.py +2 -1
  527. mindspore/train/callback/_on_request_exit.py +6 -5
  528. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  529. mindspore/train/callback/_summary_collector.py +52 -19
  530. mindspore/train/callback/_time_monitor.py +2 -1
  531. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  532. mindspore/train/data_sink.py +25 -2
  533. mindspore/train/dataset_helper.py +15 -16
  534. mindspore/train/loss_scale_manager.py +8 -7
  535. mindspore/train/metrics/accuracy.py +3 -3
  536. mindspore/train/metrics/confusion_matrix.py +9 -9
  537. mindspore/train/metrics/error.py +3 -3
  538. mindspore/train/metrics/hausdorff_distance.py +4 -4
  539. mindspore/train/metrics/mean_surface_distance.py +3 -3
  540. mindspore/train/metrics/metric.py +0 -12
  541. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  542. mindspore/train/metrics/precision.py +11 -10
  543. mindspore/train/metrics/recall.py +9 -9
  544. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  545. mindspore/train/mind_ir_pb2.py +174 -46
  546. mindspore/train/model.py +184 -113
  547. mindspore/train/serialization.py +622 -978
  548. mindspore/train/summary/_summary_adapter.py +2 -2
  549. mindspore/train/summary/summary_record.py +2 -3
  550. mindspore/train/train_thor/model_thor.py +1 -1
  551. mindspore/turbojpeg.dll +0 -0
  552. mindspore/utils/__init__.py +6 -3
  553. mindspore/utils/dryrun.py +140 -0
  554. mindspore/utils/hooks.py +81 -0
  555. mindspore/utils/runtime_execution_order_check.py +550 -0
  556. mindspore/utils/utils.py +138 -4
  557. mindspore/version.py +1 -1
  558. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  559. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  561. mindspore/_install_custom.py +0 -43
  562. mindspore/common/_register_for_adapter.py +0 -74
  563. mindspore/common/_tensor_overload.py +0 -139
  564. mindspore/mindspore_np_dtype.dll +0 -0
  565. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  566. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  567. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  568. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  569. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  570. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  571. mindspore/ops_generate/gen_utils.py +0 -209
  572. mindspore/ops_generate/op_proto.py +0 -145
  573. mindspore/ops_generate/template.py +0 -261
  574. mindspore/profiler/envprofiling.py +0 -254
  575. mindspore/profiler/profiling.py +0 -1926
  576. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  577. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -18,33 +18,46 @@ from __future__ import absolute_import
18
18
  import os
19
19
  import glob
20
20
  import copy
21
+ from multiprocessing import Pool
21
22
  from collections import defaultdict
22
23
  import numpy as np
23
24
  import mindspore as ms
25
+ from mindspore import log as logger
26
+ from mindspore import _checkparam as Validator
24
27
  from mindspore.common import dtype as mstype
25
- from mindspore.parallel._utils import _is_in_auto_parallel_mode, _get_pipeline_stages
28
+ from mindspore.common.parameter import Parameter
29
+ from mindspore.common.tensor import Tensor
30
+ from mindspore.communication.management import get_rank, get_group_size
31
+ from mindspore.parallel._tensor import _load_tensor, _reshape_param_data, _reshape_param_data_with_weight, \
32
+ _get_tensor_slice_index, _get_tensor_strategy
33
+ from mindspore.parallel._utils import _is_in_auto_parallel_mode, _get_pipeline_stages, _infer_rank_list, \
34
+ _remove_repeated_slices, _get_auto_parallel_net
26
35
  from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
27
- _transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
36
+ _transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, _build_searched_strategy, \
28
37
  _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
29
- _merge_protobuf_strategy, _merge_json_strategy, _extract_src_dst_layout_map_by_src
30
- from mindspore.parallel.transform_safetensors import _transform_safetensors, _collect_safetensor_files
38
+ _merge_protobuf_strategy, _merge_json_strategy, _extract_src_dst_layout_map_by_src, _convert_to_list, \
39
+ _check_checkpoint_file, _check_predict_strategy, _gather_tasks_load_dis, _get_param_list_when_first_dim_sharded, \
40
+ _convert_to_layout, _restore_group_info_list
41
+ from mindspore._c_expression import AutoParallelContext
42
+ from mindspore.parallel.transform_safetensors import _transform_safetensors, _collect_safetensor_files, \
43
+ _load_parallel_checkpoint
31
44
 
32
45
  __all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
33
- "transform_checkpoints", "sync_pipeline_shared_parameters", "load_segmented_checkpoints"]
46
+ "transform_checkpoints", "sync_pipeline_shared_parameters", "load_segmented_checkpoints",
47
+ "load_distributed_checkpoint", "merge_sliced_parameter", "restore_group_info_list",
48
+ "build_searched_strategy"]
34
49
 
35
50
 
36
51
  def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
37
52
  """
38
- Merge parallel strategy between all pipeline stages in pipeline parallel mode.
39
- For more details about converting distributed Checkpoint, please refer to
40
- `Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
53
+ Aggregate the sharding strategy files of all pipeline parallel subgraphs to the destination file.
41
54
 
42
55
  Note:
43
56
  Strategy file of each pipeline stage should be included in src_strategy_dirs.
44
57
 
45
58
  Args:
46
59
  src_strategy_dirs (str): The directory of strategy files including all pipeline stage which is saved by
47
- 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
60
+ :func:`mindspore.parallel.auto_parallel.AutoParallel.save_param_strategy_file`.
48
61
  dst_strategy_file (str): The file merged strategy to save.
49
62
 
50
63
  Raises:
@@ -53,7 +66,7 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
53
66
  Examples:
54
67
  >>> import mindspore as ms
55
68
  >>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt
56
- >>> ms.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt")
69
+ >>> ms.parallel.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt")
57
70
 
58
71
  """
59
72
  dst_strategy_dir, _ = os.path.split(dst_strategy_file)
@@ -72,11 +85,211 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
72
85
  _merge_json_strategy(src_strategy_files_json, dst_strategy_file)
73
86
 
74
87
 
88
+ def merge_sliced_parameter(sliced_parameters, strategy=None):
89
+ """
90
+ Merge parameter slices into one parameter. Used in the case of distributed inference.
91
+
92
+ Args:
93
+ sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
94
+ strategy (Optional[dict], optional): Parameter slice strategy, whose key is parameter name and
95
+ value is slice strategy of this parameter. If strategy is None, just merge
96
+ parameter slices in 0 axis order. Default: ``None``.
97
+
98
+ Returns:
99
+ Parameter, the merged parameter which has the whole data.
100
+
101
+ Raises:
102
+ ValueError: Failed to merge.
103
+ TypeError: The sliced_parameters is incorrect or strategy is not dict.
104
+ KeyError: The parameter name is not in keys of strategy.
105
+
106
+ Examples:
107
+ >>> import numpy as np
108
+ >>> import mindspore as ms
109
+ >>> from mindspore import Tensor, Parameter
110
+ >>>
111
+ >>> sliced_parameters = [
112
+ ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
113
+ ... "network.embedding_table"),
114
+ ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
115
+ ... "network.embedding_table"),
116
+ ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
117
+ ... "network.embedding_table"),
118
+ ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
119
+ ... "network.embedding_table")]
120
+ >>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters)
121
+ >>> print(merged_parameter)
122
+ Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
123
+ """
124
+ if not isinstance(sliced_parameters, list):
125
+ raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
126
+ f"but got {type(sliced_parameters)}.")
127
+
128
+ if not sliced_parameters:
129
+ raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
130
+
131
+ if strategy and not isinstance(strategy, dict):
132
+ raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
133
+ f"but got {type(strategy)}.")
134
+
135
+ try:
136
+ parameter_name = sliced_parameters[0].name
137
+ parameter_shape = sliced_parameters[0].data.shape
138
+ parameter_shape_length = len(parameter_shape)
139
+ except BaseException as e:
140
+ raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be "
141
+ f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e
142
+
143
+ is_even = True
144
+ for index, parameter in enumerate(sliced_parameters):
145
+ if not isinstance(parameter, Parameter):
146
+ raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', "
147
+ f"but got {type(parameter)} at index {index}.")
148
+
149
+ if parameter.name != parameter_name \
150
+ or len(parameter.data.shape) != parameter_shape_length \
151
+ or parameter.data.shape[1:] != parameter_shape[1:]:
152
+ raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'"
153
+ f" have the same name, dimension length and shape except 0 axis. The name, dimension "
154
+ f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, "
155
+ f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: "
156
+ f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} "
157
+ f"at index {index}.")
158
+
159
+ if parameter.data.shape != parameter_shape:
160
+ is_even = False
161
+
162
+ layerwise_parallel = sliced_parameters[0].layerwise_parallel
163
+ requires_grad = sliced_parameters[0].requires_grad
164
+ sliced_data = []
165
+ for parameter in sliced_parameters:
166
+ if parameter.data.dtype == mstype.bfloat16:
167
+ from mindspore.ops import Cast
168
+ cpu_cast = Cast().set_device("CPU")
169
+ sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
170
+ else:
171
+ sliced_data.append(parameter.data.asnumpy())
172
+
173
+ if not strategy:
174
+ merged_tensor = Tensor(np.concatenate(sliced_data))
175
+ merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
176
+
177
+ else:
178
+ if parameter_name not in strategy.keys():
179
+ raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in "
180
+ f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.")
181
+ merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
182
+ merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
183
+
184
+ return merged_parameter
185
+
186
+
187
+ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
188
+ """Merge sliced parameter and split it according to the predict strategy."""
189
+ merged_param = merge_sliced_parameter(sliced_params, train_strategy)
190
+ if not predict_strategy:
191
+ return merged_param
192
+ param_name = merged_param.name
193
+ tensor_layout = predict_strategy[param_name]
194
+ rank = get_rank()
195
+ split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
196
+ requires_grad = merged_param.requires_grad
197
+ layerwise_parallel = merged_param.layerwise_parallel
198
+ if merged_param.data.dtype == mstype.bfloat16:
199
+ split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
200
+ else:
201
+ split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
202
+ return split_param
203
+
204
+
205
+ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
206
+ """
207
+ Merge data slices to one tensor with whole data when strategy is not None.
208
+
209
+ Args:
210
+ sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
211
+ parameter_name (str): Name of parameter.
212
+ strategy (dict): Parameter slice strategy.
213
+ is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
214
+
215
+ Returns:
216
+ Tensor, the merged Tensor which has the whole data.
217
+
218
+ Raises:
219
+ ValueError: Failed to merge.
220
+ """
221
+ layout = strategy.get(parameter_name)
222
+ try:
223
+ dev_mat = list(layout.dev_matrix[0].dim)
224
+ tensor_map = list(layout.tensor_map[0].dim)
225
+ param_split_shape = list(layout.param_split_shape[0].dim)
226
+ field_size = int(layout.field)
227
+ except BaseException as e:
228
+ raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'"
229
+ f", please make sure that 'strategy' is correct.") from e
230
+
231
+ device_count = 1
232
+ for dim in dev_mat:
233
+ device_count *= dim
234
+
235
+ if len(sliced_data) != device_count:
236
+ raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to "
237
+ f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but "
238
+ f"device_count is {device_count}.")
239
+
240
+ if not param_split_shape:
241
+ if not is_even:
242
+ raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' "
243
+ "should be the same when slice manner is even.")
244
+
245
+ all_gather_tensor = Tensor(np.concatenate(sliced_data))
246
+
247
+ if field_size > 0:
248
+ merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
249
+ else:
250
+ merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
251
+
252
+ else:
253
+ tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
254
+
255
+ slice_count = 1
256
+ for dim in tensor_strategy:
257
+ slice_count *= dim
258
+
259
+ if len(param_split_shape) != slice_count:
260
+ raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be "
261
+ f"{slice_count}, but got {len(param_split_shape)}.")
262
+
263
+ tensor_slices_new = list(range(slice_count))
264
+ tensor_slices = sliced_data
265
+ for i in range(device_count):
266
+ slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
267
+ if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
268
+ raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be "
269
+ f"{param_split_shape[slice_index]} in 0 axis, but got "
270
+ f"{tensor_slices[i].shape[0]}.")
271
+ tensor_slices_new[slice_index] = np.array(tensor_slices[i])
272
+
273
+ dim_len = len(tensor_strategy)
274
+ for i in range(dim_len):
275
+ ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
276
+ tensor_slices_new_inner = []
277
+ for j in range(ele_count):
278
+ new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
279
+ for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
280
+ (j + 1) * tensor_strategy[dim_len - 1 - i]):
281
+ new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
282
+ tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
283
+ tensor_slices_new = tensor_slices_new_inner
284
+ merged_tensor = Tensor(tensor_slices_new[0])
285
+
286
+ return merged_tensor
287
+
288
+
75
289
  def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None):
76
290
  """
77
291
  List of original distributed checkpoint rank index for obtaining the target checkpoint of a rank_id during the
78
- distributed checkpoint conversion. For more details about converting distributed Checkpoint, please refer to
79
- `Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
292
+ distributed checkpoint conversion.
80
293
 
81
294
  Args:
82
295
  rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
@@ -101,7 +314,7 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
101
314
  Examples:
102
315
  >>> import mindspore as ms
103
316
  >>> rank_id = 0
104
- >>> rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
317
+ >>> rank_list = ms.parallel.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
105
318
  >>> checkpoint_files_map = {}
106
319
  >>> for rank in rank_list:
107
320
  ... checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
@@ -140,8 +353,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
140
353
  src_strategy_file=None, dst_strategy_file=None):
141
354
  """
142
355
  Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank
143
- for a network. For more details about converting distributed Checkpoint, please refer to
144
- `Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
356
+ for a network.
145
357
 
146
358
  Args:
147
359
  rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
@@ -149,11 +361,11 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
149
361
  the checkpoint file name.
150
362
  save_checkpoint_file_name (str): The file name to save the converted checkpoint.
151
363
  src_strategy_file (str): Name of source sharding strategy file which saved by
152
- 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
364
+ `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
153
365
  when the `src_strategy_file` is None, it means that the source sharding strategy is
154
366
  without any sharing for each parameter. Default: ``None``.
155
367
  dst_strategy_file (str): Name of destination sharding strategy file which saved by
156
- 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
368
+ `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
157
369
  when the `dst_strategy_file` is ``None``,
158
370
  it means that the destination sharding strategy
159
371
  is without any sharing for each parameter. Default: ``None``.
@@ -361,8 +573,6 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
361
573
  dst_strategy_file=None, process_num=1, output_format="ckpt"):
362
574
  """
363
575
  Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank.
364
- For more details about converting distributed Checkpoint, please refer to
365
- `Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
366
576
 
367
577
  Note:
368
578
  The `src_checkpoints_dir` directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the
@@ -386,7 +596,7 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
386
596
  is without any sharing for each parameter. Default:None.
387
597
  process_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
388
598
  output_format (str, optional): Control the format of the output checkpoint after conversion.
389
- It can be set to either "ckpt" or "safetensors". Default: "ckpt".
599
+ It can be set to either ``"ckpt"`` or ``"safetensors"``. Default: ``"ckpt"``.
390
600
 
391
601
  Raises:
392
602
  ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
@@ -472,18 +682,21 @@ def _sync_params(name, param, layout):
472
682
  shape=param.shape,
473
683
  dtype=param.dtype)(param))
474
684
 
475
-
685
+ # pylint: disable=W0212
476
686
  def sync_pipeline_shared_parameters(net):
477
- """synchronize pipeline parallel stage shared parameters.
478
- Parameters may be shared between different stages. For example, `embedding table` is
687
+ """Synchronization of shared weights between stages for pipeline parallel inference scenarios.
688
+ For example, `embedding table` is
479
689
  shared by `WordEmbedding` layer and `LMHead` layer, which are usually split into different stages. It is necessary
480
690
  to perform synchronization after `embedding table` changes.
481
691
 
482
692
  Note:
483
- The network should be compiled before synchronize pipeline parallel stage shared parameters.
693
+ The network should be compiled before shared parameters are synchronized in the pipeline parallel stage.
484
694
 
485
695
  Args:
486
- net (nn.Cell): the inference network.
696
+ net (Cell): the inference network.
697
+
698
+ Raises:
699
+ TypeError: `net` is not in Cell type.
487
700
 
488
701
  Supported Platforms:
489
702
  ``Ascend``
@@ -493,12 +706,13 @@ def sync_pipeline_shared_parameters(net):
493
706
  Before running the following examples, you need to configure the communication environment variables.
494
707
 
495
708
  For the Ascend device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
496
- Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
709
+ Startup <https://www.mindspore.cn/tutorials/en/master/parallel/dynamic_cluster.html>`_ .
497
710
 
498
711
  >>> import numpy as np
499
712
  >>> import mindspore as ms
500
713
  >>> import mindspore.communication.management as D
501
714
  >>> from mindspore import lazy_inline, context, nn, ops, Parameter, Tensor
715
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
502
716
  >>> context.set_context(mode=context.GRAPH_MODE)
503
717
  >>> class Embedding(nn.Cell):
504
718
  ... def __init__(self, shape):
@@ -546,14 +760,16 @@ def sync_pipeline_shared_parameters(net):
546
760
  ... ret = self.concat(ret)
547
761
  ... return ret
548
762
  >>> D.init()
549
- >>> context.set_auto_parallel_context(parallel_mode='semi_auto_parallel', full_batch=True, pipeline_stages=2)
550
763
  >>> net = Network()
551
764
  >>> net = PipelineCellInference(net, 2)
552
765
  >>> net.set_train(False)
553
766
  >>> x = Tensor(np.ones((2, 4)), ms.float32)
554
767
  >>> net.compile(x)
555
- >>> ms.sync_pipeline_shared_parameters(net)
556
- >>> print(net.network.word_embedding.w.asnumpy())
768
+ >>> pp_net = AutoParallel(net, parallel_mode="semi_auto")
769
+ >>> pp_net.full_batch = True
770
+ >>> pp_net.pipeline(stages=2, scheduler="1f1b")
771
+ >>> ms.parallel.sync_pipeline_shared_parameters(pp_net)
772
+ >>> print(pp_net.network.network.word_embedding.w.asnumpy())
557
773
  [[1. 1. 1. 1.]
558
774
  [1. 1. 1. 1.]
559
775
  [1. 1. 1. 1.]
@@ -566,18 +782,25 @@ def sync_pipeline_shared_parameters(net):
566
782
  "but got {}.".format(type(net)))
567
783
  raise TypeError(msg)
568
784
 
569
- if _get_pipeline_stages() < 2:
785
+ parallel_net = _get_auto_parallel_net(net)
786
+ pipeline_stages = 1
787
+ if type(parallel_net).__name__ != 'AutoParallel':
788
+ pipeline_stages = _get_pipeline_stages()
789
+ else:
790
+ pipeline_stages = parallel_net._pipeline_stages
791
+ if pipeline_stages < 2:
570
792
  return
571
793
 
572
794
  layout_dict = net.parameter_layout_dict
573
- if _is_in_auto_parallel_mode() and not layout_dict:
795
+ if (_is_in_auto_parallel_mode() or (type(parallel_net).__name__ == 'AutoParallel')) and not layout_dict:
574
796
  from mindspore.common.api import _get_parameter_layout
575
797
  layout_dict = _get_parameter_layout()
576
798
 
577
799
  # switch to standalone mode
578
- parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
579
- full_batch = ms.context.get_auto_parallel_context("full_batch")
580
- ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False)
800
+ if type(parallel_net).__name__ != 'AutoParallel':
801
+ parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
802
+ full_batch = ms.context.get_auto_parallel_context("full_batch")
803
+ ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False)
581
804
 
582
805
  # synchronize shared parameter
583
806
  for name, param in net.parameters_and_names():
@@ -585,7 +808,8 @@ def sync_pipeline_shared_parameters(net):
585
808
  _sync_params(name, param, layout_dict[name])
586
809
 
587
810
  # restore parallel context
588
- ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)
811
+ if type(parallel_net).__name__ != 'AutoParallel':
812
+ ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)
589
813
 
590
814
 
591
815
  def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filter_prefix=None,
@@ -635,6 +859,9 @@ def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filte
635
859
  ValueError: Checkpoint file's format is incorrect.
636
860
  ValueError: Parameter's dict is None after load checkpoint file.
637
861
  TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
862
+
863
+ Supported Platforms:
864
+ ``Ascend``
638
865
  """
639
866
  if not isinstance(ckpt_file_dir, str):
640
867
  raise TypeError("The ckpt_file_dir should be a str.")
@@ -648,3 +875,439 @@ def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filte
648
875
  parameter_dict.update(ms.load_checkpoint(checkpoint_file, net, strict_load, filter_prefix, dec_key,
649
876
  dec_mode, specify_prefix, choice_func))
650
877
  return parameter_dict
878
+
879
+
880
+ def set_op_strategy_config(mode="SAVE", path=""):
881
+ """
882
+ Set strategy json configuration when using sharding propagation.
883
+
884
+ .. warning::
885
+ - This is an experimental interface, may be changed or canceled in the future, please use the api
886
+ :func:`mindspore.parallel.auto_parallel.AutoParallel.load_operator_strategy_file` or
887
+ :func:`mindspore.parallel.auto_parallel.AutoParallel.save_operator_strategy_file` instead;
888
+ - This interface currently doesn't support saving or loading strategies using layout.
889
+
890
+ Note:
891
+ - It only works when `parallel_mode=ParallelMode.AUTO_PARALLEL` and `search_mode='sharding_propagation'`.
892
+ - It only supports saving and reloading with the same configuration for the same network. If the network
893
+ or training hyperparameters are modified after using the `SAVE` mode to save the strategies of operator
894
+ to the setting json file, which may lead to the failure of using the `LOAD` mode to load operator
895
+ strategies from json.
896
+ - When performing distributed training, users can first save the strategy using dryrun on a single device
897
+ and then load strategy to perform distributed training.
898
+
899
+ Args:
900
+ mode (str): The parameter for choosing save or load .json file. Default value: ``"SAVE"`` .
901
+ path (str): Path to save or load parallel strategy json, must be an absolute path. Default value: ``""`` .
902
+
903
+ Raises:
904
+ KeyError: When type is not ``"SAVE"`` or ``"LOAD"`` .
905
+ KeyError: When path does not end in ``".json"`` .
906
+ KeyError: When path is not an absolute path.
907
+ """
908
+ if not os.path.isabs(path):
909
+ raise KeyError("File path must be an absolute path")
910
+ _, file_type = os.path.splitext(path)
911
+ if file_type != ".json":
912
+ raise KeyError("File type must be .json")
913
+ dir_path = os.path.dirname(path)
914
+ if dir_path and not os.path.exists(dir_path):
915
+ os.makedirs(dir_path, mode=0o700, exist_ok=True)
916
+ check_mode_type = ["SAVE", "LOAD"]
917
+ if mode in check_mode_type:
918
+ if AutoParallelContext.get_instance() is None:
919
+ raise ValueError("Get AutoParallelContext instance failed!!!")
920
+ AutoParallelContext.get_instance().set_ops_strategy_json_config(mode, path, "all")
921
+ else:
922
+ raise KeyError("Type must be 'SAVE' or 'LOAD'")
923
+
924
+
925
+ def build_searched_strategy(strategy_filename):
926
+ """
927
+ Extract the sharding strategy for each parameter in the network
928
+ from the strategy file for distributed inference scenarios.
929
+
930
+ Args:
931
+ strategy_filename (str): Name of strategy file.
932
+
933
+ Returns:
934
+ Dict, whose key is parameter name and value is slice strategy of this parameter.
935
+
936
+ Raises:
937
+ ValueError: Strategy file is incorrect.
938
+ TypeError: `strategy_filename` is not a string.
939
+
940
+ Supported Platforms:
941
+ ``Ascend``
942
+
943
+ Examples:
944
+ >>> from mindspore.parallel import build_searched_strategy
945
+ >>> strategy = build_searched_strategy("./strategy_train.ckpt")
946
+ """
947
+ return _build_searched_strategy(strategy_filename)
948
+
949
+
950
+ # disable pylint too broad Exception
951
+ # pylint: disable=W0212
952
+ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
953
+ train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
954
+ format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
955
+ output_format='safetensors', name_map=None, max_process_num=64,
956
+ return_param_dict=False):
957
+ """
958
+ Load checkpoint into net for distributed predication. Used in the case of distributed inference.
959
+
960
+ Note:
961
+ `output_format` will only take effect when `format` is set to `safetensors` and `network` is set to `None`.
962
+
963
+ Args:
964
+ network (Cell): Network for distributed predication, When the format is `safetensors`, the network parameter
965
+ can be left blank or passed as None, and the interface will execute save mode.
966
+ checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
967
+ predict_strategy (Union[dict, str]): Strategy of predication process. It means that using one device to predict
968
+ when setting predict_strategy as None. Default: ``None`` .
969
+ train_strategy_filename (str): The filename of training strategy protocol buffer file.
970
+ When train_strategy_filename is None, the training strategy file will be
971
+ obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
972
+ Therefore, the training strategy file needs to be specified
973
+ in at least one of them. Default: ``None`` .
974
+ strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
975
+ into net when parameter name's suffix in checkpoint file is the same as the
976
+ parameter in the network. When the types are inconsistent, perform type conversion
977
+ on the parameters of the same type, such as float32 to float16. Default: ``False`` .
978
+ dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
979
+ is not required. Default: ``None`` .
980
+ dec_mode (str): Specifies the decryption
981
+ mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
982
+ This parameter is valid only when dec_key is not set to ``None`` .
983
+ Default: ``'AES-GCM'`` .
984
+ format (str): Input weight format to be loaded into the network.
985
+ It can be set to either "ckpt" or "safetensors". Default: "ckpt".
986
+ unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
987
+ Default: ``None`` .
988
+ dst_safetensors_dir (str): In the save mode scenario, the save directory for weights.
989
+ rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
990
+ globally by initializing the network; In save mode, save the file according to the input
991
+ sequence number. If it is not input, save the entire file.
992
+ output_format (str, optional): Control the format of the output checkpoint after conversion.
993
+ It can be set to either "ckpt" or "safetensors". Default: "safetensors".
994
+ name_map (dict): The weight mapping dictionary will modify the weight names according to the mapping
995
+ dictionary before loading or saving the segmented weights into the network. Default: None.
996
+ max_process_num (int): Maximum number of processes. Default: 64.
997
+ return_param_dict (bool): Whether to return the param_dict. Default: ``False``.
998
+
999
+ Raises:
1000
+ TypeError: The type of inputs do not match the requirements.
1001
+ ValueError: Failed to load checkpoint into net.
1002
+
1003
+ Supported Platforms:
1004
+ ``Ascend``
1005
+
1006
+ Examples:
1007
+ .. note::
1008
+ Before running the following examples, you need to configure the communication environment variables.
1009
+
1010
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
1011
+ Please see the `rank table startup
1012
+ <https://www.mindspore.cn/tutorials/en/master/parallel/rank_table.html>`_
1013
+ for more details.
1014
+
1015
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
1016
+ Startup <https://www.mindspore.cn/tutorials/en/master/parallel/dynamic_cluster.html>`_ .
1017
+
1018
+ >>> import os
1019
+ >>> import numpy as np
1020
+ >>> import mindspore as ms
1021
+ >>> import mindspore.dataset as ds
1022
+ >>> from mindspore import nn, ops, train
1023
+ >>> from mindspore.communication import init
1024
+ >>> from mindspore.parallel import load_distributed_checkpoint
1025
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
1026
+ >>> from mindspore.nn.utils import no_init_parameters
1027
+ >>> from mindspore.common.initializer import initializer, One
1028
+ >>>
1029
+ >>> step_per_epoch = 4
1030
+ >>>
1031
+ >>> # Define the network structure.
1032
+ >>> class Net(nn.Cell):
1033
+ ... def __init__(self, matmul_size, strategy=None):
1034
+ ... super().__init__()
1035
+ ... self.matmul_weight = ms.Parameter(initializer(One(), matmul_size, ms.float32))
1036
+ ... self.matmul = ops.MatMul()
1037
+ ... self.neg = ops.Neg()
1038
+ ... if strategy is not None:
1039
+ ... self.matmul.shard(strategy)
1040
+ ...
1041
+ ... def construct(self, inputs):
1042
+ ... x = self.matmul(inputs, self.matmul_weight)
1043
+ ... x = self.neg(x)
1044
+ ... return x
1045
+ >>>
1046
+ >>> # Create dataset.
1047
+ >>> def get_dataset(*inputs):
1048
+ ... def generate():
1049
+ ... for _ in range(step_per_epoch):
1050
+ ... yield inputs
1051
+ ... return generate
1052
+ >>>
1053
+ >>> # Train network and save distributed checkpoint.
1054
+ >>> def train_net():
1055
+ ... ms.set_context(mode=ms.GRAPH_MODE)
1056
+ ... init()
1057
+ ... np.random.seed(1)
1058
+ ... input_data = np.random.rand(16, 96).astype(np.float32)
1059
+ ... label_data = np.random.rand(16, 16).astype(np.float32)
1060
+ ... fake_dataset = get_dataset(input_data, label_data)
1061
+ ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
1062
+ ...
1063
+ ... # Set parallel strategy.
1064
+ ... strategy = ((1, 4), (4, 1))
1065
+ ... with no_init_parameters():
1066
+ ... network = Net(matmul_size=(96, 16), strategy=strategy)
1067
+ ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
1068
+ ...
1069
+ ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
1070
+ ... network = AutoParallel(network, parallel_mode="semi_auto")
1071
+ ... network.save_param_strategy_file(file_path="./train_strategy.ckpt")
1072
+ ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
1073
+ ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
1074
+ ... global_rank_id = int(os.getenv("RANK_ID"))
1075
+ ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
1076
+ ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
1077
+ ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
1078
+ >>>
1079
+ >>> # Load distributed checkpoint and test.
1080
+ >>> def load_model():
1081
+ ... ms.set_context(mode=ms.GRAPH_MODE)
1082
+ ... init()
1083
+ ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
1084
+ ... with no_init_parameters():
1085
+ ... network = Net(matmul_size=(96, 16))
1086
+ ... network = AutoParallel(network, parallel_mode="semi_auto")
1087
+ ... network.dataset_strategy(config="full_batch")
1088
+ ... train_strategy_file = "./train_strategy.ckpt"
1089
+ ... network.save_param_strategy_file(file_path=train_strategy_file)
1090
+ ... model = ms.Model(network)
1091
+ ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
1092
+ ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
1093
+ ... load_distributed_checkpoint(network, ckpt_file_list, predict_layout, None)
1094
+ ... predict_result = model.predict(predict_data)
1095
+ ... print(predict_result)
1096
+ >>>
1097
+ >>> train_net()
1098
+ >>> load_model()
1099
+ [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
1100
+ [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
1101
+ ...
1102
+ [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
1103
+ """
1104
+ if format not in ['safetensors', 'ckpt'] or output_format not in ['safetensors', 'ckpt']:
1105
+ raise ValueError(
1106
+ f"For 'load_distributed_checkpoint', 'format' and 'output_format' "
1107
+ f"must be 'ckpt' or 'safetensors', but got {format}.")
1108
+
1109
+ if format == 'safetensors':
1110
+ if unified_safetensors_dir is None:
1111
+ raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
1112
+ f"when format is 'safetensors'.")
1113
+ unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
1114
+ for param in unsupport_param:
1115
+ if param is not None:
1116
+ raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
1117
+ f"when format is 'safetensors'.")
1118
+ if strict_load or dec_mode != 'AES-GCM':
1119
+ raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
1120
+ f"when format is 'safetensors'.")
1121
+ if network is not None:
1122
+ try:
1123
+ rank_id = get_rank()
1124
+ except RuntimeError:
1125
+ rank_id = 0
1126
+ logger.warning(f"Get rank failed, default loading weight for rank 0.")
1127
+ param_dict = _load_parallel_checkpoint(
1128
+ (unified_safetensors_dir, predict_strategy, network, None, rank_id, output_format, name_map,
1129
+ return_param_dict))
1130
+ return param_dict
1131
+ if dst_safetensors_dir is None:
1132
+ raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
1133
+ f"when network is None.")
1134
+ if rank_id is not None:
1135
+ _load_parallel_checkpoint(
1136
+ (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
1137
+ rank_id, output_format, name_map, return_param_dict))
1138
+ else:
1139
+ dst_strategy_dict = _build_searched_strategy(predict_strategy)
1140
+ dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
1141
+ dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
1142
+ dst_device_num = dst_stage_device_num * dst_stage_num
1143
+ tasks = _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
1144
+ dst_device_num, output_format, name_map, return_param_dict)
1145
+ with Pool(processes=max_process_num) as pool:
1146
+ list(pool.imap(_load_parallel_checkpoint, tasks))
1147
+ return True
1148
+
1149
+ network = Validator.check_isinstance("network", network, ms.nn.Cell)
1150
+ _check_checkpoint_file(checkpoint_filenames)
1151
+ _check_predict_strategy(predict_strategy)
1152
+
1153
+ dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1154
+ dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1155
+
1156
+ if train_strategy_filename is None:
1157
+ parallel_net = _get_auto_parallel_net(network)
1158
+ if parallel_net.__class__.__name__ == "AutoParallel":
1159
+ train_strategy_filename = parallel_net._save_strategy_file_path
1160
+ else:
1161
+ train_strategy_filename = ms.context.get_auto_parallel_context("strategy_ckpt_load_file")
1162
+
1163
+ _train_strategy = build_searched_strategy(train_strategy_filename)
1164
+ train_strategy = _convert_to_list(_train_strategy)
1165
+
1166
+ train_dev_count = 1
1167
+ ckpt_file_len = len(checkpoint_filenames)
1168
+ for dim in train_strategy[list(train_strategy.keys())[0]][0]:
1169
+ train_dev_count *= dim
1170
+ if train_dev_count != ckpt_file_len:
1171
+ raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
1172
+ f"equal to the device count of training process. "
1173
+ f"But got the length of 'checkpoint_filenames'"
1174
+ f" is {ckpt_file_len} and the device count is {train_dev_count}.")
1175
+ rank_list = _infer_rank_list(train_strategy, predict_strategy)
1176
+
1177
+ param_total_dict = defaultdict(dict)
1178
+ for file_index, file_name in enumerate(checkpoint_filenames):
1179
+ ckpt_dict = ms.load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
1180
+ for param_name, param in ckpt_dict.items():
1181
+ param_total_dict[param_name][file_index] = param
1182
+
1183
+ param_dict = {}
1184
+ param_not_in_strategy = []
1185
+ param_not_in_ckpt = []
1186
+ for _, param in network.parameters_and_names():
1187
+ sliced_params = []
1188
+ if param.name not in rank_list.keys():
1189
+ param_not_in_strategy.append(param.name)
1190
+ continue
1191
+ if param.name not in param_total_dict:
1192
+ param_not_in_ckpt.append(param.name)
1193
+ continue
1194
+
1195
+ param_rank = rank_list.get(param.name)[0]
1196
+ skip_merge_split = rank_list.get(param.name)[1]
1197
+ shard_stride = train_strategy.get(param.name)[4]
1198
+ tensor_map = train_strategy.get(param.name)[1]
1199
+ first_dim_shard_idx = tensor_map[0] if tensor_map else -1
1200
+ device_arrangement = train_strategy.get(param.name)[0]
1201
+ first_dim_shard_size = 1
1202
+ if first_dim_shard_idx >= 0:
1203
+ first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
1204
+ if train_strategy.get(param.name)[5]:
1205
+ repeat_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
1206
+ else:
1207
+ repeat_size = 0
1208
+ for rank in param_rank:
1209
+ param_total_list = list(range(0, ckpt_file_len))
1210
+ if first_dim_shard_size != 1:
1211
+ param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
1212
+ if repeat_size > 0:
1213
+ shard_size = shard_stride * train_strategy.get(param.name)[5]
1214
+ rank_index = param_total_list.index(rank)
1215
+ start = rank_index // shard_size * shard_size
1216
+ param_total_list = param_total_list[start:start + shard_size]
1217
+ if shard_stride > 0:
1218
+ param_stride = []
1219
+ # merge pre parameter
1220
+ param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
1221
+ param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
1222
+ param_index = list(set(param_index))
1223
+ param_index.sort()
1224
+ for rank_num in param_index:
1225
+ if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
1226
+ from mindspore.ops import Cast
1227
+ cpu_cast = Cast().set_device("CPU")
1228
+ param_stride.append(
1229
+ cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
1230
+ else:
1231
+ param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
1232
+
1233
+ sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
1234
+ else:
1235
+ sliced_param = param_total_dict[param.name][rank]
1236
+
1237
+ sliced_params.append(sliced_param)
1238
+ if skip_merge_split:
1239
+ split_param = sliced_params[0]
1240
+ else:
1241
+ param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
1242
+ _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
1243
+ split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
1244
+ opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
1245
+ if opt_shard_group:
1246
+ if split_param.data.dtype == mstype.bfloat16:
1247
+ from mindspore.ops import Cast
1248
+ cpu_cast = Cast().set_device("CPU")
1249
+ data = cpu_cast(split_param.data, mstype.float32).asnumpy()
1250
+ else:
1251
+ data = split_param.data.asnumpy()
1252
+ rank = get_rank(opt_shard_group)
1253
+ size = get_group_size(opt_shard_group)
1254
+ try:
1255
+ data_slice = np.split(data, size)[rank]
1256
+ except BaseException as e:
1257
+ logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
1258
+ " and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
1259
+ raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
1260
+ f" in load distributed checkpoint for {param.name}. Data shape is "
1261
+ f"{split_param.data.shape} and group is {opt_shard_group}.") from e
1262
+ split_param = Parameter(Tensor(data_slice), param.name,
1263
+ split_param.requires_grad, split_param.layerwise_parallel)
1264
+ param_dict[param.name] = split_param
1265
+
1266
+ if param_not_in_strategy:
1267
+ logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, "
1268
+ "you can check whether 'predict_strategy' or 'train_strategy_filename' is correct."
1269
+ .format(param_not_in_strategy))
1270
+ if param_not_in_ckpt:
1271
+ logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in "
1272
+ "the checkpoint file, please check whether 'checkpoint_filenames' is correct."
1273
+ .format(param_not_in_ckpt))
1274
+
1275
+ ms.load_param_into_net(network, param_dict, strict_load=strict_load)
1276
+ return True
1277
+
1278
+
1279
+ def restore_group_info_list(group_info_file_name):
1280
+ """
1281
+ Extract rank list information from communication domain files. To save the group info file,
1282
+ please export GROUP_INFO_FIL
1283
+ environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
1284
+
1285
+ Args:
1286
+ group_info_file_name (str): Name of group information file.
1287
+
1288
+ Returns:
1289
+ List, the rank list.
1290
+
1291
+ Raises:
1292
+ ValueError: group information file is incorrect.
1293
+ TypeError: `group_info_file_name` is not str.
1294
+
1295
+ Supported Platforms:
1296
+ ``Ascend``
1297
+
1298
+ Examples:
1299
+ >>> import mindspore as ms
1300
+ >>> from mindspore.parallel import restore_group_info_list
1301
+ >>> ms.restore_list = restore_group_info_list("./group_info.pb")
1302
+ """
1303
+ if not isinstance(group_info_file_name, str):
1304
+ raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
1305
+ f"but got {type(group_info_file_name)}.")
1306
+
1307
+ if not os.path.isfile(group_info_file_name):
1308
+ raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
1309
+
1310
+ if os.path.getsize(group_info_file_name) == 0:
1311
+ raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
1312
+
1313
+ return _restore_group_info_list(group_info_file_name)