mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,3883 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Communication management API"""
16
+ from __future__ import absolute_import
17
+ import hashlib
18
+ import builtins
19
+ import io
20
+ import sys
21
+ import pickle
22
+ from datetime import timedelta
23
+ import numpy as np
24
+ from mindspore import log as logger
25
+ from mindspore.common import dtype as mstype
26
+ from mindspore._checkparam import args_type_check
27
+ from mindspore.common.api import jit_class
28
+ from mindspore.common.api import _pynative_executor
29
+ from mindspore.runtime.stream import synchronize
30
+ from mindspore.ops.operations.comm_ops import ReduceOp
31
+ from mindspore.ops.auto_generate import cat
32
+ from mindspore.common.tensor import Tensor
33
+ from mindspore._c_expression import TensorPy as Tensor_
34
+ from mindspore.ops.primitive import _primexpr
35
+ from mindspore.communication.management import _init_without_sched
36
+ from mindspore.communication._comm_helper import (
37
+ _destroy_group_helper,
38
+ _get_rank_helper,
39
+ _get_size_helper,
40
+ _get_backend,
41
+ _get_group_ranks,
42
+ _is_available,
43
+ _is_initialized,
44
+ _ExistingGroup,
45
+ _get_group_rank_from_world_rank_from_cache_helper,
46
+ )
47
+ from mindspore.communication import (
48
+ init,
49
+ get_group_size,
50
+ get_world_rank_from_group_rank,
51
+ create_group,
52
+ GlobalComm,
53
+ get_group_rank_from_world_rank,
54
+ )
55
+ from mindspore.ops.auto_generate.gen_ops_prim import (
56
+ inner_comm_all_reduce_op,
57
+ inner_comm_all_gather_op,
58
+ inner_comm_all_to_all_v_op,
59
+ inner_comm_irecv_op,
60
+ inner_comm_reduce_scatter_op
61
+ )
62
+ from mindspore.ops.auto_generate.gen_ops_prim import (
63
+ dist_comm_all_gather_op,
64
+ dist_comm_all_reduce_op,
65
+ dist_comm_reduce_scatter_op,
66
+ dist_comm_isend_op,
67
+ dist_comm_all_to_all_v_op,
68
+ dist_comm_reduce_scatter_tensor_op,
69
+ dist_comm_reduce_scatter_tensor_uneven_op,
70
+ dist_comm_all_to_all_v_single_op,
71
+ dist_comm_broadcast_op,
72
+ dist_comm_all_gather_into_tensor_op,
73
+ dist_comm_all_gather_into_tensor_uneven_op,
74
+ dist_comm_irecv_op,
75
+ dist_comm_scatter_tensor_op,
76
+ dist_comm_gather_into_tensor_op,
77
+ dist_comm_gather_op,
78
+ dist_comm_reduce_op,
79
+ dist_comm_scatter_op,
80
+ dist_comm_barrier_op,
81
+ dist_comm_batch_isend_irecv_op,
82
+ )
83
+ from mindspore._c_expression import TCPStoreClient, GroupOptions, _finalize_collective
84
+ from mindspore._c_expression import CommHandle as CommHandle_
85
+
86
+ __all__ = [
87
+ "TCPStore",
88
+ "init_process_group",
89
+ "destroy_process_group",
90
+ "get_rank",
91
+ "get_world_size",
92
+ "new_group",
93
+ "get_backend",
94
+ "get_global_rank",
95
+ "get_process_group_ranks",
96
+ "get_group_rank",
97
+ "all_reduce",
98
+ "all_gather_into_tensor",
99
+ "all_gather_into_tensor_uneven",
100
+ "all_to_all",
101
+ "all_to_all_single",
102
+ "reduce_scatter_tensor",
103
+ "reduce_scatter_tensor_uneven",
104
+ "isend",
105
+ "irecv",
106
+ "send",
107
+ "recv",
108
+ "gather",
109
+ "scatter",
110
+ "all_gather",
111
+ "reduce_scatter",
112
+ "barrier",
113
+ "broadcast",
114
+ "reduce",
115
+ "P2POp",
116
+ "batch_isend_irecv",
117
+ "all_gather_object",
118
+ "broadcast_object_list",
119
+ "gather_object",
120
+ "scatter_object_list",
121
+ "is_available",
122
+ "is_initialized",
123
+ 'gather_into_tensor',
124
+ 'scatter_tensor',
125
+ 'set_comm_ops_inplace',
126
+ 'all_to_all_v_c'
127
+ ]
128
+
129
+ _pickler = pickle.Pickler
130
+ _unpickler = pickle.Unpickler
131
+ BACKEND_HCCL = "hccl"
132
+ BACKEND_MCCL = "mccl"
133
+ _GROPU_SIZE_CACHE = {}
134
+ _GROPU_RANK_CACHE = {}
135
+ _ALL_TO_ALL_CACHE = {}
136
+
137
+ safe_builtins = {
138
+ 'range',
139
+ 'complex',
140
+ 'set',
141
+ 'frozenset',
142
+ 'slice',
143
+ }
144
+
145
+
146
+ def get_cache_group_size(group=GlobalComm.WORLD_COMM_GROUP):
147
+ """get cache group size."""
148
+ global _GROPU_SIZE_CACHE
149
+ if group not in _GROPU_SIZE_CACHE:
150
+ _GROPU_SIZE_CACHE[group] = _get_size_helper(group)
151
+ group_size = _GROPU_SIZE_CACHE[group]
152
+ return group_size
153
+
154
+
155
+ def get_cache_group_rank(group=GlobalComm.WORLD_COMM_GROUP):
156
+ """get cache rank id."""
157
+ global _GROPU_RANK_CACHE
158
+ if group not in _GROPU_RANK_CACHE:
159
+ _GROPU_RANK_CACHE[group] = _get_rank_helper(group)
160
+ group_rank = _GROPU_RANK_CACHE[group]
161
+ return group_rank
162
+
163
+
164
+ class RestrictedUnpickler(pickle.Unpickler):
165
+ # Override find_class method.
166
+ def find_class(self, module, name):
167
+ # Only allow safe classes from builtins.
168
+ if module == "builtins" and name in safe_builtins:
169
+ return getattr(builtins, name)
170
+ # Forbid everything else.
171
+ raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
172
+ (module, name))
173
+
174
+
175
+ def restricted_loads(s):
176
+ """Helper function analogous to pickle.loads()."""
177
+ return RestrictedUnpickler(io.BytesIO(s)).load()
178
+
179
+
180
+ def _object_to_tensor(obj, size=0):
181
+ f = io.BytesIO()
182
+ _pickler(f).dump(obj)
183
+ buf = np.frombuffer(f.getvalue(), dtype=np.int8)
184
+ tensor_size = buf.size
185
+ if size > tensor_size:
186
+ buf = np.resize(buf, size)
187
+ tensor_size = size
188
+ return Tensor(buf), tensor_size
189
+
190
+
191
+ def _tensor_to_object(tensor, tensor_size):
192
+ buf = tensor.asnumpy().tobytes()[:tensor_size]
193
+ return restricted_loads(buf)
194
+
195
+
196
+ comm_funcs = [
197
+ "all_reduce",
198
+ "all_gather_into_tensor",
199
+ "all_gather_into_tensor_uneven",
200
+ "all_to_all",
201
+ "all_to_all_single",
202
+ "reduce_scatter_tensor",
203
+ "reduce_scatter_tensor_uneven",
204
+ "isend",
205
+ "irecv",
206
+ "send",
207
+ "recv",
208
+ "gather",
209
+ "scatter",
210
+ "all_gather",
211
+ "reduce_scatter",
212
+ "barrier",
213
+ "broadcast",
214
+ "reduce",
215
+ "batch_isend_irecv",
216
+ "all_gather_object",
217
+ "broadcast_object_list",
218
+ "gather_object",
219
+ "scatter_object_list",
220
+ 'gather_into_tensor',
221
+ 'scatter_tensor',
222
+ 'all_to_all_v_c'
223
+ ]
224
+
225
+ _COMM_ENABLE_PLACE = {item: True for item in comm_funcs}
226
+
227
+
228
+ def is_inplace_func():
229
+ """if is inplace func name."""
230
+ global _COMM_ENABLE_PLACE
231
+ caller_name = sys._getframe(1).f_code.co_name # pylint: disable=protected-access
232
+ if caller_name in _COMM_ENABLE_PLACE:
233
+ return _COMM_ENABLE_PLACE[caller_name]
234
+ return False
235
+
236
+
237
+ def set_comm_ops_inplace(is_enable, func_list=None):
238
+ """
239
+ Set inplace attribute to communication function.
240
+
241
+ Args:
242
+ is_enable (bool): Whether to enable inplace.
243
+ func_list (list): Indicates which functions have their inplace attributes set.
244
+
245
+ Raises:
246
+ TypeError: If `is_enable` is not bool.
247
+ TypeError: If `func_list` is not None and not list.
248
+ ValueError: The function name in `func_list` is invalid.
249
+
250
+ Supported Platforms:
251
+ ``Ascend``
252
+
253
+ Examples:
254
+ .. note::
255
+ Before running the following examples, you need to configure the communication environment variables.
256
+
257
+ For Ascend devices, it is recommended to use the msrun startup method
258
+ without any third-party or configuration file dependencies.
259
+ Please see the `msrun start up
260
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
261
+ for more details.
262
+
263
+ >>> from mindspore.ops.communication import set_comm_ops_inplace
264
+ >>> set_comm_ops_inplace(True)
265
+ """
266
+ global _COMM_ENABLE_PLACE
267
+ if not isinstance(is_enable, bool):
268
+ raise TypeError(
269
+ "For 'set_comm_ops_inplace', the argument 'is_enable' must be type of bool, "
270
+ "but got 'is_enable' type : {}.".format(type(is_enable))
271
+ )
272
+ if func_list is None:
273
+ for func in _COMM_ENABLE_PLACE:
274
+ _COMM_ENABLE_PLACE[func] = is_enable
275
+ return
276
+ if not isinstance(func_list, (list, tuple)):
277
+ raise TypeError(f"Expected list or tuple, but got {type(func_list)}.")
278
+ for func in func_list:
279
+ if func not in _COMM_ENABLE_PLACE:
280
+ raise ValueError(f"The function name in `func_list` must be correct, but got {func}.")
281
+ _COMM_ENABLE_PLACE[func] = is_enable
282
+
283
+
284
+ @jit_class
285
+ class CommHandle(CommHandle_):
286
+ r"""
287
+ Usually, handles are created in C++during the execution of communication operators and returned to the Python
288
+ layer. It will not be created directly in Python. Only in scenarios where graph patterns are compatible,
289
+ handles will be created using Python.
290
+ """
291
+
292
+ def __init__(self, handle=None, exec_sync=False):
293
+ super(CommHandle, self).__init__()
294
+ self.handle = handle
295
+ self.exec_sync = exec_sync
296
+
297
+
298
+ def wait(self):
299
+ r"""
300
+ The wait for asynchronous handles will not take effect for handles created on the Python side.
301
+
302
+ >>> import numpy as np
303
+ >>> from mindspore.communication import init
304
+ >>> from mindspore.ops.communication import all_reduce
305
+ >>> from mindspore import Tensor
306
+ >>>
307
+ >>> init()
308
+ >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
309
+ >>> output, handle = all_reduce(input_tensor, async_op=True)
310
+ >>> handle.wait()
311
+ >>> print(output)
312
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
313
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
314
+ """
315
+ if self.handle:
316
+ self.handle.wait()
317
+ if self.exec_sync:
318
+ synchronize()
319
+
320
+
321
+ default_handle = CommHandle()
322
+
323
+
324
+ def _deal_comm_outputs(output, async_op, exec_sync=False):
325
+ """
326
+ deal with comm ops outputs.
327
+ """
328
+ if isinstance(output, tuple):
329
+ if not async_op:
330
+ output[1].wait()
331
+ if exec_sync:
332
+ synchronize()
333
+ return (output[0], None)
334
+ return (output[0], CommHandle(output[1], exec_sync))
335
+
336
+ if not async_op:
337
+ return (output, None)
338
+ return (output, default_handle)
339
+
340
+
341
+ @_primexpr
342
+ def _check_all_tensors(tensor_list):
343
+ """check all elements in tensor_list are type of Tensor"""
344
+ if not isinstance(tensor_list, (list, tuple)):
345
+ raise TypeError(f"Expected list or tuple, but got {type(tensor_list)}.")
346
+ for t in tensor_list:
347
+ if not isinstance(t, Tensor):
348
+ raise TypeError(f"Expected tensor, but got {type(t)}")
349
+
350
+
351
+ @_primexpr
352
+ def _check_all_tensors_or_tuple(tensor_list):
353
+ """check all elements in tensor_list are type of Tensor or tuple or list"""
354
+ if not isinstance(tensor_list, (list, tuple)):
355
+ raise TypeError(f"Expected list or tuple, but got {type(tensor_list)}.")
356
+ for t in tensor_list:
357
+ if not isinstance(t, (Tensor, tuple, list)):
358
+ raise TypeError(f"Expected tensor or tuple, but got {type(t)}")
359
+
360
+
361
+ @_primexpr
362
+ def _check_all_tensor_same_dtype(*tensor_lists):
363
+ """check all the input tensor has same dtype"""
364
+ consistent_dtype = None
365
+ for list_ in tensor_lists:
366
+ if not isinstance(list_, (list, tuple)):
367
+ list_ = [list_]
368
+ for tensor_ in list_:
369
+ if not isinstance(tensor_, Tensor):
370
+ continue
371
+
372
+ dtype = tensor_.dtype
373
+ if consistent_dtype is None:
374
+ consistent_dtype = dtype
375
+ else:
376
+ if dtype != consistent_dtype:
377
+ raise TypeError("all_to_all input dtype must be the same, "
378
+ f"but got {consistent_dtype} and {dtype}.")
379
+
380
+
381
+ def _get_size(shape):
382
+ numel = 1
383
+ for s in shape:
384
+ numel *= s
385
+ return numel
386
+
387
+
388
+ def _is_split_sizes_empty(split_sizes):
389
+ return split_sizes is None or not split_sizes
390
+
391
+
392
+ class TCPStore:
393
+ """
394
+ A TCP-based distributed key-value store implementation.
395
+
396
+ Note:
397
+ - The function is implemented by CPU and does not involve any hardware operations related to Ascend.
398
+ - Currently, all parameters provided by the TCPStore class constructor are not supported
399
+ except for `host_name`, `port`, `world_size`, `is_master`, `timeout` and `wait_for_workers`,
400
+ which are reserved parameters and invalid settings.
401
+ - The current TCPStore function is limited and only supports scenarios where the key is
402
+ less than 4k and the value is less than 1G. Complex scenarios are to be supported.
403
+
404
+ Args:
405
+ host_name (str): The hostname or IP Address the server store should run on.
406
+ Currently only supports user input IP addresses.
407
+ port (int): The port on which the server store should listen for incoming requests.
408
+ world_size (int, optional): The total number of store users (number of clients + 1 for the server).
409
+ Default is ``None``, indicates a non-fixed number of store users. This parameter is
410
+ only valid for the server.
411
+ is_master (bool, optional): True when initializing the server store and False for client stores.
412
+ Default is ``False``.
413
+ timeout (timedelta, optional): Timeout used by the store during initialization. Default is
414
+ ``timedelta(seconds=300)``.
415
+ wait_for_workers (bool, optional): Whether to wait for all the workers to connect with the server
416
+ store. This is only applicable when `world_size` is a fixed value. Default is ``True``. This
417
+ parameter is only valid for the server.
418
+ multi_tenant (bool, invalid, optional): If ``True``, all ``TCPStore`` instances in the current process with
419
+ the same host/port will use the same underlying ``TCPServer``. Default is ``False``.
420
+ master_listen_fd (int, invalid, optional): If specified, the underlying ``TCPServer`` will listen on this file
421
+ descriptor, which must be a socket already bound to ``port``. Useful to avoid port assignment races
422
+ in some scenarios. Default is ``None`` (meaning the server creates a new socket and attempts to bind it
423
+ to ``port``).
424
+ use_libuv (bool, invalid, optional): If True, use libuv for ``TCPServer`` backend. Default is ``True``.
425
+
426
+ Returns:
427
+ TCPStore Object.
428
+
429
+ Supported Platforms:
430
+ ``Ascend``
431
+
432
+ Examples:
433
+ .. note::
434
+ Before running the following examples, you need to configure the communication environment variables.
435
+
436
+ For Ascend devices, it is recommended to use the msrun startup method
437
+ without any third-party or configuration file dependencies.
438
+ Please see the `msrun start up
439
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
440
+ for more details.
441
+
442
+ >>> from mindspore.ops.communication import TCPStore
443
+ >>> store = TCPStore("127.0.0.1", 1234)
444
+ """
445
+
446
+ def __init__(self, host_name, port, world_size=None, is_master=False, timeout=timedelta(seconds=300),
447
+ wait_for_workers=True, multi_tenant=False, master_listen_fd=None, use_libuv=True):
448
+ if not isinstance(host_name, str):
449
+ raise TypeError(
450
+ "For 'TCPStore', the argument 'host_name' must be type of string, "
451
+ "but got 'host_name' type : {}.".format(type(host_name))
452
+ )
453
+ if not isinstance(port, int):
454
+ raise TypeError(
455
+ "For 'TCPStore', the argument 'port' must be type of int, "
456
+ "but got 'port' type : {}.".format(type(port))
457
+ )
458
+ if not isinstance(is_master, bool):
459
+ raise TypeError(
460
+ "For 'TCPStore', the argument 'is_master' must be type of bool, "
461
+ "but got 'is_master' type : {}.".format(type(is_master))
462
+ )
463
+ if not isinstance(timeout, timedelta):
464
+ raise TypeError(
465
+ "For 'TCPStore', the argument 'timeout' must be type of timedelta, "
466
+ "but got 'timeout' type : {}.".format(type(timeout))
467
+ )
468
+ if not isinstance(wait_for_workers, bool):
469
+ raise TypeError(
470
+ "For 'TCPStore', the argument 'wait_for_workers' must be type of bool, "
471
+ "but got 'wait_for_workers' type : {}.".format(type(wait_for_workers))
472
+ )
473
+ if world_size is None:
474
+ world_size = 1
475
+ if not isinstance(world_size, int):
476
+ raise TypeError(
477
+ "For 'TCPStore', the argument 'world_size' must be type of int, "
478
+ "but got 'world_size' type : {}.".format(type(world_size))
479
+ )
480
+ if port < 0 or port > 65535:
481
+ raise ValueError(
482
+ "For 'TCPStore', the argument 'port' must be legal, "
483
+ f"but got {port}."
484
+ )
485
+ if world_size <= 0:
486
+ raise ValueError(
487
+ "For 'TCPStore', the argument 'world_size' must be legal, "
488
+ f"but got {world_size}."
489
+ )
490
+ timeout_ms = int(timeout.total_seconds() * 1000)
491
+ self.instance = TCPStoreClient(host_name, port, is_master, timeout_ms, world_size, wait_for_workers)
492
+ self.host = host_name
493
+ self.port = port
494
+
495
+
496
+ def add(self, key, amount):
497
+ """
498
+ When the `add` function is called for the first time with a given key, it creates a counter in
499
+ the storage corresponding to that key, with the initial value set to `amount`. Subsequent calls
500
+ to `add` with the same key increment the counter by amount.
501
+
502
+ Args:
503
+ key (str): The key whose counter value will be incremented.
504
+ amount (int): The amount by which the counter will be incremented.
505
+
506
+ Returns:
507
+ int, value of counter with `key`.
508
+
509
+ Raises:
510
+ TypeError: If `key` is not string.
511
+ TypeError: If `amount` is not int.
512
+ RuntimeError: If the `add` and `set` pass the same `key` and the `value` passed by `set` cannot
513
+ be correctly converted to a numerical value, calling `add` will result in an error.
514
+
515
+ Supported Platforms:
516
+ ``Ascend``
517
+
518
+ Examples:
519
+ .. note::
520
+ Before running the following examples, you need to configure the communication environment variables.
521
+
522
+ For Ascend devices, it is recommended to use the msrun startup method
523
+ without any third-party or configuration file dependencies.
524
+ Please see the `msrun start up
525
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
526
+ for more details.
527
+
528
+ >>> from mindspore.ops.communication import TCPStore
529
+ >>> store = TCPStore("127.0.0.1", 1234)
530
+ >>> store.add("first_key", 1)
531
+ """
532
+ if not isinstance(key, str):
533
+ raise TypeError(
534
+ "For 'TCPStore.add', the argument 'key' must be type of string, "
535
+ "but got 'key' type : {}.".format(type(key))
536
+ )
537
+ if not isinstance(amount, int):
538
+ raise TypeError(
539
+ "For 'TCPStore.add', the argument 'amount' must be type of string or int, "
540
+ "but got 'amount' type : {}.".format(type(amount))
541
+ )
542
+ return self.instance.add(key, amount)
543
+
544
+
545
+ def set(self, key, value):
546
+ """
547
+ Inserts the key-value pair into the store based on the supplied `key` and
548
+ `value`. If `key` already exists in the store, it will overwrite the old
549
+ value with the new supplied `value`.
550
+
551
+ Args:
552
+ key (str): The key to be added to the store.
553
+ value (Union[bytes, str]): The value associated with `key` to be added to the store.
554
+
555
+ Raises:
556
+ TypeError: If `key` is not string.
557
+ TypeError: If `value` is not string or bytes.
558
+
559
+ Supported Platforms:
560
+ ``Ascend``
561
+
562
+ Examples:
563
+ .. note::
564
+ Before running the following examples, you need to configure the communication environment variables.
565
+
566
+ For Ascend devices, it is recommended to use the msrun startup method
567
+ without any third-party or configuration file dependencies.
568
+ Please see the `msrun start up
569
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
570
+ for more details.
571
+
572
+ >>> from mindspore.ops.communication import TCPStore
573
+ >>> store = TCPStore("127.0.0.1", 1234)
574
+ >>> store.set("first_key", "first_value")
575
+ """
576
+ if not isinstance(key, str):
577
+ raise TypeError(
578
+ "For 'TCPStore.set', the argument 'key' must be type of string, "
579
+ "but got 'key' type : {}.".format(type(key))
580
+ )
581
+ if not isinstance(value, (str, bytes)):
582
+ raise TypeError(
583
+ "For 'TCPStore.set', the argument 'value' must be type of string or bytes, "
584
+ "but got 'value' type : {}.".format(type(value))
585
+ )
586
+ return self.instance.set(key, value)
587
+
588
+
589
+ def get(self, key):
590
+ """
591
+ Retrieves the value associated with the given `key` in the store. If the `key` does not exist
592
+ in the storage, this function will wait for the `timeout` set by the class initialization and then
593
+ throw an exception.
594
+
595
+ Args:
596
+ key (str): The function will return the value associated with this key.
597
+
598
+ Returns:
599
+ bytes, Value associated with `key` if `key` is in the store.
600
+
601
+ Raises:
602
+ TypeError: If `key` is not string.
603
+ RuntimeError: If `get` runs out of time.
604
+
605
+ Supported Platforms:
606
+ ``Ascend``
607
+
608
+ Examples:
609
+ .. note::
610
+ Before running the following examples, you need to configure the communication environment variables.
611
+
612
+ For Ascend devices, it is recommended to use the msrun startup method
613
+ without any third-party or configuration file dependencies.
614
+ Please see the `msrun start up
615
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
616
+ for more details.
617
+
618
+ >>> from mindspore.ops.communication import TCPStore
619
+ >>> store = TCPStore("127.0.0.1", 1234)
620
+ >>> store.set("first_key", "first_value")
621
+ >>> data = store.get("first_key")
622
+ >>> print(data)
623
+ """
624
+ if not isinstance(key, str):
625
+ raise TypeError(
626
+ "For 'TCPStore.get', the argument 'key' must be type of string, "
627
+ "but got 'key' type : {}.".format(type(key))
628
+ )
629
+ byte_data = self.instance.get(key)
630
+ return byte_data
631
+
632
+
633
+ def delete_key(self, key):
634
+ """
635
+ Deletes the key-value pair associated with `key` from the store.
636
+
637
+ Args:
638
+ key (str): The key to be deleted from the store.
639
+
640
+ Returns:
641
+ bool, ``True`` if `key` was deleted, otherwise ``False``.
642
+
643
+ Raises:
644
+ TypeError: If `key` is not string.
645
+
646
+ Supported Platforms:
647
+ ``Ascend``
648
+
649
+ Examples:
650
+ .. note::
651
+ Before running the following examples, you need to configure the communication environment variables.
652
+
653
+ For Ascend devices, it is recommended to use the msrun startup method
654
+ without any third-party or configuration file dependencies.
655
+ Please see the `msrun start up
656
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
657
+ for more details.
658
+
659
+ >>> from mindspore.ops.communication import TCPStore
660
+ >>> store = TCPStore("127.0.0.1", 1234)
661
+ >>> store.set("first_key", "first_value")
662
+ >>> # This should return true
663
+ >>> store.delete_key("first_key")
664
+ """
665
+ if not isinstance(key, str):
666
+ raise TypeError(
667
+ "For 'TCPStore.delete_key', the argument 'key' must be type of string, "
668
+ "but got 'key' type : {}.".format(type(key))
669
+ )
670
+ return self.instance.delete_key(key)
671
+
672
+
673
+ def is_available():
674
+ """
675
+ Checks if distributed module is available.
676
+
677
+ Note:
678
+ Always returns `True` because MindSpore always has distributed ability on all platforms.
679
+
680
+ Returns:
681
+ bool, whether this distributed module is available.
682
+
683
+ Supported Platforms:
684
+ ``Ascend``
685
+
686
+ Examples:
687
+ .. note::
688
+ Before running the following examples, you need to configure the communication environment variables.
689
+
690
+ For Ascend devices, it is recommended to use the msrun startup method
691
+ without any third-party or configuration file dependencies.
692
+ Please see the `msrun start up
693
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
694
+ for more details.
695
+
696
+ >>> import mindspore as ms
697
+ >>> from mindspore.ops.communication import is_available
698
+ >>> ms.set_device(device_target="Ascend")
699
+ >>> is_available()
700
+ True
701
+ """
702
+ return _is_available()
703
+
704
+
705
+ def is_initialized():
706
+ """
707
+ Checks if default process group has been initialized.
708
+
709
+ Returns:
710
+ bool, whether the default process group has been initialized.
711
+
712
+ Supported Platforms:
713
+ ``Ascend``
714
+
715
+ Examples:
716
+ .. note::
717
+ Before running the following examples, you need to configure the communication environment variables.
718
+
719
+ For Ascend devices, it is recommended to use the msrun startup method
720
+ without any third-party or configuration file dependencies.
721
+ Please see the `msrun start up
722
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
723
+ for more details.
724
+
725
+ >>> import mindspore as ms
726
+ >>> from mindspore.ops.communication import init_process_group, is_initialized
727
+ >>> ms.set_device(device_target="Ascend")
728
+ >>> init_process_group()
729
+ >>> print(is_initialized())
730
+ True
731
+ """
732
+ return _is_initialized()
733
+
734
+
735
+ @args_type_check(init_method=str, timeout=timedelta, world_size=int, rank=int, store=TCPStore)
736
+ def init_process_group(backend="hccl",
737
+ init_method=None,
738
+ timeout=None,
739
+ world_size=-1,
740
+ rank=-1,
741
+ store=None,
742
+ pg_options=None,
743
+ device_id=None):
744
+ """
745
+ Init collective communication lib. And create a default collective communication group.
746
+
747
+ Note:
748
+ This method isn't supported in GPU and CPU versions of MindSpore.
749
+ In Ascend hardware platforms, this API should be set before the definition of any Tensor and Parameter,
750
+ and the instantiation and execution of any operation and net.
751
+
752
+ Args:
753
+ backend (str, optional): The backend to ues. Default is ``"hccl"`` and now only support hccl.
754
+ init_method (str, optional): URL specifying how to init collective communication group. Default is ``None``.
755
+ timeout (timedelta, optional): Timeout for API executed. Default is ``None``. Currently, this parameter is
756
+ only supported for host-side cluster network configuration using `init_method` or `store`.
757
+ world_size (int, optional): Number of the processes participating in the job. Default is ``-1``.
758
+ rank (int, optional): Rank of the current process. Default is ``-1``.
759
+ store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process
760
+ communication addresses and connection information. Default is ``None``. Currently, only the
761
+ ``TCPStore`` type is supported.
762
+ pg_options (ProcessGroupOptions, invalid): process group options specifying what additional options need to be
763
+ passed in during the construction of specific process group. The provided parameter is a reserved
764
+ parameter, and the current setting does not take effect.
765
+ device_id (int, invalid): the device id to exeute. The provided parameter is a reserved parameter,
766
+ and the current setting does not take effect.
767
+
768
+ Raises:
769
+ ValueError: If `backend` is not hccl.
770
+ ValueError: If `world_size` is not equal to -1 or process group number.
771
+ ValueError: If both `init_method` and `store` are set.
772
+ ValueError: `world_size` is not correctly set as a positive integer value, when using the initialization
773
+ method `init_method` or `store`.
774
+ ValueError: `rank` is not correctly set as a non-negative integer, when using the initialization method
775
+ `init_method` or `store`.
776
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails,
777
+ or the environment variables RANK_ID/MINDSPORE_HCCL_CONFIG_PATH
778
+ have not been exported when backend is HCCL.
779
+
780
+ Supported Platforms:
781
+ ``Ascend``
782
+
783
+ Examples:
784
+ .. note::
785
+ Before running the following examples, you need to configure the communication environment variables.
786
+
787
+ For Ascend devices, it is recommended to use the msrun startup method
788
+ without any third-party or configuration file dependencies.
789
+ Please see the `msrun start up
790
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
791
+ for more details.
792
+
793
+ >>> import mindspore as ms
794
+ >>> from mindspore.ops.communication import init_process_group, destroy_process_group
795
+ >>> ms.set_device(device_target="Ascend")
796
+ >>> init_process_group()
797
+ >>> destroy_process_group()
798
+ """
799
+ if pg_options is not None:
800
+ logger.warning("pg_options is ignored, setting is invalid")
801
+ if device_id is not None:
802
+ logger.warning("device_id is ignored, setting is invalid")
803
+ if backend != "hccl":
804
+ raise ValueError(
805
+ "Only support hccl now, please setting backend to hccl or using default value"
806
+ )
807
+
808
+ if init_method is not None and store is not None:
809
+ raise ValueError(
810
+ "Only one of init_method and store is supported."
811
+ )
812
+ if init_method is not None or store is not None:
813
+ if world_size <= 0:
814
+ raise ValueError(
815
+ "Specified world_size must be a positive integer."
816
+ )
817
+ if rank < 0:
818
+ raise ValueError(
819
+ "Specified rank must be a non-negative integer."
820
+ )
821
+ if timeout is None:
822
+ timeout = timedelta(seconds=300)
823
+ timeout_ms = int(timeout.total_seconds() * 1000)
824
+ _init_without_sched(backend, init_method, timeout_ms, world_size, rank, store)
825
+ else:
826
+ init(backend)
827
+
828
+ if world_size != -1 and world_size != get_group_size():
829
+ raise ValueError(
830
+ "world_size is wrong, please using default value or setting: ",
831
+ get_group_size(),
832
+ )
833
+
834
+
835
+ def destroy_process_group(group=None):
836
+ """
837
+ Destroy the user collective communication group.
838
+ If group is None or "hccl_world_group", Destroy all group and release collective communication lib.
839
+
840
+ Note:
841
+ - This method isn't supported in GPU and CPU versions of MindSpore.
842
+ - This method should be used after :func:`mindspore.ops.communication.init_process_group`.
843
+
844
+ Args:
845
+ group (str, optional): The communication group to work on. Normally, the group should be created by
846
+ :func:`mindspore.ops.communication.new_group`. If ``None``, which means ``"hccl_world_group"`` in Ascend.
847
+ Default: ``None``.
848
+
849
+ Raises:
850
+ TypeError: If group is not a string.
851
+ RuntimeError: If HCCL is not available or MindSpore is GPU/CPU version.
852
+
853
+ Supported Platforms:
854
+ ``Ascend``
855
+
856
+ Examples:
857
+ .. note::
858
+ Before running the following examples, you need to configure the communication environment variables.
859
+
860
+ For Ascend devices, it is recommended to use the msrun startup method
861
+ without any third-party or configuration file dependencies.
862
+ Please see the `msrun start up
863
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
864
+ for more details.
865
+
866
+ >>> import mindspore as ms
867
+ >>> from mindspore.ops.communication import init_process_group, destroy_process_group
868
+ >>> ms.set_device(device_target="Ascend")
869
+ >>> init_process_group()
870
+ >>> destroy_process_group()
871
+ """
872
+
873
+ if group == GlobalComm.WORLD_COMM_GROUP or group is None:
874
+ _pynative_executor.sync()
875
+ _finalize_collective()
876
+ _ExistingGroup.ITEMS.clear()
877
+ _ExistingGroup.GROUP_RANKS.clear()
878
+ elif not isinstance(group, str):
879
+ raise TypeError(
880
+ "For 'destroy_group', the argument 'group' must be type of string or None, "
881
+ "but got 'group' type : {}.".format(type(group))
882
+ )
883
+ else:
884
+ _destroy_group_helper(group)
885
+
886
+
887
+ def get_rank(group=None):
888
+ """
889
+ Get the rank ID for the current device in the specified collective communication group.
890
+
891
+ Note:
892
+ This method should be used after :func:`mindspore.ops.communication.init_process_group`.
893
+
894
+ Args:
895
+ group (str, optional): The communication group to work on. Normally, the group should be created by
896
+ :func:`mindspore.ops.communication.new_group`. If ``None``, which means ``"hccl_world_group"`` in Ascend.
897
+ Default: ``None``.
898
+
899
+ Returns:
900
+ int, the rank ID of the calling process within the group.
901
+ return -1, if not part of the group
902
+
903
+ Raises:
904
+ TypeError: If group is not a string.
905
+
906
+ Supported Platforms:
907
+ ``Ascend`` ``CPU``
908
+
909
+ Examples:
910
+ .. note::
911
+ Before running the following examples, you need to configure the communication environment variables.
912
+
913
+ For Ascend devices, it is recommended to use the msrun startup method
914
+ without any third-party or configuration file dependencies.
915
+ Please see the `msrun start up
916
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
917
+ for more details.
918
+
919
+ >>> import mindspore as ms
920
+ >>> from mindspore.ops.communication import init_process_group, get_rank
921
+ >>> ms.set_device(device_target="Ascend")
922
+ >>> init_process_group()
923
+ >>> rank_id = get_rank()
924
+ >>> print(rank_id)
925
+ >>> # the result is the rank_id in world_group
926
+ #rank 0: 0
927
+ #rank 1: 1
928
+ """
929
+ if group is None:
930
+ group = GlobalComm.WORLD_COMM_GROUP
931
+ if not isinstance(group, str):
932
+ raise TypeError(
933
+ "For 'get_rank', the argument 'group' must be type of string, "
934
+ "but got 'group' type : {}.".format(type(group))
935
+ )
936
+ try:
937
+ ret = _get_rank_helper(group)
938
+ except RuntimeError as e:
939
+ logger.warning(e)
940
+ ret = -1
941
+ return ret
942
+
943
+
944
+ def get_world_size(group=None):
945
+ """
946
+ Get the rank size of the specified collective communication group.
947
+
948
+ Note:
949
+ This method should be used after :func:`mindspore.ops.communication.init_process_group`.
950
+
951
+ Args:
952
+ group (str, optional): The communication group to work on. Normally, the group should be created by
953
+ :func:`mindspore.ops.communication.new_group`. If ``None``, which means ``"hccl_world_group"`` in Ascend.
954
+ Default: ``None``.
955
+
956
+ Returns:
957
+ int, the rank size of the group.
958
+ return -1, if the group is not available.
959
+
960
+ Raises:
961
+ TypeError: If group is not a string.
962
+
963
+ Supported Platforms:
964
+ ``Ascend``
965
+
966
+ Examples:
967
+ .. note::
968
+ Before running the following examples, you need to configure the communication environment variables.
969
+
970
+ For Ascend devices, it is recommended to use the msrun startup method
971
+ without any third-party or configuration file dependencies.
972
+ Please see the `msrun start up
973
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
974
+ for more details.
975
+
976
+ This example should be run with 8 devices.
977
+
978
+ >>> import mindspore as ms
979
+ >>> from mindspore.ops.communication import init_process_group, get_world_size
980
+ >>> ms.set_device(device_target="Ascend")
981
+ >>> init_process_group()
982
+ >>> group_size = get_world_size()
983
+ >>> print("group_size is: ", group_size)
984
+ group_size is: 8
985
+ """
986
+ ret = -1
987
+ if group is None:
988
+ group = GlobalComm.WORLD_COMM_GROUP
989
+ if not isinstance(group, str):
990
+ raise TypeError(
991
+ "For 'get_world_size', the argument 'group' must be type of string, "
992
+ "but got 'group' type : {}.".format(type(group))
993
+ )
994
+ try:
995
+ ret = _get_size_helper(group)
996
+ except RuntimeError as e:
997
+ logger.warning(e)
998
+ ret = -1
999
+ return ret
1000
+
1001
+
1002
+ def new_group(ranks=None,
1003
+ timeout=None,
1004
+ backend=None,
1005
+ pg_options=None,
1006
+ use_local_synchronization=False,
1007
+ group_desc=None):
1008
+ """
1009
+ Create a new distributed group.
1010
+
1011
+ Note:
1012
+ This method should be used after :func:`mindspore.ops.communication.init_process_group`.
1013
+
1014
+ Args:
1015
+ ranks (list[int], optional): List of ranks of group members. If ``None``,
1016
+ will be create the world group. Default is ``None``.
1017
+ timeout (int, invalid): Currently it is a reserved parameter.
1018
+ backend (str, invalid): Support backend Library, Currently support ``"hccl"`` and ``"mccl"``.
1019
+ when backend is ``"hccl"`` will use Huawei Collective Communication Library(HCCL).
1020
+ when backend is ``"mccl"`` will use MindSpore Collective Communication Library(MCCL).
1021
+ If ``None``, which means ``"hccl"`` in Ascend. Default is ``None``.
1022
+ pg_options (GroupOptions, optional): Additional communication group configuration parameters.
1023
+ The backend will automatically select supported parameters and apply them during group
1024
+ initialization. i.e. for the ``HCCL`` backend, ``hccl_config`` can be specified so that
1025
+ group initialization configurations can be applied. Default is ``None``.
1026
+
1027
+ `GroupOptions` is defined as a class that can be instantiated as a python object.
1028
+
1029
+ .. code-block::
1030
+
1031
+ GroupOptions {
1032
+ hccl_config(dict)
1033
+ }
1034
+
1035
+
1036
+ `hccl_config` currently only supports "hccl_buffer_size" or "hccl_comm".
1037
+
1038
+ - hccl_buffer_size (uint32): specifies the size of the HCCL communication buffer.
1039
+ - hccl_comm (int64): specifies an existing HcclComm pointer. If "hccl_comm" is set,
1040
+ "hccl_buffer_size" will be ignored.
1041
+
1042
+ use_local_synchronization (bool, invalid): Currently it is a reserved parameter.
1043
+ group_desc (str, invalid): Currently it is a reserved parameter.
1044
+
1045
+ Returns:
1046
+ A string with group name. Return "" in the abnormal scenarios.
1047
+
1048
+ Raises:
1049
+ TypeError: If list ranks in Group has duplicate rank id.
1050
+
1051
+ Supported Platforms:
1052
+ ``Ascend`` ``CPU``
1053
+
1054
+ Examples:
1055
+ .. note::
1056
+ Before running the following examples, you need to configure the communication environment variables.
1057
+ For Ascend devices, it is recommended to use the msrun startup method
1058
+ without any third-party or configuration file dependencies.
1059
+ Please see the `msrun start up
1060
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1061
+ for more details.
1062
+
1063
+ >>> import mindspore as ms
1064
+ >>> from mindspore.ops.communication import init_process_group, new_group
1065
+ >>> ms.set_device(device_target="Ascend")
1066
+ >>> init_process_group()
1067
+ >>> group = new_group()
1068
+ >>> print("group is: ", group)
1069
+ group is: hccl_world_group
1070
+ """
1071
+ if ranks is not None:
1072
+ if not isinstance(ranks, list):
1073
+ raise TypeError("ranks must be list, but got {}".format(type(ranks)))
1074
+ ranks = sorted(ranks)
1075
+ else:
1076
+ return GlobalComm.WORLD_COMM_GROUP
1077
+ if backend is None:
1078
+ backend = "hccl"
1079
+ if not isinstance(backend, str) or backend not in ("hccl", "mccl"):
1080
+ raise TypeError(f"the input backend must be hccl or mccl, but got {backend}")
1081
+ group = backend + "_" + str(len(ranks)) + "_" + hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
1082
+ if pg_options is not None:
1083
+ if not isinstance(pg_options, GroupOptions):
1084
+ raise TypeError("pg_options must be type GroupOptions, but got {}".format(type(pg_options)))
1085
+ try:
1086
+ create_group(group, ranks, pg_options)
1087
+ except RuntimeError as e:
1088
+ logger.warning(e)
1089
+ group = ""
1090
+ return group
1091
+
1092
+
1093
+ def get_backend(group=None):
1094
+ """
1095
+ Get the backend of communication process groups.
1096
+
1097
+ Note:
1098
+ Only one communication backend is supported by MindSpore for each process.
1099
+ It should be one of `hccl`/`nccl`/`mccl`. Currently only support hccl and mccl.
1100
+
1101
+ Args:
1102
+ group (str, optional): The communication group to work on.
1103
+ Normally, the group should be created by :func:`mindspore.ops.communication.new_group`, If ``None``,
1104
+ which means ``"hccl_world_group"`` in Ascend. Default: ``None``.
1105
+
1106
+ Returns:
1107
+ string, the backend of the group.
1108
+
1109
+ Raises:
1110
+ TypeError: If the `group` is not a str.
1111
+
1112
+ Supported Platforms:
1113
+ ``Ascend`` ``CPU``
1114
+
1115
+ Examples:
1116
+ .. note::
1117
+ Before running the following examples, you need to configure the communication environment variables.
1118
+ For Ascend devices, it is recommended to use the msrun startup method
1119
+ without any third-party or configuration file dependencies.
1120
+ Please see the `msrun start up
1121
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1122
+ for more details.
1123
+
1124
+ >>> import mindspore as ms
1125
+ >>> from mindspore.ops.communication import init_process_group, get_backend
1126
+ >>> ms.set_device(device_target="Ascend")
1127
+ >>> init_process_group()
1128
+ >>> backend = get_backend()
1129
+ >>> print("backend is: ", backend)
1130
+ backend is: hccl
1131
+ """
1132
+ if group is None:
1133
+ return BACKEND_HCCL
1134
+ if not isinstance(group, str):
1135
+ raise TypeError(
1136
+ "For 'get_backend', the argument 'group' must be type of string or None, "
1137
+ "but got 'group' type : {}.".format(type(group))
1138
+ )
1139
+ if BACKEND_HCCL in group:
1140
+ return BACKEND_HCCL
1141
+ if BACKEND_MCCL in group:
1142
+ return BACKEND_MCCL
1143
+ return _get_backend()
1144
+
1145
+
1146
+ def get_global_rank(group, group_rank):
1147
+ """
1148
+ A function that returns the rank id in the world group corresponding to the
1149
+ rank which id is 'group_rank' in the user group.
1150
+
1151
+ Note:
1152
+ This method should be used after :func:`mindspore.ops.communication.init_process_group`.
1153
+
1154
+ Args:
1155
+ group (str): The communication group to work on. Normally, the group should
1156
+ be created by :func:`mindspore.ops.communication.new_group`. If ``None``, which
1157
+ means ``"hccl_world_group"`` in Ascend.
1158
+ group_rank (int): Group rank to query.
1159
+
1160
+ Returns:
1161
+ An integer scalar with the rank id in the world group.
1162
+
1163
+ Raises:
1164
+ TypeError: If the `group` is not a str.
1165
+ TypeError: If the `group_rank` is not an integer.
1166
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1167
+
1168
+ Supported Platforms:
1169
+ ``Ascend``
1170
+
1171
+ Examples:
1172
+ .. note::
1173
+ Before running the following examples, you need to configure the communication environment variables.
1174
+
1175
+ For Ascend devices, it is recommended to use the msrun startup method
1176
+ without any third-party or configuration file dependencies.
1177
+
1178
+ Please see the `msrun start up
1179
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1180
+ for more details.
1181
+
1182
+ This example should be run with 8 devices.
1183
+
1184
+ >>> import mindspore as ms
1185
+ >>> from mindspore.ops.communication import init_process_group, get_global_rank, new_group, get_rank
1186
+ >>> ms.set_device(device_target="Ascend")
1187
+ >>> # Launch 8 processes.
1188
+ >>> init_process_group()
1189
+ >>> rank_ids = [0,4]
1190
+ >>> if get_rank() in rank_ids:
1191
+ ... group = new_group(rank_ids)
1192
+ ... world_rank_id = get_global_rank(group, 1)
1193
+ ... print("world_rank_id is: ", world_rank_id)
1194
+ #rank 0 and 4:
1195
+ world_rank_id is: 4
1196
+ """
1197
+ if not isinstance(group_rank, int):
1198
+ raise TypeError(
1199
+ f"The group_rank argument must be integer, but got {type(group_rank)}."
1200
+ )
1201
+
1202
+ if group is None or group is GlobalComm.WORLD_COMM_GROUP:
1203
+ return group_rank
1204
+
1205
+ if not isinstance(group, str):
1206
+ raise TypeError(
1207
+ "For 'get_global_rank', the argument 'group' must be type of string or None, "
1208
+ "but got 'group' type : {}.".format(type(group))
1209
+ )
1210
+ return get_world_rank_from_group_rank(group, group_rank)
1211
+
1212
+
1213
+ def get_group_rank(group, global_rank):
1214
+ """
1215
+ Get the rank ID in the specified user communication group corresponding to
1216
+ the rank ID in the world communication group.
1217
+
1218
+ Note:
1219
+ This method should be used after :func:`mindspore.ops.communication.init_process_group`.
1220
+
1221
+ Args:
1222
+ group (str): The communication group to work on. Normally, the group should be
1223
+ created by :func:`mindspore.ops.communication.new_group`. If ``None``, which means
1224
+ ``"hccl_world_group"`` in Ascend.
1225
+ global_rank (int): A rank ID in the world communication group.
1226
+
1227
+ Returns:
1228
+ int, the rank ID in the user communication group.
1229
+
1230
+ Raises:
1231
+ TypeError: If global_rank is not an integer or the group is not a string.
1232
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1233
+
1234
+ Supported Platforms:
1235
+ ``Ascend``
1236
+
1237
+ Examples:
1238
+ .. note::
1239
+ Before running the following examples, you need to configure the communication environment variables.
1240
+
1241
+ For Ascend devices, it is recommended to use the msrun startup method
1242
+ without any third-party or configuration file dependencies.
1243
+ Please see the `msrun start up
1244
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1245
+ for more details.
1246
+
1247
+ This example should be run with 8 devices.
1248
+
1249
+ >>> import mindspore as ms
1250
+ >>> from mindspore.ops.communication import init_process_group, new_group, get_group_rank, get_rank
1251
+ >>> ms.set_device(device_target="Ascend")
1252
+ >>> # Launch 8 processes.
1253
+ >>> init_process_group()
1254
+ >>> rank_ids = [0,4]
1255
+ >>> if get_rank() in rank_ids:
1256
+ ... group = new_group(rank_ids)
1257
+ ... group_rank_id = get_group_rank(group, 4)
1258
+ ... print("group_rank_id is: ", group_rank_id)
1259
+ #rank 0 and 4:
1260
+ group_rank_id is: 1
1261
+ """
1262
+ if not isinstance(global_rank, int):
1263
+ raise TypeError(
1264
+ f"The global_rank argument must be integer, but got {type(global_rank)}."
1265
+ )
1266
+ if group is None:
1267
+ group = GlobalComm.WORLD_COMM_GROUP
1268
+ if not isinstance(group, str):
1269
+ raise TypeError(
1270
+ "For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, "
1271
+ "but got 'group' type : {}.".format(type(group))
1272
+ )
1273
+ return _get_group_rank_from_world_rank_from_cache_helper(
1274
+ world_rank_id=global_rank, group=group
1275
+ )
1276
+
1277
+
1278
+ def get_process_group_ranks(group=None):
1279
+ """
1280
+ Gets the ranks of the specific group and returns the process ranks in the communication group as a list.
1281
+
1282
+ Args:
1283
+ group (str, optional): The communication group to work on. Normally, the group should be created by
1284
+ :func:`mindspore.ops.communication.new_group`. If ``None``, which means ``"hccl_world_group"`` in Ascend.
1285
+ Default: ``None``.
1286
+
1287
+ Returns:
1288
+ List (List[int]), List of process ranks in the specified communication group.
1289
+
1290
+ Raises:
1291
+ TypeError: If the `group` is not a str.
1292
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1293
+
1294
+ Supported Platforms:
1295
+ ``Ascend`` ``CPU``
1296
+
1297
+ Examples:
1298
+ .. note::
1299
+ Before running the following examples, you need to configure the communication environment variables.
1300
+
1301
+ For Ascend devices, it is recommended to use the msrun startup method
1302
+ without any third-party or configuration file dependencies.
1303
+
1304
+ Please see the `msrun start up
1305
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1306
+ for more details.
1307
+
1308
+ This example should be run with 4 devices.
1309
+
1310
+ >>> import mindspore as ms
1311
+ >>> from mindspore.ops.communication import init_process_group, get_process_group_ranks
1312
+ >>> # Launch 4 processes.
1313
+ >>> ms.set_device(device_target="Ascend")
1314
+ >>> init_process_group()
1315
+ >>> output = get_process_group_ranks()
1316
+ >>> print(output)
1317
+ [0, 1, 2, 3]
1318
+
1319
+ """
1320
+ if group is None:
1321
+ group = GlobalComm.WORLD_COMM_GROUP
1322
+
1323
+ if not isinstance(group, str):
1324
+ raise TypeError(
1325
+ "For 'get_process_group_ranks', the argument 'group' must be type of string or None, "
1326
+ "but got 'group' type : {}.".format(type(group))
1327
+ )
1328
+ return _get_group_ranks(group)
1329
+
1330
+
1331
+ @_primexpr
1332
+ def _check_all_tensor_same_dtype_and_shape(*tensor_lists):
1333
+ """check all the input tensor has same dtype and shape"""
1334
+ consistent_dtype = None
1335
+ consistent_shape = None
1336
+ for list_ in tensor_lists:
1337
+ if not isinstance(list_, (list, tuple)):
1338
+ list_ = [list_]
1339
+ for tensor_ in list_:
1340
+ if not isinstance(tensor_, Tensor):
1341
+ continue
1342
+ dtype = tensor_.dtype
1343
+ shape = tensor_.shape
1344
+ if consistent_dtype is None:
1345
+ consistent_dtype = dtype
1346
+ consistent_shape = shape
1347
+ else:
1348
+ if dtype != consistent_dtype:
1349
+ raise TypeError(
1350
+ "tensor_lists dtype must be the same, "
1351
+ f"but got {consistent_dtype} and {dtype}."
1352
+ )
1353
+ if shape != consistent_shape:
1354
+ raise TypeError(
1355
+ "tensor_lists shape must be the same, "
1356
+ f"but got {consistent_shape} and {shape}."
1357
+ )
1358
+
1359
+
1360
+ @_primexpr
1361
+ def _check_output_shape(output, expected_shape, op_name):
1362
+ if output.shape != expected_shape:
1363
+ raise TypeError(
1364
+ f"For {op_name}, the output shape should be {expected_shape}, "
1365
+ f"but got {output.shape}.")
1366
+
1367
+
1368
+ @_primexpr
1369
+ def _check_output_dtype(output, expected_dtype, op_name):
1370
+ if output.dtype != expected_dtype:
1371
+ raise TypeError(
1372
+ f"For {op_name}, the output dtype should be {expected_dtype}, "
1373
+ f"but got {output.dtype}.")
1374
+
1375
+
1376
+ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
1377
+ """
1378
+ Reduce tensors across all devices in such a way that all deviceswill get the same final result,
1379
+ returns the tensor which is all reduced.
1380
+
1381
+ Note:
1382
+ The tensors must have the same shape and format in all processes of the collection.
1383
+
1384
+ Args:
1385
+ tensor (Tensor): The input tensor of collective. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1386
+ If the function operates in-place, this also means output of collective.
1387
+ op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
1388
+ Default: ``ReduceOp.SUM`` .
1389
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
1390
+ Ascend. Default: ``None``.
1391
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
1392
+
1393
+ Returns:
1394
+ - CommHandle, if the function operates in-place, return it. CommHandle is an async work handle,
1395
+ if `async_op` is set to True. CommHandle will be None, when `async_op` is False.
1396
+ - Tuple(Tensor, CommHandle), if the function operates non in-place, return it.
1397
+ the output tensor has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
1398
+ The contents depend on the specified operation. CommHandle is an async work handle,
1399
+ if `async_op` is set to True. CommHandle will be None, when `async_op` is False.
1400
+
1401
+ Raises:
1402
+ TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str,
1403
+ `op` range is illegal or async_op is not bool.
1404
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1405
+
1406
+ Supported Platforms:
1407
+ ``Ascend`` ``CPU``
1408
+
1409
+ Examples:
1410
+ .. note::
1411
+ Before running the following examples, you need to configure the communication environment variables.
1412
+
1413
+ For Ascend devices, it is recommended to use the msrun startup method
1414
+ without any third-party or configuration file dependencies.
1415
+ Please see the `msrun start up
1416
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1417
+ for more details.
1418
+
1419
+ This example should be run with 2 devices.
1420
+
1421
+ >>> import numpy as np
1422
+ >>> from mindspore.ops.communication import init_process_group
1423
+ >>> from mindspore.ops.communication import all_reduce
1424
+ >>> from mindspore import Tensor
1425
+ >>>
1426
+ >>> init_process_group()
1427
+ >>> tensor = Tensor(np.ones([2, 8]).astype(np.float32))
1428
+ >>> output = all_reduce(tensor)
1429
+ >>> print(tensor)
1430
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
1431
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
1432
+
1433
+ """
1434
+ if not isinstance(tensor, (Tensor, Tensor_)):
1435
+ raise TypeError("For all_reduce, the input tensor must be tensor")
1436
+ if not isinstance(op, str):
1437
+ raise TypeError("For all_reduce, the input op type must be str")
1438
+ if op not in ("sum", "prod", "min", "max"):
1439
+ raise TypeError(
1440
+ "For all_reduce, the input op value must be one of sum, prod, min, max"
1441
+ )
1442
+
1443
+ if group is None:
1444
+ group = GlobalComm.WORLD_COMM_GROUP
1445
+
1446
+ if not isinstance(group, str):
1447
+ raise TypeError(
1448
+ "The argument 'group' must be type of string, "
1449
+ "but got 'group' type : {}.".format(type(group))
1450
+ )
1451
+ if not isinstance(async_op, bool):
1452
+ raise TypeError(
1453
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
1454
+ )
1455
+
1456
+ if is_inplace_func() is True:
1457
+ output = dist_comm_all_reduce_op(tensor, op, group)
1458
+ _, handle = _deal_comm_outputs(output, async_op)
1459
+ return handle
1460
+ out = inner_comm_all_reduce_op(tensor, op, group)
1461
+ return _deal_comm_outputs(out, async_op)
1462
+
1463
+
1464
+ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
1465
+ """
1466
+ Gathers tensors from the specified communication group and returns the tensor which is all gathered.
1467
+
1468
+ Note:
1469
+ The tensors must have the same shape and format in all processes of the collection.
1470
+
1471
+ Args:
1472
+ output_tensor (Tensor): The output tensor to be all gathered into tensor.If the number of devices
1473
+ in the group is N, then the shape of output tensor is :math:`(N*x_1, x_2, ..., x_R)`.
1474
+ If the function operates non in-place, This parameter is invalid.
1475
+ input_tensor (Tensor): The input tensor to be all gathered into tensor.
1476
+ The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1477
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
1478
+ Ascend. Default: ``None``.
1479
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
1480
+
1481
+ Returns:
1482
+ - CommHandle, if the function operates in-place, return it. CommHandle is an async work handle,
1483
+ if `async_op` is set to True. CommHandle will be None, when `async_op` is False.
1484
+ - Tuple(Tensor, CommHandle), if the function operates non in-place, if the number of devices in the group is N,
1485
+ then the shape of output tensor is :math:`(N, x_1, x_2, ..., x_R)`.
1486
+ CommHandle is an async work handle, if `async_op` is set to True.
1487
+ CommHandle will be None, when `async_op` is False.
1488
+
1489
+ Raises:
1490
+ TypeError: If the type of the input_tensor or output_tensor parameter is not Tensor,
1491
+ `group` is not a str, or async_op is not bool.
1492
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1493
+
1494
+ Supported Platforms:
1495
+ ``Ascend``
1496
+
1497
+ Examples:
1498
+ .. note::
1499
+ Before running the following examples, you need to configure the communication environment variables.
1500
+
1501
+ For Ascend devices, it is recommended to use the msrun startup method
1502
+ without any third-party or configuration file dependencies.
1503
+ Please see the `msrun start up
1504
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1505
+ for more details.
1506
+
1507
+ This example should be run with 2 devices.
1508
+
1509
+ >>> import numpy as np
1510
+ >>> import mindspore as ms
1511
+ >>> from mindspore import ops
1512
+ >>> from mindspore.ops.communication import init_process_group
1513
+ >>> from mindspore.ops.communication import all_gather_into_tensor
1514
+ >>> from mindspore import Tensor
1515
+ >>>
1516
+ >>> ms.set_device(device_target="Ascend")
1517
+ >>> init_process_group()
1518
+ >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
1519
+ >>> out_tensor = Tensor(np.zeros([4, 8]).astype(np.float32))
1520
+ >>> output = all_gather_into_tensor(out_tensor, input_tensor)
1521
+ >>> print(out_tensor)
1522
+ [[1. 1. 1. 1. 1. 1. 1. 1.]
1523
+ [1. 1. 1. 1. 1. 1. 1. 1.]
1524
+ [1. 1. 1. 1. 1. 1. 1. 1.]
1525
+ [1. 1. 1. 1. 1. 1. 1. 1.]]
1526
+
1527
+ """
1528
+ if not isinstance(input_tensor, (Tensor, Tensor_)):
1529
+ raise TypeError("For all_gather_into_tensor, the input tensor must be tensor")
1530
+ if is_inplace_func() is True and \
1531
+ not isinstance(output_tensor, (Tensor, Tensor_)):
1532
+ raise TypeError("For all_gather_into_tensor, the output tensor must be tensor")
1533
+ if group is None:
1534
+ group = GlobalComm.WORLD_COMM_GROUP
1535
+ if not isinstance(group, str):
1536
+ raise TypeError(
1537
+ "The argument 'group' must be type of string, "
1538
+ "but got 'group' type : {}.".format(type(group))
1539
+ )
1540
+ if not isinstance(async_op, bool):
1541
+ raise TypeError(
1542
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
1543
+ )
1544
+ group_size = get_cache_group_size(group)
1545
+ if is_inplace_func() is True:
1546
+ output = dist_comm_all_gather_into_tensor_op(
1547
+ output_tensor, input_tensor, group_size, group
1548
+ )
1549
+ _, handle = _deal_comm_outputs(output, async_op)
1550
+ return handle
1551
+ output = inner_comm_all_gather_op(input_tensor, group_size, group)
1552
+ return _deal_comm_outputs(output, async_op)
1553
+
1554
+
1555
+ def all_gather_into_tensor_uneven(output, input, output_split_sizes=None, group=None, async_op=False):
1556
+ r"""
1557
+ Gathers and concatenates tensors across devices with uneven first dimensions.
1558
+
1559
+ Note:
1560
+ - Input tensors must have identical shapes except for the first dimension.
1561
+ - Output tensor's first dimension should equal to the sum of all devices' input first dimensions.
1562
+
1563
+ Args:
1564
+ output (Tensor): Concatenated output tensor with shape :math:`(\sum_{i=0}^{N-1} x_{i1}, x_2, ..., x_R)`,
1565
+ where N is the number of devices in the group.
1566
+ input (Tensor): Local input tensor with shape :math:`(x_{k1}, x_2, ..., x_R)`, where k is current device's rank.
1567
+ output_split_sizes (list[int], optional): Specifies first dimension sizes from each device.
1568
+ Must match actual input dimensions when provided.
1569
+ If ``None``, assumes equal split sizes across devices. Default: ``None``.
1570
+ group (str, optional): The communication group to work on. If ``None``,
1571
+ which means ``"hccl_world_group"`` in Ascend. Default: ``None``.
1572
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False``.
1573
+
1574
+ Returns:
1575
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
1576
+ CommHandle will be None, when `async_op` is False.
1577
+
1578
+ Raises:
1579
+ ValueError: If the shape of `input` does not match the constraints of `output_split_sizes`.
1580
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1581
+
1582
+ Supported Platforms:
1583
+ ``Ascend``
1584
+
1585
+ Examples:
1586
+ .. note::
1587
+ Before running the following examples, you need to configure the communication environment variables.
1588
+
1589
+ For Ascend devices, it is recommended to use the msrun startup method
1590
+ without any third-party or configuration file dependencies.
1591
+ Please see the `msrun start up
1592
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1593
+ for more details.
1594
+
1595
+ This example should be run with 2 devices.
1596
+
1597
+ >>> import numpy as np
1598
+ >>> import mindspore as ms
1599
+ >>> from mindspore import ops
1600
+ >>> from mindspore.ops.communication import init_process_group, get_rank
1601
+ >>> from mindspore.ops.communication import all_gather_into_tensor_uneven
1602
+ >>> from mindspore import Tensor
1603
+ >>>
1604
+ >>> ms.set_device(device_target="Ascend")
1605
+ >>> init_process_group()
1606
+ >>> if get_rank() == 0:
1607
+ >>> input_tensor = Tensor(np.ones([3, 4]).astype(np.float32))
1608
+ >>> else:
1609
+ >>> input_tensor = Tensor(np.ones([2, 4]).astype(np.float32))
1610
+ >>> out_tensor = Tensor(np.zeros([5, 4]).astype(np.float32))
1611
+ >>> output_split_sizes = [3, 2]
1612
+ >>> output = all_gather_into_tensor_uneven(out_tensor, input_tensor, output_split_sizes)
1613
+ >>> print(out_tensor)
1614
+ [[1. 1. 1. 1.]
1615
+ [1. 1. 1. 1.]
1616
+ [1. 1. 1. 1.]
1617
+ [1. 1. 1. 1.]
1618
+ [1. 1. 1. 1.]]
1619
+ """
1620
+ if is_inplace_func() is False:
1621
+ raise ValueError("Non-inplace mode is currently not supported.")
1622
+ if group is None:
1623
+ group = GlobalComm.WORLD_COMM_GROUP
1624
+ if not isinstance(group, str):
1625
+ raise TypeError(
1626
+ "The argument 'group' must be type of string, "
1627
+ "but got 'group' type : {}.".format(type(group))
1628
+ )
1629
+ if not isinstance(async_op, bool):
1630
+ raise TypeError(
1631
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
1632
+ )
1633
+ group_size = get_cache_group_size(group)
1634
+ output_split_sizes = [] if output_split_sizes is None else output_split_sizes
1635
+ result = dist_comm_all_gather_into_tensor_uneven_op(
1636
+ output, input, output_split_sizes, group_size, group
1637
+ )
1638
+ _, handle = _deal_comm_outputs(result, async_op)
1639
+ return handle
1640
+
1641
+
1642
+ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
1643
+ r"""
1644
+ Reduces and scatters tensors from the specified communication group and
1645
+ returns the tensor which is reduced and scattered.
1646
+
1647
+ Note:
1648
+ The tensors must have the same shape and format in all processes of the collection.
1649
+
1650
+ Args:
1651
+ output(Tensor): the output tensor has the same dtype as `input_x` with a shape of :math:`(N/rank\_size, *)`.
1652
+ If the function operates non in-place, This parameter is invalid.
1653
+ input(Tensor): The input tensor to be reduced and scattered, suppose it has a shape :math:`(N, *)`, where `*`
1654
+ means any number of additional dimensions. N must be divisible by rank_size.
1655
+ rank_size refers to the number of cards in the communication group.
1656
+ op (str, optional): Specifies an operation used for element-wise reductions,
1657
+ like SUM and MAX. Default: ``ReduceOp.SUM`` .
1658
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
1659
+ Ascend. Default: ``None``.
1660
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
1661
+
1662
+ Returns:
1663
+ - CommHandle, if the function operates in-place, CommHandle is an async work handle,
1664
+ if `async_op` is set to True. CommHandle will be None, when `async_op` is False.
1665
+ - Tuple(Tensor, CommHandle), if the function operates non in-place, return it.
1666
+ the output tensor has the same dtype as `input_x` with a shape of
1667
+ :math:`(N/rank\_size, *)`. CommHandle is an async work handle, if `async_op` is set to True.
1668
+ CommHandle will be None, when `async_op` is False.
1669
+
1670
+ Raises:
1671
+ TypeError: If the type of the input and output parameter is not Tensor, any of `op` and `group` is not a str.
1672
+ async_op is not bool or 'op' is invalid.
1673
+ ValueError: If the first dimension of the input cannot be divided by the rank_size.
1674
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1675
+
1676
+ Supported Platforms:
1677
+ ``Ascend``
1678
+
1679
+ Examples:
1680
+ .. note::
1681
+ Before running the following examples, you need to configure the communication environment variables.
1682
+
1683
+ For Ascend devices, it is recommended to use the msrun startup method
1684
+ without any third-party or configuration file dependencies.
1685
+ Please see the `msrun start up
1686
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1687
+ for more details.
1688
+
1689
+ This example should be run with 2 devices.
1690
+
1691
+ >>> import mindspore as ms
1692
+ >>> from mindspore import Tensor
1693
+ >>> from mindspore.ops.communication import init_process_group
1694
+ >>> from mindspore.ops.communication import reduce_scatter_tensor
1695
+ >>> import numpy as np
1696
+ >>>
1697
+ >>> ms.set_device(device_target="Ascend")
1698
+ >>> init_process_group()
1699
+ >>> input_tensor = Tensor(np.ones([8, 8]).astype(np.float32))
1700
+ >>> output_tensor = Tensor(np.ones([4, 8]).astype(np.float32))
1701
+ >>> output = reduce_scatter_tensor(output_tensor ,input_tensor)
1702
+ >>> print(output_tensor)
1703
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
1704
+ [2. 2. 2. 2. 2. 2. 2. 2.]
1705
+ [2. 2. 2. 2. 2. 2. 2. 2.]
1706
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
1707
+
1708
+ """
1709
+ if not isinstance(input, (Tensor, Tensor_)):
1710
+ raise TypeError("For reduce_scatter_tensor, the input tensor must be tensor")
1711
+ if is_inplace_func() is True and \
1712
+ not isinstance(output, (Tensor, Tensor_)):
1713
+ raise TypeError("For reduce_scatter_tensor, the output tensor must be tensor")
1714
+ if not isinstance(op, str):
1715
+ raise TypeError("For reduce_scatter_tensor, the input op type must be str")
1716
+ if op not in ("sum", "prod", "min", "max"):
1717
+ raise TypeError(
1718
+ "For reduce_scatter_tensor, the input op value must be one of sum, prod, min, max"
1719
+ )
1720
+ if group is None:
1721
+ group = GlobalComm.WORLD_COMM_GROUP
1722
+ if not isinstance(group, str):
1723
+ raise TypeError(
1724
+ "The argument 'group' must be type of string, "
1725
+ "but got 'group' type : {}.".format(type(group))
1726
+ )
1727
+ if not isinstance(async_op, bool):
1728
+ raise TypeError(
1729
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
1730
+ )
1731
+ rank_size = get_cache_group_size(group)
1732
+ if is_inplace_func() is True:
1733
+ out = dist_comm_reduce_scatter_tensor_op(output, input, rank_size, op, group)
1734
+ _, handle = _deal_comm_outputs(out, async_op)
1735
+ return handle
1736
+ out = inner_comm_reduce_scatter_op(input, rank_size, op, group)
1737
+ return _deal_comm_outputs(out, async_op)
1738
+
1739
+
1740
+ def reduce_scatter_tensor_uneven(output, input, input_split_sizes=None, op=ReduceOp.SUM, group=None, async_op=False):
1741
+ r"""
1742
+ Reduce tensors from the specified communication group and scatter to the output tensor
1743
+ according to `input_split_sizes`.
1744
+
1745
+ Note:
1746
+ - The input tensor must have identical shape and format across all processes.
1747
+ - The first dimension of input tensor should equal to the sum of `input_split_sizes`.
1748
+
1749
+ Args:
1750
+ output(Tensor): the output tensor has the same dtype as `input` with a shape of
1751
+ :math:`(input\_split\_sizes[rank], *)`, where rank is the local rank id of the device.
1752
+ input(Tensor): The input tensor to be reduced and scattered, Expected shape :math:`(N, *)`, where `*`
1753
+ means any number of additional dimensions. N must equal the sum of `input_split_sizes` across ranks.
1754
+ input_split_sizes (list[int], optional): List specifying how to split the first dimension of input tensor.
1755
+ If ``None``, splits evenly according to group size. Default: ``None``.
1756
+ op (str, optional): Specifies an operation used for element-wise reductions,
1757
+ One of ReduceOp: 'SUM', 'MIN', 'MAX'. Default: ``ReduceOp.SUM``.
1758
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
1759
+ Ascend. Default: ``None``.
1760
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False``.
1761
+
1762
+ Returns:
1763
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
1764
+ CommHandle will be None, when `async_op` is False.
1765
+
1766
+ Raises:
1767
+ ValueError: If the shape of `output` does not match the constraints of `input_split_sizes`.
1768
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1769
+
1770
+ Supported Platforms:
1771
+ ``Ascend``
1772
+
1773
+ Examples:
1774
+ .. note::
1775
+ Before running the following examples, you need to configure the communication environment variables.
1776
+
1777
+ For Ascend devices, it is recommended to use the msrun startup method
1778
+ without any third-party or configuration file dependencies.
1779
+ Please see the `msrun start up
1780
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1781
+ for more details.
1782
+
1783
+ This example should be run with 2 devices.
1784
+
1785
+ >>> import mindspore as ms
1786
+ >>> from mindspore import Tensor
1787
+ >>> from mindspore.ops.communication import init_process_group, get_rank
1788
+ >>> from mindspore.ops.communication import reduce_scatter_tensor_uneven
1789
+ >>> import numpy as np
1790
+ >>>
1791
+ >>> ms.set_device(device_target="Ascend")
1792
+ >>> init_process_group()
1793
+ >>> input_tensor = Tensor(np.ones([5, 8]).astype(np.float32))
1794
+ >>> if get_rank() == 0:
1795
+ >>> output_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
1796
+ >>> else:
1797
+ >>> output_tensor = Tensor(np.ones([3, 8]).astype(np.float32))
1798
+ >>> input_split_sizes = [2, 3]
1799
+ >>> output = reduce_scatter_tensor_uneven(output_tensor, input_tensor, input_split_sizes)
1800
+ >>> print(output_tensor)
1801
+ rank 0:
1802
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
1803
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
1804
+ rank 1:
1805
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
1806
+ [2. 2. 2. 2. 2. 2. 2. 2.]
1807
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
1808
+ """
1809
+ if is_inplace_func() is False:
1810
+ raise ValueError("Non-inplace mode is currently not supported.")
1811
+ if not isinstance(op, str):
1812
+ raise TypeError("For reduce_scatter_tensor_uneven, the input op type must be str")
1813
+ if op not in ("sum", "min", "max"):
1814
+ raise TypeError(
1815
+ "For reduce_scatter_tensor_uneven, the input op value must be one of sum, prod, min, max"
1816
+ )
1817
+ if group is None:
1818
+ group = GlobalComm.WORLD_COMM_GROUP
1819
+ if not isinstance(group, str):
1820
+ raise TypeError(
1821
+ "The argument 'group' must be type of string, "
1822
+ "but got 'group' type : {}.".format(type(group))
1823
+ )
1824
+ if not isinstance(async_op, bool):
1825
+ raise TypeError(
1826
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
1827
+ )
1828
+ input_split_sizes = [] if input_split_sizes is None else input_split_sizes
1829
+ rank_size = get_cache_group_size(group)
1830
+ result = dist_comm_reduce_scatter_tensor_uneven_op(
1831
+ output, input, input_split_sizes, rank_size, op, group
1832
+ )
1833
+ _, handle = _deal_comm_outputs(result, async_op)
1834
+ return handle
1835
+
1836
+
1837
+ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
1838
+ """
1839
+ Reduces tensors across the processes in the specified communication group, sends the result
1840
+ to the target dst(global rank), and returns the tensor which is sent to the target process.
1841
+
1842
+ Note:
1843
+ - Only process with destination rank receives the reduced output.
1844
+ - Only support PyNative mode, Graph mode is not currently supported.
1845
+ - Other processes only get a tensor with shape [1], which has no mathematical meaning.
1846
+
1847
+ Args:
1848
+ tensor (Tensor): Input and output of the collective. The function operates in-place.
1849
+ dst (int): The target rank of the process(global rank) that receives the reduced output.
1850
+ op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
1851
+ Default: ``ReduceOp.SUM`` .
1852
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
1853
+ Ascend. Default: ``None``.
1854
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
1855
+
1856
+ Returns:
1857
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to ``True``.
1858
+ CommHandle will be None, when `async_op` is ``False``.
1859
+
1860
+ Raises:
1861
+ TypeError: If the type of `tensor` is not Tensor, any of `op` and `group` is not a str.
1862
+ async_op is not bool or 'op' is invalid.
1863
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1864
+
1865
+ Supported Platforms:
1866
+ ``Ascend``
1867
+
1868
+ Examples:
1869
+ .. note::
1870
+ Before running the following examples, you need to configure the communication environment variables.
1871
+
1872
+ For Ascend devices, it is recommended to use the msrun startup method
1873
+ without any third-party or configuration file dependencies.
1874
+
1875
+ Please see the `msrun start up
1876
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1877
+ for more details.
1878
+
1879
+ This example should be run with 4 devices.
1880
+
1881
+ >>> from mindspore import ops
1882
+ >>> import mindspore.nn as nn
1883
+ >>> from mindspore.ops.communication import init_process_group, reduce
1884
+ >>> from mindspore import Tensor
1885
+ >>> import numpy as np
1886
+ >>> # Launch 2 processes.
1887
+ >>> init_process_group()
1888
+ >>> dest_rank=1
1889
+ >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
1890
+ >>> output = reduce(input_tensor, dest_rank)
1891
+ >>> print(input_tensor)
1892
+ Process with rank 0: [[1. 1. 1. 1. 1. 1. 1. 1.]
1893
+ [1. 1. 1. 1. 1. 1. 1. 1.]],
1894
+ Process with rank 1: [[2. 2. 2. 2. 2. 2. 2. 2.]
1895
+ [2. 2. 2. 2. 2. 2. 2. 2.]],
1896
+ """
1897
+ if is_inplace_func() is False:
1898
+ raise ValueError("Non-inplace mode is currently not supported.")
1899
+ if not isinstance(tensor, (Tensor, Tensor_)):
1900
+ raise TypeError("For reduce, the input tensor must be tensor")
1901
+ if not isinstance(dst, int):
1902
+ raise TypeError("For reduce, the dst must be int")
1903
+ if not isinstance(op, str):
1904
+ raise TypeError("For reduce, the input op type must be str")
1905
+ if op not in ("sum", "prod", "min", "max"):
1906
+ raise TypeError(
1907
+ "For reduce, the input op value must be one of sum, prod, min, max"
1908
+ )
1909
+ if group is None:
1910
+ group = GlobalComm.WORLD_COMM_GROUP
1911
+ if not isinstance(group, str):
1912
+ raise TypeError(
1913
+ "The argument 'group' must be type of string, "
1914
+ "but got 'group' type : {}.".format(type(group))
1915
+ )
1916
+ if not isinstance(async_op, bool):
1917
+ raise TypeError(
1918
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
1919
+ )
1920
+ result = dist_comm_reduce_op(tensor, op, dst, group)
1921
+ _, handle = _deal_comm_outputs(result, async_op)
1922
+ return handle
1923
+
1924
+
1925
+ class P2POp:
1926
+ """
1927
+ Object for `batch_isend_irecv` input, to store information of ``"isend"`` and ``"irecv"``.
1928
+
1929
+ Note:
1930
+ `tensor` will be modified in-place by final result when `op` is ``"irecv"``.
1931
+
1932
+ Args:
1933
+ op(Union[str, function]): Only string of ``"isend"`` and ``"irecv"`` are allowed.
1934
+ Or function of ``ops.isend`` and ``ops.irecv`` are allowed.
1935
+ tensor(Tensor): tensor for sending/receiving.
1936
+ peer(int): remote global rank for send/receive.
1937
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
1938
+ Ascend. Default: ``None``.
1939
+ tag(int, optional): currently not supported yet. Default: ``0``.
1940
+
1941
+ Returns:
1942
+ P2POp Object.
1943
+
1944
+ Raises:
1945
+ TypeError: when `op` is not string or function of 'isend' and 'irecv'.
1946
+ TypeError: when `tensor` is not type of Tensor or 'peer' is not int.
1947
+ NotImplementedError: when `tag` is not 0.
1948
+
1949
+ Supported Platforms:
1950
+ ``Ascend``
1951
+
1952
+ Examples:
1953
+ >>> import numpy as np
1954
+ >>> import mindspore
1955
+ >>> from mindspore.ops.communication import P2POp, isend, irecv
1956
+ >>> from mindspore import Tensor
1957
+ >>> # Launch 2 processes.
1958
+ >>> send_tensor = Tensor(1.)
1959
+ >>> send_op = P2POp('isend', send_tensor, 1)
1960
+ >>> send_op = P2POp(isend, send_tensor, 1)
1961
+ >>> recv_tensor = Tensor(0.)
1962
+ >>> recv_op = P2POp('irecv', recv_tensor, 0)
1963
+ >>> recv_op = P2POp(irecv, recv_tensor, 0)
1964
+ """
1965
+
1966
+ def __init__(self, op, tensor, peer, group=None, tag=0):
1967
+ self.op = op
1968
+ self.tensor = tensor
1969
+ self.peer = peer
1970
+ self.group = group
1971
+ self.tag = tag
1972
+
1973
+ def __new__(cls, op, tensor, peer, group=None, tag=0):
1974
+ if isinstance(op, str):
1975
+ op_name = op
1976
+ if op_name not in ["isend", "irecv"]:
1977
+ raise TypeError(
1978
+ f"Expected op to be of type isend or irecv, but got {op_name}"
1979
+ )
1980
+ else:
1981
+ if op not in [isend, irecv]:
1982
+ raise TypeError(
1983
+ f"Expected op to be of type isend or irecv, but got {op}"
1984
+ )
1985
+ op_name = op.__name__
1986
+
1987
+ if not isinstance(tensor, (Tensor, Tensor_)):
1988
+ raise TypeError(
1989
+ f"Expected tensor to be Tensor, but got {type(tensor)}."
1990
+ )
1991
+ if not isinstance(peer, int):
1992
+ raise TypeError("For P2POp, the peer must be int")
1993
+ if tag != 0:
1994
+ raise NotImplementedError("tag is not support yet.")
1995
+ return object.__new__(cls)
1996
+
1997
+
1998
+ TYPE_ISEND = 0
1999
+ TYPE_IRECV = 1
2000
+
2001
+
2002
+ def batch_isend_irecv(p2p_op_list):
2003
+ """
2004
+ Batch send and recv tensors asynchronously.
2005
+
2006
+ Note:
2007
+ - The 'isend' and 'irecv' of `P2POp` in `p2p_op_list` between ranks need to match each other.
2008
+ - `P2POp` in `p2p_op_list` can only use the same communication group.
2009
+ - `tag` of `P2POp` in `p2p_op_list` is not support yet.
2010
+ - `tensor` of `P2POp` in `p2p_op_list` will not be modified by result inplace.
2011
+ - Only support PyNative mode, Graph mode is not currently supported.
2012
+
2013
+ Args:
2014
+ p2p_op_list(list[P2POp]): list contains `P2POp`. `P2POp` is type of :class:`mindspore.ops.communication.P2POp`
2015
+
2016
+ Returns:
2017
+ list[CommHandle], CommHandle is an async work handle, Currently only one packaging handle is supported.
2018
+
2019
+ Raises:
2020
+ TypeError: If `p2p_op_list` is empty or `p2p_op_list` are not all type of `P2POp`.
2021
+ TypeError: The group name in `p2p_op_list` are not consistent.
2022
+ TypeError: The `tensor` in `p2p_op_list` are not Tensor.
2023
+ TypeError: The `op` in `p2p_op_list` are not isend or irecv.
2024
+
2025
+ Supported Platforms:
2026
+ ``Ascend``
2027
+
2028
+ Examples:
2029
+ .. note::
2030
+ Before running the following examples, you need to configure the communication environment variables.
2031
+
2032
+ For Ascend devices, it is recommended to use the msrun startup method
2033
+ without any third-party or configuration file dependencies.
2034
+ Please see the `msrun start up
2035
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2036
+ for more details.
2037
+
2038
+ This example should be run with 2 devices.
2039
+
2040
+ >>> import numpy as np
2041
+ >>> import mindspore
2042
+ >>> from mindspore.ops.communication import init_process_group, get_rank, get_world_size
2043
+ >>> from mindspore.ops.communication import batch_isend_irecv, P2POp
2044
+ >>> from mindspore import Tensor
2045
+ >>>
2046
+ >>> init_process_group()
2047
+ >>> this_rank = get_rank()
2048
+ >>> world_size = get_world_size()
2049
+ >>> next_rank = (this_rank + 1) % world_size
2050
+ >>> prev_rank = (this_rank + world_size - 1) % world_size
2051
+ >>>
2052
+ >>> send_tensor = Tensor(this_rank + 1, dtype=mindspore.float32)
2053
+ >>> recv_tensor = Tensor(0., dtype=mindspore.float32)
2054
+ >>>
2055
+ >>> send_op = P2POp('isend', send_tensor, next_rank)
2056
+ >>> recv_op = P2POp('irecv', recv_tensor, prev_rank)
2057
+ >>>
2058
+ >>> p2p_op_list = [send_op, recv_op]
2059
+ >>> output = batch_isend_irecv(p2p_op_list)
2060
+ >>> print(recv_tensor)
2061
+ rank 0:
2062
+ 2.0
2063
+ rank 1:
2064
+ 1.0
2065
+ """
2066
+ if is_inplace_func() is False:
2067
+ raise ValueError("Non-inplace mode is currently not supported.")
2068
+ tensors = []
2069
+ op_types = []
2070
+ remotes_ranks = []
2071
+ tags = []
2072
+ if not p2p_op_list:
2073
+ raise TypeError(f"p2p_op_list can not be empty list.")
2074
+ for _, p2p_op in enumerate(p2p_op_list):
2075
+ if not isinstance(p2p_op, P2POp):
2076
+ raise TypeError("The elements in p2p_op_list must be type of P2POp.")
2077
+ group = p2p_op_list[0].group
2078
+
2079
+ type_ = None
2080
+ for _, p2p_op in enumerate(p2p_op_list):
2081
+ if group != p2p_op.group:
2082
+ raise TypeError("The group name in p2p_op_list must be consistent.")
2083
+ if isinstance(p2p_op.op, str):
2084
+ type_ = p2p_op.op
2085
+ else:
2086
+ type_ = p2p_op.op.__name__
2087
+ rank_ = (
2088
+ p2p_op.peer
2089
+ if p2p_op.group is None
2090
+ else get_group_rank_from_world_rank(p2p_op.peer, p2p_op.group)
2091
+ )
2092
+ remotes_ranks.append(rank_)
2093
+ tags.append(p2p_op.tag)
2094
+ if type_ == "isend":
2095
+ tensors.append(p2p_op.tensor)
2096
+ op_types.append(TYPE_ISEND)
2097
+ elif type_ == "irecv":
2098
+ if isinstance(p2p_op.tensor, Tensor):
2099
+ tensors.append(p2p_op.tensor)
2100
+ op_types.append(TYPE_IRECV)
2101
+ else:
2102
+ raise TypeError("p2p_op.tensor must be tensor")
2103
+ else:
2104
+ raise TypeError("p2p_op.op must be isend or irecv")
2105
+
2106
+ if group is None:
2107
+ group = GlobalComm.WORLD_COMM_GROUP
2108
+ output = dist_comm_batch_isend_irecv_op(tensors, group, op_types, remotes_ranks)
2109
+ _, handle = _deal_comm_outputs(output, True)
2110
+ return [handle]
2111
+
2112
+
2113
+ def scatter_tensor(output_tensor, input_tensor, src=0, group=None, async_op=False):
2114
+ r"""
2115
+ Scatter tensor evently across the processes in the specified communication group.
2116
+
2117
+ Note:
2118
+ - The interface behavior only support Tensor input and scatter evenly, which
2119
+ is different from that of `pytoch.distributed.scatter`.
2120
+ - Only the tensor in process `src` (global rank) will do scatter.
2121
+ - Only support PyNative mode, Graph mode is not currently supported.
2122
+
2123
+ Args:
2124
+ output_tensor (Tensor): Output tensor. It should have the same size across all ranks.
2125
+ input_tensor (Tensor): The input tensor to be scattered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
2126
+ src (int, optional): Specifies the rank(global rank) of the process that send the tensor.
2127
+ And only process `src` will send the tensor. Default is 0.
2128
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2129
+ Ascend. Default: ``None``.
2130
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
2131
+
2132
+ Returns:
2133
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
2134
+ CommHandle will be None, when `async_op` is False.
2135
+
2136
+ Raises:
2137
+ TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str.
2138
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
2139
+
2140
+ Supported Platforms:
2141
+ ``Ascend``
2142
+
2143
+ Examples:
2144
+ .. note::
2145
+ Before running the following examples, you need to configure the communication environment variables.
2146
+
2147
+ For Ascend devices, it is recommended to use the msrun startup method
2148
+ without any third-party or configuration file dependencies.
2149
+ Please see the `msrun start up
2150
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2151
+ for more details.
2152
+
2153
+ This example should be run with 2 devices.
2154
+
2155
+ >>> import mindspore as ms
2156
+ >>> from mindspore.ops.communication import init_process_group
2157
+ >>> from mindspore.communication.comm_func import scatter_tensor
2158
+ >>> import numpy as np
2159
+ >>> # Launch 2 processes.
2160
+ >>>
2161
+ >>> init_process_group()
2162
+ >>> input = ms.Tensor(np.arange(8).reshape([4, 2]).astype(np.float32))
2163
+ >>> output = ms.Tensor(np.zeros([2, 2]).astype(np.float32))
2164
+ >>> out = scatter_tensor(output, input, src=0)
2165
+ >>> print(output)
2166
+ # rank_0
2167
+ [[0. 1.]
2168
+ [2. 3.]]
2169
+ # rank_1
2170
+ [[4. 5.]
2171
+ [6. 7.]]
2172
+ """
2173
+ if is_inplace_func() is False:
2174
+ raise ValueError("Non-inplace mode is currently not supported.")
2175
+ if not isinstance(input_tensor, (Tensor, Tensor_)):
2176
+ raise TypeError("For scatter_tensor, the input tensor must be tensor")
2177
+ if not isinstance(output_tensor, (Tensor, Tensor_)):
2178
+ raise TypeError("For scatter_tensor, the output tensor must be tensor")
2179
+ if not isinstance(src, int):
2180
+ raise TypeError("For scatter_tensor, the src must be int")
2181
+ if group is None:
2182
+ group = GlobalComm.WORLD_COMM_GROUP
2183
+ if not isinstance(group, str):
2184
+ raise TypeError(
2185
+ "The argument 'group' must be type of string, "
2186
+ "but got 'group' type : {}.".format(type(group))
2187
+ )
2188
+ if not isinstance(async_op, bool):
2189
+ raise TypeError(
2190
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
2191
+ )
2192
+ src = get_group_rank_from_world_rank(src, group)
2193
+ rank_size = get_cache_group_size(group)
2194
+ rank_id = get_cache_group_rank(group)
2195
+ output = dist_comm_scatter_tensor_op(
2196
+ output_tensor, input_tensor, rank_size, src, rank_id, group
2197
+ )
2198
+ _, handle = _deal_comm_outputs(output, async_op)
2199
+ return handle
2200
+
2201
+
2202
+ def gather_into_tensor(output_tensor, input_tensor, dst=0, group=None, async_op=False):
2203
+ r"""
2204
+ Gathers tensors from the specified communication group. The operation will gather the tensor
2205
+ from processes according to dimension 0.
2206
+
2207
+ Note:
2208
+ - Only the tensor in process `dst` (global rank) will keep the gathered tensor. The other process
2209
+ will keep a tensor with shape [1], which has no mathematical meaning.
2210
+ - Only support PyNative mode, Graph mode is not currently supported.
2211
+
2212
+ Args:
2213
+ output_tensor (Tensor): Output tensor to accommodate tensor elements from all ranks.
2214
+ input_tensor (Tensor): The tensor to be gathered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
2215
+ the input tensors in this API must have the same size across all ranks.
2216
+ dst(int, optional): Specifies the rank(global rank) of the process that receive the tensor.
2217
+ And only process `dst` will receive the gathered tensor. Default: 0.
2218
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2219
+ Ascend. Default: ``None``.
2220
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
2221
+
2222
+ Returns:
2223
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
2224
+ CommHandle will be None, when `async_op` is False.
2225
+
2226
+ Raises:
2227
+ TypeError: If the type of the `input_tensor` or `output_tensor` parameter is not Tensor,
2228
+ or any of `op` and `group` is not a str.
2229
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
2230
+
2231
+ Supported Platforms:
2232
+ ``Ascend``
2233
+
2234
+ Examples:
2235
+ .. note::
2236
+ Before running the following examples, you need to configure the communication environment variables.
2237
+
2238
+ For Ascend devices, it is recommended to use the msrun startup method
2239
+ without any third-party or configuration file dependencies.
2240
+ Please see the `msrun start up
2241
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2242
+ for more details.
2243
+
2244
+ This example should be run with 2 devices.
2245
+
2246
+ >>> import numpy as np
2247
+ >>> import mindspore as ms
2248
+ >>> import mindspore.nn as nn
2249
+ >>> from mindspore.ops.communication import init_process_group
2250
+ >>> from mindspore import Tensor
2251
+ >>> from mindspore.communication.comm_func import gather_into_tensor
2252
+ >>> # Launch 2 processes.
2253
+ >>>
2254
+ >>> init_process_group()
2255
+ >>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32))
2256
+ >>> output = Tensor(np.zeros([4, 2]).astype(np.float32))
2257
+ >>> handle = gather_into_tensor(output, input, dst=0)
2258
+ >>> print(output)
2259
+ Process with rank 0: [[0. 1.],
2260
+ [2. 3.],
2261
+ [0. 1.],
2262
+ [2. 3.]]
2263
+ Process with rank 1: [[0. 0.],
2264
+ [0. 0.],
2265
+ [0. 0.],
2266
+ [0. 0.]]
2267
+ """
2268
+ if is_inplace_func() is False:
2269
+ raise ValueError("Non-inplace mode is currently not supported.")
2270
+ if not isinstance(input_tensor, (Tensor, Tensor_)):
2271
+ raise TypeError("For gather_into_tensor, the input tensor must be tensor")
2272
+ if not isinstance(output_tensor, (Tensor, Tensor_)):
2273
+ raise TypeError("For gather_into_tensor, the output tensor must be tensor")
2274
+ if not isinstance(dst, int):
2275
+ raise TypeError("For gather_into_tensor, the dst must be int")
2276
+ if group is None:
2277
+ group = GlobalComm.WORLD_COMM_GROUP
2278
+ if not isinstance(group, str):
2279
+ raise TypeError(
2280
+ "The argument 'group' must be type of string, "
2281
+ "but got 'group' type : {}.".format(type(group))
2282
+ )
2283
+ if not isinstance(async_op, bool):
2284
+ raise TypeError(
2285
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
2286
+ )
2287
+ group_size = get_cache_group_size(group)
2288
+ dst = get_group_rank_from_world_rank(dst, group)
2289
+ rank_id = get_cache_group_rank(group)
2290
+ output = dist_comm_gather_into_tensor_op(
2291
+ output_tensor, input_tensor, group_size, dst, rank_id, group
2292
+ )
2293
+ _, handle = _deal_comm_outputs(output, async_op)
2294
+ return handle
2295
+
2296
+
2297
+ def broadcast(tensor, src, group=None, async_op=False):
2298
+ """
2299
+ Broadcasts the tensor to the whole group.
2300
+
2301
+ Note:
2302
+ - The tensors must have the same shape and format in all processes of the collection.
2303
+ - Only support PyNative mode, Graph mode is not currently supported.
2304
+
2305
+ Args:
2306
+ tensor (Tensor): Data to be sent if src is the rank of current process,
2307
+ and tensor to be used to save received data otherwise.
2308
+ src (int): Specifies the rank(global rank) of the process that broadcast the tensor.
2309
+ And only process `src` will broadcast the tensor.
2310
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2311
+ Ascend. Default: ``None``.
2312
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
2313
+
2314
+ Returns:
2315
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
2316
+ CommHandle will be None, when `async_op` is False.
2317
+
2318
+ Raises:
2319
+ TypeError: If the type of the `tensor` parameter is not Tensor, `src` is not an integer,
2320
+ `group` is not a string or `async_op` is not bool.
2321
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
2322
+
2323
+ Supported Platforms:
2324
+ ``Ascend`` ``CPU``
2325
+
2326
+ Examples:
2327
+ .. note::
2328
+ Before running the following examples, you need to configure the communication environment variables.
2329
+
2330
+ For Ascend devices, it is recommended to use the msrun startup method
2331
+ without any third-party or configuration file dependencies.
2332
+ Please see the `msrun start up
2333
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2334
+ for more details.
2335
+
2336
+ This example should be run with 2 devices.
2337
+
2338
+ >>> import mindspore as ms
2339
+ >>> from mindspore.ops.communication import init_process_group, broadcast
2340
+ >>> import numpy as np
2341
+ >>> # Launch 2 processes.
2342
+ >>>
2343
+ >>> init_process_group()
2344
+ >>> data = ms.Tensor(np.arange(8).reshape([2, 4]).astype(np.float32))
2345
+ >>> handle = broadcast(tensor=data, src=0)
2346
+ >>> print(data)
2347
+ [[0. 1. 2. 3.]
2348
+ [4. 5. 6. 7.]]
2349
+ """
2350
+ if is_inplace_func() is False:
2351
+ raise ValueError("Non-inplace mode is currently not supported.")
2352
+ if not isinstance(tensor, (Tensor, Tensor_)):
2353
+ raise TypeError("For broadcast, the input tensor must be tensor")
2354
+ if not isinstance(src, int):
2355
+ raise TypeError("For broadcast, the src must be int")
2356
+ if group is None:
2357
+ group = GlobalComm.WORLD_COMM_GROUP
2358
+ if not isinstance(group, str):
2359
+ raise TypeError(
2360
+ "The argument 'group' must be type of string, "
2361
+ "but got 'group' type : {}.".format(type(group))
2362
+ )
2363
+ if not isinstance(async_op, bool):
2364
+ raise TypeError(
2365
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
2366
+ )
2367
+ src_rank = get_group_rank_from_world_rank(src, group)
2368
+ rank_id = get_cache_group_rank(group)
2369
+ output = dist_comm_broadcast_op(tensor, src_rank, rank_id, group)
2370
+ _, handle = _deal_comm_outputs(output, async_op)
2371
+ return handle
2372
+
2373
+
2374
+ def barrier(group=None, async_op=False, device_ids=None):
2375
+ """
2376
+ Synchronizes all processes in the specified group. Once the process call this operation, it will be blocked until
2377
+ all processes call this operation. After all processes finish calling the operations, the blocked processes
2378
+ will be woken and continue their task.
2379
+
2380
+ Args:
2381
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2382
+ Ascend. Default: ``None``.
2383
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
2384
+ device_ids (list[int], optional): Currently It is a reserved Parameter.
2385
+
2386
+ Returns:
2387
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
2388
+ CommHandle will be None, when `async_op` is False.
2389
+
2390
+ Raises:
2391
+ TypeError: `group` is not a str or `async_op` is not a bool.
2392
+ RuntimeError: If backend is invalid, or distributed initialization fails.
2393
+
2394
+ Supported Platforms:
2395
+ ``Ascend`` ``CPU``
2396
+
2397
+ Examples:
2398
+ .. note::
2399
+ Before running the following examples, you need to configure the communication environment variables.
2400
+
2401
+ For Ascend devices, it is recommended to use the msrun startup method
2402
+ without any third-party or configuration file dependencies.
2403
+ Please see the `msrun start up
2404
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2405
+ for more details.
2406
+
2407
+ This example should be run with 2 devices.
2408
+
2409
+ >>> from mindspore.ops.communication import init_process_group
2410
+ >>> from mindspore.communication.comm_func import barrier
2411
+ >>> # Launch 2 processes.
2412
+ >>> init_process_group()
2413
+ >>> barrier()
2414
+ >>> print("barrier finish!")
2415
+ barrier finish!
2416
+ """
2417
+ if group is None:
2418
+ group = GlobalComm.WORLD_COMM_GROUP
2419
+ if not isinstance(group, str):
2420
+ raise TypeError(
2421
+ "The argument 'group' must be type of string, "
2422
+ "but got 'group' type : {}.".format(type(group))
2423
+ )
2424
+ if not isinstance(async_op, bool):
2425
+ raise TypeError(
2426
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
2427
+ )
2428
+ output = dist_comm_barrier_op(group)
2429
+ _, handle = _deal_comm_outputs(output, async_op, True)
2430
+ return handle
2431
+
2432
+
2433
+ def send(tensor, dst=0, group=None, tag=0):
2434
+ """
2435
+ Send tensors to the specified dest_rank.
2436
+
2437
+ Note:
2438
+ Only support PyNative mode, Graph mode is not currently supported.
2439
+
2440
+ Args:
2441
+ tensor (Tensor): Tensor to send.
2442
+ dst (int, optional): A required integer identifying the destination rank(global rank). Default: 0.
2443
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2444
+ Ascend. Default: ``None``.
2445
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
2446
+ be received by the Receive op with the same "tag". Default: 0. It is a reserved parameter currently.
2447
+
2448
+ Raises:
2449
+ TypeError: If the `tensor` is not Tensor, `dst` is not an int or `group` is not a str.
2450
+ ValueError: If the `dst` process rank id is same as the current process.
2451
+
2452
+ Supported Platforms:
2453
+ ``Ascend`` ``CPU``
2454
+
2455
+ Examples:
2456
+ .. note::
2457
+ Before running the following examples, you need to configure the communication environment variables.
2458
+
2459
+ For Ascend devices, it is recommended to use the msrun startup method
2460
+ without any third-party or configuration file dependencies.
2461
+ Please see the `msrun start up
2462
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2463
+ for more details.
2464
+
2465
+ This example should be run with 2 devices.
2466
+
2467
+ >>> from mindspore.ops.communication import init_process_group
2468
+ >>> from mindspore.ops.communication import send, recv, get_rank
2469
+ >>> from mindspore import Tensor
2470
+ >>> import numpy as np
2471
+ >>>
2472
+ # Launch 2 processes, Process 0 sends the array to Process 1.
2473
+ >>> init_process_group()
2474
+ >>> this_rank = get_rank()
2475
+ >>> if this_rank == 0:
2476
+ ... input_ = Tensor(np.ones([2, 8]).astype(np.float32))
2477
+ ... send(input_, 1)
2478
+ >>> if this_rank == 1:
2479
+ ... x = Tensor(np.zeros([2, 8]).astype(np.float32))
2480
+ ... out = recv(x, src=0)
2481
+ ... print(x)
2482
+ rank 1:
2483
+ [[1. 1. 1. 1. 1. 1. 1. 1.]
2484
+ [1. 1. 1. 1. 1. 1. 1. 1.]]
2485
+ """
2486
+ if not isinstance(tensor, (Tensor, Tensor_)):
2487
+ raise TypeError("For send, the input tensor must be tensor")
2488
+ if not isinstance(dst, int):
2489
+ raise TypeError("For send, the dst must be int")
2490
+ if group is None:
2491
+ group = GlobalComm.WORLD_COMM_GROUP
2492
+ if not isinstance(group, str):
2493
+ raise TypeError(
2494
+ "The argument 'group' must be type of string, "
2495
+ "but got 'group' type : {}.".format(type(group))
2496
+ )
2497
+ if get_cache_group_rank() == dst:
2498
+ raise ValueError(
2499
+ "Invalid destination rank: destination rank should not be the same as "
2500
+ "the rank of the current process."
2501
+ )
2502
+ _dst = _get_group_rank_from_world_rank_from_cache_helper(dst, group)
2503
+ output = dist_comm_isend_op(tensor, _dst, group, tag)
2504
+ _deal_comm_outputs(output, False)
2505
+
2506
+
2507
+
2508
+ def recv(tensor, src=0, group=None, tag=0):
2509
+ """
2510
+ Receive tensors from src.
2511
+
2512
+ Note:
2513
+ Only support PyNative mode, Graph mode is not currently supported.
2514
+
2515
+ Args:
2516
+ tensor (Tensor): Tensor to fill with received data, If the function operates in-place. Otherwise,
2517
+ Indicates the the shape and dtype of this tensor is used to receive tensor, but the value of
2518
+ input `tensor` would not take effect.
2519
+ src (int, optional): A required integer identifying the source rank(global rank). Default: ``0``.
2520
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2521
+ Ascend. Default: ``None``.
2522
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
2523
+ be received by the Send op with the same "tag". Default: ``0``. It is a reserved parameter currently.
2524
+
2525
+ Returns:
2526
+ - int, if the function operates in-place, return it. If success, return ``0``.
2527
+ - Tensor, if the function operates non in-place, return it.
2528
+ the shape of output is :math:`(x_1, x_2, ..., x_R)`.
2529
+
2530
+ Raises:
2531
+ TypeError: If the `tensor` is not Tensor, `src` is not an int or `group` is not a str.
2532
+ ValueError: If the rank ID of the process is greater than the rank size of the communication group.
2533
+
2534
+ Supported Platforms:
2535
+ ``Ascend`` ``CPU``
2536
+
2537
+ Examples:
2538
+ .. note::
2539
+ Before running the following examples, you need to configure the communication environment variables.
2540
+
2541
+ For Ascend devices, it is recommended to use the msrun startup method
2542
+ without any third-party or configuration file dependencies.
2543
+ Please see the `msrun start up
2544
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2545
+ for more details.
2546
+
2547
+ This example should be run with 2 devices.
2548
+
2549
+ >>> from mindspore.ops.communication import init_process_group
2550
+ >>> from mindspore.ops.communication import send, recv, get_rank
2551
+ >>> from mindspore import Tensor
2552
+ >>> import numpy as np
2553
+ >>>
2554
+ # Launch 2 processes, Process 0 sends the array to Process 1.
2555
+ >>> init_process_group()
2556
+ >>> this_rank = get_rank()
2557
+ >>> if this_rank == 0:
2558
+ ... input_ = Tensor(np.ones([2, 8]).astype(np.float32))
2559
+ ... send(input_, 1)
2560
+ >>> if this_rank == 1:
2561
+ ... x = Tensor(np.zeros([2, 8]).astype(np.float32))
2562
+ ... out = recv(x, src=0)
2563
+ ... print(x)
2564
+ rank 1:
2565
+ [[1. 1. 1. 1. 1. 1. 1. 1.]
2566
+ [1. 1. 1. 1. 1. 1. 1. 1.]]
2567
+ """
2568
+ if not isinstance(tensor, (Tensor, Tensor_)):
2569
+ raise TypeError("For recv, the input tensor must be tensor")
2570
+ if not isinstance(src, int):
2571
+ raise TypeError("For recv, the src must be int")
2572
+ if group is None:
2573
+ group = GlobalComm.WORLD_COMM_GROUP
2574
+ if not isinstance(group, str):
2575
+ raise TypeError(
2576
+ "The argument 'group' must be type of string, "
2577
+ "but got 'group' type : {}.".format(type(group))
2578
+ )
2579
+ _src = _get_group_rank_from_world_rank_from_cache_helper(src, group)
2580
+
2581
+ if is_inplace_func() is True:
2582
+ output = dist_comm_irecv_op(tensor, tag, _src, group)
2583
+ _deal_comm_outputs(output, False)
2584
+ return 0
2585
+ shape = tensor.shape
2586
+ dtype = tensor.dtype
2587
+ output = inner_comm_irecv_op(tag, _src, shape, group, dtype)
2588
+ output, _ = _deal_comm_outputs(output, False)
2589
+ return output
2590
+
2591
+
2592
+ def isend(tensor, dst=0, group=None, tag=0):
2593
+ """
2594
+ Send tensors to the specified dest_rank asynchronously.
2595
+
2596
+ Note:
2597
+ Only support PyNative mode, Graph mode is not currently supported.
2598
+
2599
+ Args:
2600
+ tensor (Tensor): Tensor to send.
2601
+ dst (int, optional): A required integer identifying the destination rank(global rank). Default: 0.
2602
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2603
+ Ascend. Default: ``None``.
2604
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
2605
+ be received by the Receive op with the same "tag". Default: 0. It is a reserved parameter currently.
2606
+
2607
+ Returns:
2608
+ CommHandle, it is an async work handle.
2609
+
2610
+ Raises:
2611
+ TypeError: If the `tensor` is not Tensor, `dst` is not an int or `group` is not a str.
2612
+ ValueError: If the `dst` process rank id is same as the current process.
2613
+
2614
+ Supported Platforms:
2615
+ ``Ascend``
2616
+
2617
+ Examples:
2618
+ .. note::
2619
+ Before running the following examples, you need to configure the communication environment variables.
2620
+
2621
+ For Ascend devices, it is recommended to use the msrun startup method
2622
+ without any third-party or configuration file dependencies.
2623
+ Please see the `msrun start up
2624
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2625
+ for more details.
2626
+
2627
+ This example should be run with 2 devices.
2628
+
2629
+ >>> from mindspore.ops.communication import init_process_group
2630
+ >>> from mindspore.ops.communication import isend, irecv, get_rank
2631
+ >>> from mindspore import Tensor
2632
+ >>> import numpy as np
2633
+ >>>
2634
+ # Launch 2 processes, Process 0 sends the array to Process 1.
2635
+ >>> init_process_group()
2636
+ >>> this_rank = get_rank()
2637
+ >>> if this_rank == 0:
2638
+ ... input_ = Tensor(np.ones([2, 8]).astype(np.float32))
2639
+ ... handle = isend(input_, 1)
2640
+ ... handle.wait()
2641
+ >>> if this_rank == 1:
2642
+ ... x = Tensor(np.zeros([2, 8]).astype(np.float32))
2643
+ ... handle = irecv(x, src=0)
2644
+ ... handle.wait()
2645
+ ... print(x)
2646
+ rank 1:
2647
+ [[1. 1. 1. 1. 1. 1. 1. 1.]
2648
+ [1. 1. 1. 1. 1. 1. 1. 1.]]
2649
+ """
2650
+ if not isinstance(tensor, (Tensor, Tensor_)):
2651
+ raise TypeError("For isend, the input tensor must be tensor")
2652
+ if not isinstance(dst, int):
2653
+ raise TypeError("For isend, the dst must be int")
2654
+ if group is None:
2655
+ group = GlobalComm.WORLD_COMM_GROUP
2656
+ if not isinstance(group, str):
2657
+ raise TypeError(
2658
+ "The argument 'group' must be type of string, "
2659
+ "but got 'group' type : {}.".format(type(group))
2660
+ )
2661
+ if get_cache_group_rank() == dst:
2662
+ raise ValueError(
2663
+ "Invalid destination rank: destination rank should not be the same as "
2664
+ "the rank of the current process."
2665
+ )
2666
+ _dst = _get_group_rank_from_world_rank_from_cache_helper(dst, group)
2667
+ output = dist_comm_isend_op(tensor, _dst, group, tag)
2668
+ _, handle = _deal_comm_outputs(output, True)
2669
+ return handle
2670
+
2671
+
2672
+ def irecv(tensor, src=0, group=None, tag=0):
2673
+ """
2674
+ Receive tensors from src asynchronously.
2675
+
2676
+ Note:
2677
+ Only support PyNative mode, Graph mode is not currently supported.
2678
+
2679
+ Args:
2680
+ tensor (Tensor): Tensor to fill with received data, if the function operates in-place.Otherwise,
2681
+ Indicates the the shape and dtype of this tensor is used to receive tensor, but the value of
2682
+ input `tensor` would not take effect.
2683
+ src (int, optional): A required integer identifying the source rank(global rank). Default: ``0``.
2684
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2685
+ Ascend. Default: ``None``.
2686
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
2687
+ be received by the Send op with the same "tag". Default: ``0``. It is a reserved parameter currently.
2688
+
2689
+ Returns:
2690
+ - CommHandle, if the function operates in-place, return it. CommHandle is an async work handle,
2691
+ if `async_op` is set to True. CommHandle will be None, when `async_op` is False.
2692
+ - Tuple(Tensor, CommHandle), if the function operates non in-place, return it. the shape of output
2693
+ is :math:`(x_1, x_2, ..., x_R)`. CommHandle is an async work handle, if `async_op` is set to True.
2694
+ CommHandle will be None, when `async_op` is False.
2695
+
2696
+ Raises:
2697
+ TypeError: If the type of `tensor` is not Tensor, If `src` is not an int or `group` is not a str.
2698
+ ValueError: If the rank ID of the process is greater than the rank size of the communication group.
2699
+
2700
+ Supported Platforms:
2701
+ ``Ascend``
2702
+
2703
+ Examples:
2704
+ .. note::
2705
+ Before running the following examples, you need to configure the communication environment variables.
2706
+
2707
+ For Ascend devices, it is recommended to use the msrun startup method
2708
+ without any third-party or configuration file dependencies.
2709
+ Please see the `msrun start up
2710
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2711
+ for more details.
2712
+
2713
+ This example should be run with 2 devices.
2714
+
2715
+ >>> from mindspore.ops.communication import init_process_group
2716
+ >>> from mindspore.ops.communication import isend, irecv, get_rank
2717
+ >>> from mindspore import Tensor
2718
+ >>> import numpy as np
2719
+ >>>
2720
+ # Launch 2 processes, Process 0 sends the array to Process 1.
2721
+ >>> init_process_group()
2722
+ >>> this_rank = get_rank()
2723
+ >>> if this_rank == 0:
2724
+ ... input_ = Tensor(np.ones([2, 8]).astype(np.float32))
2725
+ ... handle = isend(input_, 1)
2726
+ ... handle.wait()
2727
+ >>> if this_rank == 1:
2728
+ ... x = Tensor(np.zeros([2, 8]).astype(np.float32))
2729
+ ... handle = irecv(x, src=0)
2730
+ ... handle.wait()
2731
+ ... print(x)
2732
+ rank 1:
2733
+ [[1. 1. 1. 1. 1. 1. 1. 1.]
2734
+ [1. 1. 1. 1. 1. 1. 1. 1.]]
2735
+ """
2736
+ if not isinstance(tensor, (Tensor, Tensor_)):
2737
+ raise TypeError("For irecv, the input tensor must be tensor")
2738
+ if group is None:
2739
+ group = GlobalComm.WORLD_COMM_GROUP
2740
+ if not isinstance(group, str):
2741
+ raise TypeError(
2742
+ "The argument 'group' must be type of string, "
2743
+ "but got 'group' type : {}.".format(type(group))
2744
+ )
2745
+ if not isinstance(src, int):
2746
+ raise TypeError("For irecv, the src must be int")
2747
+ _src = _get_group_rank_from_world_rank_from_cache_helper(src, group)
2748
+
2749
+ if is_inplace_func() is True:
2750
+ output = dist_comm_irecv_op(tensor, tag, _src, group)
2751
+ _, handle = _deal_comm_outputs(output, True)
2752
+ return handle
2753
+ shape = tensor.shape
2754
+ dtype = tensor.dtype
2755
+ output = inner_comm_irecv_op(tag, _src, shape, group, dtype)
2756
+ return _deal_comm_outputs(output, True)
2757
+
2758
+
2759
+ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
2760
+ """
2761
+ scatter and gather list of tensor to/from all rank according to input/output tensor list.
2762
+
2763
+ Note:
2764
+ - tensor shape in `output_shape_list` and `input_tensor_list` should be match across ranks.
2765
+ - Only support PyNative mode, Graph mode is not currently supported.
2766
+
2767
+ Args:
2768
+ output_tensor_list(Union[List(Tensor), List(Tuple(int))]): List of tensors that indicate the gathered
2769
+ from remote ranks, If the function operates in-place. Otherwise, List of tensors or shape
2770
+ that indicate the gathered tensors shape from remote ranks.
2771
+ input_tensor_list (List[Tensor]): List of tensors to scatter to the remote rank.
2772
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2773
+ Ascend. Default: ``None``.
2774
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
2775
+
2776
+ Returns:
2777
+ - CommHandle, if the function operates in-place, return it. CommHandle is an async work handle,
2778
+ if `async_op` is set to True. CommHandle will be None, when `async_op` is False.
2779
+ - Tuple(Tensor, CommHandle), if the function operates non in-place, return it. the tensors is gathered
2780
+ from remote ranks. CommHandle is an async work handle, if `async_op` is set to True.
2781
+ CommHandle will be None, when `async_op` is False.
2782
+
2783
+ Raises:
2784
+ TypeError: If not all elements in `input_tensor_list` or `output_tensor_list` are Tensor.
2785
+ TypeError: If tensors in `input_tensor_list` or `output_tensor_list` are not the same type.
2786
+ TypeError: If `group` is not str or `async_op` is not bool.
2787
+
2788
+ Supported Platforms:
2789
+ ``Ascend``
2790
+
2791
+ Examples:
2792
+ .. note::
2793
+ Before running the following examples, you need to configure the communication environment variables.
2794
+
2795
+ For Ascend devices, it is recommended to use the msrun startup method
2796
+ without any third-party or configuration file dependencies.
2797
+ Please see the `msrun start up
2798
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2799
+ for more details.
2800
+
2801
+ This example should be run with 2 devices.
2802
+
2803
+ >>> import mindspore as ms
2804
+ >>> from mindspore.ops.communication import init_process_group, get_rank
2805
+ >>> from mindspore.ops.communication import all_to_all
2806
+ >>> from mindspore import Tensor
2807
+ >>>
2808
+ >>> init_process_group()
2809
+ >>> this_rank = get_rank()
2810
+ >>> if this_rank == 0:
2811
+ ... send_tensor_list = [Tensor(1.), Tensor([[2, 3], [4, 5.]])]
2812
+ ... recv_tensor_list = [Tensor((0), dtype=ms.float32), Tensor([0, 0.])]
2813
+ >>> if this_rank == 1:
2814
+ ... send_tensor_list = [Tensor([2, 2.]), Tensor([4, 5, 6, 7.])]
2815
+ ... recv_tensor_list = [Tensor([[0, 0.],[0, 0]]), Tensor([0, 0, 0, 0.])]
2816
+ >>> handle = all_to_all(recv_tensor_list, send_tensor_list)
2817
+ >>> print(recv_tensor_list)
2818
+ rank 0:
2819
+ (Tensor(shape=[], dtype=Float32, value= 1),
2820
+ Tensor(shape=[2], dtype=Float32, value= [2.00000000e+00, 2.00000000e+00]))
2821
+ rank 1:
2822
+ (Tensor(shape=[2, 2], dtype=Float32, value=
2823
+ [[2.00000000e+00, 3.00000000e+00],
2824
+ [4.00000000e+00, 5.00000000e+00]]),
2825
+ Tensor(shape=[4], dtype=Float32, value=[4.00000000e+00, 5.00000000e+00, 6.00000000e+00, 7.00000000e+00]))
2826
+
2827
+ """
2828
+ if group is None:
2829
+ group = GlobalComm.WORLD_COMM_GROUP
2830
+ if not isinstance(group, str):
2831
+ raise TypeError(
2832
+ "The argument 'group' must be type of string, "
2833
+ "but got 'group' type : {}.".format(type(group))
2834
+ )
2835
+ if not isinstance(async_op, bool):
2836
+ raise TypeError(
2837
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
2838
+ )
2839
+
2840
+ _check_all_tensors(input_tensor_list)
2841
+ _check_all_tensor_same_dtype(input_tensor_list)
2842
+
2843
+ send_numel_list = []
2844
+ send_flatten_tensor = []
2845
+ recv_numel_list = []
2846
+ recv_shape_list = []
2847
+
2848
+ for tensor in input_tensor_list:
2849
+ send_numel_list.append(tensor.numel())
2850
+ send_flatten_tensor.append(tensor.reshape(-1))
2851
+ send_flatten_tensor = cat(send_flatten_tensor)
2852
+ rank_size = get_cache_group_size(group)
2853
+
2854
+ if is_inplace_func() is False:
2855
+ _check_all_tensors_or_tuple(output_tensor_list)
2856
+ for tensor in output_tensor_list:
2857
+ if isinstance(tensor, Tensor):
2858
+ recv_numel_list.append(tensor.size)
2859
+ recv_shape_list.append(tensor.shape)
2860
+ else:
2861
+ _shape = tensor
2862
+ recv_numel_list.append(_get_size(_shape))
2863
+ recv_shape_list.append(_shape)
2864
+ output = inner_comm_all_to_all_v_op(send_flatten_tensor, group, send_numel_list, recv_numel_list,
2865
+ rank_size, False)
2866
+ output, handle = _deal_comm_outputs(output, async_op)
2867
+ result = []
2868
+ offset = 0
2869
+ for numel, shape in zip(recv_numel_list, recv_shape_list):
2870
+ result.append(output[offset:offset + numel].reshape(shape))
2871
+ offset = offset + numel
2872
+ return (tuple(result), handle)
2873
+
2874
+ _check_all_tensors(output_tensor_list)
2875
+ _check_all_tensor_same_dtype(output_tensor_list)
2876
+ for tensor in output_tensor_list:
2877
+ recv_numel_list.append(tensor.numel())
2878
+ recv_shape_list.append(tensor.shape)
2879
+
2880
+ output = dist_comm_all_to_all_v_op(
2881
+ output_tensor_list,
2882
+ send_flatten_tensor,
2883
+ group,
2884
+ send_numel_list,
2885
+ recv_numel_list,
2886
+ rank_size,
2887
+ )
2888
+ _, handle = _deal_comm_outputs(output, async_op)
2889
+ return handle
2890
+
2891
+
2892
+ def _get_all_to_all_single_numel_list(tensor, output, output_split_sizes,
2893
+ input_split_sizes, group):
2894
+ """get numel list for all_to_all_single."""
2895
+ if _is_split_sizes_empty(input_split_sizes):
2896
+ _world_size = get_cache_group_size(group)
2897
+ if tensor.shape[0] % _world_size != 0:
2898
+ raise ValueError(
2899
+ "input shape at dim 0 must be divided by world_size, "
2900
+ f"but got {tensor.shape[0]} and {_world_size}."
2901
+ )
2902
+ _split_size = tensor.shape[0] // _world_size
2903
+ input_split_sizes = (_split_size,) * _world_size
2904
+ if _is_split_sizes_empty(output_split_sizes):
2905
+ _world_size = get_cache_group_size(group)
2906
+ shape_dim_0 = output.shape[0]
2907
+
2908
+ if shape_dim_0 % _world_size != 0:
2909
+ raise ValueError(
2910
+ "output shape at dim 0 must be divided by world_size, "
2911
+ f"but got {shape_dim_0} and {_world_size}."
2912
+ )
2913
+ _split_size = shape_dim_0 // _world_size
2914
+ output_split_sizes = (_split_size,) * _world_size
2915
+
2916
+ send_size_without_first_dim = _get_size(tensor.shape[1:])
2917
+ send_numel_list = [size * send_size_without_first_dim for size in input_split_sizes]
2918
+
2919
+ recv_shape_without_first_dim = output.shape[1:]
2920
+ recv_size_without_first_dim = _get_size(recv_shape_without_first_dim)
2921
+ recv_numel_list = [
2922
+ size * recv_size_without_first_dim for size in output_split_sizes
2923
+ ]
2924
+ return send_numel_list, recv_numel_list, recv_shape_without_first_dim
2925
+
2926
+
2927
+ def all_to_all_single(output,
2928
+ input,
2929
+ output_split_sizes=None,
2930
+ input_split_sizes=None,
2931
+ group=None,
2932
+ async_op=False):
2933
+ """
2934
+ scatter and gather input with split size to/from all rank, and return result in a single tensor.
2935
+
2936
+ Note:
2937
+ - Only support PyNative mode, Graph mode is not currently supported.
2938
+
2939
+ Args:
2940
+ output (Union(Tensor, Tuple(int))): the output tensor is gathered concatenated from remote ranks,
2941
+ if the functionoperates in-place. Otherwise, the tensor or shape to indicate the shape
2942
+ of tensor gathered concatenated from remote rank.
2943
+ input (Tensor): tensor to be scattered to remote rank.
2944
+ output_split_sizes (Union(Tuple(int), List(int)), optional): output split size at dim 0. If set to None,
2945
+ it means equally split by ``world_size``. Default: ``None``.
2946
+ input_split_sizes (Union(Tuple(int), List(int)), optional): input split size at dim 0. If set to None,
2947
+ it means equally split by ``world_size``. Default: ``None``.
2948
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
2949
+ Ascend. Default: ``None``.
2950
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
2951
+
2952
+ Returns:
2953
+ - CommHandle, if the function operates in-place, return it. CommHandle is an async work handle,
2954
+ if `async_op` is set to True. CommHandle will be None, when `async_op` is False.
2955
+ - Tuple(Tensor, CommHandle), if the function operates non in-place, return it.
2956
+ The output tensor is gathered concatenated from remote ranks.
2957
+ If the numel of tensor gathered from remote is zero, it will return a Tensor with shape `()`,
2958
+ and value has no actual meanning. CommHandle is an async work handle, if `async_op` is set to True.
2959
+ CommHandle will be None, when `async_op` is False.
2960
+
2961
+ Raises:
2962
+ TypeError: If `input` or `output` is not tensor. `group` is not a str, or async_op is not bool.
2963
+ ValueError: When `input_split_sizes` is empty, input dim 0 can not be divided by ``world_size``.
2964
+ ValueError: When `output_split_sizes` is empty, output dim 0 can not be divided by ``world_size``.
2965
+
2966
+ Supported Platforms:
2967
+ ``Ascend``
2968
+
2969
+ Examples:
2970
+ .. note::
2971
+ Before running the following examples, you need to configure the communication environment variables.
2972
+
2973
+ For Ascend devices, it is recommended to use the msrun startup method
2974
+ without any third-party or configuration file dependencies.
2975
+ Please see the `msrun start up
2976
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
2977
+ for more details.
2978
+
2979
+ This example should be run with 2 devices.
2980
+
2981
+ >>> import numpy as np
2982
+ >>> import mindspore
2983
+ >>> from mindspore.ops.communication import init_process_group, get_rank
2984
+ >>> from mindspore.ops.communication import all_to_all_single
2985
+ >>> from mindspore import Tensor
2986
+ >>> from mindspore.ops.communication import zeros
2987
+ >>>
2988
+ >>> init_process_group()
2989
+ >>> this_rank = get_rank()
2990
+ >>> if this_rank == 0:
2991
+ ... output = Tensor(np.zeros([3, 3]).astype(np.float32))
2992
+ ... tensor = Tensor([[0, 1, 2.], [3, 4, 5], [6, 7, 8]])
2993
+ ... result = all_to_all_single(output, tensor, [2, 1], [2, 1])
2994
+ ... print(output)
2995
+ >>> if this_rank == 1:
2996
+ ... output = Tensor(np.zeros([2, 3]).astype(np.float32))
2997
+ ... tensor = Tensor([[9, 10., 11], [12, 13, 14]])
2998
+ ... result = all_to_all_single(output, tensor, [1, 1], [1, 1])
2999
+ ... print(output)
3000
+ rank 0:
3001
+ [[ 0. 1. 2.]
3002
+ [ 3. 4. 5.]
3003
+ [ 9. 10. 11.]]
3004
+ rank 1:
3005
+ [[ 6. 7. 8.]
3006
+ [12. 13. 14.]]
3007
+
3008
+ """
3009
+ _check_all_tensors([input])
3010
+ _check_all_tensors([output])
3011
+ if group is None:
3012
+ group = GlobalComm.WORLD_COMM_GROUP
3013
+ if not isinstance(group, str):
3014
+ raise TypeError(
3015
+ "The argument 'group' must be type of string, "
3016
+ "but got 'group' type : {}.".format(type(group))
3017
+ )
3018
+ if not isinstance(async_op, bool):
3019
+ raise TypeError(
3020
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
3021
+ )
3022
+ split_sizes_empty = _is_split_sizes_empty(output_split_sizes) and _is_split_sizes_empty(input_split_sizes)
3023
+ _input = input.reshape(-1)
3024
+ rank_size = get_cache_group_size(group)
3025
+
3026
+ if is_inplace_func() is False:
3027
+ if isinstance(output_split_sizes, list):
3028
+ output_split_sizes = tuple(output_split_sizes)
3029
+ if isinstance(input_split_sizes, list):
3030
+ input_split_sizes = tuple(input_split_sizes)
3031
+ global _ALL_TO_ALL_CACHE
3032
+ tensor_shape = output
3033
+ cache_key = (tensor_shape, output, output_split_sizes, input_split_sizes, group)
3034
+ if cache_key not in _ALL_TO_ALL_CACHE:
3035
+ _ALL_TO_ALL_CACHE[cache_key] = _get_all_to_all_single_numel_list(*cache_key)
3036
+ send_numel_list, recv_numel_list, recv_shape_without_first_dim = _ALL_TO_ALL_CACHE[cache_key]
3037
+ result = \
3038
+ inner_comm_all_to_all_v_op(_input, group, send_numel_list, recv_numel_list, rank_size, split_sizes_empty)
3039
+ result, handle = _deal_comm_outputs(result, async_op)
3040
+ if any(recv_numel_list):
3041
+ result = result.reshape((-1,) + recv_shape_without_first_dim)
3042
+ return result, handle
3043
+
3044
+ send_numel_list, recv_numel_list, _ = \
3045
+ _get_all_to_all_single_numel_list(input, output, output_split_sizes, input_split_sizes, group)
3046
+ result = dist_comm_all_to_all_v_single_op(
3047
+ output,
3048
+ _input,
3049
+ group,
3050
+ send_numel_list,
3051
+ recv_numel_list,
3052
+ rank_size,
3053
+ split_sizes_empty,
3054
+ )
3055
+ _, handle = _deal_comm_outputs(result, async_op)
3056
+ return handle
3057
+
3058
+
3059
+ def _check_tensor_list(tensor_list, tensor, group_size):
3060
+ """check all elements in tensor_list are type of Tensor or tuple or list"""
3061
+ _check_group_tensor_list(tensor_list, group_size)
3062
+ if tensor.dtype != tensor_list[0].dtype:
3063
+ raise TypeError(
3064
+ f"The argument list tensor type must be equal to tensor type, but got {tensor_list[0].dtype}."
3065
+ )
3066
+ if tensor.shape != tensor_list[0].shape:
3067
+ raise TypeError(
3068
+ f"The argument list tensor shape must be equal to tensor shape, but got {tensor_list[0].shape}."
3069
+ )
3070
+
3071
+
3072
+ def _check_group_tensor_list(tensor_list, group_size):
3073
+ if not tensor_list or len(tensor_list) != group_size:
3074
+ raise TypeError(
3075
+ f"The argument list tensor len must be equal to group rank size, but got {len(tensor_list)}."
3076
+ )
3077
+
3078
+
3079
+ def all_gather(tensor_list, tensor, group=None, async_op=False):
3080
+ """
3081
+ Gathers tensors from the specified communication group and returns the tensor list which is all gathered.
3082
+
3083
+ Args:
3084
+ tensor_list (list[Tensor]): Output list.
3085
+ tensor (Tensor): The input tensor to be all gathered into tensor.
3086
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
3087
+ Ascend. Default: ``None``.
3088
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
3089
+
3090
+ Returns:
3091
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
3092
+ CommHandle will be None, when `async_op` is False.
3093
+
3094
+ Raises:
3095
+ TypeError: If the type of input `tensor` is not Tensor, `tensor_list` is not Tensor List,
3096
+ `group` is not a str or async_op is not bool.
3097
+ TypeError: If size of `tensor_list` is not equal to group size。
3098
+ TypeError: If the type or shape of `tensor` not equal to the member of `tensor_list`。
3099
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
3100
+
3101
+ Supported Platforms:
3102
+ ``Ascend`` ``CPU``
3103
+
3104
+ Examples:
3105
+ .. note::
3106
+ Before running the following examples, you need to configure the communication environment variables.
3107
+
3108
+ For Ascend devices, it is recommended to use the msrun startup method
3109
+ without any third-party or configuration file dependencies.
3110
+ Please see the `msrun start up
3111
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
3112
+ for more details.
3113
+
3114
+ This example should be run with 2 devices.
3115
+
3116
+ >>> import numpy as np
3117
+ >>> import mindspore as ms
3118
+ >>> from mindspore.ops.communication import init_process_group
3119
+ >>> from mindspore.ops.communication import all_gather
3120
+ >>> from mindspore import Tensor
3121
+ >>>
3122
+ >>> init_process_group()
3123
+ >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
3124
+ >>> out_tensors = [Tensor(np.zeros([2, 8]).astype(np.float32)), Tensor(np.zeros([2, 8]).astype(np.float32))]
3125
+ >>> output = all_gather(out_tensors, input_tensor)
3126
+ >>> print(out_tensors)
3127
+ [Tensor(shape=[2, 8], dtype=Float32, value=
3128
+ [[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00 ... 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
3129
+ [ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00 ... 1.00000000e+00, 1.00000000e+00, 1.00000000e+00]]),
3130
+ Tensor(shape=[2, 8], dtype=Float32, value=
3131
+ [[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00 ... 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
3132
+ [ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00 ... 1.00000000e+00, 1.00000000e+00, 1.00000000e+00]])]
3133
+
3134
+
3135
+ """
3136
+ if is_inplace_func() is False:
3137
+ raise ValueError("Non-inplace mode is currently not supported.")
3138
+ _check_all_tensors(tensor_list)
3139
+ _check_all_tensor_same_dtype(tensor_list)
3140
+ if not isinstance(tensor, (Tensor, Tensor_)):
3141
+ raise TypeError("For all_gather_into_tensor, the input tensor must be tensor")
3142
+ if group is None:
3143
+ group = GlobalComm.WORLD_COMM_GROUP
3144
+ if not isinstance(group, str):
3145
+ raise TypeError(
3146
+ "The argument 'group' must be type of string, "
3147
+ "but got 'group' type : {}.".format(type(group))
3148
+ )
3149
+ if not isinstance(async_op, bool):
3150
+ raise TypeError(
3151
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
3152
+ )
3153
+ group_size = get_cache_group_size(group)
3154
+ _check_group_tensor_list(tensor_list, group_size)
3155
+ rank_id = get_group_rank_from_world_rank(get_rank(), group)
3156
+ _check_output_shape(tensor, tensor_list[rank_id].shape, "all_gather")
3157
+ _check_output_dtype(tensor, tensor_list[0].dtype, "all_gather")
3158
+ result = dist_comm_all_gather_op(tensor_list, tensor, group_size, group)
3159
+ _, handle = _deal_comm_outputs(result, async_op)
3160
+ return handle
3161
+
3162
+
3163
+ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
3164
+ r"""
3165
+ Reduces and scatters tensors from the specified communication group and
3166
+ returns the tensor which is reduced and scattered.
3167
+
3168
+ Args:
3169
+ output (Tensor): the output tensor.
3170
+ input_list (list[Tensor]): List of tensors to reduce and scatter.
3171
+ op (str, optional): Specifies an operation used for element-wise reductions,
3172
+ like SUM and MAX. Default: ``ReduceOp.SUM`` .
3173
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
3174
+ Ascend. Default: ``None``.
3175
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
3176
+
3177
+ Returns:
3178
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
3179
+ CommHandle will be None, when `async_op` is False.
3180
+
3181
+ Raises:
3182
+ TypeError: If the type of `output` parameter is not Tensor, `input_list` is not Tensor List.
3183
+ TypeError: If any of `op` and `group` is not a str. async_op is not bool or 'op' is invalid.
3184
+ TypeError: If size of `input_list` is not equal to group size.
3185
+ TypeError: If the type or shape of `output` not equal to the member of `input_list`.
3186
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
3187
+
3188
+ Supported Platforms:
3189
+ ``Ascend``
3190
+
3191
+ Examples:
3192
+ .. note::
3193
+ Before running the following examples, you need to configure the communication environment variables.
3194
+
3195
+ For Ascend devices, it is recommended to use the msrun startup method
3196
+ without any third-party or configuration file dependencies.
3197
+ Please see the `msrun start up
3198
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
3199
+ for more details.
3200
+
3201
+ This example should be run with 2 devices.
3202
+
3203
+ >>> from mindspore import Tensor
3204
+ >>> from mindspore.ops.communication import init_process_group
3205
+ >>> from mindspore.ops.communication import reduce_scatter
3206
+ >>> import numpy as np
3207
+ >>>
3208
+ >>> init_process_group()
3209
+ >>> input_tensors = [Tensor(np.ones([4, 8]).astype(np.float32)), Tensor(np.ones([4, 8]).astype(np.float32))]
3210
+ >>> output_tensor = Tensor(np.zeros([4, 8]).astype(np.float32))
3211
+ >>> output = reduce_scatter(output_tensor ,input_tensors)
3212
+ >>> print(output_tensor)
3213
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
3214
+ [2. 2. 2. 2. 2. 2. 2. 2.]
3215
+ [2. 2. 2. 2. 2. 2. 2. 2.]
3216
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
3217
+
3218
+ """
3219
+ if is_inplace_func() is False:
3220
+ raise ValueError("Non-inplace mode is currently not supported.")
3221
+ _check_all_tensors(input_list)
3222
+ _check_all_tensor_same_dtype(input_list)
3223
+ if not isinstance(output, (Tensor, Tensor_)):
3224
+ raise TypeError("For reduce_scatter, the output tensor must be tensor")
3225
+ if group is None:
3226
+ group = GlobalComm.WORLD_COMM_GROUP
3227
+ if not isinstance(group, str):
3228
+ raise TypeError(
3229
+ "The argument 'group' must be type of string, "
3230
+ "but got 'group' type : {}.".format(type(group))
3231
+ )
3232
+ if not isinstance(async_op, bool):
3233
+ raise TypeError(
3234
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
3235
+ )
3236
+ if not isinstance(op, str):
3237
+ raise TypeError("For reduce_scatter, the input op type must be str")
3238
+ if op not in ("sum", "prod", "min", "max"):
3239
+ raise TypeError(
3240
+ "For reduce_scatter, the input op value must be one of sum, prod, min, max"
3241
+ )
3242
+ rank_size = get_cache_group_size(group)
3243
+ _check_group_tensor_list(input_list, rank_size)
3244
+
3245
+ rank_id = get_group_rank_from_world_rank(get_rank(), group)
3246
+ _check_output_shape(output, input_list[rank_id].shape, "reduce_scatter")
3247
+ _check_output_dtype(output, input_list[0].dtype, "reduce_scatter")
3248
+ result = dist_comm_reduce_scatter_op(output, input_list, rank_size, op, group)
3249
+ _, handle = _deal_comm_outputs(result, async_op)
3250
+ return handle
3251
+
3252
+
3253
+ def scatter(tensor, scatter_list, src=0, group=None, async_op=False):
3254
+ r"""
3255
+ Scatter tensor evently across the processes in the specified communication group.
3256
+
3257
+ Note:
3258
+ - The interface behavior only support Tensor List input and scatter evenly.
3259
+ - Only the tensor in process `src` (global rank) will do scatter.
3260
+ - Only support PyNative mode, Graph mode is not currently supported.
3261
+
3262
+ Args:
3263
+ tensor (Tensor): the output tensor.
3264
+ scatter_list (list[Tensor]): List of same-sized tensors to scatter.
3265
+ default is None, must be specified on the source rank.
3266
+ src (int, optional): Specifies the rank(global rank) of the process that send the tensor.
3267
+ And only process `src` will send the tensor.
3268
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
3269
+ Ascend. Default: ``None``.
3270
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
3271
+
3272
+ Returns:
3273
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
3274
+ CommHandle will be None, when `async_op` is False.
3275
+
3276
+ Raises:
3277
+ TypeError: If the type of `tensor` parameter is not Tensor, `scatter_list` is not Tensor List.
3278
+ TypeError: If any of `op` and `group` is not a str. async_op is not bool or 'op' is invalid.
3279
+ TypeError: If size of `scatter_list` is not equal to group size.
3280
+ TypeError: If the type or shape of `tensor` not equal to the member of `scatter_list`.
3281
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
3282
+
3283
+ Supported Platforms:
3284
+ ``Ascend`` ``CPU``
3285
+
3286
+ Examples:
3287
+ .. note::
3288
+ Before running the following examples, you need to configure the communication environment variables.
3289
+
3290
+ For Ascend devices, it is recommended to use the msrun startup method
3291
+ without any third-party or configuration file dependencies.
3292
+ Please see the `msrun start up
3293
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
3294
+ for more details.
3295
+
3296
+ This example should be run with 2 devices.
3297
+
3298
+ >>> from mindspore import Tensor
3299
+ >>> from mindspore.ops.communication import init_process_group, scatter
3300
+ >>> import numpy as np
3301
+ >>> # Launch 2 processes.
3302
+ >>>
3303
+ >>> init_process_group()
3304
+ >>> inputs = [Tensor(np.ones([2, 2]).astype(np.float32)), Tensor(np.ones([2, 2]).astype(np.float32))]
3305
+ >>> output = Tensor(np.zeros([2, 2]).astype(np.float32))
3306
+ >>> scatter(output, inputs, src=0)
3307
+ >>> print(output)
3308
+ # rank_0
3309
+ [[1. 1.]
3310
+ [1. 1.]]
3311
+ # rank_1
3312
+ [[1. 1.]
3313
+ [1. 1.]]
3314
+ """
3315
+ if is_inplace_func() is False:
3316
+ raise ValueError("Non-inplace mode is currently not supported.")
3317
+ _check_all_tensors(scatter_list)
3318
+ _check_all_tensor_same_dtype_and_shape(scatter_list)
3319
+ if not isinstance(tensor, (Tensor, Tensor_)):
3320
+ raise TypeError("For scatter_tensor, the output tensor must be tensor")
3321
+ if not isinstance(src, int):
3322
+ raise TypeError("For scatter_tensor, the src must be int")
3323
+ if group is None:
3324
+ group = GlobalComm.WORLD_COMM_GROUP
3325
+ if not isinstance(group, str):
3326
+ raise TypeError(
3327
+ "The argument 'group' must be type of string, "
3328
+ "but got 'group' type : {}.".format(type(group))
3329
+ )
3330
+ if not isinstance(async_op, bool):
3331
+ raise TypeError(
3332
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
3333
+ )
3334
+ src = get_group_rank_from_world_rank(src, group)
3335
+ rank_size = get_cache_group_size(group)
3336
+ rank_id = get_cache_group_rank(group)
3337
+ if src == rank_id:
3338
+ _check_tensor_list(scatter_list, tensor, rank_size)
3339
+ output = dist_comm_scatter_op(tensor, scatter_list, rank_size, src, rank_id, group)
3340
+ _, handle = _deal_comm_outputs(output, async_op)
3341
+ return handle
3342
+
3343
+
3344
+ def gather(tensor, gather_list, dst=0, group=None, async_op=False):
3345
+ r"""
3346
+ Gathers tensors from the specified communication group. The operation will gather the tensor
3347
+ from processes according to dimension 0.
3348
+
3349
+ Note:
3350
+ - Only the tensor in process `dst` (global rank) will keep the gathered tensor. The other process
3351
+ will keep a tensor list which has no mathematical meaning.
3352
+ - The tensors must have the same shape and format in all processes of the collection.
3353
+ - Only support PyNative mode, Graph mode is not currently supported.
3354
+
3355
+ Args:
3356
+ tensor (Tensor): The tensor to be gathered.
3357
+ gather_list (list[Tensor]): List of same-sized tensors to use for gathered data.
3358
+ dst (int, optional): Specifies the rank(global rank) of the process that receive the tensor.
3359
+ And only process `dst` will receive the gathered tensor. Default: ``0`` .
3360
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
3361
+ Ascend. Default: ``None``.
3362
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
3363
+
3364
+ Returns:
3365
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
3366
+ CommHandle will be None, when `async_op` is False.
3367
+
3368
+ Raises:
3369
+ TypeError: If the type of input tensor is not Tensor, or gather_list is not Tensor list.
3370
+ TypeError: If dst is not an integer, group is not a string or async_op is not bool.
3371
+ TypeError: If size of `gather_list` is not equal to group size.
3372
+ TypeError: If the type or shape of `tensor` not equal to the member of `gather_list`.
3373
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
3374
+
3375
+ Supported Platforms:
3376
+ ``Ascend`` ``CPU``
3377
+
3378
+ Examples:
3379
+ .. note::
3380
+ Before running the following examples, you need to configure the communication environment variables.
3381
+
3382
+ For Ascend devices, it is recommended to use the msrun startup method
3383
+ without any third-party or configuration file dependencies.
3384
+ Please see the `msrun start up
3385
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
3386
+ for more details.
3387
+
3388
+ This example should be run with 2 devices.
3389
+
3390
+ >>> import numpy as np
3391
+ >>> import mindspore as ms
3392
+ >>> import mindspore.nn as nn
3393
+ >>> from mindspore.ops.communication import init_process_group, gather
3394
+ >>> from mindspore import Tensor
3395
+ >>> # Launch 2 processes.
3396
+ >>> init_process_group()
3397
+ >>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32))
3398
+ >>> outputs = [Tensor(np.zeros([2, 2]).astype(np.float32)),Tensor(np.zeros([2, 2]).astype(np.float32))]
3399
+ >>> gather(input, outputs, dst=0)
3400
+ >>> print(outputs)
3401
+ # rank_0
3402
+ [Tensor(shape=[2, 2], dtype=Float32, value=
3403
+ [[ 0.00000000e+00, 1.00000000e+00],
3404
+ [ 2.00000000e+00, 3.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
3405
+ [[ 0.00000000e+00, 1.00000000e+00], [ 2.00000000e+00, 3.00000000e+00]])]
3406
+ [Tensor(shape=[2, 2], dtype=Float32, value=[[ 0.00000000e+00, 1.00000000e+00],
3407
+ [ 2.00000000e+00, 3.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
3408
+ [[ 0.00000000e+00, 1.00000000e+00], [ 2.00000000e+00, 3.00000000e+00]])]
3409
+ # rank_1
3410
+ [Tensor(shape=[2, 2], dtype=Float32, value=[[ 0.00000000e+00, 0.00000000e+00],
3411
+ [ 0.00000000e+00, 0.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
3412
+ [[ 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00]])]
3413
+ [Tensor(shape=[2, 2], dtype=Float32, value=
3414
+ [[ 0.00000000e+00, 0.00000000e+00],
3415
+ [ 0.00000000e+00, 0.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
3416
+ [[ 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00]])]
3417
+ """
3418
+ if is_inplace_func() is False:
3419
+ raise ValueError("Non-inplace mode is currently not supported.")
3420
+ if not isinstance(tensor, (Tensor, Tensor_)):
3421
+ raise TypeError("For gather, the input tensor must be tensor")
3422
+ _check_all_tensors(gather_list)
3423
+ _check_all_tensor_same_dtype_and_shape(gather_list)
3424
+ if not isinstance(dst, int):
3425
+ raise TypeError("For gather, the dst must be int")
3426
+ if group is None:
3427
+ group = GlobalComm.WORLD_COMM_GROUP
3428
+ if not isinstance(group, str):
3429
+ raise TypeError(
3430
+ "The argument 'group' must be type of string, "
3431
+ "but got 'group' type : {}.".format(type(group))
3432
+ )
3433
+ if not isinstance(async_op, bool):
3434
+ raise TypeError(f"The argument 'async_op' must be a bool, but got {type(async_op)}.")
3435
+ group_size = get_cache_group_size(group)
3436
+ dst = get_group_rank_from_world_rank(dst, group)
3437
+ rank_id = get_cache_group_rank(group)
3438
+ if dst == rank_id:
3439
+ _check_tensor_list(gather_list, tensor, group_size)
3440
+ output = dist_comm_gather_op(tensor, gather_list, group_size, dst, rank_id, group)
3441
+ _, handle = _deal_comm_outputs(output, async_op)
3442
+ return handle
3443
+
3444
+
3445
+ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None):
3446
+ r"""
3447
+ Scatters picklable objects in scatter_object_input_list to the whole group.
3448
+
3449
+ Note:
3450
+ - Similar to :func:`mindspore.ops.communication.scatter`, but Python objects can be passed in.
3451
+ - Only the objects in process `src` (global rank) will do scatter.
3452
+ - Only support PyNative mode, Graph mode is not currently supported.
3453
+
3454
+ Args:
3455
+ scatter_object_output_list (list[Any]): Non-empty list whose first element
3456
+ will store the object scattered to this rank.
3457
+ scatter_object_input_list (list[Any]): List of python objects to scatter.
3458
+ it must be specified on the source rank.
3459
+ src (int, optional): Specifies the rank(global rank) of the process that send the tensor.
3460
+ And only process `src` will send the tensor. Default: ``0`` .
3461
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
3462
+ Ascend. Default: ``None``.
3463
+
3464
+ Raises:
3465
+ TypeError: If `group` is not a str or `src` is not an integer.
3466
+ TypeError: If size of `scatter_object_input_list` is not equal to group size.
3467
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
3468
+
3469
+ Supported Platforms:
3470
+ ``Ascend``
3471
+
3472
+ Examples:
3473
+ .. note::
3474
+ Before running the following examples, you need to configure the communication environment variables.
3475
+
3476
+ For Ascend devices, it is recommended to use the msrun startup method
3477
+ without any third-party or configuration file dependencies.
3478
+ Please see the `msrun start up
3479
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
3480
+ for more details.
3481
+
3482
+ This example should be run with 2 devices.
3483
+
3484
+ >>> from mindspore.ops.communication import init_process_group, scatter_object_list
3485
+ >>> init_process_group()
3486
+ >>> obj = ["test", {1: 2}]
3487
+ >>> scatter_object_output_list=[None]
3488
+ >>> scatter_object_list(scatter_object_output_list, obj)
3489
+ >>> print(scatter_object_output_list)
3490
+ # rank_0
3491
+ ['test']
3492
+ # rank_1
3493
+ [{1: 2}]
3494
+ """
3495
+ if is_inplace_func() is False:
3496
+ raise ValueError("Non-inplace mode is currently not supported.")
3497
+ if group is None:
3498
+ group = GlobalComm.WORLD_COMM_GROUP
3499
+ if not isinstance(group, str):
3500
+ raise TypeError(
3501
+ "For 'scatter_object_list', the argument 'group' must be type of string, "
3502
+ "but got 'group' type : {}.".format(type(group))
3503
+ )
3504
+ if not isinstance(scatter_object_output_list, list) or not scatter_object_output_list:
3505
+ raise TypeError(f"The scatter_object_output_list can not be empty.")
3506
+ if not isinstance(src, int):
3507
+ raise TypeError("For scatter_object_list, the src must be int")
3508
+ group_size = get_cache_group_size(group)
3509
+ rank_id = get_cache_group_rank()
3510
+ tensor_sizes = []
3511
+ tensor_list = []
3512
+ if rank_id == src:
3513
+ if not isinstance(scatter_object_input_list, list) or len(scatter_object_input_list) != group_size:
3514
+ raise TypeError(
3515
+ "The len of scatter_object_input_list must be equal to group rank size, "
3516
+ "but got {len(scatter_object_input_list)}."
3517
+ )
3518
+ for obj in scatter_object_input_list:
3519
+ _, size = _object_to_tensor(obj)
3520
+ tensor_sizes.append(Tensor([size], dtype=mstype.int32))
3521
+ max_size = int(max(tensor_sizes).item())
3522
+ for obj in scatter_object_input_list:
3523
+ tensor, _ = _object_to_tensor(obj, max_size)
3524
+ tensor_list.append(tensor)
3525
+ else:
3526
+ tensor_sizes = [Tensor([0], dtype=mstype.int32) for i in range(group_size)]
3527
+
3528
+ object_size = cat(tensor_sizes)
3529
+ broadcast(object_size, src, group)
3530
+ max_object_size = int(max(object_size).item())
3531
+ data = np.zeros((max_object_size)).astype(np.int8)
3532
+ if rank_id != src:
3533
+ tensor_list = [Tensor(data) for i in range(group_size)]
3534
+ out_tensor = Tensor(data)
3535
+ scatter(out_tensor, tensor_list, src, group)
3536
+ group_id = get_group_rank_from_world_rank(rank_id, group)
3537
+ scatter_object_output_list[0] = _tensor_to_object(out_tensor, object_size[group_id])
3538
+
3539
+
3540
+ def gather_object(obj, object_gather_list=None, dst=0, group=None):
3541
+ r"""
3542
+ Gathers python objects from the whole group in a single process.
3543
+
3544
+ Note:
3545
+ - Similar to :func:`mindspore.ops.communication.gather`, but Python objects can be passed in.
3546
+ - Only support PyNative mode, Graph mode is not currently supported.
3547
+
3548
+ Args:
3549
+ obj (Any): The python objects to be gathered.
3550
+ object_gather_list (list[Any], optional): List of same-sized tensors to use for gathered data.
3551
+ On the ``dst`` rank, it should be correctly sized as the size of the group for this
3552
+ collective and will contain the output. Default: ``None``.
3553
+ dst (int, optional): Specifies the rank(global rank) of the process that receive the tensor.
3554
+ And only process `dst` will receive the gathered tensor. Default: ``0`` .
3555
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
3556
+ Ascend. Default: ``None``.
3557
+
3558
+ Raises:
3559
+ TypeError: If dst is not an integer, or group is not a string.
3560
+ TypeError: If size of `object_gather_list` is not equal to group size.
3561
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
3562
+
3563
+ Supported Platforms:
3564
+ ``Ascend``
3565
+
3566
+ Examples:
3567
+ .. note::
3568
+ Before running the following examples, you need to configure the communication environment variables.
3569
+
3570
+ For Ascend devices, it is recommended to use the msrun startup method
3571
+ without any third-party or configuration file dependencies.
3572
+ Please see the `msrun start up
3573
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
3574
+ for more details.
3575
+
3576
+ This example should be run with 2 devices.
3577
+
3578
+ >>> from mindspore.ops.communication import init_process_group, gather_object, get_rank
3579
+ >>> init_process_group()
3580
+ >>> rank = get_rank()
3581
+ >>> obj = ["test", {1: 2}]
3582
+ >>> object_gather_list=[None, None]
3583
+ >>> gather_object(obj[rank], object_gather_list)
3584
+ >>> print(object_gather_list)
3585
+ # rank_0
3586
+ ['test', {1: 2}]
3587
+ """
3588
+ if is_inplace_func() is False:
3589
+ raise ValueError("Non-inplace mode is currently not supported.")
3590
+ if group is None:
3591
+ group = GlobalComm.WORLD_COMM_GROUP
3592
+ if not isinstance(group, str):
3593
+ raise TypeError(
3594
+ "For 'gather_object', the argument 'group' must be type of string, "
3595
+ "but got 'group' type : {}.".format(type(group))
3596
+ )
3597
+ if not isinstance(dst, int):
3598
+ raise TypeError("For gather_object, the dst must be int")
3599
+ group_size = get_cache_group_size(group)
3600
+ rank_id = get_cache_group_rank()
3601
+ if rank_id == dst:
3602
+ if not isinstance(object_gather_list, list) or len(object_gather_list) != group_size:
3603
+ raise TypeError(
3604
+ f"The len of object_gather_list must be equal to group rank size, but got {len(object_gather_list)}."
3605
+ )
3606
+ _, size = _object_to_tensor(obj)
3607
+ tensor = Tensor([size], dtype=mstype.int32)
3608
+ object_size_list = [Tensor([0], dtype=mstype.int32) for i in range(group_size)]
3609
+ all_gather(object_size_list, tensor, group=group)
3610
+ max_object_size = int(max(object_size_list).item())
3611
+ in_tensor, size = _object_to_tensor(obj, max_object_size)
3612
+ data = np.zeros((size)).astype(np.int8)
3613
+ object_tensor_list = [Tensor(data) for i in range(group_size)]
3614
+ gather(in_tensor, object_tensor_list, dst, group)
3615
+ if rank_id != dst:
3616
+ return
3617
+ for i, item in enumerate(object_size_list):
3618
+ tensor_size = int(item.item())
3619
+ tensor = object_tensor_list[i]
3620
+ object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
3621
+
3622
+
3623
+ def broadcast_object_list(object_list, src=0, group=None, device=None):
3624
+ """
3625
+ Broadcasts the entire group of input Python objects.
3626
+
3627
+ Note:
3628
+ - Similar to :func:`mindspore.ops.communication.broadcast`, but Python objects can be passed in.
3629
+ - Only support PyNative mode, Graph mode is not currently supported.
3630
+
3631
+ Args:
3632
+ object_list (list[Any]): list of input to be sent if src is the rank of current process,
3633
+ and list to be used to save received data otherwise.
3634
+ src (int, optional): Specifies the rank(global rank) of the process that broadcast the Python objects.
3635
+ And only process `src` will broadcast the Python objects. Default: ``0`` .
3636
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
3637
+ Ascend. Default: ``None``.
3638
+ device (str, optional): Currently it is a reserved parameter. Default: ``None``.
3639
+
3640
+ Raises:
3641
+ TypeError: If `src` is not an integer or `group` is not a string.
3642
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
3643
+
3644
+ Supported Platforms:
3645
+ ``Ascend``
3646
+
3647
+ Examples:
3648
+ .. note::
3649
+ Before running the following examples, you need to configure the communication environment variables.
3650
+
3651
+ For Ascend devices, it is recommended to use the msrun startup method
3652
+ without any third-party or configuration file dependencies.
3653
+ Please see the `msrun start up
3654
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
3655
+ for more details.
3656
+
3657
+ This example should be run with 2 devices.
3658
+
3659
+ >>> from mindspore.ops.communication import init_process_group, broadcast_object_list, get_rank
3660
+ >>> init_process_group()
3661
+ >>> rank = get_rank()
3662
+ >>> obj = ["test", 12, {1: 2}]
3663
+ >>> if rank == 1:
3664
+ ... obj = [None, None, None]
3665
+ >>> broadcast_object_list(obj)
3666
+ >>> print(obj)
3667
+ ['test', 12, {1: 2}]
3668
+ """
3669
+ if is_inplace_func() is False:
3670
+ raise ValueError("Non-inplace mode is currently not supported.")
3671
+ if group is None:
3672
+ group = GlobalComm.WORLD_COMM_GROUP
3673
+ if not isinstance(group, str):
3674
+ raise TypeError(
3675
+ "For 'broadcast_object_list', the argument 'group' must be type of string, "
3676
+ "but got 'group' type : {}.".format(type(group))
3677
+ )
3678
+ if not isinstance(src, int):
3679
+ raise TypeError("For broadcast_object_list, the src must be int")
3680
+ if not isinstance(object_list, list) or not object_list:
3681
+ raise TypeError(f"The object_list can not be empty.")
3682
+ rank_id = get_cache_group_rank()
3683
+ tensor_sizes = []
3684
+ tensor_list = []
3685
+ size = 0
3686
+ object_size_list = [Tensor([0], dtype=mstype.int32) for i in range(len(object_list))]
3687
+ if rank_id == src:
3688
+ tensor_list, tensor_sizes = zip(
3689
+ *[_object_to_tensor(obj) for obj in object_list]
3690
+ )
3691
+ object_size_list = [Tensor([tensor_sizes[i]], dtype=mstype.int32) for i in range(len(tensor_sizes))]
3692
+ object_tensor = cat(tensor_list)
3693
+ object_size = cat(object_size_list)
3694
+ broadcast(object_size, src, group)
3695
+ size = int(sum(object_size).item())
3696
+ if rank_id != src:
3697
+ data = np.zeros((size)).astype(np.int8)
3698
+ object_tensor = Tensor(data)
3699
+ broadcast(object_tensor, src, group)
3700
+ if rank_id != src:
3701
+ offset = 0
3702
+ for i, item in enumerate(object_size):
3703
+ obj_size = item
3704
+ obj_view = object_tensor[offset: offset + obj_size]
3705
+ offset += obj_size
3706
+ object_list[i] = _tensor_to_object(obj_view, obj_size)
3707
+
3708
+
3709
+ def all_gather_object(object_list, obj, group=None):
3710
+ """
3711
+ Aggregates Python objects in a specified communication group.
3712
+
3713
+ Note:
3714
+ Similar to :func:`mindspore.ops.communication.all_gather`, but Python objects can be passed in.
3715
+
3716
+ Args:
3717
+ object_list (list[Any]): Output Python object list.
3718
+ obj (Any): Python object to be broadcast from current process.
3719
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
3720
+ Ascend. Default: ``None``.
3721
+
3722
+ Raises:
3723
+ TypeError: `group` is not a str.
3724
+ TypeError: If size of `object_list` is not equal to group size.
3725
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
3726
+
3727
+ Supported Platforms:
3728
+ ``Ascend``
3729
+
3730
+ Examples:
3731
+ .. note::
3732
+ Before running the following examples, you need to configure the communication environment variables.
3733
+
3734
+ For Ascend devices, it is recommended to use the msrun startup method
3735
+ without any third-party or configuration file dependencies.
3736
+ Please see the `msrun start up
3737
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
3738
+ for more details.
3739
+
3740
+ This example should be run with 2 devices.
3741
+
3742
+ >>> from mindspore.ops.communication import init_process_group, get_rank
3743
+ >>> from mindspore.ops.communication import all_gather_object
3744
+ >>> init_process_group()
3745
+ >>> rank = get_rank()
3746
+ >>> obj = ["test", {1: 2}]
3747
+ >>> object_gather_list=[None, None]
3748
+ >>> all_gather_object(object_gather_list, obj[rank])
3749
+ >>> print(object_gather_list)
3750
+ # rank_0
3751
+ ['test', {1: 2}]
3752
+ # rank_1
3753
+ ['test', {1: 2}]
3754
+ """
3755
+ if is_inplace_func() is False:
3756
+ raise ValueError("Non-inplace mode is currently not supported.")
3757
+ if group is None:
3758
+ group = GlobalComm.WORLD_COMM_GROUP
3759
+ if not isinstance(group, str):
3760
+ raise TypeError(
3761
+ "For 'all_gather_object', the argument 'group' must be type of string, "
3762
+ "but got 'group' type : {}.".format(type(group))
3763
+ )
3764
+ group_size = get_cache_group_size(group)
3765
+ if not isinstance(object_list, list) or len(object_list) != group_size:
3766
+ raise TypeError(
3767
+ f"The len of argument object_list must be equal to group rank size, but got {len(object_list)}."
3768
+ )
3769
+ _, size = _object_to_tensor(obj)
3770
+ tensor = Tensor([size], dtype=mstype.int32)
3771
+ object_size_list = [Tensor([0], dtype=mstype.int32) for i in range(group_size)]
3772
+ all_gather(object_size_list, tensor, group=group)
3773
+ max_object_size = int(max(object_size_list).item())
3774
+ in_tensor, size = _object_to_tensor(obj, max_object_size)
3775
+ data = np.zeros((size)).astype(np.int8)
3776
+ object_tensor_list = [Tensor(data) for i in range(group_size)]
3777
+ all_gather(object_tensor_list, in_tensor, group=group)
3778
+
3779
+ for i, item in enumerate(object_size_list):
3780
+ tensor_size = int(item.item())
3781
+ tensor = object_tensor_list[i]
3782
+ object_list[i] = _tensor_to_object(tensor, tensor_size)
3783
+
3784
+
3785
+ def all_to_all_v_c(output, input, send_count_matrix, group=None, async_op=False):
3786
+ r"""
3787
+ Based on the user-specified split size, the input tensor is divided and sent to other devices, where split chunks
3788
+ are received and then merged into a single output tensor.
3789
+
3790
+ Note:
3791
+ Only support PyNative mode, Graph mode is not currently supported.
3792
+
3793
+ Args:
3794
+ output (Tensor): the output tensor is gathered concatenated from remote ranks.
3795
+ input (Tensor): tensor to be scattered to remote rank.
3796
+ send_count_matrix (list[int]) - The sending and receiving parameters of all ranks,
3797
+ :math:`\text{send_count_matrix}[i*\text{rank_size}+j]` represents the amount of data sent by
3798
+ rank i to rank j, and the basic unit is first dimension sizes. Among them, `rank_size`
3799
+ indicates the size of the communication group.
3800
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
3801
+ Ascend. Default: ``None``.
3802
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
3803
+
3804
+ Returns:
3805
+ CommHandle. CommHandle is an async work handle, if `async_op` is set to True.
3806
+ CommHandle will be None, when `async_op` is False.
3807
+
3808
+ Raises:
3809
+ TypeError: If `input` or `output` is not tensor. `group` is not a str, or async_op is not bool.
3810
+
3811
+ Supported Platforms:
3812
+ ``Ascend``
3813
+
3814
+ Examples:
3815
+ .. note::
3816
+ Before running the following examples, you need to configure the communication environment variables.
3817
+
3818
+ For Ascend devices, it is recommended to use the msrun startup method
3819
+ without any third-party or configuration file dependencies.
3820
+ Please see the `msrun start up
3821
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
3822
+ for more details.
3823
+
3824
+ This example should be run with 2 devices.
3825
+
3826
+ >>> import numpy as np
3827
+ >>> import mindspore
3828
+ >>> from mindspore.ops.communication import init_process_group, get_rank
3829
+ >>> from mindspore.ops.communication import all_to_all_v_c
3830
+ >>> from mindspore import Tensor
3831
+ >>> from mindspore.ops import zeros
3832
+ >>>
3833
+ >>> init_process_group()
3834
+ >>> this_rank = get_rank()
3835
+ >>> if this_rank == 0:
3836
+ ... output = Tensor(np.zeros([3]).astype(np.float32))
3837
+ ... tensor = Tensor([0, 1, 2.]) * this_rank
3838
+ ... result = all_to_all_v_c(output, tensor, [0, 3, 3, 0])
3839
+ ... print(output)
3840
+ >>> if this_rank == 1:
3841
+ ... output = Tensor(np.zeros([3]).astype(np.float32))
3842
+ ... tensor = Tensor([0, 1, 2.]) * this_rank
3843
+ ... result = all_to_all_v_c(output, tensor, [0, 3, 3, 0])
3844
+ ... print(output)
3845
+ rank 0:
3846
+ [0. 1. 2]
3847
+ rank 1:
3848
+ [0. 0. 0]
3849
+ """
3850
+
3851
+ _check_all_tensors([input])
3852
+ _check_all_tensors([output])
3853
+ if group is None:
3854
+ group = GlobalComm.WORLD_COMM_GROUP
3855
+ if not isinstance(group, str):
3856
+ raise TypeError(
3857
+ "The argument 'group' must be type of string, "
3858
+ "but got 'group' type : {}.".format(type(group))
3859
+ )
3860
+ if not isinstance(async_op, bool):
3861
+ raise TypeError(
3862
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
3863
+ )
3864
+ if not isinstance(send_count_matrix, list):
3865
+ raise TypeError("send_count_matrix must be list, but got {}".format(type(send_count_matrix)))
3866
+ if not all(isinstance(x, int) for x in send_count_matrix):
3867
+ raise TypeError("send_count_matrix elements must be of type int")
3868
+ rank_size = get_cache_group_size(group)
3869
+ if rank_size * rank_size != len(send_count_matrix):
3870
+ raise TypeError(f"send_count_matrix must be square matrix, but got {len(send_count_matrix)}.")
3871
+ _send_count_matrix = _get_all_to_all_v_c_numel_list(output, input, send_count_matrix)
3872
+ _input = input.reshape(-1)
3873
+ rank_id = get_cache_group_rank(group)
3874
+ result = dist_comm_all_to_all_v_c_op(
3875
+ output,
3876
+ _input,
3877
+ group,
3878
+ _send_count_matrix,
3879
+ rank_size,
3880
+ rank_id,
3881
+ )
3882
+ _, handle = _deal_comm_outputs(result, async_op)
3883
+ return handle