mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ from mindspore.communication import GlobalComm, get_group_rank_from_world_rank,
20
20
  from mindspore.communication.management import _get_group
21
21
  from mindspore.communication._comm_helper import _get_group_rank_from_world_rank_from_cache_helper
22
22
  from mindspore.common.tensor import Tensor
23
- from mindspore._c_expression import Tensor as Tensor_
23
+ from mindspore._c_expression import TensorPy as Tensor_
24
24
  from mindspore.ops import ReduceOp, cat
25
25
  from mindspore.ops._primitive_cache import _get_cache_prim
26
26
  from mindspore.ops.primitive import _primexpr
@@ -28,7 +28,9 @@ from mindspore.ops.auto_generate.gen_ops_prim import (inner_comm_all_reduce_op,
28
28
  inner_comm_all_to_all_v_op, inner_comm_irecv_op,
29
29
  inner_comm_isend_op, inner_comm_reduce_scatter_op)
30
30
  from mindspore._c_expression import CommHandle as CommHandle_
31
+ from mindspore._c_expression.typing import Type
31
32
  from mindspore import jit_class
33
+ import mindspore as ms
32
34
 
33
35
  __all__ = [
34
36
  'all_reduce',
@@ -61,6 +63,12 @@ class CommHandle(CommHandle_):
61
63
  handles will be created using Python.
62
64
  """
63
65
 
66
+ def __init__(self, handle=None, exec_sync=False):
67
+ super(CommHandle, self).__init__()
68
+ self.handle = handle
69
+ self.exec_sync = exec_sync
70
+
71
+
64
72
  def wait(self):
65
73
  r"""
66
74
  The wait for asynchronous handles will not take effect for handles created on the Python side.
@@ -78,6 +86,10 @@ class CommHandle(CommHandle_):
78
86
  [[2. 2. 2. 2. 2. 2. 2. 2.]
79
87
  [2. 2. 2. 2. 2. 2. 2. 2.]]
80
88
  """
89
+ if self.handle:
90
+ self.handle.wait()
91
+ if self.exec_sync:
92
+ ms.runtime.synchronize()
81
93
 
82
94
 
83
95
  default_handle = CommHandle()
@@ -218,19 +230,18 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async
218
230
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
219
231
  without any third-party or configuration file dependencies.
220
232
  Please see the `msrun start up
221
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
233
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
222
234
  for more details.
223
235
 
224
236
  This example should be run with 2 devices.
225
237
 
226
238
  >>> import numpy as np
227
- >>> from mindspore.communication import init
228
- >>> from mindspore.communication.comm_func import all_reduce
229
- >>> from mindspore import Tensor
239
+ >>> import mindspore as ms
240
+ >>> import mindspore.communication as comm
230
241
  >>>
231
- >>> init()
232
- >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
233
- >>> output = all_reduce(input_tensor)
242
+ >>> comm.init()
243
+ >>> input_tensor = ms.Tensor(np.ones([2, 8]).astype(np.float32))
244
+ >>> output, _ = comm.comm_func.all_reduce(input_tensor)
234
245
  >>> print(output)
235
246
  [[2. 2. 2. 2. 2. 2. 2. 2.]
236
247
  [2. 2. 2. 2. 2. 2. 2. 2.]]
@@ -284,22 +295,18 @@ def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP, async_op=F
284
295
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
285
296
  without any third-party or configuration file dependencies.
286
297
  Please see the `msrun start up
287
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
298
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
288
299
  for more details.
289
300
 
290
301
  This example should be run with 2 devices.
291
302
 
292
303
  >>> import numpy as np
293
304
  >>> import mindspore as ms
294
- >>> from mindspore import ops
295
- >>> from mindspore.communication import init
296
- >>> from mindspore.communication.comm_func import all_gather_into_tensor
297
- >>> from mindspore import Tensor
305
+ >>> import mindspore.communication as comm
298
306
  >>>
299
- >>> ms.set_context(mode=ms.GRAPH_MODE)
300
- >>> init()
301
- >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
302
- >>> output = all_gather_into_tensor(input_tensor)
307
+ >>> comm.init()
308
+ >>> input_tensor = ms.Tensor(np.ones([2, 8]).astype(np.float32))
309
+ >>> output, _ = comm.comm_func.all_gather_into_tensor(input_tensor)
303
310
  >>> print(output)
304
311
  [[1. 1. 1. 1. 1. 1. 1. 1.]
305
312
  [1. 1. 1. 1. 1. 1. 1. 1.]
@@ -358,21 +365,18 @@ def reduce_scatter_tensor(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_G
358
365
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
359
366
  without any third-party or configuration file dependencies.
360
367
  Please see the `msrun start up
361
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
368
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
362
369
  for more details.
363
370
 
364
371
  This example should be run with 2 devices.
365
372
 
366
- >>> import mindspore as ms
367
- >>> from mindspore import Tensor
368
- >>> from mindspore.communication import init
369
- >>> from mindspore.communication.comm_func import reduce_scatter_tensor
370
373
  >>> import numpy as np
374
+ >>> import mindspore as ms
375
+ >>> import mindspore.communication as comm
371
376
  >>>
372
- >>> ms.set_context(mode=ms.GRAPH_MODE)
373
- >>> init()
374
- >>> input_tensor = Tensor(np.ones([8, 8]).astype(np.float32))
375
- >>> output = reduce_scatter_tensor(input_tensor)
377
+ >>> comm.init()
378
+ >>> input_tensor = ms.Tensor(np.ones([8, 8]).astype(np.float32))
379
+ >>> output, _ = comm.comm_func.reduce_scatter_tensor(input_tensor)
376
380
  >>> print(output)
377
381
  [[2. 2. 2. 2. 2. 2. 2. 2.]
378
382
  [2. 2. 2. 2. 2. 2. 2. 2.]
@@ -430,22 +434,20 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
430
434
  without any third-party or configuration file dependencies.
431
435
 
432
436
  Please see the `msrun start up
433
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
437
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
434
438
  for more details.
435
439
 
436
440
  This example should be run with 4 devices.
437
441
 
438
- >>> from mindspore import ops
439
- >>> import mindspore.nn as nn
440
- >>> from mindspore.communication import init
441
- >>> from mindspore.communication.comm_func import reduce
442
- >>> from mindspore import Tensor
443
442
  >>> import numpy as np
443
+ >>> import mindspore as ms
444
+ >>> import mindspore.communication as comm
445
+ >>>
444
446
  >>> # Launch 4 processes.
445
- >>> init()
447
+ >>> comm.init()
446
448
  >>> dest_rank=1
447
- >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
448
- >>> output = reduce(input_tensor)
449
+ >>> input_tensor = ms.Tensor(np.ones([2, 8]).astype(np.float32))
450
+ >>> output = comm.comm_func.reduce(input_tensor, dst=dest_rank)
449
451
  >>> print(output)
450
452
  Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
451
453
  [4. 4. 4. 4. 4. 4. 4. 4.]],
@@ -494,27 +496,36 @@ class P2POp:
494
496
 
495
497
  Examples:
496
498
  >>> import numpy as np
497
- >>> import mindspore
498
- >>> from mindspore.communication.comm_func import P2POp, isend, irecv
499
- >>> from mindspore import Tensor
500
- >>> send_tensor = Tensor(1.)
501
- >>> send_op = P2POp('isend', send_tensor, 1)
502
- >>> send_op = P2POp(isend, send_tensor, 1)
503
- >>> recv_tensor = Tensor(0.)
504
- >>> recv_op = P2POp('irecv', recv_tensor, 0)
505
- >>> recv_op = P2POp(irecv, recv_tensor, 0)
506
- >>> recv_op = P2POp('irecv', (), 0, recv_dtype=mindspore.float32)
499
+ >>> import mindspore as ms
500
+ >>> import mindspore.communication as comm
501
+ >>>
502
+ >>> send_tensor = ms.Tensor(1.)
503
+ >>> send_op = comm.comm_func.P2POp('isend', send_tensor, 1)
504
+ >>> send_op = comm.comm_func.P2POp(comm.comm_func.isend, send_tensor, 1)
505
+ >>> recv_tensor = ms.Tensor(0.)
506
+ >>> recv_op = comm.comm_func.P2POp('irecv', recv_tensor, 0)
507
+ >>> recv_op = comm.comm_func.P2POp(comm.comm_func.irecv, recv_tensor, 0)
508
+ >>> recv_op = comm.comm_func.P2POp('irecv', (), 0, recv_dtype=ms.float32)
507
509
  """
508
510
 
509
511
  def __init__(self, op, tensor, peer, group=None, tag=0, *, recv_dtype=None):
510
512
  self.op = op
511
513
  self.tensor = tensor
514
+ if not isinstance(peer, int):
515
+ raise TypeError(f"peer must be type of int, but got type of {type(peer)}")
516
+
517
+ if recv_dtype and not isinstance(recv_dtype, Type):
518
+ raise TypeError(f"recv_dtype must be type of mindspore dtype, but got type of {type(recv_dtype)}")
519
+
512
520
  self.peer = peer
513
521
  self.group = group
514
522
  self.tag = tag
515
523
  self.recv_dtype = recv_dtype
516
524
 
517
525
  def __new__(cls, op, tensor, peer, group=None, tag=0, recv_dtype=None):
526
+ if not (isinstance(op, str) or callable(op)):
527
+ raise TypeError(f"op must be type of string or function, but got type of {type(op)}")
528
+
518
529
  if isinstance(op, str):
519
530
  op_name = op
520
531
  else:
@@ -560,31 +571,29 @@ def batch_isend_irecv(p2p_op_list):
560
571
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
561
572
  without any third-party or configuration file dependencies.
562
573
  Please see the `msrun start up
563
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
574
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
564
575
  for more details.
565
576
 
566
577
  This example should be run with 2 devices.
567
578
 
568
579
  >>> import numpy as np
569
- >>> import mindspore
570
- >>> from mindspore.communication import init, get_rank, get_group_size
571
- >>> from mindspore.communication.comm_func import batch_isend_irecv, P2POp
572
- >>> from mindspore import Tensor
580
+ >>> import mindspore as ms
581
+ >>> import mindspore.communication as comm
573
582
  >>>
574
- >>> init()
575
- >>> this_rank = get_rank()
576
- >>> world_size = get_group_size()
583
+ >>> comm.init()
584
+ >>> this_rank = comm.get_rank()
585
+ >>> world_size = comm.get_group_size()
577
586
  >>> next_rank = (this_rank + 1) % world_size
578
587
  >>> prev_rank = (this_rank + world_size - 1) % world_size
579
588
  >>>
580
- >>> send_tensor = Tensor(this_rank + 1, dtype=mindspore.float32)
581
- >>> recv_tensor = Tensor(0., dtype=mindspore.float32)
589
+ >>> send_tensor = ms.Tensor(this_rank + 1, dtype=ms.float32)
590
+ >>> recv_tensor = ms.Tensor(0., dtype=ms.float32)
582
591
  >>>
583
- >>> send_op = P2POp('isend', send_tensor, next_rank)
584
- >>> recv_op = P2POp('irecv', recv_tensor, prev_rank)
592
+ >>> send_op = comm.comm_func.P2POp('isend', send_tensor, next_rank)
593
+ >>> recv_op = comm.comm_func.P2POp('irecv', recv_tensor, prev_rank)
585
594
  >>>
586
595
  >>> p2p_op_list = [send_op, recv_op]
587
- >>> output = batch_isend_irecv(p2p_op_list)
596
+ >>> output = comm.comm_func.batch_isend_irecv(p2p_op_list)
588
597
  >>> print(output)
589
598
  rank 0:
590
599
  (Tensor(shape=[], dtype=Float32, value= 0), Tensor(shape=[], dtype=Float32, value= 2))
@@ -597,6 +606,10 @@ def batch_isend_irecv(p2p_op_list):
597
606
  receive_shapes = []
598
607
  receive_dtypes = []
599
608
  tags = []
609
+
610
+ if not isinstance(p2p_op_list, list):
611
+ raise TypeError(f"p2p_op_list must be type of list, but got type of {p2p_op_list}.")
612
+
600
613
  if not p2p_op_list:
601
614
  raise TypeError(f"p2p_op_list can not be empty list.")
602
615
  group = p2p_op_list[0].group
@@ -676,20 +689,20 @@ def scatter_tensor(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP):
676
689
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
677
690
  without any third-party or configuration file dependencies.
678
691
  Please see the `msrun start up
679
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
692
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
680
693
  for more details.
681
694
 
682
695
  This example should be run with 2 devices.
683
696
 
684
- >>> import mindspore as ms
685
- >>> from mindspore.communication import init
686
- >>> from mindspore.communication.comm_func import scatter_tensor
687
697
  >>> import numpy as np
698
+ >>> import mindspore as ms
699
+ >>> import mindspore.communication as comm
700
+ >>>
688
701
  >>> # Launch 2 processes.
689
702
  >>>
690
- >>> init()
703
+ >>> comm.init()
691
704
  >>> input = ms.Tensor(np.arange(8).reshape([4, 2]).astype(np.float32))
692
- >>> out = scatter_tensor(tensor=data, src=0)
705
+ >>> out = comm.comm_func.scatter_tensor(tensor=input, src=0)
693
706
  >>> print(out)
694
707
  # rank_0
695
708
  [[0. 1.]
@@ -741,22 +754,20 @@ def gather_into_tensor(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP):
741
754
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
742
755
  without any third-party or configuration file dependencies.
743
756
  Please see the `msrun start up
744
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
757
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
745
758
  for more details.
746
759
 
747
760
  This example should be run with 2 devices.
748
761
 
749
762
  >>> import numpy as np
750
763
  >>> import mindspore as ms
751
- >>> import mindspore.nn as nn
752
- >>> from mindspore.communication import init
753
- >>> from mindspore import Tensor
754
- >>> from mindspore.communication.comm_func import gather_into_tensor
764
+ >>> import mindspore.communication as comm
765
+ >>>
755
766
  >>> # Launch 2 processes.
756
767
  >>>
757
- >>> init()
758
- >>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32))
759
- >>> output = gather_into_tensor(tensor=data, dst=0)
768
+ >>> comm.init()
769
+ >>> input = ms.Tensor(np.arange(4).reshape([2, 2]).astype(np.float32))
770
+ >>> output = comm.comm_func.gather_into_tensor(tensor=input, dst=0)
760
771
  >>> print(output)
761
772
  Process with rank 0: [[0. 1.],
762
773
  [2. 3.],
@@ -804,21 +815,21 @@ def broadcast(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP):
804
815
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
805
816
  without any third-party or configuration file dependencies.
806
817
  Please see the `msrun start up
807
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
818
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
808
819
  for more details.
809
820
 
810
821
  This example should be run with 2 devices.
811
822
 
812
- >>> import mindspore as ms
813
- >>> from mindspore import Tensor
814
- >>> from mindspore.communication import init
815
- >>> from mindspore.communication.comm_func import broadcast
816
823
  >>> import numpy as np
824
+ >>> import mindspore as ms
825
+ >>> import mindspore.communication as comm
826
+ >>>
817
827
  >>> # Launch 2 processes.
818
828
  >>>
819
- >>> init()
829
+ >>> comm.init()
820
830
  >>> data = ms.Tensor(np.arange(8).reshape([2, 4]).astype(np.float32))
821
- >>> out = broadcast(tensor=data, src=0)
831
+ >>> out = comm.comm_func.broadcast(tensor=data, src=0)
832
+ >>> print(out)
822
833
  [[0. 1. 2. 3.]
823
834
  [4. 5. 6. 7.]]
824
835
 
@@ -858,31 +869,41 @@ def barrier(group=GlobalComm.WORLD_COMM_GROUP):
858
869
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
859
870
  without any third-party or configuration file dependencies.
860
871
  Please see the `msrun start up
861
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
872
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
862
873
  for more details.
863
874
 
864
875
  This example should be run with 2 devices.
865
876
 
866
- >>> from mindspore.communication import init
867
- >>> from mindspore.communication.comm_func import barrier
877
+ >>> import mindspore as ms
878
+ >>> import mindspore.communication as comm
879
+ >>>
868
880
  >>> # Launch 2 processes.
869
- >>> init()
870
- >>> barrier()
881
+ >>> comm.init()
882
+ >>> comm.comm_func.barrier()
883
+ >>> print("barrier finish!")
884
+ barrier finish!
871
885
 
872
886
  Tutorial Examples:
873
887
  - `Distributed Set Communication Primitives - Barrier
874
888
  <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#barrier>`_
875
889
  """
890
+ if not isinstance(group, str):
891
+ raise TypeError(f"group must be type of string, but got {type(group)}")
876
892
  _op = _get_cache_prim(P.Barrier)(group)
877
893
  return _op()
878
894
 
879
895
 
880
- def _deal_comm_outputs(output, async_op):
896
+ def _deal_comm_outputs(output, async_op, exec_sync=False):
897
+ """
898
+ deal with comm ops outputs.
899
+ """
881
900
  if isinstance(output, tuple):
882
901
  if not async_op:
883
902
  output[1].wait()
903
+ if exec_sync:
904
+ ms.runtime.synchronize()
884
905
  return (output[0], None)
885
- return output
906
+ return (output[0], CommHandle(output[1], exec_sync))
886
907
 
887
908
  if not async_op:
888
909
  return (output, None)
@@ -918,21 +939,35 @@ def send(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
918
939
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
919
940
  without any third-party or configuration file dependencies.
920
941
  Please see the `msrun start up
921
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
942
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
922
943
  for more details.
923
944
 
924
945
  This example should be run with 2 devices.
925
946
 
926
- >>> from mindspore import ops
927
- >>> import mindspore.nn as nn
928
- >>> from mindspore.communication import init
929
- >>> from mindspore.communication.comm_func import send
930
- >>> from mindspore import Tensor
931
947
  >>> import numpy as np
948
+ >>> import mindspore as ms
949
+ >>> from mindspore.communication import init
950
+ >>> from mindspore.communication.comm_func import send, recv
951
+ >>> from mindspore.communication import get_rank, get_group_size
932
952
  >>>
953
+ >>> np.random.seed(1)
933
954
  >>> init()
934
- >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
935
- >>> send(input_, 0)
955
+ >>> rank = get_rank()
956
+ >>> size = get_group_size()
957
+ >>> x = np.ones([2, 2]).astype(np.float32) * 0.01 * (rank + 1)
958
+ >>> x2 = np.ones([2, 2]).astype(np.float32)
959
+ >>>
960
+ >>>
961
+ >>> if rank < size / 2:
962
+ ... _x = ms.Tensor(x)
963
+ ... send(_x, rank + size // 2)
964
+ ... else:
965
+ ... _x2 = ms.Tensor(x2)
966
+ ... output = recv(_x2, rank - size // 2)
967
+ ... print(output)
968
+ rank1:
969
+ [[0.01 0.01]
970
+ [0.01 0.01]]
936
971
  """
937
972
  if not isinstance(tensor, (Tensor, Tensor_)):
938
973
  raise TypeError("For send, the input tensor must be tensor")
@@ -979,29 +1014,35 @@ def recv(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
979
1014
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
980
1015
  without any third-party or configuration file dependencies.
981
1016
  Please see the `msrun start up
982
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1017
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
983
1018
  for more details.
984
1019
 
985
1020
  This example should be run with 2 devices.
986
1021
 
987
- >>> from mindspore import ops
988
- >>> import mindspore.nn as nn
989
- >>> from mindspore.communication import init
990
- >>> from mindspore.communication.comm_func import recv
991
- >>> from mindspore import Tensor
992
1022
  >>> import numpy as np
1023
+ >>> import mindspore as ms
1024
+ >>> from mindspore.communication import init
1025
+ >>> from mindspore.communication.comm_func import send, recv
1026
+ >>> from mindspore.communication import get_rank, get_group_size
993
1027
  >>>
994
- # Launch 2 processes.
995
- Process 0 send the following array to Process 1
996
- [[ 0. 1.]
997
- [ 2. 3.]]
1028
+ >>> np.random.seed(1)
998
1029
  >>> init()
999
- >>> x = ms.Tensor(np.zeros([2, 2]))
1000
- # Process 1 receive tensor from Process 0.
1001
- >>> out = recv(x, src=0)
1002
- >>> print(out)
1003
- [[ 0. 1.]
1004
- [ 2. 3.]]
1030
+ >>> rank = get_rank()
1031
+ >>> size = get_group_size()
1032
+ >>> x = np.ones([2, 2]).astype(np.float32) * 0.01 * (rank + 1)
1033
+ >>> x2 = np.ones([2, 2]).astype(np.float32)
1034
+ >>>
1035
+ >>>
1036
+ >>> if rank < size / 2:
1037
+ ... _x = ms.Tensor(x)
1038
+ ... send(_x, rank + size // 2)
1039
+ ... else:
1040
+ ... _x2 = ms.Tensor(x2)
1041
+ ... output = recv(_x2, rank - size // 2)
1042
+ ... print(output)
1043
+ rank1:
1044
+ [[0.01 0.01]
1045
+ [0.01 0.01]]
1005
1046
  """
1006
1047
  if not isinstance(tensor, (Tensor, Tensor_)):
1007
1048
  raise TypeError("For recv, the input tensor must be tensor")
@@ -1049,22 +1090,36 @@ def isend(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
1049
1090
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1050
1091
  without any third-party or configuration file dependencies.
1051
1092
  Please see the `msrun start up
1052
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1093
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1053
1094
  for more details.
1054
1095
 
1055
1096
  This example should be run with 2 devices.
1056
1097
 
1057
- >>> from mindspore import ops
1058
- >>> import mindspore.nn as nn
1059
- >>> from mindspore.communication import init
1060
- >>> from mindspore.communication.comm_func import isend
1061
- >>> from mindspore import Tensor
1062
1098
  >>> import numpy as np
1099
+ >>> import mindspore as ms
1100
+ >>> from mindspore.communication import init
1101
+ >>> from mindspore.communication.comm_func import isend, irecv
1102
+ >>> from mindspore.communication import get_rank, get_group_size
1063
1103
  >>>
1104
+ >>> np.random.seed(1)
1064
1105
  >>> init()
1065
- >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
1066
- >>> handle = isend(input_, 0)
1067
- >>> handle.wait()
1106
+ >>> rank = get_rank()
1107
+ >>> size = get_group_size()
1108
+ >>> x = np.ones([2, 2]).astype(np.float32) * 0.01 * (rank + 1)
1109
+ >>> x2 = np.ones([2, 2]).astype(np.float32)
1110
+ >>>
1111
+ >>>
1112
+ >>> if rank < size / 2:
1113
+ ... _x = ms.Tensor(x)
1114
+ ... isend(_x, rank + size // 2)
1115
+ ... else:
1116
+ ... _x2 = ms.Tensor(x2)
1117
+ ... output, handle = irecv(_x2, rank - size // 2)
1118
+ ... handle.wait()
1119
+ ... print(output)
1120
+ rank1:
1121
+ [[0.01 0.01]
1122
+ [0.01 0.01]]
1068
1123
  """
1069
1124
  if not isinstance(tensor, (Tensor, Tensor_)):
1070
1125
  raise TypeError("For isend, the input tensor must be tensor")
@@ -1114,30 +1169,36 @@ def irecv(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
1114
1169
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1115
1170
  without any third-party or configuration file dependencies.
1116
1171
  Please see the `msrun start up
1117
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1172
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1118
1173
  for more details.
1119
1174
 
1120
1175
  This example should be run with 2 devices.
1121
1176
 
1122
- >>> from mindspore import ops
1123
- >>> import mindspore.nn as nn
1124
- >>> from mindspore.communication import init
1125
- >>> from mindspore.communication.comm_func import irecv
1126
- >>> from mindspore import Tensor
1127
1177
  >>> import numpy as np
1178
+ >>> import mindspore as ms
1179
+ >>> from mindspore.communication import init
1180
+ >>> from mindspore.communication.comm_func import isend, irecv
1181
+ >>> from mindspore.communication import get_rank, get_group_size
1128
1182
  >>>
1129
- # Launch 2 processes.
1130
- Process 0 send the following array to Process 1
1131
- [[ 0. 1.]
1132
- [ 2. 3.]]
1183
+ >>> np.random.seed(1)
1133
1184
  >>> init()
1134
- >>> x = ms.Tensor(np.zeros([2, 2]))
1135
- # Process 1 receive tensor from Process 0.
1136
- >>> out, handle = irecv(x, src=0)
1137
- >>> handle.wait()
1138
- >>> print(out)
1139
- [[ 0. 1.]
1140
- [ 2. 3.]]
1185
+ >>> rank = get_rank()
1186
+ >>> size = get_group_size()
1187
+ >>> x = np.ones([2, 2]).astype(np.float32) * 0.01 * (rank + 1)
1188
+ >>> x2 = np.ones([2, 2]).astype(np.float32)
1189
+ >>>
1190
+ >>>
1191
+ >>> if rank < size / 2:
1192
+ ... _x = ms.Tensor(x)
1193
+ ... isend(_x, rank + size // 2)
1194
+ ... else:
1195
+ ... _x2 = ms.Tensor(x2)
1196
+ ... output, handle = irecv(_x2, rank - size // 2)
1197
+ ... handle.wait()
1198
+ ... print(output)
1199
+ rank1:
1200
+ [[0.01 0.01]
1201
+ [0.01 0.01]]
1141
1202
  """
1142
1203
  group = _get_group(group)
1143
1204
  _src = _get_group_rank_from_world_rank_from_cache_helper(src, group)
@@ -1185,27 +1246,24 @@ def all_to_all_with_output_shape(output_shape_list, input_tensor_list, group=Non
1185
1246
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1186
1247
  without any third-party or configuration file dependencies.
1187
1248
  Please see the `msrun start up
1188
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1249
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1189
1250
  for more details.
1190
1251
 
1191
1252
  This example should be run with 2 devices.
1192
1253
 
1193
1254
  >>> import numpy as np
1194
- >>> import mindspore
1195
- >>> from mindspore.communication import init, get_rank, get_group_size
1196
- >>> from mindspore.communication.comm_func import all_to_all_with_output_shape
1197
- >>> from mindspore import Tensor
1198
- >>> from mindspore.ops import zeros
1255
+ >>> import mindspore as ms
1256
+ >>> import mindspore.communication as comm
1199
1257
  >>>
1200
- >>> init()
1201
- >>> this_rank = get_rank()
1258
+ >>> comm.init()
1259
+ >>> this_rank = comm.get_rank()
1202
1260
  >>> if this_rank == 0:
1203
- >>> send_tensor_list = [Tensor(1.), Tensor([[2, 3], [4, 5.]])]
1204
- >>> recv_tensor_list = [(), (2,)]
1261
+ ... send_tensor_list = [ms.Tensor(1.), ms.Tensor([[2, 3], [4, 5.]])]
1262
+ ... recv_tensor_list = [(), (2,)]
1205
1263
  >>> if this_rank == 1:
1206
- >>> send_tensor_list = [Tensor([2, 2.]), Tensor([4, 5, 6, 7.])]
1207
- >>> recv_tensor_list = [(2, 2), (4,)]
1208
- >>> output = all_to_all_with_output_shape(recv_tensor_list, send_tensor_list)
1264
+ ... send_tensor_list = [ms.Tensor([2, 2.]), ms.Tensor([4, 5, 6, 7.])]
1265
+ ... recv_tensor_list = [(2, 2), (4,)]
1266
+ >>> output, _ = comm.comm_func.all_to_all_with_output_shape(recv_tensor_list, send_tensor_list)
1209
1267
  >>> print(output)
1210
1268
  rank 0:
1211
1269
  (Tensor(shape=[], dtype=Float32, value= 1),
@@ -1239,7 +1297,6 @@ def all_to_all_with_output_shape(output_shape_list, input_tensor_list, group=Non
1239
1297
  recv_shape_list.append(_shape)
1240
1298
 
1241
1299
  send_flatten_tensor = cat(send_flatten_tensor)
1242
- send_flatten_tensor = _contiguous(send_flatten_tensor)
1243
1300
  group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
1244
1301
  global _GROPU_SIZE_CACHE
1245
1302
  if group not in _GROPU_SIZE_CACHE:
@@ -1256,17 +1313,17 @@ def all_to_all_with_output_shape(output_shape_list, input_tensor_list, group=Non
1256
1313
  return (tuple(result), handle)
1257
1314
 
1258
1315
 
1259
- def _get_all_to_all_single_numel_list(tensor, output_shape, output_split_sizes, input_split_sizes, group):
1316
+ def _get_all_to_all_single_numel_list(tensor_shape, output_shape, output_split_sizes, input_split_sizes, group):
1260
1317
  """get numel list for all_to_all_single."""
1261
1318
  global _GROPU_SIZE_CACHE
1262
1319
  if _is_split_sizes_empty(input_split_sizes):
1263
1320
  if group not in _GROPU_SIZE_CACHE:
1264
1321
  _GROPU_SIZE_CACHE[group] = get_group_size(group)
1265
1322
  _world_size = _GROPU_SIZE_CACHE[group]
1266
- if tensor.shape[0] % _world_size != 0:
1323
+ if tensor_shape[0] % _world_size != 0:
1267
1324
  raise ValueError("input shape at dim 0 must be divided by world_size, "
1268
- f"but got {tensor.shape[0]} and {_world_size}.")
1269
- _split_size = tensor.shape[0] // _world_size
1325
+ f"but got {tensor_shape[0]} and {_world_size}.")
1326
+ _split_size = tensor_shape[0] // _world_size
1270
1327
  input_split_sizes = (_split_size,) * _world_size
1271
1328
  if _is_split_sizes_empty(output_split_sizes):
1272
1329
  if group not in _GROPU_SIZE_CACHE:
@@ -1283,7 +1340,7 @@ def _get_all_to_all_single_numel_list(tensor, output_shape, output_split_sizes,
1283
1340
  _split_size = shape_dim_0 // _world_size
1284
1341
  output_split_sizes = (_split_size,) * _world_size
1285
1342
 
1286
- send_size_without_first_dim = _get_size(tensor.shape[1:])
1343
+ send_size_without_first_dim = _get_size(tensor_shape[1:])
1287
1344
  send_numel_list = [size * send_size_without_first_dim for size in input_split_sizes]
1288
1345
 
1289
1346
  recv_size_without_first_dim = None
@@ -1298,10 +1355,14 @@ def _get_all_to_all_single_numel_list(tensor, output_shape, output_split_sizes,
1298
1355
  return send_numel_list, recv_numel_list, recv_shape_without_first_dim
1299
1356
 
1300
1357
 
1358
+ _ALL_TO_ALL_CACHE = {}
1359
+
1360
+
1301
1361
  def all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes=None,
1302
1362
  input_split_sizes=None, group=None, async_op=False):
1303
1363
  """
1304
- scatter and gather input with split size to/from all rank, and return result in a single tensor.
1364
+ Based on the slice size of the user input, the input `tensor` is sliced and sent to other devices
1365
+ and receives the sliced chunks from the other devices, which are then merged into an output Tensor.
1305
1366
 
1306
1367
  Note:
1307
1368
  'output_shape' and 'tensor' shape should be match across ranks.
@@ -1321,8 +1382,8 @@ def all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes
1321
1382
 
1322
1383
  Returns:
1323
1384
  Tuple(Tensor, CommHandle), the output tensor is gathered concatenated from remote ranks.
1324
- If the numel of tensor gathered from remote is zero, it will return a Tensor will value 0,
1325
- which has no actual meanning. CommHandle is an async work handle, if `async_op` is set to True.
1385
+ If the numel of tensor gathered from remote is zero, it will return a Tensor with shape `()`,
1386
+ and value has no actual meanning. CommHandle is an async work handle, if `async_op` is set to True.
1326
1387
  CommHandle will be None, when `async_op` is False.
1327
1388
 
1328
1389
  Raises:
@@ -1339,36 +1400,25 @@ def all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes
1339
1400
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1340
1401
  without any third-party or configuration file dependencies.
1341
1402
  Please see the `msrun start up
1342
- <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1403
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1343
1404
  for more details.
1344
1405
 
1345
1406
  This example should be run with 2 devices.
1346
1407
 
1347
1408
  >>> import numpy as np
1348
- >>> import mindspore
1349
- >>> from mindspore.communication import init, get_rank, get_group_size
1350
- >>> from mindspore.communication.comm_func import all_to_all_single_with_output_shape
1351
- >>> from mindspore import Tensor
1352
- >>> from mindspore.ops import zeros
1409
+ >>> import mindspore as ms
1410
+ >>> import mindspore.communication as comm
1353
1411
  >>>
1354
- >>> init()
1355
- >>> this_rank = get_rank()
1356
- >>> if this_rank == 0:
1357
- >>> output_shape = (3, 3)
1358
- >>> tensor = Tensor([[0, 1, 2.], [3, 4, 5], [6, 7, 8]])
1359
- >>> result = all_to_all_single_with_output_shape(output_shape, tensor, [2, 1], [2, 1])
1360
- >>> if this_rank == 1:
1361
- >>> output_shape = (2, 3)
1362
- >>> tensor = Tensor([[9, 10., 11], [12, 13, 14]])
1363
- >>> result = all_to_all_single_with_output_shape(output_shape, tensor)
1412
+ >>> comm.init()
1413
+ >>> rank = comm.get_rank()
1414
+ >>> input = ms.Tensor([0, 1]) + rank * 2
1415
+ >>> output_shape = (2,)
1416
+ >>> result, _ = comm.comm_func.all_to_all_single_with_output_shape(output_shape, input)
1364
1417
  >>> print(result)
1365
1418
  rank 0:
1366
- [[ 0. 1. 2.]
1367
- [ 3. 4. 5.]
1368
- [ 9. 10. 11.]]
1419
+ [ 0. 2.]
1369
1420
  rank 1:
1370
- [[ 6. 7. 8.]
1371
- [12. 13. 14.]]
1421
+ [ 1. 3.]
1372
1422
 
1373
1423
  """
1374
1424
 
@@ -1378,8 +1428,17 @@ def all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes
1378
1428
  group = GlobalComm.WORLD_COMM_GROUP
1379
1429
 
1380
1430
  split_sizes_empty = _is_split_sizes_empty(output_split_sizes) and _is_split_sizes_empty(input_split_sizes)
1381
- send_numel_list, recv_numel_list, recv_shape_without_first_dim = \
1382
- _get_all_to_all_single_numel_list(tensor, output_shape, output_split_sizes, input_split_sizes, group)
1431
+ if isinstance(output_split_sizes, list):
1432
+ output_split_sizes = tuple(output_split_sizes)
1433
+ if isinstance(input_split_sizes, list):
1434
+ input_split_sizes = tuple(input_split_sizes)
1435
+ global _ALL_TO_ALL_CACHE
1436
+ tensor_shape = output_shape
1437
+ cache_key = (tensor_shape, output_shape, output_split_sizes, input_split_sizes, group)
1438
+ if cache_key not in _ALL_TO_ALL_CACHE:
1439
+ _ALL_TO_ALL_CACHE[cache_key] = _get_all_to_all_single_numel_list(*cache_key)
1440
+ send_numel_list, recv_numel_list, recv_shape_without_first_dim = _ALL_TO_ALL_CACHE[cache_key]
1441
+
1383
1442
  tensor = _contiguous(tensor)
1384
1443
  _input = tensor.reshape(-1)
1385
1444
  group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)