mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -18,13 +18,16 @@ import hashlib
18
18
  import builtins
19
19
  import io
20
20
  import pickle
21
+ from datetime import timedelta
21
22
  import numpy as np
22
23
  from mindspore import log as logger
23
24
  from mindspore.common import dtype as mstype
25
+ from mindspore._checkparam import args_type_check
24
26
  from mindspore.ops import ReduceOp, cat
25
27
  from mindspore.common.tensor import Tensor
26
28
  from mindspore._c_expression import TensorPy as Tensor_
27
29
  from mindspore.ops.primitive import _primexpr
30
+ from mindspore.common.api import _pynative_executor
28
31
  from mindspore.communication._comm_helper import (
29
32
  _destroy_group_helper,
30
33
  _get_rank_helper,
@@ -33,10 +36,11 @@ from mindspore.communication._comm_helper import (
33
36
  _get_group_ranks,
34
37
  _is_available,
35
38
  _is_initialized,
39
+ _ExistingGroup,
36
40
  )
41
+ from mindspore.communication.management import _init_without_sched
37
42
  from mindspore.communication import (
38
43
  init,
39
- release,
40
44
  get_group_size,
41
45
  get_world_rank_from_group_rank,
42
46
  create_group,
@@ -58,9 +62,11 @@ from mindspore.ops.auto_generate.gen_ops_prim import (
58
62
  dist_comm_isend_op,
59
63
  dist_comm_all_to_all_v_op,
60
64
  dist_comm_reduce_scatter_tensor_op,
65
+ dist_comm_reduce_scatter_tensor_uneven_op,
61
66
  dist_comm_all_to_all_v_single_op,
62
67
  dist_comm_broadcast_op,
63
68
  dist_comm_all_gather_into_tensor_op,
69
+ dist_comm_all_gather_into_tensor_uneven_op,
64
70
  dist_comm_irecv_op,
65
71
  dist_comm_scatter_tensor_op,
66
72
  dist_comm_gather_into_tensor_op,
@@ -70,7 +76,7 @@ from mindspore.ops.auto_generate.gen_ops_prim import (
70
76
  dist_comm_barrier_op,
71
77
  dist_comm_batch_isend_irecv_op,
72
78
  )
73
- from mindspore._c_expression import TCPStoreClient, GroupOptions
79
+ from mindspore._c_expression import TCPStoreClient, GroupOptions, _finalize_collective
74
80
 
75
81
  _pickler = pickle.Pickler
76
82
  _unpickler = pickle.Unpickler
@@ -144,28 +150,26 @@ class TCPStore:
144
150
 
145
151
  Note:
146
152
  - The function is implemented by CPU and does not involve any hardware operations related to Ascend.
147
- - Currently, all parameters provided by the TCPStore class constructor are not supported.
148
- The master node and port number are uniformly specified by the MindSpore framework.
149
- The following parameters are provided, currently not supported and settings are invalid.
150
- - The current TcpStore function is limited and only supports scenarios where the key is
153
+ - Currently, all parameters provided by the TCPStore class constructor are not supported
154
+ except for `host_name`, `port`, `world_size`, `is_master`, `timeout` and `wait_for_workers`,
155
+ which are reserved parameters and invalid settings.
156
+ - The current TCPStore function is limited and only supports scenarios where the key is
151
157
  less than 4k and the value is less than 1G. Complex scenarios are to be supported.
152
- - The timeout interval for message sending and receiving in the TcpStore function is controlled by
153
- the `MS_RECEIVE_MSG_TIMEOUT` environment variable, in seconds, with a default value of ``15``.
154
- If a timeout occurs, the user needs to increase the configuration value.
155
158
 
156
159
  Args:
157
- host_name (str, invalid, optional): The hostname or IP Address the server store should run on.
158
- Default is ``None``.
159
- port (int, invalid, optional): The port on which the server store should listen for incoming requests.
160
- Default is ``None``.
161
- world_size (int, invalid, optional): The total number of store users (number of clients + 1 for the server).
162
- Default is ``None`` (``None`` indicates a non-fixed number of store users).
163
- is_master (bool, invalid, optional): True when initializing the server store and False for client stores.
160
+ host_name (str): The hostname or IP Address the server store should run on.
161
+ Currently only supports user input IP addresses.
162
+ port (int): The port on which the server store should listen for incoming requests.
163
+ world_size (int, optional): The total number of store users (number of clients + 1 for the server).
164
+ Default is ``None``, indicates a non-fixed number of store users. This parameter is
165
+ only valid for the server.
166
+ is_master (bool, optional): True when initializing the server store and False for client stores.
164
167
  Default is ``False``.
165
- timeout (timedelta, invalid, optional): Timeout used by the store during initialization, Unit: seconds.
166
- Default is ``300``.
167
- wait_for_workers (bool, invalid, optional): Whether to wait for all the workers to connect with the server
168
- store. This is only applicable when `world_size` is a fixed value. Default is ``True``.
168
+ timeout (timedelta, optional): Timeout used by the store during initialization. Default is
169
+ ``timedelta(seconds=300)``.
170
+ wait_for_workers (bool, optional): Whether to wait for all the workers to connect with the server
171
+ store. This is only applicable when `world_size` is a fixed value. Default is ``True``. This
172
+ parameter is only valid for the server.
169
173
  multi_tenant (bool, invalid, optional): If ``True``, all ``TCPStore`` instances in the current process with
170
174
  the same host/port will use the same underlying ``TCPServer``. Default is ``False``.
171
175
  master_listen_fd (int, invalid, optional): If specified, the underlying ``TCPServer`` will listen on this file
@@ -191,12 +195,106 @@ class TCPStore:
191
195
  for more details.
192
196
 
193
197
  >>> from mindspore.mint.distributed import TCPStore
194
- >>> store = TCPStore()
198
+ >>> store = TCPStore("127.0.0.1", 1234)
195
199
  """
196
200
 
197
- def __init__(self, host_name=None, port=None, world_size=None, is_master=False, timeout=300,
201
+ def __init__(self, host_name, port, world_size=None, is_master=False, timeout=timedelta(seconds=300),
198
202
  wait_for_workers=True, multi_tenant=False, master_listen_fd=None, use_libuv=True):
199
- self.instance = TCPStoreClient.get_instance()
203
+ if not isinstance(host_name, str):
204
+ raise TypeError(
205
+ "For 'TCPStore', the argument 'host_name' must be type of string, "
206
+ "but got 'host_name' type : {}.".format(type(host_name))
207
+ )
208
+ if not isinstance(port, int):
209
+ raise TypeError(
210
+ "For 'TCPStore', the argument 'port' must be type of int, "
211
+ "but got 'port' type : {}.".format(type(port))
212
+ )
213
+ if not isinstance(is_master, bool):
214
+ raise TypeError(
215
+ "For 'TCPStore', the argument 'is_master' must be type of bool, "
216
+ "but got 'is_master' type : {}.".format(type(is_master))
217
+ )
218
+ if not isinstance(timeout, timedelta):
219
+ raise TypeError(
220
+ "For 'TCPStore', the argument 'timeout' must be type of timedelta, "
221
+ "but got 'timeout' type : {}.".format(type(timeout))
222
+ )
223
+ if not isinstance(wait_for_workers, bool):
224
+ raise TypeError(
225
+ "For 'TCPStore', the argument 'wait_for_workers' must be type of bool, "
226
+ "but got 'wait_for_workers' type : {}.".format(type(wait_for_workers))
227
+ )
228
+ if world_size is None:
229
+ world_size = 1
230
+ if not isinstance(world_size, int):
231
+ raise TypeError(
232
+ "For 'TCPStore', the argument 'world_size' must be type of int, "
233
+ "but got 'world_size' type : {}.".format(type(world_size))
234
+ )
235
+ if port < 0 or port > 65535:
236
+ raise ValueError(
237
+ "For 'TCPStore', the argument 'port' must be legal, "
238
+ f"but got {port}."
239
+ )
240
+ if world_size <= 0:
241
+ raise ValueError(
242
+ "For 'TCPStore', the argument 'world_size' must be legal, "
243
+ f"but got {world_size}."
244
+ )
245
+ timeout_ms = int(timeout.total_seconds() * 1000)
246
+ self.instance = TCPStoreClient(host_name, port, is_master, timeout_ms, world_size, wait_for_workers)
247
+ self.host = host_name
248
+ self.port = port
249
+
250
+
251
+ def add(self, key, amount):
252
+ """
253
+ When the `add` function is called for the first time with a given key, it creates a counter in
254
+ the storage corresponding to that key, with the initial value set to `amount`. Subsequent calls
255
+ to `add` with the same key increment the counter by amount.
256
+
257
+ Args:
258
+ key (str): The key whose counter value will be incremented.
259
+ amount (int): The amount by which the counter will be incremented.
260
+
261
+ Returns:
262
+ int, value of counter with `key`.
263
+
264
+ Raises:
265
+ TypeError: If `key` is not string.
266
+ TypeError: If `amount` is not int.
267
+ RuntimeError: If the `add` and `set` pass the same `key` and the `value` passed by `set` cannot
268
+ be correctly converted to a numerical value, calling `add` will result in an error.
269
+
270
+ Supported Platforms:
271
+ ``Ascend``
272
+
273
+ Examples:
274
+ .. note::
275
+ Before running the following examples, you need to configure the communication environment variables.
276
+
277
+ For Ascend devices, it is recommended to use the msrun startup method
278
+ without any third-party or configuration file dependencies.
279
+ Please see the `msrun start up
280
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
281
+ for more details.
282
+
283
+ >>> from mindspore.mint.distributed import TCPStore
284
+ >>> store = TCPStore("127.0.0.1", 1234)
285
+ >>> store.add("first_key", 1)
286
+ """
287
+ if not isinstance(key, str):
288
+ raise TypeError(
289
+ "For 'TCPStore.add', the argument 'key' must be type of string, "
290
+ "but got 'key' type : {}.".format(type(key))
291
+ )
292
+ if not isinstance(amount, int):
293
+ raise TypeError(
294
+ "For 'TCPStore.add', the argument 'amount' must be type of string or int, "
295
+ "but got 'amount' type : {}.".format(type(amount))
296
+ )
297
+ return self.instance.add(key, amount)
200
298
 
201
299
 
202
300
  def set(self, key, value):
@@ -227,7 +325,7 @@ class TCPStore:
227
325
  for more details.
228
326
 
229
327
  >>> from mindspore.mint.distributed import TCPStore
230
- >>> store = TCPStore()
328
+ >>> store = TCPStore("127.0.0.1", 1234)
231
329
  >>> store.set("first_key", "first_value")
232
330
  """
233
331
  if not isinstance(key, str):
@@ -245,8 +343,9 @@ class TCPStore:
245
343
 
246
344
  def get(self, key):
247
345
  """
248
- Retrieves the value associated with the given `key` in the store. If `key` is not
249
- present in the store, the function will return "".
346
+ Retrieves the value associated with the given `key` in the store. If the `key` does not exist
347
+ in the storage, this function will wait for the `timeout` set by the class initialization and then
348
+ throw an exception.
250
349
 
251
350
  Args:
252
351
  key (str): The function will return the value associated with this key.
@@ -256,6 +355,7 @@ class TCPStore:
256
355
 
257
356
  Raises:
258
357
  TypeError: If `key` is not string.
358
+ RuntimeError: If `get` runs out of time.
259
359
 
260
360
  Supported Platforms:
261
361
  ``Ascend``
@@ -271,7 +371,7 @@ class TCPStore:
271
371
  for more details.
272
372
 
273
373
  >>> from mindspore.mint.distributed import TCPStore
274
- >>> store = TCPStore()
374
+ >>> store = TCPStore("127.0.0.1", 1234)
275
375
  >>> store.set("first_key", "first_value")
276
376
  >>> data = store.get("first_key")
277
377
  >>> print(data)
@@ -299,7 +399,7 @@ class TCPStore:
299
399
  TypeError: If `key` is not string.
300
400
 
301
401
  Supported Platforms:
302
- ``CPU``
402
+ ``Ascend``
303
403
 
304
404
  Examples:
305
405
  .. note::
@@ -312,7 +412,7 @@ class TCPStore:
312
412
  for more details.
313
413
 
314
414
  >>> from mindspore.mint.distributed import TCPStore
315
- >>> store = TCPStore()
415
+ >>> store = TCPStore("127.0.0.1", 1234)
316
416
  >>> store.set("first_key", "first_value")
317
417
  >>> # This should return true
318
418
  >>> store.delete_key("first_key")
@@ -387,6 +487,7 @@ def is_initialized():
387
487
  return _is_initialized()
388
488
 
389
489
 
490
+ @args_type_check(init_method=str, timeout=timedelta, world_size=int, rank=int, store=TCPStore)
390
491
  def init_process_group(backend="hccl",
391
492
  init_method=None,
392
493
  timeout=None,
@@ -404,26 +505,29 @@ def init_process_group(backend="hccl",
404
505
  and the instantiation and execution of any operation and net.
405
506
 
406
507
  Args:
407
- backend (str, optional): The backend to ues. default is hccl and now only support hccl.
408
- init_method (str, invalid): URL specifying how to init collective communication group. Provides parameters
409
- consistent with pytorch, but is not currently support, setting is invalid.
410
- timeout (timedelta, invalid): Timeout for API executed. Provides parameters consistent with pytorch, but is not
411
- currently support, setting is invalid.
412
- world_size (int, optional): Number of the processes participating in the job.
413
- rank (int, invalid): Rank of the current process. Provides parameters consistent with pytorch, but is not
414
- currently support, setting is invalid.
415
- store (Store, invalid): Key/Value store accessible to all workers, used to exchange connection/address
416
- information. Provides parameters consistent with pytorch, but is not currently support,
417
- setting is invalid.
508
+ backend (str, optional): The backend to ues. Default is ``"hccl"`` and now only support hccl.
509
+ init_method (str, optional): URL specifying how to init collective communication group. Default is ``None``.
510
+ timeout (timedelta, optional): Timeout for API executed. Default is ``None``. Currently, this parameter is
511
+ only supported for host-side cluster network configuration using `init_method` or `store`.
512
+ world_size (int, optional): Number of the processes participating in the job. Default is ``-1``.
513
+ rank (int, optional): Rank of the current process. Default is ``-1``.
514
+ store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process
515
+ communication addresses and connection information. Default is ``None``. Currently, only the
516
+ ``TCPStore`` type is supported.
418
517
  pg_options (ProcessGroupOptions, invalid): process group options specifying what additional options need to be
419
- passed in during the construction of specific process group. Provides parameters consistent with pytorch,
420
- but is not currently support, setting is invalid.
421
- device_id (int, invalid): the device id to exeute. Provides parameters consistent with pytorch, but is not
422
- currently support, setting is invalid.
518
+ passed in during the construction of specific process group. The provided parameter is a reserved
519
+ parameter, and the current setting does not take effect.
520
+ device_id (int, invalid): the device id to exeute. The provided parameter is a reserved parameter,
521
+ and the current setting does not take effect.
423
522
 
424
523
  Raises:
425
524
  ValueError: If `backend` is not hccl.
426
525
  ValueError: If `world_size` is not equal to -1 or process group number.
526
+ ValueError: If both `init_method` and `store` are set.
527
+ ValueError: `world_size` is not correctly set as a positive integer value, when using the initialization
528
+ method `init_method` or `store`.
529
+ ValueError: `rank` is not correctly set as a non-negative integer, when using the initialization method
530
+ `init_method` or `store`.
427
531
  RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails,
428
532
  or the environment variables RANK_ID/MINDSPORE_HCCL_CONFIG_PATH
429
533
  have not been exported when backend is HCCL.
@@ -447,25 +551,34 @@ def init_process_group(backend="hccl",
447
551
  >>> init_process_group()
448
552
  >>> destroy_process_group()
449
553
  """
450
- if init_method is not None:
451
- logger.warning("init_method is ignored, setting is invalid")
452
- if timeout is not None:
453
- logger.warning("timeout is ignored, setting is invalid")
454
- if store is not None:
455
- logger.warning("store is ignored, setting is invalid")
456
554
  if pg_options is not None:
457
555
  logger.warning("pg_options is ignored, setting is invalid")
458
556
  if device_id is not None:
459
557
  logger.warning("device_id is ignored, setting is invalid")
460
- if rank != -1:
461
- logger.warning("rank is ignored, setting is invalid")
462
558
  if backend != "hccl":
463
559
  raise ValueError(
464
560
  "Only support hccl now, please setting backend to hccl or using default value"
465
561
  )
466
562
 
467
- # init hccl & create world group
468
- init(backend)
563
+ if init_method is not None and store is not None:
564
+ raise ValueError(
565
+ "Only one of init_method and store is supported."
566
+ )
567
+ if init_method is not None or store is not None:
568
+ if world_size <= 0:
569
+ raise ValueError(
570
+ "Specified world_size must be a positive integer."
571
+ )
572
+ if rank < 0:
573
+ raise ValueError(
574
+ "Specified rank must be a non-negative integer."
575
+ )
576
+ if timeout is None:
577
+ timeout = timedelta(seconds=300)
578
+ timeout_ms = int(timeout.total_seconds() * 1000)
579
+ _init_without_sched(backend, init_method, timeout_ms, world_size, rank, store)
580
+ else:
581
+ init(backend)
469
582
 
470
583
  if world_size != -1 and world_size != get_group_size():
471
584
  raise ValueError(
@@ -513,7 +626,10 @@ def destroy_process_group(group=None):
513
626
  """
514
627
 
515
628
  if group == GlobalComm.WORLD_COMM_GROUP or group is None:
516
- release()
629
+ _pynative_executor.sync()
630
+ _finalize_collective()
631
+ _ExistingGroup.ITEMS.clear()
632
+ _ExistingGroup.GROUP_RANKS.clear()
517
633
  elif not isinstance(group, str):
518
634
  raise TypeError(
519
635
  "For 'destroy_group', the argument 'group' must be type of string or None, "
@@ -671,6 +787,12 @@ def new_group(ranks=None,
671
787
  hccl_config(dict)
672
788
  }
673
789
 
790
+ `hccl_config` currently only supports "hccl_buffer_size" or "hccl_comm".
791
+
792
+ - hccl_buffer_size (uint32): specifies the size of the HCCL communication buffer.
793
+ - hccl_comm (int64): specifies an existing HcclComm pointer. If "hccl_comm" is set,
794
+ "hccl_buffer_size" will be ignored.
795
+
674
796
  use_local_synchronization (bool, invalid): Currently it is a reserved parameter.
675
797
  group_desc (str, invalid): Currently it is a reserved parameter.
676
798
 
@@ -989,6 +1111,22 @@ def _check_all_tensor_same_dtype_and_shape(*tensor_lists):
989
1111
  )
990
1112
 
991
1113
 
1114
+ @_primexpr
1115
+ def _check_output_shape(output, expected_shape, op_name):
1116
+ if output.shape != expected_shape:
1117
+ raise TypeError(
1118
+ f"For {op_name}, the output shape should be {expected_shape}, "
1119
+ f"but got {output.shape}.")
1120
+
1121
+
1122
+ @_primexpr
1123
+ def _check_output_dtype(output, expected_dtype, op_name):
1124
+ if output.dtype != expected_dtype:
1125
+ raise TypeError(
1126
+ f"For {op_name}, the output dtype should be {expected_dtype}, "
1127
+ f"but got {output.dtype}.")
1128
+
1129
+
992
1130
  def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
993
1131
  """
994
1132
  Reduce tensors across all devices in such a way that all deviceswill get the same final result,
@@ -1153,6 +1291,91 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
1153
1291
  return handle
1154
1292
 
1155
1293
 
1294
+ def all_gather_into_tensor_uneven(output, input, output_split_sizes=None, group=None, async_op=False):
1295
+ r"""
1296
+ Gathers and concatenates tensors across devices with uneven first dimensions.
1297
+
1298
+ Note:
1299
+ - Input tensors must have identical shapes except for the first dimension.
1300
+ - Output tensor's first dimension should equal to the sum of all devices' input first dimensions.
1301
+
1302
+ Args:
1303
+ output (Tensor): Concatenated output tensor with shape :math:`(\sum_{i=0}^{N-1} x_{i1}, x_2, ..., x_R)`,
1304
+ where N is the number of devices in the group.
1305
+ input (Tensor): Local input tensor with shape :math:`(x_{k1}, x_2, ..., x_R)`, where k is current device's rank.
1306
+ output_split_sizes (list[int], optional): Specifies first dimension sizes from each device.
1307
+ Must match actual input dimensions when provided.
1308
+ If ``None``, assumes equal split sizes across devices. Default: ``None``.
1309
+ group (str, optional): The communication group to work on. If ``None``,
1310
+ which means ``"hccl_world_group"`` in Ascend. Default: ``None``.
1311
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False``.
1312
+
1313
+ Returns:
1314
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
1315
+ CommHandle will be None, when `async_op` is False.
1316
+
1317
+ Raises:
1318
+ ValueError: If the shape of `input` does not match the constraints of `output_split_sizes`.
1319
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1320
+
1321
+ Supported Platforms:
1322
+ ``Ascend``
1323
+
1324
+ Examples:
1325
+ .. note::
1326
+ Before running the following examples, you need to configure the communication environment variables.
1327
+
1328
+ For Ascend devices, it is recommended to use the msrun startup method
1329
+ without any third-party or configuration file dependencies.
1330
+ Please see the `msrun start up
1331
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1332
+ for more details.
1333
+
1334
+ This example should be run with 2 devices.
1335
+
1336
+ >>> import numpy as np
1337
+ >>> import mindspore as ms
1338
+ >>> from mindspore import ops
1339
+ >>> from mindspore.mint.distributed import init_process_group, get_rank
1340
+ >>> from mindspore.mint.distributed import all_gather_into_tensor_uneven
1341
+ >>> from mindspore import Tensor
1342
+ >>>
1343
+ >>> ms.set_device(device_target="Ascend")
1344
+ >>> init_process_group()
1345
+ >>> if get_rank() == 0:
1346
+ ... input_tensor = Tensor(np.ones([3, 4]).astype(np.float32))
1347
+ ... else:
1348
+ ... input_tensor = Tensor(np.ones([2, 4]).astype(np.float32))
1349
+ >>> out_tensor = Tensor(np.zeros([5, 4]).astype(np.float32))
1350
+ >>> output_split_sizes = [3, 2]
1351
+ >>> output = all_gather_into_tensor_uneven(out_tensor, input_tensor, output_split_sizes)
1352
+ >>> print(out_tensor)
1353
+ [[1. 1. 1. 1.]
1354
+ [1. 1. 1. 1.]
1355
+ [1. 1. 1. 1.]
1356
+ [1. 1. 1. 1.]
1357
+ [1. 1. 1. 1.]]
1358
+ """
1359
+ if group is None:
1360
+ group = GlobalComm.WORLD_COMM_GROUP
1361
+ if not isinstance(group, str):
1362
+ raise TypeError(
1363
+ "The argument 'group' must be type of string, "
1364
+ "but got 'group' type : {}.".format(type(group))
1365
+ )
1366
+ if not isinstance(async_op, bool):
1367
+ raise TypeError(
1368
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
1369
+ )
1370
+ group_size = get_cache_group_size(group)
1371
+ output_split_sizes = [] if output_split_sizes is None else output_split_sizes
1372
+ result = dist_comm_all_gather_into_tensor_uneven_op(
1373
+ output, input, output_split_sizes, group_size, group
1374
+ )
1375
+ _, handle = _deal_comm_outputs(result, async_op)
1376
+ return handle
1377
+
1378
+
1156
1379
  def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
1157
1380
  r"""
1158
1381
  Reduces and scatters tensors from the specified communication group and
@@ -1243,6 +1466,101 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
1243
1466
  return handle
1244
1467
 
1245
1468
 
1469
+ def reduce_scatter_tensor_uneven(output, input, input_split_sizes=None, op=ReduceOp.SUM, group=None, async_op=False):
1470
+ r"""
1471
+ Reduce tensors from the specified communication group and scatter to the output tensor
1472
+ according to `input_split_sizes`.
1473
+
1474
+ Note:
1475
+ - The input tensor must have identical shape and format across all processes.
1476
+ - The first dimension of input tensor should equal to the sum of `input_split_sizes`.
1477
+
1478
+ Args:
1479
+ output(Tensor): the output tensor has the same dtype as `input` with a shape of
1480
+ :math:`(input\_split\_sizes[rank], *)`, where rank is the local rank id of the device.
1481
+ input(Tensor): The input tensor to be reduced and scattered, Expected shape :math:`(N, *)`, where `*`
1482
+ means any number of additional dimensions. N must equal the sum of `input_split_sizes` across ranks.
1483
+ input_split_sizes (list[int], optional): List specifying how to split the first dimension of input tensor.
1484
+ If ``None``, splits evenly according to group size. Default: ``None``.
1485
+ op (str, optional): Specifies an operation used for element-wise reductions,
1486
+ One of ReduceOp: 'SUM', 'MIN', 'MAX'. Default: ``ReduceOp.SUM``.
1487
+ group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
1488
+ Ascend. Default: ``None``.
1489
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False``.
1490
+
1491
+ Returns:
1492
+ CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
1493
+ CommHandle will be None, when `async_op` is False.
1494
+
1495
+ Raises:
1496
+ ValueError: If the shape of `output` does not match the constraints of `input_split_sizes`.
1497
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1498
+
1499
+ Supported Platforms:
1500
+ ``Ascend``
1501
+
1502
+ Examples:
1503
+ .. note::
1504
+ Before running the following examples, you need to configure the communication environment variables.
1505
+
1506
+ For Ascend devices, it is recommended to use the msrun startup method
1507
+ without any third-party or configuration file dependencies.
1508
+ Please see the `msrun start up
1509
+ <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
1510
+ for more details.
1511
+
1512
+ This example should be run with 2 devices.
1513
+
1514
+ >>> import mindspore as ms
1515
+ >>> from mindspore import Tensor
1516
+ >>> from mindspore.mint.distributed import init_process_group, get_rank
1517
+ >>> from mindspore.mint.distributed import reduce_scatter_tensor_uneven
1518
+ >>> import numpy as np
1519
+ >>>
1520
+ >>> ms.set_device(device_target="Ascend")
1521
+ >>> init_process_group()
1522
+ >>> input_tensor = Tensor(np.ones([5, 8]).astype(np.float32))
1523
+ >>> if get_rank() == 0:
1524
+ ... output_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
1525
+ ... else:
1526
+ ... output_tensor = Tensor(np.ones([3, 8]).astype(np.float32))
1527
+ >>> input_split_sizes = [2, 3]
1528
+ >>> output = reduce_scatter_tensor_uneven(output_tensor, input_tensor, input_split_sizes)
1529
+ >>> print(output_tensor)
1530
+ rank 0:
1531
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
1532
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
1533
+ rank 1:
1534
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
1535
+ [2. 2. 2. 2. 2. 2. 2. 2.]
1536
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
1537
+ """
1538
+ if not isinstance(op, str):
1539
+ raise TypeError("For reduce_scatter_tensor_uneven, the input op type must be str")
1540
+ if op not in ("sum", "min", "max"):
1541
+ raise TypeError(
1542
+ "For reduce_scatter_tensor_uneven, the input op value must be one of sum, prod, min, max"
1543
+ )
1544
+ if group is None:
1545
+ group = GlobalComm.WORLD_COMM_GROUP
1546
+ if not isinstance(group, str):
1547
+ raise TypeError(
1548
+ "The argument 'group' must be type of string, "
1549
+ "but got 'group' type : {}.".format(type(group))
1550
+ )
1551
+ if not isinstance(async_op, bool):
1552
+ raise TypeError(
1553
+ f"The argument 'async_op' must be a bool, but got {type(async_op)}."
1554
+ )
1555
+ input_split_sizes = [] if input_split_sizes is None else input_split_sizes
1556
+ rank_size = get_cache_group_size(group)
1557
+ result = dist_comm_reduce_scatter_tensor_uneven_op(
1558
+ output, input, input_split_sizes, rank_size, op, group
1559
+ )
1560
+ _, handle = _deal_comm_outputs(result, async_op)
1561
+ return handle
1562
+
1563
+
1246
1564
  def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
1247
1565
  """
1248
1566
  Reduces tensors across the processes in the specified communication group, sends the result
@@ -2386,10 +2704,7 @@ def all_to_all_single(output,
2386
2704
 
2387
2705
  def _check_tensor_list(tensor_list, tensor, group_size):
2388
2706
  """check all elements in tensor_list are type of Tensor or tuple or list"""
2389
- if not tensor_list or len(tensor_list) != group_size:
2390
- raise TypeError(
2391
- f"The argument list tensor len must be equal to group rank size, but got {len(tensor_list)}."
2392
- )
2707
+ _check_group_tensor_list(tensor_list, group_size)
2393
2708
  if tensor.dtype != tensor_list[0].dtype:
2394
2709
  raise TypeError(
2395
2710
  f"The argument list tensor type must be equal to tensor type, but got {tensor_list[0].dtype}."
@@ -2400,13 +2715,17 @@ def _check_tensor_list(tensor_list, tensor, group_size):
2400
2715
  )
2401
2716
 
2402
2717
 
2718
+ def _check_group_tensor_list(tensor_list, group_size):
2719
+ if not tensor_list or len(tensor_list) != group_size:
2720
+ raise TypeError(
2721
+ f"The argument list tensor len must be equal to group rank size, but got {len(tensor_list)}."
2722
+ )
2723
+
2724
+
2403
2725
  def all_gather(tensor_list, tensor, group=None, async_op=False):
2404
2726
  """
2405
2727
  Gathers tensors from the specified communication group and returns the tensor list which is all gathered.
2406
2728
 
2407
- Note:
2408
- The tensors must have the same shape and format in all processes of the collection.
2409
-
2410
2729
  Args:
2411
2730
  tensor_list (list[Tensor]): Output list.
2412
2731
  tensor (Tensor): The input tensor to be all gathered into tensor.
@@ -2461,7 +2780,7 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
2461
2780
 
2462
2781
  """
2463
2782
  _check_all_tensors(tensor_list)
2464
- _check_all_tensor_same_dtype_and_shape(tensor_list)
2783
+ _check_all_tensor_same_dtype(tensor_list)
2465
2784
  if not isinstance(tensor, (Tensor, Tensor_)):
2466
2785
  raise TypeError("For all_gather_into_tensor, the input tensor must be tensor")
2467
2786
  if group is None:
@@ -2476,7 +2795,10 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
2476
2795
  f"The argument 'async_op' must be a bool, but got {type(async_op)}."
2477
2796
  )
2478
2797
  group_size = get_cache_group_size(group)
2479
- _check_tensor_list(tensor_list, tensor, group_size)
2798
+ _check_group_tensor_list(tensor_list, group_size)
2799
+ rank_id = get_group_rank_from_world_rank(get_rank(), group)
2800
+ _check_output_shape(tensor, tensor_list[rank_id].shape, "all_gather")
2801
+ _check_output_dtype(tensor, tensor_list[0].dtype, "all_gather")
2480
2802
  result = dist_comm_all_gather_op(tensor_list, tensor, group_size, group)
2481
2803
  _, handle = _deal_comm_outputs(result, async_op)
2482
2804
  return handle
@@ -2487,9 +2809,6 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
2487
2809
  Reduces and scatters tensors from the specified communication group and
2488
2810
  returns the tensor which is reduced and scattered.
2489
2811
 
2490
- Note:
2491
- The tensors must have the same shape and format in all processes of the collection.
2492
-
2493
2812
  Args:
2494
2813
  output (Tensor): the output tensor.
2495
2814
  input_list (list[Tensor]): List of tensors to reduce and scatter.
@@ -2543,7 +2862,7 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
2543
2862
  """
2544
2863
 
2545
2864
  _check_all_tensors(input_list)
2546
- _check_all_tensor_same_dtype_and_shape(input_list)
2865
+ _check_all_tensor_same_dtype(input_list)
2547
2866
  if not isinstance(output, (Tensor, Tensor_)):
2548
2867
  raise TypeError("For reduce_scatter, the output tensor must be tensor")
2549
2868
  if group is None:
@@ -2564,7 +2883,11 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
2564
2883
  "For reduce_scatter, the input op value must be one of sum, prod, min, max"
2565
2884
  )
2566
2885
  rank_size = get_cache_group_size(group)
2567
- _check_tensor_list(input_list, output, rank_size)
2886
+ _check_group_tensor_list(input_list, rank_size)
2887
+
2888
+ rank_id = get_group_rank_from_world_rank(get_rank(), group)
2889
+ _check_output_shape(output, input_list[rank_id].shape, "reduce_scatter")
2890
+ _check_output_dtype(output, input_list[0].dtype, "reduce_scatter")
2568
2891
  result = dist_comm_reduce_scatter_op(output, input_list, rank_size, op, group)
2569
2892
  _, handle = _deal_comm_outputs(result, async_op)
2570
2893
  return handle