mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (287) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/initializer.py +51 -15
  26. mindspore/common/mindir_util.py +2 -2
  27. mindspore/common/parameter.py +62 -15
  28. mindspore/common/recompute.py +39 -9
  29. mindspore/common/sparse_tensor.py +7 -3
  30. mindspore/common/tensor.py +183 -37
  31. mindspore/communication/__init__.py +1 -1
  32. mindspore/communication/_comm_helper.py +38 -3
  33. mindspore/communication/comm_func.py +315 -60
  34. mindspore/communication/management.py +14 -14
  35. mindspore/context.py +132 -22
  36. mindspore/dataset/__init__.py +1 -1
  37. mindspore/dataset/audio/__init__.py +1 -1
  38. mindspore/dataset/core/config.py +7 -0
  39. mindspore/dataset/core/validator_helpers.py +7 -0
  40. mindspore/dataset/engine/cache_client.py +1 -1
  41. mindspore/dataset/engine/datasets.py +72 -44
  42. mindspore/dataset/engine/datasets_audio.py +7 -7
  43. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  44. mindspore/dataset/engine/datasets_text.py +20 -20
  45. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  46. mindspore/dataset/engine/datasets_vision.py +33 -33
  47. mindspore/dataset/engine/iterators.py +29 -0
  48. mindspore/dataset/engine/obs/util.py +7 -0
  49. mindspore/dataset/engine/queue.py +114 -60
  50. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  51. mindspore/dataset/engine/validators.py +34 -14
  52. mindspore/dataset/text/__init__.py +1 -4
  53. mindspore/dataset/transforms/__init__.py +0 -3
  54. mindspore/dataset/utils/line_reader.py +2 -0
  55. mindspore/dataset/vision/__init__.py +1 -4
  56. mindspore/dataset/vision/utils.py +1 -1
  57. mindspore/dataset/vision/validators.py +2 -1
  58. mindspore/dnnl.dll +0 -0
  59. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  60. mindspore/experimental/es/embedding_service.py +883 -0
  61. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  62. mindspore/experimental/llm_boost/__init__.py +21 -0
  63. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  64. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  65. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  66. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  67. mindspore/experimental/llm_boost/register.py +129 -0
  68. mindspore/experimental/llm_boost/utils.py +31 -0
  69. mindspore/experimental/optim/adamw.py +85 -0
  70. mindspore/experimental/optim/optimizer.py +3 -0
  71. mindspore/hal/__init__.py +3 -3
  72. mindspore/hal/contiguous_tensors_handle.py +175 -0
  73. mindspore/hal/stream.py +18 -0
  74. mindspore/include/api/model_group.h +13 -1
  75. mindspore/include/api/types.h +10 -10
  76. mindspore/include/dataset/config.h +2 -2
  77. mindspore/include/dataset/constants.h +2 -2
  78. mindspore/include/dataset/execute.h +2 -2
  79. mindspore/include/dataset/vision.h +4 -0
  80. mindspore/jpeg62.dll +0 -0
  81. mindspore/log.py +1 -1
  82. mindspore/mindrecord/filewriter.py +68 -51
  83. mindspore/mindspore_backend.dll +0 -0
  84. mindspore/mindspore_common.dll +0 -0
  85. mindspore/mindspore_core.dll +0 -0
  86. mindspore/mindspore_glog.dll +0 -0
  87. mindspore/mindspore_np_dtype.dll +0 -0
  88. mindspore/mindspore_ops.dll +0 -0
  89. mindspore/mint/__init__.py +983 -46
  90. mindspore/mint/distributed/__init__.py +31 -0
  91. mindspore/mint/distributed/distributed.py +254 -0
  92. mindspore/mint/nn/__init__.py +268 -23
  93. mindspore/mint/nn/functional.py +125 -19
  94. mindspore/mint/nn/layer/__init__.py +39 -0
  95. mindspore/mint/nn/layer/activation.py +133 -0
  96. mindspore/mint/nn/layer/normalization.py +477 -0
  97. mindspore/mint/nn/layer/pooling.py +110 -0
  98. mindspore/mint/optim/adamw.py +26 -13
  99. mindspore/mint/special/__init__.py +63 -0
  100. mindspore/multiprocessing/__init__.py +2 -1
  101. mindspore/nn/__init__.py +0 -1
  102. mindspore/nn/cell.py +276 -96
  103. mindspore/nn/layer/activation.py +211 -44
  104. mindspore/nn/layer/basic.py +137 -10
  105. mindspore/nn/layer/embedding.py +137 -2
  106. mindspore/nn/layer/normalization.py +101 -5
  107. mindspore/nn/layer/padding.py +34 -48
  108. mindspore/nn/layer/pooling.py +161 -7
  109. mindspore/nn/layer/transformer.py +3 -3
  110. mindspore/nn/loss/__init__.py +2 -2
  111. mindspore/nn/loss/loss.py +84 -6
  112. mindspore/nn/optim/__init__.py +2 -1
  113. mindspore/nn/optim/adadelta.py +1 -1
  114. mindspore/nn/optim/adam.py +1 -1
  115. mindspore/nn/optim/lamb.py +1 -1
  116. mindspore/nn/optim/tft_wrapper.py +124 -0
  117. mindspore/nn/wrap/cell_wrapper.py +12 -23
  118. mindspore/nn/wrap/grad_reducer.py +5 -5
  119. mindspore/nn/wrap/loss_scale.py +17 -3
  120. mindspore/numpy/__init__.py +1 -1
  121. mindspore/numpy/array_creations.py +65 -68
  122. mindspore/numpy/array_ops.py +64 -60
  123. mindspore/numpy/fft.py +610 -75
  124. mindspore/numpy/logic_ops.py +11 -10
  125. mindspore/numpy/math_ops.py +85 -84
  126. mindspore/numpy/utils_const.py +4 -4
  127. mindspore/opencv_core452.dll +0 -0
  128. mindspore/opencv_imgcodecs452.dll +0 -0
  129. mindspore/opencv_imgproc452.dll +0 -0
  130. mindspore/ops/__init__.py +6 -4
  131. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  132. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  133. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  134. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  135. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  136. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  139. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  140. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  141. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  142. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  143. mindspore/ops/composite/base.py +85 -48
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  145. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  146. mindspore/ops/function/__init__.py +22 -0
  147. mindspore/ops/function/array_func.py +492 -153
  148. mindspore/ops/function/debug_func.py +113 -1
  149. mindspore/ops/function/fft_func.py +15 -2
  150. mindspore/ops/function/grad/grad_func.py +3 -2
  151. mindspore/ops/function/math_func.py +564 -207
  152. mindspore/ops/function/nn_func.py +817 -383
  153. mindspore/ops/function/other_func.py +3 -2
  154. mindspore/ops/function/random_func.py +402 -12
  155. mindspore/ops/function/reshard_func.py +13 -11
  156. mindspore/ops/function/sparse_unary_func.py +1 -1
  157. mindspore/ops/function/vmap_func.py +3 -2
  158. mindspore/ops/functional.py +24 -14
  159. mindspore/ops/op_info_register.py +3 -3
  160. mindspore/ops/operations/__init__.py +7 -2
  161. mindspore/ops/operations/_grad_ops.py +2 -76
  162. mindspore/ops/operations/_infer_ops.py +1 -1
  163. mindspore/ops/operations/_inner_ops.py +71 -94
  164. mindspore/ops/operations/array_ops.py +14 -146
  165. mindspore/ops/operations/comm_ops.py +63 -53
  166. mindspore/ops/operations/custom_ops.py +83 -19
  167. mindspore/ops/operations/debug_ops.py +42 -10
  168. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  169. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  170. mindspore/ops/operations/math_ops.py +12 -223
  171. mindspore/ops/operations/nn_ops.py +20 -114
  172. mindspore/ops/operations/other_ops.py +7 -4
  173. mindspore/ops/operations/random_ops.py +46 -1
  174. mindspore/ops/primitive.py +18 -6
  175. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  176. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  177. mindspore/ops_generate/gen_constants.py +36 -0
  178. mindspore/ops_generate/gen_ops.py +67 -52
  179. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  180. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  181. mindspore/ops_generate/op_proto.py +10 -3
  182. mindspore/ops_generate/pyboost_utils.py +14 -1
  183. mindspore/ops_generate/template.py +43 -21
  184. mindspore/parallel/__init__.py +3 -1
  185. mindspore/parallel/_auto_parallel_context.py +31 -9
  186. mindspore/parallel/_cell_wrapper.py +85 -0
  187. mindspore/parallel/_parallel_serialization.py +47 -19
  188. mindspore/parallel/_tensor.py +127 -13
  189. mindspore/parallel/_utils.py +53 -22
  190. mindspore/parallel/algo_parameter_config.py +5 -5
  191. mindspore/parallel/checkpoint_transform.py +46 -39
  192. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  193. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  194. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  195. mindspore/parallel/parameter_broadcast.py +3 -4
  196. mindspore/parallel/shard.py +162 -31
  197. mindspore/parallel/transform_safetensors.py +1146 -0
  198. mindspore/profiler/__init__.py +2 -1
  199. mindspore/profiler/common/constant.py +29 -0
  200. mindspore/profiler/common/registry.py +47 -0
  201. mindspore/profiler/common/util.py +28 -0
  202. mindspore/profiler/dynamic_profiler.py +694 -0
  203. mindspore/profiler/envprofiling.py +17 -19
  204. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  205. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  206. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  207. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  208. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  209. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  210. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  211. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  212. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  213. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  214. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  215. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  216. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  217. mindspore/profiler/parser/framework_parser.py +1 -391
  218. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  219. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  220. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  221. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  222. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  223. mindspore/profiler/parser/profiler_info.py +78 -6
  224. mindspore/profiler/profiler.py +153 -0
  225. mindspore/profiler/profiling.py +285 -413
  226. mindspore/rewrite/__init__.py +1 -2
  227. mindspore/rewrite/common/namespace.py +4 -4
  228. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  229. mindspore/run_check/_check_version.py +39 -104
  230. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  231. mindspore/swresample-4.dll +0 -0
  232. mindspore/swscale-6.dll +0 -0
  233. mindspore/tinyxml2.dll +0 -0
  234. mindspore/train/__init__.py +4 -3
  235. mindspore/train/_utils.py +105 -19
  236. mindspore/train/amp.py +171 -53
  237. mindspore/train/callback/__init__.py +2 -2
  238. mindspore/train/callback/_callback.py +4 -4
  239. mindspore/train/callback/_checkpoint.py +97 -31
  240. mindspore/train/callback/_cluster_monitor.py +1 -1
  241. mindspore/train/callback/_flops_collector.py +1 -0
  242. mindspore/train/callback/_loss_monitor.py +3 -3
  243. mindspore/train/callback/_on_request_exit.py +145 -31
  244. mindspore/train/callback/_summary_collector.py +5 -5
  245. mindspore/train/callback/_tft_register.py +375 -0
  246. mindspore/train/dataset_helper.py +15 -3
  247. mindspore/train/metrics/metric.py +3 -3
  248. mindspore/train/metrics/roc.py +4 -4
  249. mindspore/train/mind_ir_pb2.py +44 -39
  250. mindspore/train/model.py +154 -58
  251. mindspore/train/serialization.py +342 -128
  252. mindspore/turbojpeg.dll +0 -0
  253. mindspore/utils/__init__.py +21 -0
  254. mindspore/utils/utils.py +60 -0
  255. mindspore/version.py +1 -1
  256. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  257. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
  258. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
  259. mindspore/include/c_api/ms/abstract.h +0 -67
  260. mindspore/include/c_api/ms/attribute.h +0 -197
  261. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  262. mindspore/include/c_api/ms/base/macros.h +0 -32
  263. mindspore/include/c_api/ms/base/status.h +0 -33
  264. mindspore/include/c_api/ms/base/types.h +0 -283
  265. mindspore/include/c_api/ms/context.h +0 -102
  266. mindspore/include/c_api/ms/graph.h +0 -160
  267. mindspore/include/c_api/ms/node.h +0 -606
  268. mindspore/include/c_api/ms/tensor.h +0 -161
  269. mindspore/include/c_api/ms/value.h +0 -84
  270. mindspore/mindspore_shared_lib.dll +0 -0
  271. mindspore/nn/extend/basic.py +0 -140
  272. mindspore/nn/extend/embedding.py +0 -143
  273. mindspore/nn/extend/layer/normalization.py +0 -109
  274. mindspore/nn/extend/pooling.py +0 -117
  275. mindspore/nn/layer/embedding_service.py +0 -531
  276. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  277. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  278. mindspore/ops/extend/__init__.py +0 -53
  279. mindspore/ops/extend/array_func.py +0 -218
  280. mindspore/ops/extend/math_func.py +0 -76
  281. mindspore/ops/extend/nn_func.py +0 -308
  282. mindspore/ops/silent_check.py +0 -162
  283. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  284. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  285. mindspore/train/callback/_mindio_ttp.py +0 -443
  286. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  287. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,18 @@
17
17
  Defines communication operators with functional form.
18
18
  """
19
19
  from mindspore.communication import GlobalComm, get_group_rank_from_world_rank, get_group_size
20
+ from mindspore.communication.management import _get_group
21
+ from mindspore.communication._comm_helper import _get_group_rank_from_world_rank_from_cache_helper
20
22
  from mindspore.common.tensor import Tensor
21
23
  from mindspore._c_expression import Tensor as Tensor_
22
24
  from mindspore.ops import ReduceOp, cat
23
25
  from mindspore.ops._primitive_cache import _get_cache_prim
24
26
  from mindspore.ops.primitive import _primexpr
27
+ from mindspore.ops.auto_generate.gen_ops_prim import (inner_comm_all_reduce_op, inner_comm_all_gather_op,
28
+ inner_comm_all_to_all_v_op, inner_comm_irecv_op,
29
+ inner_comm_isend_op, inner_comm_reduce_scatter_op)
30
+ from mindspore._c_expression import CommHandle as CommHandle_
31
+ from mindspore import jit_class
25
32
 
26
33
  __all__ = [
27
34
  'all_reduce',
@@ -36,15 +43,48 @@ __all__ = [
36
43
  'reduce_scatter_tensor',
37
44
  'reduce',
38
45
  'scatter_tensor',
46
+ 'send',
47
+ 'recv',
39
48
  'P2POp',
40
49
  'batch_isend_irecv',
41
50
  ]
42
51
 
43
52
  import mindspore.ops.operations as P
44
53
 
54
+ _GROPU_SIZE_CACHE = {}
55
+
56
+ @jit_class
57
+ class CommHandle(CommHandle_):
58
+ r"""
59
+ Usually, handles are created in C++during the execution of communication operators and returned to the Python
60
+ layer. It will not be created directly in Python. Only in scenarios where graph patterns are compatible,
61
+ handles will be created using Python.
62
+ """
63
+
64
+ def wait(self):
65
+ r"""
66
+ The wait for asynchronous handles will not take effect for handles created on the Python side.
67
+
68
+ >>> import numpy as np
69
+ >>> from mindspore.communication import init
70
+ >>> from mindspore.communication.comm_func import all_reduce
71
+ >>> from mindspore import Tensor
72
+ >>>
73
+ >>> init()
74
+ >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
75
+ >>> output, handle = all_reduce(input_tensor, async_op=True)
76
+ >>> handle.wait()
77
+ >>> print(output)
78
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
79
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
80
+ """
81
+
82
+
83
+ default_handle = CommHandle()
84
+
45
85
 
46
86
  def _check_split_sizes_sequence(tensor, sequence):
47
- if sequence == []:
87
+ if not sequence:
48
88
  raise TypeError(f"sequence can not be empty list.")
49
89
  element0 = sequence[0]
50
90
  for idx in range(1, len(sequence)):
@@ -132,7 +172,17 @@ def _get_size(shape):
132
172
  return numel
133
173
 
134
174
 
135
- def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
175
+ def _is_split_sizes_empty(split_sizes):
176
+ return split_sizes is None or not split_sizes
177
+
178
+
179
+ def _contiguous(tensor):
180
+ if not tensor.is_contiguous() or tensor.storage_offset() != 0:
181
+ tensor = tensor.contiguous()
182
+ return tensor
183
+
184
+
185
+ def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
136
186
  """
137
187
  Reduce tensors across all devices in such a way that all deviceswill get the same final result,
138
188
  returns the tensor which is all reduced.
@@ -146,17 +196,20 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
146
196
  On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
147
197
  group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
148
198
  means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
199
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
149
200
 
150
201
  Returns:
151
- Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
152
- The contents depend on the specified operation.
202
+ Tuple(Tensor, CommHandle), the output tensor has the same shape of the input,
203
+ i.e., :math:`(x_1, x_2, ..., x_R)`. The contents depend on the specified operation.
204
+ CommHandle is an async work handle, if `async_op` is set to True. CommHandle will be None,
205
+ when `async_op` is False.
153
206
 
154
207
  Raises:
155
208
  TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str.
156
209
  RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
157
210
 
158
211
  Supported Platforms:
159
- ``Ascend`` ``GPU`` ``CPU``
212
+ ``Ascend``
160
213
 
161
214
  Examples:
162
215
  .. note::
@@ -165,7 +218,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
165
218
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
166
219
  without any third-party or configuration file dependencies.
167
220
  Please see the `msrun start up
168
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
221
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
169
222
  for more details.
170
223
 
171
224
  This example should be run with 2 devices.
@@ -185,11 +238,17 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
185
238
  """
186
239
  if not isinstance(tensor, (Tensor, Tensor_)):
187
240
  raise TypeError("For all_reduce, the input tensor must be tensor")
188
- all_reduce_op = _get_cache_prim(P.AllReduce)(op=op, group=group)
189
- return all_reduce_op(tensor)
241
+ if not isinstance(op, str):
242
+ raise TypeError("For all_reduce, the input op type must be str")
243
+ if op not in ('sum', 'prod', 'min', 'max'):
244
+ raise TypeError("For all_reduce, the input op value must be one of sum, prod, min, max")
245
+ group = _get_group(group)
246
+ tensor = _contiguous(tensor)
247
+ output = inner_comm_all_reduce_op(tensor, op, group)
248
+ return _deal_comm_outputs(output, async_op)
190
249
 
191
250
 
192
- def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP):
251
+ def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
193
252
  """
194
253
  Gathers tensors from the specified communication group and returns the tensor which is all gathered.
195
254
 
@@ -201,10 +260,13 @@ def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP):
201
260
  The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
202
261
  group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
203
262
  means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
263
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
204
264
 
205
265
  Returns:
206
- Tensor. If the number of devices in the group is N,
207
- then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
266
+ Tuple(Tensor, CommHandle), if the number of devices in the group is N,
267
+ then the shape of output tensor is :math:`(N, x_1, x_2, ..., x_R)`.
268
+ CommHandle is an async work handle, if `async_op` is set to True.
269
+ CommHandle will be None, when `async_op` is False.
208
270
 
209
271
  Raises:
210
272
  TypeError: If the type of the first input parameter is not Tensor, or `group` is not a str.
@@ -213,7 +275,7 @@ def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP):
213
275
  RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
214
276
 
215
277
  Supported Platforms:
216
- ``Ascend`` ``GPU``
278
+ ``Ascend``
217
279
 
218
280
  Examples:
219
281
  .. note::
@@ -222,7 +284,7 @@ def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP):
222
284
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
223
285
  without any third-party or configuration file dependencies.
224
286
  Please see the `msrun start up
225
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
287
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
226
288
  for more details.
227
289
 
228
290
  This example should be run with 2 devices.
@@ -248,11 +310,17 @@ def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP):
248
310
 
249
311
  if not isinstance(tensor, (Tensor, Tensor_)):
250
312
  raise TypeError("For all_gather_into_tensor, the input tensor must be tensor")
251
- all_gather_op = _get_cache_prim(P.AllGather)(group=group)
252
- return all_gather_op(tensor)
313
+ group = _get_group(group)
314
+ global _GROPU_SIZE_CACHE
315
+ if group not in _GROPU_SIZE_CACHE:
316
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
317
+ group_size = _GROPU_SIZE_CACHE[group]
318
+ tensor = _contiguous(tensor)
319
+ output = inner_comm_all_gather_op(tensor, group_size, group)
320
+ return _deal_comm_outputs(output, async_op)
253
321
 
254
322
 
255
- def reduce_scatter_tensor(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
323
+ def reduce_scatter_tensor(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
256
324
  r"""
257
325
  Reduces and scatters tensors from the specified communication group and
258
326
  returns the tensor which is reduced and scattered.
@@ -268,9 +336,12 @@ def reduce_scatter_tensor(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_G
268
336
  like SUM and MAX. Default: ``ReduceOp.SUM`` .
269
337
  group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
270
338
  means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
339
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
271
340
 
272
341
  Returns:
273
- Tensor, it has the same dtype as `input_x` with a shape of :math:`(N/rank\_size, *)`.
342
+ Tuple(Tensor, CommHandle), the output tensor has the same dtype as `input_x` with a shape of
343
+ :math:`(N/rank\_size, *)`. CommHandle is an async work handle, if `async_op` is set to True.
344
+ CommHandle will be None, when `async_op` is False.
274
345
 
275
346
  Raises:
276
347
  TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str.
@@ -278,7 +349,7 @@ def reduce_scatter_tensor(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_G
278
349
  RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
279
350
 
280
351
  Supported Platforms:
281
- ``Ascend`` ``GPU``
352
+ ``Ascend``
282
353
 
283
354
  Examples:
284
355
  .. note::
@@ -287,7 +358,7 @@ def reduce_scatter_tensor(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_G
287
358
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
288
359
  without any third-party or configuration file dependencies.
289
360
  Please see the `msrun start up
290
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
361
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
291
362
  for more details.
292
363
 
293
364
  This example should be run with 2 devices.
@@ -312,8 +383,14 @@ def reduce_scatter_tensor(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_G
312
383
 
313
384
  if not isinstance(tensor, (Tensor, Tensor_)):
314
385
  raise TypeError("For reduce_scatter_tensor, the input tensor must be tensor")
315
- reduce_scatter_op = _get_cache_prim(P.ReduceScatter)(op=op, group=group)
316
- return reduce_scatter_op(tensor)
386
+ group = _get_group(group)
387
+ global _GROPU_SIZE_CACHE
388
+ if group not in _GROPU_SIZE_CACHE:
389
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
390
+ rank_size = _GROPU_SIZE_CACHE[group]
391
+ tensor = _contiguous(tensor)
392
+ output = inner_comm_reduce_scatter_op(tensor, rank_size, op, group)
393
+ return _deal_comm_outputs(output, async_op)
317
394
 
318
395
 
319
396
  def reduce(tensor, dst, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
@@ -353,7 +430,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
353
430
  without any third-party or configuration file dependencies.
354
431
 
355
432
  Please see the `msrun start up
356
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
433
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
357
434
  for more details.
358
435
 
359
436
  This example should be run with 4 devices.
@@ -428,6 +505,7 @@ class P2POp:
428
505
  >>> recv_op = P2POp(irecv, recv_tensor, 0)
429
506
  >>> recv_op = P2POp('irecv', (), 0, recv_dtype=mindspore.float32)
430
507
  """
508
+
431
509
  def __init__(self, op, tensor, peer, group=None, tag=0, *, recv_dtype=None):
432
510
  self.op = op
433
511
  self.tensor = tensor
@@ -482,7 +560,7 @@ def batch_isend_irecv(p2p_op_list):
482
560
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
483
561
  without any third-party or configuration file dependencies.
484
562
  Please see the `msrun start up
485
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
563
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
486
564
  for more details.
487
565
 
488
566
  This example should be run with 2 devices.
@@ -519,6 +597,8 @@ def batch_isend_irecv(p2p_op_list):
519
597
  receive_shapes = []
520
598
  receive_dtypes = []
521
599
  tags = []
600
+ if not p2p_op_list:
601
+ raise TypeError(f"p2p_op_list can not be empty list.")
522
602
  group = p2p_op_list[0].group
523
603
  if group is None:
524
604
  group = GlobalComm.WORLD_COMM_GROUP
@@ -596,7 +676,7 @@ def scatter_tensor(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP):
596
676
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
597
677
  without any third-party or configuration file dependencies.
598
678
  Please see the `msrun start up
599
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
679
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
600
680
  for more details.
601
681
 
602
682
  This example should be run with 2 devices.
@@ -661,7 +741,7 @@ def gather_into_tensor(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP):
661
741
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
662
742
  without any third-party or configuration file dependencies.
663
743
  Please see the `msrun start up
664
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
744
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
665
745
  for more details.
666
746
 
667
747
  This example should be run with 2 devices.
@@ -724,7 +804,7 @@ def broadcast(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP):
724
804
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
725
805
  without any third-party or configuration file dependencies.
726
806
  Please see the `msrun start up
727
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
807
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
728
808
  for more details.
729
809
 
730
810
  This example should be run with 2 devices.
@@ -778,7 +858,7 @@ def barrier(group=GlobalComm.WORLD_COMM_GROUP):
778
858
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
779
859
  without any third-party or configuration file dependencies.
780
860
  Please see the `msrun start up
781
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
861
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
782
862
  for more details.
783
863
 
784
864
  This example should be run with 2 devices.
@@ -797,7 +877,19 @@ def barrier(group=GlobalComm.WORLD_COMM_GROUP):
797
877
  return _op()
798
878
 
799
879
 
800
- def isend(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
880
+ def _deal_comm_outputs(output, async_op):
881
+ if isinstance(output, tuple):
882
+ if not async_op:
883
+ output[1].wait()
884
+ return (output[0], None)
885
+ return output
886
+
887
+ if not async_op:
888
+ return (output, None)
889
+ return (output, default_handle)
890
+
891
+
892
+ def send(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
801
893
  """
802
894
  Send tensors to the specified dest_rank.
803
895
 
@@ -817,7 +909,7 @@ def isend(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
817
909
  ValueError: If the rank ID of the process is greater than the rank size of the communication group.
818
910
 
819
911
  Supported Platforms:
820
- ``Ascend`` ``GPU``
912
+ ``Ascend``
821
913
 
822
914
  Examples:
823
915
  .. note::
@@ -826,7 +918,138 @@ def isend(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
826
918
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
827
919
  without any third-party or configuration file dependencies.
828
920
  Please see the `msrun start up
829
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
921
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
922
+ for more details.
923
+
924
+ This example should be run with 2 devices.
925
+
926
+ >>> from mindspore import ops
927
+ >>> import mindspore.nn as nn
928
+ >>> from mindspore.communication import init
929
+ >>> from mindspore.communication.comm_func import send
930
+ >>> from mindspore import Tensor
931
+ >>> import numpy as np
932
+ >>>
933
+ >>> init()
934
+ >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
935
+ >>> send(input_, 0)
936
+ """
937
+ if not isinstance(tensor, (Tensor, Tensor_)):
938
+ raise TypeError("For send, the input tensor must be tensor")
939
+ group = _get_group(group)
940
+ _dst = _get_group_rank_from_world_rank_from_cache_helper(dst, group)
941
+ tensor = _contiguous(tensor)
942
+ output = inner_comm_isend_op(tensor, _dst, group, tag)
943
+ _deal_comm_outputs(output, False)
944
+
945
+
946
+ def recv(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
947
+ """
948
+ Receive tensors from src.
949
+
950
+ Note:
951
+ Send and Receive must be used in combination and have same tag.
952
+ The shape and dtype of input `tensor` is used to receive tensor, but the value
953
+ of input `tensor` would not take effect.
954
+ Only support PyNative mode, Graph mode is not currently supported.
955
+
956
+ Args:
957
+ tensor (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The shape and dtype of this
958
+ tensor is used to receive tensor, but the value of input `tensor` would not take effect.
959
+ src (int, optional): A required integer identifying the source rank(global rank). Default: 0.
960
+ group (str, optional): The communication group to work on.
961
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
962
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
963
+ be received by the Send op with the same "tag". Default: 0.
964
+
965
+ Returns:
966
+ Tensor, the shape of output is :math:`(x_1, x_2, ..., x_R)`.
967
+
968
+ Raises:
969
+ TypeError: If `src` is not an int or `group` is not a str.
970
+ ValueError: If the rank ID of the process is greater than the rank size of the communication group.
971
+
972
+ Supported Platforms:
973
+ ``Ascend``
974
+
975
+ Examples:
976
+ .. note::
977
+ Before running the following examples, you need to configure the communication environment variables.
978
+
979
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
980
+ without any third-party or configuration file dependencies.
981
+ Please see the `msrun start up
982
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
983
+ for more details.
984
+
985
+ This example should be run with 2 devices.
986
+
987
+ >>> from mindspore import ops
988
+ >>> import mindspore.nn as nn
989
+ >>> from mindspore.communication import init
990
+ >>> from mindspore.communication.comm_func import recv
991
+ >>> from mindspore import Tensor
992
+ >>> import numpy as np
993
+ >>>
994
+ # Launch 2 processes.
995
+ Process 0 send the following array to Process 1
996
+ [[ 0. 1.]
997
+ [ 2. 3.]]
998
+ >>> init()
999
+ >>> x = ms.Tensor(np.zeros([2, 2]))
1000
+ # Process 1 receive tensor from Process 0.
1001
+ >>> out = recv(x, src=0)
1002
+ >>> print(out)
1003
+ [[ 0. 1.]
1004
+ [ 2. 3.]]
1005
+ """
1006
+ if not isinstance(tensor, (Tensor, Tensor_)):
1007
+ raise TypeError("For recv, the input tensor must be tensor")
1008
+ if not isinstance(src, int):
1009
+ raise TypeError("For recv, the src must be int")
1010
+ group = _get_group(group)
1011
+ _src = _get_group_rank_from_world_rank_from_cache_helper(src, group)
1012
+ tensor = _contiguous(tensor)
1013
+ shape = tensor.shape
1014
+ dtype = tensor.dtype
1015
+ output, _ = _deal_comm_outputs(inner_comm_irecv_op(tag, _src, shape, group, dtype), False)
1016
+ return output
1017
+
1018
+
1019
+ def isend(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
1020
+ """
1021
+ Send tensors to the specified dest_rank asynchronously.
1022
+
1023
+ Note:
1024
+ Send and Receive must be used in combination and have same tag.
1025
+ Only support PyNative mode, Graph mode is not currently supported.
1026
+
1027
+ Args:
1028
+ tensor (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1029
+ dst (int, optional): A required integer identifying the destination rank(global rank). Default: 0.
1030
+ group (str, optional): The communication group to work on.
1031
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
1032
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
1033
+ be received by the Receive op with the same "tag". Default: 0.
1034
+
1035
+ Returns:
1036
+ CommHandle, it is an async work handle.
1037
+
1038
+ Raises:
1039
+ TypeError: `dst` is not an int or `group` is not a str。
1040
+ ValueError: If the rank ID of the process is greater than the rank size of the communication group.
1041
+
1042
+ Supported Platforms:
1043
+ ``Ascend``
1044
+
1045
+ Examples:
1046
+ .. note::
1047
+ Before running the following examples, you need to configure the communication environment variables.
1048
+
1049
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1050
+ without any third-party or configuration file dependencies.
1051
+ Please see the `msrun start up
1052
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
830
1053
  for more details.
831
1054
 
832
1055
  This example should be run with 2 devices.
@@ -840,19 +1063,22 @@ def isend(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
840
1063
  >>>
841
1064
  >>> init()
842
1065
  >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
843
- >>> isend(input_, 0)
1066
+ >>> handle = isend(input_, 0)
1067
+ >>> handle.wait()
844
1068
  """
845
1069
  if not isinstance(tensor, (Tensor, Tensor_)):
846
1070
  raise TypeError("For isend, the input tensor must be tensor")
847
- _dst = get_group_rank_from_world_rank(dst, group)
848
- _op = _get_cache_prim(P.Send)(tag, _dst, group, group)
849
- _depend = _get_cache_prim(P.Depend)()
850
- return _depend(tensor, _op(tensor))
1071
+ group = _get_group(group)
1072
+ _dst = _get_group_rank_from_world_rank_from_cache_helper(dst, group)
1073
+ tensor = _contiguous(tensor)
1074
+ output = inner_comm_isend_op(tensor, _dst, group, tag)
1075
+ _, handle = _deal_comm_outputs(output, True)
1076
+ return handle
851
1077
 
852
1078
 
853
1079
  def irecv(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
854
1080
  """
855
- Receive tensors from src.
1081
+ Receive tensors from src asynchronously.
856
1082
 
857
1083
  Note:
858
1084
  Send and Receive must be used in combination and have same tag.
@@ -870,14 +1096,16 @@ def irecv(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
870
1096
  be received by the Send op with the same "tag". Default: 0.
871
1097
 
872
1098
  Returns:
873
- Tensor, the shape of output is :math:`(x_1, x_2, ..., x_R)`.
1099
+ Tuple(Tensor, CommHandle), the shape of output is :math:`(x_1, x_2, ..., x_R)`.
1100
+ CommHandle is an async work handle, if `async_op` is set to True.
1101
+ CommHandle will be None, when `async_op` is False.
874
1102
 
875
1103
  Raises:
876
1104
  TypeError: If `src` is not an int or `group` is not a str.
877
1105
  ValueError: If the rank ID of the process is greater than the rank size of the communication group.
878
1106
 
879
1107
  Supported Platforms:
880
- ``Ascend`` ``GPU``
1108
+ ``Ascend``
881
1109
 
882
1110
  Examples:
883
1111
  .. note::
@@ -886,7 +1114,7 @@ def irecv(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
886
1114
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
887
1115
  without any third-party or configuration file dependencies.
888
1116
  Please see the `msrun start up
889
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1117
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
890
1118
  for more details.
891
1119
 
892
1120
  This example should be run with 2 devices.
@@ -905,19 +1133,22 @@ def irecv(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
905
1133
  >>> init()
906
1134
  >>> x = ms.Tensor(np.zeros([2, 2]))
907
1135
  # Process 1 receive tensor from Process 0.
908
- >>> out = irecv(x, src=0)
1136
+ >>> out, handle = irecv(x, src=0)
1137
+ >>> handle.wait()
909
1138
  >>> print(out)
910
1139
  [[ 0. 1.]
911
1140
  [ 2. 3.]]
912
1141
  """
913
- _src = get_group_rank_from_world_rank(src, group)
1142
+ group = _get_group(group)
1143
+ _src = _get_group_rank_from_world_rank_from_cache_helper(src, group)
1144
+ tensor = _contiguous(tensor)
914
1145
  shape = tensor.shape
915
1146
  dtype = tensor.dtype
916
- _op = _get_cache_prim(P.Receive)(tag, _src, shape, dtype, group, group)
917
- return _op(tensor)
1147
+ output = inner_comm_irecv_op(tag, _src, shape, group, dtype)
1148
+ return _deal_comm_outputs(output, True)
918
1149
 
919
1150
 
920
- def all_to_all_with_output_shape(output_shape_list, input_tensor_list, group=None):
1151
+ def all_to_all_with_output_shape(output_shape_list, input_tensor_list, group=None, async_op=False):
921
1152
  """
922
1153
  scatter and gather list of tensor to/from all rank according to input/output tensor list.
923
1154
 
@@ -932,9 +1163,12 @@ def all_to_all_with_output_shape(output_shape_list, input_tensor_list, group=Non
932
1163
  List of tensors to scatter to the remote rank.
933
1164
  group (str, optional): The communication group to work on.
934
1165
  Default: None, which means "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
1166
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
935
1167
 
936
1168
  Returns:
937
- Tuple(Tensor), the tensors gathered from remote ranks.
1169
+ Tuple(Tuple(Tensor), CommHandle), the tensors is gathered from remote ranks.
1170
+ CommHandle is an async work handle, if `async_op` is set to True.
1171
+ CommHandle will be None, when `async_op` is False.
938
1172
 
939
1173
  Raises:
940
1174
  TypeError: If `input_tensor_list` is not list of tensors.
@@ -951,7 +1185,7 @@ def all_to_all_with_output_shape(output_shape_list, input_tensor_list, group=Non
951
1185
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
952
1186
  without any third-party or configuration file dependencies.
953
1187
  Please see the `msrun start up
954
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1188
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
955
1189
  for more details.
956
1190
 
957
1191
  This example should be run with 2 devices.
@@ -1004,28 +1238,40 @@ def all_to_all_with_output_shape(output_shape_list, input_tensor_list, group=Non
1004
1238
  recv_numel_list.append(_get_size(_shape))
1005
1239
  recv_shape_list.append(_shape)
1006
1240
 
1007
- _op = _get_cache_prim(P.AlltoAllV)(send_numel_list, recv_numel_list, group)
1008
1241
  send_flatten_tensor = cat(send_flatten_tensor)
1009
- output = _op(send_flatten_tensor)
1242
+ send_flatten_tensor = _contiguous(send_flatten_tensor)
1243
+ group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
1244
+ global _GROPU_SIZE_CACHE
1245
+ if group not in _GROPU_SIZE_CACHE:
1246
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
1247
+ rank_size = _GROPU_SIZE_CACHE[group]
1248
+ output = inner_comm_all_to_all_v_op(send_flatten_tensor, group, send_numel_list, recv_numel_list,
1249
+ rank_size, False)
1250
+ output, handle = _deal_comm_outputs(output, async_op)
1010
1251
  result = []
1011
1252
  offset = 0
1012
1253
  for numel, shape in zip(recv_numel_list, recv_shape_list):
1013
1254
  result.append(output[offset:offset + numel].reshape(shape))
1014
1255
  offset = offset + numel
1015
- return tuple(result)
1256
+ return (tuple(result), handle)
1016
1257
 
1017
1258
 
1018
1259
  def _get_all_to_all_single_numel_list(tensor, output_shape, output_split_sizes, input_split_sizes, group):
1019
1260
  """get numel list for all_to_all_single."""
1020
- if input_split_sizes is None or not input_split_sizes:
1021
- _world_size = get_group_size(group)
1261
+ global _GROPU_SIZE_CACHE
1262
+ if _is_split_sizes_empty(input_split_sizes):
1263
+ if group not in _GROPU_SIZE_CACHE:
1264
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
1265
+ _world_size = _GROPU_SIZE_CACHE[group]
1022
1266
  if tensor.shape[0] % _world_size != 0:
1023
1267
  raise ValueError("input shape at dim 0 must be divided by world_size, "
1024
1268
  f"but got {tensor.shape[0]} and {_world_size}.")
1025
1269
  _split_size = tensor.shape[0] // _world_size
1026
1270
  input_split_sizes = (_split_size,) * _world_size
1027
- if output_split_sizes is None or not output_split_sizes:
1028
- _world_size = get_group_size(group)
1271
+ if _is_split_sizes_empty(output_split_sizes):
1272
+ if group not in _GROPU_SIZE_CACHE:
1273
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
1274
+ _world_size = _GROPU_SIZE_CACHE[group]
1029
1275
  shape_dim_0 = None
1030
1276
  if isinstance(output_shape, Tensor):
1031
1277
  shape_dim_0 = output_shape.shape[0]
@@ -1053,7 +1299,7 @@ def _get_all_to_all_single_numel_list(tensor, output_shape, output_split_sizes,
1053
1299
 
1054
1300
 
1055
1301
  def all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes=None,
1056
- input_split_sizes=None, group=None):
1302
+ input_split_sizes=None, group=None, async_op=False):
1057
1303
  """
1058
1304
  scatter and gather input with split size to/from all rank, and return result in a single tensor.
1059
1305
 
@@ -1071,11 +1317,13 @@ def all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes
1071
1317
  it means equally split by ``world_size``. Default: None.
1072
1318
  group (str, optional): The communication group to work on.
1073
1319
  Default: None, which means "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
1320
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
1074
1321
 
1075
1322
  Returns:
1076
- Tensor, the tensors gathered concatenated from remote ranks.
1323
+ Tuple(Tensor, CommHandle), the output tensor is gathered concatenated from remote ranks.
1077
1324
  If the numel of tensor gathered from remote is zero, it will return a Tensor will value 0,
1078
- which has no actual meanning.
1325
+ which has no actual meanning. CommHandle is an async work handle, if `async_op` is set to True.
1326
+ CommHandle will be None, when `async_op` is False.
1079
1327
 
1080
1328
  Raises:
1081
1329
  TypeError: If `tensor` is not tensor.
@@ -1091,7 +1339,7 @@ def all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes
1091
1339
  For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1092
1340
  without any third-party or configuration file dependencies.
1093
1341
  Please see the `msrun start up
1094
- <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1342
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1095
1343
  for more details.
1096
1344
 
1097
1345
  This example should be run with 2 devices.
@@ -1129,12 +1377,19 @@ def all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes
1129
1377
  if group is None:
1130
1378
  group = GlobalComm.WORLD_COMM_GROUP
1131
1379
 
1380
+ split_sizes_empty = _is_split_sizes_empty(output_split_sizes) and _is_split_sizes_empty(input_split_sizes)
1132
1381
  send_numel_list, recv_numel_list, recv_shape_without_first_dim = \
1133
1382
  _get_all_to_all_single_numel_list(tensor, output_shape, output_split_sizes, input_split_sizes, group)
1134
- _op = _get_cache_prim(P.AlltoAllV)(send_numel_list, recv_numel_list, group)
1383
+ tensor = _contiguous(tensor)
1135
1384
  _input = tensor.reshape(-1)
1136
- result = _op(_input)
1385
+ group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
1386
+ global _GROPU_SIZE_CACHE
1387
+ if group not in _GROPU_SIZE_CACHE:
1388
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
1389
+ rank_size = _GROPU_SIZE_CACHE[group]
1390
+ result = inner_comm_all_to_all_v_op(_input, group, send_numel_list, recv_numel_list, rank_size, split_sizes_empty)
1391
+ result, handle = _deal_comm_outputs(result, async_op)
1137
1392
  if any(recv_numel_list):
1138
1393
  result = result.reshape((-1,) + recv_shape_without_first_dim)
1139
1394
 
1140
- return result
1395
+ return result, handle