mindspore 2.4.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1406) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +53 -0
  18. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +106 -0
  22. mindspore/_checkparam.py +1419 -0
  23. mindspore/_extends/__init__.py +23 -0
  24. mindspore/_extends/builtin_operations.py +224 -0
  25. mindspore/_extends/graph_kernel/__init__.py +17 -0
  26. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  27. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  28. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  29. mindspore/_extends/graph_kernel/model/model.py +553 -0
  30. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  31. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  32. mindspore/_extends/graph_kernel/splitter.py +140 -0
  33. mindspore/_extends/graph_kernel/utils.py +28 -0
  34. mindspore/_extends/parallel_compile/__init__.py +19 -0
  35. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  38. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  39. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  40. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  41. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  42. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  43. mindspore/_extends/parse/__init__.py +49 -0
  44. mindspore/_extends/parse/compile_config.py +299 -0
  45. mindspore/_extends/parse/namespace.py +136 -0
  46. mindspore/_extends/parse/parser.py +1448 -0
  47. mindspore/_extends/parse/resources.py +213 -0
  48. mindspore/_extends/parse/standard_method.py +4475 -0
  49. mindspore/_extends/parse/trope.py +97 -0
  50. mindspore/_extends/pijit/__init__.py +23 -0
  51. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  52. mindspore/_extends/remote/__init__.py +19 -0
  53. mindspore/_extends/remote/kernel_build_server.py +199 -0
  54. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  57. mindspore/_extends/utils.py +68 -0
  58. mindspore/_install_custom.py +43 -0
  59. mindspore/_profiler.py +30 -0
  60. mindspore/amp.py +433 -0
  61. mindspore/atlprov.dll +0 -0
  62. mindspore/avcodec-59.dll +0 -0
  63. mindspore/avdevice-59.dll +0 -0
  64. mindspore/avfilter-8.dll +0 -0
  65. mindspore/avformat-59.dll +0 -0
  66. mindspore/avutil-57.dll +0 -0
  67. mindspore/boost/__init__.py +42 -0
  68. mindspore/boost/adasum.py +319 -0
  69. mindspore/boost/base.py +535 -0
  70. mindspore/boost/boost.py +400 -0
  71. mindspore/boost/boost_cell_wrapper.py +790 -0
  72. mindspore/boost/dim_reduce.py +323 -0
  73. mindspore/boost/grad_accumulation.py +79 -0
  74. mindspore/boost/grad_freeze.py +382 -0
  75. mindspore/boost/group_loss_scale_manager.py +166 -0
  76. mindspore/boost/less_batch_normalization.py +174 -0
  77. mindspore/c1.dll +0 -0
  78. mindspore/c1xx.dll +0 -0
  79. mindspore/c2.dll +0 -0
  80. mindspore/cfgpersist.dll +0 -0
  81. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  82. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  83. mindspore/common/__init__.py +86 -0
  84. mindspore/common/_auto_dynamic.py +68 -0
  85. mindspore/common/_decorator.py +50 -0
  86. mindspore/common/_jit_fallback_utils.py +110 -0
  87. mindspore/common/_monad.py +25 -0
  88. mindspore/common/_pijit_context.py +190 -0
  89. mindspore/common/_register_for_adapter.py +74 -0
  90. mindspore/common/_register_for_recompute.py +48 -0
  91. mindspore/common/_register_for_tensor.py +46 -0
  92. mindspore/common/_stub_tensor.py +210 -0
  93. mindspore/common/_tensor_overload.py +139 -0
  94. mindspore/common/_utils.py +122 -0
  95. mindspore/common/api.py +2064 -0
  96. mindspore/common/auto_dynamic_shape.py +507 -0
  97. mindspore/common/dtype.py +422 -0
  98. mindspore/common/dump.py +130 -0
  99. mindspore/common/file_system.py +48 -0
  100. mindspore/common/generator.py +254 -0
  101. mindspore/common/hook_handle.py +143 -0
  102. mindspore/common/initializer.py +880 -0
  103. mindspore/common/jit_config.py +98 -0
  104. mindspore/common/lazy_inline.py +240 -0
  105. mindspore/common/mindir_util.py +111 -0
  106. mindspore/common/mutable.py +234 -0
  107. mindspore/common/no_inline.py +54 -0
  108. mindspore/common/np_dtype.py +25 -0
  109. mindspore/common/parameter.py +1081 -0
  110. mindspore/common/recompute.py +292 -0
  111. mindspore/common/seed.py +260 -0
  112. mindspore/common/sparse_tensor.py +1175 -0
  113. mindspore/common/symbol.py +122 -0
  114. mindspore/common/tensor.py +5039 -0
  115. mindspore/communication/__init__.py +37 -0
  116. mindspore/communication/_comm_helper.py +501 -0
  117. mindspore/communication/_hccl_management.py +297 -0
  118. mindspore/communication/comm_func.py +1395 -0
  119. mindspore/communication/management.py +673 -0
  120. mindspore/config/op_info.config +533 -0
  121. mindspore/context.py +2077 -0
  122. mindspore/d3dcompiler_47.dll +0 -0
  123. mindspore/dataset/__init__.py +90 -0
  124. mindspore/dataset/audio/__init__.py +61 -0
  125. mindspore/dataset/audio/transforms.py +3690 -0
  126. mindspore/dataset/audio/utils.py +386 -0
  127. mindspore/dataset/audio/validators.py +1172 -0
  128. mindspore/dataset/callback/__init__.py +20 -0
  129. mindspore/dataset/callback/ds_callback.py +368 -0
  130. mindspore/dataset/callback/validators.py +32 -0
  131. mindspore/dataset/core/__init__.py +13 -0
  132. mindspore/dataset/core/config.py +1095 -0
  133. mindspore/dataset/core/datatypes.py +101 -0
  134. mindspore/dataset/core/py_util_helpers.py +65 -0
  135. mindspore/dataset/core/validator_helpers.py +781 -0
  136. mindspore/dataset/debug/__init__.py +21 -0
  137. mindspore/dataset/debug/debug_hook.py +97 -0
  138. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  139. mindspore/dataset/engine/__init__.py +124 -0
  140. mindspore/dataset/engine/cache_admin.py +47 -0
  141. mindspore/dataset/engine/cache_client.py +129 -0
  142. mindspore/dataset/engine/datasets.py +4582 -0
  143. mindspore/dataset/engine/datasets_audio.py +911 -0
  144. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  145. mindspore/dataset/engine/datasets_text.py +2161 -0
  146. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  147. mindspore/dataset/engine/datasets_vision.py +4816 -0
  148. mindspore/dataset/engine/iterators.py +371 -0
  149. mindspore/dataset/engine/obs/__init__.py +23 -0
  150. mindspore/dataset/engine/obs/config_loader.py +68 -0
  151. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  152. mindspore/dataset/engine/obs/util.py +482 -0
  153. mindspore/dataset/engine/offload.py +596 -0
  154. mindspore/dataset/engine/queue.py +304 -0
  155. mindspore/dataset/engine/samplers.py +895 -0
  156. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  157. mindspore/dataset/engine/validators.py +2895 -0
  158. mindspore/dataset/text/__init__.py +51 -0
  159. mindspore/dataset/text/transforms.py +1703 -0
  160. mindspore/dataset/text/utils.py +715 -0
  161. mindspore/dataset/text/validators.py +642 -0
  162. mindspore/dataset/transforms/__init__.py +45 -0
  163. mindspore/dataset/transforms/c_transforms.py +638 -0
  164. mindspore/dataset/transforms/py_transforms.py +393 -0
  165. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  166. mindspore/dataset/transforms/transforms.py +1260 -0
  167. mindspore/dataset/transforms/validators.py +410 -0
  168. mindspore/dataset/utils/__init__.py +19 -0
  169. mindspore/dataset/utils/browse_dataset.py +190 -0
  170. mindspore/dataset/utils/line_reader.py +126 -0
  171. mindspore/dataset/vision/__init__.py +65 -0
  172. mindspore/dataset/vision/c_transforms.py +2641 -0
  173. mindspore/dataset/vision/py_transforms.py +2120 -0
  174. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  175. mindspore/dataset/vision/transforms.py +7295 -0
  176. mindspore/dataset/vision/utils.py +863 -0
  177. mindspore/dataset/vision/validators.py +1483 -0
  178. mindspore/default_config.py +2 -0
  179. mindspore/dnnl.dll +0 -0
  180. mindspore/dpcmi.dll +0 -0
  181. mindspore/experimental/__init__.py +20 -0
  182. mindspore/experimental/es/__init__.py +22 -0
  183. mindspore/experimental/es/embedding_service.py +883 -0
  184. mindspore/experimental/es/embedding_service_layer.py +581 -0
  185. mindspore/experimental/llm_boost/__init__.py +21 -0
  186. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  187. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  188. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  189. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  190. mindspore/experimental/llm_boost/register.py +129 -0
  191. mindspore/experimental/llm_boost/utils.py +31 -0
  192. mindspore/experimental/map_parameter.py +309 -0
  193. mindspore/experimental/optim/__init__.py +40 -0
  194. mindspore/experimental/optim/adadelta.py +161 -0
  195. mindspore/experimental/optim/adagrad.py +168 -0
  196. mindspore/experimental/optim/adam.py +193 -0
  197. mindspore/experimental/optim/adamax.py +170 -0
  198. mindspore/experimental/optim/adamw.py +290 -0
  199. mindspore/experimental/optim/asgd.py +153 -0
  200. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  201. mindspore/experimental/optim/nadam.py +157 -0
  202. mindspore/experimental/optim/optimizer.py +262 -0
  203. mindspore/experimental/optim/radam.py +194 -0
  204. mindspore/experimental/optim/rmsprop.py +154 -0
  205. mindspore/experimental/optim/rprop.py +164 -0
  206. mindspore/experimental/optim/sgd.py +156 -0
  207. mindspore/hal/__init__.py +40 -0
  208. mindspore/hal/_ascend.py +57 -0
  209. mindspore/hal/_base.py +57 -0
  210. mindspore/hal/_cpu.py +56 -0
  211. mindspore/hal/_gpu.py +57 -0
  212. mindspore/hal/contiguous_tensors_handle.py +175 -0
  213. mindspore/hal/device.py +356 -0
  214. mindspore/hal/event.py +179 -0
  215. mindspore/hal/memory.py +326 -0
  216. mindspore/hal/stream.py +357 -0
  217. mindspore/include/OWNERS +7 -0
  218. mindspore/include/api/allocator.h +97 -0
  219. mindspore/include/api/callback/callback.h +93 -0
  220. mindspore/include/api/callback/ckpt_saver.h +41 -0
  221. mindspore/include/api/callback/loss_monitor.h +33 -0
  222. mindspore/include/api/callback/lr_scheduler.h +51 -0
  223. mindspore/include/api/callback/time_monitor.h +34 -0
  224. mindspore/include/api/callback/train_accuracy.h +37 -0
  225. mindspore/include/api/cell.h +90 -0
  226. mindspore/include/api/cfg.h +82 -0
  227. mindspore/include/api/context.h +602 -0
  228. mindspore/include/api/data_type.h +47 -0
  229. mindspore/include/api/delegate.h +178 -0
  230. mindspore/include/api/delegate_api.h +75 -0
  231. mindspore/include/api/dual_abi_helper.h +208 -0
  232. mindspore/include/api/format.h +28 -0
  233. mindspore/include/api/graph.h +46 -0
  234. mindspore/include/api/kernel.h +58 -0
  235. mindspore/include/api/kernel_api.h +168 -0
  236. mindspore/include/api/metrics/accuracy.h +36 -0
  237. mindspore/include/api/metrics/metrics.h +41 -0
  238. mindspore/include/api/model.h +438 -0
  239. mindspore/include/api/model_group.h +91 -0
  240. mindspore/include/api/model_parallel_runner.h +168 -0
  241. mindspore/include/api/serialization.h +185 -0
  242. mindspore/include/api/status.h +192 -0
  243. mindspore/include/api/types.h +431 -0
  244. mindspore/include/api/visible.h +41 -0
  245. mindspore/include/c_api/context_c.h +179 -0
  246. mindspore/include/c_api/data_type_c.h +52 -0
  247. mindspore/include/c_api/format_c.h +46 -0
  248. mindspore/include/c_api/model_c.h +347 -0
  249. mindspore/include/c_api/status_c.h +79 -0
  250. mindspore/include/c_api/tensor_c.h +146 -0
  251. mindspore/include/c_api/types_c.h +67 -0
  252. mindspore/include/dataset/config.h +163 -0
  253. mindspore/include/dataset/constants.h +363 -0
  254. mindspore/include/dataset/execute.h +196 -0
  255. mindspore/include/dataset/text.h +1092 -0
  256. mindspore/include/dataset/transforms.h +638 -0
  257. mindspore/include/dataset/vision.h +2129 -0
  258. mindspore/include/dataset/vision_ascend.h +206 -0
  259. mindspore/include/dataset/vision_lite.h +625 -0
  260. mindspore/jpeg62.dll +0 -0
  261. mindspore/log.py +633 -0
  262. mindspore/mindrecord/__init__.py +43 -0
  263. mindspore/mindrecord/common/__init__.py +17 -0
  264. mindspore/mindrecord/common/constant.py +20 -0
  265. mindspore/mindrecord/common/enums.py +44 -0
  266. mindspore/mindrecord/common/exceptions.py +311 -0
  267. mindspore/mindrecord/config.py +809 -0
  268. mindspore/mindrecord/filereader.py +174 -0
  269. mindspore/mindrecord/filewriter.py +722 -0
  270. mindspore/mindrecord/mindpage.py +210 -0
  271. mindspore/mindrecord/shardheader.py +141 -0
  272. mindspore/mindrecord/shardindexgenerator.py +74 -0
  273. mindspore/mindrecord/shardreader.py +117 -0
  274. mindspore/mindrecord/shardsegment.py +128 -0
  275. mindspore/mindrecord/shardutils.py +185 -0
  276. mindspore/mindrecord/shardwriter.py +237 -0
  277. mindspore/mindrecord/tools/__init__.py +17 -0
  278. mindspore/mindrecord/tools/cifar10.py +140 -0
  279. mindspore/mindrecord/tools/cifar100.py +153 -0
  280. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  281. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  282. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  283. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  284. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  285. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  286. mindspore/mindspore_backend.dll +0 -0
  287. mindspore/mindspore_common.dll +0 -0
  288. mindspore/mindspore_core.dll +0 -0
  289. mindspore/mindspore_glog.dll +0 -0
  290. mindspore/mindspore_np_dtype.dll +0 -0
  291. mindspore/mindspore_ops.dll +0 -0
  292. mindspore/mint/__init__.py +1586 -0
  293. mindspore/mint/distributed/__init__.py +31 -0
  294. mindspore/mint/distributed/distributed.py +254 -0
  295. mindspore/mint/linalg/__init__.py +22 -0
  296. mindspore/mint/nn/__init__.py +757 -0
  297. mindspore/mint/nn/functional.py +679 -0
  298. mindspore/mint/nn/layer/__init__.py +39 -0
  299. mindspore/mint/nn/layer/activation.py +133 -0
  300. mindspore/mint/nn/layer/normalization.py +477 -0
  301. mindspore/mint/nn/layer/pooling.py +110 -0
  302. mindspore/mint/optim/__init__.py +24 -0
  303. mindspore/mint/optim/adamw.py +206 -0
  304. mindspore/mint/special/__init__.py +63 -0
  305. mindspore/msobj140.dll +0 -0
  306. mindspore/mspdb140.dll +0 -0
  307. mindspore/mspdbcore.dll +0 -0
  308. mindspore/mspdbst.dll +0 -0
  309. mindspore/mspft140.dll +0 -0
  310. mindspore/msvcdis140.dll +0 -0
  311. mindspore/msvcp140.dll +0 -0
  312. mindspore/msvcp140_1.dll +0 -0
  313. mindspore/msvcp140_2.dll +0 -0
  314. mindspore/msvcp140_atomic_wait.dll +0 -0
  315. mindspore/msvcp140_codecvt_ids.dll +0 -0
  316. mindspore/multiprocessing/__init__.py +73 -0
  317. mindspore/nn/__init__.py +47 -0
  318. mindspore/nn/cell.py +2787 -0
  319. mindspore/nn/dynamic_lr.py +482 -0
  320. mindspore/nn/grad/__init__.py +21 -0
  321. mindspore/nn/grad/cell_grad.py +196 -0
  322. mindspore/nn/layer/__init__.py +63 -0
  323. mindspore/nn/layer/activation.py +1822 -0
  324. mindspore/nn/layer/basic.py +1629 -0
  325. mindspore/nn/layer/channel_shuffle.py +90 -0
  326. mindspore/nn/layer/combined.py +248 -0
  327. mindspore/nn/layer/container.py +734 -0
  328. mindspore/nn/layer/conv.py +1505 -0
  329. mindspore/nn/layer/dense.py +204 -0
  330. mindspore/nn/layer/embedding.py +869 -0
  331. mindspore/nn/layer/image.py +661 -0
  332. mindspore/nn/layer/math.py +1069 -0
  333. mindspore/nn/layer/normalization.py +1273 -0
  334. mindspore/nn/layer/padding.py +880 -0
  335. mindspore/nn/layer/pooling.py +2302 -0
  336. mindspore/nn/layer/rnn_cells.py +388 -0
  337. mindspore/nn/layer/rnns.py +849 -0
  338. mindspore/nn/layer/thor_layer.py +963 -0
  339. mindspore/nn/layer/timedistributed.py +155 -0
  340. mindspore/nn/layer/transformer.py +823 -0
  341. mindspore/nn/learning_rate_schedule.py +512 -0
  342. mindspore/nn/loss/__init__.py +36 -0
  343. mindspore/nn/loss/loss.py +2924 -0
  344. mindspore/nn/metrics.py +53 -0
  345. mindspore/nn/optim/__init__.py +45 -0
  346. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  347. mindspore/nn/optim/ada_grad.py +217 -0
  348. mindspore/nn/optim/adadelta.py +206 -0
  349. mindspore/nn/optim/adafactor.py +448 -0
  350. mindspore/nn/optim/adam.py +1297 -0
  351. mindspore/nn/optim/adamax.py +220 -0
  352. mindspore/nn/optim/adasum.py +548 -0
  353. mindspore/nn/optim/asgd.py +216 -0
  354. mindspore/nn/optim/ftrl.py +401 -0
  355. mindspore/nn/optim/lamb.py +296 -0
  356. mindspore/nn/optim/lars.py +202 -0
  357. mindspore/nn/optim/lazyadam.py +533 -0
  358. mindspore/nn/optim/momentum.py +239 -0
  359. mindspore/nn/optim/optimizer.py +1034 -0
  360. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  361. mindspore/nn/optim/rmsprop.py +264 -0
  362. mindspore/nn/optim/rprop.py +251 -0
  363. mindspore/nn/optim/sgd.py +237 -0
  364. mindspore/nn/optim/tft_wrapper.py +127 -0
  365. mindspore/nn/optim/thor.py +1310 -0
  366. mindspore/nn/probability/__init__.py +22 -0
  367. mindspore/nn/probability/bijector/__init__.py +35 -0
  368. mindspore/nn/probability/bijector/bijector.py +337 -0
  369. mindspore/nn/probability/bijector/exp.py +65 -0
  370. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  371. mindspore/nn/probability/bijector/invert.py +126 -0
  372. mindspore/nn/probability/bijector/power_transform.py +196 -0
  373. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  374. mindspore/nn/probability/bijector/softplus.py +189 -0
  375. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  376. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  377. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  378. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  379. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  380. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  381. mindspore/nn/probability/distribution/__init__.py +56 -0
  382. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  383. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  384. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  385. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  386. mindspore/nn/probability/distribution/beta.py +391 -0
  387. mindspore/nn/probability/distribution/categorical.py +435 -0
  388. mindspore/nn/probability/distribution/cauchy.py +383 -0
  389. mindspore/nn/probability/distribution/distribution.py +827 -0
  390. mindspore/nn/probability/distribution/exponential.py +350 -0
  391. mindspore/nn/probability/distribution/gamma.py +391 -0
  392. mindspore/nn/probability/distribution/geometric.py +335 -0
  393. mindspore/nn/probability/distribution/gumbel.py +257 -0
  394. mindspore/nn/probability/distribution/half_normal.py +133 -0
  395. mindspore/nn/probability/distribution/laplace.py +128 -0
  396. mindspore/nn/probability/distribution/log_normal.py +272 -0
  397. mindspore/nn/probability/distribution/logistic.py +379 -0
  398. mindspore/nn/probability/distribution/normal.py +336 -0
  399. mindspore/nn/probability/distribution/poisson.py +288 -0
  400. mindspore/nn/probability/distribution/student_t.py +149 -0
  401. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  402. mindspore/nn/probability/distribution/uniform.py +375 -0
  403. mindspore/nn/reinforcement/__init__.py +24 -0
  404. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  405. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  406. mindspore/nn/reinforcement/tensor_array.py +145 -0
  407. mindspore/nn/sparse/__init__.py +23 -0
  408. mindspore/nn/sparse/sparse.py +147 -0
  409. mindspore/nn/wrap/__init__.py +49 -0
  410. mindspore/nn/wrap/cell_wrapper.py +968 -0
  411. mindspore/nn/wrap/grad_reducer.py +608 -0
  412. mindspore/nn/wrap/loss_scale.py +694 -0
  413. mindspore/numpy/__init__.py +121 -0
  414. mindspore/numpy/array_creations.py +2731 -0
  415. mindspore/numpy/array_ops.py +2629 -0
  416. mindspore/numpy/dtypes.py +185 -0
  417. mindspore/numpy/fft.py +966 -0
  418. mindspore/numpy/logic_ops.py +936 -0
  419. mindspore/numpy/math_ops.py +5911 -0
  420. mindspore/numpy/utils.py +214 -0
  421. mindspore/numpy/utils_const.py +565 -0
  422. mindspore/opencv_core452.dll +0 -0
  423. mindspore/opencv_imgcodecs452.dll +0 -0
  424. mindspore/opencv_imgproc452.dll +0 -0
  425. mindspore/ops/__init__.py +56 -0
  426. mindspore/ops/_constants.py +30 -0
  427. mindspore/ops/_grad_experimental/__init__.py +31 -0
  428. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  429. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  430. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  431. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  432. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  433. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  434. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  435. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  436. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  437. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  438. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  439. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  440. mindspore/ops/_op_impl/__init__.py +23 -0
  441. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  442. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  443. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  444. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  445. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  446. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  447. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  448. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  449. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  450. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  451. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  452. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  453. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  454. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  455. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  456. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  457. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  458. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  459. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  460. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  461. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  462. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  463. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  464. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  465. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  466. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  467. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  468. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  469. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  470. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  471. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  472. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  473. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  474. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  475. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  476. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  477. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  478. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  479. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  480. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  481. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  482. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  483. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  484. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  485. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  486. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  487. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  488. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  490. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  491. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  493. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  494. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  495. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  496. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  497. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  498. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  499. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  500. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  501. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  502. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  503. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  504. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  505. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  506. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  507. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  508. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  509. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  510. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  511. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  513. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  514. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  515. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  516. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  517. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  518. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  519. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  520. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  521. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  522. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  523. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  524. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  526. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  527. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  528. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  529. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  530. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  531. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  532. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  533. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  534. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  535. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  537. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  538. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  539. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  540. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  541. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  542. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  543. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  544. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  545. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  546. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  547. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  548. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  549. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  550. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  551. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  552. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  553. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  554. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  555. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  556. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  557. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  558. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  559. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  560. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  561. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  562. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  563. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  564. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  565. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  566. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  567. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  568. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  569. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  570. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  571. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  572. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  574. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  575. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  576. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  577. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  578. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  579. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  580. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  581. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  582. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  583. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  584. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  585. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  586. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  587. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  588. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  589. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  590. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  591. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  592. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  593. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  594. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  595. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  596. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  597. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  598. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  599. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  600. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  601. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  603. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  604. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  605. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  606. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  608. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  609. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  610. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  611. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  612. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  613. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  614. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  615. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  616. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  617. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  618. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  619. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  620. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  621. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  622. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  623. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  624. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  625. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  626. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  627. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  628. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  629. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  630. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  632. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  633. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  634. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  635. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  636. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  637. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  638. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  639. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  640. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  641. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  642. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  643. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  644. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  645. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  646. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  647. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  648. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  650. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  651. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  652. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  653. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  654. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  655. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  656. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  657. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  658. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  659. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  660. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  661. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  662. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  663. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  664. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  665. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  666. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  667. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  668. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  669. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  670. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  673. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  675. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  676. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  677. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  679. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  680. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  681. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  682. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  683. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  684. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  685. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  686. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  687. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  688. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  689. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  690. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  691. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  692. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  693. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  694. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  695. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  696. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  697. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  698. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  699. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  700. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  701. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  702. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  703. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  704. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  705. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  706. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  707. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  708. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  709. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  710. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  711. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  712. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  713. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  714. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  715. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  716. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  717. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  718. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  719. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  720. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  721. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  722. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  723. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  724. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  725. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  726. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  727. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  728. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  729. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  730. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  731. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  732. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  733. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  734. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  735. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  736. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  737. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  738. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  739. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  740. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  741. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  742. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  743. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  744. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  745. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  746. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  747. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  748. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  749. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  750. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  751. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  752. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  753. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  754. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  755. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  756. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  757. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  758. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  759. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  760. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  761. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  762. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  763. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  764. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  765. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  766. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  767. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  768. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  769. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  770. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  772. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  773. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  774. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  775. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  776. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  777. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  778. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  779. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  780. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  781. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  782. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  783. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  784. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  785. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  786. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  787. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  788. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  789. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  790. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  791. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  792. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  793. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  794. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  795. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  796. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  797. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  798. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  799. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  800. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  801. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  802. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  803. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  804. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  805. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  806. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  808. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  809. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  810. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  811. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  812. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  813. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  814. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  815. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  816. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  817. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  818. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  819. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  820. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  821. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  822. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  823. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  825. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  826. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  827. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  828. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  829. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  841. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  843. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  844. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  845. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  846. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  847. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  848. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  849. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  850. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  851. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  852. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  853. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  854. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  855. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  856. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  857. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  858. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  859. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  860. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  861. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  862. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  863. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  864. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  865. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  866. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  868. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  869. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  870. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  871. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  872. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  873. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  874. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  876. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  877. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  878. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  879. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  880. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  881. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  882. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  883. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  884. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  885. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  886. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  887. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  888. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  889. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  890. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  891. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  892. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  893. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  894. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  895. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  896. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  897. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  898. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  899. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  900. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  901. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  902. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  903. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  904. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  905. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  906. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  907. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  908. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  909. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  910. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  911. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  912. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  913. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  914. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  915. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  916. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  917. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  918. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  919. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  920. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  921. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  922. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  923. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  924. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  925. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  926. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  927. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  929. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  930. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  931. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  932. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  933. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  934. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  935. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  936. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  937. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  938. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  939. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  940. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  941. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  942. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  943. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  944. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  945. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  946. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  947. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  948. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  949. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  950. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  951. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  952. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  953. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  954. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  955. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  956. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  957. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  958. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  959. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  960. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  961. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  962. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  963. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  964. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  965. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  966. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  967. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  968. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  969. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  970. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  971. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  972. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  973. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  974. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  975. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  976. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  977. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  978. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  979. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  980. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  981. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  982. mindspore/ops/_op_impl/cpu/div.py +32 -0
  983. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  984. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  985. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  986. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  987. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  988. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  989. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  990. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  991. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  992. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  993. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  994. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  995. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  996. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  997. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  998. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  999. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  1000. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  1001. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  1002. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  1003. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  1004. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  1005. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  1006. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  1007. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  1009. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  1010. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  1011. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  1012. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  1013. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  1014. mindspore/ops/_op_impl/cpu/range.py +34 -0
  1015. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  1016. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  1017. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  1018. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  1019. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  1020. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  1021. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  1022. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1023. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1024. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1025. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1026. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1027. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1028. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1029. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1030. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1031. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1032. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1033. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1034. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1035. mindspore/ops/_primitive_cache.py +90 -0
  1036. mindspore/ops/_register_for_op.py +73 -0
  1037. mindspore/ops/_utils/__init__.py +20 -0
  1038. mindspore/ops/_utils/utils.py +147 -0
  1039. mindspore/ops/_vmap/__init__.py +25 -0
  1040. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1041. mindspore/ops/_vmap/vmap_base.py +533 -0
  1042. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1043. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1044. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1045. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1046. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1047. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1048. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1049. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1050. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1051. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1052. mindspore/ops/auto_generate/__init__.py +31 -0
  1053. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1054. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1055. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1056. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1057. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1058. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1059. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1060. mindspore/ops/composite/__init__.py +71 -0
  1061. mindspore/ops/composite/base.py +1318 -0
  1062. mindspore/ops/composite/env_ops.py +41 -0
  1063. mindspore/ops/composite/math_ops.py +125 -0
  1064. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1065. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1066. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1067. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1068. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1069. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1070. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1071. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1072. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1073. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1074. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1075. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1076. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1077. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1078. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1079. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1080. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1081. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1082. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1083. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1084. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1085. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1086. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1087. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1088. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1089. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1090. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1091. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1092. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1093. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1094. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1095. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1096. mindspore/ops/deprecated.py +315 -0
  1097. mindspore/ops/function/__init__.py +782 -0
  1098. mindspore/ops/function/array_func.py +7226 -0
  1099. mindspore/ops/function/clip_func.py +384 -0
  1100. mindspore/ops/function/debug_func.py +181 -0
  1101. mindspore/ops/function/fft_func.py +44 -0
  1102. mindspore/ops/function/grad/__init__.py +34 -0
  1103. mindspore/ops/function/grad/grad_func.py +1425 -0
  1104. mindspore/ops/function/image_func.py +292 -0
  1105. mindspore/ops/function/linalg_func.py +416 -0
  1106. mindspore/ops/function/math_func.py +12228 -0
  1107. mindspore/ops/function/nn_func.py +8609 -0
  1108. mindspore/ops/function/other_func.py +115 -0
  1109. mindspore/ops/function/parameter_func.py +134 -0
  1110. mindspore/ops/function/random_func.py +1715 -0
  1111. mindspore/ops/function/reshard_func.py +104 -0
  1112. mindspore/ops/function/sparse_func.py +884 -0
  1113. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1114. mindspore/ops/function/spectral_func.py +150 -0
  1115. mindspore/ops/function/vmap_func.py +117 -0
  1116. mindspore/ops/functional.py +464 -0
  1117. mindspore/ops/op_info_register.py +1572 -0
  1118. mindspore/ops/operations/__init__.py +722 -0
  1119. mindspore/ops/operations/_csr_ops.py +403 -0
  1120. mindspore/ops/operations/_custom_grad.py +181 -0
  1121. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1122. mindspore/ops/operations/_grad_ops.py +2978 -0
  1123. mindspore/ops/operations/_infer_ops.py +19 -0
  1124. mindspore/ops/operations/_inner_ops.py +2544 -0
  1125. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1126. mindspore/ops/operations/_ms_kernel.py +601 -0
  1127. mindspore/ops/operations/_ocr_ops.py +379 -0
  1128. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1129. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1130. mindspore/ops/operations/_quant_ops.py +1844 -0
  1131. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1132. mindspore/ops/operations/_scalar_ops.py +106 -0
  1133. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1134. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1135. mindspore/ops/operations/_tensor_array.py +359 -0
  1136. mindspore/ops/operations/_thor_ops.py +807 -0
  1137. mindspore/ops/operations/array_ops.py +6124 -0
  1138. mindspore/ops/operations/comm_ops.py +1985 -0
  1139. mindspore/ops/operations/control_ops.py +127 -0
  1140. mindspore/ops/operations/custom_ops.py +1129 -0
  1141. mindspore/ops/operations/debug_ops.py +678 -0
  1142. mindspore/ops/operations/image_ops.py +1041 -0
  1143. mindspore/ops/operations/inner_ops.py +697 -0
  1144. mindspore/ops/operations/linalg_ops.py +95 -0
  1145. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1146. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1147. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1148. mindspore/ops/operations/math_ops.py +5095 -0
  1149. mindspore/ops/operations/nn_ops.py +9575 -0
  1150. mindspore/ops/operations/other_ops.py +874 -0
  1151. mindspore/ops/operations/random_ops.py +1288 -0
  1152. mindspore/ops/operations/reshard_ops.py +53 -0
  1153. mindspore/ops/operations/rl_ops.py +288 -0
  1154. mindspore/ops/operations/sparse_ops.py +2753 -0
  1155. mindspore/ops/operations/spectral_ops.py +111 -0
  1156. mindspore/ops/primitive.py +1046 -0
  1157. mindspore/ops/signature.py +54 -0
  1158. mindspore/ops/vm_impl_registry.py +91 -0
  1159. mindspore/ops_generate/__init__.py +27 -0
  1160. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1161. mindspore/ops_generate/arg_handler.py +197 -0
  1162. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1163. mindspore/ops_generate/gen_constants.py +36 -0
  1164. mindspore/ops_generate/gen_ops.py +1099 -0
  1165. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1166. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1167. mindspore/ops_generate/gen_utils.py +209 -0
  1168. mindspore/ops_generate/op_proto.py +145 -0
  1169. mindspore/ops_generate/pyboost_utils.py +367 -0
  1170. mindspore/ops_generate/template.py +261 -0
  1171. mindspore/parallel/__init__.py +30 -0
  1172. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1173. mindspore/parallel/_cell_wrapper.py +174 -0
  1174. mindspore/parallel/_cost_model_context.py +700 -0
  1175. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1176. mindspore/parallel/_offload_context.py +275 -0
  1177. mindspore/parallel/_parallel_serialization.py +561 -0
  1178. mindspore/parallel/_ps_context.py +242 -0
  1179. mindspore/parallel/_recovery_context.py +110 -0
  1180. mindspore/parallel/_tensor.py +730 -0
  1181. mindspore/parallel/_transformer/__init__.py +35 -0
  1182. mindspore/parallel/_transformer/layers.py +765 -0
  1183. mindspore/parallel/_transformer/loss.py +251 -0
  1184. mindspore/parallel/_transformer/moe.py +693 -0
  1185. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1186. mindspore/parallel/_transformer/transformer.py +3119 -0
  1187. mindspore/parallel/_utils.py +612 -0
  1188. mindspore/parallel/algo_parameter_config.py +400 -0
  1189. mindspore/parallel/checkpoint_transform.py +650 -0
  1190. mindspore/parallel/cluster/__init__.py +15 -0
  1191. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1192. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1193. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1194. mindspore/parallel/cluster/run.py +136 -0
  1195. mindspore/parallel/mpi/__init__.py +14 -0
  1196. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1197. mindspore/parallel/parameter_broadcast.py +151 -0
  1198. mindspore/parallel/shard.py +481 -0
  1199. mindspore/parallel/transform_safetensors.py +993 -0
  1200. mindspore/perf_msvcbuildinsights.dll +0 -0
  1201. mindspore/pgodb140.dll +0 -0
  1202. mindspore/pgort140.dll +0 -0
  1203. mindspore/profiler/__init__.py +28 -0
  1204. mindspore/profiler/common/__init__.py +14 -0
  1205. mindspore/profiler/common/constant.py +29 -0
  1206. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1207. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1208. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1209. mindspore/profiler/common/process_pool.py +41 -0
  1210. mindspore/profiler/common/registry.py +47 -0
  1211. mindspore/profiler/common/singleton.py +28 -0
  1212. mindspore/profiler/common/struct_type.py +118 -0
  1213. mindspore/profiler/common/util.py +472 -0
  1214. mindspore/profiler/common/validator/__init__.py +14 -0
  1215. mindspore/profiler/common/validator/validate_path.py +84 -0
  1216. mindspore/profiler/dynamic_profiler.py +694 -0
  1217. mindspore/profiler/envprofiling.py +254 -0
  1218. mindspore/profiler/parser/__init__.py +14 -0
  1219. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1220. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1221. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1222. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1223. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1229. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1230. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1231. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1232. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1233. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1234. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1235. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1236. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1237. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1238. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1239. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1240. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1241. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1242. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1243. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1244. mindspore/profiler/parser/container.py +229 -0
  1245. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1246. mindspore/profiler/parser/flops_parser.py +531 -0
  1247. mindspore/profiler/parser/framework_enum.py +111 -0
  1248. mindspore/profiler/parser/framework_parser.py +464 -0
  1249. mindspore/profiler/parser/framework_struct.py +61 -0
  1250. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1251. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1252. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1253. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1254. mindspore/profiler/parser/hccl_parser.py +573 -0
  1255. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1256. mindspore/profiler/parser/integrator.py +526 -0
  1257. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1258. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1259. mindspore/profiler/parser/minddata_parser.py +186 -0
  1260. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1261. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1262. mindspore/profiler/parser/optime_parser.py +250 -0
  1263. mindspore/profiler/parser/profiler_info.py +213 -0
  1264. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1265. mindspore/profiler/profiler.py +153 -0
  1266. mindspore/profiler/profiling.py +1922 -0
  1267. mindspore/rewrite/__init__.py +28 -0
  1268. mindspore/rewrite/api/__init__.py +17 -0
  1269. mindspore/rewrite/api/node.py +519 -0
  1270. mindspore/rewrite/api/node_type.py +53 -0
  1271. mindspore/rewrite/api/pattern_engine.py +490 -0
  1272. mindspore/rewrite/api/scoped_value.py +181 -0
  1273. mindspore/rewrite/api/symbol_tree.py +497 -0
  1274. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1275. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1276. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1277. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1278. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1279. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1280. mindspore/rewrite/common/__init__.py +19 -0
  1281. mindspore/rewrite/common/config.py +24 -0
  1282. mindspore/rewrite/common/error_log.py +39 -0
  1283. mindspore/rewrite/common/event.py +28 -0
  1284. mindspore/rewrite/common/namer.py +271 -0
  1285. mindspore/rewrite/common/namespace.py +118 -0
  1286. mindspore/rewrite/common/observable.py +44 -0
  1287. mindspore/rewrite/common/observer.py +54 -0
  1288. mindspore/rewrite/node/__init__.py +22 -0
  1289. mindspore/rewrite/node/call_function.py +95 -0
  1290. mindspore/rewrite/node/cell_container.py +139 -0
  1291. mindspore/rewrite/node/control_flow.py +113 -0
  1292. mindspore/rewrite/node/node.py +1428 -0
  1293. mindspore/rewrite/node/node_manager.py +283 -0
  1294. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1295. mindspore/rewrite/parsers/__init__.py +29 -0
  1296. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1297. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1298. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1299. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1300. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1301. mindspore/rewrite/parsers/container_parser.py +88 -0
  1302. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1303. mindspore/rewrite/parsers/for_parser.py +61 -0
  1304. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1305. mindspore/rewrite/parsers/if_parser.py +85 -0
  1306. mindspore/rewrite/parsers/module_parser.py +117 -0
  1307. mindspore/rewrite/parsers/parser.py +43 -0
  1308. mindspore/rewrite/parsers/parser_register.py +86 -0
  1309. mindspore/rewrite/parsers/return_parser.py +37 -0
  1310. mindspore/rewrite/parsers/while_parser.py +59 -0
  1311. mindspore/rewrite/sparsify/__init__.py +0 -0
  1312. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1313. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1314. mindspore/rewrite/sparsify/utils.py +179 -0
  1315. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1316. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1317. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1318. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1319. mindspore/run_check/__init__.py +20 -0
  1320. mindspore/run_check/_check_version.py +507 -0
  1321. mindspore/run_check/run_check.py +66 -0
  1322. mindspore/safeguard/__init__.py +18 -0
  1323. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1324. mindspore/swresample-4.dll +0 -0
  1325. mindspore/swscale-6.dll +0 -0
  1326. mindspore/tbbmalloc.dll +0 -0
  1327. mindspore/tinyxml2.dll +0 -0
  1328. mindspore/train/__init__.py +48 -0
  1329. mindspore/train/_utils.py +465 -0
  1330. mindspore/train/amp.py +935 -0
  1331. mindspore/train/anf_ir_pb2.py +1517 -0
  1332. mindspore/train/callback/__init__.py +44 -0
  1333. mindspore/train/callback/_backup_and_restore.py +117 -0
  1334. mindspore/train/callback/_callback.py +613 -0
  1335. mindspore/train/callback/_checkpoint.py +814 -0
  1336. mindspore/train/callback/_cluster_monitor.py +201 -0
  1337. mindspore/train/callback/_dataset_graph.py +150 -0
  1338. mindspore/train/callback/_early_stop.py +239 -0
  1339. mindspore/train/callback/_flops_collector.py +239 -0
  1340. mindspore/train/callback/_history.py +92 -0
  1341. mindspore/train/callback/_lambda_callback.py +80 -0
  1342. mindspore/train/callback/_landscape.py +1049 -0
  1343. mindspore/train/callback/_loss_monitor.py +107 -0
  1344. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1345. mindspore/train/callback/_on_request_exit.py +298 -0
  1346. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1347. mindspore/train/callback/_summary_collector.py +1184 -0
  1348. mindspore/train/callback/_tft_register.py +352 -0
  1349. mindspore/train/callback/_time_monitor.py +141 -0
  1350. mindspore/train/checkpoint_pb2.py +233 -0
  1351. mindspore/train/data_sink.py +219 -0
  1352. mindspore/train/dataset_helper.py +692 -0
  1353. mindspore/train/lineage_pb2.py +1260 -0
  1354. mindspore/train/loss_scale_manager.py +213 -0
  1355. mindspore/train/memory_profiling_pb2.py +298 -0
  1356. mindspore/train/metrics/__init__.py +175 -0
  1357. mindspore/train/metrics/accuracy.py +133 -0
  1358. mindspore/train/metrics/auc.py +129 -0
  1359. mindspore/train/metrics/bleu_score.py +170 -0
  1360. mindspore/train/metrics/confusion_matrix.py +700 -0
  1361. mindspore/train/metrics/cosine_similarity.py +109 -0
  1362. mindspore/train/metrics/dice.py +116 -0
  1363. mindspore/train/metrics/error.py +175 -0
  1364. mindspore/train/metrics/fbeta.py +167 -0
  1365. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1366. mindspore/train/metrics/loss.py +97 -0
  1367. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1368. mindspore/train/metrics/metric.py +373 -0
  1369. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1370. mindspore/train/metrics/perplexity.py +133 -0
  1371. mindspore/train/metrics/precision.py +160 -0
  1372. mindspore/train/metrics/recall.py +159 -0
  1373. mindspore/train/metrics/roc.py +223 -0
  1374. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1375. mindspore/train/metrics/topk.py +167 -0
  1376. mindspore/train/mind_ir_pb2.py +1908 -0
  1377. mindspore/train/model.py +2252 -0
  1378. mindspore/train/node_strategy_pb2.py +653 -0
  1379. mindspore/train/print_pb2.py +184 -0
  1380. mindspore/train/profiling_parallel_pb2.py +151 -0
  1381. mindspore/train/serialization.py +3325 -0
  1382. mindspore/train/summary/__init__.py +23 -0
  1383. mindspore/train/summary/_lineage_adapter.py +41 -0
  1384. mindspore/train/summary/_summary_adapter.py +496 -0
  1385. mindspore/train/summary/_writer_pool.py +207 -0
  1386. mindspore/train/summary/enums.py +56 -0
  1387. mindspore/train/summary/summary_record.py +581 -0
  1388. mindspore/train/summary/writer.py +167 -0
  1389. mindspore/train/summary_pb2.py +1165 -0
  1390. mindspore/train/train_thor/__init__.py +20 -0
  1391. mindspore/train/train_thor/convert_utils.py +268 -0
  1392. mindspore/train/train_thor/dataset_helper.py +192 -0
  1393. mindspore/train/train_thor/model_thor.py +257 -0
  1394. mindspore/turbojpeg.dll +0 -0
  1395. mindspore/utils/__init__.py +21 -0
  1396. mindspore/utils/utils.py +60 -0
  1397. mindspore/vcmeta.dll +0 -0
  1398. mindspore/vcomp140.dll +0 -0
  1399. mindspore/vcruntime140.dll +0 -0
  1400. mindspore/vcruntime140_1.dll +0 -0
  1401. mindspore/version.py +1 -0
  1402. mindspore-2.4.0.dist-info/METADATA +352 -0
  1403. mindspore-2.4.0.dist-info/RECORD +1406 -0
  1404. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1405. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1406. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1985 @@
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Communication APIs.
17
+ """
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+
21
+ from mindspore.common import Tensor
22
+ from mindspore import _checkparam as validator
23
+ from mindspore.communication.management import get_rank, get_group_size, GlobalComm, _get_group, _host_distribute
24
+ from mindspore.common import dtype as mstype
25
+ from mindspore.ops.primitive import PrimitiveWithInfer, PrimitiveWithCheck, Primitive, prim_attr_register
26
+ from mindspore.common.api import context
27
+
28
+
29
+ class ReduceOp:
30
+ """
31
+ Operation options for reducing tensors. This is an enumerated type, not an operator.
32
+
33
+ The main calling methods are as follows:
34
+
35
+ - SUM: ReduceOp.SUM.
36
+ - MAX: ReduceOp.MAX.
37
+ - MIN: ReduceOp.MIN.
38
+ - PROD: ReduceOp.PROD.
39
+
40
+ There are four kinds of operation options, "SUM", "MAX", "MIN", and "PROD".
41
+
42
+ - SUM: Take the sum.
43
+ - MAX: Take the maximum.
44
+ - MIN: Take the minimum.
45
+ - PROD: Take the product.
46
+
47
+ Supported Platforms:
48
+ ``Ascend`` ``GPU``
49
+
50
+ Examples:
51
+ .. note::
52
+ Before running the following examples, you need to configure the communication environment variables.
53
+
54
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
55
+ without any third-party or configuration file dependencies.
56
+ Please see the `msrun start up
57
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
58
+ for more details.
59
+
60
+ This example should be run with multiple devices.
61
+
62
+ >>> import numpy as np
63
+ >>> import mindspore
64
+ >>> from mindspore.communication import init
65
+ >>> from mindspore import Tensor, ops, nn
66
+ >>> from mindspore.ops import ReduceOp
67
+ >>>
68
+ >>> init()
69
+ >>> class Net(nn.Cell):
70
+ ... def __init__(self):
71
+ ... super(Net, self).__init__()
72
+ ... self.allreduce_sum = ops.AllReduce(ReduceOp.SUM)
73
+ ...
74
+ ... def construct(self, x):
75
+ ... return self.allreduce_sum(x)
76
+ ...
77
+ >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
78
+ >>> net = Net()
79
+ >>> output = net(input_)
80
+ >>> print(output)
81
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
82
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
83
+ """
84
+ SUM = "sum"
85
+ MAX = "max"
86
+ MIN = "min"
87
+ PROD = "prod"
88
+
89
+
90
+ def check_collective_target_dtype(data_name, data_dtype, prim_name):
91
+ """Check if data type is valid."""
92
+ default_target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32, mstype.bfloat16)
93
+ gpu_target_dtypes = (mstype.bool_, mstype.int8, mstype.int32, mstype.int64, mstype.uint32, mstype.uint64,
94
+ mstype.float16, mstype.float32, mstype.float64)
95
+
96
+ valid_dtype = gpu_target_dtypes if context.get_context("device_target") == "GPU" else default_target_dtypes
97
+ validator.check_tensor_dtype_valid(data_name, data_dtype, valid_dtype, prim_name)
98
+
99
+
100
+ def check_hcom_group_valid(group, prim_name=None):
101
+ """Check if hcom group is valid."""
102
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
103
+ if not _host_distribute() and context.get_context("mode") == context.PYNATIVE_MODE and \
104
+ group != GlobalComm.WORLD_COMM_GROUP:
105
+ raise RuntimeError(f"{msg_prefix} 'group' only support 'hccl_world_group' in pynative mode, but got "
106
+ f"'group': {group}. Please start by using mpi-run.")
107
+
108
+
109
+ class AllReduce(Primitive):
110
+ """
111
+ Reduces tensors across all devices in such a way that all devices will get the same final result,
112
+ returns the tensor which is all reduced.
113
+
114
+ Note:
115
+ The tensors must have the same shape and format in all processes of the collection.
116
+
117
+ Args:
118
+ op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
119
+ On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
120
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
121
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
122
+
123
+ Inputs:
124
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
125
+
126
+ Outputs:
127
+ Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
128
+ The contents depend on the specified operation.
129
+
130
+ Raises:
131
+ TypeError: If any of `op` and `group` is not a str or the input's dtype is bool.
132
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
133
+
134
+ Supported Platforms:
135
+ ``Ascend`` ``GPU`` ``CPU``
136
+
137
+ Examples:
138
+ .. note::
139
+ Before running the following examples, you need to configure the communication environment variables.
140
+
141
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
142
+ without any third-party or configuration file dependencies.
143
+ Please see the `msrun start up
144
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
145
+ for more details.
146
+
147
+ This example should be run with 2 devices.
148
+
149
+ >>> import numpy as np
150
+ >>> from mindspore.communication import init
151
+ >>> from mindspore import Tensor
152
+ >>> from mindspore.ops import ReduceOp
153
+ >>> import mindspore.nn as nn
154
+ >>> from mindspore import ops
155
+ >>>
156
+ >>> init()
157
+ >>> class Net(nn.Cell):
158
+ ... def __init__(self):
159
+ ... super(Net, self).__init__()
160
+ ... self.allreduce_sum = ops.AllReduce(ReduceOp.SUM)
161
+ ...
162
+ ... def construct(self, x):
163
+ ... return self.allreduce_sum(x)
164
+ ...
165
+ >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
166
+ >>> net = Net()
167
+ >>> output = net(input_)
168
+ >>> print(output)
169
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
170
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
171
+
172
+ Tutorial Examples:
173
+ - `Distributed Set Communication Primitives - AllReduce
174
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#allreduce>`_
175
+
176
+ """
177
+
178
+ @prim_attr_register
179
+ def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
180
+ """Initialize AllReduce."""
181
+ self.group = _get_group(group)
182
+ if not isinstance(op, type(ReduceOp.SUM)):
183
+ raise TypeError(f"For '{self.name}', the 'op' must be str, but got {type(op).__name__}.")
184
+ if not isinstance(self.group, str):
185
+ raise TypeError(f"For '{self.name}', the 'group' must be str, "
186
+ f"but got {type(self.group).__name__}.")
187
+ check_hcom_group_valid(self.group, prim_name=self.name)
188
+ self.op = op
189
+ self.add_prim_attr('group', self.group)
190
+ self.add_prim_attr('fusion', 0)
191
+ self.add_prim_attr('index', 0)
192
+ self.add_prim_attr('no_eliminate', True)
193
+
194
+
195
+ class Reduce(PrimitiveWithInfer):
196
+ """
197
+ Reduces tensors across the processes in the specified communication group, sends the result
198
+ to the target dest_rank(local rank), and returns the tensor which is sent to the target process.
199
+
200
+ Note:
201
+ Only process with destination rank receives the reduced output.
202
+ Support PyNative mode and Graph mode, but Graph mode only supports scenes with a graph compilation level of O0.
203
+ Other processes only get a tensor with shape [1], which has no mathematical meaning.
204
+
205
+ Args:
206
+ dest_rank (int): The target process(local rank) in the specific group that receives the reduced output.
207
+ op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
208
+ On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
209
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
210
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
211
+
212
+ Inputs:
213
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
214
+
215
+ Outputs:
216
+ Tensor. Return the tensor in the specific rank of the process after reduction.
217
+ The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
218
+
219
+ Raises:
220
+ TypeError: If the type of the first input parameter is not Tensor,
221
+ or any of `op` and `group` is not a str.
222
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
223
+
224
+ Supported Platforms:
225
+ ``Ascend``
226
+
227
+ Examples:
228
+ .. note::
229
+ Before running the following examples, you need to configure the communication environment variables.
230
+
231
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party
232
+ or configuration file dependencies.
233
+ Please see the `msrun start up
234
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
235
+ for more details.
236
+
237
+ This example should be run with 4 devices.
238
+
239
+ >>> from mindspore import ops
240
+ >>> import mindspore.nn as nn
241
+ >>> from mindspore.communication import init
242
+ >>> from mindspore import Tensor
243
+ >>> import numpy as np
244
+ >>> # Launch 4 processes.
245
+ >>> init()
246
+ >>> class ReduceNet(nn.Cell):
247
+ >>> def __init__(self):
248
+ >>> super(Net, self).__init__()
249
+ >>> self.reduce = ops.Reduce(dest_rank=1)
250
+ >>>
251
+ >>> def construct(self, x):
252
+ >>> out = self.reduce(x)
253
+ >>> return out
254
+ >>> input = Tensor(np.ones([2, 8]).astype(np.float32))
255
+ >>> net = ReduceNet()
256
+ >>> output = net(input)
257
+ >>> print(output)
258
+ Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
259
+ [4. 4. 4. 4. 4. 4. 4. 4.]],
260
+ Other proesses: [0.].
261
+ """
262
+
263
+ @prim_attr_register
264
+ def __init__(self, dest_rank, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
265
+ validator.check_value_type('group', _get_group(group), (str,), self.name)
266
+ validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
267
+ self.dest_rank = dest_rank
268
+ self.op = op
269
+ self.group = _get_group(group)
270
+ self.add_prim_attr('group', _get_group(group))
271
+ self.add_prim_attr('dest_rank', dest_rank)
272
+
273
+ def infer_shape(self, x_shape):
274
+ # The process with dest_rank returns the reduced output.
275
+ # Other processes only gets a tensor with shape [1], which has no mathematical meaning.
276
+ if self.dest_rank == get_rank():
277
+ return x_shape
278
+ return [1]
279
+
280
+ def infer_dtype(self, x_dtype):
281
+ return x_dtype
282
+
283
+
284
+ class AllGather(PrimitiveWithInfer):
285
+ """
286
+ Gathers tensors from the specified communication group and returns the tensor which is all gathered.
287
+
288
+ Note:
289
+ - The tensors must have the same shape and format in all processes of the collection.
290
+
291
+ Args:
292
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
293
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
294
+
295
+ Inputs:
296
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
297
+
298
+ Outputs:
299
+ Tensor. If the number of devices in the group is N,
300
+ then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
301
+
302
+ Raises:
303
+ TypeError: If `group` is not a str.
304
+ ValueError: If the local rank id of the calling process in the group
305
+ is larger than the group's rank size.
306
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
307
+
308
+ Supported Platforms:
309
+ ``Ascend`` ``GPU``
310
+
311
+ Examples:
312
+ .. note::
313
+ Before running the following examples, you need to configure the communication environment variables.
314
+
315
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
316
+ without any third-party or configuration file dependencies.
317
+ Please see the `msrun start up
318
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
319
+ for more details.
320
+
321
+ This example should be run with 2 devices.
322
+
323
+ >>> import numpy as np
324
+ >>> import mindspore as ms
325
+ >>> from mindspore import ops
326
+ >>> import mindspore.nn as nn
327
+ >>> from mindspore.communication import init
328
+ >>> from mindspore import Tensor
329
+ >>>
330
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
331
+ >>> init()
332
+ >>> class Net(nn.Cell):
333
+ ... def __init__(self):
334
+ ... super(Net, self).__init__()
335
+ ... self.allgather = ops.AllGather()
336
+ ...
337
+ ... def construct(self, x):
338
+ ... return self.allgather(x)
339
+ ...
340
+ >>> input_x = Tensor(np.ones([2, 8]).astype(np.float32))
341
+ >>> net = Net()
342
+ >>> output = net(input_x)
343
+ >>> print(output)
344
+ [[1. 1. 1. 1. 1. 1. 1. 1.]
345
+ [1. 1. 1. 1. 1. 1. 1. 1.]
346
+ [1. 1. 1. 1. 1. 1. 1. 1.]
347
+ [1. 1. 1. 1. 1. 1. 1. 1.]]
348
+
349
+ Tutorial Examples:
350
+ - `Distributed Set Communication Primitives - AllGather
351
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#allgather>`_
352
+
353
+ """
354
+
355
+ @prim_attr_register
356
+ def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
357
+ """Initialize AllGather."""
358
+ self.group = _get_group(group)
359
+ validator.check_value_type('group', self.group, (str,), self.name)
360
+ self.rank = get_rank(self.group)
361
+ self.rank_size = get_group_size(self.group)
362
+ validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
363
+ self.add_prim_attr('rank_size', self.rank_size)
364
+ self.add_prim_attr('group', self.group)
365
+ self.add_prim_attr('fusion', 0)
366
+ self.add_prim_attr('mean_flag', False)
367
+ self.add_prim_attr('no_eliminate', True)
368
+
369
+ def infer_shape(self, x_shape):
370
+ validator.check_positive_int(len(x_shape), "x shape", self.name)
371
+ if x_shape[0] > 0:
372
+ x_shape[0] = x_shape[0] * self.rank_size
373
+ return x_shape
374
+
375
+ def infer_dtype(self, x_dtype):
376
+ check_collective_target_dtype('x', x_dtype, self.name)
377
+ return x_dtype
378
+
379
+
380
+ class _MiniStepAllGather(PrimitiveWithInfer):
381
+ """
382
+ Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
383
+ internal use of parallel modules and cannot be called by users.
384
+
385
+ Args:
386
+ group (str): The communication group to work on. Default: ``None`` .
387
+ grad_accumulation_step (int): The grad accumulation step. Default: ``None`` .
388
+ """
389
+
390
+ @prim_attr_register
391
+ def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None):
392
+ """Initialize _MiniStepAllGather."""
393
+ validator.check_value_type('group', _get_group(group), (str,), self.name)
394
+ self.rank = get_rank(_get_group(group))
395
+ self.rank_size = get_group_size(_get_group(group))
396
+ validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
397
+ self.add_prim_attr('rank_size', self.rank_size)
398
+ self.add_prim_attr('group', _get_group(group))
399
+ self.add_prim_attr('fusion', 1)
400
+ self.grad_accumulation_step = grad_accumulation_step
401
+ self.mean_flag = mean_flag
402
+ self.add_prim_attr('order_enforce_skip', True)
403
+ self.add_prim_attr('side_effect_backprop_mem', True)
404
+
405
+ def infer_shape(self, x_shape, z_shape):
406
+ validator.check_positive_int(len(x_shape), "x shape", self.name)
407
+ if x_shape[0] > 0:
408
+ x_shape[0] = x_shape[0] * self.rank_size
409
+ return x_shape
410
+
411
+ def infer_dtype(self, x_dtype, z_shape):
412
+ check_collective_target_dtype('x', x_dtype, self.name)
413
+ return x_dtype
414
+
415
+
416
+ class _MicroStepAllGather(PrimitiveWithInfer):
417
+ """
418
+ Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
419
+ internal use of parallel modules and cannot be called by users.
420
+
421
+ Args:
422
+ group (str): The communication group to work on. Default: ``None`` .
423
+ """
424
+
425
+ @prim_attr_register
426
+ def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None):
427
+ validator.check_value_type('group', _get_group(group), (str,), self.name)
428
+ self.rank_size = 1
429
+ if group != "":
430
+ self.rank = get_rank(_get_group(group))
431
+ self.rank_size = get_group_size(_get_group(group))
432
+ validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
433
+ self.add_prim_attr('rank_size', self.rank_size)
434
+ self.add_prim_attr('group', _get_group(group))
435
+ self.add_prim_attr('fusion', 1)
436
+ self.add_prim_attr('do_mirror', False)
437
+ self.mean_flag = mean_flag
438
+ self.add_prim_attr('order_enforce_skip', True)
439
+
440
+ def infer_shape(self, x_shape, z_shape):
441
+ validator.check_positive_int(len(x_shape), "x shape", self.name)
442
+ if x_shape[0] > 0:
443
+ x_shape[0] = x_shape[0] * self.rank_size
444
+ return x_shape
445
+
446
+ def infer_dtype(self, x_dtype, z_dtype):
447
+ check_collective_target_dtype('x', x_dtype, self.name)
448
+ return x_dtype
449
+
450
+
451
+ class _HostAllGather(PrimitiveWithInfer):
452
+ """
453
+ Gathers tensors from the specified communication group on host.
454
+
455
+ Note:
456
+ The tensors must have the same shape and format in all processes of the collection.
457
+ _HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on
458
+ to enable it. Using mpirun command to run it:
459
+ mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
460
+
461
+ Args:
462
+ group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` .
463
+
464
+ Raises:
465
+ TypeError: If group is not a list nor tuple, or elements of group are not int.
466
+ ValueError: If group is not set, or rank_id from group not in [0, 7].
467
+
468
+ Inputs:
469
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
470
+
471
+ Outputs:
472
+ Tensor. If the number of devices in the group is N,
473
+ then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
474
+ """
475
+
476
+ @prim_attr_register
477
+ def __init__(self, group=None):
478
+ """Initialize _HostAllGather."""
479
+ if group is None:
480
+ raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.")
481
+ validator.check_value_type('group', group, (tuple, list), self.name)
482
+ validator.check_int(len(group), 2, validator.GE, "group size", self.name)
483
+ for r in group:
484
+ validator.check_int_range(r, 0, 7, validator.INC_BOTH, "rank_id", self.name)
485
+ validator.check_value_type("rank_id", r, (int,), self.name)
486
+ self.group_size = len(group)
487
+ self.add_prim_attr('group', group)
488
+ self.add_prim_attr('no_eliminate', True)
489
+ self.add_prim_attr('order_enforce_skip', True)
490
+
491
+ def infer_shape(self, x_shape):
492
+ validator.check_positive_int(len(x_shape), "x shape", self.name)
493
+ if x_shape[0] > 0:
494
+ x_shape[0] = x_shape[0] * self.group_size
495
+ return x_shape
496
+
497
+ def infer_dtype(self, x_dtype):
498
+ check_collective_target_dtype('x', x_dtype, self.name)
499
+ return x_dtype
500
+
501
+ def __call__(self, tensor):
502
+ raise NotImplementedError
503
+
504
+
505
+ class ReduceScatter(Primitive):
506
+ r"""
507
+ Reduces and scatters tensors from the specified communication group
508
+ and returns the tensor which is reduced and scattered.
509
+
510
+ Note:
511
+ The tensors must have the same shape and format in all processes of the collection.
512
+
513
+ Args:
514
+ op (str, optional): Specifies an operation used for element-wise reductions,
515
+ like SUM and MAX. Default: ``ReduceOp.SUM`` .
516
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
517
+
518
+ Inputs:
519
+ - **input_x** (Tensor) - Input Tensor, suppose it has a shape :math:`(N, *)`, where `*`
520
+ means any number of additional dimensions. N must be divisible by rank_size.
521
+ rank_size refers to the number of cards in the communication group.
522
+
523
+ Outputs:
524
+ Tensor, it has the same dtype as `input_x` with a shape of :math:`(N/rank\_size, *)`.
525
+
526
+ Raises:
527
+ TypeError: If any of operation and group is not a string.
528
+ ValueError: If the first dimension of the input cannot be divided by the rank_size.
529
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
530
+
531
+ Supported Platforms:
532
+ ``Ascend`` ``GPU``
533
+
534
+ Examples:
535
+ .. note::
536
+ Before running the following examples, you need to configure the communication environment variables.
537
+
538
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
539
+ without any third-party or configuration file dependencies.
540
+ Please see the `msrun start up
541
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
542
+ for more details.
543
+
544
+ This example should be run with 2 devices.
545
+
546
+ >>> import mindspore as ms
547
+ >>> from mindspore import Tensor
548
+ >>> from mindspore.communication import init
549
+ >>> from mindspore.ops import ReduceOp
550
+ >>> import mindspore.nn as nn
551
+ >>> from mindspore import ops
552
+ >>> import numpy as np
553
+ >>>
554
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
555
+ >>> init()
556
+ >>> class Net(nn.Cell):
557
+ ... def __init__(self):
558
+ ... super(Net, self).__init__()
559
+ ... self.reducescatter = ops.ReduceScatter(ReduceOp.SUM)
560
+ ...
561
+ ... def construct(self, x):
562
+ ... return self.reducescatter(x)
563
+ ...
564
+ >>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
565
+ >>> net = Net()
566
+ >>> output = net(input_)
567
+ >>> print(output)
568
+ [[2. 2. 2. 2. 2. 2. 2. 2.]
569
+ [2. 2. 2. 2. 2. 2. 2. 2.]
570
+ [2. 2. 2. 2. 2. 2. 2. 2.]
571
+ [2. 2. 2. 2. 2. 2. 2. 2.]]
572
+
573
+ Tutorial Examples:
574
+ - `Distributed Set Communication Primitives - ReduceScatter
575
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#reducescatter>`_
576
+
577
+ """
578
+
579
+ @prim_attr_register
580
+ def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
581
+ """Initialize ReduceScatter."""
582
+ validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
583
+ self.group = _get_group(group)
584
+ validator.check_value_type('group', self.group, (str,), self.name)
585
+ self.op = op
586
+ self.rank_size = get_group_size(self.group)
587
+ self.add_prim_attr('rank_size', self.rank_size)
588
+ self.add_prim_attr('group', self.group)
589
+ self.add_prim_attr('fusion', 0)
590
+ self.add_prim_attr('no_eliminate', True)
591
+
592
+
593
+ class _HostReduceScatter(PrimitiveWithInfer):
594
+ """
595
+ Reduces and scatters tensors from the specified communication group on host.
596
+
597
+ Note:
598
+ The tensors must have the same shape and format in all processes of the collection.
599
+ _HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option
600
+ -M on to enable it. Using mpirun command to run it:
601
+ mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
602
+
603
+ Args:
604
+ op (str): Specifies an operation used for element-wise reductions,
605
+ like sum, max, avg. Default: ``ReduceOp.SUM`` .
606
+ group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` .
607
+
608
+ Raises:
609
+ TypeError: If op is not a string and group is not a list nor tuple,
610
+ or elements of group are not int.
611
+ ValueError: If the first dimension of input can not be divided by group size,
612
+ or group is not set, or rank_id not in [0, 7].
613
+ """
614
+
615
+ @prim_attr_register
616
+ def __init__(self, op=ReduceOp.SUM, group=None):
617
+ """Initialize _HostReduceScatter."""
618
+ if group is None:
619
+ raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.")
620
+ validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
621
+ validator.check_value_type('group', group, (tuple, list), self.name)
622
+ validator.check_int(len(group), 2, validator.GE, "group size", self.name)
623
+ for r in group:
624
+ validator.check_int_range(r, 0, 7, validator.INC_BOTH, "rank_id", self.name)
625
+ validator.check_value_type("rank_id", r, (int,), self.name)
626
+ self.op = op
627
+ self.group_size = len(group)
628
+ self.add_prim_attr('group', group)
629
+ self.add_prim_attr('no_eliminate', True)
630
+ self.add_prim_attr('order_enforce_skip', True)
631
+
632
+ def infer_shape(self, x_shape):
633
+ if x_shape[0] % self.group_size != 0:
634
+ raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' must be divided by 'group_size', "
635
+ f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.group_size}.")
636
+ x_shape[0] = int(x_shape[0] / self.group_size)
637
+ return x_shape
638
+
639
+ def infer_dtype(self, x_dtype):
640
+ check_collective_target_dtype('x', x_dtype, self.name)
641
+ return x_dtype
642
+
643
+ def __call__(self, tensor):
644
+ raise NotImplementedError
645
+
646
+
647
+ class Broadcast(PrimitiveWithInfer):
648
+ """
649
+ Broadcasts the tensor to the whole group.
650
+
651
+ Note:
652
+ The tensors must have the same shape and format in all processes of the collection.
653
+
654
+ Args:
655
+ root_rank (int): Specifies the rank(global rank) of the process that broadcast the tensor.
656
+ And only process `root_rank` will broadcast the tensor.
657
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
658
+
659
+ Inputs:
660
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
661
+
662
+ Outputs:
663
+ tuple[Tensor], Tensor has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
664
+ The contents depend on the data of the `root_rank` device.
665
+
666
+ Raises:
667
+ TypeError: If root_rank is not an integer or group is not a string.
668
+
669
+ Supported Platforms:
670
+ ``Ascend`` ``GPU``
671
+
672
+ Examples:
673
+ .. note::
674
+ Before running the following examples, you need to configure the communication environment variables.
675
+
676
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
677
+ without any third-party or configuration file dependencies.
678
+ Please see the `msrun start up
679
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
680
+ for more details.
681
+
682
+ This example should be run with 2 devices.
683
+
684
+ >>> import mindspore as ms
685
+ >>> from mindspore import Tensor
686
+ >>> from mindspore.communication import init
687
+ >>> import mindspore.nn as nn
688
+ >>> from mindspore import ops
689
+ >>> import numpy as np
690
+ >>>
691
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
692
+ >>> init()
693
+ >>> class Net(nn.Cell):
694
+ ... def __init__(self):
695
+ ... super(Net, self).__init__()
696
+ ... self.broadcast = ops.Broadcast(1)
697
+ ...
698
+ ... def construct(self, x):
699
+ ... return self.broadcast((x,))
700
+ ...
701
+ >>> input_x = Tensor(np.ones([2, 4]).astype(np.int32))
702
+ >>> net = Net()
703
+ >>> output = net(input_x)
704
+ >>> print(output)
705
+ (Tensor(shape[2,4], dtype=Int32, value=
706
+ [[1, 1, 1, 1],
707
+ [1, 1, 1, 1]]),)
708
+
709
+ Tutorial Examples:
710
+ - `Distributed Set Communication Primitives - Broadcast
711
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#broadcast>`_
712
+
713
+ """
714
+
715
+ @prim_attr_register
716
+ def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
717
+ """Initialize Broadcast."""
718
+ validator.check_value_type('root_rank', root_rank, (int,), self.name)
719
+ validator.check_value_type('group', _get_group(group), (str,), self.name)
720
+ check_hcom_group_valid(group, prim_name=self.name)
721
+ self.add_prim_attr('group', _get_group(group))
722
+ self.add_prim_attr('no_eliminate', True)
723
+
724
+
725
+ class _AllSwap(PrimitiveWithCheck):
726
+ """
727
+ _AllSwap is a collective operation.
728
+
729
+ _AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
730
+
731
+ - The scatter phase: On each process, the operand is split into the send size of blocks along the
732
+ 0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
733
+ - The gather phase: Each process concatenates the received blocks along the 0-th axis.
734
+
735
+ Note:
736
+ The tensors must have the same format in all processes of the collection.
737
+
738
+ Args:
739
+ group (str): The communication group name.
740
+
741
+ Inputs:
742
+ tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size.
743
+ send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process.
744
+ recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process.
745
+
746
+ Returns:
747
+ tensor_out (tensor): The result tensor.
748
+
749
+ Raises:
750
+ TypeError: If group is not a string.
751
+ """
752
+
753
+ @prim_attr_register
754
+ def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
755
+ """Initialize _AllSwap"""
756
+ validator.check_value_type('group', _get_group(group), (str,), self.name)
757
+ self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
758
+ self.add_prim_attr('group', _get_group(group))
759
+ self.add_prim_attr('no_eliminate', True)
760
+ self.add_prim_attr('order_enforce_skip', True)
761
+
762
+ def __check__(self, tensor_in, send_size, recv_size):
763
+ validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor_type, self.name)
764
+ validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64],
765
+ self.name)
766
+ validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64],
767
+ self.name)
768
+
769
+ validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name)
770
+ validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name)
771
+ validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name)
772
+
773
+ out_shape = [-1] + [tensor_in['shape'][1]]
774
+ out = {'shape': out_shape,
775
+ 'dtype': tensor_in['dtype'],
776
+ 'value': None}
777
+ return out
778
+
779
+
780
+ class NeighborExchange(Primitive):
781
+ """
782
+ NeighborExchange is a collective operation.
783
+
784
+ NeighborExchange sends data from the local rank to ranks in the send_rank_ids,
785
+ as while receive data from recv_rank_ids.
786
+
787
+ Note:
788
+ The user needs to preset
789
+ communication environment variables before running the following example, please check the details on the
790
+ official website of `MindSpore \
791
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.ops.primitive.html#communication-operator>`_.
792
+
793
+ This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
794
+ in the same subnet, please check the `details \
795
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#notes>`_.
796
+
797
+ Args:
798
+ send_rank_ids (list(int)): Ranks which the data is sent to.
799
+ recv_rank_ids (list(int)): Ranks which the data is received from.
800
+ recv_shapes (tuple(list(int))): Data shape which received from recv_rank_ids.
801
+ send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids.
802
+ recv_type (type): Data type which received from recv_rank_ids
803
+ group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
804
+
805
+ Inputs:
806
+ - **input_x** (tuple[Tensor]) - Shapes are same as args of send_shapes.
807
+
808
+ Outputs:
809
+ Tuple tensor, shapes are same as args of recv_shapes.
810
+
811
+ Supported Platforms:
812
+ ``Ascend``
813
+
814
+ Examples:
815
+ >>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn
816
+ >>> import os
817
+ >>> import mindspore as ms
818
+ >>> from mindspore import Tensor
819
+ >>> from mindspore.communication import init
820
+ >>> import mindspore.nn as nn
821
+ >>> from mindspore import ops
822
+ >>> import numpy as np
823
+ >>> class Net(nn.Cell):
824
+ ... def __init__(self):
825
+ ... super(Net, self).__init__()
826
+ ... self.neighborexchange = ops.NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1],
827
+ ... recv_shapes=([2, 2],), send_shapes=([3, 3],),
828
+ ... recv_type=ms.float32)
829
+ ...
830
+ ...
831
+ ... def construct(self, x):
832
+ ... out = self.neighborexchange((x,))
833
+ ...
834
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
835
+ >>> init()
836
+ >>> net = Net()
837
+ >>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32)
838
+ >>> output = net(input_x)
839
+ >>> print(output)
840
+ [[2. 2.], [2. 2.]]
841
+
842
+ Tutorial Examples:
843
+ - `Distributed Set Communication Primitives - NeighborExchange
844
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#neighborexchange>`_
845
+
846
+ """
847
+
848
+ @prim_attr_register
849
+ def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type,
850
+ group=GlobalComm.WORLD_COMM_GROUP):
851
+ self.init_prim_io_names(inputs=['x'], outputs=['output'])
852
+ self.send_rank_ids = send_rank_ids
853
+ self.recv_rank_ids = recv_rank_ids
854
+ self.recv_shapes = recv_shapes
855
+ self.send_shapes = send_shapes
856
+ self.recv_type = recv_type
857
+ self.add_prim_attr('group', _get_group(group))
858
+ self.add_prim_attr('no_eliminate', True)
859
+
860
+ def __call__(self, tensor):
861
+ raise NotImplementedError
862
+
863
+
864
+ class AlltoAll(PrimitiveWithInfer):
865
+ r"""
866
+ AlltoAll is a collective operation.
867
+
868
+ AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases:
869
+
870
+ - The scatter phase: On each process, the operand is split into split_count number of blocks along the
871
+ split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
872
+ - The gather phase: Each process concatenates the received blocks along the concat_dimension.
873
+
874
+ Note:
875
+ This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
876
+ in the same subnet, please check the `details \
877
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#notes>`_.
878
+
879
+ Args:
880
+ split_count (int): On each process, divide blocks into split_count number.
881
+ split_dim (int): On each process, split blocks along the split_dim.
882
+ concat_dim (int): On each process, gather the received blocks along the concat_dimension.
883
+ group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
884
+
885
+ Inputs:
886
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
887
+
888
+ Outputs:
889
+ Tensor. If the shape of input tensor is :math:`(x_1, x_2, ..., x_R)`, then the shape of output tensor is
890
+ :math:`(y_1, y_2, ..., y_R)`, where:
891
+
892
+ - :math:`y_{split\_dim} = x_{split\_dim} / split\_count`
893
+ - :math:`y_{concat\_dim} = x_{concat\_dim} * split\_count`
894
+ - :math:`y_{other} = x_{other}`.
895
+
896
+ Raises:
897
+ TypeError: If group is not a string.
898
+
899
+ Supported Platforms:
900
+ ``Ascend``
901
+
902
+ Examples:
903
+ .. note::
904
+ Before running the following examples, you need to configure the communication environment variables.
905
+
906
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
907
+ without any third-party or configuration file dependencies.
908
+ Please see the `msrun start up
909
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
910
+ for more details.
911
+
912
+ This example should be run with 8 devices.
913
+
914
+ >>> import os
915
+ >>> import mindspore as ms
916
+ >>> from mindspore import Tensor
917
+ >>> from mindspore.communication import init
918
+ >>> import mindspore.nn as nn
919
+ >>> from mindspore import ops
920
+ >>> import numpy as np
921
+ >>> class Net(nn.Cell):
922
+ ... def __init__(self):
923
+ ... super(Net, self).__init__()
924
+ ... self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1)
925
+ ...
926
+ ... def construct(self, x):
927
+ ... out = self.alltoall(x)
928
+ ... return out
929
+ ...
930
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
931
+ >>> init()
932
+ >>> net = Net()
933
+ >>> rank_id = int(os.getenv("RANK_ID"))
934
+ >>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32)
935
+ >>> output = net(input_x)
936
+ >>> print(output)
937
+ [[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
938
+
939
+ Tutorial Examples:
940
+ - `Distributed Set Communication Primitives - AlltoAll
941
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#alltoall>`_
942
+
943
+ """
944
+
945
+ @prim_attr_register
946
+ def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP):
947
+ """Initialize AlltoAll"""
948
+ validator.check_value_type('group', _get_group(group), (str,), self.name)
949
+ validator.check_is_int(split_count, int)
950
+ validator.check_is_int(split_dim, int)
951
+ validator.check_is_int(concat_dim, int)
952
+ self.split_count = split_count
953
+ self.split_dim = split_dim
954
+ self.concat_dim = concat_dim
955
+ self.add_prim_attr('group', _get_group(group))
956
+ self.add_prim_attr('no_eliminate', True)
957
+
958
+ def infer_shape(self, x_shape):
959
+ rank_size = get_group_size(_get_group(self.group))
960
+ if self.split_count != rank_size:
961
+ raise ValueError(f"For '{self.name}', the 'split_count' must be equal to 'rank_size', "
962
+ f"but got 'split_count': {self.split_count}, 'rank_size': {rank_size}.")
963
+ if x_shape[self.split_dim] >= 0 and x_shape[self.split_dim] % self.split_count != 0:
964
+ raise ValueError(f"For '{self.name}', the 'x_shape[self.split_dim]' must be divisible by 'split_count', "
965
+ f"but got 'x_shape[self.split_dim]' {x_shape[self.split_dim]}, "
966
+ f"'split_count' {self.split_count}.")
967
+ if x_shape[self.concat_dim] >= 0:
968
+ x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
969
+ if x_shape[self.split_dim] >= 0:
970
+ x_shape[self.split_dim] = int(x_shape[self.split_dim] / self.split_count)
971
+ return x_shape
972
+
973
+ def infer_dtype(self, x_dtype):
974
+ check_collective_target_dtype('x', x_dtype, self.name)
975
+ return x_dtype
976
+
977
+
978
+ class NeighborExchangeV2(Primitive):
979
+ r"""
980
+ NeighborExchangeV2 is a collective communication operation.
981
+
982
+ NeighborExchangeV2 sends data from the local rank to ranks in the `send_rank_ids`,
983
+ as while receive data from `recv_rank_ids`. Please refer to the tutorial examples
984
+ below to learn about how the data is exchanged between neighborhood devices.
985
+
986
+ Note:
987
+ This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
988
+ in the same subnet, please check the `details \
989
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#notes>`_.
990
+
991
+ Args:
992
+ send_rank_ids (list(int)): Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one
993
+ direction is not send to , set it -1.
994
+ recv_rank_ids (list(int)): Ranks which the data is received from. 8 rank_ids represents 8 directions,
995
+ if one direction is not recv from , set it -1.
996
+ send_lens (list(int)): Data lens which send to the send_rank_ids, 4 numbers represent the lens of
997
+ [send_top, send_bottom, send_left, send_right].
998
+ recv_lens (list(int)): Data lens which received from recv_rank_ids, 4 numbers represent the lens of
999
+ [recv_top, recv_bottom, recv_left, recv_right].
1000
+ data_format (str): Data format, only support NCHW now.
1001
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
1002
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
1003
+
1004
+ Inputs:
1005
+ - **input_x** (Tensor) - The Tensor before being exchanged. It has a shape of :math:`(N, C, H, W)`.
1006
+
1007
+ Outputs:
1008
+ The Tensor after being exchanged. If input shape is :math:`(N, C, H, W)`, output shape is
1009
+ :math:`(N, C, H+recv\_top+recv\_bottom, W+recv\_left+recv\_right)`.
1010
+
1011
+ Raises:
1012
+ TypeError: If `group` is not a string or any one of `send_rank_ids`,
1013
+ `recv_rank_ids`, `send_lens`, `recv_lens` is not a list.
1014
+ ValueError: If `send_rank_ids` or `recv_rank_ids` has value less than -1 or has repeated values.
1015
+ ValueError: If `send_lens`, `recv_lens` has value less than 0.
1016
+ ValueError: If `data_format` is not "NCHW".
1017
+
1018
+ Supported Platforms:
1019
+ ``Ascend``
1020
+
1021
+ Examples:
1022
+ .. note::
1023
+ Before running the following examples, you need to configure the communication environment variables.
1024
+
1025
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1026
+ without any third-party or configuration file dependencies.
1027
+ Please see the `msrun start up
1028
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1029
+ for more details.
1030
+
1031
+ This example should be run with 2 devices.
1032
+
1033
+ >>> import os
1034
+ >>> import mindspore as ms
1035
+ >>> from mindspore.communication import init
1036
+ >>> import mindspore.nn as nn
1037
+ >>> from mindspore import ops
1038
+ >>> import numpy as np
1039
+ >>>
1040
+ >>> class Net0(nn.Cell):
1041
+ ... def __init__(self):
1042
+ ... super(Net0, self).__init__()
1043
+ ... self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
1044
+ ... send_lens=[0, 1, 0, 0],
1045
+ ... recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
1046
+ ... recv_lens=[0, 1, 0, 0], data_format="NCHW")
1047
+ ...
1048
+ ... def construct(self, x):
1049
+ ... out = self.neighbor_exchangev2(x)
1050
+ ... return out
1051
+ ... class Net1(nn.Cell):
1052
+ ... def __init__(self):
1053
+ ... super(Net1, self).__init__()
1054
+ ... self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
1055
+ ... send_lens=[1, 0, 0, 0],
1056
+ ... recv_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
1057
+ ... recv_lens=[1, 0, 0, 0], data_format="NCHW")
1058
+ ...
1059
+ ... def construct(self, x):
1060
+ ... out = self.neighbor_exchangev2(x)
1061
+ ... return out
1062
+ >>>
1063
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
1064
+ >>> init()
1065
+ >>> rank_id = int(os.getenv("RANK_ID"))
1066
+ >>> if (rank_id % 2 == 0):
1067
+ >>> input_x = ms.Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32)
1068
+ >>> net = Net0()
1069
+ >>> output = net(input_x)
1070
+ >>> print(output)
1071
+ >>> else:
1072
+ >>> input_x = ms.Tensor(np.ones([1, 1, 2, 2]) * 2, dtype = ms.float32)
1073
+ >>> net = Net1()
1074
+ >>> output = net(input_x)
1075
+ >>> print(output)
1076
+ [[[[1. 1.], [1. 1.], [2. 2.]]]]
1077
+
1078
+ Tutorial Examples:
1079
+ - `Distributed Set Communication Primitives - NeighborExchangeV2
1080
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#neighborexchangev2>`_
1081
+
1082
+ """
1083
+
1084
+ @prim_attr_register
1085
+ def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format,
1086
+ group=GlobalComm.WORLD_COMM_GROUP):
1087
+ self.init_prim_io_names(inputs=['x'], outputs=['output'])
1088
+ self.send_rank_ids = send_rank_ids
1089
+ self.recv_rank_ids = recv_rank_ids
1090
+ self.send_lens = send_lens
1091
+ self.recv_lens = recv_lens
1092
+ self.format = data_format
1093
+ self.add_prim_attr('group', _get_group(group))
1094
+ self.add_prim_attr('no_eliminate', True)
1095
+ self.rank_size = get_group_size(_get_group(group))
1096
+ for rank_id in send_rank_ids:
1097
+ if rank_id != -1:
1098
+ validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int,
1099
+ "rank_id in send_rank_ids")
1100
+ for rank_id in recv_rank_ids:
1101
+ if rank_id != -1:
1102
+ validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int,
1103
+ "rank_id in recv_rank_ids")
1104
+
1105
+ def __call__(self, tensor):
1106
+ raise NotImplementedError
1107
+
1108
+
1109
+ class CollectiveScatter(Primitive):
1110
+ r"""
1111
+ Scatter tensor evently across the processes in the specified communication group.
1112
+
1113
+ Note:
1114
+ The interface behavior only support Tensor input and scatter evenly.
1115
+ Only the tensor in process `src_rank` (global rank) will do scatter.
1116
+
1117
+ Args:
1118
+ src_rank (int, optional): Specifies the rank of the process that send the tensor.
1119
+ And only process `src_rank` will send the tensor.
1120
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1121
+
1122
+ Inputs:
1123
+ - **input_x** (Tensor) - The input tensor to be scattered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1124
+
1125
+ Outputs:
1126
+ Tensor, the shape of output is :math:`(x_1/src\_rank, x_2, ..., x_R)`. The dimension 0 of data is equal to
1127
+ the dimension of input tensor divided by `src`, and the other dimension keep the same.
1128
+
1129
+ Raises:
1130
+ TypeError: If `group` is not a str.
1131
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1132
+ ValueError: If the local rank id of the calling process in the group
1133
+ is larger than the group's rank size.
1134
+
1135
+ Supported Platforms:
1136
+ ``Ascend``
1137
+
1138
+ Examples:
1139
+ .. note::
1140
+ Before running the following examples, you need to configure the communication environment variables.
1141
+
1142
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1143
+ without any third-party or configuration file dependencies.
1144
+ Please see the `msrun start up
1145
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1146
+ for more details.
1147
+
1148
+ This example should be run with 2 devices.
1149
+
1150
+ >>> import numpy as np
1151
+ >>> import mindspore.nn as nn
1152
+ >>> from mindspore import Tensor
1153
+ >>> from mindspore.communication.management import init, get_rank
1154
+ >>> from mindspore import ops
1155
+ >>> # Launch 2 processes.
1156
+ >>> init()
1157
+ >>> class CollectiveScatterNet(nn.Cell):
1158
+ >>> def __init__(self):
1159
+ >>> super(CollectiveScatter, self).__init__()
1160
+ >>> self.collective_scatter = ops.CollectiveScatter(src_rank=0)
1161
+ >>>
1162
+ >>> def construct(self, x):
1163
+ >>> return self.collective_scatter(x)
1164
+ >>>
1165
+ >>> input = Tensor(np.arange(8).reshape([4, 2]).astype(np.float32))
1166
+ >>> net = CollectiveScatterNet()
1167
+ >>> output = net(input)
1168
+ >>> print(output)
1169
+ Process with rank 0: [[0. 1.],
1170
+ [2. 3.]]
1171
+ Process with rank 1: [[4. 5.],
1172
+ [6. 7.]]
1173
+
1174
+ Tutorial Examples:
1175
+ - `Distributed Set Communication Primitives - CollectiveScatter
1176
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#reducescatter>`_
1177
+
1178
+ """
1179
+
1180
+ @prim_attr_register
1181
+ def __init__(self, src_rank=0, group=GlobalComm.WORLD_COMM_GROUP):
1182
+ validator.check_value_type('group', _get_group(group), (str,), self.name)
1183
+ self.rank_id = get_rank(_get_group(group))
1184
+ self.src_rank = src_rank
1185
+ self.rank_size = get_group_size(_get_group(group))
1186
+ validator.check('rank', self.rank_id, 'rank_size', self.rank_size, validator.LT, self.name)
1187
+ self.add_prim_attr('rank_id', self.rank_id)
1188
+ self.add_prim_attr('src_rank', self.src_rank)
1189
+ self.add_prim_attr('rank_size', self.rank_size)
1190
+ self.add_prim_attr('group', _get_group(group))
1191
+
1192
+
1193
+ class CollectiveGather(Primitive):
1194
+ r"""
1195
+ Gathers tensors from the specified communication group. The operation will gather the tensor
1196
+ from processes according to dimension 0.
1197
+
1198
+ Note:
1199
+ Only the tensor in process `dest_rank` (global rank) will keep the gathered tensor. The other process
1200
+ will keep a tensor with shape [1], which has no mathematical meaning.
1201
+
1202
+ Args:
1203
+ dest_rank(int): Specifies the rank of the process that receive the tensor.
1204
+ And only process `dest_rank` will receive the gathered tensor.
1205
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1206
+
1207
+ Inputs:
1208
+ - **input_x** (Tensor) - The tensor to be gathered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1209
+
1210
+ Outputs:
1211
+ Tensor, the shape of output is :math:`(\sum x_1, x_2, ..., x_R)`. The dimension 0 of data is equal to
1212
+ sum of the dimension of input tensor, and the other dimension keep the same.
1213
+
1214
+ Raises:
1215
+ TypeError: If `group` is not a str.
1216
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1217
+ ValueError: If the local rank id of the calling process in the group
1218
+ is larger than the group's rank size.
1219
+
1220
+ Supported Platforms:
1221
+ ``Ascend``
1222
+
1223
+ Examples:
1224
+ .. note::
1225
+ Before running the following examples, you need to configure the communication environment variables.
1226
+
1227
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1228
+ without any third-party or configuration file dependencies.
1229
+ Please see the `msrun start up
1230
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1231
+ for more details.
1232
+
1233
+ This example should be run with 4 devices.
1234
+
1235
+ >>> import numpy as np
1236
+ >>> import mindspore as ms
1237
+ >>> import mindspore.nn as nn
1238
+ >>> from mindspore.communication import init
1239
+ >>> from mindspore import Tensor
1240
+ >>> from mindspore import ops
1241
+ >>> # Launch 2 processes.
1242
+ >>>
1243
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
1244
+ >>> init()
1245
+ >>> class CollectiveGatherNet(nn.Cell):
1246
+ ... def __init__(self):
1247
+ ... super(CollectiveGatherNet, self).__init__()
1248
+ ... self.collective_gather = ops.CollectiveGather(dest_rank=0)
1249
+ ...
1250
+ ... def construct(self, x):
1251
+ ... return self.collective_gather(x)
1252
+ ...
1253
+ >>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32))
1254
+ >>> net = CollectiveGatherNet()
1255
+ >>> output = net(input)
1256
+ >>> print(output)
1257
+ Process with rank 0: [[0. 1.],
1258
+ [2. 3.],
1259
+ [0. 1.],
1260
+ [2. 3.]]
1261
+ Process with rank 1: [0.]
1262
+
1263
+ Tutorial Examples:
1264
+ - `Distributed Set Communication Primitives - CollectiveGather
1265
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#collectivegather>`_
1266
+
1267
+ """
1268
+
1269
+ @prim_attr_register
1270
+ def __init__(self, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
1271
+ """Initialize Gather."""
1272
+ validator.check_value_type('group', _get_group(group), (str,), self.name)
1273
+ self.rank_id = get_rank(_get_group(group))
1274
+ self.dest_rank = dest_rank
1275
+ self.rank_size = get_group_size(_get_group(group))
1276
+ validator.check('rank', self.rank_id, 'rank_size', self.rank_size, validator.LT, self.name)
1277
+ self.add_prim_attr('rank_size', self.rank_size)
1278
+ self.add_prim_attr('group', _get_group(group))
1279
+ self.add_prim_attr('dest_rank', self.dest_rank)
1280
+ self.add_prim_attr('rank_id', self.rank_id)
1281
+
1282
+
1283
+ class Barrier(PrimitiveWithInfer):
1284
+ """
1285
+ Synchronizes all processes in the specified group. Once the process call this operation, it will be blocked until
1286
+ all processes call this operation. After all processes finish calling the operations, the blocked processes
1287
+ will be waken and continue their task.
1288
+
1289
+ Args:
1290
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1291
+
1292
+ Raises:
1293
+ TypeError: If `group` is not a str.
1294
+ RuntimeError: If backend is invalid, or distributed initialization fails.
1295
+
1296
+ Supported Platforms:
1297
+ ``Ascend``
1298
+
1299
+ Examples:
1300
+ .. note::
1301
+ Before running the following examples, you need to configure the communication environment variables.
1302
+
1303
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1304
+ without any third-party or configuration file dependencies.
1305
+ Please see the `msrun start up
1306
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1307
+ for more details.
1308
+
1309
+ This example should be run with 2 devices.
1310
+
1311
+ >>> import numpy as np
1312
+ >>> import mindspore.nn as nn
1313
+ >>> from mindspore.communication import init
1314
+ >>> from mindspore import Tensor
1315
+ >>> from mindspore import ops
1316
+ >>> # Launch 4 processes.
1317
+ >>> init()
1318
+ >>> class BarrierNet(nn.Cell):
1319
+ >>> def __init__(self):
1320
+ >>> super(BarrierNet, self).__init__()
1321
+ >>> self.barrier = ops.Barrier()
1322
+ >>>
1323
+ >>> def construct(self):
1324
+ >>> self.barrier()
1325
+ >>> net = BarrierNet()
1326
+ >>> net()
1327
+
1328
+ Tutorial Examples:
1329
+ - `Distributed Set Communication Primitives - Barrier
1330
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#barrier>`_
1331
+
1332
+ """
1333
+
1334
+ @prim_attr_register
1335
+ def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
1336
+ self.group = group
1337
+ self.add_prim_attr("side_effect_mem", True)
1338
+
1339
+ def infer_shape(self):
1340
+ return [1]
1341
+
1342
+ def infer_dtype(self):
1343
+ return mstype.float32
1344
+
1345
+
1346
+ class Send(PrimitiveWithInfer):
1347
+ """
1348
+ Send tensors to the specified dest_rank.
1349
+
1350
+ Note:
1351
+ Send and Receive must be used in combination and have same sr_tag.
1352
+
1353
+ Args:
1354
+ sr_tag (int): The tag to identify the send/recv message. The message will
1355
+ be received by the Receive op with the same "sr_tag".
1356
+ dest_rank (int): A required integer identifying the destination rank.
1357
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1358
+ group_back (str, optional): The communication group for backpropagation.
1359
+ Default: ``GlobalComm.WORLD_COMM_GROUP``.
1360
+
1361
+ Inputs:
1362
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1363
+
1364
+ Raises:
1365
+ TypeError: If `group` is not a str.
1366
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1367
+ ValueError: If the local rank id of the calling process in the group
1368
+ is larger than the group's rank size.
1369
+
1370
+ Supported Platforms:
1371
+ ``Ascend`` ``GPU``
1372
+
1373
+ Examples:
1374
+ .. note::
1375
+ Before running the following examples, you need to configure the communication environment variables.
1376
+
1377
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1378
+ without any third-party or configuration file dependencies.
1379
+ Please see the `msrun start up
1380
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1381
+ for more details.
1382
+
1383
+ This example should be run with 2 devices.
1384
+
1385
+ >>> import numpy as np
1386
+ >>> import mindspore.nn as nn
1387
+ >>> from mindspore.communication import init
1388
+ >>> from mindspore import Tensor
1389
+ >>> from mindspore import ops
1390
+ >>>
1391
+ >>> init()
1392
+ >>> class SendNet(nn.Cell):
1393
+ >>> def __init__(self):
1394
+ >>> super(SendNet, self).__init__()
1395
+ >>> self.depend = ops.Depend()
1396
+ >>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
1397
+ >>>
1398
+ >>> def construct(self, x):
1399
+ >>> out = self.depend(x, self.send(x))
1400
+ >>> return out
1401
+ >>>
1402
+ >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
1403
+ >>> net = Net()
1404
+ >>> output = net(input_)
1405
+
1406
+ Tutorial Examples:
1407
+ - `Distributed Set Communication Primitives - Send
1408
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#send>`_
1409
+
1410
+ """
1411
+
1412
+ @prim_attr_register
1413
+ def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP):
1414
+ self.rank = dest_rank
1415
+ self.sr_tag = sr_tag
1416
+ self.group = _get_group(group)
1417
+ self.add_prim_attr("no_eliminate", True)
1418
+
1419
+ def infer_shape(self, x_shape):
1420
+ self.add_prim_attr("shape", x_shape)
1421
+ return x_shape
1422
+
1423
+ def infer_dtype(self, x_dtype):
1424
+ return x_dtype
1425
+
1426
+
1427
+ class Receive(PrimitiveWithInfer):
1428
+ """
1429
+ Receive tensors from src_rank.
1430
+
1431
+ Note:
1432
+ Send and Receive must be used in combination and have same sr_tag.
1433
+
1434
+ Args:
1435
+ sr_tag (int): A required integer identifying the send/recv message tag. The message will
1436
+ will be send by the Send op with the same "sr_tag".
1437
+ src_rank (int): A required integer identifying the source rank.
1438
+ shape (list[int]): A required list identifying the shape of the tensor to be received.
1439
+ dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
1440
+ int8/int16/int32/float16/float32.
1441
+ group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1442
+ group_back (str, optional): The communication group for backpropagation.
1443
+ Default: ``GlobalComm.WORLD_COMM_GROUP``.
1444
+
1445
+ Outputs:
1446
+ Tensor, output has the same shape as the Tensor sent by `Send` operation.
1447
+
1448
+ Raises:
1449
+ TypeError: If `group` is not a str.
1450
+ RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1451
+ ValueError: If the local rank id of the calling process in the group
1452
+ is larger than the group's rank size.
1453
+
1454
+ Supported Platforms:
1455
+ ``Ascend`` ``GPU``
1456
+
1457
+ Examples:
1458
+ .. note::
1459
+ Before running the following examples, you need to configure the communication environment variables.
1460
+
1461
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1462
+ without any third-party or configuration file dependencies.
1463
+ Please see the `msrun start up
1464
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1465
+ for more details.
1466
+
1467
+ This example should be run with 2 devices.
1468
+
1469
+ >>> import numpy as np
1470
+ >>> import mindspore.nn as nn
1471
+ >>> from mindspore.communication import init
1472
+ >>> from mindspore import Tensor
1473
+ >>> from mindspore import ops
1474
+ >>>
1475
+ >>> init()
1476
+ >>> class ReceiveNet(nn.Cell):
1477
+ >>> def __init__(self):
1478
+ >>> super(ReceiveNet, self).__init__()
1479
+ >>> self.recv = ops.Receive(sr_tag=0, src_rank=0, shape=[2, 8], dtype=ms.float32,
1480
+ >>> group="hccl_world_group")
1481
+ >>>
1482
+ >>> def construct(self):
1483
+ >>> out = self.recv()
1484
+ >>> return out
1485
+ >>>
1486
+ >>> net = Net()
1487
+ >>> output = net()
1488
+
1489
+ Tutorial Examples:
1490
+ - `Distributed Set Communication Primitives - Receive
1491
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#receive>`_
1492
+
1493
+ """
1494
+
1495
+ @prim_attr_register
1496
+ def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP,
1497
+ group_back=GlobalComm.WORLD_COMM_GROUP):
1498
+ self.rank = src_rank
1499
+ self.tag = sr_tag
1500
+ self.shape = shape
1501
+ self.dtype = dtype
1502
+ self.group = _get_group(group)
1503
+ self.add_prim_attr("no_eliminate", True)
1504
+ valid_type = [mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16,
1505
+ mstype.int8, mstype.int16, mstype.int32, mstype.int64,
1506
+ mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64]
1507
+ args = {"dtype": dtype}
1508
+ validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
1509
+
1510
+ def infer_shape(self, x_shape=None):
1511
+ return self.get_attr_dict()['shape']
1512
+
1513
+ def infer_dtype(self, x_dtype=None):
1514
+ return self.get_attr_dict()['dtype']
1515
+
1516
+
1517
+ class _MirrorOperator(PrimitiveWithInfer):
1518
+ """
1519
+ Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
1520
+ internal use of parallel modules and cannot be called by users.
1521
+
1522
+ Args:
1523
+ group (str): The communication group to work on. Default: ``None`` .
1524
+ dev_num (int): The device number of the group. Default: ``None`` .
1525
+ mean_flag (bool): Whether use mean in backward. Default: ``None`` .
1526
+ """
1527
+
1528
+ @prim_attr_register
1529
+ def __init__(self, group=None, dev_num=None, mean_flag=None):
1530
+ """Initialize _MirrorOperator."""
1531
+ self.group = group
1532
+ self.dev_num = dev_num
1533
+ self.mean_flag = mean_flag
1534
+ self.add_prim_attr("fusion", 1)
1535
+ self.add_prim_attr('order_enforce_skip', True)
1536
+
1537
+ def infer_shape(self, x_shape):
1538
+ return x_shape
1539
+
1540
+ def infer_dtype(self, x_dtype):
1541
+ return x_dtype
1542
+
1543
+
1544
+ mirror = _MirrorOperator()
1545
+
1546
+
1547
+ class _MirrorMiniStepOperator(PrimitiveWithInfer):
1548
+ """
1549
+ Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
1550
+ internal use of parallel modules and cannot be called by users.
1551
+
1552
+ Args:
1553
+ group (str): The communication group to work on. Default: ``None`` .
1554
+ dev_num (int): The device number of the group. Default: ``None`` .
1555
+ mean_flag (bool): Whether use mean in backward. Default: ``None`` .
1556
+ grad_accumulation_step (int): The grad accumulation step. Default: ``None`` .
1557
+ """
1558
+
1559
+ @prim_attr_register
1560
+ def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None):
1561
+ """Initialize _MirrorMiniStepOperator."""
1562
+ self.group = group
1563
+ self.dev_num = dev_num
1564
+ self.mean_flag = mean_flag
1565
+ self.grad_accumulation_step = grad_accumulation_step
1566
+ self.add_prim_attr('order_enforce_skip', True)
1567
+ self.add_prim_attr('side_effect_backprop_mem', True)
1568
+
1569
+ def infer_shape(self, x_shape, z_shape):
1570
+ return x_shape
1571
+
1572
+ def infer_dtype(self, x_dtype, z_shape):
1573
+ return x_dtype
1574
+
1575
+
1576
+ mirror_mini_step = _MirrorMiniStepOperator()
1577
+
1578
+
1579
+ class _VirtualDiv(PrimitiveWithInfer):
1580
+ """
1581
+ Auto parallel virtual operator. Do nothing in forward, do Div in backward.
1582
+
1583
+ Args:
1584
+ divisor: float32
1585
+ """
1586
+
1587
+ @prim_attr_register
1588
+ def __init__(self, divisor=None):
1589
+ """Initialize _VirtualDiv."""
1590
+ self.divisor = divisor
1591
+ self.add_prim_attr('order_enforce_skip', True)
1592
+
1593
+ def infer_shape(self, x_shape):
1594
+ return x_shape
1595
+
1596
+ def infer_dtype(self, x_dtype):
1597
+ return x_dtype
1598
+
1599
+
1600
+ virtual_div = _VirtualDiv()
1601
+
1602
+
1603
+ class _VirtualPipelineEnd(PrimitiveWithInfer):
1604
+ """
1605
+ Auto parallel virtual operator. Do nothing in forward and backward, mark end node in pipeline parallel.
1606
+
1607
+ Args:
1608
+ divisor: float32
1609
+ """
1610
+
1611
+ @prim_attr_register
1612
+ def __init__(self):
1613
+ """Initialize _VirtualPipelineEnd."""
1614
+
1615
+ def infer_shape(self, x_shape):
1616
+ return x_shape
1617
+
1618
+ def infer_dtype(self, x_dtype):
1619
+ return x_dtype
1620
+
1621
+
1622
+ virtual_pipeline_end = _VirtualPipelineEnd()
1623
+
1624
+
1625
+ class _VirtualAdd(PrimitiveWithInfer):
1626
+ """Auto parallel virtual operator. Do nothing in forward, do Add in backward."""
1627
+
1628
+ @prim_attr_register
1629
+ def __init__(self):
1630
+ """Initialize _VirtualAdd."""
1631
+ self.add_prim_attr('order_enforce_skip', True)
1632
+
1633
+ def infer_shape(self, x_shape, y_shape):
1634
+ return x_shape
1635
+
1636
+ def infer_dtype(self, x_dtype, y_dtype):
1637
+ return x_dtype
1638
+
1639
+
1640
+ class _VirtualDataset(PrimitiveWithInfer):
1641
+ """
1642
+ Auto parallel virtual dataset operator.
1643
+
1644
+ It would insert VirtualDataset operator in forward computation and be deleted before backward computation.
1645
+ """
1646
+
1647
+ @prim_attr_register
1648
+ def __init__(self):
1649
+ """Initialize _VirtualDataset."""
1650
+ self.add_prim_attr('order_enforce_skip', True)
1651
+
1652
+ def infer_shape(self, *args):
1653
+ return args
1654
+
1655
+ def infer_dtype(self, *args):
1656
+ return args
1657
+
1658
+
1659
+ virtual_dataset = _VirtualDataset()
1660
+
1661
+
1662
+ class _VirtualAssignAdd(PrimitiveWithInfer):
1663
+ """
1664
+ Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for
1665
+ internal use of parallel modules and cannot be called by users.
1666
+
1667
+ """
1668
+
1669
+ @prim_attr_register
1670
+ def __init__(self):
1671
+ """Initialize _VirtualAssignAdd."""
1672
+ self.add_prim_attr('order_enforce_skip', True)
1673
+ self.add_prim_attr('side_effect_backprop_mem', True)
1674
+
1675
+ def infer_shape(self, x_shape, y_shape):
1676
+ return x_shape
1677
+
1678
+ def infer_dtype(self, x_dtype, y_dtype):
1679
+ return x_dtype
1680
+
1681
+
1682
+ virtual_assign_add = _VirtualAssignAdd()
1683
+
1684
+
1685
+ class _VirtualAccuGrad(PrimitiveWithInfer):
1686
+ """
1687
+ Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for
1688
+ internal use of parallel modules and cannot be called by users.
1689
+ """
1690
+
1691
+ @prim_attr_register
1692
+ def __init__(self):
1693
+ """Initialize _VirtualAccuGrad."""
1694
+ self.add_prim_attr('order_enforce_skip', True)
1695
+
1696
+ def infer_shape(self, x_shape, y_shape):
1697
+ return x_shape
1698
+
1699
+ def infer_dtype(self, x_dtype, y_dtype):
1700
+ return x_dtype
1701
+
1702
+
1703
+ virtual_accu_grad = _VirtualAccuGrad()
1704
+
1705
+
1706
+ class _MirrorMicroStepOperator(PrimitiveWithInfer):
1707
+ """
1708
+ Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
1709
+ internal use of parallel modules and cannot be called by users.
1710
+
1711
+ Args:
1712
+ group (str): The communication group to work on. Default: ``None`` .
1713
+ dev_num (int): The device number of the group. Default: ``None`` .
1714
+ mean_flag (bool): Whether use mean in backward. Default: ``None`` .
1715
+ """
1716
+
1717
+ @prim_attr_register
1718
+ def __init__(self, group=None, dev_num=None, mean_flag=None):
1719
+ """Initialize _MirrorMicroStepOperator."""
1720
+ self.group = group
1721
+ self.dev_num = dev_num
1722
+ self.mean_flag = mean_flag
1723
+ self.add_prim_attr('order_enforce_skip', True)
1724
+ self.add_prim_attr('side_effect_backprop_mem', True)
1725
+
1726
+ def infer_shape(self, x_shape, z_shape):
1727
+ return x_shape
1728
+
1729
+ def infer_dtype(self, x_dtype, z_shape):
1730
+ return x_dtype
1731
+
1732
+
1733
+ class _VirtualOutput(PrimitiveWithInfer):
1734
+ """
1735
+ Auto parallel virtual out operator.
1736
+
1737
+ It would insert VirtualOutput operator in forward computation and be deleted before backward computation.
1738
+ """
1739
+
1740
+ @prim_attr_register
1741
+ def __init__(self):
1742
+ """Initialize _VirtualOutput."""
1743
+ self.add_prim_attr('order_enforce_skip', True)
1744
+
1745
+ def infer_shape(self, x_shape):
1746
+ return x_shape
1747
+
1748
+ def infer_dtype(self, x_dtype):
1749
+ return x_dtype
1750
+
1751
+
1752
+ class _GetTensorSlice(PrimitiveWithInfer):
1753
+ """
1754
+ Gets tensor slice by device matrix and tensor map.
1755
+
1756
+ Args:
1757
+ dev_mat (tuple): The device matrix of the slice tensor.
1758
+ tensor_map (tuple): The tensor map of the slice tensor.
1759
+ """
1760
+
1761
+ @prim_attr_register
1762
+ def __init__(self):
1763
+ """Initialize _GetTensorSlice."""
1764
+ self.add_prim_attr('order_enforce_skip', True)
1765
+
1766
+ def infer_value(self, x, dev_mat, tensor_map, slice_shape, full_shape):
1767
+ from mindspore.parallel._tensor import _load_tensor
1768
+ validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
1769
+ validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
1770
+ tensor_slice = _load_tensor(x, dev_mat, tensor_map, full_shape)
1771
+ if tensor_slice.shape != slice_shape:
1772
+ tensor_slice = tensor_slice.reshape(slice_shape)
1773
+ return Tensor(tensor_slice, x.dtype)
1774
+
1775
+
1776
+ class BatchISendIRecv(PrimitiveWithInfer):
1777
+ """
1778
+ Batch send and recv tensors asynchronously.
1779
+
1780
+ Note:
1781
+ - The ``isend`` and ``irecv`` in ``op_types`` between ranks need to match each other.
1782
+ - ``isend`` and ``irecv`` in a batch can only be used in the same communication group.
1783
+
1784
+ Args:
1785
+ op_types(Union[tuple[str], list[str]]): "isend" or "irecv" to indicate the order and number of communication.
1786
+ remote_ranks(Union[tuple[int], list[int]]): src or dst rank that matches the op_types.
1787
+ receive_shapes(Union[tuple[int], list[int]]): receive tensor shapes that matches "irecv" in op_types.
1788
+ receive_types(Union[tuple[mindspore.dtype], list[mindspore.dtype]]): receive tensor dtype
1789
+ that matches "irecv" in op_types.
1790
+ group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which
1791
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
1792
+
1793
+ Inputs:
1794
+ - **input_x** (Union[tuple[Tensor], list[Tensor], tuple(None)]) -
1795
+ The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1796
+
1797
+ Outputs:
1798
+ tuple(Tensor). Output tensors is corresponding to ``op_types``:
1799
+ At ``"isend"`` position, output tensor is a fake tensor with scalar, which has no meaning.
1800
+ At ``"irecv"`` position, output tensor is a tensor received from remote end.
1801
+
1802
+
1803
+ Raises:
1804
+ TypeError: If ``group`` is not a str.
1805
+ TypeError: If ``op_types``, ``receive_shapes``, ``receive_dtypes``, ``remote_ranks`` are not tuple or list.
1806
+ ValueError: If the length of ``receive_shapes`` and ``receive_dtypes`` are not the same.
1807
+ ValueError: If the length of ``op_types`` and ``remote_ranks`` are not the same.
1808
+ RuntimeError: If the length of input tensors and ``"isend"`` count are not the same.
1809
+
1810
+ Supported Platforms:
1811
+ ``Ascend``
1812
+
1813
+ Examples:
1814
+ .. note::
1815
+ Before running the following examples, you need to configure the communication environment variables.
1816
+
1817
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1818
+ without any third-party or configuration file dependencies.
1819
+
1820
+ Please see the `msrun start up
1821
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1822
+ for more details.
1823
+
1824
+ This example should be run with 2 devices.
1825
+
1826
+ >>> import numpy as np
1827
+ >>> import mindspore as ms
1828
+ >>> from mindspore import ops
1829
+ >>> import mindspore.nn as nn
1830
+ >>> from mindspore.communication import init, get_rank
1831
+ >>> from mindspore import Tensor
1832
+ >>>
1833
+ >>> init()
1834
+ >>> rank = get_rank()
1835
+ >>> class Net(nn.Cell):
1836
+ ... def __init__(self):
1837
+ ... super(Net, self).__init__()
1838
+ ... if rank == 0:
1839
+ ... remote_rank = [1, 1]
1840
+ ... else:
1841
+ ... remote_rank = [0, 0]
1842
+ ... self.batchisendirecv = ops.BatchISendIRecv(("isend", "irecv"), remote_rank, [()], (ms.float32,))
1843
+ ...
1844
+ ... def construct(self, x):
1845
+ ... if isinstance(x, Tensor):
1846
+ ... x = (x,)
1847
+ ... return self.batchisendirecv(x)
1848
+ ...
1849
+ >>> send_x = Tensor(rank + 1).astype(ms.float32)
1850
+ >>> net = Net()
1851
+ >>> output = net(send_x)
1852
+ >>> print(output)
1853
+ rank 0:
1854
+ (Tensor(shape=[], dtype=Float32, value= 0), Tensor(shape=[], dtype=Float32, value= 2))
1855
+ rank 1:
1856
+ (Tensor(shape=[], dtype=Float32, value= 0), Tensor(shape=[], dtype=Float32, value= 1))
1857
+
1858
+ Tutorial Examples:
1859
+ - `Distributed Set Communication Primitives - BatchISendIRecv
1860
+ <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#allgather>`_
1861
+
1862
+ """
1863
+
1864
+ @prim_attr_register
1865
+ def __init__(self, op_types, remote_ranks, receive_shapes=None,
1866
+ receive_dtypes=None, group=GlobalComm.WORLD_COMM_GROUP):
1867
+ if receive_shapes is None:
1868
+ receive_shapes = ()
1869
+ else:
1870
+ validator.check_value_type("receive_shapes", receive_shapes, [tuple, list], self.name)
1871
+
1872
+ if receive_dtypes is None:
1873
+ receive_dtypes = ()
1874
+ else:
1875
+ validator.check_value_type("receive_dtypes", receive_dtypes, [tuple, list], self.name)
1876
+
1877
+ validator.check_value_type("op_types", op_types, [tuple, list], self.name)
1878
+ validator.check_value_type("remote_ranks", remote_ranks, [tuple, list], self.name)
1879
+
1880
+ if len(receive_shapes) != len(receive_dtypes):
1881
+ raise ValueError("length of receive_shapes and receive_shapes must be the same, "
1882
+ f"but got receive_shapes: {len(receive_shapes)} "
1883
+ f" and receive_shapes: {receive_dtypes}")
1884
+
1885
+ if len(op_types) != len(remote_ranks):
1886
+ raise ValueError("length of op_types and remote_ranks must be the same.")
1887
+
1888
+ if group is None:
1889
+ group = GlobalComm.WORLD_COMM_GROUP
1890
+ self.add_prim_attr('group', group)
1891
+ self.add_prim_attr('op_types', op_types)
1892
+ self.add_prim_attr('remote_ranks', remote_ranks)
1893
+ self.add_prim_attr('receive_shapes', receive_shapes)
1894
+ self.add_prim_attr('receive_dtypes', receive_dtypes)
1895
+ self.add_prim_attr('no_eliminate', True)
1896
+
1897
+
1898
+ class AlltoAllV(PrimitiveWithInfer):
1899
+ """
1900
+ AllToAll which support uneven split.
1901
+
1902
+ Note:
1903
+ - Only support flatten tensor as input. input tensor should be flattened and
1904
+ concatenated before call this primitive.
1905
+
1906
+ Args:
1907
+ send_numel_list(Union[tuple[int], list[int]]): split numel to scatter to different remote rank.
1908
+ recv_numel_list(Union[tuple[int], list[int]]): split numel to gather from different remote rank.
1909
+ group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which
1910
+ means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
1911
+ TODO:
1912
+
1913
+ Inputs:
1914
+ - **input_x** (Tensor) - flatten tensor to scatter. The shape of tensor is :math:`(x_1)`.
1915
+
1916
+ Outputs:
1917
+ Tensor. flattened and concatenated tensor gather from remote ranks.
1918
+ If gather result is empty, it will return a Tensor with value 0, which has no actual meaning.
1919
+
1920
+ Raises:
1921
+ TypeError: If 'send_numel_list' or 'recv_numel_list' is not type of tuple and list.
1922
+
1923
+ Supported Platforms:
1924
+ ``Ascend``
1925
+
1926
+ Examples:
1927
+ .. note::
1928
+ Before running the following examples, you need to configure the communication environment variables.
1929
+
1930
+ For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1931
+ without any third-party or configuration file dependencies.
1932
+
1933
+ Please see the `msrun start up
1934
+ <https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
1935
+ for more details.
1936
+
1937
+ This example should be run with 2 devices.
1938
+
1939
+ >>> import numpy as np
1940
+ >>> import mindspore as ms
1941
+ >>> from mindspore import ops
1942
+ >>> import mindspore.nn as nn
1943
+ >>> from mindspore.communication import init, get_rank
1944
+ >>> from mindspore import Tensor
1945
+ >>>
1946
+ >>> init()
1947
+ >>> rank = get_rank()
1948
+ >>> class Net(nn.Cell):
1949
+ ... def __init__(self):
1950
+ ... super(Net, self).__init__()
1951
+ ... if rank == 0:
1952
+ ... self.all_to_all = ops.AlltoAllV([1, 2], [1, 2])
1953
+ ... else:
1954
+ ... self.all_to_all = ops.AlltoAllV([2, 1], [2, 1])
1955
+ ...
1956
+ ... def construct(self, x):
1957
+ ... return self.all_to_all(x)
1958
+ ...
1959
+ >>> if rank == 0:
1960
+ >>> send_tensor = Tensor([0, 1, 2.])
1961
+ >>> elif rank == 1:
1962
+ >>> send_tensor = Tensor([3, 4, 5.])
1963
+ >>> net = Net()
1964
+ >>> output = net(send_tensor)
1965
+ >>> print(output)
1966
+ rank 0:
1967
+ [0. 3. 4]
1968
+ rank 1:
1969
+ [1. 2. 5]
1970
+
1971
+ """
1972
+
1973
+ @prim_attr_register
1974
+ def __init__(self, send_numel_list, recv_numel_list, group=None, split_sizes_empty=False):
1975
+ validator.check_value_type("send_numel_list", send_numel_list, [tuple, list], self.name)
1976
+ validator.check_value_type("recv_numel_list", recv_numel_list, [tuple, list], self.name)
1977
+ self.group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
1978
+ self.send_numel_list = send_numel_list
1979
+ self.recv_numel_list = recv_numel_list
1980
+ self.split_sizes_empty = split_sizes_empty
1981
+ self.rank_size = get_group_size(self.group)
1982
+
1983
+ self.add_prim_attr('group', self.group)
1984
+ self.add_prim_attr('send_numel_list', send_numel_list)
1985
+ self.add_prim_attr('recv_numel_list', recv_numel_list)