mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2149 @@
1
+ # Copyright 2022-2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """array_ops vmap impl."""
17
+ from __future__ import absolute_import
18
+
19
+ import mindspore
20
+ import mindspore.numpy as mnp
21
+ from mindspore import ops
22
+ from mindspore.common import Tensor
23
+ from mindspore._c_expression import Tensor as Tensor_
24
+ from mindspore.ops import operations as P
25
+ from mindspore.ops import functional as F
26
+ from mindspore.ops.primitive import constexpr, _primexpr
27
+ from mindspore.ops.operations._grad_ops import MaskedSelectGrad
28
+ from mindspore.ops.operations import _grad_ops as G
29
+ from mindspore.ops.operations.array_ops import Fills, UniqueConsecutive, Col2Im, NonZero, IndexFill, \
30
+ TensorScatterElements
31
+ from mindspore.ops.operations.random_ops import RandomPoisson
32
+ from mindspore.ops.operations._inner_ops import DynamicBroadcastTo
33
+ from mindspore.ops.primitive import Primitive
34
+ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
35
+ _raise_value_error, _vmap_clone_prim, _handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis, \
36
+ get_unop_vmap_rule, _get_reduce_out_dim, _get_reduce_batch_axis, \
37
+ _bdim_at_any
38
+ from mindspore.ops.function import _VmapGeneralRule
39
+
40
+
41
+ @vmap_rules_getters.register(P.NoRepeatNGram)
42
+ def get_no_repeat_ngram_vmap_rule(prim, axis_size):
43
+ """VmapRule for `NoRepeatNGram` operation."""
44
+
45
+ def vmap_rule(state_seq_bdim, log_probs_bdim):
46
+ is_all_none, result = vmap_general_preprocess(prim, state_seq_bdim, log_probs_bdim)
47
+ if is_all_none:
48
+ return result
49
+
50
+ state_seq, state_seq_dim = state_seq_bdim
51
+ log_probs, log_probs_dim = log_probs_bdim
52
+ state_seq = _bdim_at_front(state_seq, state_seq_dim, axis_size)
53
+ log_probs = _bdim_at_front(log_probs, log_probs_dim, axis_size)
54
+ s_ori_shape = F.shape(state_seq)
55
+ l_ori_shape = F.shape(log_probs)
56
+ state_seq = F.reshape(state_seq, (-1,) + s_ori_shape[-2:])
57
+ log_probs = F.reshape(log_probs, (-1,) + l_ori_shape[-2:])
58
+ out = prim(state_seq, log_probs)
59
+ out = F.reshape(out, l_ori_shape)
60
+ return out, 0
61
+
62
+ return vmap_rule
63
+
64
+
65
+ @vmap_rules_getters.register("Cast")
66
+ def get_cast_vmap_rule(prim, axis_size):
67
+ """VmapRule for `Cast` operation."""
68
+ if isinstance(prim, str):
69
+ prim_name = prim
70
+ prim = Primitive(prim)
71
+ else:
72
+ prim_name = prim.name
73
+
74
+ def vmap_rule(input_bdim, type_bdim):
75
+ input_x, x_dim = input_bdim
76
+ dtype, type_dim = type_bdim
77
+ if type_dim is not None:
78
+ _raise_value_error("The source axis of 'type' in `{}` must be None, "
79
+ "but got {}.".format(prim_name, type_dim))
80
+ out = prim(input_x, dtype)
81
+ return out, x_dim
82
+
83
+ return vmap_rule
84
+
85
+
86
+ @vmap_rules_getters.register(P.Argmax)
87
+ @vmap_rules_getters.register(P.Argmin)
88
+ def get_argmin_vmap_rule(prim, axis_size):
89
+ """VmapRule for `Argmin` operations."""
90
+
91
+ def vmap_rule(x_bdim, axis_bdim, type_bdim):
92
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
93
+ if is_all_none:
94
+ return result
95
+ var, x_dim = x_bdim
96
+ axis_data, _ = axis_bdim
97
+ type_data, _ = type_bdim
98
+ x_ndim = ops.rank(var)
99
+ batch_axis = _get_reduce_batch_axis(axis_data, x_dim, x_ndim)
100
+ out = prim(var, batch_axis, type_data)
101
+ out_dim = _get_reduce_out_dim(x_dim, batch_axis)
102
+ return out, out_dim
103
+
104
+ return vmap_rule
105
+
106
+
107
+ @vmap_rules_getters.register(P.ArgMaxWithValue)
108
+ @vmap_rules_getters.register(P.ArgMinWithValue)
109
+ def get_arg_min_max_with_value_vmap_rule(prim, axis_size):
110
+ """VmapRule for `ArgMaxWithValue` and `ArgMinWithValue` operations."""
111
+ cum_fun_map = {
112
+ "ArgMaxWithValue": P.ArgMaxWithValue,
113
+ "ArgMinWithValue": P.ArgMinWithValue,
114
+ }
115
+ prim_class = cum_fun_map.get(prim.name)
116
+
117
+ def vmap_rule(x_bdim, axis_bdim, keep_dims_bdim):
118
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
119
+ if is_all_none:
120
+ return result
121
+ var, x_dim = x_bdim
122
+ axis_data, _ = axis_bdim
123
+ keep_dims_data, _ = keep_dims_bdim
124
+ x_ndim = ops.rank(var)
125
+ batch_axis = _get_reduce_batch_axis(axis_data, x_dim, x_ndim)
126
+ index, out = prim_class(batch_axis, keep_dims_data)(var)
127
+ out_dim = _get_reduce_out_dim(x_dim, batch_axis, keep_dims_data)
128
+ return (index, out_dim), (out, out_dim)
129
+
130
+ return vmap_rule
131
+
132
+
133
+ @_primexpr
134
+ def _get_prefix(indices_shape, axis_size, indices_dtype):
135
+ """
136
+ Generate prefix by indices shape, whose -1 axis value is the index value of axis 0.
137
+ eg. if the indices is Tensor([[[1, 2], [2, 3]],
138
+ [[2, 3], [3, 4]]])
139
+ we got the indices_shape (2, 2, 2),
140
+ the generated prefix is a Tensor([[[0], [0]],
141
+ [[1], [1]]])
142
+ """
143
+ def _check(indices_shape):
144
+ if not indices_shape:
145
+ raise ValueError("indices_shape is empty in _get_prefix.")
146
+
147
+ _check(indices_shape)
148
+ indices_len = len(indices_shape)
149
+ if indices_len == 1:
150
+ prefix = P.Range()(Tensor(0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype))
151
+ return prefix
152
+
153
+ indices_end = indices_len - 1
154
+ prefix_shape = ()
155
+ expand_shape = ()
156
+ for i, element in enumerate(indices_shape):
157
+ if i == indices_end:
158
+ prefix_shape = prefix_shape + (1,)
159
+ else:
160
+ prefix_shape = prefix_shape + (element,)
161
+ if i == 0:
162
+ expand_shape = expand_shape + (element,)
163
+ else:
164
+ expand_shape = expand_shape + (1,)
165
+
166
+ prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
167
+ 0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
168
+ return prefix
169
+
170
+
171
+ @vmap_rules_getters.register(P.Transpose)
172
+ def get_transpose_vmap_rule(prim, axis_size):
173
+ """VmapRule for `Transpose` operation."""
174
+ if isinstance(prim, str):
175
+ prim = Primitive(prim)
176
+
177
+ @_primexpr
178
+ def _get_transpose_batch_perm(dim, perm, x_rank):
179
+ """Generate batch_perm based on the original perm of transpose operation and dim of the input."""
180
+ if dim < 0:
181
+ dim = dim + x_rank
182
+ batch_perm = (dim,)
183
+
184
+ perm_len = len(perm)
185
+
186
+ for i in perm:
187
+ if i < 0:
188
+ i += perm_len
189
+
190
+ if i < dim:
191
+ batch_perm = batch_perm + (i,)
192
+ else:
193
+ index = i + 1
194
+ batch_perm = batch_perm + (index,)
195
+ return batch_perm
196
+
197
+ def vmap_rule(x_bdim, perm_bdim):
198
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, perm_bdim)
199
+ if is_all_none:
200
+ return result
201
+
202
+ x, dim = x_bdim
203
+ perm, perm_dim = perm_bdim
204
+ if perm_dim is not None:
205
+ _raise_value_error("The source axis of perm in `Transpose` must be None, "
206
+ "but got {}.".format(perm_dim))
207
+ x_rank = F.rank(x)
208
+ batch_perm = _get_transpose_batch_perm(dim, perm, x_rank)
209
+ out = prim(x, batch_perm)
210
+ return out, 0
211
+
212
+ return vmap_rule
213
+
214
+
215
+ @vmap_rules_getters.register("Tile")
216
+ def get_tile_vmap_rule(prim, axis_size):
217
+ """VmapRule for `P.Tile` operation."""
218
+
219
+ @_primexpr
220
+ def _get_batch_multiples(input_shape, dim, dims):
221
+ input_ndim = len(input_shape)
222
+ multiples_ndim = len(dims)
223
+ if multiples_ndim < input_ndim - 1:
224
+ dims = (1,) * (input_ndim - 1 - multiples_ndim) + dims
225
+
226
+ rev_dim = input_ndim - 1 - dim
227
+ if rev_dim == 0:
228
+ return dims + (1,), multiples_ndim
229
+
230
+ batch_multiples = list(dims)
231
+ batch_multiples.insert(-rev_dim, 1)
232
+ return tuple(batch_multiples), multiples_ndim - rev_dim
233
+
234
+ def vmap_rule(input_bdim, multiples_bdim):
235
+ is_all_none, result = vmap_general_preprocess(prim, input_bdim, multiples_bdim)
236
+ if is_all_none:
237
+ return result
238
+
239
+ input_x, dim = input_bdim
240
+ dims, dims_dim = multiples_bdim
241
+ if dims_dim is not None:
242
+ _raise_value_error("The source axis of shape in `Tile` must be None, but got {}.".format(dims_dim))
243
+
244
+ input_shape = F.shape(input_x)
245
+ batch_multiples, out_dim = _get_batch_multiples(input_shape, dim, dims)
246
+ repeat_tensor = prim(input_x, batch_multiples)
247
+ return repeat_tensor, out_dim
248
+
249
+ return vmap_rule
250
+
251
+
252
+ @vmap_rules_getters.register("Concat")
253
+ def get_concat_vmap_rule(prim, axis_size):
254
+ """VmapRule for `Concat` operation."""
255
+ @_primexpr
256
+ def _get_concat_batch_axis(axis):
257
+ new_axis = axis
258
+ if new_axis >= 0:
259
+ new_axis += 1
260
+ return new_axis
261
+
262
+ def vmap_rule(inputs_bdim, axis_bdim):
263
+ is_all_none, result = vmap_general_preprocess(prim, inputs_bdim, axis_bdim)
264
+ if is_all_none:
265
+ return result
266
+
267
+ if not isinstance(inputs_bdim, (tuple, list)):
268
+ _raise_value_error("The 'x' of Concat is neither tuple nor list.")
269
+
270
+ vals = ()
271
+ for each_arg in inputs_bdim:
272
+ x, bdim = each_arg
273
+ x = _bdim_at_front(x, bdim, axis_size)
274
+ vals = vals + (x,)
275
+
276
+ axis, axis_dim = axis_bdim
277
+ if axis_dim is not None:
278
+ _raise_value_error("The source axis of `axis` in P.Concat must be None, but got {}.".format(axis_dim))
279
+ axis = _get_concat_batch_axis(axis)
280
+
281
+ out = prim(vals, axis)
282
+ return out, 0
283
+
284
+ return vmap_rule
285
+
286
+
287
+ @vmap_rules_getters.register(P.Stack)
288
+ def get_stack_vmap_rule(prim, axis_size):
289
+ """VmapRule for `Stack` operation."""
290
+ if isinstance(prim, str):
291
+ prim = P.Stack(0)
292
+ new_axis = 0
293
+ else:
294
+ new_axis = prim.axis
295
+ if new_axis >= 0:
296
+ new_axis += 1
297
+
298
+ def vmap_rule(*inputs_bdim):
299
+ is_all_none, result = vmap_general_preprocess(prim, *inputs_bdim)
300
+ if is_all_none:
301
+ return result
302
+
303
+ if not isinstance(inputs_bdim, (tuple, list)):
304
+ _raise_value_error("The 'x' of P.Stack is neither tuple nor list.")
305
+
306
+ args = inputs_bdim[0]
307
+ vals = ()
308
+ for each_arg in args:
309
+ x, bdim = each_arg
310
+ x = _bdim_at_front(x, bdim, axis_size)
311
+ vals = vals + (x,)
312
+
313
+ out = P.Stack(new_axis)(vals)
314
+ return out, 0
315
+
316
+ return vmap_rule
317
+
318
+
319
+ @vmap_rules_getters.register(P.Unstack)
320
+ def get_unstack_vmap_rule(prim, axis_size):
321
+ """VmapRule for `Unstack` operation."""
322
+ if isinstance(prim, str):
323
+ prim = P.Unstack(0)
324
+ new_axis = 0
325
+ else:
326
+ new_axis = prim.axis
327
+ if new_axis >= 0:
328
+ new_axis += 1
329
+
330
+ def vmap_rule(inputs_bdim):
331
+ is_all_none, result = vmap_general_preprocess(prim, inputs_bdim)
332
+ if is_all_none:
333
+ return result
334
+
335
+ x, bdim = inputs_bdim
336
+ x = _bdim_at_front(x, bdim, axis_size)
337
+
338
+ outputs = P.Unstack(new_axis)(x)
339
+ outputs_tuple = ()
340
+ for output in outputs:
341
+ outputs_tuple = outputs_tuple + ((output, 0),)
342
+ return outputs_tuple
343
+
344
+ return vmap_rule
345
+
346
+
347
+ @vmap_rules_getters.register(P.Reshape)
348
+ def get_reshape_vmap_rule(prim, axis_size):
349
+ """VmapRule for `Reshape` operation."""
350
+
351
+ @_primexpr
352
+ def get_batch_shape(x_shape, x_dim, target_shape, axis_size):
353
+ def _check(neg_index, target_shape):
354
+ if neg_index != -1:
355
+ raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
356
+
357
+ if x_dim == 0:
358
+ return (axis_size,) + target_shape, 0, False
359
+
360
+ if x_dim in (len(x_shape) - 1, -1):
361
+ return target_shape + (axis_size,), len(target_shape), False
362
+
363
+ neg_index = -1
364
+ dim_prod = 1
365
+ for i, shp_i in enumerate(target_shape):
366
+ if shp_i == -1:
367
+ _check(neg_index, target_shape)
368
+ neg_index = i
369
+ else:
370
+ dim_prod *= shp_i
371
+ arr_prod = 1
372
+ for i in x_shape:
373
+ arr_prod *= i
374
+ target_shape_list = list(target_shape)
375
+ if neg_index != -1:
376
+ neg_index_size = int(arr_prod // (dim_prod * axis_size))
377
+ target_shape_list[neg_index] = neg_index_size
378
+
379
+ arr_prod_before_dim = 1
380
+ for i in x_shape[:x_dim]:
381
+ arr_prod_before_dim *= i
382
+ dim_prod = 1
383
+ for i, shp_i in enumerate(target_shape_list, start=1):
384
+ dim_prod *= shp_i
385
+ if dim_prod == arr_prod_before_dim:
386
+ return tuple(target_shape_list[:i]) + (axis_size,) + tuple(target_shape_list[i:]), i, False
387
+ if dim_prod > arr_prod_before_dim:
388
+ return 0, 0, True
389
+
390
+ return 0, 0, True
391
+
392
+ def vmap_rule(operand_bdim, shape_bdim):
393
+ is_all_none, result = vmap_general_preprocess(prim, operand_bdim, shape_bdim)
394
+ if is_all_none:
395
+ return result
396
+
397
+ x, dim = operand_bdim
398
+ shape, shape_dim = shape_bdim
399
+ if shape_dim is not None:
400
+ _raise_value_error("The source axis of shape in `Reshape` must be None, but got {}.".format(shape_dim))
401
+
402
+ x_shape = F.shape(x)
403
+ batch_shape, out_axis, need_moveaxis = get_batch_shape(x_shape, dim, shape, axis_size)
404
+ if need_moveaxis:
405
+ # for such case: `x_shape` is (2, 3, 4, 5, 6), `x_dim` is 3, and `shape` is (-1,)
406
+ x = mnp.moveaxis(x, dim, 0)
407
+ batch_shape = (axis_size,) + shape
408
+ out = prim(x, batch_shape)
409
+ return out, 0
410
+
411
+ out = prim(x, batch_shape)
412
+ return out, out_axis
413
+
414
+ return vmap_rule
415
+
416
+
417
+ @vmap_rules_getters.register(P.ReverseSequence)
418
+ def get_reverse_sequence_vmap_rule(prim, axis_size):
419
+ """VmapRule for `ReverseSequence` operation."""
420
+ if isinstance(prim, str):
421
+ prim = Primitive(prim)
422
+ reshape = P.Reshape()
423
+ batch_dim = prim.batch_dim_
424
+ seq_dim = prim.seq_dim_
425
+
426
+ @_primexpr
427
+ def get_batch_seq_dim(dim, batch_dim_, seq_dim_):
428
+ if dim is None:
429
+ batch_dim_ += 1
430
+ seq_dim_ += 1
431
+ else:
432
+ if seq_dim_ == dim:
433
+ seq_dim_ += 1
434
+ if seq_dim_ == batch_dim_:
435
+ batch_dim_ += 1
436
+ elif batch_dim_ == dim:
437
+ batch_dim_ += 1
438
+ if seq_dim_ == batch_dim_:
439
+ seq_dim_ += 1
440
+ return batch_dim_, seq_dim_
441
+
442
+ @_primexpr
443
+ def get_seq_dim(dim, batch_dim_, seq_dim_):
444
+ if dim is None:
445
+ return seq_dim_
446
+ if seq_dim_ < dim and seq_dim_ < batch_dim_:
447
+ seq_dim_ = seq_dim_ + 1
448
+ elif seq_dim_ > dim and seq_dim_ > batch_dim_:
449
+ seq_dim_ = seq_dim_ - 1
450
+ return seq_dim_
451
+
452
+ def vmap_rule(x_bdim, seq_lengths_bdim):
453
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, seq_lengths_bdim)
454
+ if is_all_none:
455
+ return result
456
+ x, dim = x_bdim
457
+ seq_lengths, seq_lengths_dim = seq_lengths_bdim
458
+ seq_lengths = _bdim_at_front(seq_lengths, seq_lengths_dim, axis_size)
459
+ origin_shape = x.shape
460
+ batch_dim_, seq_dim_ = get_batch_seq_dim(dim, batch_dim, seq_dim)
461
+ if dim is None:
462
+ x = _bdim_at_front(x, dim, axis_size)
463
+ origin_shape = x.shape
464
+ x = mnp.moveaxis(x, batch_dim_, 1)
465
+ real_dim = 0
466
+ else:
467
+ x = mnp.moveaxis(x, [dim, batch_dim_], [0, 1])
468
+ real_dim = dim
469
+ shape = x.shape
470
+ shape = (shape[0] * shape[1],) + tuple(shape[2:])
471
+ x = reshape(x, shape)
472
+ seq_dim_ = get_seq_dim(dim, batch_dim_, seq_dim_)
473
+ seq_lengths = reshape(seq_lengths, (-1,))
474
+ x = P.ReverseSequence(seq_dim=seq_dim_)(x, seq_lengths)
475
+ shape = x.shape
476
+ shape = (origin_shape[real_dim], origin_shape[batch_dim_],) + tuple(shape[1:])
477
+ out = reshape(x, shape)
478
+ if batch_dim_ not in (0, 1):
479
+ out = mnp.moveaxis(out, 1, batch_dim_)
480
+ return out, 0
481
+
482
+ return vmap_rule
483
+
484
+
485
+ @vmap_rules_getters.register(P.Flatten)
486
+ def get_flatten_vmap_rule(prim, axis_size):
487
+ """VmapRule for `Flatten` operation."""
488
+
489
+ def vmap_rule(x_bdim):
490
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
491
+ if is_all_none:
492
+ return result
493
+
494
+ x, x_dim = x_bdim
495
+ x = _bdim_at_front(x, x_dim, axis_size)
496
+ x_shape = F.shape(x)
497
+ output = F.reshape(x, x_shape[0:2] + (-1,))
498
+ return output, 0
499
+
500
+ return vmap_rule
501
+
502
+
503
+ @vmap_rules_getters.register(P.Select)
504
+ def get_select_vmap_rule(prim, axis_size):
505
+ """VmapRule for 'Select' operation."""
506
+ if isinstance(prim, str):
507
+ prim = P.Select()
508
+
509
+ def vmap_rule(condition_bdim, x_bdim, y_bdim):
510
+ is_all_none, result = vmap_general_preprocess(prim, condition_bdim, x_bdim, y_bdim)
511
+ if is_all_none:
512
+ return result
513
+
514
+ condition, condition_dim = condition_bdim
515
+ x, x_dim = x_bdim
516
+ y, y_dim = y_bdim
517
+
518
+ condition = _bdim_at_front(condition, condition_dim, axis_size)
519
+ x = _bdim_at_front(x, x_dim, axis_size)
520
+ y = _bdim_at_front(y, y_dim, axis_size)
521
+
522
+ out = prim(condition, x, y)
523
+
524
+ return out, 0
525
+
526
+ return vmap_rule
527
+
528
+
529
+ @vmap_rules_getters.register(P.ScatterNd)
530
+ def get_scatter_nd_vmap_rule(prim, axis_size):
531
+ """
532
+ VmapRule for `ScatterNd` operation.
533
+
534
+ An example for the rule:
535
+ --- inputs info
536
+ indices.shape = [10, 3, 2, 2]
537
+ updates.shape = [10, 3, 2, 5]
538
+ shape = [6, 4, 5]
539
+ the first dim (10) is batch.
540
+ the shape without batch dim are:
541
+ indices.shape = [3, 2, 2]
542
+ updates.shape = [3, 2, 5]
543
+ shape = [6, 4, 5]
544
+ --- step 1
545
+ Change the `shape` to `[60, 4, 5]`, set the indices `offset` to 6 (original first dim).
546
+ Since there's a constraint `updates.shape = indices.shape[:-1] + shape[indices.shape[-1]:]` in the `ScatterNd` op,
547
+ so the `shape` with a batch dim is invalid, but its first dim can be changed.
548
+ --- step 2
549
+ Generate an constant offset tensor for the indices, which `indices_offset.shape = [10, 1, 1, 2]`,
550
+ for i in [0, 10), set `indices_offset[i, :, :, 0] = i * offset`.
551
+ The output batch dim was concat by original 0-axis, so the indices should be offset.
552
+ Only the 0-dim of output is changed, so only the `indices_offset[i,:,:,0]` is set, and the `indices_offset[i,:,:,1]`
553
+ is leave as zero.
554
+ --- step 3
555
+ Add the `indices_offset` with `indices`.
556
+ --- step 4
557
+ Call `ScatterNd` with new `indices`, old `updates`, and new `shape (60, 4, 5)`.
558
+ --- step 5
559
+ Reshape the output tensor to `[10, 6, 4, 5]`
560
+ """
561
+
562
+ @_primexpr
563
+ def _refine_shape(shape, bdim_size):
564
+ offset = shape[0]
565
+ return (bdim_size * shape[0],) + tuple(shape[1:]), offset, (bdim_size,) + tuple(shape)
566
+
567
+ @_primexpr
568
+ def _gen_indices_offset(shape, offset):
569
+ # original rank(indices.shape) is required >= 2, so indices with batch dim's rank >= 3.
570
+ shape = (shape[0],) + (1,) * (len(shape) - 2) + (shape[-1],)
571
+ val = P.Zeros()((shape[0], shape[-1]), mindspore.int32)
572
+ for i in range(shape[0]):
573
+ val[i, 0] = i * offset
574
+ return P.Reshape()(val, shape)
575
+
576
+ if isinstance(prim, str):
577
+ prim = Primitive(prim)
578
+
579
+ def vmap_rule(indices_bdim, updates_bdim, shape_bdim):
580
+ is_all_none, result = vmap_general_preprocess(prim, indices_bdim, updates_bdim, shape_bdim)
581
+ if is_all_none:
582
+ return result
583
+ indices, indices_dim = indices_bdim
584
+ updates, updates_dim = updates_bdim
585
+ shape, shape_dim = shape_bdim
586
+ if shape_dim is not None:
587
+ _raise_value_error("The source axis of `shape` in `{}` must be None, "
588
+ "but got {}.".format(prim.name, shape_dim))
589
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
590
+ updates = _bdim_at_front(updates, updates_dim, axis_size)
591
+ new_shape, offset, out_shape = _refine_shape(shape, axis_size)
592
+ indices_shape = F.shape(indices)
593
+ indices_dtype = F.dtype(indices)
594
+ offset_val = _gen_indices_offset(indices_shape, offset)
595
+ indices_offset = P.Cast()(offset_val, indices_dtype)
596
+ new_indices = P.Add()(indices, indices_offset)
597
+ out = prim(new_indices, updates, new_shape)
598
+ real_out = P.Reshape()(out, out_shape)
599
+ return real_out, 0
600
+
601
+ return vmap_rule
602
+
603
+
604
+ @vmap_rules_getters.register(P.ScatterAdd)
605
+ @vmap_rules_getters.register(P.ScatterMul)
606
+ @vmap_rules_getters.register(P.ScatterMin)
607
+ @vmap_rules_getters.register(P.ScatterMax)
608
+ @vmap_rules_getters.register(P.ScatterDiv)
609
+ @vmap_rules_getters.register(P.ScatterNdAdd)
610
+ @vmap_rules_getters.register(P.ScatterNdSub)
611
+ @vmap_rules_getters.register(P.ScatterNdMin)
612
+ @vmap_rules_getters.register(P.ScatterNdMax)
613
+ @vmap_rules_getters.register(P.array_ops.ScatterNdMul)
614
+ @vmap_rules_getters.register(P.ScatterNdDiv)
615
+ @vmap_rules_getters.register(P.ScatterNdUpdate)
616
+ @vmap_rules_getters.register(P.ScatterUpdate)
617
+ def get_scatter_op_vmap_rule(prim, axis_size):
618
+ """
619
+ VmapRule for `Scatter*` operations, such as `ScatterAdd`, `ScatterNdAdd`, `ScatterMin` and `ScatterMax`.
620
+ scatter_func_map: high-dimensional implementation for recording Scatter class operators
621
+ and ScatterNd class operators.
622
+ scatter_func_list: used to record all Scatter class operators.
623
+ """
624
+ scatter_func_map = {
625
+ "ScatterAdd": P.ScatterNdAdd,
626
+ "ScatterMul": P.array_ops.ScatterNdMul,
627
+ "ScatterMin": P.ScatterNdMin,
628
+ "ScatterMax": P.ScatterNdMax,
629
+ "ScatterDiv": P.ScatterNdDiv,
630
+ "ScatterNdAdd": P.ScatterNdAdd,
631
+ "ScatterNdSub": P.ScatterNdSub,
632
+ "ScatterNdMin": P.ScatterNdMin,
633
+ "ScatterNdMax": P.ScatterNdMax,
634
+ "ScatterNdMul": P.array_ops.ScatterNdMul,
635
+ "ScatterNdDiv": P.ScatterNdDiv,
636
+ "ScatterNdUpdate": P.ScatterNdUpdate,
637
+ "ScatterUpdate": P.ScatterNdUpdate
638
+ }
639
+ scatter_func_list = ["ScatterAdd", "ScatterMul", "ScatterMin", "ScatterMax", "ScatterDiv", "ScatterUpdate"]
640
+ if isinstance(prim, str):
641
+ prim_name = prim
642
+ prim = Primitive(prim)
643
+ use_locking = False
644
+ else:
645
+ prim_name = prim.name
646
+ use_locking = prim.use_locking
647
+
648
+ scatter_func = scatter_func_map.get(prim_name)(use_locking)
649
+ concat = P.Concat(-1)
650
+
651
+ def vmap_rule(ref_bdim, indices_bdim, updates_bdim, u_monad):
652
+ ref, ref_dim = ref_bdim
653
+ indices, indices_dim = indices_bdim
654
+ updates, updates_dim = updates_bdim
655
+
656
+ if ref_dim is None:
657
+ if indices_dim is not None or updates_dim is not None:
658
+ _raise_value_error("The source axis of `ref` is None, but the source axis of "
659
+ "`indices` or `updates` is not None. The execution order of "
660
+ "operator `{}` cannot be guaranteed.".format(prim_name))
661
+ out = prim(ref, indices, updates, u_monad)
662
+ elif ref_dim == 0:
663
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
664
+ updates = _bdim_at_front(updates, updates_dim, axis_size)
665
+ if prim_name in scatter_func_list:
666
+ indices = F.expand_dims(indices, -1)
667
+
668
+ indices_shape = F.shape(indices)
669
+ prefix = _get_prefix(indices_shape, axis_size, F.dtype(indices))
670
+ indices = concat((prefix, indices))
671
+ out = scatter_func(ref, indices, updates, u_monad)
672
+ else:
673
+ _raise_value_error("The source axis of `ref` in `{}` must be 0 or None, "
674
+ "but got {}.".format(prim_name, ref_dim))
675
+ out = None
676
+ return out, ref_dim
677
+
678
+ return vmap_rule
679
+
680
+
681
+ @vmap_rules_getters.register(G.SliceGrad)
682
+ def get_slice_grad_vmap_rule(prim, axis_size):
683
+ """VmapRule for `SliceGrad` operation."""
684
+ if isinstance(prim, str):
685
+ prim_name = prim
686
+ prim = Primitive(prim)
687
+ else:
688
+ prim_name = prim.name
689
+
690
+ def vmap_rule(dy_bdim, x_bdim, begin_bdim, size_bdim):
691
+ is_all_none, result = vmap_general_preprocess(prim, dy_bdim, x_bdim, begin_bdim, size_bdim)
692
+ if is_all_none:
693
+ return result
694
+
695
+ dy, dy_dim = dy_bdim
696
+ x, x_dim = x_bdim
697
+ begin, begin_dim = begin_bdim
698
+ size, size_dim = size_bdim
699
+
700
+ if begin_dim is not None:
701
+ _raise_value_error("The source axis of `begin` in {} only supports None currently, "
702
+ "but got {}.".format(prim_name, begin_dim))
703
+ if size_dim is not None:
704
+ _raise_value_error("The source axis of `size` in {} must be None, but got {}.".format(prim_name, size_dim))
705
+
706
+ dy = _bdim_at_front(dy, dy_dim, axis_size)
707
+ x = _bdim_at_front(x, x_dim, axis_size)
708
+
709
+ batch_begin = (0,) + begin
710
+ batch_size = (axis_size,) + size
711
+
712
+ out = prim(dy, x, batch_begin, batch_size)
713
+
714
+ return out, 0
715
+
716
+ return vmap_rule
717
+
718
+
719
+ @vmap_rules_getters.register(P.TensorScatterAdd)
720
+ @vmap_rules_getters.register(P.TensorScatterSub)
721
+ @vmap_rules_getters.register(P.TensorScatterMul)
722
+ @vmap_rules_getters.register(P.TensorScatterDiv)
723
+ @vmap_rules_getters.register(P.TensorScatterMax)
724
+ def get_tensor_scatter_op_vmap_rule(prim, axis_size):
725
+ """
726
+ VmapRule for `TensorScatter*` operations, such as `TensorScatterMul`.
727
+ tensor_scatter_func_map: TensorScatter implementation for recording TensorScatter class operators.
728
+ """
729
+ tensor_scatter_func_map = {
730
+ "TensorScatterAdd": P.TensorScatterAdd,
731
+ "TensorScatterSub": P.TensorScatterSub,
732
+ "TensorScatterMul": P.TensorScatterMul,
733
+ "TensorScatterDiv": P.TensorScatterDiv,
734
+ "TensorScatterMax": P.TensorScatterMax,
735
+ }
736
+ if isinstance(prim, str):
737
+ prim_name = prim
738
+ prim = Primitive(prim)
739
+ else:
740
+ prim_name = prim.name
741
+
742
+ tensor_scatter_func = tensor_scatter_func_map.get(prim_name)()
743
+ concat = P.Concat(-1)
744
+
745
+ def vmap_rule(input_x_bdim, indices_bdim, updates_bdim):
746
+ is_all_none, result = vmap_general_preprocess(prim, input_x_bdim, indices_bdim, updates_bdim)
747
+ if is_all_none:
748
+ return result
749
+ input_x, input_x_dim = input_x_bdim
750
+ indices, indices_dim = indices_bdim
751
+ updates, updates_dim = updates_bdim
752
+
753
+ input_x = _bdim_at_front(input_x, input_x_dim, axis_size)
754
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
755
+ updates = _bdim_at_front(updates, updates_dim, axis_size)
756
+
757
+ indices_shape = F.shape(indices)
758
+ prefix = _get_prefix(indices_shape, axis_size, F.dtype(indices))
759
+ indices = concat((prefix, indices))
760
+ out = tensor_scatter_func(input_x, indices, updates)
761
+ return out, input_x_dim
762
+
763
+ return vmap_rule
764
+
765
+
766
+ @vmap_rules_getters.register(P.UnsortedSegmentMin)
767
+ @vmap_rules_getters.register(P.UnsortedSegmentMax)
768
+ @vmap_rules_getters.register(P.UnsortedSegmentProd)
769
+ def get_unsorted_segment_arithmetic_vmap_rule(prim, axis_size):
770
+ """VmapRule for `UnsortedSegment*` operation."""
771
+
772
+ unsorted_segment_func_map = {
773
+ "UnsortedSegmentMin": P.UnsortedSegmentMin,
774
+ "UnsortedSegmentMax": P.UnsortedSegmentMax,
775
+ "UnsortedSegmentProd": P.UnsortedSegmentProd,
776
+ }
777
+ prim_name = prim.name
778
+ unsorted_segment_func = unsorted_segment_func_map.get(prim_name)()
779
+
780
+ if hasattr(prim, 'batch_rank'):
781
+ batch_rank = prim.batch_rank + 1
782
+ else:
783
+ batch_rank = 1
784
+
785
+ unsorted_segment_func.add_prim_attr('batch_rank', batch_rank)
786
+
787
+ def vmap_rule(input_bdim, segment_ids_bdim, num_segment_bdim):
788
+ is_all_none, result = vmap_general_preprocess(prim, input_bdim, segment_ids_bdim, num_segment_bdim)
789
+ if is_all_none:
790
+ return result
791
+
792
+ # num_segment affect output shape, must be none
793
+ num_segment, num_segment_dim = num_segment_bdim
794
+ if num_segment_dim is not None:
795
+ _raise_value_error("The source axis of `num_segment` in `{}` must be None, "
796
+ "but got {}.".format(prim_name, num_segment_dim))
797
+
798
+ input_value, input_dim = input_bdim
799
+ segment_ids, segment_ids_dim = segment_ids_bdim
800
+
801
+ input_value = _bdim_at_front(input_value, input_dim, axis_size)
802
+ segment_ids = _bdim_at_front(segment_ids, segment_ids_dim, axis_size)
803
+
804
+ out = unsorted_segment_func(input_value, segment_ids, num_segment)
805
+ return out, 0
806
+
807
+ return vmap_rule
808
+
809
+
810
+ @vmap_rules_getters.register(P.UnsortedSegmentSum)
811
+ def get_unsorted_segment_sum_vmap_rule(prim, axis_size):
812
+ """VmapRule for `UnsortedSegmentSum*` operation."""
813
+
814
+ prim_name = prim.name
815
+ if prim.has_label("batch_rank"):
816
+ batch_rank = prim.get_label("batch_rank") + 1
817
+ else:
818
+ batch_rank = 1
819
+
820
+ prim = prim.clone()
821
+ prim.set_label('batch_rank', batch_rank)
822
+
823
+ def vmap_rule(input_bdim, segment_ids_bdim, num_segment_bdim):
824
+ is_all_none, result = vmap_general_preprocess(prim, input_bdim, segment_ids_bdim, num_segment_bdim)
825
+ if is_all_none:
826
+ return result
827
+
828
+ # num_segment affect output shape, must be none
829
+ num_segment, num_segment_dim = num_segment_bdim
830
+ if num_segment_dim is not None:
831
+ _raise_value_error("The source axis of `num_segment` in `{}` must be None, "
832
+ "but got {}.".format(prim_name, num_segment_dim))
833
+
834
+ input_value, input_dim = input_bdim
835
+ segment_ids, segment_ids_dim = segment_ids_bdim
836
+
837
+ input_value = _bdim_at_front(input_value, input_dim, axis_size)
838
+ segment_ids = _bdim_at_front(segment_ids, segment_ids_dim, axis_size)
839
+
840
+ out = prim(input_value, segment_ids, num_segment)
841
+ return out, 0
842
+
843
+ return vmap_rule
844
+
845
+
846
+ @vmap_rules_getters.register(P.Fill)
847
+ def get_fill_vmap_rule(prim, axis_size):
848
+ """VmapRule for `Fill` operation."""
849
+ if isinstance(prim, str):
850
+ prim = Primitive(prim)
851
+ cast_op = P.Cast()
852
+
853
+ def vmap_rule(dtype_bdim, shape_bdim, value_bdim):
854
+ is_all_none, result = vmap_general_preprocess(prim, dtype_bdim, shape_bdim, value_bdim)
855
+ if is_all_none:
856
+ return result
857
+ dtype, type_dim = dtype_bdim
858
+ if type_dim is not None:
859
+ _raise_value_error("The source axis of `type` in `P.Fill` must be None, but got {}.".format(type_dim))
860
+ value_shape, shape_dim = shape_bdim
861
+ if shape_dim is not None:
862
+ _raise_value_error("The source axis of `shape` in `P.Fill` must be None, but got {}.".format(shape_dim))
863
+ value, vdim = value_bdim
864
+ value_rank = F.rank(value)
865
+ if value_rank != 1 or vdim != 0:
866
+ _raise_value_error("The `value` in `P.Fill` must be constant value, thus the value only "
867
+ "can be rank: 1 with source axis: 0 in vmap scope, but got value rank: "
868
+ "{} with source axis: {}.".format(value_rank, vdim))
869
+ value = cast_op(value, dtype)
870
+ value = F.reshape(value, (axis_size,) + (1,) * len(value_shape))
871
+ out = P.BroadcastTo((axis_size,) + value_shape)(value)
872
+ return out, 0
873
+
874
+ return vmap_rule
875
+
876
+
877
+ @constexpr
878
+ def to_tensor_with_type(x, dtype):
879
+ """x to Tensor with type"""
880
+ return Tensor(x, dtype)
881
+
882
+
883
+ @vmap_rules_getters.register(P.FillV2)
884
+ def get_fill_v2_vmap_rule(prim, axis_size):
885
+ """VmapRule for `FillV2` operation."""
886
+ if isinstance(prim, str):
887
+ prim = Primitive(prim)
888
+
889
+ def vmap_rule(shape_bdim, value_bdim):
890
+ is_all_none, result = vmap_general_preprocess(prim, shape_bdim, value_bdim)
891
+ if is_all_none:
892
+ return result
893
+
894
+ value_shape, shape_dim = shape_bdim
895
+ if shape_dim is not None:
896
+ _raise_value_error(
897
+ "The source axis of `shape` in `P.FillV2` must be None, but got {}."
898
+ .format(shape_dim))
899
+
900
+ value, vdim = value_bdim
901
+ value_rank = F.rank(value)
902
+ if value_rank != 1 or vdim != 0:
903
+ _raise_value_error(
904
+ "The `value` in `P.FillV2` must be constant value, thus the value only "
905
+ "can be rank: 1 with source axis: 0 in vmap scope, but got value rank: "
906
+ "{} with source axis: {}.".format(value_rank, vdim))
907
+ value = F.reshape(value, (axis_size,) + (1,) * len(value_shape))
908
+
909
+ out = None
910
+ if isinstance(value_shape, (Tensor_, Tensor)):
911
+ value_shape_rank = F.rank(value_shape)
912
+ if value_shape_rank != 1:
913
+ _raise_value_error(
914
+ "The `shape` in `P.FillV2` must be 1-D tensor, thus the shape only "
915
+ "can be rank: 1, but got shape rank: "
916
+ "{}.".format(value_shape_rank))
917
+ axis_size_tensor = to_tensor_with_type((axis_size,),
918
+ F.dtype(value_shape))
919
+ broad_cast_shape = F.concat((axis_size_tensor, value_shape))
920
+ out = DynamicBroadcastTo()(value, broad_cast_shape)
921
+ elif isinstance(value_shape, tuple):
922
+ out = P.BroadcastTo((axis_size,) + value_shape)(value)
923
+ else:
924
+ _raise_value_error(
925
+ f"For `P.FillV2`, the input `shape` should be Tuple or Tensor, but got `shape`: {value_shape}."
926
+ )
927
+
928
+ return out, 0
929
+
930
+ return vmap_rule
931
+
932
+
933
+ @vmap_rules_getters.register(Fills)
934
+ def get_fills_vmap_rule(prim, axis_size):
935
+ """VmapRule for `Fills` operation."""
936
+ if isinstance(prim, str):
937
+ prim = Primitive(prim)
938
+ cast_op = P.Cast()
939
+
940
+ def vmap_rule(x_bdim, value_bdim):
941
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, value_bdim)
942
+ if is_all_none:
943
+ return result
944
+ x, x_batch_dim = x_bdim
945
+ value, value_batch_dim = value_bdim
946
+ out_type = x.dtype
947
+ out_shape = x.shape
948
+ value = cast_op(value, out_type)
949
+ if value_batch_dim is None:
950
+ out = P.BroadcastTo(out_shape)(value)
951
+ return out, x_batch_dim
952
+ value_rank = F.rank(value)
953
+ if value_rank != 1 or value_batch_dim != 0:
954
+ _raise_value_error("The `value` in `F.fills` only accept scalar or 0-dims tensor, thus the value only "
955
+ "can be rank: 1 with source axis: 0 in vmap scope, but got value rank: "
956
+ "{} with source axis: {}.".format(value_rank, value_batch_dim))
957
+ if x_batch_dim is None:
958
+ value = F.reshape(value, (axis_size,) + (1,) * len(out_shape))
959
+ out = P.BroadcastTo((axis_size,) + out_shape)(value)
960
+ else:
961
+ x = _bdim_at_front(x, x_batch_dim, axis_size)
962
+ out_shape = x.shape
963
+ value = F.reshape(value, (axis_size,) + (1,) * (len(out_shape) - 1))
964
+ out = P.BroadcastTo(out_shape)(value)
965
+ return out, 0
966
+
967
+ return vmap_rule
968
+
969
+
970
+ @vmap_rules_getters.register(P.Range)
971
+ def get_range_vmap_rule(prim, axis_size):
972
+ """VmapRule for `Range` operation."""
973
+ if isinstance(prim, str):
974
+ prim = Primitive(prim)
975
+
976
+ def vmap_rule(start_bdim, limit_bdim, delta_bdim):
977
+ is_all_none, result = vmap_general_preprocess(prim, start_bdim, limit_bdim, delta_bdim)
978
+ if not is_all_none:
979
+ _, start_dim = start_bdim
980
+ _, limit_dim = limit_bdim
981
+ _, delta_dim = delta_bdim
982
+ _raise_value_error("For operator Range, all axis for inputs should be None, but got start_dim: {},"
983
+ " limit_dim: {} and delta_dim: {}.".format(start_dim, limit_dim, delta_dim))
984
+ return result
985
+
986
+ return vmap_rule
987
+
988
+
989
+ @vmap_rules_getters.register(P.UniqueWithPad)
990
+ def get_unique_with_pad_vmap_rule(prim, axis_size):
991
+ """VmapRule for `UniqueWithPad` operations.
992
+ if isinstance(prim, str):
993
+ prim = P.UniqueWithPad()
994
+
995
+ prim_vmap = _VmapGeneralRule(prim, axis_size)
996
+
997
+ def vmap_rule(x_bdim, pad_num_bdim):
998
+ return prim_vmap(x_bdim, pad_num_bdim)
999
+
1000
+ return vmap_rule
1001
+ """
1002
+ if hasattr(prim, 'batch_rank'):
1003
+ batch_rank = prim.batch_rank + 1
1004
+ else:
1005
+ batch_rank = 1
1006
+
1007
+ batch_prim = _vmap_clone_prim(prim)
1008
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1009
+
1010
+ def vmap_rule(x_bdim, pad_num_bdim):
1011
+ x, x_dim = x_bdim
1012
+ pad_num, pad_num_dim = pad_num_bdim
1013
+
1014
+ x = _bdim_at_front(x, x_dim, axis_size)
1015
+ pad_num = _bdim_at_front(pad_num, pad_num_dim, axis_size)
1016
+ y, idx = batch_prim(x, pad_num)
1017
+ return (y, 0), (idx, 0)
1018
+
1019
+ return vmap_rule
1020
+
1021
+
1022
+ @vmap_rules_getters.register(P.array_ops.MatrixDiagV3)
1023
+ def get_matrix_diag_v3_vmap_rule(prim, axis_size):
1024
+ """VmapRule for `MatrixDiagV3` operation."""
1025
+ if isinstance(prim, str):
1026
+ prim_name = prim
1027
+ prim = P.array_ops.MatrixDiagV3()
1028
+ else:
1029
+ prim_name = prim.name
1030
+
1031
+ def vmap_rule(x_bdim, k_bdim, num_rows_bdim, num_cols_bdim, padding_value_bdim):
1032
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, k_bdim, num_rows_bdim, num_cols_bdim,
1033
+ padding_value_bdim)
1034
+ if is_all_none:
1035
+ return result
1036
+
1037
+ x, x_dim = x_bdim
1038
+ k, k_dim = k_bdim
1039
+ num_rows, num_rows_dim = num_rows_bdim
1040
+ num_cols, num_cols_dim = num_cols_bdim
1041
+ padding_value, padding_value_dim = padding_value_bdim
1042
+ if k_dim is not None:
1043
+ _raise_value_error("The source axis of `k` in {} must be None, but got {}.".format(prim_name, k_dim))
1044
+ if num_rows_dim is not None:
1045
+ _raise_value_error(
1046
+ "The source axis of `num_rows` in {} must be None, but got {}.".format(prim_name, num_rows_dim))
1047
+ if num_cols_dim is not None:
1048
+ _raise_value_error(
1049
+ "The source axis of `num_cols` in {} must be None, but got {}.".format(prim_name, num_cols_dim))
1050
+ if padding_value_dim is not None:
1051
+ _raise_value_error("The source axis of `padding_value` in {} must be None, "
1052
+ "but got {}.".format(prim_name, padding_value_dim))
1053
+
1054
+ x = _bdim_at_front(x, x_dim, axis_size)
1055
+ out = prim(x, k, num_rows, num_cols, padding_value)
1056
+ return out, 0
1057
+
1058
+ return vmap_rule
1059
+
1060
+
1061
+ @vmap_rules_getters.register("TensorShape")
1062
+ def get_tensor_shape_vmap_rule(prim, axis_size):
1063
+ """VmapRule for `TensorShape` operation."""
1064
+ def vmap_rule(input_bdim):
1065
+ is_all_none, result = vmap_general_preprocess(prim, input_bdim)
1066
+ if is_all_none:
1067
+ return result
1068
+
1069
+ input_x, x_dim = input_bdim
1070
+ sub_x = P.Unstack(x_dim)(input_x)[0]
1071
+ out = prim(sub_x)
1072
+
1073
+ return out, None
1074
+
1075
+ return vmap_rule
1076
+
1077
+
1078
+ @constexpr
1079
+ def _get_one_hot_vmap_axis(orig_axis, ndim, indices_dim):
1080
+ """Find vmap axis for OneHot."""
1081
+ if orig_axis >= 0 and indices_dim <= orig_axis:
1082
+ return orig_axis + 1, indices_dim
1083
+ if orig_axis == -1:
1084
+ if indices_dim == (ndim - 1):
1085
+ return ndim - 1, indices_dim + 1
1086
+ return orig_axis, indices_dim
1087
+ return orig_axis, indices_dim + 1
1088
+
1089
+
1090
+ @vmap_rules_getters.register(P.OneHot)
1091
+ def get_one_hot_vmap_rule(prim, axis_size):
1092
+ """VmapRule for `OneHot` operation."""
1093
+ prim_name = prim.name
1094
+
1095
+ def vmap_rule(indices_bdim, depth_bdim, on_value_bdim, off_value_bdim, axis_bdim):
1096
+ is_all_none, result = vmap_general_preprocess(prim, indices_bdim, depth_bdim, on_value_bdim,
1097
+ off_value_bdim, axis_bdim)
1098
+ if is_all_none:
1099
+ return result
1100
+
1101
+ indices, indices_dim = indices_bdim
1102
+ depth, depth_dim = depth_bdim
1103
+ on_value, on_value_dim = on_value_bdim
1104
+ off_value, off_value_dim = off_value_bdim
1105
+ axis, _ = axis_bdim
1106
+
1107
+ if depth_dim is not None:
1108
+ _raise_value_error(
1109
+ "The source axis of `depth` in {} must be None, but got {}.".format(prim_name, depth_dim))
1110
+
1111
+ if on_value_dim is not None:
1112
+ _raise_value_error(
1113
+ "The source axis of `on_value` in {} must be None, but got {}.".format(prim_name, on_value_dim))
1114
+
1115
+ if off_value_dim is not None:
1116
+ _raise_value_error(
1117
+ "The source axis of `off_value` in {} must be None, but got {}.".format(prim_name, off_value_dim))
1118
+
1119
+ if not F.isconstant(axis):
1120
+ _raise_value_error("'axis' in {} must be constant.".format(prim_name))
1121
+ ndim = F.rank(indices)
1122
+ new_axis, new_bd = _get_one_hot_vmap_axis(axis, ndim, indices_dim)
1123
+ out = prim(indices, depth, on_value, off_value, new_axis)
1124
+
1125
+ return out, new_bd
1126
+
1127
+ return vmap_rule
1128
+
1129
+
1130
+ @vmap_rules_getters.register(P.MaskedSelect)
1131
+ def get_masked_select_vmap_rule(prim, axis_size):
1132
+ """VmapRule for `MaskedSelect`."""
1133
+
1134
+ def vmap_rule(x_bdim, mask_bdim):
1135
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, mask_bdim)
1136
+ if is_all_none:
1137
+ return result
1138
+
1139
+ x, x_dim = x_bdim
1140
+ mask, mask_dim = mask_bdim
1141
+ if mask_dim is not None:
1142
+ _raise_value_error("The source axis of `mask` in `P.MaskedSelect` must be None, "
1143
+ "but got {}.".format(mask_dim))
1144
+
1145
+ x = _bdim_at_front(x, x_dim, axis_size)
1146
+ mask = _bdim_at_front(mask, mask_dim, axis_size)
1147
+ x_shape = F.shape(x)
1148
+ mask_shape = F.shape(mask)
1149
+ x = _handle_broadcasting(x, x_shape, mask_shape)
1150
+ out = prim(x, mask)
1151
+ x_rank = F.rank(x)
1152
+ if x_rank > 1:
1153
+ out = F.reshape(out, (x_shape[0], -1))
1154
+ return out, 0
1155
+
1156
+ return vmap_rule
1157
+
1158
+
1159
+ @vmap_rules_getters.register(MaskedSelectGrad)
1160
+ def get_masked_select_grad_vmap_rule(prim, axis_size):
1161
+ """VmapRule for `MaskedSelect`."""
1162
+ if isinstance(prim, str):
1163
+ prim = Primitive(prim)
1164
+
1165
+ def vmap_rule(x_bdim, mask_bdim, outgrad_bdim):
1166
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, mask_bdim, outgrad_bdim)
1167
+ if is_all_none:
1168
+ return result
1169
+
1170
+ x, x_dim = x_bdim
1171
+ mask, mask_dim = mask_bdim
1172
+ outgrad, outgrad_dim = outgrad_bdim
1173
+ if mask_dim is not None:
1174
+ _raise_value_error("The source axis of `mask` in `P.MaskedSelect` must be None, "
1175
+ "but got {}.".format(mask_dim))
1176
+
1177
+ x = _bdim_at_front(x, x_dim, axis_size)
1178
+ mask = _bdim_at_front(mask, mask_dim, axis_size)
1179
+ outgrad = _bdim_at_front(outgrad, outgrad_dim, axis_size)
1180
+ outgrad_shape = F.shape(outgrad)
1181
+ outgrad = F.reshape(outgrad, (outgrad_shape[0] * outgrad_shape[1],))
1182
+ x_grad = prim(x, mask, outgrad)
1183
+ return x_grad, 0
1184
+
1185
+ return vmap_rule
1186
+
1187
+
1188
+ @vmap_rules_getters.register(P.array_ops.MatrixBandPart)
1189
+ def get_matrix_band_part_vmap_rule(prim, axis_size):
1190
+ """VmapRule for `MatrixBandPart` operation."""
1191
+ if isinstance(prim, str):
1192
+ prim = Primitive(prim)
1193
+
1194
+ if hasattr(prim, 'batch_rank'):
1195
+ batch_rank = prim.batch_rank + 1
1196
+ else:
1197
+ batch_rank = 1
1198
+
1199
+ batch_prim = P.array_ops.MatrixBandPart()
1200
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
1201
+
1202
+ def vmap_rule(x_bdim, lower_bdim, upper_bdim):
1203
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, lower_bdim, upper_bdim)
1204
+ if is_all_none:
1205
+ return result
1206
+
1207
+ x, x_dim = x_bdim
1208
+ lower, lower_dim = lower_bdim
1209
+ upper, upper_dim = upper_bdim
1210
+
1211
+ x = _bdim_at_front(x, x_dim, axis_size)
1212
+ lower = _bdim_at_front(lower, lower_dim, axis_size)
1213
+ upper = _bdim_at_front(upper, upper_dim, axis_size)
1214
+
1215
+ out = batch_prim(x, lower, upper)
1216
+ return out, 0
1217
+
1218
+ return vmap_rule
1219
+
1220
+
1221
+ @vmap_rules_getters.register(P.array_ops.MatrixDiagPartV3)
1222
+ def get_matrix_diag_part_v3_vmap_rule(prim, axis_size):
1223
+ """VmapRule for `MatrixBandPart` operation."""
1224
+ if isinstance(prim, str):
1225
+ prim_name = prim
1226
+ align = "RIGHT_LEFT"
1227
+ else:
1228
+ prim_name = prim.name
1229
+ align = prim.align
1230
+
1231
+ matrix_diag_part = P.array_ops.MatrixDiagPartV3(align=align)
1232
+
1233
+ def vmap_rule(x_bdim, k_bdim, padding_value_bdim):
1234
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, k_bdim, padding_value_bdim)
1235
+ if is_all_none:
1236
+ return result
1237
+
1238
+ x, x_dim = x_bdim
1239
+ k, k_dim = k_bdim
1240
+ padding_value, padding_value_dim = padding_value_bdim
1241
+ if k_dim is not None:
1242
+ _raise_value_error("The source axis of `k` in {} must be None, but got {}.".format(prim_name, k_dim))
1243
+ if padding_value_dim is not None:
1244
+ _raise_value_error("The source axis of `padding_value` in {} must be None, "
1245
+ "but got {}.".format(prim_name, padding_value_dim))
1246
+
1247
+ x = _bdim_at_front(x, x_dim, axis_size)
1248
+ out = matrix_diag_part(x, k, padding_value)
1249
+
1250
+ return out, 0
1251
+
1252
+ return vmap_rule
1253
+
1254
+
1255
+ @vmap_rules_getters.register(P.array_ops.MatrixSetDiagV3)
1256
+ def get_matrix_set_diag_v3_vmap_rule(prim, axis_size):
1257
+ """VmapRule for `MatrixSetDiagV3` operation."""
1258
+ if isinstance(prim, str):
1259
+ prim_name = prim
1260
+ align = "RIGHT_LEFT"
1261
+ else:
1262
+ prim_name = prim.name
1263
+ align = prim.align
1264
+
1265
+ matrix_set_diag_op = P.array_ops.MatrixSetDiagV3(align=align)
1266
+
1267
+ def vmap_rule(x_bdim, diagonal_bdim, k_bdim):
1268
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, diagonal_bdim, k_bdim)
1269
+ if is_all_none:
1270
+ return result
1271
+
1272
+ x, x_dim = x_bdim
1273
+ k, k_dim = k_bdim
1274
+ diagonal, diagonal_dim = diagonal_bdim
1275
+ if k_dim is not None:
1276
+ _raise_value_error("The source axis of `k` in {} must be None, but got {}.".format(prim_name, k_dim))
1277
+
1278
+ x = _bdim_at_front(x, x_dim, axis_size)
1279
+ diagonal = _bdim_at_front(diagonal, diagonal_dim, axis_size)
1280
+ out = matrix_set_diag_op(x, diagonal, k)
1281
+
1282
+ return out, 0
1283
+
1284
+ return vmap_rule
1285
+
1286
+
1287
+ @vmap_rules_getters.register(P.Padding)
1288
+ def get_padding_vmap_rule(prim, axis_size):
1289
+ """VmapRule for `Padding` operation."""
1290
+ if isinstance(prim, str):
1291
+ prim = Primitive(prim)
1292
+
1293
+ def vmap_rule(x_bdim):
1294
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1295
+ if is_all_none:
1296
+ return result
1297
+
1298
+ x, x_dim = x_bdim
1299
+ if F.rank(x) and x_dim in (-1, F.rank(x) - 1):
1300
+ x = _bdim_at_front(x, x_dim, axis_size)
1301
+ output = prim(x)
1302
+ return output, 0
1303
+ output = prim(x)
1304
+ return output, x_dim
1305
+
1306
+ return vmap_rule
1307
+
1308
+
1309
+ @vmap_rules_getters.register(P.Ger)
1310
+ def get_ger_vmap_rule(prim, axis_size):
1311
+ """VmapRule for `Ger`."""
1312
+ if hasattr(prim, 'batch_rank'):
1313
+ batch_rank = prim.batch_rank + 1
1314
+ else:
1315
+ batch_rank = 1
1316
+
1317
+ batch_prim = P.Ger()
1318
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
1319
+
1320
+ def vmap_rule(x1_bdim, x2_bdim):
1321
+ is_all_none, result = vmap_general_preprocess(prim, x1_bdim, x2_bdim)
1322
+ if is_all_none:
1323
+ return result
1324
+
1325
+ x1, x1_dim = x1_bdim
1326
+ x2, x2_dim = x2_bdim
1327
+ x1 = _bdim_at_front(x1, x1_dim, axis_size)
1328
+ x2 = _bdim_at_front(x2, x2_dim, axis_size)
1329
+ out = batch_prim(x1, x2)
1330
+ return out, 0
1331
+
1332
+ return vmap_rule
1333
+
1334
+
1335
+ @vmap_rules_getters.register(P.GatherD)
1336
+ def get_gatherd_vmap_rule(prim, axis_size):
1337
+ """VmapRule for GatherD operations."""
1338
+ if isinstance(prim, str):
1339
+ prim = Primitive(prim)
1340
+
1341
+ def vmap_rule(x_bdim, dim_bdim, index_bdim):
1342
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim)
1343
+ if is_all_none:
1344
+ return result
1345
+
1346
+ x, x_dim = x_bdim
1347
+ dim_value, axis_dim = dim_bdim
1348
+ index, index_dim = index_bdim
1349
+
1350
+ # `dim` will be a Tensor in dynamic shape case, do not support its vamp.
1351
+ if axis_dim is not None:
1352
+ _raise_value_error("The source axis of `dim` in `GatherD` must be None, "
1353
+ "but got {}.".format(axis_dim))
1354
+ if not isinstance(dim_value, int):
1355
+ _raise_value_error("The `dim` in `GatherD` must be a int, but got {}.".format(dim_value))
1356
+
1357
+ out_dim = index_dim
1358
+
1359
+ # Broadcast if needed.
1360
+ if x_dim is None:
1361
+ x = _broadcast_by_axis(x, index_dim, axis_size)
1362
+ elif index_dim is None:
1363
+ index = _broadcast_by_axis(index, x_dim, axis_size)
1364
+ out_dim = x_dim
1365
+ elif x_dim != index_dim:
1366
+ mnp.moveaxis(x, x_dim, index_dim)
1367
+
1368
+ # Adapt `dim` to vmap case.
1369
+ x_ndim = ops.rank(x)
1370
+ dim_value = _get_reduce_batch_axis(dim_value, x_dim, x_ndim)
1371
+
1372
+ out = prim(x, dim_value, index)
1373
+ return out, out_dim
1374
+
1375
+ return vmap_rule
1376
+
1377
+
1378
+ @vmap_rules_getters.register(G.GatherDGradV2)
1379
+ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
1380
+ """VmapRule for GatherDGradV2 operations."""
1381
+ if isinstance(prim, str):
1382
+ prim = Primitive(prim)
1383
+
1384
+ def _update_dim(dim, x_rank, batch_dim):
1385
+ pdim = dim
1386
+ if pdim < 0:
1387
+ pdim += x_rank
1388
+ if pdim < 0 or pdim >= x_rank:
1389
+ _raise_value_error(
1390
+ "The `dim` in `GatherDGradV2` must be in range [{}, {}], but got {}.".format(-x_rank, x_rank - 1, dim))
1391
+ if pdim >= batch_dim:
1392
+ return pdim + 1
1393
+ if dim < 0:
1394
+ return pdim
1395
+ return dim
1396
+
1397
+ def vmap_rule(x_bdim, dim_bdim, index_bdim, grad_bdim):
1398
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim, grad_bdim)
1399
+ if is_all_none:
1400
+ return result
1401
+
1402
+ x, x_dim = x_bdim
1403
+ dim, dim_dim = dim_bdim
1404
+ if dim_dim is not None:
1405
+ _raise_value_error("The dim of 'dim' in `GatherDGradV2` must be None, but got {}.".format(dim_dim))
1406
+ index, index_dim = index_bdim
1407
+ grad, grad_dim = grad_bdim
1408
+ batch_dim = 0
1409
+ if x_dim is not None:
1410
+ batch_dim = x_dim
1411
+ elif index_dim is not None:
1412
+ batch_dim = index_dim
1413
+ elif grad_dim is not None:
1414
+ batch_dim = grad_dim
1415
+
1416
+ x = _bdim_at_any(x, x_dim, batch_dim, axis_size)
1417
+ index = _bdim_at_any(index, index_dim, batch_dim, axis_size)
1418
+ grad = _bdim_at_any(grad, grad_dim, batch_dim, axis_size)
1419
+ x_rank = F.rank(x) - 1
1420
+ # Adjust dim if needed
1421
+ dim = _update_dim(dim, x_rank, batch_dim)
1422
+ out = prim(x, dim, index, grad)
1423
+ return (out, batch_dim)
1424
+
1425
+ return vmap_rule
1426
+
1427
+
1428
+ @vmap_rules_getters.register(P.SpaceToBatchND)
1429
+ def get_space_to_batch_nd_vmap_rule(prim, axis_size):
1430
+ """VmapRule for `SpaceToBatchND`."""
1431
+
1432
+ def vmap_rule(input_xdim):
1433
+ is_all_none, result = vmap_general_preprocess(prim, input_xdim)
1434
+ if is_all_none:
1435
+ return result
1436
+
1437
+ x, x_dim = input_xdim
1438
+ x_trans = mnp.moveaxis(x, x_dim, 1)
1439
+ out = prim(x_trans)
1440
+ return out, 1
1441
+
1442
+ return vmap_rule
1443
+
1444
+
1445
+ @vmap_rules_getters.register(P.BatchToSpaceND)
1446
+ def get_batch_to_space_nd_vmap_rule(prim, axis_size):
1447
+ """VmapRule for `BatchToSpaceND`."""
1448
+
1449
+ def vmap_rule(input_xdim):
1450
+ is_all_none, result = vmap_general_preprocess(prim, input_xdim)
1451
+ if is_all_none:
1452
+ return result
1453
+
1454
+ x, x_dim = input_xdim
1455
+ x_trans = mnp.moveaxis(x, x_dim, 1)
1456
+ out = prim(x_trans)
1457
+ return out, 1
1458
+
1459
+ return vmap_rule
1460
+
1461
+
1462
+ @vmap_rules_getters.register(P.GatherNd)
1463
+ def get_gather_nd_vmap_rule(prim, axis_size):
1464
+ """VmapRule for GatherND operations."""
1465
+ if isinstance(prim, str):
1466
+ prim = P.GatherNd()
1467
+ concat = P.Concat(-1)
1468
+
1469
+ def vmap_rule(x_bdim, indices_bdim):
1470
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, indices_bdim)
1471
+ if is_all_none:
1472
+ return result
1473
+ x, x_dim = x_bdim
1474
+ indices, indices_dim = indices_bdim
1475
+ x = _bdim_at_front(x, x_dim, axis_size)
1476
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
1477
+ indices_shape = F.shape(indices)
1478
+ prefix = _get_prefix(indices_shape, axis_size, F.dtype(indices))
1479
+ indices = concat((prefix, indices))
1480
+ out = prim(x, indices)
1481
+ return out, 0
1482
+
1483
+ return vmap_rule
1484
+
1485
+
1486
+ @vmap_rules_getters.register(P.Meshgrid)
1487
+ def get_meshgrid_vmap_rule(prim, axis_size):
1488
+ """VmapRule for `P.Meshgrid` operation."""
1489
+ if isinstance(prim, str):
1490
+ prim = Primitive(prim)
1491
+ indexing = prim.indexing
1492
+
1493
+ def vmap_rule(*inputs_bdim):
1494
+ is_all_none, result = vmap_general_preprocess(prim, *inputs_bdim)
1495
+ if is_all_none:
1496
+ return result
1497
+
1498
+ if not isinstance(inputs_bdim, (tuple)):
1499
+ _raise_value_error("The inputs of P.Meshgrid is not tuple.")
1500
+ args = inputs_bdim[0]
1501
+ if len(args) <= 1:
1502
+ _raise_value_error(
1503
+ "The input number of P.Meshgrid must be greater than 1.")
1504
+
1505
+ output_shape = []
1506
+ ones_shape = []
1507
+ for each_arg in args:
1508
+ x, bdim = each_arg
1509
+ if bdim is None:
1510
+ _raise_value_error(
1511
+ "For Meshgrid vmap, the axis of each input must be same.")
1512
+ x = _bdim_at_front(x, bdim, axis_size)
1513
+ if F.rank(x) != 2:
1514
+ _raise_value_error(
1515
+ "Each input of Meshgrid must be 1D, but got {}.".format(F.rank(x) - 1))
1516
+ output_shape.append(F.shape(x)[-1])
1517
+ ones_shape.append(1)
1518
+ output_shape.insert(0, axis_size)
1519
+ ones_shape.insert(0, axis_size)
1520
+
1521
+ if indexing == "xy":
1522
+ output_shape[1], output_shape[2] = output_shape[2], output_shape[1]
1523
+ shape = tuple(output_shape)
1524
+
1525
+ input_0, _ = args[0]
1526
+ dtype = F.dtype(input_0)
1527
+ ones_tensor = F.fill(dtype, shape, 1)
1528
+
1529
+ index = 0
1530
+ vals_out_tuple = ()
1531
+ for each_arg in args:
1532
+ x, bdim = each_arg
1533
+ x = _bdim_at_front(x, bdim, axis_size)
1534
+ shape_index = (1 - index) if (index <= 1 and indexing == "xy") else index
1535
+ ones_shape[shape_index + 1] = output_shape[shape_index + 1]
1536
+ x = P.Reshape()(x, tuple(ones_shape))
1537
+ output = P.Mul()(x, ones_tensor)
1538
+ vals_out_tuple = vals_out_tuple + ((output, 0),)
1539
+ ones_shape[shape_index + 1] = 1
1540
+ index = index + 1
1541
+
1542
+ return vals_out_tuple
1543
+
1544
+ return vmap_rule
1545
+
1546
+
1547
+ @vmap_rules_getters.register(P.MaskedFill)
1548
+ def get_masked_fill_vmap_rule(prim, axis_size):
1549
+ """VmapRule for `MaskedFill` operation."""
1550
+ if prim.has_label('batch_rank'):
1551
+ batch_rank = prim.get_label('batch_rank') + 1
1552
+ else:
1553
+ batch_rank = 1
1554
+
1555
+ prim = prim.clone()
1556
+ prim.set_label('batch_rank', batch_rank)
1557
+
1558
+ def vmap_rule(input_bdim, mask_bdim, value_bdim):
1559
+ is_all_none, result = vmap_general_preprocess(prim, input_bdim, mask_bdim, value_bdim)
1560
+ if is_all_none:
1561
+ return result
1562
+
1563
+ input_x, x_dim = input_bdim
1564
+ mask, mask_dim = mask_bdim
1565
+ value, value_dim = value_bdim
1566
+ input_x = _bdim_at_front(input_x, x_dim, axis_size)
1567
+ mask = _bdim_at_front(mask, mask_dim, axis_size)
1568
+ value = _bdim_at_front(value, value_dim, axis_size)
1569
+ out = prim(input_x, mask, value)
1570
+ return out, 0
1571
+
1572
+ return vmap_rule
1573
+
1574
+
1575
+ @vmap_rules_getters.register(P.Gather)
1576
+ def get_gather_vmap_rule(prim, axis_size):
1577
+ """VmapRule for `Gather` operation. """
1578
+ if isinstance(prim, str):
1579
+ prim_name = prim
1580
+ prim = P.Gather()
1581
+ else:
1582
+ prim_name = prim.name
1583
+
1584
+ @_primexpr
1585
+ def process_axis(axis, x_shape_size, has_xdim: bool, has_idim: bool):
1586
+ if has_xdim and has_idim:
1587
+ if axis < 0:
1588
+ axis = x_shape_size - 1 + axis
1589
+ elif has_xdim:
1590
+ if axis >= 0:
1591
+ axis = axis + 1
1592
+ else:
1593
+ if axis < 0:
1594
+ axis = x_shape_size + axis
1595
+
1596
+ return axis
1597
+
1598
+ @_primexpr
1599
+ def get_x_dst_shape(x_shape, axis):
1600
+ target_axis_size = x_shape[axis + 1]
1601
+ x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis + 2:]
1602
+ max_axis_size = axis_size * target_axis_size
1603
+
1604
+ return target_axis_size, x_dst_shape, max_axis_size
1605
+
1606
+ def vmap_rule(x_bdim, indices_bdim, axis_bdim, batch_dims_bdim):
1607
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, indices_bdim, batch_dims_bdim)
1608
+ if is_all_none:
1609
+ return result
1610
+
1611
+ x, x_dim = x_bdim
1612
+ indices, indices_dim = indices_bdim
1613
+ axis, axis_dim = axis_bdim
1614
+ batch_dims, batch_dims_dim = batch_dims_bdim
1615
+
1616
+ if axis_dim is not None:
1617
+ _raise_value_error("The source axis of `axis` in {} must be None, but got {}.".format(prim_name, axis_dim))
1618
+
1619
+ if batch_dims_dim is not None:
1620
+ _raise_value_error("The source batch_dims of `batch_dims` in {} must be None, but got {}."
1621
+ .format(prim_name, batch_dims_dim))
1622
+
1623
+ x_shape_len = len(x.shape)
1624
+
1625
+ if x_dim is not None and indices_dim is None:
1626
+ x = _bdim_at_front(x, x_dim, axis_size)
1627
+ axis = process_axis(axis, x_shape_len, True, False)
1628
+ output = prim(x, indices, axis, batch_dims)
1629
+ return output, 0
1630
+
1631
+ if x_dim is None and indices_dim is not None:
1632
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
1633
+ axis = process_axis(axis, x_shape_len, False, True)
1634
+ output = prim(x, indices, axis, batch_dims)
1635
+ return output, axis
1636
+
1637
+ x = _bdim_at_front(x, x_dim, axis_size)
1638
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
1639
+
1640
+ axis = process_axis(axis, x_shape_len, True, True)
1641
+
1642
+ x = mnp.moveaxis(x, 0, axis)
1643
+
1644
+ x_shape = x.shape
1645
+ target_axis_size, x_dst_shape, max_axis_size = get_x_dst_shape(x_shape, axis)
1646
+
1647
+ x = x.reshape(x_dst_shape)
1648
+
1649
+ counts_shape = indices.shape
1650
+ counts = mnp.arange(0, axis_size, 1)
1651
+ counts = F.mul(counts, target_axis_size)
1652
+ counts = P.BroadcastTo(counts_shape[1:] + (axis_size,))(counts)
1653
+ counts = mnp.moveaxis(counts, -1, 0)
1654
+
1655
+ indices_out_of_bound = mnp.where(indices > target_axis_size - 1, x=max_axis_size, y=0)
1656
+
1657
+ indices = F.add(indices, counts)
1658
+ indices = F.add(indices, indices_out_of_bound)
1659
+
1660
+ output = prim(x, indices, axis, batch_dims)
1661
+
1662
+ return output, axis
1663
+
1664
+ return vmap_rule
1665
+
1666
+
1667
+ @vmap_rules_getters.register(TensorScatterElements)
1668
+ def get_tensor_scatter_elements_vmap_rule(prim, axis_size):
1669
+ """VmapRule for TensorScatterElements operations."""
1670
+ if isinstance(prim, str):
1671
+ axis = 0
1672
+ reduction = 'none'
1673
+ else:
1674
+ axis = prim.axis
1675
+ reduction = prim.reduction
1676
+
1677
+ def two_dims_are_none(i_bdim, j_no_dim, k_no_dim, axis_size):
1678
+ i, i_dim = i_bdim
1679
+ j = _broadcast_by_axis(j_no_dim, i_dim, axis_size)
1680
+ k = _broadcast_by_axis(k_no_dim, i_dim, axis_size)
1681
+ new_inputs = (i, j, k)
1682
+ return (new_inputs, i_dim)
1683
+
1684
+ def one_dim_is_none(i_bdim, j_bdim, k_no_dim, axis_size):
1685
+ i, i_dim = i_bdim
1686
+ j, j_dim = j_bdim
1687
+ mnp.moveaxis(j, j_dim, i_dim)
1688
+ k = _broadcast_by_axis(k_no_dim, i_dim, axis_size)
1689
+ new_inputs = (i, j, k)
1690
+ return (new_inputs, i_dim)
1691
+
1692
+ def no_dim_is_none(i_bdim, j_bdim, k_bdim):
1693
+ i, i_dim = i_bdim
1694
+ j, j_dim = j_bdim
1695
+ k, k_dim = k_bdim
1696
+ mnp.moveaxis(j, j_dim, i_dim)
1697
+ mnp.moveaxis(k, k_dim, i_dim)
1698
+ new_inputs = (i, j, k)
1699
+ return new_inputs, i_dim
1700
+
1701
+ def vmap_rule(x_bdim, index_bdim, update_bdim):
1702
+ is_all_none, result = vmap_general_preprocess(
1703
+ prim, x_bdim, index_bdim, update_bdim)
1704
+ if is_all_none:
1705
+ return result
1706
+
1707
+ x, x_dim = x_bdim
1708
+ index, index_dim = index_bdim
1709
+ update, update_dim = update_bdim
1710
+
1711
+ numbers = [x_dim, index_dim, update_dim].count(None)
1712
+ if numbers == 2:
1713
+ if x_dim is not None:
1714
+ inputs, out_dim = two_dims_are_none(
1715
+ x_bdim, index, update, axis_size)
1716
+ x, index, update = inputs
1717
+ elif index_dim is not None:
1718
+ inputs, out_dim = two_dims_are_none(
1719
+ index_bdim, x, update, axis_size)
1720
+ index, x, update = inputs
1721
+ else:
1722
+ inputs, out_dim = two_dims_are_none(
1723
+ update_bdim, x, index, axis_size)
1724
+ update, x, index = inputs
1725
+ elif numbers == 1:
1726
+ if x_dim is None:
1727
+ inputs, out_dim = one_dim_is_none(
1728
+ index_bdim, update_bdim, x, axis_size)
1729
+ index, update, x = inputs
1730
+ elif index_dim is None:
1731
+ inputs, out_dim = one_dim_is_none(
1732
+ x_bdim, update_bdim, index, axis_size)
1733
+ x, update, index = inputs
1734
+ else:
1735
+ inputs, out_dim = one_dim_is_none(
1736
+ x_bdim, index_bdim, update, axis_size)
1737
+ x, index, update = inputs
1738
+ else:
1739
+ inputs, out_dim = no_dim_is_none(x_bdim, index_bdim, update_bdim)
1740
+ x, index, update = inputs
1741
+
1742
+ # Adapt `axis` to vmap case.
1743
+ new_axis = axis + 1 if axis >= out_dim else axis
1744
+
1745
+ out = TensorScatterElements(new_axis, reduction)(x, index, update)
1746
+ return out, out_dim
1747
+
1748
+ return vmap_rule
1749
+
1750
+
1751
+ @vmap_rules_getters.register(IndexFill)
1752
+ def get_index_fill_rule(prim, axis_size):
1753
+ """VmapRule for `IndexFill` operation."""
1754
+ prim_vmap = _VmapGeneralRule(prim, axis_size)
1755
+
1756
+ def vmap_rule(x_bdim, dim_bdim, index_bdim, value_bdim):
1757
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim, value_bdim)
1758
+ if is_all_none:
1759
+ return result
1760
+
1761
+ x, x_dim = x_bdim
1762
+ dim, dim_dim = dim_bdim
1763
+ index, index_dim = index_bdim
1764
+ value, value_dim = value_bdim
1765
+ if dim_dim is not None or index_dim is not None or value_dim is not None:
1766
+ return prim_vmap(x_bdim, dim_bdim, index_bdim, value_bdim)
1767
+
1768
+ x = _bdim_at_front(x, x_dim, axis_size)
1769
+ new_dim = F.select(dim < 0, dim, dim + 1)
1770
+ out = prim(x, new_dim, index, value)
1771
+ return out, 0
1772
+
1773
+ return vmap_rule
1774
+
1775
+
1776
+ @vmap_rules_getters.register(P.DataFormatDimMap)
1777
+ def get_data_format_dim_map_vmap_rule(prim, axis_size):
1778
+ """VmapRule for `DataFormatDimMap`"""
1779
+ if hasattr(prim, 'batch_rank'):
1780
+ batch_rank = prim.batch_rank + 1
1781
+ else:
1782
+ batch_rank = 1
1783
+
1784
+ batch_prim = P.DataFormatDimMap()
1785
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
1786
+
1787
+ def vmap_rule(x_bdim):
1788
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1789
+ if is_all_none:
1790
+ return result
1791
+ x, x_dim = x_bdim
1792
+ x = _bdim_at_front(x, x_dim, axis_size)
1793
+ out = batch_prim(x)
1794
+ return out, 0
1795
+
1796
+ return vmap_rule
1797
+
1798
+
1799
+ @vmap_rules_getters.register(P.ExpandDims)
1800
+ def get_expand_dims_vmap_rule(prim, axis_size):
1801
+ """VmapRule for `ExpandDims`."""
1802
+
1803
+ @_primexpr
1804
+ def process_axis(axis, rank, x_dim):
1805
+ if axis < 0:
1806
+ axis += rank
1807
+ axis_processed = axis + 1 if x_dim <= axis else axis
1808
+ x_dim = x_dim if x_dim < axis_processed else x_dim + 1
1809
+ return axis_processed, x_dim
1810
+
1811
+ def vmap_rule(x_bdim, axis_bdim):
1812
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, axis_bdim)
1813
+ if is_all_none:
1814
+ return result
1815
+
1816
+ x, x_dim = x_bdim
1817
+ axis, axis_dim = axis_bdim
1818
+ rank = ops.rank(x)
1819
+
1820
+ if axis_dim is not None:
1821
+ _raise_value_error("The source axis of shape in `ExpandDims` must be None, but got {}.".format(axis_dim))
1822
+
1823
+ axis, x_dim = process_axis(axis, rank, x_dim)
1824
+ output = prim(x, axis)
1825
+ return output, x_dim
1826
+
1827
+ return vmap_rule
1828
+
1829
+
1830
+ @vmap_rules_getters.register(P.Diag)
1831
+ def get_diag_vmap_rule(prim, axis_size):
1832
+ """VmapRule for `Diag` operations."""
1833
+ if prim.has_label("batch_rank"):
1834
+ batch_rank = prim.get_label("batch_rank") + 1
1835
+ else:
1836
+ batch_rank = 1
1837
+
1838
+ prim = prim.clone()
1839
+ prim.set_label('batch_rank', batch_rank)
1840
+
1841
+ def vmap_rule(x_bdim):
1842
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1843
+ if is_all_none:
1844
+ return result
1845
+ x, x_dim = x_bdim
1846
+ x = _bdim_at_front(x, x_dim, axis_size)
1847
+ out = prim(x)
1848
+ return out, 0
1849
+
1850
+ return vmap_rule
1851
+
1852
+
1853
+ @vmap_rules_getters.register(P.Slice)
1854
+ def get_slice_vmap_rule(prim, axis_size):
1855
+ """VmapRule for `Slice` operation."""
1856
+ if isinstance(prim, str):
1857
+ prim_name = prim
1858
+ prim = Primitive(prim)
1859
+ else:
1860
+ prim_name = prim.name
1861
+
1862
+ def vmap_rule(x_bdim, begin_bdim, size_bdim):
1863
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, begin_bdim, size_bdim)
1864
+ if is_all_none:
1865
+ return result
1866
+
1867
+ x, x_dim = x_bdim
1868
+ begin, begin_dim = begin_bdim
1869
+ size, size_dim = size_bdim
1870
+
1871
+ if begin_dim is not None:
1872
+ _raise_value_error("The source axis of `begin` in {} only supports None currently, "
1873
+ "but got {}.".format(prim_name, begin_dim))
1874
+ if size_dim is not None:
1875
+ _raise_value_error("The source axis of `size` in {} must be None, but got {}.".format(prim_name, size_dim))
1876
+
1877
+ x = _bdim_at_front(x, x_dim, axis_size)
1878
+
1879
+ batch_begin = (0,) + begin
1880
+ batch_size = (axis_size,) + size
1881
+
1882
+ out = prim(x, batch_begin, batch_size)
1883
+
1884
+ return out, 0
1885
+
1886
+ return vmap_rule
1887
+
1888
+
1889
+ @vmap_rules_getters.register(P.Squeeze)
1890
+ def get_squeeze_vmap_rule(prim, axis_size):
1891
+ """VmapRule for `Squeeze`."""
1892
+ if hasattr(prim, 'axis'):
1893
+ prim_axis = prim.axis
1894
+ else:
1895
+ prim_axis = None
1896
+
1897
+ @_primexpr
1898
+ def move_axis(axes):
1899
+ new_axis = ()
1900
+ for axis in axes:
1901
+ if axis < 0:
1902
+ new_axis = new_axis + (axis,)
1903
+ else:
1904
+ new_axis = new_axis + (axis + 1,)
1905
+ return new_axis
1906
+
1907
+ @_primexpr
1908
+ def generate_all_axis_except_first(x_rank):
1909
+ new_axis = ()
1910
+ for i in range(1, x_rank, 1):
1911
+ new_axis += (i,)
1912
+ return new_axis
1913
+
1914
+ def vmap_rule(x_bdim):
1915
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1916
+ if is_all_none:
1917
+ return result
1918
+
1919
+ x, x_dim = x_bdim
1920
+ x = _bdim_at_front(x, x_dim, axis_size)
1921
+
1922
+ if prim_axis is None:
1923
+ if axis_size == 1:
1924
+ new_axis = generate_all_axis_except_first(F.rank(x))
1925
+ batch_squeeze = P.Squeeze(axis=new_axis)
1926
+ out = batch_squeeze(x)
1927
+ return out, 0
1928
+
1929
+ out = prim(x)
1930
+ return out, 0
1931
+
1932
+ new_axis = move_axis(prim_axis)
1933
+ batch_squeeze = P.Squeeze(axis=new_axis)
1934
+ out = batch_squeeze(x)
1935
+ return out, 0
1936
+
1937
+ return vmap_rule
1938
+
1939
+
1940
+ @vmap_rules_getters.register(P.StridedSlice)
1941
+ def get_stridedslice_vmap_rule(prim, axis_size):
1942
+ """VmapRule for `StridedSlice`."""
1943
+ @_primexpr
1944
+ def get_new_begin_end_strided(begin, end, strided):
1945
+ new_begin = (0,) + begin
1946
+ new_end = (0,) + end
1947
+ new_strided = (1,) + strided
1948
+ return new_begin, new_end, new_strided
1949
+
1950
+ def _get_mask_value_and_prim(begin_mask_bdim, end_mask_bdim, ellipsis_mask_bdim, new_axis_mask_bdim,
1951
+ shrink_axis_mask_bdim):
1952
+ begin_mask, _ = begin_mask_bdim
1953
+ end_mask, _ = end_mask_bdim
1954
+ ellipsis_mask, _ = ellipsis_mask_bdim
1955
+ new_axis_mask, _ = new_axis_mask_bdim
1956
+ shrink_axis_mask, _ = shrink_axis_mask_bdim
1957
+ new_begin_mask = begin_mask * 2 + 1
1958
+ new_end_mask = end_mask * 2 + 1
1959
+ new_ellipsis_mask = ellipsis_mask
1960
+ new_new_axis_mask = new_axis_mask * 2
1961
+ new_shrink_axis_mask = shrink_axis_mask * 2
1962
+ batch_stridedslice = P.StridedSlice(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask,
1963
+ new_shrink_axis_mask)
1964
+ return batch_stridedslice
1965
+
1966
+ def vmap_rule(x_bdim, begin_bdim, end_bdim, strided_bdim, begin_mask_bdim, end_mask_bdim, ellipsis_mask_bdim,
1967
+ new_axis_mask_bdim, shrink_axis_mask_bdim):
1968
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, begin_bdim, end_bdim, strided_bdim)
1969
+ if is_all_none:
1970
+ return result
1971
+
1972
+ x, x_dim = x_bdim
1973
+ begin, begin_dim = begin_bdim
1974
+ end, end_dim = end_bdim
1975
+ strided, strided_dim = strided_bdim
1976
+ batch_stridedslice = _get_mask_value_and_prim(begin_mask_bdim, end_mask_bdim, ellipsis_mask_bdim,
1977
+ new_axis_mask_bdim, shrink_axis_mask_bdim)
1978
+
1979
+ if any(dim is not None for dim in [begin_dim, end_dim, strided_dim]):
1980
+ _raise_value_error("vmap of `StridedSlice` not support `begin`, `end` or `strided` has batch dimension, "
1981
+ "but got {}, {}, {}".format(begin_dim, end_dim, strided_dim))
1982
+
1983
+ # x_dim is not None, and the others are None
1984
+ x = _bdim_at_front(x, x_dim, axis_size)
1985
+ new_begin, new_end, new_strided = get_new_begin_end_strided(begin, end, strided)
1986
+ result = batch_stridedslice(x, new_begin, new_end, new_strided)
1987
+ return result, 0
1988
+
1989
+ return vmap_rule
1990
+
1991
+
1992
+ @vmap_rules_getters.register(G.StridedSliceGrad)
1993
+ def get_stridedslice_grad_vmap_rule(prim, axis_size):
1994
+ """VmapRule for `StridedSliceGrad`."""
1995
+ new_begin_mask = prim.begin_mask * 2 + 1
1996
+ new_end_mask = prim.end_mask * 2 + 1
1997
+ new_ellipsis_mask = prim.ellipsis_mask
1998
+ new_new_axis_mask = prim.new_axis_mask * 2
1999
+ new_shrink_axis_mask = prim.shrink_axis_mask * 2
2000
+ batch_stridedslice_grad = G.StridedSliceGrad(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask,
2001
+ new_shrink_axis_mask)
2002
+
2003
+ @_primexpr
2004
+ def get_new_xshape_begin_end_strided(xshape, begin, end, strided):
2005
+ new_xshape = (axis_size,) + xshape
2006
+ new_begin = (0,) + begin
2007
+ new_end = (axis_size,) + end
2008
+ new_strided = (1,) + strided
2009
+ return new_xshape, new_begin, new_end, new_strided
2010
+
2011
+ def vmap_rule(dy_bdim, xshape_bdim, begin_bdim, end_bdim, strided_bdim):
2012
+ is_all_none, result = vmap_general_preprocess(prim, dy_bdim, xshape_bdim, begin_bdim, end_bdim, strided_bdim)
2013
+ if is_all_none:
2014
+ return result
2015
+
2016
+ dy, dy_dim = dy_bdim
2017
+ xshape, xshape_dim = xshape_bdim
2018
+ begin, begin_dim = begin_bdim
2019
+ end, end_dim = end_bdim
2020
+ strided, strided_dim = strided_bdim
2021
+
2022
+ if any(dim is not None for dim in [xshape_dim, begin_dim, end_dim, strided_dim]):
2023
+ _raise_value_error("vmap of `StridedSliceGrad` not support `xshape`, `begin`, "
2024
+ "`end` or `strided` has batch dimension, "
2025
+ "but got {}, {}, {}, {}".format(xshape_dim, begin_dim, end_dim, strided_dim))
2026
+
2027
+ # dy_dim and x_dim are not None, and others are None
2028
+ dy = _bdim_at_front(dy, dy_dim, axis_size)
2029
+
2030
+ new_xshape, new_begin, new_end, new_strided = get_new_xshape_begin_end_strided(xshape, begin, end, strided)
2031
+
2032
+ result = batch_stridedslice_grad(dy, new_xshape, new_begin, new_end, new_strided)
2033
+ return result, 0
2034
+
2035
+ return vmap_rule
2036
+
2037
+
2038
+ @vmap_rules_getters.register(P.TopK)
2039
+ def get_topk_vmap_rule(prim, axis_size):
2040
+ """VmapRule for `TopK` operation."""
2041
+ if isinstance(prim, str):
2042
+ prim_name = prim
2043
+ prim = Primitive(prim)
2044
+ else:
2045
+ prim_name = prim.name
2046
+
2047
+ def vmap_rule(x_bdim, k_bdim):
2048
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, k_bdim)
2049
+ if is_all_none:
2050
+ return result
2051
+
2052
+ x, x_dim = x_bdim
2053
+ k, k_dim = k_bdim
2054
+
2055
+ if k_dim is not None:
2056
+ _raise_value_error("The source axis of `k` in {} must be None, but got {}.".format(prim_name, k_dim))
2057
+
2058
+ if F.rank(x) and x_dim in (-1, F.rank(x) - 1):
2059
+ x = _bdim_at_front(x, x_dim, axis_size)
2060
+ values, indices = prim(x, k)
2061
+ return (values, 0), (indices, 0)
2062
+
2063
+ values, indices = prim(x, k)
2064
+ return (values, x_dim), (indices, x_dim)
2065
+
2066
+ return vmap_rule
2067
+
2068
+
2069
+ @vmap_rules_getters.register(P.Im2Col)
2070
+ def get_im2col_vmap_rule(prim, axis_size):
2071
+ """VmapRule for `Im2Col` operations."""
2072
+ if isinstance(prim, str):
2073
+ prim = Primitive(prim)
2074
+
2075
+ def vmap_rule(x_bdim):
2076
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
2077
+ if is_all_none:
2078
+ return result
2079
+ x, x_dim = x_bdim
2080
+ x = _bdim_at_front(x, x_dim, axis_size)
2081
+ x_shape = x.shape
2082
+ x_new_shape = (-1,) + x_shape[2:]
2083
+ x = x.reshape(x_new_shape)
2084
+
2085
+ out = prim(x)
2086
+ out_shape = out.shape
2087
+ original_shape = x_shape[:2] + out_shape[1:]
2088
+ out = out.reshape(original_shape)
2089
+ return out, 0
2090
+
2091
+ return vmap_rule
2092
+
2093
+
2094
+ @vmap_rules_getters.register(P.Split)
2095
+ def get_split_vmap_rule(prim, axis_size):
2096
+ """VmapRule for `Split`."""
2097
+ def vmap_rule(x_bdim, axis_bdim, output_num_bdim):
2098
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, axis_bdim, output_num_bdim)
2099
+ if is_all_none:
2100
+ return result
2101
+ x, x_dim = x_bdim
2102
+ axis, axis_bdim = axis_bdim
2103
+ if axis >= 0:
2104
+ axis += 1
2105
+ output_num, output_num_bdim = output_num_bdim
2106
+ x = _bdim_at_front(x, x_dim, axis_size)
2107
+ batch_prim = P.Split(axis, output_num)
2108
+ outputs = batch_prim(x)
2109
+ output = ()
2110
+ for out in outputs:
2111
+ output = output + ((out, 0),)
2112
+ return output
2113
+
2114
+ return vmap_rule
2115
+
2116
+
2117
+ @vmap_rules_getters.register(P.SearchSorted)
2118
+ def get_searchsorted_vmap_rule(prim, axis_size):
2119
+ """VmapRule for `SearchSorted`."""
2120
+ def vmap_rule(sequence_bdim, values_bdim, sorter_bdim, dtype_bdim, right_bdim):
2121
+ is_all_none, result = vmap_general_preprocess(prim, sequence_bdim, values_bdim,
2122
+ sorter_bdim, dtype_bdim, right_bdim)
2123
+ if is_all_none:
2124
+ return result
2125
+
2126
+ sequence, sequence_dim = sequence_bdim
2127
+ values, values_dim = values_bdim
2128
+ sorter, sorter_dim = sorter_bdim
2129
+
2130
+ sequence = _bdim_at_front(sequence, sequence_dim, axis_size)
2131
+ values = _bdim_at_front(values, values_dim, axis_size)
2132
+ if sorter is not None and sorter_dim is not None:
2133
+ sorter = _bdim_at_front(sorter, sorter_dim, axis_size)
2134
+
2135
+ outputs = prim(sequence, values, sorter, dtype_bdim[0], right_bdim[0])
2136
+
2137
+ return outputs, 0
2138
+
2139
+ return vmap_rule
2140
+
2141
+
2142
+ get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(NonZero)(get_unsupported_dynamic_vmap_rule)
2143
+ get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
2144
+ get_unsupported_dynamic_vmap_rule = \
2145
+ vmap_rules_getters.register(UniqueConsecutive)(get_unsupported_dynamic_vmap_rule)
2146
+ get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(Col2Im)(get_unsupported_dynamic_vmap_rule)
2147
+ get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(RandomPoisson)(get_unsupported_dynamic_vmap_rule)
2148
+ get_unop_vmap_rule = vmap_rules_getters.register("ZerosLike")(get_unop_vmap_rule)
2149
+ get_unop_vmap_rule = vmap_rules_getters.register(P.OnesLike)(get_unop_vmap_rule)