mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -52,12 +52,15 @@ class ReduceOp:
52
52
  Before running the following examples, you need to configure the communication environment variables.
53
53
 
54
54
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
55
- Please see the `Ascend tutorial
56
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
55
+ Please see the `rank table Startup
56
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
57
57
  for more details.
58
58
 
59
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
60
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
59
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
60
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
61
+
62
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
63
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
61
64
 
62
65
  This example should be run with multiple devices.
63
66
 
@@ -117,8 +120,9 @@ class AllReduce(Primitive):
117
120
 
118
121
  Args:
119
122
  op (str): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
120
- On the CPU, only 'sum' is supported. Default: ReduceOp.SUM.
121
- group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
123
+ On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
124
+ group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
125
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
122
126
 
123
127
  Inputs:
124
128
  - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
@@ -139,14 +143,17 @@ class AllReduce(Primitive):
139
143
  Before running the following examples, you need to configure the communication environment variables.
140
144
 
141
145
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
142
- Please see the `Ascend tutorial
143
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
146
+ Please see the `rank table Startup
147
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
144
148
  for more details.
145
149
 
146
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
147
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
150
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
151
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
148
152
 
149
- This example should be run with multiple devices.
153
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
154
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
155
+
156
+ This example should be run with 2 devices.
150
157
 
151
158
  >>> import numpy as np
152
159
  >>> from mindspore.communication import init
@@ -170,6 +177,11 @@ class AllReduce(Primitive):
170
177
  >>> print(output)
171
178
  [[2. 2. 2. 2. 2. 2. 2. 2.]
172
179
  [2. 2. 2. 2. 2. 2. 2. 2.]]
180
+
181
+ Tutorial Examples:
182
+ - `Distributed Set Communication Primitives - AllReduce
183
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#allreduce>`_
184
+
173
185
  """
174
186
 
175
187
  @prim_attr_register
@@ -197,7 +209,8 @@ class AllGather(PrimitiveWithInfer):
197
209
  - Currently only supports GRAPH_MODE and it should be called in Cell.
198
210
 
199
211
  Args:
200
- group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
212
+ group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
213
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
201
214
 
202
215
  Inputs:
203
216
  - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
@@ -219,12 +232,15 @@ class AllGather(PrimitiveWithInfer):
219
232
  Before running the following examples, you need to configure the communication environment variables.
220
233
 
221
234
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
222
- Please see the `Ascend tutorial
223
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
235
+ Please see the `rank table Startup
236
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
224
237
  for more details.
225
238
 
226
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
227
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
239
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
240
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
241
+
242
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
243
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
228
244
 
229
245
  This example should be run with 2 devices.
230
246
 
@@ -253,6 +269,11 @@ class AllGather(PrimitiveWithInfer):
253
269
  [1. 1. 1. 1. 1. 1. 1. 1.]
254
270
  [1. 1. 1. 1. 1. 1. 1. 1.]
255
271
  [1. 1. 1. 1. 1. 1. 1. 1.]]
272
+
273
+ Tutorial Examples:
274
+ - `Distributed Set Communication Primitives - AllGather
275
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#allgather>`_
276
+
256
277
  """
257
278
 
258
279
  @prim_attr_register
@@ -268,9 +289,6 @@ class AllGather(PrimitiveWithInfer):
268
289
  self.add_prim_attr('mean_flag', False)
269
290
  self.add_prim_attr('no_eliminate', True)
270
291
 
271
- def __call__(self, tensor):
272
- raise NotImplementedError
273
-
274
292
  def infer_shape(self, x_shape):
275
293
  validator.check_positive_int(len(x_shape), "x shape", self.name)
276
294
  if x_shape[0] > 0:
@@ -288,8 +306,8 @@ class _MiniStepAllGather(PrimitiveWithInfer):
288
306
  internal use of parallel modules and cannot be called by users.
289
307
 
290
308
  Args:
291
- group (str): The communication group to work on. Default: None.
292
- grad_accumulation_step (int): The grad accumulation step. Default: None.
309
+ group (str): The communication group to work on. Default: ``None`` .
310
+ grad_accumulation_step (int): The grad accumulation step. Default: ``None`` .
293
311
  """
294
312
 
295
313
  @prim_attr_register
@@ -324,7 +342,7 @@ class _MicroStepAllGather(PrimitiveWithInfer):
324
342
  internal use of parallel modules and cannot be called by users.
325
343
 
326
344
  Args:
327
- group (str): The communication group to work on. Default: None.
345
+ group (str): The communication group to work on. Default: ``None`` .
328
346
  """
329
347
 
330
348
  @prim_attr_register
@@ -364,7 +382,7 @@ class _HostAllGather(PrimitiveWithInfer):
364
382
  mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
365
383
 
366
384
  Args:
367
- group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
385
+ group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` .
368
386
 
369
387
  Raises:
370
388
  TypeError: If group is not a list nor tuple, or elements of group are not int.
@@ -410,16 +428,14 @@ class _HostAllGather(PrimitiveWithInfer):
410
428
  class ReduceScatter(Primitive):
411
429
  r"""
412
430
  Reduces and scatters tensors from the specified communication group.
413
- For more details about it, please refer to `Distributed Set Communication Primitives - ReduceScatter \
414
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/communicate_ops.html#reducescatter>`_ .
415
431
 
416
432
  Note:
417
433
  The tensors must have the same shape and format in all processes of the collection.
418
434
 
419
435
  Args:
420
- op (str): Specifies an operation used for element-wise reductions,
421
- like SUM and MAX. Default: ReduceOp.SUM.
422
- group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
436
+ op (str, optional): Specifies an operation used for element-wise reductions,
437
+ like SUM and MAX. Default: ``ReduceOp.SUM`` .
438
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
423
439
 
424
440
  Inputs:
425
441
  - **input_x** (Tensor) - Input Tensor, suppose it has a shape :math:`(N, *)`, where `*`
@@ -441,12 +457,15 @@ class ReduceScatter(Primitive):
441
457
  Before running the following examples, you need to configure the communication environment variables.
442
458
 
443
459
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
444
- Please see the `Ascend tutorial
445
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
460
+ Please see the `rank table Startup
461
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
446
462
  for more details.
447
463
 
448
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
449
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
464
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
465
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
466
+
467
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
468
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
450
469
 
451
470
  This example should be run with 2 devices.
452
471
 
@@ -476,6 +495,11 @@ class ReduceScatter(Primitive):
476
495
  [2. 2. 2. 2. 2. 2. 2. 2.]
477
496
  [2. 2. 2. 2. 2. 2. 2. 2.]
478
497
  [2. 2. 2. 2. 2. 2. 2. 2.]]
498
+
499
+ Tutorial Examples:
500
+ - `Distributed Set Communication Primitives - ReduceScatter
501
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#reducescatter>`_
502
+
479
503
  """
480
504
 
481
505
  @prim_attr_register
@@ -490,9 +514,6 @@ class ReduceScatter(Primitive):
490
514
  self.add_prim_attr('fusion', 0)
491
515
  self.add_prim_attr('no_eliminate', True)
492
516
 
493
- def __call__(self, tensor):
494
- raise NotImplementedError
495
-
496
517
 
497
518
  class _HostReduceScatter(PrimitiveWithInfer):
498
519
  """
@@ -506,8 +527,8 @@ class _HostReduceScatter(PrimitiveWithInfer):
506
527
 
507
528
  Args:
508
529
  op (str): Specifies an operation used for element-wise reductions,
509
- like sum, max, avg. Default: ReduceOp.SUM.
510
- group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
530
+ like sum, max, avg. Default: ``ReduceOp.SUM`` .
531
+ group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` .
511
532
 
512
533
  Raises:
513
534
  TypeError: If op is not a string and group is not a list nor tuple,
@@ -558,7 +579,7 @@ class Broadcast(PrimitiveWithInfer):
558
579
  Args:
559
580
  root_rank (int): Source rank. Required in all processes except the one
560
581
  that is sending the data.
561
- group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
582
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
562
583
 
563
584
  Inputs:
564
585
  - **input_x** (tuple[Tensor]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
@@ -578,12 +599,15 @@ class Broadcast(PrimitiveWithInfer):
578
599
  Before running the following examples, you need to configure the communication environment variables.
579
600
 
580
601
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
581
- Please see the `Ascend tutorial
582
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
602
+ Please see the `rank table Startup
603
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
583
604
  for more details.
584
605
 
585
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
586
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
606
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
607
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
608
+
609
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
610
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
587
611
 
588
612
  This example should be run with multiple devices.
589
613
 
@@ -611,6 +635,11 @@ class Broadcast(PrimitiveWithInfer):
611
635
  (Tensor(shape[2,4], dtype=Int32, value=
612
636
  [[1, 1, 1, 1],
613
637
  [1, 1, 1, 1]]),)
638
+
639
+ Tutorial Examples:
640
+ - `Distributed Set Communication Primitives - Broadcast
641
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#broadcast>`_
642
+
614
643
  """
615
644
 
616
645
  @prim_attr_register
@@ -629,7 +658,7 @@ class Broadcast(PrimitiveWithInfer):
629
658
  if not isinstance(x_dtype, tuple):
630
659
  raise TypeError(f"For '{self.name}', the 'input_x' must be a tuple, but got {type(x_dtype).__name__}!")
631
660
  for _ele in x_dtype:
632
- check_collective_target_dtype('input_x', _ele, self.name)
661
+ check_collective_target_dtype('tuple input_x', _ele, self.name)
633
662
  return x_dtype
634
663
 
635
664
 
@@ -671,7 +700,7 @@ class _AllSwap(PrimitiveWithCheck):
671
700
  self.add_prim_attr('order_enforce_skip', True)
672
701
 
673
702
  def __check__(self, tensor_in, send_size, recv_size):
674
- validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
703
+ validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor_type, self.name)
675
704
  validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64],
676
705
  self.name)
677
706
  validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64],
@@ -699,11 +728,11 @@ class NeighborExchange(Primitive):
699
728
  The user needs to preset
700
729
  communication environment variables before running the following example, please check the details on the
701
730
  official website of `MindSpore \
702
- <https://www.mindspore.cn/docs/en/r2.0/api_python/mindspore.ops.html#communication-operator>`_.
731
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.ops.primitive.html#communication-operator>`_.
703
732
 
704
733
  This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
705
734
  in the same subnet, please check the `details \
706
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/communicate_ops.html#notes>`_.
735
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#notes>`_.
707
736
 
708
737
  Args:
709
738
  send_rank_ids (list(int)): Ranks which the data is sent to.
@@ -711,7 +740,7 @@ class NeighborExchange(Primitive):
711
740
  recv_shapes (tuple(list(int))): Data shape which received from recv_rank_ids.
712
741
  send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids.
713
742
  recv_type (type): Data type which received from recv_rank_ids
714
- group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
743
+ group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
715
744
 
716
745
  Inputs:
717
746
  - **input_x** (tuple[Tensor]) - Shapes are same as args of send_shapes.
@@ -742,13 +771,18 @@ class NeighborExchange(Primitive):
742
771
  ... def construct(self, x):
743
772
  ... out = self.neighborexchange((x,))
744
773
  ...
745
- >>> ms.set_context(mode=ms.GRAPH_MODE, device_target='Ascend')
774
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
746
775
  >>> init()
747
776
  >>> net = Net()
748
777
  >>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32)
749
778
  >>> output = net(input_x)
750
779
  >>> print(output)
751
780
  [[2. 2.], [2. 2.]]
781
+
782
+ Tutorial Examples:
783
+ - `Distributed Set Communication Primitives - NeighborExchange
784
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#neighborexchange>`_
785
+
752
786
  """
753
787
 
754
788
  @prim_attr_register
@@ -760,6 +794,7 @@ class NeighborExchange(Primitive):
760
794
  self.recv_shapes = recv_shapes
761
795
  self.send_shapes = send_shapes
762
796
  self.recv_type = recv_type
797
+ self.add_prim_attr('group', _get_group(group))
763
798
  self.add_prim_attr('no_eliminate', True)
764
799
 
765
800
  def __call__(self, tensor):
@@ -779,13 +814,13 @@ class AlltoAll(PrimitiveWithInfer):
779
814
  Note:
780
815
  This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
781
816
  in the same subnet, please check the `details \
782
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/communicate_ops.html#notes>`_.
817
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#notes>`_.
783
818
 
784
819
  Args:
785
820
  split_count (int): On each process, divide blocks into split_count number.
786
821
  split_dim (int): On each process, split blocks along the split_dim.
787
822
  concat_dim (int): On each process, gather the received blocks along the concat_dimension.
788
- group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
823
+ group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
789
824
 
790
825
  Inputs:
791
826
  - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
@@ -809,12 +844,15 @@ class AlltoAll(PrimitiveWithInfer):
809
844
  Before running the following examples, you need to configure the communication environment variables.
810
845
 
811
846
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
812
- Please see the `Ascend tutorial
813
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
847
+ Please see the `rank table Startup
848
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
814
849
  for more details.
815
850
 
816
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
817
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
851
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
852
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
853
+
854
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
855
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
818
856
 
819
857
  This example should be run with 8 devices.
820
858
 
@@ -834,7 +872,7 @@ class AlltoAll(PrimitiveWithInfer):
834
872
  ... out = self.alltoall(x)
835
873
  ... return out
836
874
  ...
837
- >>> ms.set_context(mode=ms.GRAPH_MODE, device_target='Ascend')
875
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
838
876
  >>> init()
839
877
  >>> net = Net()
840
878
  >>> rank_id = int(os.getenv("RANK_ID"))
@@ -842,6 +880,11 @@ class AlltoAll(PrimitiveWithInfer):
842
880
  >>> output = net(input_x)
843
881
  >>> print(output)
844
882
  [[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
883
+
884
+ Tutorial Examples:
885
+ - `Distributed Set Communication Primitives - AlltoAll
886
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#alltoall>`_
887
+
845
888
  """
846
889
 
847
890
  @prim_attr_register
@@ -882,15 +925,13 @@ class NeighborExchangeV2(Primitive):
882
925
  NeighborExchangeV2 is a collective communication operation.
883
926
 
884
927
  NeighborExchangeV2 sends data from the local rank to ranks in the `send_rank_ids`,
885
- as while receive data from `recv_rank_ids`. Please refer to
886
- `Distributed Set Communication Primitives - NeighborExchangeV2 \
887
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/communicate_ops.html#neighborexchangev2>`_
888
- to learn about how the data is exchanged between neighborhood devices.
928
+ as while receive data from `recv_rank_ids`. Please refer to the tutorial examples
929
+ below to learn about how the data is exchanged between neighborhood devices.
889
930
 
890
931
  Note:
891
932
  This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
892
933
  in the same subnet, please check the `details \
893
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/communicate_ops.html#notes>`_.
934
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#notes>`_.
894
935
 
895
936
  Args:
896
937
  send_rank_ids (list(int)): Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one
@@ -902,8 +943,8 @@ class NeighborExchangeV2(Primitive):
902
943
  recv_lens (list(int)): Data lens which received from recv_rank_ids, 4 numbers represent the lens of
903
944
  [recv_top, recv_bottom, recv_left, recv_right].
904
945
  data_format (str): Data format, only support NCHW now.
905
- group (str, optional): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP", which means
906
- "hccl_world_group" in Ascend, and "nccl_world_group" in GPU.
946
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
947
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
907
948
 
908
949
  Inputs:
909
950
  - **input_x** (Tensor) - The Tensor before being exchanged. It has a shape of :math:`(N, C, H, W)`.
@@ -927,42 +968,68 @@ class NeighborExchangeV2(Primitive):
927
968
  Before running the following examples, you need to configure the communication environment variables.
928
969
 
929
970
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
930
- Please see the `Ascend tutorial
931
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
971
+ Please see the `rank table Startup
972
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
932
973
  for more details.
933
974
 
934
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
935
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
975
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
976
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
977
+
978
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
979
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
936
980
 
937
981
  This example should be run with 2 devices.
938
982
 
939
983
  >>> import os
940
984
  >>> import mindspore as ms
941
- >>> from mindspore import Tensor
942
985
  >>> from mindspore.communication import init
943
986
  >>> import mindspore.nn as nn
944
987
  >>> import mindspore.ops as ops
945
988
  >>> import numpy as np
946
- >>> class Net(nn.Cell):
989
+ >>>
990
+ >>> class Net0(nn.Cell):
947
991
  ... def __init__(self):
948
- ... super(Net, self).__init__()
949
- ... self.neighborexchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
950
- ... send_lens=[0, 1, 0, 0],
951
- ... recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
952
- ... recv_lens=[0, 1, 0, 0],
953
- ... data_format="NCHW")
992
+ ... super(Net0, self).__init__()
993
+ ... self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
994
+ ... send_lens=[0, 1, 0, 0],
995
+ ... recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
996
+ ... recv_lens=[0, 1, 0, 0], data_format="NCHW")
954
997
  ...
955
998
  ... def construct(self, x):
956
- ... out = self.neighborexchangev2(x)
999
+ ... out = self.neighbor_exchangev2(x)
957
1000
  ... return out
1001
+ >>>
1002
+ ... class Net1(nn.Cell):
1003
+ ... def __init__(self):
1004
+ ... super(Net1, self).__init__()
1005
+ ... self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
1006
+ ... send_lens=[1, 0, 0, 0],
1007
+ ... recv_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
1008
+ ... recv_lens=[1, 0, 0, 0], data_format="NCHW")
958
1009
  ...
959
- >>> ms.set_context(mode=ms.GRAPH_MODE, device_target='Ascend')
1010
+ ... def construct(self, x):
1011
+ ... out = self.neighbor_exchangev2(x)
1012
+ ... return out
1013
+ >>>
1014
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
960
1015
  >>> init()
961
- >>> input_x = Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32)
962
- >>> net = Net()
963
- >>> output = net(input_x)
964
- >>> print(output)
1016
+ >>> rank_id = int(os.getenv("RANK_ID"))
1017
+ >>> if (rank_id % 2 == 0):
1018
+ >>> input_x = ms.Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32)
1019
+ >>> net = Net0()
1020
+ >>> output = net(input_x)
1021
+ >>> print(output)
1022
+ >>> else:
1023
+ >>> input_x = ms.Tensor(np.ones([1, 1, 2, 2]) * 2, dtype = ms.float32)
1024
+ >>> net = Net1()
1025
+ >>> output = net(input_x)
1026
+ >>> print(output)
965
1027
  [[[[1. 1.], [1. 1.], [2. 2.]]]]
1028
+
1029
+ Tutorial Examples:
1030
+ - `Distributed Set Communication Primitives - NeighborExchangeV2
1031
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#neighborexchangev2>`_
1032
+
966
1033
  """
967
1034
 
968
1035
  @prim_attr_register
@@ -976,6 +1043,15 @@ class NeighborExchangeV2(Primitive):
976
1043
  self.format = data_format
977
1044
  self.add_prim_attr('group', _get_group(group))
978
1045
  self.add_prim_attr('no_eliminate', True)
1046
+ self.rank_size = get_group_size(_get_group(group))
1047
+ for rank_id in send_rank_ids:
1048
+ if rank_id != -1:
1049
+ validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int,
1050
+ "rank_id in send_rank_ids")
1051
+ for rank_id in recv_rank_ids:
1052
+ if rank_id != -1:
1053
+ validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int,
1054
+ "rank_id in recv_rank_ids")
979
1055
 
980
1056
  def __call__(self, tensor):
981
1057
  raise NotImplementedError
@@ -987,9 +1063,9 @@ class _MirrorOperator(PrimitiveWithInfer):
987
1063
  internal use of parallel modules and cannot be called by users.
988
1064
 
989
1065
  Args:
990
- group (str): The communication group to work on. Default: None.
991
- dev_num (int): The device number of the group. Default: None.
992
- mean_flag (bool): Whether use mean in backward. Default: None.
1066
+ group (str): The communication group to work on. Default: ``None`` .
1067
+ dev_num (int): The device number of the group. Default: ``None`` .
1068
+ mean_flag (bool): Whether use mean in backward. Default: ``None`` .
993
1069
  """
994
1070
 
995
1071
  @prim_attr_register
@@ -1017,10 +1093,10 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer):
1017
1093
  internal use of parallel modules and cannot be called by users.
1018
1094
 
1019
1095
  Args:
1020
- group (str): The communication group to work on. Default: None.
1021
- dev_num (int): The device number of the group. Default: None.
1022
- mean_flag (bool): Whether use mean in backward. Default: None.
1023
- grad_accumulation_step (int): The grad accumulation step. Default: None.
1096
+ group (str): The communication group to work on. Default: ``None`` .
1097
+ dev_num (int): The device number of the group. Default: ``None`` .
1098
+ mean_flag (bool): Whether use mean in backward. Default: ``None`` .
1099
+ grad_accumulation_step (int): The grad accumulation step. Default: ``None`` .
1024
1100
  """
1025
1101
 
1026
1102
  @prim_attr_register
@@ -1176,9 +1252,9 @@ class _MirrorMicroStepOperator(PrimitiveWithInfer):
1176
1252
  internal use of parallel modules and cannot be called by users.
1177
1253
 
1178
1254
  Args:
1179
- group (str): The communication group to work on. Default: None.
1180
- dev_num (int): The device number of the group. Default: None.
1181
- mean_flag (bool): Whether use mean in backward. Default: None.
1255
+ group (str): The communication group to work on. Default: ``None`` .
1256
+ dev_num (int): The device number of the group. Default: ``None`` .
1257
+ mean_flag (bool): Whether use mean in backward. Default: ``None`` .
1182
1258
  """
1183
1259
 
1184
1260
  @prim_attr_register
@@ -25,8 +25,8 @@ class GeSwitch(PrimitiveWithInfer):
25
25
  """
26
26
  Adds control switch to data.
27
27
 
28
- Switch data flows into false or true branch depending on the condition. If the condition is true,
29
- the true branch will be activated, or vise verse.
28
+ Switch data flows into ``False`` or ``True`` branch depending on the condition. If the condition is ``True`` ,
29
+ the ``True`` branch will be activated, or vise verse.
30
30
 
31
31
  Inputs:
32
32
  - **data** (Union[Tensor, Number]) - The data to be used for switch control.
@@ -81,7 +81,7 @@ class GeSwitch(PrimitiveWithInfer):
81
81
 
82
82
  def infer_dtype(self, data_type, pred_type):
83
83
  validator.check_subclass(
84
- "data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
84
+ "data", data_type, (mstype.tensor_type,) + mstype.number_type, self.name)
85
85
  validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
86
86
  return data_type, data_type
87
87