mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1395 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """
17
+ Defines communication operators with functional form.
18
+ """
19
+ from mindspore.communication import GlobalComm, get_group_rank_from_world_rank, get_group_size
20
+ from mindspore.communication.management import _get_group
21
+ from mindspore.communication._comm_helper import _get_group_rank_from_world_rank_from_cache_helper
22
+ from mindspore.common.tensor import Tensor
23
+ from mindspore._c_expression import Tensor as Tensor_
24
+ from mindspore.ops import ReduceOp, cat
25
+ from mindspore.ops._primitive_cache import _get_cache_prim
26
+ from mindspore.ops.primitive import _primexpr
27
+ from mindspore.ops.auto_generate.gen_ops_prim import (inner_comm_all_reduce_op, inner_comm_all_gather_op,
28
+ inner_comm_all_to_all_v_op, inner_comm_irecv_op,
29
+ inner_comm_isend_op, inner_comm_reduce_scatter_op)
30
+ from mindspore._c_expression import CommHandle as CommHandle_
31
+ from mindspore import jit_class
32
+
33
+ __all__ = [
34
+ 'all_reduce',
35
+ 'all_gather_into_tensor',
36
+ 'all_to_all_with_output_shape',
37
+ 'all_to_all_single_with_output_shape',
38
+ 'barrier',
39
+ 'broadcast',
40
+ 'gather_into_tensor',
41
+ 'isend',
42
+ 'irecv',
43
+ 'reduce_scatter_tensor',
44
+ 'reduce',
45
+ 'scatter_tensor',
46
+ 'send',
47
+ 'recv',
48
+ 'P2POp',
49
+ 'batch_isend_irecv',
50
+ ]
51
+
52
+ import mindspore.ops.operations as P
53
+
54
+ _GROPU_SIZE_CACHE = {}
55
+
56
+ @jit_class
57
+ class CommHandle(CommHandle_):
58
+ r"""
59
+ Usually, handles are created in C++during the execution of communication operators and returned to the Python
60
+ layer. It will not be created directly in Python. Only in scenarios where graph patterns are compatible,
61
+ handles will be created using Python.
62
+ """
63
+
64
+ def wait(self):
65
+ r"""
66
+ The wait for asynchronous handles will not take effect for handles created on the Python side.
67
+
68
+ >>> import numpy as np
69
+ >>> from mindspore.communication import init
70
+ >>> from mindspore.communication.comm_func import all_reduce
71
+ >>> from mindspore import Tensor
72
+ >>>
73
+ >>> init()
74
+ >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
75
+ >>> output, handle = all_reduce(input_tensor, async_op=True)
76
+ >>> handle.wait()
77
+ >>> print(output)
78
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
79
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
80
+ """
81
+
82
+
83
+ default_handle = CommHandle()
84
+
85
+
86
+ def _check_split_sizes_sequence(tensor, sequence):
87
+ if not sequence:
88
+ raise TypeError(f"sequence can not be empty list.")
89
+ element0 = sequence[0]
90
+ for idx in range(1, len(sequence)):
91
+ if sequence[idx] != element0:
92
+ raise TypeError(f"sequence containing different elements is not supported yet. "
93
+ f"Elements must be the same.")
94
+ if sum(sequence) != tensor.shape[0]:
95
+ raise TypeError(f" The sum of sequence should equal to tensor.shape[0].")
96
+
97
+
98
+ def _check_compute_split_count(tensor, output_split_sizes, input_split_sizes, group):
99
+ """
100
+ Check the output_split_sizes and input_split_sizes by the rules in _check_split_sizes_sequence,
101
+ compute the split count and return it.
102
+ """
103
+ group_size = get_group_size(group)
104
+ if output_split_sizes:
105
+ _check_split_sizes_sequence(tensor, output_split_sizes)
106
+ output_split_value = output_split_sizes[0]
107
+ else:
108
+ output_split_value = None
109
+ if input_split_sizes:
110
+ _check_split_sizes_sequence(tensor, input_split_sizes)
111
+ input_split_value = input_split_sizes[0]
112
+ else:
113
+ input_split_value = None
114
+ split_count = 0
115
+ if input_split_value and output_split_value is None:
116
+ split_count = tensor.shape[0] // input_split_value
117
+ elif input_split_value is None and output_split_value:
118
+ split_count = tensor.shape[0] // output_split_value
119
+ elif input_split_value and output_split_value:
120
+ if input_split_value != output_split_value:
121
+ raise TypeError(f"input_split_value should equal to output_split_value.")
122
+ split_count = tensor.shape[0] // input_split_value
123
+ else:
124
+ split_count = group_size
125
+ return split_count
126
+
127
+
128
+ @_primexpr
129
+ def _check_all_tensors(tensor_list):
130
+ """check all elements in tensor_list are type of Tensor"""
131
+ if not isinstance(tensor_list, (list, tuple)):
132
+ raise TypeError(f"Expected list or tuple, but got {type(tensor_list)}.")
133
+ for t in tensor_list:
134
+ if not isinstance(t, Tensor):
135
+ raise TypeError(f"Expected tensor, but got {type(t)}")
136
+
137
+
138
+ @_primexpr
139
+ def _check_all_tensors_or_tuple(tensor_list):
140
+ """check all elements in tensor_list are type of Tensor or tuple or list"""
141
+ if not isinstance(tensor_list, (list, tuple)):
142
+ raise TypeError(f"Expected list or tuple, but got {type(tensor_list)}.")
143
+ for t in tensor_list:
144
+ if not isinstance(t, (Tensor, tuple, list)):
145
+ raise TypeError(f"Expected tensor or tuple, but got {type(t)}")
146
+
147
+
148
+ @_primexpr
149
+ def _check_all_tensor_same_dtype(*tensor_lists):
150
+ """check all the input tensor has same dtype"""
151
+ consistent_dtype = None
152
+ for list_ in tensor_lists:
153
+ if not isinstance(list_, (list, tuple)):
154
+ list_ = [list_]
155
+ for tensor_ in list_:
156
+ if not isinstance(tensor_, Tensor):
157
+ continue
158
+
159
+ dtype = tensor_.dtype
160
+ if consistent_dtype is None:
161
+ consistent_dtype = dtype
162
+ else:
163
+ if dtype != consistent_dtype:
164
+ raise TypeError("all_to_all input dtype must be the same, "
165
+ f"but got {consistent_dtype} and {dtype}.")
166
+
167
+
168
+ def _get_size(shape):
169
+ numel = 1
170
+ for s in shape:
171
+ numel *= s
172
+ return numel
173
+
174
+
175
+ def _is_split_sizes_empty(split_sizes):
176
+ return split_sizes is None or not split_sizes
177
+
178
+
179
+ def _contiguous(tensor):
180
+ if not tensor.is_contiguous() or tensor.storage_offset() != 0:
181
+ tensor = tensor.contiguous()
182
+ return tensor
183
+
184
+
185
+ def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
186
+ """
187
+ Reduce tensors across all devices in such a way that all deviceswill get the same final result,
188
+ returns the tensor which is all reduced.
189
+
190
+ Note:
191
+ The tensors must have the same shape and format in all processes of the collection.
192
+
193
+ Args:
194
+ tensor (Tensor): The input tensor to be all reduced. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
195
+ op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
196
+ On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
197
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
198
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
199
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
200
+
201
+ Returns:
202
+ Tuple(Tensor, CommHandle), the output tensor has the same shape of the input,
203
+ i.e., :math:`(x_1, x_2, ..., x_R)`. The contents depend on the specified operation.
204
+ CommHandle is an async work handle, if `async_op` is set to True. CommHandle will be None,
205
+ when `async_op` is False.
206
+
207
+ Raises:
208
+ TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str.
209
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
210
+
211
+ Supported Platforms:
212
+ ``Ascend`` ``GPU`` ``CPU``
213
+
214
+ Examples:
215
+ .. note::
216
+ Before running the following examples, you need to configure the communication environment variables.
217
+
218
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
219
+ without any third-party or configuration file dependencies.
220
+ Please see the `msrun start up
221
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
222
+ for more details.
223
+
224
+ This example should be run with 2 devices.
225
+
226
+ >>> import numpy as np
227
+ >>> from mindspore.communication import init
228
+ >>> from mindspore.communication.comm_func import all_reduce
229
+ >>> from mindspore import Tensor
230
+ >>>
231
+ >>> init()
232
+ >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
233
+ >>> output = all_reduce(input_tensor)
234
+ >>> print(output)
235
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
236
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
237
+
238
+ """
239
+ if not isinstance(tensor, (Tensor, Tensor_)):
240
+ raise TypeError("For all_reduce, the input tensor must be tensor")
241
+ if not isinstance(op, str):
242
+ raise TypeError("For all_reduce, the input op type must be str")
243
+ if op not in ('sum', 'prod', 'min', 'max'):
244
+ raise TypeError("For all_reduce, the input op value must be one of sum, prod, min, max")
245
+ group = _get_group(group)
246
+ tensor = _contiguous(tensor)
247
+ output = inner_comm_all_reduce_op(tensor, op, group)
248
+ return _deal_comm_outputs(output, async_op)
249
+
250
+
251
+ def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
252
+ """
253
+ Gathers tensors from the specified communication group and returns the tensor which is all gathered.
254
+
255
+ Note:
256
+ - The tensors must have the same shape and format in all processes of the collection.
257
+
258
+ Args:
259
+ tensor (Tensor): The input tensor to be all gathered into tensor.
260
+ The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
261
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
262
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
263
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
264
+
265
+ Returns:
266
+ Tuple(Tensor, CommHandle), if the number of devices in the group is N,
267
+ then the shape of output tensor is :math:`(N, x_1, x_2, ..., x_R)`.
268
+ CommHandle is an async work handle, if `async_op` is set to True.
269
+ CommHandle will be None, when `async_op` is False.
270
+
271
+ Raises:
272
+ TypeError: If the type of the first input parameter is not Tensor, or `group` is not a str.
273
+ ValueError: If the local rank id of the calling process in the group
274
+ is larger than the group's rank size.
275
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
276
+
277
+ Supported Platforms:
278
+ ``Ascend`` ``GPU``
279
+
280
+ Examples:
281
+ .. note::
282
+ Before running the following examples, you need to configure the communication environment variables.
283
+
284
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
285
+ without any third-party or configuration file dependencies.
286
+ Please see the `msrun start up
287
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
288
+ for more details.
289
+
290
+ This example should be run with 2 devices.
291
+
292
+ >>> import numpy as np
293
+ >>> import mindspore as ms
294
+ >>> from mindspore import ops
295
+ >>> from mindspore.communication import init
296
+ >>> from mindspore.communication.comm_func import all_gather_into_tensor
297
+ >>> from mindspore import Tensor
298
+ >>>
299
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
300
+ >>> init()
301
+ >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
302
+ >>> output = all_gather_into_tensor(input_tensor)
303
+ >>> print(output)
304
+ [[1. 1. 1. 1. 1. 1. 1. 1.]
305
+ [1. 1. 1. 1. 1. 1. 1. 1.]
306
+ [1. 1. 1. 1. 1. 1. 1. 1.]
307
+ [1. 1. 1. 1. 1. 1. 1. 1.]]
308
+
309
+ """
310
+
311
+ if not isinstance(tensor, (Tensor, Tensor_)):
312
+ raise TypeError("For all_gather_into_tensor, the input tensor must be tensor")
313
+ group = _get_group(group)
314
+ global _GROPU_SIZE_CACHE
315
+ if group not in _GROPU_SIZE_CACHE:
316
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
317
+ group_size = _GROPU_SIZE_CACHE[group]
318
+ tensor = _contiguous(tensor)
319
+ output = inner_comm_all_gather_op(tensor, group_size, group)
320
+ return _deal_comm_outputs(output, async_op)
321
+
322
+
323
+ def reduce_scatter_tensor(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
324
+ r"""
325
+ Reduces and scatters tensors from the specified communication group and
326
+ returns the tensor which is reduced and scattered.
327
+
328
+ Note:
329
+ The tensors must have the same shape and format in all processes of the collection.
330
+
331
+ Args:
332
+ tensor(Tensor): The input tensor to be reduced and scattered, suppose it has a shape :math:`(N, *)`, where `*`
333
+ means any number of additional dimensions. N must be divisible by rank_size.
334
+ rank_size refers to the number of cards in the communication group.
335
+ op (str, optional): Specifies an operation used for element-wise reductions,
336
+ like SUM and MAX. Default: ``ReduceOp.SUM`` .
337
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
338
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
339
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
340
+
341
+ Returns:
342
+ Tuple(Tensor, CommHandle), the output tensor has the same dtype as `input_x` with a shape of
343
+ :math:`(N/rank\_size, *)`. CommHandle is an async work handle, if `async_op` is set to True.
344
+ CommHandle will be None, when `async_op` is False.
345
+
346
+ Raises:
347
+ TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str.
348
+ ValueError: If the first dimension of the input cannot be divided by the rank_size.
349
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
350
+
351
+ Supported Platforms:
352
+ ``Ascend`` ``GPU``
353
+
354
+ Examples:
355
+ .. note::
356
+ Before running the following examples, you need to configure the communication environment variables.
357
+
358
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
359
+ without any third-party or configuration file dependencies.
360
+ Please see the `msrun start up
361
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
362
+ for more details.
363
+
364
+ This example should be run with 2 devices.
365
+
366
+ >>> import mindspore as ms
367
+ >>> from mindspore import Tensor
368
+ >>> from mindspore.communication import init
369
+ >>> from mindspore.communication.comm_func import reduce_scatter_tensor
370
+ >>> import numpy as np
371
+ >>>
372
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
373
+ >>> init()
374
+ >>> input_tensor = Tensor(np.ones([8, 8]).astype(np.float32))
375
+ >>> output = reduce_scatter_tensor(input_tensor)
376
+ >>> print(output)
377
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
378
+ [2. 2. 2. 2. 2. 2. 2. 2.]
379
+ [2. 2. 2. 2. 2. 2. 2. 2.]
380
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
381
+
382
+ """
383
+
384
+ if not isinstance(tensor, (Tensor, Tensor_)):
385
+ raise TypeError("For reduce_scatter_tensor, the input tensor must be tensor")
386
+ group = _get_group(group)
387
+ global _GROPU_SIZE_CACHE
388
+ if group not in _GROPU_SIZE_CACHE:
389
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
390
+ rank_size = _GROPU_SIZE_CACHE[group]
391
+ tensor = _contiguous(tensor)
392
+ output = inner_comm_reduce_scatter_op(tensor, rank_size, op, group)
393
+ return _deal_comm_outputs(output, async_op)
394
+
395
+
396
+ def reduce(tensor, dst, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
397
+ """
398
+ Reduces tensors across the processes in the specified communication group, sends the result
399
+ to the target dst(global rank), and returns the tensor which is sent to the target process.
400
+
401
+ Note:
402
+ Only process with destination rank receives the reduced output.
403
+ Only support PyNative mode, Graph mode is not currently supported.
404
+ Other processes only get a tensor with shape [1], which has no mathematical meaning.
405
+
406
+ Args:
407
+ tensor (Tensor): The input tensor to be reduced. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
408
+ dst (int): The target rank of the process(global rank) that receives the reduced output.
409
+ op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
410
+ On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
411
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
412
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
413
+
414
+ Returns:
415
+ Tensor. Return the tensor in the specific rank of the process after reduction.
416
+ The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
417
+
418
+ Raises:
419
+ TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str.
420
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
421
+
422
+ Supported Platforms:
423
+ ``Ascend``
424
+
425
+ Examples:
426
+ .. note::
427
+ Before running the following examples, you need to configure the communication environment variables.
428
+
429
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
430
+ without any third-party or configuration file dependencies.
431
+
432
+ Please see the `msrun start up
433
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
434
+ for more details.
435
+
436
+ This example should be run with 4 devices.
437
+
438
+ >>> from mindspore import ops
439
+ >>> import mindspore.nn as nn
440
+ >>> from mindspore.communication import init
441
+ >>> from mindspore.communication.comm_func import reduce
442
+ >>> from mindspore import Tensor
443
+ >>> import numpy as np
444
+ >>> # Launch 4 processes.
445
+ >>> init()
446
+ >>> dest_rank=1
447
+ >>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
448
+ >>> output = reduce(input_tensor)
449
+ >>> print(output)
450
+ Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
451
+ [4. 4. 4. 4. 4. 4. 4. 4.]],
452
+ Other proesses: [0.].
453
+ """
454
+
455
+ if not isinstance(tensor, (Tensor, Tensor_)):
456
+ raise TypeError("For reduce, the input tensor must be tensor")
457
+ group_rank = get_group_rank_from_world_rank(dst, group)
458
+ reduce_op = _get_cache_prim(P.Reduce)(dest_rank=group_rank, op=op, group=group)
459
+ return reduce_op(tensor)
460
+
461
+
462
+ class P2POp:
463
+ """
464
+ Object for `batch_isend_irecv` input, to store information of ``"isend"`` and ``"irecv"``.
465
+
466
+ Note:
467
+ - Allow pass-in recv shape rather than tensor when `op` is 'irecv'.
468
+ - `tensor` will not be modified in-place by final result.
469
+
470
+ Args:
471
+ op(Union[str, function]): Only string of ``"isend"`` and ``"irecv"`` are allow.
472
+ Or function of ``comm_func.isend`` and ``comm_func.irecv`` are allow.
473
+ tensor(Union[Tensor, Tuple(int)]): tensor for sending/receiving or receive tensor shape
474
+ when `op` is ``"irecv"``.
475
+ peer(int): remote global rank for send/receive.
476
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
477
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
478
+ tag(int, optional): currently not supported yet. default: ``0``.
479
+
480
+ Keyword Args:
481
+ recv_dtype (mindspore.dtype, optional): when `tensor` is a tuple shape, this arg will be used and has
482
+ to be configured. default: ``None``
483
+
484
+ Returns:
485
+ P2POp Object.
486
+
487
+ Raises:
488
+ ValueError: when `op` is not string or function of 'isend' and 'irecv'.
489
+ TypeError: when `tensor` is not type of Tensor or Tuple.
490
+ NotImplementedError: when `tag` is not 0.
491
+
492
+ Supported Platforms:
493
+ ``Ascend``
494
+
495
+ Examples:
496
+ >>> import numpy as np
497
+ >>> import mindspore
498
+ >>> from mindspore.communication.comm_func import P2POp, isend, irecv
499
+ >>> from mindspore import Tensor
500
+ >>> send_tensor = Tensor(1.)
501
+ >>> send_op = P2POp('isend', send_tensor, 1)
502
+ >>> send_op = P2POp(isend, send_tensor, 1)
503
+ >>> recv_tensor = Tensor(0.)
504
+ >>> recv_op = P2POp('irecv', recv_tensor, 0)
505
+ >>> recv_op = P2POp(irecv, recv_tensor, 0)
506
+ >>> recv_op = P2POp('irecv', (), 0, recv_dtype=mindspore.float32)
507
+ """
508
+
509
+ def __init__(self, op, tensor, peer, group=None, tag=0, *, recv_dtype=None):
510
+ self.op = op
511
+ self.tensor = tensor
512
+ self.peer = peer
513
+ self.group = group
514
+ self.tag = tag
515
+ self.recv_dtype = recv_dtype
516
+
517
+ def __new__(cls, op, tensor, peer, group=None, tag=0, recv_dtype=None):
518
+ if isinstance(op, str):
519
+ op_name = op
520
+ else:
521
+ op_name = op.__name__
522
+ if op_name not in ['isend', 'irecv']:
523
+ raise ValueError(f"Expected ``op`` to be of type ``isend`` or `irecv``, but got {op_name}")
524
+ if not isinstance(tensor, (Tensor, tuple)):
525
+ raise TypeError(f"Expected ``tensor`` to be type of tuple or Tensor, but got {type(tensor)}.")
526
+ if tag != 0:
527
+ raise NotImplementedError("``tag`` not support yet.")
528
+ return object.__new__(cls)
529
+
530
+
531
+ def batch_isend_irecv(p2p_op_list):
532
+ """
533
+ Batch send and recv tensors asynchronously.
534
+
535
+ Note:
536
+ - The 'isend' and 'irecv' of `P2POp` in `p2p_op_list` between ranks need to match each other.
537
+ - `P2POp` in `p2p_op_list` can only use the same communication group.
538
+ - `tag` of `P2POp` in `p2p_op_list` is not support yet.
539
+ - `tensor` of `P2POp` in `p2p_op_list` will not be modified by result inplace.
540
+ - Only support PyNative mode, Graph mode is not currently supported.
541
+
542
+ Args:
543
+ p2p_op_list(P2POp): list contains `P2POp`. `P2POp` is type of :class:`mindspore.communication.comm_func.P2POp`
544
+
545
+ Returns:
546
+ tuple(Tensor). Output tensors is corresponding to `p2p_op_list`.
547
+ At `P2POp` with 'isend' position, output tensor is a fake tensor with scalar, which has no meaning.
548
+ At `P2POp` with 'irecv' position, output tensor is a tensor received from remote device.
549
+
550
+ Raises:
551
+ TypeError: If `p2p_op_list` are not all type of `P2POp`.
552
+
553
+ Supported Platforms:
554
+ ``Ascend``
555
+
556
+ Examples:
557
+ .. note::
558
+ Before running the following examples, you need to configure the communication environment variables.
559
+
560
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
561
+ without any third-party or configuration file dependencies.
562
+ Please see the `msrun start up
563
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
564
+ for more details.
565
+
566
+ This example should be run with 2 devices.
567
+
568
+ >>> import numpy as np
569
+ >>> import mindspore
570
+ >>> from mindspore.communication import init, get_rank, get_group_size
571
+ >>> from mindspore.communication.comm_func import batch_isend_irecv, P2POp
572
+ >>> from mindspore import Tensor
573
+ >>>
574
+ >>> init()
575
+ >>> this_rank = get_rank()
576
+ >>> world_size = get_group_size()
577
+ >>> next_rank = (this_rank + 1) % world_size
578
+ >>> prev_rank = (this_rank + world_size - 1) % world_size
579
+ >>>
580
+ >>> send_tensor = Tensor(this_rank + 1, dtype=mindspore.float32)
581
+ >>> recv_tensor = Tensor(0., dtype=mindspore.float32)
582
+ >>>
583
+ >>> send_op = P2POp('isend', send_tensor, next_rank)
584
+ >>> recv_op = P2POp('irecv', recv_tensor, prev_rank)
585
+ >>>
586
+ >>> p2p_op_list = [send_op, recv_op]
587
+ >>> output = batch_isend_irecv(p2p_op_list)
588
+ >>> print(output)
589
+ rank 0:
590
+ (Tensor(shape=[], dtype=Float32, value= 0), Tensor(shape=[], dtype=Float32, value= 2))
591
+ rank 1:
592
+ (Tensor(shape=[], dtype=Float32, value= 0), Tensor(shape=[], dtype=Float32, value= 1))
593
+ """
594
+ send_tensors = []
595
+ op_types = []
596
+ remotes_ranks = []
597
+ receive_shapes = []
598
+ receive_dtypes = []
599
+ tags = []
600
+ if not p2p_op_list:
601
+ raise TypeError(f"p2p_op_list can not be empty list.")
602
+ group = p2p_op_list[0].group
603
+ if group is None:
604
+ group = GlobalComm.WORLD_COMM_GROUP
605
+ type_ = None
606
+ for i, p2p_op in enumerate(p2p_op_list):
607
+ if not isinstance(p2p_op, P2POp):
608
+ raise TypeError("must be type of P2POp")
609
+ if isinstance(p2p_op.op, str):
610
+ type_ = p2p_op.op
611
+ else:
612
+ type_ = p2p_op.op.__name__
613
+ rank_ = p2p_op.peer if p2p_op.group is None else \
614
+ get_group_rank_from_world_rank(p2p_op.peer, p2p_op.group)
615
+ remotes_ranks.append(rank_)
616
+ tags.append(p2p_op.tag)
617
+ if type_ == "isend":
618
+ send_tensors.append(p2p_op.tensor)
619
+ elif type_ == "irecv":
620
+ if isinstance(p2p_op.tensor, Tensor):
621
+ receive_shapes.append(p2p_op.tensor.shape)
622
+ receive_dtypes.append(p2p_op.tensor.dtype)
623
+ elif isinstance(p2p_op.tensor, tuple):
624
+ receive_shapes.append(p2p_op.tensor)
625
+ if p2p_op.recv_dtype is None:
626
+ raise ValueError(f"'recv_dtype' of {i}th P2POp in p2p_op_list is None but op_types is"
627
+ "'irecv' and P2POp.tensor is a tuple type.")
628
+ receive_dtypes.append(p2p_op.recv_dtype)
629
+ else:
630
+ raise TypeError("p2p_op.tensor must be tensor or shape")
631
+ else:
632
+ raise TypeError("p2p_op.op must be isend or irecv")
633
+ op_types.append(type_)
634
+
635
+ _op = _get_cache_prim(P.BatchISendIRecv)(op_types,
636
+ remotes_ranks,
637
+ receive_shapes,
638
+ receive_dtypes,
639
+ group)
640
+ output = _op(send_tensors)
641
+ return output
642
+
643
+
644
+ def scatter_tensor(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP):
645
+ r"""
646
+ Scatter tensor evently across the processes in the specified communication group.
647
+
648
+ Note:
649
+ The interface behavior only support Tensor input and scatter evenly, which
650
+ is different from that of `pytoch.distributed.scatter`.
651
+ Only the tensor in process `src` (global rank) will do scatter.
652
+ Only support PyNative mode, Graph mode is not currently supported.
653
+
654
+ Args:
655
+ tensor (Tensor): The input tensor to be scattered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
656
+ src (int, optional): Specifies the rank(global rank) of the process that send the tensor.
657
+ And only process `src` will send the tensor.
658
+ group (str, optional): The communication group to work on.
659
+ Default: "GlobalComm.WORLD_COMM_GROUP".
660
+
661
+ Returns:
662
+ Tensor, the shape of output is :math:`(x_1/src\_rank, x_2, ..., x_R)`. The dimension 0 of data is equal to
663
+ the dimension of input tensor divided by `src`, and the other dimension keep the same.
664
+
665
+ Raise:
666
+ TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str.
667
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
668
+
669
+ Supported Platforms:
670
+ ``Ascend`` ``GPU``
671
+
672
+ Examples:
673
+ .. note::
674
+ Before running the following examples, you need to configure the communication environment variables.
675
+
676
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
677
+ without any third-party or configuration file dependencies.
678
+ Please see the `msrun start up
679
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
680
+ for more details.
681
+
682
+ This example should be run with 2 devices.
683
+
684
+ >>> import mindspore as ms
685
+ >>> from mindspore.communication import init
686
+ >>> from mindspore.communication.comm_func import scatter_tensor
687
+ >>> import numpy as np
688
+ >>> # Launch 2 processes.
689
+ >>>
690
+ >>> init()
691
+ >>> input = ms.Tensor(np.arange(8).reshape([4, 2]).astype(np.float32))
692
+ >>> out = scatter_tensor(tensor=data, src=0)
693
+ >>> print(out)
694
+ # rank_0
695
+ [[0. 1.]
696
+ [2. 3.]]
697
+ # rank_1
698
+ [[4. 5.]
699
+ [6. 7.]]
700
+ """
701
+ if not isinstance(tensor, (Tensor, Tensor_)):
702
+ raise TypeError("For scatter_tensor, the input tensor must be tensor")
703
+ if not isinstance(src, int):
704
+ raise TypeError("For scatter_tensor, the src must be int")
705
+ _src = get_group_rank_from_world_rank(src, group)
706
+ _op = _get_cache_prim(P.CollectiveScatter)(_src, group)
707
+ return _op(tensor)
708
+
709
+
710
+ def gather_into_tensor(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP):
711
+ r"""
712
+ Gathers tensors from the specified communication group. The operation will gather the tensor
713
+ from processes according to dimension 0.
714
+
715
+ Note:
716
+ Only the tensor in process `dst` (global rank) will keep the gathered tensor. The other process
717
+ will keep a tensor with shape [1], which has no mathematical meaning.
718
+ Only support PyNative mode, Graph mode is not currently supported.
719
+
720
+ Args:
721
+ tensor (Tensor): The tensor to be gathered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
722
+ dst(int, optional): Specifies the rank(global rank) of the process that receive the tensor.
723
+ And only process `dst` will receive the gathered tensor. Default: 0.
724
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
725
+
726
+ Returns:
727
+ Tensor, the shape of output is :math:`(\sum x_1, x_2, ..., x_R)`. The dimension 0 of data is equal to
728
+ sum of the dimension of input tensor, and the other dimension keep the same.
729
+
730
+ Raise:
731
+ TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str.
732
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
733
+
734
+ Supported Platforms:
735
+ ``Ascend``
736
+
737
+ Examples:
738
+ .. note::
739
+ Before running the following examples, you need to configure the communication environment variables.
740
+
741
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
742
+ without any third-party or configuration file dependencies.
743
+ Please see the `msrun start up
744
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
745
+ for more details.
746
+
747
+ This example should be run with 2 devices.
748
+
749
+ >>> import numpy as np
750
+ >>> import mindspore as ms
751
+ >>> import mindspore.nn as nn
752
+ >>> from mindspore.communication import init
753
+ >>> from mindspore import Tensor
754
+ >>> from mindspore.communication.comm_func import gather_into_tensor
755
+ >>> # Launch 2 processes.
756
+ >>>
757
+ >>> init()
758
+ >>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32))
759
+ >>> output = gather_into_tensor(tensor=data, dst=0)
760
+ >>> print(output)
761
+ Process with rank 0: [[0. 1.],
762
+ [2. 3.],
763
+ [0. 1.],
764
+ [2. 3.]]
765
+ Process with rank 1: [0]
766
+ """
767
+ if not isinstance(tensor, (Tensor, Tensor_)):
768
+ raise TypeError("For gather_into_tensor, the input tensor must be tensor")
769
+ if not isinstance(dst, int):
770
+ raise TypeError("For gather_into_tensor, the dst must be int")
771
+ _dst = get_group_rank_from_world_rank(dst, group)
772
+ _op = _get_cache_prim(P.CollectiveGather)(_dst, group)
773
+ return _op(tensor)
774
+
775
+
776
+ def broadcast(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP):
777
+ """
778
+ Broadcasts the tensor to the whole group.
779
+
780
+ Note:
781
+ The tensors must have the same shape and format in all processes of the collection.
782
+ Only support PyNative mode, Graph mode is not currently supported.
783
+
784
+ Args:
785
+ tensor (Tensor): The tensor to be broadcasted. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
786
+ src (int, optional): Specifies the rank(global rank) of the process that broadcast the tensor.
787
+ And only process `src` will broadcast the tensor.
788
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
789
+
790
+ Returns:
791
+ Tensor, tensor has the same shape as input tensor :math:`(x_1, x_2, ..., x_R)`.
792
+
793
+ Raises:
794
+ TypeError: If src is not an integer or group is not a string.
795
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
796
+
797
+ Supported Platforms:
798
+ ``Ascend`` ``GPU``
799
+
800
+ Examples:
801
+ .. note::
802
+ Before running the following examples, you need to configure the communication environment variables.
803
+
804
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
805
+ without any third-party or configuration file dependencies.
806
+ Please see the `msrun start up
807
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
808
+ for more details.
809
+
810
+ This example should be run with 2 devices.
811
+
812
+ >>> import mindspore as ms
813
+ >>> from mindspore import Tensor
814
+ >>> from mindspore.communication import init
815
+ >>> from mindspore.communication.comm_func import broadcast
816
+ >>> import numpy as np
817
+ >>> # Launch 2 processes.
818
+ >>>
819
+ >>> init()
820
+ >>> data = ms.Tensor(np.arange(8).reshape([2, 4]).astype(np.float32))
821
+ >>> out = broadcast(tensor=data, src=0)
822
+ [[0. 1. 2. 3.]
823
+ [4. 5. 6. 7.]]
824
+
825
+ Tutorial Examples:
826
+ - `Distributed Set Communication Primitives - Broadcast
827
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#broadcast>`_
828
+
829
+ """
830
+ if not isinstance(tensor, (Tensor, Tensor_)):
831
+ raise TypeError("For broadcast, the input tensor must be tensor")
832
+ if not isinstance(src, int):
833
+ raise TypeError("For broadcast, the src must be int")
834
+ _src = get_group_rank_from_world_rank(src, group)
835
+ _op = _get_cache_prim(P.Broadcast)(_src, group)
836
+ return _op((tensor,))[0]
837
+
838
+
839
+ def barrier(group=GlobalComm.WORLD_COMM_GROUP):
840
+ """
841
+ Synchronizes all processes in the specified group. Once the process call this operation, it will be blocked until
842
+ all processes call this operation. After all processes finish calling the operations, the blocked processes
843
+ will be woken and continue their task.
844
+
845
+ Args:
846
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
847
+
848
+ Raises:
849
+ RuntimeError: If backend is invalid, or distributed initialization fails.
850
+
851
+ Supported Platforms:
852
+ ``Ascend``
853
+
854
+ Examples:
855
+ .. note::
856
+ Before running the following examples, you need to configure the communication environment variables.
857
+
858
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
859
+ without any third-party or configuration file dependencies.
860
+ Please see the `msrun start up
861
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
862
+ for more details.
863
+
864
+ This example should be run with 2 devices.
865
+
866
+ >>> from mindspore.communication import init
867
+ >>> from mindspore.communication.comm_func import barrier
868
+ >>> # Launch 2 processes.
869
+ >>> init()
870
+ >>> barrier()
871
+
872
+ Tutorial Examples:
873
+ - `Distributed Set Communication Primitives - Barrier
874
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#barrier>`_
875
+ """
876
+ _op = _get_cache_prim(P.Barrier)(group)
877
+ return _op()
878
+
879
+
880
+ def _deal_comm_outputs(output, async_op):
881
+ if isinstance(output, tuple):
882
+ if not async_op:
883
+ output[1].wait()
884
+ return (output[0], None)
885
+ return output
886
+
887
+ if not async_op:
888
+ return (output, None)
889
+ return (output, default_handle)
890
+
891
+
892
+ def send(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
893
+ """
894
+ Send tensors to the specified dest_rank.
895
+
896
+ Note:
897
+ Send and Receive must be used in combination and have same tag.
898
+
899
+ Args:
900
+ tensor (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
901
+ dst (int, optional): A required integer identifying the destination rank(global rank). Default: 0.
902
+ group (str, optional): The communication group to work on.
903
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
904
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
905
+ be received by the Receive op with the same "tag". Default: 0.
906
+
907
+ Raises:
908
+ TypeError: `dst` is not an int or `group` is not a str。
909
+ ValueError: If the rank ID of the process is greater than the rank size of the communication group.
910
+
911
+ Supported Platforms:
912
+ ``Ascend`` ``GPU``
913
+
914
+ Examples:
915
+ .. note::
916
+ Before running the following examples, you need to configure the communication environment variables.
917
+
918
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
919
+ without any third-party or configuration file dependencies.
920
+ Please see the `msrun start up
921
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
922
+ for more details.
923
+
924
+ This example should be run with 2 devices.
925
+
926
+ >>> from mindspore import ops
927
+ >>> import mindspore.nn as nn
928
+ >>> from mindspore.communication import init
929
+ >>> from mindspore.communication.comm_func import send
930
+ >>> from mindspore import Tensor
931
+ >>> import numpy as np
932
+ >>>
933
+ >>> init()
934
+ >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
935
+ >>> send(input_, 0)
936
+ """
937
+ if not isinstance(tensor, (Tensor, Tensor_)):
938
+ raise TypeError("For send, the input tensor must be tensor")
939
+ group = _get_group(group)
940
+ _dst = _get_group_rank_from_world_rank_from_cache_helper(dst, group)
941
+ tensor = _contiguous(tensor)
942
+ output = inner_comm_isend_op(tensor, _dst, group, tag)
943
+ _deal_comm_outputs(output, False)
944
+
945
+
946
+ def recv(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
947
+ """
948
+ Receive tensors from src.
949
+
950
+ Note:
951
+ Send and Receive must be used in combination and have same tag.
952
+ The shape and dtype of input `tensor` is used to receive tensor, but the value
953
+ of input `tensor` would not take effect.
954
+ Only support PyNative mode, Graph mode is not currently supported.
955
+
956
+ Args:
957
+ tensor (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The shape and dtype of this
958
+ tensor is used to receive tensor, but the value of input `tensor` would not take effect.
959
+ src (int, optional): A required integer identifying the source rank(global rank). Default: 0.
960
+ group (str, optional): The communication group to work on.
961
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
962
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
963
+ be received by the Send op with the same "tag". Default: 0.
964
+
965
+ Returns:
966
+ Tensor, the shape of output is :math:`(x_1, x_2, ..., x_R)`.
967
+
968
+ Raises:
969
+ TypeError: If `src` is not an int or `group` is not a str.
970
+ ValueError: If the rank ID of the process is greater than the rank size of the communication group.
971
+
972
+ Supported Platforms:
973
+ ``Ascend`` ``GPU``
974
+
975
+ Examples:
976
+ .. note::
977
+ Before running the following examples, you need to configure the communication environment variables.
978
+
979
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
980
+ without any third-party or configuration file dependencies.
981
+ Please see the `msrun start up
982
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
983
+ for more details.
984
+
985
+ This example should be run with 2 devices.
986
+
987
+ >>> from mindspore import ops
988
+ >>> import mindspore.nn as nn
989
+ >>> from mindspore.communication import init
990
+ >>> from mindspore.communication.comm_func import recv
991
+ >>> from mindspore import Tensor
992
+ >>> import numpy as np
993
+ >>>
994
+ # Launch 2 processes.
995
+ Process 0 send the following array to Process 1
996
+ [[ 0. 1.]
997
+ [ 2. 3.]]
998
+ >>> init()
999
+ >>> x = ms.Tensor(np.zeros([2, 2]))
1000
+ # Process 1 receive tensor from Process 0.
1001
+ >>> out = recv(x, src=0)
1002
+ >>> print(out)
1003
+ [[ 0. 1.]
1004
+ [ 2. 3.]]
1005
+ """
1006
+ if not isinstance(tensor, (Tensor, Tensor_)):
1007
+ raise TypeError("For recv, the input tensor must be tensor")
1008
+ if not isinstance(src, int):
1009
+ raise TypeError("For recv, the src must be int")
1010
+ group = _get_group(group)
1011
+ _src = _get_group_rank_from_world_rank_from_cache_helper(src, group)
1012
+ tensor = _contiguous(tensor)
1013
+ shape = tensor.shape
1014
+ dtype = tensor.dtype
1015
+ output, _ = _deal_comm_outputs(inner_comm_irecv_op(tag, _src, shape, group, dtype), False)
1016
+ return output
1017
+
1018
+
1019
+ def isend(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
1020
+ """
1021
+ Send tensors to the specified dest_rank asynchronously.
1022
+
1023
+ Note:
1024
+ Send and Receive must be used in combination and have same tag.
1025
+ Only support PyNative mode, Graph mode is not currently supported.
1026
+
1027
+ Args:
1028
+ tensor (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1029
+ dst (int, optional): A required integer identifying the destination rank(global rank). Default: 0.
1030
+ group (str, optional): The communication group to work on.
1031
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
1032
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
1033
+ be received by the Receive op with the same "tag". Default: 0.
1034
+
1035
+ Returns:
1036
+ CommHandle, it is an async work handle.
1037
+
1038
+ Raises:
1039
+ TypeError: `dst` is not an int or `group` is not a str。
1040
+ ValueError: If the rank ID of the process is greater than the rank size of the communication group.
1041
+
1042
+ Supported Platforms:
1043
+ ``Ascend`` ``GPU``
1044
+
1045
+ Examples:
1046
+ .. note::
1047
+ Before running the following examples, you need to configure the communication environment variables.
1048
+
1049
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1050
+ without any third-party or configuration file dependencies.
1051
+ Please see the `msrun start up
1052
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1053
+ for more details.
1054
+
1055
+ This example should be run with 2 devices.
1056
+
1057
+ >>> from mindspore import ops
1058
+ >>> import mindspore.nn as nn
1059
+ >>> from mindspore.communication import init
1060
+ >>> from mindspore.communication.comm_func import isend
1061
+ >>> from mindspore import Tensor
1062
+ >>> import numpy as np
1063
+ >>>
1064
+ >>> init()
1065
+ >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
1066
+ >>> handle = isend(input_, 0)
1067
+ >>> handle.wait()
1068
+ """
1069
+ if not isinstance(tensor, (Tensor, Tensor_)):
1070
+ raise TypeError("For isend, the input tensor must be tensor")
1071
+ group = _get_group(group)
1072
+ _dst = _get_group_rank_from_world_rank_from_cache_helper(dst, group)
1073
+ tensor = _contiguous(tensor)
1074
+ output = inner_comm_isend_op(tensor, _dst, group, tag)
1075
+ _, handle = _deal_comm_outputs(output, True)
1076
+ return handle
1077
+
1078
+
1079
+ def irecv(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
1080
+ """
1081
+ Receive tensors from src asynchronously.
1082
+
1083
+ Note:
1084
+ Send and Receive must be used in combination and have same tag.
1085
+ The shape and dtype of input `tensor` is used to receive tensor, but the value
1086
+ of input `tensor` would not take effect.
1087
+ Only support PyNative mode, Graph mode is not currently supported.
1088
+
1089
+ Args:
1090
+ tensor (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The shape and dtype of this
1091
+ tensor is used to receive tensor, but the value of input `tensor` would not take effect.
1092
+ src (int, optional): A required integer identifying the source rank(global rank). Default: 0.
1093
+ group (str, optional): The communication group to work on.
1094
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
1095
+ tag (int, optional): A required integer identifying the send/recv message tag. The message will
1096
+ be received by the Send op with the same "tag". Default: 0.
1097
+
1098
+ Returns:
1099
+ Tuple(Tensor, CommHandle), the shape of output is :math:`(x_1, x_2, ..., x_R)`.
1100
+ CommHandle is an async work handle, if `async_op` is set to True.
1101
+ CommHandle will be None, when `async_op` is False.
1102
+
1103
+ Raises:
1104
+ TypeError: If `src` is not an int or `group` is not a str.
1105
+ ValueError: If the rank ID of the process is greater than the rank size of the communication group.
1106
+
1107
+ Supported Platforms:
1108
+ ``Ascend`` ``GPU``
1109
+
1110
+ Examples:
1111
+ .. note::
1112
+ Before running the following examples, you need to configure the communication environment variables.
1113
+
1114
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1115
+ without any third-party or configuration file dependencies.
1116
+ Please see the `msrun start up
1117
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1118
+ for more details.
1119
+
1120
+ This example should be run with 2 devices.
1121
+
1122
+ >>> from mindspore import ops
1123
+ >>> import mindspore.nn as nn
1124
+ >>> from mindspore.communication import init
1125
+ >>> from mindspore.communication.comm_func import irecv
1126
+ >>> from mindspore import Tensor
1127
+ >>> import numpy as np
1128
+ >>>
1129
+ # Launch 2 processes.
1130
+ Process 0 send the following array to Process 1
1131
+ [[ 0. 1.]
1132
+ [ 2. 3.]]
1133
+ >>> init()
1134
+ >>> x = ms.Tensor(np.zeros([2, 2]))
1135
+ # Process 1 receive tensor from Process 0.
1136
+ >>> out, handle = irecv(x, src=0)
1137
+ >>> handle.wait()
1138
+ >>> print(out)
1139
+ [[ 0. 1.]
1140
+ [ 2. 3.]]
1141
+ """
1142
+ group = _get_group(group)
1143
+ _src = _get_group_rank_from_world_rank_from_cache_helper(src, group)
1144
+ tensor = _contiguous(tensor)
1145
+ shape = tensor.shape
1146
+ dtype = tensor.dtype
1147
+ output = inner_comm_irecv_op(tag, _src, shape, group, dtype)
1148
+ return _deal_comm_outputs(output, True)
1149
+
1150
+
1151
+ def all_to_all_with_output_shape(output_shape_list, input_tensor_list, group=None, async_op=False):
1152
+ """
1153
+ scatter and gather list of tensor to/from all rank according to input/output tensor list.
1154
+
1155
+ Note:
1156
+ tensor shape in `output_shape_list` and `input_tensor_list` should be match across ranks.
1157
+ Only support PyNative mode, Graph mode is not currently supported.
1158
+
1159
+ Args:
1160
+ output_shape_list (Union[Tuple(Tensor), List(Tensor), Tuple(Tuple(int))]): List of shape
1161
+ that indicate the gathered tensors shape from remote ranks.
1162
+ input_tensor_list (Union[Tuple(Tensor), List(Tensor)]):
1163
+ List of tensors to scatter to the remote rank.
1164
+ group (str, optional): The communication group to work on.
1165
+ Default: None, which means "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
1166
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
1167
+
1168
+ Returns:
1169
+ Tuple(Tuple(Tensor), CommHandle), the tensors is gathered from remote ranks.
1170
+ CommHandle is an async work handle, if `async_op` is set to True.
1171
+ CommHandle will be None, when `async_op` is False.
1172
+
1173
+ Raises:
1174
+ TypeError: If `input_tensor_list` is not list of tensors.
1175
+ TypeError: If `output_shape_list` is not list of tuple or tensors.
1176
+ TypeError: If tensors in `input_tensor_list` are not the same type.
1177
+
1178
+ Supported Platforms:
1179
+ ``Ascend``
1180
+
1181
+ Examples:
1182
+ .. note::
1183
+ Before running the following examples, you need to configure the communication environment variables.
1184
+
1185
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1186
+ without any third-party or configuration file dependencies.
1187
+ Please see the `msrun start up
1188
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1189
+ for more details.
1190
+
1191
+ This example should be run with 2 devices.
1192
+
1193
+ >>> import numpy as np
1194
+ >>> import mindspore
1195
+ >>> from mindspore.communication import init, get_rank, get_group_size
1196
+ >>> from mindspore.communication.comm_func import all_to_all_with_output_shape
1197
+ >>> from mindspore import Tensor
1198
+ >>> from mindspore.ops import zeros
1199
+ >>>
1200
+ >>> init()
1201
+ >>> this_rank = get_rank()
1202
+ >>> if this_rank == 0:
1203
+ >>> send_tensor_list = [Tensor(1.), Tensor([[2, 3], [4, 5.]])]
1204
+ >>> recv_tensor_list = [(), (2,)]
1205
+ >>> if this_rank == 1:
1206
+ >>> send_tensor_list = [Tensor([2, 2.]), Tensor([4, 5, 6, 7.])]
1207
+ >>> recv_tensor_list = [(2, 2), (4,)]
1208
+ >>> output = all_to_all_with_output_shape(recv_tensor_list, send_tensor_list)
1209
+ >>> print(output)
1210
+ rank 0:
1211
+ (Tensor(shape=[], dtype=Float32, value= 1),
1212
+ Tensor(shape=[2], dtype=Float32, value= [2.00000000e+00, 2.00000000e+00]))
1213
+ rank 1:
1214
+ (Tensor(shape=[2, 2], dtype=Float32, value=
1215
+ [[2.00000000e+00, 3.00000000e+00],
1216
+ [4.00000000e+00, 5.00000000e+00]]),
1217
+ Tensor(shape=[4], dtype=Float32, value=[4.00000000e+00, 5.00000000e+00, 6.00000000e+00, 7.00000000e+00]))
1218
+
1219
+ """
1220
+
1221
+ _check_all_tensors(input_tensor_list)
1222
+ _check_all_tensors_or_tuple(output_shape_list)
1223
+ _check_all_tensor_same_dtype(input_tensor_list)
1224
+ send_numel_list = []
1225
+ send_flatten_tensor = []
1226
+ recv_numel_list = []
1227
+ recv_shape_list = []
1228
+
1229
+ for tensor in input_tensor_list:
1230
+ send_numel_list.append(tensor.size)
1231
+ send_flatten_tensor.append(tensor.reshape(-1))
1232
+ for tensor in output_shape_list:
1233
+ if isinstance(tensor, Tensor):
1234
+ recv_numel_list.append(tensor.size)
1235
+ recv_shape_list.append(tensor.shape)
1236
+ else:
1237
+ _shape = tensor
1238
+ recv_numel_list.append(_get_size(_shape))
1239
+ recv_shape_list.append(_shape)
1240
+
1241
+ send_flatten_tensor = cat(send_flatten_tensor)
1242
+ send_flatten_tensor = _contiguous(send_flatten_tensor)
1243
+ group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
1244
+ global _GROPU_SIZE_CACHE
1245
+ if group not in _GROPU_SIZE_CACHE:
1246
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
1247
+ rank_size = _GROPU_SIZE_CACHE[group]
1248
+ output = inner_comm_all_to_all_v_op(send_flatten_tensor, group, send_numel_list, recv_numel_list,
1249
+ rank_size, False)
1250
+ output, handle = _deal_comm_outputs(output, async_op)
1251
+ result = []
1252
+ offset = 0
1253
+ for numel, shape in zip(recv_numel_list, recv_shape_list):
1254
+ result.append(output[offset:offset + numel].reshape(shape))
1255
+ offset = offset + numel
1256
+ return (tuple(result), handle)
1257
+
1258
+
1259
+ def _get_all_to_all_single_numel_list(tensor, output_shape, output_split_sizes, input_split_sizes, group):
1260
+ """get numel list for all_to_all_single."""
1261
+ global _GROPU_SIZE_CACHE
1262
+ if _is_split_sizes_empty(input_split_sizes):
1263
+ if group not in _GROPU_SIZE_CACHE:
1264
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
1265
+ _world_size = _GROPU_SIZE_CACHE[group]
1266
+ if tensor.shape[0] % _world_size != 0:
1267
+ raise ValueError("input shape at dim 0 must be divided by world_size, "
1268
+ f"but got {tensor.shape[0]} and {_world_size}.")
1269
+ _split_size = tensor.shape[0] // _world_size
1270
+ input_split_sizes = (_split_size,) * _world_size
1271
+ if _is_split_sizes_empty(output_split_sizes):
1272
+ if group not in _GROPU_SIZE_CACHE:
1273
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
1274
+ _world_size = _GROPU_SIZE_CACHE[group]
1275
+ shape_dim_0 = None
1276
+ if isinstance(output_shape, Tensor):
1277
+ shape_dim_0 = output_shape.shape[0]
1278
+ else:
1279
+ shape_dim_0 = output_shape[0]
1280
+ if shape_dim_0 % _world_size != 0:
1281
+ raise ValueError("output shape at dim 0 must be divided by world_size, "
1282
+ f"but got {shape_dim_0} and {_world_size}.")
1283
+ _split_size = shape_dim_0 // _world_size
1284
+ output_split_sizes = (_split_size,) * _world_size
1285
+
1286
+ send_size_without_first_dim = _get_size(tensor.shape[1:])
1287
+ send_numel_list = [size * send_size_without_first_dim for size in input_split_sizes]
1288
+
1289
+ recv_size_without_first_dim = None
1290
+ recv_shape_without_first_dim = None
1291
+ if isinstance(output_shape, Tensor):
1292
+ recv_shape_without_first_dim = output_shape.shape[1:]
1293
+ recv_size_without_first_dim = _get_size(recv_shape_without_first_dim)
1294
+ else:
1295
+ recv_shape_without_first_dim = output_shape[1:]
1296
+ recv_size_without_first_dim = _get_size(recv_shape_without_first_dim)
1297
+ recv_numel_list = [size * recv_size_without_first_dim for size in output_split_sizes]
1298
+ return send_numel_list, recv_numel_list, recv_shape_without_first_dim
1299
+
1300
+
1301
+ def all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes=None,
1302
+ input_split_sizes=None, group=None, async_op=False):
1303
+ """
1304
+ scatter and gather input with split size to/from all rank, and return result in a single tensor.
1305
+
1306
+ Note:
1307
+ 'output_shape' and 'tensor' shape should be match across ranks.
1308
+ Only support PyNative mode, Graph mode is not currently supported.
1309
+
1310
+ Args:
1311
+ output_shape (Union(Tensor, Tuple(int))): shape to indicate the shape
1312
+ of tensor gathered concatenated from remote rank.
1313
+ tensor (Tensor): tensor to be scattered to remote rank.
1314
+ output_split_sizes (Union(Tuple(int), List(int))): output split size at dim 0. If set to None,
1315
+ it means equally split by ``world_size``. Default: None.
1316
+ input_split_sizes (Union(Tuple(int), List(int))): input split size at dim 0. If set to None,
1317
+ it means equally split by ``world_size``. Default: None.
1318
+ group (str, optional): The communication group to work on.
1319
+ Default: None, which means "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
1320
+ async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
1321
+
1322
+ Returns:
1323
+ Tuple(Tensor, CommHandle), the output tensor is gathered concatenated from remote ranks.
1324
+ If the numel of tensor gathered from remote is zero, it will return a Tensor will value 0,
1325
+ which has no actual meanning. CommHandle is an async work handle, if `async_op` is set to True.
1326
+ CommHandle will be None, when `async_op` is False.
1327
+
1328
+ Raises:
1329
+ TypeError: If `tensor` is not tensor.
1330
+ TypeError: If `output_shape` is not tuple or tensors.
1331
+
1332
+ Supported Platforms:
1333
+ ``Ascend``
1334
+
1335
+ Examples:
1336
+ .. note::
1337
+ Before running the following examples, you need to configure the communication environment variables.
1338
+
1339
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1340
+ without any third-party or configuration file dependencies.
1341
+ Please see the `msrun start up
1342
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1343
+ for more details.
1344
+
1345
+ This example should be run with 2 devices.
1346
+
1347
+ >>> import numpy as np
1348
+ >>> import mindspore
1349
+ >>> from mindspore.communication import init, get_rank, get_group_size
1350
+ >>> from mindspore.communication.comm_func import all_to_all_single_with_output_shape
1351
+ >>> from mindspore import Tensor
1352
+ >>> from mindspore.ops import zeros
1353
+ >>>
1354
+ >>> init()
1355
+ >>> this_rank = get_rank()
1356
+ >>> if this_rank == 0:
1357
+ >>> output_shape = (3, 3)
1358
+ >>> tensor = Tensor([[0, 1, 2.], [3, 4, 5], [6, 7, 8]])
1359
+ >>> result = all_to_all_single_with_output_shape(output_shape, tensor, [2, 1], [2, 1])
1360
+ >>> if this_rank == 1:
1361
+ >>> output_shape = (2, 3)
1362
+ >>> tensor = Tensor([[9, 10., 11], [12, 13, 14]])
1363
+ >>> result = all_to_all_single_with_output_shape(output_shape, tensor)
1364
+ >>> print(result)
1365
+ rank 0:
1366
+ [[ 0. 1. 2.]
1367
+ [ 3. 4. 5.]
1368
+ [ 9. 10. 11.]]
1369
+ rank 1:
1370
+ [[ 6. 7. 8.]
1371
+ [12. 13. 14.]]
1372
+
1373
+ """
1374
+
1375
+ _check_all_tensors([tensor])
1376
+ _check_all_tensors_or_tuple([output_shape])
1377
+ if group is None:
1378
+ group = GlobalComm.WORLD_COMM_GROUP
1379
+
1380
+ split_sizes_empty = _is_split_sizes_empty(output_split_sizes) and _is_split_sizes_empty(input_split_sizes)
1381
+ send_numel_list, recv_numel_list, recv_shape_without_first_dim = \
1382
+ _get_all_to_all_single_numel_list(tensor, output_shape, output_split_sizes, input_split_sizes, group)
1383
+ tensor = _contiguous(tensor)
1384
+ _input = tensor.reshape(-1)
1385
+ group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
1386
+ global _GROPU_SIZE_CACHE
1387
+ if group not in _GROPU_SIZE_CACHE:
1388
+ _GROPU_SIZE_CACHE[group] = get_group_size(group)
1389
+ rank_size = _GROPU_SIZE_CACHE[group]
1390
+ result = inner_comm_all_to_all_v_op(_input, group, send_numel_list, recv_numel_list, rank_size, split_sizes_empty)
1391
+ result, handle = _deal_comm_outputs(result, async_op)
1392
+ if any(recv_numel_list):
1393
+ result = result.reshape((-1,) + recv_shape_without_first_dim)
1394
+
1395
+ return result, handle