mindspore 2.4.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1406) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +53 -0
  18. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +106 -0
  22. mindspore/_checkparam.py +1419 -0
  23. mindspore/_extends/__init__.py +23 -0
  24. mindspore/_extends/builtin_operations.py +224 -0
  25. mindspore/_extends/graph_kernel/__init__.py +17 -0
  26. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  27. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  28. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  29. mindspore/_extends/graph_kernel/model/model.py +553 -0
  30. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  31. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  32. mindspore/_extends/graph_kernel/splitter.py +140 -0
  33. mindspore/_extends/graph_kernel/utils.py +28 -0
  34. mindspore/_extends/parallel_compile/__init__.py +19 -0
  35. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  38. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  39. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  40. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  41. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  42. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  43. mindspore/_extends/parse/__init__.py +49 -0
  44. mindspore/_extends/parse/compile_config.py +299 -0
  45. mindspore/_extends/parse/namespace.py +136 -0
  46. mindspore/_extends/parse/parser.py +1448 -0
  47. mindspore/_extends/parse/resources.py +213 -0
  48. mindspore/_extends/parse/standard_method.py +4475 -0
  49. mindspore/_extends/parse/trope.py +97 -0
  50. mindspore/_extends/pijit/__init__.py +23 -0
  51. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  52. mindspore/_extends/remote/__init__.py +19 -0
  53. mindspore/_extends/remote/kernel_build_server.py +199 -0
  54. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  57. mindspore/_extends/utils.py +68 -0
  58. mindspore/_install_custom.py +43 -0
  59. mindspore/_profiler.py +30 -0
  60. mindspore/amp.py +433 -0
  61. mindspore/atlprov.dll +0 -0
  62. mindspore/avcodec-59.dll +0 -0
  63. mindspore/avdevice-59.dll +0 -0
  64. mindspore/avfilter-8.dll +0 -0
  65. mindspore/avformat-59.dll +0 -0
  66. mindspore/avutil-57.dll +0 -0
  67. mindspore/boost/__init__.py +42 -0
  68. mindspore/boost/adasum.py +319 -0
  69. mindspore/boost/base.py +535 -0
  70. mindspore/boost/boost.py +400 -0
  71. mindspore/boost/boost_cell_wrapper.py +790 -0
  72. mindspore/boost/dim_reduce.py +323 -0
  73. mindspore/boost/grad_accumulation.py +79 -0
  74. mindspore/boost/grad_freeze.py +382 -0
  75. mindspore/boost/group_loss_scale_manager.py +166 -0
  76. mindspore/boost/less_batch_normalization.py +174 -0
  77. mindspore/c1.dll +0 -0
  78. mindspore/c1xx.dll +0 -0
  79. mindspore/c2.dll +0 -0
  80. mindspore/cfgpersist.dll +0 -0
  81. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  82. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  83. mindspore/common/__init__.py +86 -0
  84. mindspore/common/_auto_dynamic.py +68 -0
  85. mindspore/common/_decorator.py +50 -0
  86. mindspore/common/_jit_fallback_utils.py +110 -0
  87. mindspore/common/_monad.py +25 -0
  88. mindspore/common/_pijit_context.py +190 -0
  89. mindspore/common/_register_for_adapter.py +74 -0
  90. mindspore/common/_register_for_recompute.py +48 -0
  91. mindspore/common/_register_for_tensor.py +46 -0
  92. mindspore/common/_stub_tensor.py +210 -0
  93. mindspore/common/_tensor_overload.py +139 -0
  94. mindspore/common/_utils.py +122 -0
  95. mindspore/common/api.py +2064 -0
  96. mindspore/common/auto_dynamic_shape.py +507 -0
  97. mindspore/common/dtype.py +422 -0
  98. mindspore/common/dump.py +130 -0
  99. mindspore/common/file_system.py +48 -0
  100. mindspore/common/generator.py +254 -0
  101. mindspore/common/hook_handle.py +143 -0
  102. mindspore/common/initializer.py +880 -0
  103. mindspore/common/jit_config.py +98 -0
  104. mindspore/common/lazy_inline.py +240 -0
  105. mindspore/common/mindir_util.py +111 -0
  106. mindspore/common/mutable.py +234 -0
  107. mindspore/common/no_inline.py +54 -0
  108. mindspore/common/np_dtype.py +25 -0
  109. mindspore/common/parameter.py +1081 -0
  110. mindspore/common/recompute.py +292 -0
  111. mindspore/common/seed.py +260 -0
  112. mindspore/common/sparse_tensor.py +1175 -0
  113. mindspore/common/symbol.py +122 -0
  114. mindspore/common/tensor.py +5039 -0
  115. mindspore/communication/__init__.py +37 -0
  116. mindspore/communication/_comm_helper.py +501 -0
  117. mindspore/communication/_hccl_management.py +297 -0
  118. mindspore/communication/comm_func.py +1395 -0
  119. mindspore/communication/management.py +673 -0
  120. mindspore/config/op_info.config +533 -0
  121. mindspore/context.py +2077 -0
  122. mindspore/d3dcompiler_47.dll +0 -0
  123. mindspore/dataset/__init__.py +90 -0
  124. mindspore/dataset/audio/__init__.py +61 -0
  125. mindspore/dataset/audio/transforms.py +3690 -0
  126. mindspore/dataset/audio/utils.py +386 -0
  127. mindspore/dataset/audio/validators.py +1172 -0
  128. mindspore/dataset/callback/__init__.py +20 -0
  129. mindspore/dataset/callback/ds_callback.py +368 -0
  130. mindspore/dataset/callback/validators.py +32 -0
  131. mindspore/dataset/core/__init__.py +13 -0
  132. mindspore/dataset/core/config.py +1095 -0
  133. mindspore/dataset/core/datatypes.py +101 -0
  134. mindspore/dataset/core/py_util_helpers.py +65 -0
  135. mindspore/dataset/core/validator_helpers.py +781 -0
  136. mindspore/dataset/debug/__init__.py +21 -0
  137. mindspore/dataset/debug/debug_hook.py +97 -0
  138. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  139. mindspore/dataset/engine/__init__.py +124 -0
  140. mindspore/dataset/engine/cache_admin.py +47 -0
  141. mindspore/dataset/engine/cache_client.py +129 -0
  142. mindspore/dataset/engine/datasets.py +4582 -0
  143. mindspore/dataset/engine/datasets_audio.py +911 -0
  144. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  145. mindspore/dataset/engine/datasets_text.py +2161 -0
  146. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  147. mindspore/dataset/engine/datasets_vision.py +4816 -0
  148. mindspore/dataset/engine/iterators.py +371 -0
  149. mindspore/dataset/engine/obs/__init__.py +23 -0
  150. mindspore/dataset/engine/obs/config_loader.py +68 -0
  151. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  152. mindspore/dataset/engine/obs/util.py +482 -0
  153. mindspore/dataset/engine/offload.py +596 -0
  154. mindspore/dataset/engine/queue.py +304 -0
  155. mindspore/dataset/engine/samplers.py +895 -0
  156. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  157. mindspore/dataset/engine/validators.py +2895 -0
  158. mindspore/dataset/text/__init__.py +51 -0
  159. mindspore/dataset/text/transforms.py +1703 -0
  160. mindspore/dataset/text/utils.py +715 -0
  161. mindspore/dataset/text/validators.py +642 -0
  162. mindspore/dataset/transforms/__init__.py +45 -0
  163. mindspore/dataset/transforms/c_transforms.py +638 -0
  164. mindspore/dataset/transforms/py_transforms.py +393 -0
  165. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  166. mindspore/dataset/transforms/transforms.py +1260 -0
  167. mindspore/dataset/transforms/validators.py +410 -0
  168. mindspore/dataset/utils/__init__.py +19 -0
  169. mindspore/dataset/utils/browse_dataset.py +190 -0
  170. mindspore/dataset/utils/line_reader.py +126 -0
  171. mindspore/dataset/vision/__init__.py +65 -0
  172. mindspore/dataset/vision/c_transforms.py +2641 -0
  173. mindspore/dataset/vision/py_transforms.py +2120 -0
  174. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  175. mindspore/dataset/vision/transforms.py +7295 -0
  176. mindspore/dataset/vision/utils.py +863 -0
  177. mindspore/dataset/vision/validators.py +1483 -0
  178. mindspore/default_config.py +2 -0
  179. mindspore/dnnl.dll +0 -0
  180. mindspore/dpcmi.dll +0 -0
  181. mindspore/experimental/__init__.py +20 -0
  182. mindspore/experimental/es/__init__.py +22 -0
  183. mindspore/experimental/es/embedding_service.py +883 -0
  184. mindspore/experimental/es/embedding_service_layer.py +581 -0
  185. mindspore/experimental/llm_boost/__init__.py +21 -0
  186. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  187. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  188. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  189. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  190. mindspore/experimental/llm_boost/register.py +129 -0
  191. mindspore/experimental/llm_boost/utils.py +31 -0
  192. mindspore/experimental/map_parameter.py +309 -0
  193. mindspore/experimental/optim/__init__.py +40 -0
  194. mindspore/experimental/optim/adadelta.py +161 -0
  195. mindspore/experimental/optim/adagrad.py +168 -0
  196. mindspore/experimental/optim/adam.py +193 -0
  197. mindspore/experimental/optim/adamax.py +170 -0
  198. mindspore/experimental/optim/adamw.py +290 -0
  199. mindspore/experimental/optim/asgd.py +153 -0
  200. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  201. mindspore/experimental/optim/nadam.py +157 -0
  202. mindspore/experimental/optim/optimizer.py +262 -0
  203. mindspore/experimental/optim/radam.py +194 -0
  204. mindspore/experimental/optim/rmsprop.py +154 -0
  205. mindspore/experimental/optim/rprop.py +164 -0
  206. mindspore/experimental/optim/sgd.py +156 -0
  207. mindspore/hal/__init__.py +40 -0
  208. mindspore/hal/_ascend.py +57 -0
  209. mindspore/hal/_base.py +57 -0
  210. mindspore/hal/_cpu.py +56 -0
  211. mindspore/hal/_gpu.py +57 -0
  212. mindspore/hal/contiguous_tensors_handle.py +175 -0
  213. mindspore/hal/device.py +356 -0
  214. mindspore/hal/event.py +179 -0
  215. mindspore/hal/memory.py +326 -0
  216. mindspore/hal/stream.py +357 -0
  217. mindspore/include/OWNERS +7 -0
  218. mindspore/include/api/allocator.h +97 -0
  219. mindspore/include/api/callback/callback.h +93 -0
  220. mindspore/include/api/callback/ckpt_saver.h +41 -0
  221. mindspore/include/api/callback/loss_monitor.h +33 -0
  222. mindspore/include/api/callback/lr_scheduler.h +51 -0
  223. mindspore/include/api/callback/time_monitor.h +34 -0
  224. mindspore/include/api/callback/train_accuracy.h +37 -0
  225. mindspore/include/api/cell.h +90 -0
  226. mindspore/include/api/cfg.h +82 -0
  227. mindspore/include/api/context.h +602 -0
  228. mindspore/include/api/data_type.h +47 -0
  229. mindspore/include/api/delegate.h +178 -0
  230. mindspore/include/api/delegate_api.h +75 -0
  231. mindspore/include/api/dual_abi_helper.h +208 -0
  232. mindspore/include/api/format.h +28 -0
  233. mindspore/include/api/graph.h +46 -0
  234. mindspore/include/api/kernel.h +58 -0
  235. mindspore/include/api/kernel_api.h +168 -0
  236. mindspore/include/api/metrics/accuracy.h +36 -0
  237. mindspore/include/api/metrics/metrics.h +41 -0
  238. mindspore/include/api/model.h +438 -0
  239. mindspore/include/api/model_group.h +91 -0
  240. mindspore/include/api/model_parallel_runner.h +168 -0
  241. mindspore/include/api/serialization.h +185 -0
  242. mindspore/include/api/status.h +192 -0
  243. mindspore/include/api/types.h +431 -0
  244. mindspore/include/api/visible.h +41 -0
  245. mindspore/include/c_api/context_c.h +179 -0
  246. mindspore/include/c_api/data_type_c.h +52 -0
  247. mindspore/include/c_api/format_c.h +46 -0
  248. mindspore/include/c_api/model_c.h +347 -0
  249. mindspore/include/c_api/status_c.h +79 -0
  250. mindspore/include/c_api/tensor_c.h +146 -0
  251. mindspore/include/c_api/types_c.h +67 -0
  252. mindspore/include/dataset/config.h +163 -0
  253. mindspore/include/dataset/constants.h +363 -0
  254. mindspore/include/dataset/execute.h +196 -0
  255. mindspore/include/dataset/text.h +1092 -0
  256. mindspore/include/dataset/transforms.h +638 -0
  257. mindspore/include/dataset/vision.h +2129 -0
  258. mindspore/include/dataset/vision_ascend.h +206 -0
  259. mindspore/include/dataset/vision_lite.h +625 -0
  260. mindspore/jpeg62.dll +0 -0
  261. mindspore/log.py +633 -0
  262. mindspore/mindrecord/__init__.py +43 -0
  263. mindspore/mindrecord/common/__init__.py +17 -0
  264. mindspore/mindrecord/common/constant.py +20 -0
  265. mindspore/mindrecord/common/enums.py +44 -0
  266. mindspore/mindrecord/common/exceptions.py +311 -0
  267. mindspore/mindrecord/config.py +809 -0
  268. mindspore/mindrecord/filereader.py +174 -0
  269. mindspore/mindrecord/filewriter.py +722 -0
  270. mindspore/mindrecord/mindpage.py +210 -0
  271. mindspore/mindrecord/shardheader.py +141 -0
  272. mindspore/mindrecord/shardindexgenerator.py +74 -0
  273. mindspore/mindrecord/shardreader.py +117 -0
  274. mindspore/mindrecord/shardsegment.py +128 -0
  275. mindspore/mindrecord/shardutils.py +185 -0
  276. mindspore/mindrecord/shardwriter.py +237 -0
  277. mindspore/mindrecord/tools/__init__.py +17 -0
  278. mindspore/mindrecord/tools/cifar10.py +140 -0
  279. mindspore/mindrecord/tools/cifar100.py +153 -0
  280. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  281. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  282. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  283. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  284. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  285. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  286. mindspore/mindspore_backend.dll +0 -0
  287. mindspore/mindspore_common.dll +0 -0
  288. mindspore/mindspore_core.dll +0 -0
  289. mindspore/mindspore_glog.dll +0 -0
  290. mindspore/mindspore_np_dtype.dll +0 -0
  291. mindspore/mindspore_ops.dll +0 -0
  292. mindspore/mint/__init__.py +1586 -0
  293. mindspore/mint/distributed/__init__.py +31 -0
  294. mindspore/mint/distributed/distributed.py +254 -0
  295. mindspore/mint/linalg/__init__.py +22 -0
  296. mindspore/mint/nn/__init__.py +757 -0
  297. mindspore/mint/nn/functional.py +679 -0
  298. mindspore/mint/nn/layer/__init__.py +39 -0
  299. mindspore/mint/nn/layer/activation.py +133 -0
  300. mindspore/mint/nn/layer/normalization.py +477 -0
  301. mindspore/mint/nn/layer/pooling.py +110 -0
  302. mindspore/mint/optim/__init__.py +24 -0
  303. mindspore/mint/optim/adamw.py +206 -0
  304. mindspore/mint/special/__init__.py +63 -0
  305. mindspore/msobj140.dll +0 -0
  306. mindspore/mspdb140.dll +0 -0
  307. mindspore/mspdbcore.dll +0 -0
  308. mindspore/mspdbst.dll +0 -0
  309. mindspore/mspft140.dll +0 -0
  310. mindspore/msvcdis140.dll +0 -0
  311. mindspore/msvcp140.dll +0 -0
  312. mindspore/msvcp140_1.dll +0 -0
  313. mindspore/msvcp140_2.dll +0 -0
  314. mindspore/msvcp140_atomic_wait.dll +0 -0
  315. mindspore/msvcp140_codecvt_ids.dll +0 -0
  316. mindspore/multiprocessing/__init__.py +73 -0
  317. mindspore/nn/__init__.py +47 -0
  318. mindspore/nn/cell.py +2787 -0
  319. mindspore/nn/dynamic_lr.py +482 -0
  320. mindspore/nn/grad/__init__.py +21 -0
  321. mindspore/nn/grad/cell_grad.py +196 -0
  322. mindspore/nn/layer/__init__.py +63 -0
  323. mindspore/nn/layer/activation.py +1822 -0
  324. mindspore/nn/layer/basic.py +1629 -0
  325. mindspore/nn/layer/channel_shuffle.py +90 -0
  326. mindspore/nn/layer/combined.py +248 -0
  327. mindspore/nn/layer/container.py +734 -0
  328. mindspore/nn/layer/conv.py +1505 -0
  329. mindspore/nn/layer/dense.py +204 -0
  330. mindspore/nn/layer/embedding.py +869 -0
  331. mindspore/nn/layer/image.py +661 -0
  332. mindspore/nn/layer/math.py +1069 -0
  333. mindspore/nn/layer/normalization.py +1273 -0
  334. mindspore/nn/layer/padding.py +880 -0
  335. mindspore/nn/layer/pooling.py +2302 -0
  336. mindspore/nn/layer/rnn_cells.py +388 -0
  337. mindspore/nn/layer/rnns.py +849 -0
  338. mindspore/nn/layer/thor_layer.py +963 -0
  339. mindspore/nn/layer/timedistributed.py +155 -0
  340. mindspore/nn/layer/transformer.py +823 -0
  341. mindspore/nn/learning_rate_schedule.py +512 -0
  342. mindspore/nn/loss/__init__.py +36 -0
  343. mindspore/nn/loss/loss.py +2924 -0
  344. mindspore/nn/metrics.py +53 -0
  345. mindspore/nn/optim/__init__.py +45 -0
  346. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  347. mindspore/nn/optim/ada_grad.py +217 -0
  348. mindspore/nn/optim/adadelta.py +206 -0
  349. mindspore/nn/optim/adafactor.py +448 -0
  350. mindspore/nn/optim/adam.py +1297 -0
  351. mindspore/nn/optim/adamax.py +220 -0
  352. mindspore/nn/optim/adasum.py +548 -0
  353. mindspore/nn/optim/asgd.py +216 -0
  354. mindspore/nn/optim/ftrl.py +401 -0
  355. mindspore/nn/optim/lamb.py +296 -0
  356. mindspore/nn/optim/lars.py +202 -0
  357. mindspore/nn/optim/lazyadam.py +533 -0
  358. mindspore/nn/optim/momentum.py +239 -0
  359. mindspore/nn/optim/optimizer.py +1034 -0
  360. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  361. mindspore/nn/optim/rmsprop.py +264 -0
  362. mindspore/nn/optim/rprop.py +251 -0
  363. mindspore/nn/optim/sgd.py +237 -0
  364. mindspore/nn/optim/tft_wrapper.py +127 -0
  365. mindspore/nn/optim/thor.py +1310 -0
  366. mindspore/nn/probability/__init__.py +22 -0
  367. mindspore/nn/probability/bijector/__init__.py +35 -0
  368. mindspore/nn/probability/bijector/bijector.py +337 -0
  369. mindspore/nn/probability/bijector/exp.py +65 -0
  370. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  371. mindspore/nn/probability/bijector/invert.py +126 -0
  372. mindspore/nn/probability/bijector/power_transform.py +196 -0
  373. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  374. mindspore/nn/probability/bijector/softplus.py +189 -0
  375. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  376. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  377. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  378. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  379. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  380. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  381. mindspore/nn/probability/distribution/__init__.py +56 -0
  382. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  383. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  384. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  385. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  386. mindspore/nn/probability/distribution/beta.py +391 -0
  387. mindspore/nn/probability/distribution/categorical.py +435 -0
  388. mindspore/nn/probability/distribution/cauchy.py +383 -0
  389. mindspore/nn/probability/distribution/distribution.py +827 -0
  390. mindspore/nn/probability/distribution/exponential.py +350 -0
  391. mindspore/nn/probability/distribution/gamma.py +391 -0
  392. mindspore/nn/probability/distribution/geometric.py +335 -0
  393. mindspore/nn/probability/distribution/gumbel.py +257 -0
  394. mindspore/nn/probability/distribution/half_normal.py +133 -0
  395. mindspore/nn/probability/distribution/laplace.py +128 -0
  396. mindspore/nn/probability/distribution/log_normal.py +272 -0
  397. mindspore/nn/probability/distribution/logistic.py +379 -0
  398. mindspore/nn/probability/distribution/normal.py +336 -0
  399. mindspore/nn/probability/distribution/poisson.py +288 -0
  400. mindspore/nn/probability/distribution/student_t.py +149 -0
  401. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  402. mindspore/nn/probability/distribution/uniform.py +375 -0
  403. mindspore/nn/reinforcement/__init__.py +24 -0
  404. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  405. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  406. mindspore/nn/reinforcement/tensor_array.py +145 -0
  407. mindspore/nn/sparse/__init__.py +23 -0
  408. mindspore/nn/sparse/sparse.py +147 -0
  409. mindspore/nn/wrap/__init__.py +49 -0
  410. mindspore/nn/wrap/cell_wrapper.py +968 -0
  411. mindspore/nn/wrap/grad_reducer.py +608 -0
  412. mindspore/nn/wrap/loss_scale.py +694 -0
  413. mindspore/numpy/__init__.py +121 -0
  414. mindspore/numpy/array_creations.py +2731 -0
  415. mindspore/numpy/array_ops.py +2629 -0
  416. mindspore/numpy/dtypes.py +185 -0
  417. mindspore/numpy/fft.py +966 -0
  418. mindspore/numpy/logic_ops.py +936 -0
  419. mindspore/numpy/math_ops.py +5911 -0
  420. mindspore/numpy/utils.py +214 -0
  421. mindspore/numpy/utils_const.py +565 -0
  422. mindspore/opencv_core452.dll +0 -0
  423. mindspore/opencv_imgcodecs452.dll +0 -0
  424. mindspore/opencv_imgproc452.dll +0 -0
  425. mindspore/ops/__init__.py +56 -0
  426. mindspore/ops/_constants.py +30 -0
  427. mindspore/ops/_grad_experimental/__init__.py +31 -0
  428. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  429. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  430. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  431. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  432. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  433. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  434. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  435. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  436. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  437. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  438. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  439. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  440. mindspore/ops/_op_impl/__init__.py +23 -0
  441. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  442. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  443. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  444. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  445. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  446. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  447. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  448. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  449. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  450. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  451. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  452. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  453. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  454. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  455. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  456. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  457. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  458. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  459. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  460. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  461. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  462. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  463. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  464. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  465. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  466. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  467. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  468. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  469. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  470. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  471. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  472. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  473. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  474. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  475. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  476. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  477. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  478. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  479. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  480. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  481. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  482. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  483. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  484. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  485. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  486. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  487. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  488. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  490. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  491. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  493. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  494. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  495. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  496. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  497. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  498. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  499. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  500. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  501. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  502. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  503. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  504. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  505. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  506. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  507. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  508. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  509. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  510. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  511. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  513. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  514. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  515. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  516. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  517. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  518. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  519. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  520. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  521. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  522. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  523. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  524. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  526. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  527. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  528. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  529. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  530. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  531. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  532. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  533. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  534. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  535. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  537. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  538. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  539. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  540. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  541. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  542. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  543. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  544. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  545. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  546. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  547. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  548. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  549. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  550. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  551. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  552. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  553. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  554. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  555. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  556. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  557. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  558. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  559. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  560. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  561. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  562. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  563. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  564. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  565. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  566. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  567. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  568. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  569. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  570. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  571. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  572. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  574. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  575. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  576. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  577. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  578. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  579. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  580. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  581. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  582. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  583. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  584. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  585. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  586. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  587. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  588. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  589. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  590. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  591. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  592. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  593. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  594. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  595. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  596. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  597. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  598. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  599. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  600. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  601. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  603. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  604. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  605. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  606. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  608. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  609. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  610. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  611. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  612. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  613. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  614. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  615. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  616. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  617. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  618. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  619. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  620. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  621. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  622. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  623. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  624. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  625. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  626. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  627. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  628. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  629. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  630. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  632. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  633. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  634. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  635. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  636. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  637. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  638. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  639. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  640. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  641. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  642. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  643. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  644. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  645. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  646. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  647. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  648. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  650. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  651. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  652. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  653. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  654. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  655. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  656. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  657. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  658. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  659. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  660. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  661. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  662. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  663. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  664. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  665. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  666. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  667. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  668. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  669. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  670. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  673. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  675. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  676. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  677. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  679. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  680. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  681. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  682. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  683. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  684. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  685. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  686. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  687. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  688. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  689. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  690. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  691. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  692. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  693. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  694. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  695. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  696. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  697. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  698. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  699. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  700. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  701. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  702. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  703. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  704. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  705. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  706. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  707. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  708. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  709. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  710. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  711. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  712. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  713. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  714. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  715. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  716. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  717. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  718. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  719. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  720. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  721. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  722. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  723. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  724. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  725. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  726. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  727. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  728. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  729. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  730. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  731. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  732. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  733. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  734. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  735. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  736. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  737. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  738. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  739. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  740. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  741. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  742. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  743. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  744. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  745. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  746. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  747. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  748. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  749. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  750. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  751. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  752. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  753. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  754. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  755. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  756. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  757. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  758. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  759. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  760. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  761. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  762. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  763. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  764. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  765. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  766. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  767. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  768. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  769. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  770. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  772. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  773. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  774. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  775. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  776. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  777. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  778. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  779. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  780. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  781. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  782. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  783. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  784. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  785. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  786. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  787. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  788. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  789. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  790. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  791. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  792. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  793. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  794. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  795. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  796. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  797. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  798. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  799. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  800. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  801. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  802. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  803. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  804. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  805. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  806. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  808. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  809. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  810. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  811. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  812. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  813. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  814. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  815. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  816. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  817. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  818. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  819. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  820. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  821. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  822. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  823. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  825. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  826. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  827. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  828. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  829. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  841. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  843. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  844. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  845. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  846. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  847. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  848. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  849. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  850. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  851. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  852. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  853. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  854. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  855. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  856. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  857. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  858. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  859. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  860. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  861. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  862. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  863. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  864. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  865. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  866. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  868. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  869. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  870. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  871. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  872. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  873. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  874. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  876. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  877. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  878. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  879. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  880. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  881. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  882. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  883. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  884. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  885. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  886. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  887. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  888. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  889. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  890. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  891. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  892. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  893. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  894. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  895. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  896. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  897. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  898. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  899. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  900. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  901. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  902. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  903. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  904. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  905. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  906. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  907. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  908. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  909. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  910. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  911. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  912. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  913. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  914. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  915. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  916. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  917. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  918. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  919. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  920. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  921. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  922. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  923. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  924. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  925. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  926. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  927. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  929. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  930. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  931. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  932. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  933. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  934. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  935. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  936. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  937. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  938. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  939. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  940. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  941. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  942. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  943. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  944. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  945. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  946. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  947. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  948. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  949. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  950. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  951. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  952. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  953. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  954. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  955. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  956. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  957. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  958. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  959. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  960. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  961. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  962. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  963. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  964. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  965. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  966. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  967. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  968. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  969. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  970. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  971. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  972. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  973. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  974. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  975. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  976. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  977. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  978. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  979. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  980. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  981. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  982. mindspore/ops/_op_impl/cpu/div.py +32 -0
  983. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  984. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  985. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  986. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  987. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  988. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  989. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  990. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  991. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  992. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  993. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  994. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  995. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  996. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  997. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  998. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  999. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  1000. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  1001. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  1002. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  1003. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  1004. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  1005. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  1006. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  1007. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  1009. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  1010. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  1011. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  1012. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  1013. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  1014. mindspore/ops/_op_impl/cpu/range.py +34 -0
  1015. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  1016. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  1017. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  1018. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  1019. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  1020. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  1021. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  1022. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1023. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1024. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1025. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1026. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1027. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1028. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1029. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1030. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1031. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1032. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1033. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1034. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1035. mindspore/ops/_primitive_cache.py +90 -0
  1036. mindspore/ops/_register_for_op.py +73 -0
  1037. mindspore/ops/_utils/__init__.py +20 -0
  1038. mindspore/ops/_utils/utils.py +147 -0
  1039. mindspore/ops/_vmap/__init__.py +25 -0
  1040. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1041. mindspore/ops/_vmap/vmap_base.py +533 -0
  1042. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1043. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1044. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1045. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1046. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1047. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1048. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1049. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1050. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1051. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1052. mindspore/ops/auto_generate/__init__.py +31 -0
  1053. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1054. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1055. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1056. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1057. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1058. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1059. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1060. mindspore/ops/composite/__init__.py +71 -0
  1061. mindspore/ops/composite/base.py +1318 -0
  1062. mindspore/ops/composite/env_ops.py +41 -0
  1063. mindspore/ops/composite/math_ops.py +125 -0
  1064. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1065. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1066. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1067. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1068. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1069. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1070. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1071. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1072. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1073. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1074. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1075. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1076. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1077. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1078. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1079. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1080. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1081. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1082. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1083. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1084. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1085. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1086. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1087. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1088. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1089. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1090. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1091. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1092. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1093. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1094. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1095. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1096. mindspore/ops/deprecated.py +315 -0
  1097. mindspore/ops/function/__init__.py +782 -0
  1098. mindspore/ops/function/array_func.py +7226 -0
  1099. mindspore/ops/function/clip_func.py +384 -0
  1100. mindspore/ops/function/debug_func.py +181 -0
  1101. mindspore/ops/function/fft_func.py +44 -0
  1102. mindspore/ops/function/grad/__init__.py +34 -0
  1103. mindspore/ops/function/grad/grad_func.py +1425 -0
  1104. mindspore/ops/function/image_func.py +292 -0
  1105. mindspore/ops/function/linalg_func.py +416 -0
  1106. mindspore/ops/function/math_func.py +12228 -0
  1107. mindspore/ops/function/nn_func.py +8609 -0
  1108. mindspore/ops/function/other_func.py +115 -0
  1109. mindspore/ops/function/parameter_func.py +134 -0
  1110. mindspore/ops/function/random_func.py +1715 -0
  1111. mindspore/ops/function/reshard_func.py +104 -0
  1112. mindspore/ops/function/sparse_func.py +884 -0
  1113. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1114. mindspore/ops/function/spectral_func.py +150 -0
  1115. mindspore/ops/function/vmap_func.py +117 -0
  1116. mindspore/ops/functional.py +464 -0
  1117. mindspore/ops/op_info_register.py +1572 -0
  1118. mindspore/ops/operations/__init__.py +722 -0
  1119. mindspore/ops/operations/_csr_ops.py +403 -0
  1120. mindspore/ops/operations/_custom_grad.py +181 -0
  1121. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1122. mindspore/ops/operations/_grad_ops.py +2978 -0
  1123. mindspore/ops/operations/_infer_ops.py +19 -0
  1124. mindspore/ops/operations/_inner_ops.py +2544 -0
  1125. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1126. mindspore/ops/operations/_ms_kernel.py +601 -0
  1127. mindspore/ops/operations/_ocr_ops.py +379 -0
  1128. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1129. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1130. mindspore/ops/operations/_quant_ops.py +1844 -0
  1131. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1132. mindspore/ops/operations/_scalar_ops.py +106 -0
  1133. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1134. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1135. mindspore/ops/operations/_tensor_array.py +359 -0
  1136. mindspore/ops/operations/_thor_ops.py +807 -0
  1137. mindspore/ops/operations/array_ops.py +6124 -0
  1138. mindspore/ops/operations/comm_ops.py +1985 -0
  1139. mindspore/ops/operations/control_ops.py +127 -0
  1140. mindspore/ops/operations/custom_ops.py +1129 -0
  1141. mindspore/ops/operations/debug_ops.py +678 -0
  1142. mindspore/ops/operations/image_ops.py +1041 -0
  1143. mindspore/ops/operations/inner_ops.py +697 -0
  1144. mindspore/ops/operations/linalg_ops.py +95 -0
  1145. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1146. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1147. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1148. mindspore/ops/operations/math_ops.py +5095 -0
  1149. mindspore/ops/operations/nn_ops.py +9575 -0
  1150. mindspore/ops/operations/other_ops.py +874 -0
  1151. mindspore/ops/operations/random_ops.py +1288 -0
  1152. mindspore/ops/operations/reshard_ops.py +53 -0
  1153. mindspore/ops/operations/rl_ops.py +288 -0
  1154. mindspore/ops/operations/sparse_ops.py +2753 -0
  1155. mindspore/ops/operations/spectral_ops.py +111 -0
  1156. mindspore/ops/primitive.py +1046 -0
  1157. mindspore/ops/signature.py +54 -0
  1158. mindspore/ops/vm_impl_registry.py +91 -0
  1159. mindspore/ops_generate/__init__.py +27 -0
  1160. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1161. mindspore/ops_generate/arg_handler.py +197 -0
  1162. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1163. mindspore/ops_generate/gen_constants.py +36 -0
  1164. mindspore/ops_generate/gen_ops.py +1099 -0
  1165. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1166. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1167. mindspore/ops_generate/gen_utils.py +209 -0
  1168. mindspore/ops_generate/op_proto.py +145 -0
  1169. mindspore/ops_generate/pyboost_utils.py +367 -0
  1170. mindspore/ops_generate/template.py +261 -0
  1171. mindspore/parallel/__init__.py +30 -0
  1172. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1173. mindspore/parallel/_cell_wrapper.py +174 -0
  1174. mindspore/parallel/_cost_model_context.py +700 -0
  1175. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1176. mindspore/parallel/_offload_context.py +275 -0
  1177. mindspore/parallel/_parallel_serialization.py +561 -0
  1178. mindspore/parallel/_ps_context.py +242 -0
  1179. mindspore/parallel/_recovery_context.py +110 -0
  1180. mindspore/parallel/_tensor.py +730 -0
  1181. mindspore/parallel/_transformer/__init__.py +35 -0
  1182. mindspore/parallel/_transformer/layers.py +765 -0
  1183. mindspore/parallel/_transformer/loss.py +251 -0
  1184. mindspore/parallel/_transformer/moe.py +693 -0
  1185. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1186. mindspore/parallel/_transformer/transformer.py +3119 -0
  1187. mindspore/parallel/_utils.py +612 -0
  1188. mindspore/parallel/algo_parameter_config.py +400 -0
  1189. mindspore/parallel/checkpoint_transform.py +650 -0
  1190. mindspore/parallel/cluster/__init__.py +15 -0
  1191. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1192. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1193. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1194. mindspore/parallel/cluster/run.py +136 -0
  1195. mindspore/parallel/mpi/__init__.py +14 -0
  1196. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1197. mindspore/parallel/parameter_broadcast.py +151 -0
  1198. mindspore/parallel/shard.py +481 -0
  1199. mindspore/parallel/transform_safetensors.py +993 -0
  1200. mindspore/perf_msvcbuildinsights.dll +0 -0
  1201. mindspore/pgodb140.dll +0 -0
  1202. mindspore/pgort140.dll +0 -0
  1203. mindspore/profiler/__init__.py +28 -0
  1204. mindspore/profiler/common/__init__.py +14 -0
  1205. mindspore/profiler/common/constant.py +29 -0
  1206. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1207. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1208. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1209. mindspore/profiler/common/process_pool.py +41 -0
  1210. mindspore/profiler/common/registry.py +47 -0
  1211. mindspore/profiler/common/singleton.py +28 -0
  1212. mindspore/profiler/common/struct_type.py +118 -0
  1213. mindspore/profiler/common/util.py +472 -0
  1214. mindspore/profiler/common/validator/__init__.py +14 -0
  1215. mindspore/profiler/common/validator/validate_path.py +84 -0
  1216. mindspore/profiler/dynamic_profiler.py +694 -0
  1217. mindspore/profiler/envprofiling.py +254 -0
  1218. mindspore/profiler/parser/__init__.py +14 -0
  1219. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1220. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1221. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1222. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1223. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1229. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1230. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1231. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1232. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1233. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1234. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1235. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1236. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1237. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1238. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1239. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1240. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1241. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1242. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1243. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1244. mindspore/profiler/parser/container.py +229 -0
  1245. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1246. mindspore/profiler/parser/flops_parser.py +531 -0
  1247. mindspore/profiler/parser/framework_enum.py +111 -0
  1248. mindspore/profiler/parser/framework_parser.py +464 -0
  1249. mindspore/profiler/parser/framework_struct.py +61 -0
  1250. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1251. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1252. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1253. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1254. mindspore/profiler/parser/hccl_parser.py +573 -0
  1255. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1256. mindspore/profiler/parser/integrator.py +526 -0
  1257. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1258. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1259. mindspore/profiler/parser/minddata_parser.py +186 -0
  1260. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1261. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1262. mindspore/profiler/parser/optime_parser.py +250 -0
  1263. mindspore/profiler/parser/profiler_info.py +213 -0
  1264. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1265. mindspore/profiler/profiler.py +153 -0
  1266. mindspore/profiler/profiling.py +1922 -0
  1267. mindspore/rewrite/__init__.py +28 -0
  1268. mindspore/rewrite/api/__init__.py +17 -0
  1269. mindspore/rewrite/api/node.py +519 -0
  1270. mindspore/rewrite/api/node_type.py +53 -0
  1271. mindspore/rewrite/api/pattern_engine.py +490 -0
  1272. mindspore/rewrite/api/scoped_value.py +181 -0
  1273. mindspore/rewrite/api/symbol_tree.py +497 -0
  1274. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1275. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1276. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1277. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1278. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1279. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1280. mindspore/rewrite/common/__init__.py +19 -0
  1281. mindspore/rewrite/common/config.py +24 -0
  1282. mindspore/rewrite/common/error_log.py +39 -0
  1283. mindspore/rewrite/common/event.py +28 -0
  1284. mindspore/rewrite/common/namer.py +271 -0
  1285. mindspore/rewrite/common/namespace.py +118 -0
  1286. mindspore/rewrite/common/observable.py +44 -0
  1287. mindspore/rewrite/common/observer.py +54 -0
  1288. mindspore/rewrite/node/__init__.py +22 -0
  1289. mindspore/rewrite/node/call_function.py +95 -0
  1290. mindspore/rewrite/node/cell_container.py +139 -0
  1291. mindspore/rewrite/node/control_flow.py +113 -0
  1292. mindspore/rewrite/node/node.py +1428 -0
  1293. mindspore/rewrite/node/node_manager.py +283 -0
  1294. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1295. mindspore/rewrite/parsers/__init__.py +29 -0
  1296. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1297. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1298. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1299. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1300. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1301. mindspore/rewrite/parsers/container_parser.py +88 -0
  1302. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1303. mindspore/rewrite/parsers/for_parser.py +61 -0
  1304. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1305. mindspore/rewrite/parsers/if_parser.py +85 -0
  1306. mindspore/rewrite/parsers/module_parser.py +117 -0
  1307. mindspore/rewrite/parsers/parser.py +43 -0
  1308. mindspore/rewrite/parsers/parser_register.py +86 -0
  1309. mindspore/rewrite/parsers/return_parser.py +37 -0
  1310. mindspore/rewrite/parsers/while_parser.py +59 -0
  1311. mindspore/rewrite/sparsify/__init__.py +0 -0
  1312. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1313. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1314. mindspore/rewrite/sparsify/utils.py +179 -0
  1315. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1316. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1317. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1318. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1319. mindspore/run_check/__init__.py +20 -0
  1320. mindspore/run_check/_check_version.py +507 -0
  1321. mindspore/run_check/run_check.py +66 -0
  1322. mindspore/safeguard/__init__.py +18 -0
  1323. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1324. mindspore/swresample-4.dll +0 -0
  1325. mindspore/swscale-6.dll +0 -0
  1326. mindspore/tbbmalloc.dll +0 -0
  1327. mindspore/tinyxml2.dll +0 -0
  1328. mindspore/train/__init__.py +48 -0
  1329. mindspore/train/_utils.py +465 -0
  1330. mindspore/train/amp.py +935 -0
  1331. mindspore/train/anf_ir_pb2.py +1517 -0
  1332. mindspore/train/callback/__init__.py +44 -0
  1333. mindspore/train/callback/_backup_and_restore.py +117 -0
  1334. mindspore/train/callback/_callback.py +613 -0
  1335. mindspore/train/callback/_checkpoint.py +814 -0
  1336. mindspore/train/callback/_cluster_monitor.py +201 -0
  1337. mindspore/train/callback/_dataset_graph.py +150 -0
  1338. mindspore/train/callback/_early_stop.py +239 -0
  1339. mindspore/train/callback/_flops_collector.py +239 -0
  1340. mindspore/train/callback/_history.py +92 -0
  1341. mindspore/train/callback/_lambda_callback.py +80 -0
  1342. mindspore/train/callback/_landscape.py +1049 -0
  1343. mindspore/train/callback/_loss_monitor.py +107 -0
  1344. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1345. mindspore/train/callback/_on_request_exit.py +298 -0
  1346. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1347. mindspore/train/callback/_summary_collector.py +1184 -0
  1348. mindspore/train/callback/_tft_register.py +352 -0
  1349. mindspore/train/callback/_time_monitor.py +141 -0
  1350. mindspore/train/checkpoint_pb2.py +233 -0
  1351. mindspore/train/data_sink.py +219 -0
  1352. mindspore/train/dataset_helper.py +692 -0
  1353. mindspore/train/lineage_pb2.py +1260 -0
  1354. mindspore/train/loss_scale_manager.py +213 -0
  1355. mindspore/train/memory_profiling_pb2.py +298 -0
  1356. mindspore/train/metrics/__init__.py +175 -0
  1357. mindspore/train/metrics/accuracy.py +133 -0
  1358. mindspore/train/metrics/auc.py +129 -0
  1359. mindspore/train/metrics/bleu_score.py +170 -0
  1360. mindspore/train/metrics/confusion_matrix.py +700 -0
  1361. mindspore/train/metrics/cosine_similarity.py +109 -0
  1362. mindspore/train/metrics/dice.py +116 -0
  1363. mindspore/train/metrics/error.py +175 -0
  1364. mindspore/train/metrics/fbeta.py +167 -0
  1365. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1366. mindspore/train/metrics/loss.py +97 -0
  1367. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1368. mindspore/train/metrics/metric.py +373 -0
  1369. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1370. mindspore/train/metrics/perplexity.py +133 -0
  1371. mindspore/train/metrics/precision.py +160 -0
  1372. mindspore/train/metrics/recall.py +159 -0
  1373. mindspore/train/metrics/roc.py +223 -0
  1374. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1375. mindspore/train/metrics/topk.py +167 -0
  1376. mindspore/train/mind_ir_pb2.py +1908 -0
  1377. mindspore/train/model.py +2252 -0
  1378. mindspore/train/node_strategy_pb2.py +653 -0
  1379. mindspore/train/print_pb2.py +184 -0
  1380. mindspore/train/profiling_parallel_pb2.py +151 -0
  1381. mindspore/train/serialization.py +3325 -0
  1382. mindspore/train/summary/__init__.py +23 -0
  1383. mindspore/train/summary/_lineage_adapter.py +41 -0
  1384. mindspore/train/summary/_summary_adapter.py +496 -0
  1385. mindspore/train/summary/_writer_pool.py +207 -0
  1386. mindspore/train/summary/enums.py +56 -0
  1387. mindspore/train/summary/summary_record.py +581 -0
  1388. mindspore/train/summary/writer.py +167 -0
  1389. mindspore/train/summary_pb2.py +1165 -0
  1390. mindspore/train/train_thor/__init__.py +20 -0
  1391. mindspore/train/train_thor/convert_utils.py +268 -0
  1392. mindspore/train/train_thor/dataset_helper.py +192 -0
  1393. mindspore/train/train_thor/model_thor.py +257 -0
  1394. mindspore/turbojpeg.dll +0 -0
  1395. mindspore/utils/__init__.py +21 -0
  1396. mindspore/utils/utils.py +60 -0
  1397. mindspore/vcmeta.dll +0 -0
  1398. mindspore/vcomp140.dll +0 -0
  1399. mindspore/vcruntime140.dll +0 -0
  1400. mindspore/vcruntime140_1.dll +0 -0
  1401. mindspore/version.py +1 -0
  1402. mindspore-2.4.0.dist-info/METADATA +352 -0
  1403. mindspore-2.4.0.dist-info/RECORD +1406 -0
  1404. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1405. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1406. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2250 @@
1
+ # Copyright 2022-2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ # pylint: disable=unused-variable
17
+ """nn_ops vmap impl."""
18
+ from __future__ import absolute_import
19
+
20
+ import mindspore
21
+ from mindspore.common import Tensor
22
+ from mindspore.ops import operations as P
23
+ from mindspore.ops.operations import _grad_ops as G
24
+ from mindspore.ops.operations import nn_ops as NN
25
+ from mindspore.ops import functional as F
26
+ from mindspore.ops import constexpr
27
+ from mindspore.ops.primitive import _primexpr
28
+ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, \
29
+ _bdim_at_any, _bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, \
30
+ _vmap_clone_prim, _get_reduce_batch_axis
31
+ from mindspore.ops.primitive import Primitive
32
+ from mindspore.ops.auto_generate.gen_arg_handler import Format
33
+ from mindspore.ops.auto_generate import Embedding
34
+ from mindspore.ops.auto_generate import gen_arg_handler as handler
35
+
36
+
37
+ @vmap_rules_getters.register(P.ApplyAdaMax)
38
+ def get_apply_ada_max_rule(prim, axis_size):
39
+ """VmapRule for `ApplyAdaMax` operation."""
40
+ if hasattr(prim, 'batch_rank'):
41
+ batch_rank = prim.batch_rank + 1
42
+ else:
43
+ batch_rank = 1
44
+ prim_name = prim.name
45
+ batch_prim = _vmap_clone_prim(prim)
46
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
47
+
48
+ def vmap_rule(var_bdim, m_bdim, v_bdim, beta1_power_bdim, lr_bdim, beta1_bdim, beta2_bdim,
49
+ epsilon_bdim, grad_bdim, u_monad):
50
+ var, var_dim = var_bdim
51
+ m, m_dim = m_bdim
52
+ v, v_dim = v_bdim
53
+ lr, lr_dim = lr_bdim
54
+ beta1_power, beta1_power_dim = beta1_power_bdim
55
+ beta1, beta1_dim = beta1_bdim
56
+ beta2, beta2_dim = beta2_bdim
57
+ epsilon, epsilon_dim = epsilon_bdim
58
+ grad, grad_dim = grad_bdim
59
+
60
+ if var_dim is None:
61
+ if any(dim is not None for dim in [m_bdim, v_bdim, beta1_power_bdim, lr_bdim, beta1_bdim, beta2_bdim,
62
+ epsilon_bdim, grad_bdim]):
63
+ raise ValueError("The source axis of `var` is None, but the source "
64
+ "axis of `accum/lr/beta1/beta1_power/beta2/epsilon/grad` is not None. "
65
+ "The execution order of operator `{}` cannot be guaranteed.".format(prim_name))
66
+ var, m, v = prim(var, m, v, beta1_power, lr, beta1, beta2, epsilon, grad, u_monad)
67
+ return (var, None), (m, None), (v, None)
68
+ if var_dim != 0 or m_dim != var_dim or var_dim != v_dim:
69
+ raise ValueError("For `{}`, the source axis of `var` must be equal to `accum`, and not equal to 0, "
70
+ "but got the source axis of `var`: {}, `accum`: {}.".format(prim_name, var_dim, m_dim))
71
+
72
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
73
+ beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
74
+ beta1 = _bdim_at_front(beta1, beta1_dim, axis_size)
75
+ beta2 = _bdim_at_front(beta2, beta2_dim, axis_size)
76
+ epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
77
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
78
+ var, m, v = batch_prim(var, m, v, beta1_power, lr, beta1, beta2, epsilon, grad, u_monad)
79
+ return (var, 0), (m, 0), (v, 0)
80
+
81
+ return vmap_rule
82
+
83
+
84
+ @vmap_rules_getters.register(P.ApplyAdadelta)
85
+ def get_apply_adadelta_rule(prim, axis_size):
86
+ """VmapRule for `ApplyAdadelta` operation."""
87
+ if hasattr(prim, 'batch_rank'):
88
+ batch_rank = prim.batch_rank + 1
89
+ else:
90
+ batch_rank = 1
91
+
92
+ prim_name = prim.name
93
+ batch_prim = _vmap_clone_prim(prim)
94
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
95
+
96
+ def vmap_rule(var_bdim, accum_bdim, accum_update_bdim, lr_bdim, rho_bdim, epsilon_bdim, grad_bdim, u_monad):
97
+ var, var_dim = var_bdim
98
+ accum, accum_dim = accum_bdim
99
+ accum_update, accum_update_dim = accum_update_bdim
100
+ lr, lr_dim = lr_bdim
101
+ rho, rho_dim = rho_bdim
102
+ epsilon, epsilon_dim = epsilon_bdim
103
+ grad, grad_dim = grad_bdim
104
+
105
+ if var_dim is None:
106
+ if any(dim is not None for dim in [accum, accum_dim, lr_dim, rho_dim, epsilon_dim, grad_dim]):
107
+ raise ValueError("The source axis of `var` is None, but the source "
108
+ "axis of `accum/accum_dim/lr/rho/epsilon/grad` is not None. The execution order of "
109
+ "operator `{}` cannot be guaranteed.".format(prim_name))
110
+ var, accum, accum_update = prim(var, accum, accum_update, lr, rho, epsilon, grad, u_monad)
111
+ return (var, None), (accum, None), (accum_update, None)
112
+ if var_dim != 0 or accum_dim != var_dim or accum_update_dim != var_dim:
113
+ raise ValueError(
114
+ "For `{}`, the source axis of `var` must be equal to `accum` and `accum_update`, and not equal to 0, "
115
+ "but got the source axis of `var`: {}, `accum`: {}, `accum_update`: {}.".format(
116
+ prim_name, var_dim, accum_dim, accum_update_dim))
117
+
118
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
119
+ rho = _bdim_at_front(rho, rho_dim, axis_size)
120
+ epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
121
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
122
+
123
+ var, accum, accum_update = batch_prim(var, accum, accum_update, lr, rho, epsilon, grad, u_monad)
124
+ return (var, 0), (accum, 0), (accum_update, 0)
125
+
126
+ return vmap_rule
127
+
128
+
129
+ @vmap_rules_getters.register(P.ApplyFtrl)
130
+ def get_apply_ftrl_rule(prim, axis_size):
131
+ """VmapRule for `ApplyFtrl` operation"""
132
+ if hasattr(prim, "batch_rank"):
133
+ batch_rank = prim.batch_rank + 1
134
+ else:
135
+ batch_rank = 1
136
+ prim_name = prim.name
137
+ batch_prim = _vmap_clone_prim(prim)
138
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
139
+
140
+ def vmap_rule(var_bdim, accum_bdim, linear_bdim, grad_bdim, lr_bdim, l1_bdim, l2_bdim, lr_power_bdim, u_monad):
141
+ var, var_dim = var_bdim
142
+ accum, accum_dim = accum_bdim
143
+ linear, linear_dim = linear_bdim
144
+ grad, grad_dim = grad_bdim
145
+ lr, lr_dim = lr_bdim
146
+ l1, l1_dim = l1_bdim
147
+ l2, l2_dim = l2_bdim
148
+ lr_power, lr_power_dim = lr_power_bdim
149
+
150
+ if var_dim is None:
151
+ if any(dim is not None for dim in [accum_dim, linear_dim, grad_dim, lr_dim, l1_dim, l2_dim, lr_power_dim]):
152
+ raise ValueError("The source axis of `var` is None, "
153
+ "but the source axis of `accum/linear/grad/lr/l1/l1/lr_power` is not None. "
154
+ "The execution order of operator `{}` cannot be guaranteed.".format(prim_name))
155
+ var = prim(var, accum, linear, grad, lr, l1, l2, lr_power, u_monad)
156
+ return var, None
157
+ if var_dim != 0 or accum_dim != var_dim or linear_dim != var_dim:
158
+ raise ValueError("For `{}`, the source axis of `var/accum/linear` must be 0, "
159
+ "but get `var`: {}, `accum`: {}, `linear`: {}.".format(prim_name, var_dim, accum_dim,
160
+ linear_dim))
161
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
162
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
163
+ l1 = _bdim_at_front(l1, l1_dim, axis_size)
164
+ l2 = _bdim_at_front(l2, l2_dim, axis_size)
165
+ lr_power = _bdim_at_front(lr_power, lr_power_dim, axis_size)
166
+
167
+ var = batch_prim(var, accum, linear, grad, lr, l1, l2, lr_power, u_monad)
168
+ return var, 0
169
+
170
+ return vmap_rule
171
+
172
+
173
+ @vmap_rules_getters.register(P.ApplyProximalAdagrad)
174
+ def get_apply_proximal_adagrad_rule(prim, axis_size):
175
+ """VmapRule for `ApplyProximalAdagrad` operation."""
176
+ if hasattr(prim, 'batch_rank'):
177
+ batch_rank = prim.batch_rank + 1
178
+ else:
179
+ batch_rank = 1
180
+
181
+ prim_name = prim.name
182
+ batch_prim = _vmap_clone_prim(prim)
183
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
184
+
185
+ def vmap_rule(var_bdim, accum_bdim, lr_bdim, l1_bdim, l2_bdim, grad_bdim, u_monad):
186
+ var, var_dim = var_bdim
187
+ accum, accum_dim = accum_bdim
188
+ lr, lr_dim = lr_bdim
189
+ l1, l1_dim = l1_bdim
190
+ l2, l2_dim = l2_bdim
191
+ grad, grad_dim = grad_bdim
192
+
193
+ if var_dim is None:
194
+ if any(dim is not None for dim in [accum_dim, lr_dim, l1_dim, l2_dim, grad_dim]):
195
+ raise ValueError("The source axis of `var` is None, but the source "
196
+ "axis of `accum/lr/l1/l2/grad` is not None. The execution order of "
197
+ "operator `{}` cannot be guaranteed.".format(prim_name))
198
+ var, accum = prim(var, accum, lr, l1, l2, grad, u_monad)
199
+ return (var, None), (accum, None)
200
+
201
+ if var_dim != 0 or accum_dim != var_dim:
202
+ raise ValueError("For `{}`, the source axis of `var` must be equal to `accum`, and not equal to 0, "
203
+ "but got the source axis of `var`: {}, `accum`: {}.".format(prim_name, var_dim, accum_dim))
204
+
205
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
206
+ l1 = _bdim_at_front(l1, l1_dim, axis_size)
207
+ l2 = _bdim_at_front(l2, l2_dim, axis_size)
208
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
209
+
210
+ var, accum = batch_prim(var, accum, lr, l1, l2, grad, u_monad)
211
+ return (var, 0), (accum, 0)
212
+
213
+ return vmap_rule
214
+
215
+
216
+ @vmap_rules_getters.register(P.ApplyGradientDescent)
217
+ def get_apply_gradient_descent_rule(prim, axis_size):
218
+ """VmapRule for `ApplyGradientDescent` operation."""
219
+ if hasattr(prim, 'batch_rank'):
220
+ batch_rank = prim.batch_rank + 1
221
+ else:
222
+ batch_rank = 1
223
+
224
+ prim_name = prim.name
225
+ batch_prim = _vmap_clone_prim(prim)
226
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
227
+
228
+ def vmap_rule(var_bdim, alpha_bdim, delta_bdim, u_monad):
229
+ var, var_dim = var_bdim
230
+ alpha, alpha_dim = alpha_bdim
231
+ delta, delta_dim = delta_bdim
232
+
233
+ if var_dim is None:
234
+ if any(dim is not None for dim in [alpha_dim, delta_dim]):
235
+ raise ValueError("The source axis of `var` is None, but the source "
236
+ "axis of `alpha/delta` is not None. The execution order of "
237
+ "operator `{}` cannot be guaranteed.".format(prim_name))
238
+ var = prim(var, alpha, delta, u_monad)
239
+ return var, None
240
+
241
+ if var_dim != 0:
242
+ raise ValueError("For `{}`, the source axis of `var` must not equal to 0, "
243
+ "but got the source axis of `var`: {}.".format(prim_name, var_dim))
244
+
245
+ alpha = _bdim_at_front(alpha, alpha_dim, axis_size)
246
+ delta = _bdim_at_front(delta, delta_dim, axis_size)
247
+
248
+ var = batch_prim(var, alpha, delta, u_monad)
249
+ return var, 0
250
+
251
+ return vmap_rule
252
+
253
+
254
+ @vmap_rules_getters.register(P.ApplyProximalGradientDescent)
255
+ def get_apply_proximal_gradient_descent_rule(prim, axis_size):
256
+ """VmapRule for `ApplyProximalGradientDescent` operation."""
257
+ if hasattr(prim, 'batch_rank'):
258
+ batch_rank = prim.batch_rank + 1
259
+ else:
260
+ batch_rank = 1
261
+
262
+ prim_name = prim.name
263
+ batch_prim = _vmap_clone_prim(prim)
264
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
265
+
266
+ def vmap_rule(var_bdim, alpha_bdim, l1_bdim, l2_bdim, delta_bdim, u_monad):
267
+ var, var_dim = var_bdim
268
+ alpha, alpha_dim = alpha_bdim
269
+ l1, l1_dim = l1_bdim
270
+ l2, l2_dim = l2_bdim
271
+ delta, delta_dim = delta_bdim
272
+
273
+ if var_dim is None:
274
+ if any(dim is not None for dim in [alpha_dim, l1_dim, l2_dim, delta_dim]):
275
+ raise ValueError("The source axis of `var` is None, but the source "
276
+ "axis of `alpha/l1/l2/delta` is not None. The execution order of "
277
+ "operator `{}` cannot be guaranteed.".format(prim_name))
278
+ var = prim(var, alpha, l1, l2, delta, u_monad)
279
+ return var, None
280
+
281
+ if var_dim != 0:
282
+ raise ValueError("For `{}`, the source axis of `var` must not equal to 0, "
283
+ "but got the source axis of `var`: {}.".format(prim_name, var_dim))
284
+
285
+ alpha = _bdim_at_front(alpha, alpha_dim, axis_size)
286
+ l1 = _bdim_at_front(l1, l1_dim, axis_size)
287
+ l2 = _bdim_at_front(l2, l2_dim, axis_size)
288
+ delta = _bdim_at_front(delta, delta_dim, axis_size)
289
+
290
+ var = batch_prim(var, alpha, l1, l2, delta, u_monad)
291
+ return var, 0
292
+
293
+ return vmap_rule
294
+
295
+
296
+ @vmap_rules_getters.register(NN.BCEWithLogitsLoss)
297
+ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
298
+ """VmapRule for 'BCEWithLogitsLoss' ."""
299
+
300
+ if isinstance(prim, str):
301
+ prim = Primitive(prim)
302
+ prim_name = prim.name
303
+ bce_logits_with_loss_op = NN.BCEWithLogitsLoss('none')
304
+
305
+ def vmap_rule(logits_bdim, label_bdim, weight_bdim, pos_weight_bdim, reduction_bdim):
306
+ is_all_none, result = vmap_general_preprocess(prim, logits_bdim, label_bdim, weight_bdim, pos_weight_bdim,
307
+ reduction_bdim)
308
+ if is_all_none:
309
+ return result
310
+ logits, logits_dim = logits_bdim
311
+ label, label_dim = label_bdim
312
+ weight, weight_dim = weight_bdim
313
+ pos_weight, pos_weight_dim = pos_weight_bdim
314
+ prim_reduction, _ = reduction_bdim
315
+ logits_rank = F.rank(logits)
316
+ label_rank = F.rank(label)
317
+ weight_rank = F.rank(weight)
318
+ pos_weight_rank = F.rank(pos_weight)
319
+ max_rank = max(logits_rank, label_rank)
320
+ max_rank = max(max_rank, weight_rank)
321
+ max_rank = max(max_rank, pos_weight_rank)
322
+ reduce_indexes = None
323
+ # If rank is larger than 1, we need to reduce result when reduction != 'none'
324
+ if max_rank > 1:
325
+ reduce_indexes = tuple(range(1, max_rank))
326
+ logits_dim_ok = logits_dim == label_dim and logits_dim == weight_dim and logits_dim == pos_weight_dim
327
+ shape = F.shape(logits)
328
+ shape_ok = shape == F.shape(label) and shape == F.shape(weight) and shape == F.shape(pos_weight)
329
+ if logits_dim_ok and shape_ok:
330
+ if prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'none'):
331
+ output = prim(logits, label, weight, pos_weight, prim_reduction)
332
+ elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'mean'):
333
+ out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
334
+ output = P.ReduceMean()(out, reduce_indexes)
335
+ elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'sum'):
336
+ out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
337
+ output = P.ReduceSum()(out, reduce_indexes)
338
+ else:
339
+ raise RuntimeError("For {} vmap, the attribute of reduction must in "
340
+ "('none', 'mean', 'sum'), but got {}."
341
+ .format(prim_name, prim_reduction))
342
+ return output, logits_dim
343
+
344
+ logits = _bdim_at_front(logits, logits_dim, axis_size)
345
+ label = _bdim_at_front(label, label_dim, axis_size)
346
+ weight = _bdim_at_front(weight, weight_dim, axis_size)
347
+ pos_weight = _bdim_at_front(pos_weight, pos_weight_dim, axis_size)
348
+ logits_shape = F.shape(logits)
349
+ weight_shape = F.shape(weight)
350
+ pos_weight_shape = F.shape(pos_weight)
351
+ weight = _handle_broadcasting(weight, weight_shape, logits_shape)
352
+ pos_weight = _handle_broadcasting(pos_weight, pos_weight_shape, logits_shape)
353
+ if prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'none'):
354
+ output = prim(logits, label, weight, pos_weight, prim_reduction)
355
+ elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'mean'):
356
+ out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
357
+ output = P.ReduceMean()(out, reduce_indexes)
358
+ elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'sum'):
359
+ out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
360
+ output = P.ReduceSum()(out, reduce_indexes)
361
+ else:
362
+ raise RuntimeError("For {} vmap, the attribute of reduction must in "
363
+ "('none', 'mean', 'sum'), but got {}."
364
+ .format(prim_name, prim_reduction))
365
+ return output, 0
366
+
367
+ return vmap_rule
368
+
369
+
370
+ @vmap_rules_getters.register(P.BiasAdd)
371
+ def get_bias_add_vmap_rule(prim, axis_size):
372
+ """VmapRule for `BiasAdd` operation."""
373
+ add_op = P.Add()
374
+
375
+ @constexpr
376
+ def get_channal_pos_in_x(d_format, n_dims):
377
+ if d_format == Format.NHWC:
378
+ return n_dims
379
+ return 2
380
+
381
+ @_primexpr
382
+ def get_bias_dst_shape(x_shape, n_dims, d_format, has_b_dim: bool):
383
+ pos = get_channal_pos_in_x(d_format, n_dims)
384
+
385
+ bias_shape = []
386
+ for i in range(n_dims):
387
+ if i != pos:
388
+ bias_shape.append(1)
389
+ else:
390
+ bias_shape.append(x_shape[i])
391
+
392
+ if has_b_dim:
393
+ bias_shape[0] = axis_size
394
+
395
+ return tuple(bias_shape)
396
+
397
+ def vmap_rule(x_bdim, bias_bdim, data_format_bdim):
398
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, bias_bdim, data_format_bdim)
399
+ if is_all_none:
400
+ return result
401
+
402
+ x, x_dim = x_bdim
403
+ b, b_dim = bias_bdim
404
+ data_format_data, _ = data_format_bdim
405
+
406
+ x = _bdim_at_front(x, x_dim, axis_size)
407
+ has_b_dim = False
408
+ if b_dim is not None:
409
+ b = _bdim_at_front(b, b_dim, axis_size)
410
+ has_b_dim = True
411
+
412
+ x_shape = x.shape
413
+ n_dims = len(x_shape)
414
+ b_shape = get_bias_dst_shape(x_shape, n_dims, data_format_data, has_b_dim)
415
+
416
+ b = b.reshape(b_shape)
417
+ result = add_op(x, b)
418
+
419
+ return (result, 0)
420
+
421
+ return vmap_rule
422
+
423
+
424
+ @vmap_rules_getters.register(G.BiasAddGrad)
425
+ def get_bias_add_grad_vmap_rule(prim, axis_size):
426
+ """VmapRule for `BiasAddGrad` operation."""
427
+ @constexpr
428
+ def get_channal_pos(d_format, x_rank):
429
+ if d_format == Format.NHWC:
430
+ return x_rank
431
+ return 2
432
+
433
+ @_primexpr
434
+ def get_axis_for_reduce(x_shape_rank, data_format):
435
+ channal_pos = get_channal_pos(data_format, x_shape_rank)
436
+ axis_list = ()
437
+ for i in range(1, x_shape_rank):
438
+ if channal_pos == i:
439
+ continue
440
+ axis_list += (i,)
441
+
442
+ return axis_list
443
+
444
+ def vmap_rule(x_bdim, data_format_bdim):
445
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, data_format_bdim)
446
+ if is_all_none:
447
+ return result
448
+
449
+ x, x_dim = x_bdim
450
+ data_format_data, _ = data_format_bdim
451
+ x = _bdim_at_front(x, x_dim, axis_size)
452
+ x_shape_rank = len(x.shape)
453
+
454
+ axis_for_reduce = get_axis_for_reduce(x_shape_rank, data_format_data)
455
+
456
+ result = x.sum(axis=axis_for_reduce)
457
+ return (result, 0)
458
+
459
+ return vmap_rule
460
+
461
+
462
+ @vmap_rules_getters.register(P.Dropout)
463
+ @vmap_rules_getters.register(P.Dropout2D)
464
+ @vmap_rules_getters.register(P.Dropout3D)
465
+ def get_dropout_nd_vmap_rule(prim, axis_size):
466
+ """VmapRule for 'DropoutND' operation."""
467
+ prim_name = prim.name
468
+ dropout_nd_dim = 4
469
+ if prim_name == "Dropout3D":
470
+ dropout_nd_dim = 5
471
+
472
+ def vmap_rule(x_bdim):
473
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
474
+ if is_all_none:
475
+ return result
476
+
477
+ x, x_dim = x_bdim
478
+ x = _bdim_at_front(x, x_dim, axis_size)
479
+ x_ndim = F.rank(x)
480
+ if x_ndim > dropout_nd_dim:
481
+ x_ori_shape = F.shape(x)
482
+ x = F.reshape(x, (-1,) + x_ori_shape[2:x_ndim])
483
+ output, mask = prim(x)
484
+ output = F.reshape(output, x_ori_shape)
485
+ mask = F.reshape(mask, x_ori_shape)
486
+ else:
487
+ output, mask = prim(x)
488
+
489
+ return (output, 0), (mask, 0)
490
+
491
+ return vmap_rule
492
+
493
+
494
+ @vmap_rules_getters.register(P.InTopK)
495
+ def get_in_top_k_vmap_rule(prim, axis_size):
496
+ """VmapRule for `InTopK`."""
497
+
498
+ def vmap_rule(x1_bdim, x2_bdim):
499
+ is_all_none, result = vmap_general_preprocess(prim, x1_bdim, x2_bdim)
500
+ if is_all_none:
501
+ return result
502
+
503
+ x1, x1_dim = x1_bdim
504
+ x2, x2_dim = x2_bdim
505
+ x1 = _bdim_at_front(x1, x1_dim, axis_size)
506
+ x2 = _bdim_at_front(x2, x2_dim, axis_size)
507
+ x1_shape = F.shape(x1)
508
+ x2_shape = F.shape(x2)
509
+ x1 = F.reshape(x1, (-1, x1_shape[-1]))
510
+ x2 = F.reshape(x2, (-1,))
511
+ output = prim(x1, x2)
512
+ output = F.reshape(output, x2_shape)
513
+ return output, 0
514
+
515
+ return vmap_rule
516
+
517
+
518
+ @vmap_rules_getters.register(G.FastGeLUGrad)
519
+ @vmap_rules_getters.register(G.HSwishGrad)
520
+ def get_common_activation_grad_vmap_rule(prim, axis_size):
521
+ """VmapRule for common activation grad operation."""
522
+ prim_name = prim.name
523
+
524
+ def vmap_rule(x_bdim, dy_bdim):
525
+ x, x_dim = x_bdim
526
+ dy, dy_dim = dy_bdim
527
+ x_shape = F.shape(x)
528
+ dy_shape = F.shape(dy)
529
+ if x_dim == dy_dim and x_shape == dy_shape:
530
+ out = prim(x, dy)
531
+ return out, x_dim
532
+
533
+ if F.rank(x):
534
+ x = _bdim_at_front(x, x_dim, 1)
535
+ if F.rank(dy):
536
+ dy = _bdim_at_front(dy, dy_dim, 1)
537
+ x_shape = F.shape(x)
538
+ dy_shape = F.shape(dy)
539
+ if x_shape != dy_shape:
540
+ raise RuntimeError("For {} vmap, input x shape is supposed to be the same as input dy shape "
541
+ "after batch transforming, but got x_shape {}, dy_shape {}"
542
+ .format(prim_name, x_shape, dy_shape))
543
+ out = prim(x, dy)
544
+ return out, 0
545
+
546
+ return vmap_rule
547
+
548
+
549
+ @vmap_rules_getters.register("SoftShrink")
550
+ def get_softshrink_vmap_rule(prim, axis_size):
551
+ """VmapRule for `SoftShrink`."""
552
+ def vmap_rule(x_bdim, lambd_bdim):
553
+ var, dim = x_bdim
554
+ lambd, _ = lambd_bdim
555
+ out = prim(var, lambd)
556
+ return out, dim
557
+
558
+ return vmap_rule
559
+
560
+
561
+ @vmap_rules_getters.register("SoftShrinkGrad")
562
+ def get_softshrink_grad_vmap_rule(prim, axis_size):
563
+ """VmapRule for `SoftShrinkGrad`."""
564
+ prim_name = prim.name
565
+
566
+ def vmap_rule(dy_bdim, x_bdim, lambd_bdim):
567
+ x, x_dim = x_bdim
568
+ lambd, _ = lambd_bdim
569
+ dy, dy_dim = dy_bdim
570
+ x_shape = F.shape(x)
571
+ dy_shape = F.shape(dy)
572
+ if x_dim == dy_dim and x_shape == dy_shape:
573
+ out = prim(dy, x, lambd)
574
+ return out, x_dim
575
+
576
+ if F.rank(x):
577
+ x = _bdim_at_front(x, x_dim, 1)
578
+ if F.rank(dy):
579
+ dy = _bdim_at_front(dy, dy_dim, 1)
580
+ x_shape = F.shape(x)
581
+ dy_shape = F.shape(dy)
582
+ if x_shape != dy_shape:
583
+ raise RuntimeError("For {} vmap, input x shape is supposed to be the same as input dy shape "
584
+ "after batch transforming, but got x_shape {}, dy_shape {}"
585
+ .format(prim_name, x_shape, dy_shape))
586
+ out = prim(dy, x, lambd)
587
+ return out, 0
588
+
589
+ return vmap_rule
590
+
591
+
592
+ @vmap_rules_getters.register("HShrink")
593
+ def get_hshrink_vmap_rule(prim, axis_size):
594
+ """VmapRule for `HShrink`."""
595
+ def vmap_rule(x_bdim, lambd_bdim):
596
+ var, dim = x_bdim
597
+ lambd, _ = lambd_bdim
598
+ out = prim(var, lambd)
599
+ return out, dim
600
+
601
+ return vmap_rule
602
+
603
+
604
+ @vmap_rules_getters.register("HShrinkGrad")
605
+ def get_hshrink_grad_vmap_rule(prim, axis_size):
606
+ """VmapRule for `HShrinkGrad`."""
607
+ prim_name = prim.name
608
+
609
+ def vmap_rule(dy_bdim, x_bdim, lambd_bdim):
610
+ x, x_dim = x_bdim
611
+ lambd, _ = lambd_bdim
612
+ dy, dy_dim = dy_bdim
613
+ x_shape = F.shape(x)
614
+ dy_shape = F.shape(dy)
615
+ if x_dim == dy_dim and x_shape == dy_shape:
616
+ out = prim(dy, x, lambd)
617
+ return out, x_dim
618
+
619
+ if F.rank(x):
620
+ x = _bdim_at_front(x, x_dim, 1)
621
+ if F.rank(dy):
622
+ dy = _bdim_at_front(dy, dy_dim, 1)
623
+ x_shape = F.shape(x)
624
+ dy_shape = F.shape(dy)
625
+ if x_shape != dy_shape:
626
+ raise RuntimeError("For {} vmap, input x shape is supposed to be the same as input dy shape "
627
+ "after batch transforming, but got x_shape {}, dy_shape {}"
628
+ .format(prim_name, x_shape, dy_shape))
629
+ out = prim(dy, x, lambd)
630
+ return out, 0
631
+
632
+ return vmap_rule
633
+
634
+
635
+ @vmap_rules_getters.register(P.Pad)
636
+ def get_pad_vmap_rule(prim, axis_size):
637
+ """VmapRule for `Pad`"""
638
+ paddings = prim.paddings
639
+
640
+ @constexpr
641
+ def _get_paddings(cur_paddings, x_dim):
642
+ """get paddings."""
643
+ new_paddings = list(cur_paddings)
644
+ new_paddings.insert(x_dim, (0, 0))
645
+ return tuple(new_paddings)
646
+
647
+ def vmap_rule(x_bdim):
648
+ x, x_dim = x_bdim
649
+ if x_dim is None:
650
+ # case1: batch not exists
651
+ out = prim(x)
652
+ else:
653
+ # case2: batch exists
654
+ new_paddings = _get_paddings(paddings, x_dim)
655
+ op = P.Pad(new_paddings)
656
+ out = op(x)
657
+ return out, x_dim
658
+
659
+ return vmap_rule
660
+
661
+
662
+ @vmap_rules_getters.register(NN.Pdist)
663
+ def get_pdist_vmap_rule(prim, axis_size):
664
+ """VmapRule for `Pdist`"""
665
+ if isinstance(prim, str):
666
+ prim = Primitive(prim)
667
+ prim.add_prim_attr('p', 2.0)
668
+
669
+ def vmap_rule(x_bdim):
670
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
671
+ if is_all_none:
672
+ return result
673
+ x, x_dim = x_bdim
674
+ x = _bdim_at_front(x, x_dim, axis_size)
675
+ out = prim(x)
676
+ return out, 0
677
+
678
+ return vmap_rule
679
+
680
+
681
+ @vmap_rules_getters.register(NN.DeformableOffsets)
682
+ def get_matmul_vmap_rule(prim, axis_size):
683
+ """VmapRule for `DeformableOffsets` operation."""
684
+ nchw_size = 4
685
+ chw_size = 3
686
+ chw_reverse_index = -chw_size
687
+
688
+ def vmap_rule(x_bdim, offsets_bdim):
689
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, offsets_bdim)
690
+ if is_all_none:
691
+ return result
692
+
693
+ x, x_dim = x_bdim
694
+ offsets, offsets_dim = offsets_bdim
695
+ x = _bdim_at_front(x, x_dim, axis_size)
696
+ x_ndim = F.rank(x)
697
+ x_origin_shape = F.shape(x)
698
+
699
+ offsets = _bdim_at_front(offsets, offsets_dim, axis_size)
700
+ offsets_ndim = F.rank(offsets)
701
+ offsets_origin_shape = F.shape(offsets)
702
+
703
+ batch_origin_shape = x_origin_shape
704
+ if x_ndim > nchw_size:
705
+ x = F.reshape(x, (-1,) + x_origin_shape[chw_reverse_index:])
706
+ if offsets_ndim > nchw_size:
707
+ offsets = F.reshape(offsets, (-1,) + offsets_origin_shape[chw_reverse_index:])
708
+ batch_origin_shape = offsets_origin_shape
709
+
710
+ out = prim(x, offsets)
711
+ out_shape = F.shape(out)
712
+ out = F.reshape(out, batch_origin_shape[:(nchw_size + 1 - chw_size)] + out_shape[chw_reverse_index:])
713
+ return out, 0
714
+
715
+ return vmap_rule
716
+
717
+
718
+ @vmap_rules_getters.register("Softmax")
719
+ def get_softmax_vmap_rule(prim, axis_size):
720
+ """VmapRule for `Softmax`"""
721
+
722
+ def vmap_rule(x_bdim, axis_bdim):
723
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, axis_bdim)
724
+ if is_all_none:
725
+ return result
726
+ x, x_dim = x_bdim
727
+ axis, _ = axis_bdim
728
+ x_ndim = F.rank(x)
729
+ if not F.isconstant(axis) or not F.isconstant(x_ndim):
730
+ raise ValueError
731
+ batch_axis = _get_reduce_batch_axis(axis, x_dim, x_ndim)
732
+ out = prim(x, batch_axis)
733
+ return out, x_dim
734
+
735
+ return vmap_rule
736
+
737
+
738
+ @vmap_rules_getters.register(P.AdaptiveAvgPool2D)
739
+ def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
740
+ """VmapRule for `AdaptiveAvgPool2D` operation."""
741
+ chw_reverse_index = -3
742
+ hw_reverse_index = -2
743
+
744
+ def vmap_rule(input_bdim):
745
+ is_all_none, result = vmap_general_preprocess(prim, input_bdim)
746
+ if is_all_none:
747
+ return result
748
+
749
+ input_x, x_dim = input_bdim
750
+ input_x = _bdim_at_front(input_x, x_dim, axis_size)
751
+ x_shape = F.shape(input_x)
752
+ input_shape = (-1,) + x_shape[chw_reverse_index:]
753
+ input_x = F.reshape(input_x, input_shape)
754
+ out = prim(input_x)
755
+ out_shape = F.shape(out)
756
+ real_out_shape = x_shape[:hw_reverse_index] + out_shape[hw_reverse_index:]
757
+ out = F.reshape(out, real_out_shape)
758
+ return out, 0
759
+
760
+ return vmap_rule
761
+
762
+
763
+ @vmap_rules_getters.register(NN.AdaptiveAvgPool3D)
764
+ def get_adaptive_avgpool3d_vmap_rule(prim, axis_size):
765
+ """VmapRule for `AdaptiveAvgPool3D` operation."""
766
+ dhw_reverse_index = -3
767
+ max_dims = 5
768
+
769
+ def vmap_rule(x_bdim):
770
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
771
+ if is_all_none:
772
+ return result
773
+
774
+ x, x_dim = x_bdim
775
+ x = _bdim_at_front(x, x_dim, axis_size)
776
+ if F.rank(x) == max_dims:
777
+ out = prim(x)
778
+ return out, 0
779
+
780
+ x_shape = F.shape(x)
781
+ shape = (-1,) + x_shape[dhw_reverse_index:]
782
+ x = F.reshape(x, shape)
783
+ out = prim(x)
784
+ out_shape = F.shape(out)
785
+ real_out_shape = x_shape[:dhw_reverse_index] + out_shape[dhw_reverse_index:]
786
+ out = F.reshape(out, real_out_shape)
787
+ return out, 0
788
+
789
+ return vmap_rule
790
+
791
+
792
+ @vmap_rules_getters.register("AvgPool")
793
+ def get_avgpool_vmap_rule(prim, axis_size):
794
+ """VmapRule for `AvgPool`."""
795
+ chw_reverse_index = -3
796
+
797
+ def vmap_rule(x_bdim, kernel_size_bdim, strides_bdim, pad_mode_bdim, data_format_bdim):
798
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, kernel_size_bdim, strides_bdim, pad_mode_bdim,
799
+ data_format_bdim)
800
+ if is_all_none:
801
+ return result
802
+
803
+ x, x_dim = x_bdim
804
+ kernel_size, _ = kernel_size_bdim
805
+ strides, _ = strides_bdim
806
+ pad_mode, _ = pad_mode_bdim
807
+ data_format, _ = data_format_bdim
808
+ x = _bdim_at_front(x, x_dim, axis_size)
809
+ x_shape = F.shape(x)
810
+ input_shape = (-1,) + x_shape[chw_reverse_index:]
811
+ x = F.reshape(x, input_shape)
812
+ out = prim(x, kernel_size, strides, pad_mode, data_format)
813
+ out_shape = F.shape(out)
814
+ real_out_shape = x_shape[:chw_reverse_index] + out_shape[chw_reverse_index:]
815
+ out = F.reshape(out, real_out_shape)
816
+ return out, 0
817
+
818
+ return vmap_rule
819
+
820
+
821
+ @vmap_rules_getters.register(NN.AdaptiveMaxPool3D)
822
+ def get_adaptive_max_pool3d_vmap_rule(prim, axis_size):
823
+ """VmapRule for `AdaptiveMaxPool3D`."""
824
+ dhw_reverse_index = -3
825
+ max_dims = 5
826
+
827
+ @constexpr
828
+ def convert_shape_to_tensor(shape):
829
+ return Tensor(shape, dtype=mindspore.int32)
830
+
831
+ def vmap_rule(x_bdim, out_size_bdim):
832
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, out_size_bdim)
833
+ if is_all_none:
834
+ return result
835
+
836
+ x, x_dim = x_bdim
837
+ out_size, out_size_dim = out_size_bdim
838
+ x = _bdim_at_front(x, x_dim, axis_size)
839
+ if out_size_dim is not None:
840
+ _raise_value_error("The source axis of `output_size` in `AdaptiveMaxPool3D` must be None, "
841
+ "but got {}.".format(out_size_dim))
842
+ if F.rank(x) == max_dims:
843
+ out, indices = prim(x, out_size)
844
+ return (out, 0), (indices, 0)
845
+
846
+ x_shape = F.shape(x)
847
+ shape = (-1,) + x_shape[dhw_reverse_index:]
848
+ x = F.reshape(x, shape)
849
+ out, indices = prim(x, out_size)
850
+ # AdaptiveMaxPool3D is a dynamic op, the 'shape' of reshape should be a tensor
851
+ front_shape = convert_shape_to_tensor(x_shape[:dhw_reverse_index])
852
+ output_shape = F.concat((front_shape, out_size))
853
+ out = F.reshape(out, output_shape)
854
+ indices = F.reshape(indices, output_shape)
855
+ return (out, 0), (indices, 0)
856
+
857
+ return vmap_rule
858
+
859
+
860
+ @vmap_rules_getters.register(NN.InstanceNorm)
861
+ def get_instance_norm_rule(prim, axis_size):
862
+ """VmapRule for `InstanceNorm` operation."""
863
+ if hasattr(prim, 'batch_rank'):
864
+ batch_rank = prim.batch_rank + 1
865
+ else:
866
+ batch_rank = 1
867
+
868
+ prim_name = prim.name
869
+ batch_prim = _vmap_clone_prim(prim)
870
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
871
+
872
+ def vmap_rule(input_x_bdim, gamma_bdim, beta_bdim, mean_bdim, variance_bdim, u_monad):
873
+ input_x, input_x_dim = input_x_bdim
874
+ gamma, gamma_dim = gamma_bdim
875
+ beta, beta_dim = beta_bdim
876
+ mean, mean_dim = mean_bdim
877
+ variance, variance_dim = variance_bdim
878
+ if gamma_dim is None:
879
+ if any(dim is not None for dim in [input_x_dim, beta_dim, mean_dim, variance_dim]):
880
+ raise ValueError("The source axis of `gamma` is None, but the source "
881
+ "axis of `input_x/beta/mean/variance` is not None. The execution order of "
882
+ "operator `{}` cannot be guaranteed.".format(prim_name))
883
+ output_x, updated_moving_mean, updated_moving_variance = prim(input_x, gamma, beta, mean, variance, u_monad)
884
+ return (output_x, None), (updated_moving_mean, None), (updated_moving_variance, None)
885
+
886
+ precondition = gamma_dim != 0 or beta_dim != gamma_dim or mean_dim != gamma_dim or variance_dim != gamma_dim
887
+ if precondition:
888
+ # pylint: disable=too-many-format-args
889
+ raise ValueError(
890
+ "For `{}`, the source axis of `var` must be equal to `accum` and `accum_update`, and not equal to 0, "
891
+ "but got the source axis of `var`: {}, `accum`: {}, `accum_update`: {}.".format(
892
+ prim_name, gamma_dim, beta_dim, mean_dim, variance_dim))
893
+ input_x = _bdim_at_front(input_x, input_x_dim, axis_size)
894
+ output_x, updated_moving_mean, updated_moving_variance = batch_prim(input_x, gamma, beta, mean, variance,
895
+ u_monad)
896
+ return (output_x, 0), (updated_moving_mean, 0), (updated_moving_variance, 0)
897
+
898
+ return vmap_rule
899
+
900
+
901
+ @vmap_rules_getters.register(P.KLDivLoss)
902
+ def get_kl_div_loss_vmap_rule(prim, axis_size):
903
+ """VmapRule for `KLDivLoss` operation."""
904
+ if isinstance(prim, str):
905
+ prim = Primitive(prim)
906
+
907
+ prim_reduction = prim.reduction
908
+ if prim_reduction == "mean":
909
+ kl_div_loss_op = P.KLDivLoss("none")
910
+ reduce_op = P.ReduceMean()
911
+ elif prim_reduction == "sum":
912
+ kl_div_loss_op = P.KLDivLoss("none")
913
+ reduce_op = P.ReduceSum()
914
+ elif prim_reduction == "batchmean":
915
+ kl_div_loss_op = P.KLDivLoss("none")
916
+ reduce_op = P.ReduceSum()
917
+ factor_op = P.Div()
918
+
919
+ def vmap_rule(x_bdim, target_bdim):
920
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, target_bdim)
921
+ if is_all_none:
922
+ return result
923
+
924
+ x, x_dim = x_bdim
925
+ target, target_dim = target_bdim
926
+ x_ndim = F.rank(x)
927
+ target_ndim = F.rank(target)
928
+ max_rank = max(x_ndim, target_ndim)
929
+ x = _bdim_at_front(x, x_dim, axis_size)
930
+ target = _bdim_at_front(target, target_dim, axis_size)
931
+ reduce_indexes = None
932
+ factor = 1
933
+ # if rank is larger than 1, we need to reduce result when reduction != 'none'
934
+ if max_rank > 1:
935
+ reduce_indexes = tuple(range(1, max_rank))
936
+ factor = F.shape(x)[1]
937
+
938
+ # elementwise style when reduction='none', otherwise reduce style
939
+ if prim_reduction == "none":
940
+ out = prim(x, target)
941
+ elif prim_reduction in ("mean", "sum"):
942
+ out = kl_div_loss_op(x, target)
943
+ if reduce_indexes is not None:
944
+ out = reduce_op(out, reduce_indexes)
945
+ elif prim_reduction == "batchmean":
946
+ out = kl_div_loss_op(x, target)
947
+ if reduce_indexes is not None:
948
+ out = reduce_op(out, reduce_indexes)
949
+ out = factor_op(out, factor)
950
+ else:
951
+ raise RuntimeError("For KLDivLoss vmap, reduction should be one of "
952
+ "['none', 'mean', 'batchmean', 'sum'], but got '{}'".format(prim_reduction))
953
+ return out, 0
954
+
955
+ return vmap_rule
956
+
957
+
958
+ @vmap_rules_getters.register(G.KLDivLossGrad)
959
+ def get_kl_div_loss_grad_vmap_rule(prim, axis_size):
960
+ """VmapRule for `KLDivLossGrad`."""
961
+ if isinstance(prim, str):
962
+ prim = Primitive(prim)
963
+ reduction = "mean"
964
+ else:
965
+ reduction = prim.reduction
966
+
967
+ kldivloss_grad = G.KLDivLossGrad(reduction=reduction)
968
+
969
+ def vmap_rule(dy_bdim, x_bdim, target_bdim):
970
+ is_all_none, result = vmap_general_preprocess(prim, dy_bdim, x_bdim, target_bdim)
971
+ if is_all_none:
972
+ return result
973
+
974
+ dy, dy_dim = dy_bdim
975
+ x, x_dim = x_bdim
976
+ target, target_dim = target_bdim
977
+ dy = _bdim_at_front(dy, dy_dim, axis_size)
978
+ x = _bdim_at_front(x, x_dim, axis_size)
979
+ target = _bdim_at_front(target, target_dim, axis_size)
980
+
981
+ out = kldivloss_grad(dy, x, target)
982
+ return out, 0
983
+
984
+ return vmap_rule
985
+
986
+
987
+ @vmap_rules_getters.register(P.SmoothL1Loss)
988
+ def get_smooth_l1_loss_vmap_rule(prim, axis_size):
989
+ """VmapRule for `SmoothL1Loss` operation."""
990
+ if isinstance(prim, str):
991
+ prim = Primitive(prim)
992
+ prim_beta = 1.0
993
+ prim_reduction = 'none'
994
+ else:
995
+ prim_reduction = prim.reduction
996
+ prim_beta = prim.beta
997
+
998
+ smooth_l1_loss_op = P.SmoothL1Loss(prim_beta, 'none')
999
+ if prim_reduction == 'mean':
1000
+ reduce_op = P.ReduceMean()
1001
+ elif prim_reduction == "sum":
1002
+ reduce_op = P.ReduceSum()
1003
+
1004
+ def vmap_rule(x_bdim, target_bdim):
1005
+ is_all_none, result = vmap_general_preprocess(
1006
+ prim, x_bdim, target_bdim)
1007
+ if is_all_none:
1008
+ return result
1009
+
1010
+ x, x_dim = x_bdim
1011
+ target, target_dim = target_bdim
1012
+ x_ndim = F.rank(x)
1013
+ target_ndim = F.rank(target)
1014
+ max_rank = max(x_ndim, target_ndim)
1015
+ x = _bdim_at_front(x, x_dim, axis_size)
1016
+ target = _bdim_at_front(target, target_dim, axis_size)
1017
+ reduce_indexes = None
1018
+ # if rank is larger than 1, we need to reduce result when reduction != 'none'
1019
+ if max_rank > 1:
1020
+ reduce_indexes = tuple(range(1, max_rank))
1021
+
1022
+ # elementwise style when reduction='none', otherwise reduce style
1023
+ if prim_reduction == "none":
1024
+ out = prim(x, target)
1025
+ elif prim_reduction in ("mean", "sum"):
1026
+ out = smooth_l1_loss_op(x, target)
1027
+ if reduce_indexes is not None:
1028
+ out = reduce_op(out, reduce_indexes)
1029
+ else:
1030
+ raise RuntimeError("For SmoothL1Loss vmap, reduction should be one of "
1031
+ "['none', 'mean', 'sum'], but got '{}'".format(prim_reduction))
1032
+ return out, 0
1033
+
1034
+ return vmap_rule
1035
+
1036
+
1037
+ @vmap_rules_getters.register(G.SmoothL1LossGrad)
1038
+ def get_smooth_l1_loss_grad_vmap_rule(prim, axis_size):
1039
+ """VmapRule for `SmoothL1LossGrad`."""
1040
+ if isinstance(prim, str):
1041
+ prim = Primitive(prim)
1042
+ reduction = "none"
1043
+ beta = 1.0
1044
+ else:
1045
+ reduction = prim.reduction
1046
+ beta = prim.beta
1047
+ smooth_l1_loss_grad = G.SmoothL1LossGrad(beta, reduction)
1048
+
1049
+ def vmap_rule(x_bdim, target_bdim, dy_bdim):
1050
+ is_all_none, result = vmap_general_preprocess(
1051
+ prim, dy_bdim, x_bdim, target_bdim)
1052
+ if is_all_none:
1053
+ return result
1054
+
1055
+ dy, dy_dim = dy_bdim
1056
+ x, x_dim = x_bdim
1057
+ target, target_dim = target_bdim
1058
+ dy = _bdim_at_front(dy, dy_dim, axis_size)
1059
+ x = _bdim_at_front(x, x_dim, axis_size)
1060
+ target = _bdim_at_front(target, target_dim, axis_size)
1061
+
1062
+ out = smooth_l1_loss_grad(x, target, dy)
1063
+ return out, 0
1064
+
1065
+ return vmap_rule
1066
+
1067
+
1068
+ @vmap_rules_getters.register(P.nn_ops.LogSoftmax)
1069
+ def get_log_softmax_vmap_rule(prim_func, axis_size):
1070
+ """VmapRule for 'LogSoftmax' operation."""
1071
+ def vmap_rule(x_bdim, axis_bdim):
1072
+ is_all_none, result = vmap_general_preprocess(prim_func, x_bdim, axis_bdim)
1073
+ if is_all_none:
1074
+ return result
1075
+ x, x_dim = x_bdim
1076
+ axis, _ = axis_bdim
1077
+ x_ndim = F.rank(x) - 1
1078
+
1079
+ batch_axis = axis + x_ndim if axis < 0 else axis
1080
+ batch_axis = batch_axis if batch_axis < x_dim else batch_axis + 1
1081
+
1082
+ out = F.log_softmax(x, batch_axis)
1083
+ return out, x_dim
1084
+
1085
+ return vmap_rule
1086
+
1087
+
1088
+ @vmap_rules_getters.register(P.RandomCategorical)
1089
+ def get_random_categorical_vmap_rule(prim, axis_size):
1090
+ """VmapRule for `RandomCategorical` operation."""
1091
+
1092
+ default_dim = 2
1093
+
1094
+ def vmap_rule(logits_bdim, num_sample_bdim, seed_bdim):
1095
+ is_all_none, result = vmap_general_preprocess(prim, logits_bdim, num_sample_bdim, seed_bdim)
1096
+ if is_all_none:
1097
+ return result
1098
+ logits, logits_dim = logits_bdim
1099
+ num_sample, num_sample_dim = num_sample_bdim
1100
+ seed, seed_dim = seed_bdim
1101
+ if num_sample_dim is not None or seed_dim is not None:
1102
+ raise RuntimeError("For RandomCategorical vmap, num_sample and seed should be None.")
1103
+ # Move axis to first dim
1104
+ logits = _bdim_at_front(logits, logits_dim, axis_size)
1105
+ x_ndim = F.rank(logits)
1106
+ if x_ndim > default_dim:
1107
+ x_ori_shape = F.shape(logits)
1108
+ logits = F.reshape(logits, (-1, x_ori_shape[-1]))
1109
+ dx = prim(logits, num_sample, seed)
1110
+ new_output_shape = (x_ori_shape[0], x_ori_shape[1], num_sample)
1111
+ dx = F.reshape(dx, new_output_shape)
1112
+ else:
1113
+ dx = prim(logits, num_sample, seed)
1114
+ return dx, 0
1115
+
1116
+ return vmap_rule
1117
+
1118
+
1119
+ @vmap_rules_getters.register(NN.LRN)
1120
+ def get_lrn_vmap_rule(prim, axis_size):
1121
+ """VmapRule for `LRN` operation."""
1122
+ lrn_default_dim = 4
1123
+ lrn_pre_remain_dim = 3
1124
+
1125
+ def vmap_rule(x_bdim):
1126
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1127
+ if is_all_none:
1128
+ return result
1129
+ input_x, input_x_dim = x_bdim
1130
+ # Move axis to last dim
1131
+ x = _bdim_at_back(input_x, input_x_dim, axis_size)
1132
+ x_ndim = F.rank(x)
1133
+ if x_ndim > lrn_default_dim:
1134
+ x_ori_shape = F.shape(x)
1135
+ x = F.reshape(x, x_ori_shape[:lrn_pre_remain_dim] + (-1,))
1136
+ out = prim(x)
1137
+ out = F.reshape(out, x_ori_shape)
1138
+ else:
1139
+ out = prim(x)
1140
+ return out, x_ndim - 1
1141
+
1142
+ return vmap_rule
1143
+
1144
+
1145
+ @vmap_rules_getters.register(NN.PadV3)
1146
+ def get_pad_v3_vmap_rule(prim, axis_size):
1147
+ """VmapRule for `PadV3` operation."""
1148
+ pad_pair = 2
1149
+ input_max_dim = 4
1150
+ mode = prim.mode
1151
+
1152
+ def vmap_rule(*params_bdim):
1153
+ is_all_none, result = vmap_general_preprocess(
1154
+ prim, params_bdim)
1155
+ if is_all_none:
1156
+ return result
1157
+ if len(params_bdim) < 2:
1158
+ _raise_value_error("The input params in `PadV3` must >= 2, "
1159
+ "but got {}.".format(len(params_bdim)))
1160
+ input_x, input_x_dim = params_bdim[0]
1161
+ paddings, paddings_dim = params_bdim[1]
1162
+ values = None
1163
+ out = None
1164
+ x = _bdim_at_front(input_x, input_x_dim, axis_size)
1165
+ if paddings_dim is not None:
1166
+ _raise_value_error("The source axis of `paddings` in `PadV3` must be None, "
1167
+ "but got {}.".format(paddings_dim))
1168
+ if mode == "constant":
1169
+ if len(params_bdim) != 3:
1170
+ _raise_value_error("The input params in `PadV3` of constant mode must be 3, "
1171
+ "but got {}.".format(len(params_bdim)))
1172
+ values, values_dim = params_bdim[2]
1173
+ if values_dim is not None:
1174
+ _raise_value_error("The source axis of `values_dim` in `PadV3` must be None, "
1175
+ "but got {}.".format(values_dim))
1176
+ if isinstance(paddings, Tensor):
1177
+ pad_dim = F.shape(paddings)[0] / pad_pair
1178
+ else:
1179
+ pad_dim = len(paddings) / pad_pair
1180
+ x_ndim = F.rank(x)
1181
+ # pylint: disable=chained-comparison
1182
+ if pad_dim < x_ndim and x_ndim < input_max_dim:
1183
+ if mode == "constant":
1184
+ out = prim(x, paddings, values)
1185
+ else:
1186
+ out = prim(x, paddings)
1187
+ elif x_ndim >= input_max_dim:
1188
+ # reshape to 4 dims
1189
+ x_shape = F.shape(x)
1190
+ diff_dim = x_ndim - input_max_dim
1191
+ first_shape = 1
1192
+ for i in range(diff_dim + 1):
1193
+ first_shape *= x_shape[i]
1194
+ input_shape = (first_shape,) + x_shape[(-input_max_dim + 1):]
1195
+ x = F.reshape(x, input_shape)
1196
+ if mode == "constant":
1197
+ out = prim(x, paddings, values)
1198
+ else:
1199
+ out = prim(x, paddings)
1200
+ out_shape = F.shape(out)
1201
+ real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
1202
+ out = F.reshape(out, real_out_shape)
1203
+ else:
1204
+ _raise_value_error("The dim of `input_x` in `PadV3` must be bigger than {}, "
1205
+ "but got {}.".format(pad_dim, x_ndim))
1206
+ return out, 0
1207
+
1208
+ return vmap_rule
1209
+
1210
+
1211
+ @vmap_rules_getters.register(NN.MirrorPad)
1212
+ def get_mirror_pad_vmap_rule(prim, axis_size):
1213
+ """VmapRule for `MirrorPad` operation."""
1214
+ input_max_dim = 4
1215
+
1216
+ def vmap_rule(*params_bdim):
1217
+ is_all_none, result = vmap_general_preprocess(prim, params_bdim)
1218
+ if is_all_none:
1219
+ return result
1220
+ if len(params_bdim) < 2:
1221
+ _raise_value_error("The input params in `{}` must >= 2, but got {}.".format(prim.name, len(params_bdim)))
1222
+ input_x, input_x_dim = params_bdim[0]
1223
+ paddings, paddings_dim = params_bdim[1]
1224
+
1225
+ out = None
1226
+ x = _bdim_at_front(input_x, input_x_dim, axis_size)
1227
+ if paddings_dim is not None:
1228
+ _raise_value_error(
1229
+ "The source axis of `paddings` in `{}` must be None, but got {}.".format(prim.name, paddings_dim))
1230
+ pad_dim = F.shape(paddings)[0]
1231
+ x_ndim = F.rank(x)
1232
+
1233
+ if pad_dim == x_ndim and x_ndim <= input_max_dim:
1234
+ out = prim(x, paddings)
1235
+ elif x_ndim > input_max_dim:
1236
+ # reshape to 4 dims
1237
+ x_shape = F.shape(x)
1238
+ diff_dim = x_ndim - input_max_dim
1239
+ first_shape = 1
1240
+ for i in range(diff_dim + 1):
1241
+ first_shape *= x_shape[i]
1242
+ input_shape = (first_shape,) + x_shape[(-input_max_dim + 1):]
1243
+ x = F.reshape(x, input_shape)
1244
+ out = prim(x, paddings)
1245
+ out_shape = F.shape(out)
1246
+ real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
1247
+ out = F.reshape(out, real_out_shape)
1248
+ else:
1249
+ _raise_value_error("The dim of `input_x` in `{}` must be bigger than {}, "
1250
+ "but got {}.".format(prim.name, pad_dim, x_ndim))
1251
+ return out, 0
1252
+
1253
+ return vmap_rule
1254
+
1255
+
1256
+ @vmap_rules_getters.register(G.LRNGrad)
1257
+ def get_lrn_grad_vmap_rule(prim, axis_size):
1258
+ """VmapRule for `LRNGrad` operation."""
1259
+ lrn_default_dim = 4
1260
+ lrn_pre_remain_dim = 3
1261
+
1262
+ def vmap_rule(dout_bdim, x_bdim, out_bdim):
1263
+ is_all_none, result = vmap_general_preprocess(prim, dout_bdim, x_bdim, out_bdim)
1264
+ if is_all_none:
1265
+ return result
1266
+ x, x_dim = x_bdim
1267
+ dy, dy_dim = dout_bdim
1268
+ y, y_dim = out_bdim
1269
+ # Move axis to last dim
1270
+ x = _bdim_at_back(x, x_dim, axis_size)
1271
+ dy = _bdim_at_back(dy, dy_dim, axis_size)
1272
+ y = _bdim_at_back(y, y_dim, axis_size)
1273
+ x_ndim = F.rank(x)
1274
+ if x_ndim > lrn_default_dim:
1275
+ x_ori_shape = F.shape(x)
1276
+ dy_ori_shape = F.shape(dy)
1277
+ y_ori_shape = F.shape(y)
1278
+ x = F.reshape(x, x_ori_shape[:lrn_pre_remain_dim] + (-1,))
1279
+ dy = F.reshape(dy, dy_ori_shape[:lrn_pre_remain_dim] + (-1,))
1280
+ y = F.reshape(y, y_ori_shape[:lrn_pre_remain_dim] + (-1,))
1281
+ dx = prim(dy, x, y)
1282
+ dx = F.reshape(dx, x_ori_shape)
1283
+ else:
1284
+ dx = prim(dy, x, y)
1285
+ return dx, x_ndim - 1
1286
+
1287
+ return vmap_rule
1288
+
1289
+
1290
+ @vmap_rules_getters.register(P.BatchNorm)
1291
+ def get_batchnorm_vmap_rule(prim, axis_size):
1292
+ """VmapRule for `BatchNorm` operation."""
1293
+ bn_min_dim = 3
1294
+ bn_max_dim = 5
1295
+ prim_name = "BatchNorm"
1296
+ NCHW = Format.NCHW
1297
+
1298
+ def vmap_rule(*inputs):
1299
+ is_all_none, result = vmap_general_preprocess(prim, *inputs)
1300
+ if is_all_none:
1301
+ return result
1302
+ input_x, input_x_dim = inputs[0]
1303
+ scale, scale_dim = inputs[1]
1304
+ offset, offset_dim = inputs[2]
1305
+ mean, mean_dim = inputs[3]
1306
+ var, var_dim = inputs[4]
1307
+ is_training, _ = inputs[5]
1308
+ epsilon, _ = inputs[6]
1309
+ momentum, _ = inputs[7]
1310
+ data_format, _ = inputs[8]
1311
+ if is_training:
1312
+ raise ValueError("Operator {} does not support Vmap during training, since the input `scale, offset, mean, "
1313
+ "var of BatchNorm are parameters when is_training = true. If multiple batches of input "
1314
+ "data share the same parameters, please stack batches to the N dimension manually."
1315
+ .format(prim_name))
1316
+ x_ndim = F.rank(input_x)
1317
+ if x_ndim < bn_min_dim or x_ndim > bn_max_dim:
1318
+ raise ValueError("The dim of `input_x` in `{}` must be larger than {} and less than {}, "
1319
+ "but got {}.".format(prim_name, bn_min_dim - 1, bn_max_dim + 1, x_ndim))
1320
+ # Move input_x axis to the dim front of C
1321
+ out_axis = 1 if data_format == NCHW else x_ndim - 2
1322
+ input_x = _bdim_at_any(input_x, input_x_dim, out_axis, axis_size)
1323
+ scale = _bdim_at_front(scale, scale_dim, axis_size)
1324
+ offset = _bdim_at_front(offset, offset_dim, axis_size)
1325
+ mean = _bdim_at_front(mean, mean_dim, axis_size)
1326
+ var = _bdim_at_front(var, var_dim, axis_size)
1327
+ x_shape = input_x.shape
1328
+ other_shape = scale.shape
1329
+ vmap_shape = (x_shape[0], -1,) + x_shape[3:] if data_format == NCHW else x_shape[:-2] + (-1,)
1330
+ input_x = F.reshape(input_x, vmap_shape)
1331
+ scale = scale.flatten()
1332
+ offset = offset.flatten()
1333
+ mean = mean.flatten()
1334
+ var = var.flatten()
1335
+ out, batch_mean, batch_var, rsv_1, rsv_2 =\
1336
+ prim(input_x, scale, offset, mean, var, is_training, epsilon, momentum, data_format)
1337
+ out = F.reshape(out, x_shape)
1338
+ batch_mean = F.reshape(batch_mean, other_shape)
1339
+ batch_var = F.reshape(batch_var, other_shape)
1340
+ rsv_1 = F.reshape(rsv_1, other_shape)
1341
+ rsv_2 = F.reshape(rsv_2, other_shape)
1342
+ return (out, out_axis), (batch_mean, 0), (batch_var, 0), (rsv_1, 0), (rsv_2, 0)
1343
+
1344
+ return vmap_rule
1345
+
1346
+
1347
+ @vmap_rules_getters.register(P.ApplyAdamWithAmsgrad)
1348
+ def get_apply_adam_with_amsgrad_rule(prim, axis_size):
1349
+ """VmapRule for `ApplyAdamWithAmsgrad` operation"""
1350
+ if hasattr(prim, "batch_rank"):
1351
+ batch_rank = prim.batch_rank + 1
1352
+ else:
1353
+ batch_rank = 1
1354
+ prim_name = prim.name
1355
+ batch_prim = _vmap_clone_prim(prim)
1356
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1357
+
1358
+ def vmap_rule(var_bdim, m_bdim, v_bdim, vhat_bdim, beta1_power_bdim, beta2_power_bdim, lr_bdim, grad_bdim, u_monad):
1359
+ var, var_dim = var_bdim
1360
+ m, m_dim = m_bdim
1361
+ v, v_dim = v_bdim
1362
+ vhat, vhat_dim = vhat_bdim
1363
+ beta1_power, beta1_power_dim = beta1_power_bdim
1364
+ beta2_power, beta2_power_dim = beta2_power_bdim
1365
+ lr, lr_dim = lr_bdim
1366
+ grad, grad_dim = grad_bdim
1367
+
1368
+ if var_dim is None:
1369
+ if any(dim is not None for dim in [m_dim, v_dim, vhat_dim, beta1_power_dim,
1370
+ beta2_power_dim, lr_dim, grad_dim]):
1371
+ raise ValueError("The source axis of `var` is None, "
1372
+ "but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/grad` is not None. "
1373
+ "The execution of operator `{}` cannot be guaranteed.".format(prim_name))
1374
+ out_var, out_m, out_v, out_vhat = prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
1375
+ return (out_var, None), (out_m, None), (out_v, None), (out_vhat, None)
1376
+
1377
+ if any(dim != 0 for dim in [var_dim, m_dim, v_dim, vhat_dim]):
1378
+ raise ValueError("For `{}`, the source axis of `var/m/v/vhat` must be 0, "
1379
+ "but get `var`: {}, `m`: {}, `v`: {}, `vhat`: {}".format(prim_name, var_dim,
1380
+ m_dim, v_dim, vhat_dim))
1381
+
1382
+ beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
1383
+ beta2_power = _bdim_at_front(beta2_power, beta2_power_dim, axis_size)
1384
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1385
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1386
+
1387
+ out_var, out_m, out_v, out_vhat = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
1388
+ return (out_var, 0), (out_m, 0), (out_v, 0), (out_vhat, 0)
1389
+
1390
+ return vmap_rule
1391
+
1392
+
1393
+ @vmap_rules_getters.register(P.ApplyAdamWithAmsgradV2)
1394
+ def get_apply_adam_with_amsgrad_v2_rule(prim, axis_size):
1395
+ """VmapRule for `ApplyAdamWithAmsgradV2` operation"""
1396
+ if hasattr(prim, "batch_rank"):
1397
+ batch_rank = prim.batch_rank + 1
1398
+ else:
1399
+ batch_rank = 1
1400
+ prim_name = prim.name
1401
+ batch_prim = _vmap_clone_prim(prim)
1402
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1403
+
1404
+ def vmap_rule(var_bdim, m_bdim, v_bdim, vhat_bdim, beta1_power_bdim, beta2_power_bdim, lr_bdim, beta1_bdim,
1405
+ beta2_bdim, epsilon_bdim, grad_bdim, u_monad):
1406
+ var, var_dim = var_bdim
1407
+ m, m_dim = m_bdim
1408
+ v, v_dim = v_bdim
1409
+ vhat, vhat_dim = vhat_bdim
1410
+ beta1_power, beta1_power_dim = beta1_power_bdim
1411
+ beta2_power, beta2_power_dim = beta2_power_bdim
1412
+ lr, lr_dim = lr_bdim
1413
+ beta1, beta1_dim = beta1_bdim
1414
+ beta2, beta2_dim = beta2_bdim
1415
+ epsilon, epsilon_dim = epsilon_bdim
1416
+ grad, grad_dim = grad_bdim
1417
+
1418
+ if var_dim is None:
1419
+ if any(dim is not None for dim in [m_dim, v_dim, vhat_dim, beta1_power_dim,
1420
+ beta2_power_dim, lr_dim, beta1_dim, beta2_dim, grad_dim]):
1421
+ raise ValueError("The source axis of `var` is None, "
1422
+ "but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/beta1/beta2/grad` is not "
1423
+ "None. The execution of operator `{}` cannot be guaranteed.".format(prim_name))
1424
+ out_var, out_m, out_v, out_vhat = prim(var, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2, epsilon,
1425
+ grad, u_monad)
1426
+ return (out_var, None), (out_m, None), (out_v, None), (out_vhat, None)
1427
+
1428
+ if any(dim != 0 for dim in [var_dim, m_dim, v_dim, vhat_dim]):
1429
+ raise ValueError("For `{}`, the source axis of `var/m/v/vhat` must be 0, "
1430
+ "but get `var`: {}, `m`: {}, `v`: {}, `vhat`: {}".format(prim_name, var_dim,
1431
+ m_dim, v_dim, vhat_dim))
1432
+
1433
+ beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
1434
+ beta2_power = _bdim_at_front(beta2_power, beta2_power_dim, axis_size)
1435
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1436
+ beta1 = _bdim_at_front(beta1, beta1_dim, axis_size)
1437
+ beta2 = _bdim_at_front(beta2, beta2_dim, axis_size)
1438
+ epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
1439
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1440
+
1441
+ out_var, out_m, out_v, out_vhat = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2,
1442
+ epsilon, grad, u_monad)
1443
+ return (out_var, 0), (out_m, 0), (out_v, 0), (out_vhat, 0)
1444
+
1445
+ return vmap_rule
1446
+
1447
+
1448
+ @vmap_rules_getters.register(P.Adam)
1449
+ def get_adam_rule(prim, axis_size):
1450
+ """VmapRule for `Adam` operation"""
1451
+ if hasattr(prim, "batch_rank"):
1452
+ batch_rank = prim.batch_rank + 1
1453
+ else:
1454
+ batch_rank = 1
1455
+ prim_name = prim.name
1456
+ batch_prim = _vmap_clone_prim(prim)
1457
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1458
+
1459
+ def vmap_rule(var_bdim, m_bdim, v_bdim, beta1_power_bdim, beta2_power_bdim, lr_bdim, beta1_bdim,
1460
+ beta2_bdim, epsilon_bdim, grad_bdim, u_monad):
1461
+ var, var_dim = var_bdim
1462
+ m, m_dim = m_bdim
1463
+ v, v_dim = v_bdim
1464
+ beta1_power, beta1_power_dim = beta1_power_bdim
1465
+ beta2_power, beta2_power_dim = beta2_power_bdim
1466
+ lr, lr_dim = lr_bdim
1467
+ beta1, beta1_dim = beta1_bdim
1468
+ beta2, beta2_dim = beta2_bdim
1469
+ epsilon, epsilon_dim = epsilon_bdim
1470
+ grad, grad_dim = grad_bdim
1471
+
1472
+ all_dim = [m_dim, v_dim, beta1_power_dim, beta2_power_dim, lr_dim, beta1_dim, beta2_dim, epsilon_dim, grad_dim]
1473
+ if var_dim is None:
1474
+ if any(dim is not None for dim in all_dim):
1475
+ raise ValueError("The source axis of `var` is None, "
1476
+ "but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/beta1/beta2/epsilon grad"
1477
+ " is not None. The execution of operator `{}` cannot be guaranteed.".format(prim_name))
1478
+ out_var, out_m, out_v = prim(
1479
+ var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, u_monad)
1480
+ return ((out_var, None), (out_m, None), (out_v, None))
1481
+
1482
+ if any(dim != 0 for dim in [var_dim, m_dim, v_dim]):
1483
+ raise ValueError("For `{}`, the source axis of `var/m/v` must be 0, "
1484
+ "but get `var`: {}, `m`: {}, `v`: {}".format(prim_name, var_dim,
1485
+ m_dim, v_dim))
1486
+
1487
+ beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
1488
+ beta2_power = _bdim_at_front(beta2_power, beta2_power_dim, axis_size)
1489
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1490
+ beta1 = _bdim_at_front(beta1, beta1_dim, axis_size)
1491
+ beta2 = _bdim_at_front(beta2, beta2_dim, axis_size)
1492
+ epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
1493
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1494
+
1495
+ out_var, out_m, out_v = batch_prim(
1496
+ var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, u_monad)
1497
+ return ((out_var, 0), (out_m, 0), (out_v, 0))
1498
+
1499
+ return vmap_rule
1500
+
1501
+
1502
+ @vmap_rules_getters.register(P.ApplyPowerSign)
1503
+ def get_apply_power_sign_rule(prim, axis_size):
1504
+ """VmapRule for `ApplyPowerSign` operation."""
1505
+ if hasattr(prim, 'batch_rank'):
1506
+ batch_rank = prim.batch_rank + 1
1507
+ else:
1508
+ batch_rank = 1
1509
+
1510
+ prim_name = prim.name
1511
+ batch_prim = _vmap_clone_prim(prim)
1512
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1513
+
1514
+ def vmap_rule(var_bdim, m_bdim, lr_bdim, logbase_bdim, sign_decay_bdim, beta_bdim, grad_bdim, u_monad):
1515
+ var, var_dim = var_bdim
1516
+ m, m_dim = m_bdim
1517
+ lr, lr_dim = lr_bdim
1518
+ logbase, logbase_dim = logbase_bdim
1519
+ sign_decay, sign_decay_dim = sign_decay_bdim
1520
+ beta, beta_dim = beta_bdim
1521
+ grad, grad_dim = grad_bdim
1522
+
1523
+ if var_dim is None:
1524
+ if any(dim is not None for dim in [m_bdim, lr_bdim, logbase_bdim, sign_decay_bdim, beta_bdim, grad_bdim]):
1525
+ raise ValueError("The source axis of `var` is None, but the source "
1526
+ "axis of `m/lr/logbase/sign_decay/beta/grad` is not None. The execution order of "
1527
+ "operator `{}` cannot be guaranteed.".format(prim_name))
1528
+ var, m = prim(var, m, lr, logbase, sign_decay, beta, grad, u_monad)
1529
+ return (var, None), (m, None)
1530
+ if var_dim != 0 or m_dim != var_dim:
1531
+ raise ValueError("For `{}`, the source axis of `var` must be equal to `m`, and not equal to 0, "
1532
+ "but got the source axis of `var`: {}, `m`: {}.".format(prim_name, var_dim, m_dim))
1533
+
1534
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1535
+ logbase = _bdim_at_front(logbase, logbase_dim, axis_size)
1536
+ sign_decay = _bdim_at_front(sign_decay, sign_decay_dim, axis_size)
1537
+ beta = _bdim_at_front(beta, beta_dim, axis_size)
1538
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1539
+ var, m = batch_prim(var, m, lr, logbase, sign_decay, beta, grad, u_monad)
1540
+ return (var, 0), (m, 0)
1541
+
1542
+ return vmap_rule
1543
+
1544
+
1545
+ @vmap_rules_getters.register(P.ApplyAdagradV2)
1546
+ def get_apply_adagrad_v2_vmap_rule(prim, axis_size):
1547
+ """VmapRule for `ApplyAdagradV2` operation."""
1548
+ if hasattr(prim, 'batch_rank'):
1549
+ batch_rank = prim.batch_rank + 1
1550
+ else:
1551
+ batch_rank = 1
1552
+
1553
+ batch_prim = _vmap_clone_prim(prim)
1554
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
1555
+ prim_name = prim.name
1556
+
1557
+ def vmap_rule(var_bdim, accum_bdim, lr_bdim, grad_bdim, u_monad):
1558
+ var, var_dim = var_bdim
1559
+ accum, accum_dim = accum_bdim
1560
+ lr, lr_dim = lr_bdim
1561
+ grad, grad_dim = grad_bdim
1562
+
1563
+ if var_dim is None:
1564
+ if any(dim is not None for dim in
1565
+ [accum_bdim, lr_dim, grad_bdim]):
1566
+ raise ValueError("The source axis of 'var' is None, but the source "
1567
+ "axis of 'accum/lr/grad'"
1568
+ " is not None. The execution order of "
1569
+ "operator '{}' cannot be guaranteed.".format(prim_name))
1570
+ var, accum = prim(var, accum, lr, grad, u_monad)
1571
+ return (var, None), (accum, None)
1572
+ if var_dim != 0 or var_dim != accum_dim:
1573
+ raise ValueError(
1574
+ f"For '{prim_name}', the source axis of 'var' must be equal to 'accum_dim' "
1575
+ f"and not equal to 0, but got the source axis of 'var': {var_dim}, "
1576
+ f"'accum_dim': {accum_dim}")
1577
+
1578
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1579
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1580
+
1581
+ var, accum = batch_prim(var, accum, lr, grad, u_monad)
1582
+ return (var, 0), (accum, 0)
1583
+
1584
+ return vmap_rule
1585
+
1586
+
1587
+ @vmap_rules_getters.register(P.ApplyAdagradDA)
1588
+ def get_apply_adagrad_da_vmap_rule(prim, axis_size):
1589
+ """VmapRule for `ApplyAdagradDA` operation."""
1590
+ if hasattr(prim, 'batch_rank'):
1591
+ batch_rank = prim.batch_rank + 1
1592
+ else:
1593
+ batch_rank = 1
1594
+
1595
+ attr = prim.init_attrs
1596
+ batch_prim = P.ApplyAdagradDA(**attr)
1597
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
1598
+ prim_name = prim.name
1599
+
1600
+ def vmap_rule(var_bdim, gradient_accumulator_bdim, gradient_squared_accumulator_bdim, grad_bdim, lr_bdim, l1_bdim,
1601
+ l2_bdim, global_step_bdim, u_monad):
1602
+ var, var_dim = var_bdim
1603
+ gradient_accumulator, gradient_accumulator_dim = gradient_accumulator_bdim
1604
+ gradient_squared_accumulator, gradient_squared_accumulator_dim = gradient_squared_accumulator_bdim
1605
+ grad, grad_dim = grad_bdim
1606
+ lr, lr_dim = lr_bdim
1607
+ l1, l1_dim = l1_bdim
1608
+ l2, l2_dim = l2_bdim
1609
+ global_step, global_step_dim = global_step_bdim
1610
+
1611
+ if var_dim is None:
1612
+ if any(dim is not None for dim in
1613
+ [gradient_accumulator_bdim, gradient_squared_accumulator_bdim, grad_bdim, lr_bdim, l1_bdim, l2_bdim,
1614
+ global_step_bdim]):
1615
+ raise ValueError("The source axis of 'var' is None, but the source "
1616
+ "axis of 'gradient_accumulator/gradient_squared_accumulator/grad/lr/l1/l2/global_step'"
1617
+ " is not None. The execution order of "
1618
+ "operator '{}' cannot be guaranteed.".format(prim_name))
1619
+ var, gradient_accumulator, gradient_squared_accumulator = prim(var, gradient_accumulator,
1620
+ gradient_squared_accumulator, grad, lr, l1,
1621
+ l2,
1622
+ global_step,
1623
+ u_monad) # Low dimensional operator
1624
+ return (var, None), (gradient_accumulator, None), (gradient_squared_accumulator, None)
1625
+ if var_dim != 0 or var_dim != gradient_accumulator_dim or var_dim != gradient_squared_accumulator_dim:
1626
+ raise ValueError(
1627
+ f"For '{prim_name}', the source axis of 'var' must be equal to 'gradient_accumulator_dim' "
1628
+ f"and 'gradient_squared_accumulator_dim' and not equal to 0, "
1629
+ f"but got the source axis of 'var': {var_dim}, "
1630
+ f"'gradient_accumulator_dim': {gradient_accumulator_dim}, "
1631
+ f"'gradient_squared_accumulator_dim': {gradient_squared_accumulator_dim}")
1632
+
1633
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1634
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1635
+ l1 = _bdim_at_front(l1, l1_dim, axis_size)
1636
+ l2 = _bdim_at_front(l2, l2_dim, axis_size)
1637
+ global_step = _bdim_at_front(global_step, global_step_dim, axis_size)
1638
+
1639
+ var = batch_prim(var, gradient_accumulator,
1640
+ gradient_squared_accumulator, grad, lr, l1,
1641
+ l2,
1642
+ global_step,
1643
+ u_monad) # High dimensional operator;
1644
+ return (var, 0)
1645
+
1646
+ return vmap_rule
1647
+
1648
+
1649
+ @vmap_rules_getters.register(NN.AdaptiveMaxPool2D)
1650
+ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1651
+ """VmapRule for `AdaptiveMaxPool2D`."""
1652
+ nchw_index = 4
1653
+ chw_reverse_index = -3
1654
+ hw_size = 2
1655
+ output_size = prim.output_size
1656
+
1657
+ @_primexpr
1658
+ def get_output_shape(x_ori_shape, output_size):
1659
+ if isinstance(output_size, tuple):
1660
+ h_out, w_out = output_size
1661
+ else:
1662
+ h_out = output_size
1663
+ w_out = output_size
1664
+
1665
+ rank = len(x_ori_shape)
1666
+ output_shape = x_ori_shape[:rank - hw_size]
1667
+ if h_out is None or h_out == -1:
1668
+ output_shape += (x_ori_shape[-2],)
1669
+ else:
1670
+ output_shape += (h_out,)
1671
+
1672
+ if w_out is None or w_out == -1:
1673
+ output_shape += (x_ori_shape[-1],)
1674
+ else:
1675
+ output_shape += (w_out,)
1676
+ return output_shape
1677
+
1678
+ def vmap_rule(input_x_bdim):
1679
+ is_all_none, result = vmap_general_preprocess(prim, input_x_bdim)
1680
+ if is_all_none:
1681
+ return result
1682
+
1683
+ input_x, input_x_dim = input_x_bdim
1684
+ x = _bdim_at_front(input_x, input_x_dim, axis_size)
1685
+ x_ndim = F.rank(x)
1686
+
1687
+ if x_ndim > nchw_index:
1688
+ # for the case of NCHW
1689
+ x_ori_shape = F.shape(x)
1690
+ x = F.reshape(x, (-1,) + x_ori_shape[chw_reverse_index:])
1691
+ output_shape = get_output_shape(x_ori_shape, output_size)
1692
+ out, indices = prim(x)
1693
+ out = F.reshape(out, output_shape)
1694
+ indices = F.reshape(indices, output_shape)
1695
+ return (out, 0), (indices, 0)
1696
+
1697
+ # for the case of CHW
1698
+ out, indices = prim(x)
1699
+ return (out, 0), (indices, 0)
1700
+
1701
+ return vmap_rule
1702
+
1703
+
1704
+ @vmap_rules_getters.register(NN.MaxPool3DWithArgmax)
1705
+ def get_max_pool3d_with_argmax_vmap_rule(prim, axis_size):
1706
+ """VmapRule for `MaxPool3DWithArgmax`."""
1707
+ cdhw_reverse_index = -4
1708
+
1709
+ def vmap_rule(x_bdim):
1710
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1711
+ if is_all_none:
1712
+ return result
1713
+
1714
+ x, x_dim = x_bdim
1715
+ x = _bdim_at_front(x, x_dim, axis_size)
1716
+ x_shape = F.shape(x)
1717
+ input_shape = (-1,) + x_shape[cdhw_reverse_index:]
1718
+ x = F.reshape(x, input_shape)
1719
+ out, indices = prim(x)
1720
+ out_shape = F.shape(out)
1721
+ return_shape = x_shape[:cdhw_reverse_index] + out_shape[cdhw_reverse_index:]
1722
+ out = F.reshape(out, return_shape)
1723
+ indices = F.reshape(indices, return_shape)
1724
+ return (out, 0), (indices, 0)
1725
+
1726
+ return vmap_rule
1727
+
1728
+
1729
+ @vmap_rules_getters.register(P.ApplyRMSProp)
1730
+ def get_rmsprop_vmap_rule(prim, axis_size):
1731
+ """VmapRule for `ApplyRMSProp` operation."""
1732
+ if hasattr(prim, 'batch_rank'):
1733
+ batch_rank = prim.batch_rank + 1
1734
+ else:
1735
+ batch_rank = 1
1736
+
1737
+ batch_prim = _vmap_clone_prim(prim)
1738
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
1739
+ prim_name = prim.name
1740
+
1741
+ def vmap_rule(var_bdim, mean_square_bdim, moment_bdim, lr_bdim, grad_bdim, decay_bdim, momentum_bdim,
1742
+ epsilon_bdim, u_monad):
1743
+ var, var_dim = var_bdim
1744
+ mean_square, mean_square_dim = mean_square_bdim
1745
+ moment, moment_dim = moment_bdim
1746
+ grad, grad_dim = grad_bdim
1747
+ lr, lr_dim = lr_bdim
1748
+ decay, decay_dim = decay_bdim
1749
+ momentum, momentum_dim = momentum_bdim
1750
+ epsilon, epsilon_dim = epsilon_bdim
1751
+
1752
+ if var_dim is None:
1753
+ if any(dim is not None for dim in
1754
+ [mean_square_dim, moment_dim, grad_dim, lr_dim, decay_dim, momentum_dim, epsilon_dim]):
1755
+ raise ValueError("The source axis of 'var' is None, but the source "
1756
+ "axis of 'mean_square/moment/lr/grad/decay/momentum/epsilon'"
1757
+ " is not None. The execution order of "
1758
+ "operator '{}' cannot be guaranteed.".format(prim_name))
1759
+
1760
+ res = prim(var, mean_square, moment, lr, grad, decay, momentum, epsilon,
1761
+ u_monad) # low dimensional operator;
1762
+ return (res, None)
1763
+ precondition = var_dim != 0 or var_dim != mean_square_dim or var_dim != moment_dim or var_dim != grad_dim
1764
+ if precondition:
1765
+ raise ValueError(
1766
+ f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_square_dim' "
1767
+ f"and 'moment_dim' and 'grad_dim' and not equal to 0, "
1768
+ f"but got the source axis of 'var': {var_dim}, "
1769
+ f"'mean_square_dim': {mean_square_dim}, "
1770
+ f"'moment_dim': {moment_dim},"
1771
+ f"'gradient_dim':{grad_dim}.")
1772
+
1773
+ mean_square = _bdim_at_front(mean_square, mean_square_dim, axis_size)
1774
+ moment = _bdim_at_front(moment, moment_dim, axis_size)
1775
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1776
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1777
+
1778
+ res = batch_prim(var, mean_square, moment, lr, grad, decay, momentum, epsilon,
1779
+ u_monad) # High dimensional operator;
1780
+
1781
+ return res, 0
1782
+
1783
+ return vmap_rule
1784
+
1785
+
1786
+ @vmap_rules_getters.register(P.ApplyCenteredRMSProp)
1787
+ def get_apply_centered_rmsprop_vmap_rule(prim, axis_size):
1788
+ """VmapRule for `ApplyCenteredRMSProp` operation."""
1789
+ if hasattr(prim, 'batch_rank'):
1790
+ batch_rank = prim.batch_rank + 1
1791
+ else:
1792
+ batch_rank = 1
1793
+ prim_name = prim.name
1794
+ batch_prim = _vmap_clone_prim(prim)
1795
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1796
+
1797
+ def vmap_rule(var_bdim, mean_grad_bdim, mean_square_bdim, mom_bdim, grad_bdim, lr_bdim, rho_bdim,
1798
+ momentum_bdim, eps_bdim, u_monad):
1799
+ var, var_dim = var_bdim
1800
+ mean_grad, mean_grad_dim = mean_grad_bdim
1801
+ mean_square, mean_square_dim = mean_square_bdim
1802
+ mom, mom_dim = mom_bdim
1803
+ grad, grad_dim = grad_bdim
1804
+ lr, lr_dim = lr_bdim
1805
+ rho, rho_dim = rho_bdim
1806
+ momentum, momentum_dim = momentum_bdim
1807
+ eps, eps_dim = eps_bdim
1808
+
1809
+ if var_dim is None:
1810
+ if any(dim is not None for dim in
1811
+ [mean_grad_dim, mean_square_dim, mom_dim, grad_dim, lr_dim, rho_dim,
1812
+ momentum_dim, eps_dim]):
1813
+ raise ValueError("The source axis of 'var' is None, but the source "
1814
+ "axis of 'mean_gradient/mean_square/mom/grad/lr/rho/momentum/eps'"
1815
+ " is not None. The execution order of "
1816
+ "operator '{}' cannot be guaranteed.".format(prim_name))
1817
+ var = prim(var, mean_grad, mean_square,
1818
+ mom, grad, lr, rho, momentum, eps, u_monad)
1819
+ return (var, None)
1820
+ precondition = var_dim != 0 or var_dim != mean_grad_dim or var_dim != mean_square_dim or var_dim != mom_dim
1821
+ if precondition:
1822
+ raise ValueError(
1823
+ f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_grad_dim' "
1824
+ f"and 'mean_square_dim' and 'mom_dim' and not equal to 0, "
1825
+ f"but got the source axis of 'var': {var_dim}, "
1826
+ f"'mean_grad_dim': {mean_grad_dim}, "
1827
+ f"'mean_square_dim': {mean_square_dim},"
1828
+ f"'mom_dim': {mom_dim}.")
1829
+
1830
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1831
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1832
+ rho = _bdim_at_front(rho, rho_dim, axis_size)
1833
+ momentum = _bdim_at_front(momentum, momentum_dim, axis_size)
1834
+ eps = _bdim_at_front(eps, eps_dim, axis_size)
1835
+
1836
+ var = batch_prim(var, mean_grad, mean_square,
1837
+ mom, grad, lr, rho, momentum, eps, u_monad)
1838
+ return var, 0
1839
+
1840
+ return vmap_rule
1841
+
1842
+
1843
+ @vmap_rules_getters.register(P.MaxPool)
1844
+ @vmap_rules_getters.register(P.MaxPoolWithArgmax)
1845
+ @vmap_rules_getters.register(P.MaxPoolWithArgmaxV2)
1846
+ def get_max_pool_vmap_rule(prim, axis_size):
1847
+ """VmapRule for `MaxPool` operation."""
1848
+ if isinstance(prim, str):
1849
+ prim = Primitive(prim)
1850
+
1851
+ prim_name = prim.name
1852
+
1853
+ @_primexpr
1854
+ def get_original_shape(x_shape, out_shape):
1855
+ h_new = out_shape[2]
1856
+ w_new = out_shape[3]
1857
+ original_shape = x_shape[:3] + (h_new,) + (w_new,)
1858
+ return original_shape
1859
+
1860
+ def vmap_rule(x_bdim):
1861
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1862
+ if is_all_none:
1863
+ return result
1864
+ x, x_dim = x_bdim
1865
+ x = _bdim_at_front(x, x_dim, axis_size)
1866
+ x_shape = x.shape
1867
+ x_new_shape = (-1,) + x_shape[2:]
1868
+ x = x.reshape(x_new_shape)
1869
+ if prim_name == "MaxPool":
1870
+ out = prim(x)
1871
+ out_shape = out.shape
1872
+ original_shape = get_original_shape(x_shape, out_shape)
1873
+ out = out.reshape(original_shape)
1874
+ return out, 0
1875
+ out, indices = prim(x)
1876
+ out_shape = out.shape
1877
+ original_shape = get_original_shape(x_shape, out_shape)
1878
+ out = out.reshape(original_shape)
1879
+ indices = indices.reshape(original_shape)
1880
+ return (out, 0), (indices, 0)
1881
+
1882
+ return vmap_rule
1883
+
1884
+
1885
+ @vmap_rules_getters.register("LayerNorm")
1886
+ def get_layernorm_vmap_rule(prim, axis_size):
1887
+ """VmapRule for `LayerNorm` operation."""
1888
+
1889
+ def process_attr_axis(prim_attr_axis):
1890
+ if prim_attr_axis < 0:
1891
+ return prim_attr_axis
1892
+ return prim_attr_axis + 1
1893
+
1894
+ @_primexpr
1895
+ def get_logical_shape(var_shape):
1896
+ return var_shape[1:]
1897
+
1898
+ def vmap_rule(x_bdim, gamma_bdim, beta_bdim, begin_norm_axis_bdim, begin_params_axis_bdim, epsilon_bdim):
1899
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, gamma_bdim, beta_bdim, begin_norm_axis_bdim,
1900
+ begin_params_axis_bdim, epsilon_bdim)
1901
+ if is_all_none:
1902
+ return result
1903
+
1904
+ x, x_dim = x_bdim
1905
+ g, g_dim = gamma_bdim
1906
+ b, b_dim = beta_bdim
1907
+ begin_norm_axis, _ = begin_norm_axis_bdim
1908
+ begin_params_axis, _ = begin_params_axis_bdim
1909
+ epsilon, _ = epsilon_bdim
1910
+
1911
+ begin_norm_axis = process_attr_axis(begin_norm_axis)
1912
+ begin_params_axis = process_attr_axis(begin_params_axis)
1913
+
1914
+ x = _bdim_at_front(x, x_dim, axis_size)
1915
+
1916
+ if g_dim is None and b_dim is None:
1917
+ output, mean, var = prim(x, g, b, begin_norm_axis, begin_params_axis, epsilon)
1918
+ return (output, 0), (mean, 0), (var, 0)
1919
+
1920
+ g = _bdim_at_front(g, g_dim, axis_size)
1921
+ b = _bdim_at_front(b, b_dim, axis_size)
1922
+ g_logical_shape = get_logical_shape(F.shape(g))
1923
+ b_logical_shape = get_logical_shape(F.shape(b))
1924
+
1925
+ ones_like_g = F.ones(g_logical_shape, F.dtype(g))
1926
+ zeros_like_b = F.zeros(b_logical_shape, F.dtype(b))
1927
+ output_tmp, mean, var = prim(x, ones_like_g, zeros_like_b, begin_norm_axis, begin_params_axis, epsilon)
1928
+
1929
+ x_shape = F.shape(x)
1930
+ g_shape = F.shape(g)
1931
+ b_shape = F.shape(b)
1932
+ g = _handle_broadcasting(g, g_shape, x_shape)
1933
+ b = _handle_broadcasting(b, b_shape, x_shape)
1934
+ output = F.add(F.mul(output_tmp, g), b)
1935
+
1936
+ return (output, 0), (mean, 0), (var, 0)
1937
+
1938
+ return vmap_rule
1939
+
1940
+
1941
+ @vmap_rules_getters.register(NN.GridSampler2D)
1942
+ @vmap_rules_getters.register(NN.GridSampler3D)
1943
+ def get_grid_sampler_vmap_rule(prim, axis_size):
1944
+ """VmapRule for `GridSampler2D` and `GridSampler3D`."""
1945
+ prim_name = prim.name
1946
+ if prim_name == "GridSampler2D":
1947
+ non_batch_dim_index = -3
1948
+ elif prim_name == "GridSampler3D":
1949
+ non_batch_dim_index = -4
1950
+ else:
1951
+ _raise_value_error(
1952
+ "The prim name must be `GridSampler2D` or `GridSampler3D`, but got {}.".format(prim_name))
1953
+
1954
+ def vmap_rule(input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim):
1955
+ is_all_none, result = vmap_general_preprocess(
1956
+ prim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim)
1957
+ if is_all_none:
1958
+ return result
1959
+
1960
+ input_x, input_x_dim = input_x_bdim
1961
+ grid, grid_dim = grid_bdim
1962
+ interpolation_mode, _ = interpolation_mode_bdim
1963
+ padding_mode, _ = padding_mode_bdim
1964
+ align_corners, _ = align_corners_bdim
1965
+
1966
+ input_x = _bdim_at_front(input_x, input_x_dim, axis_size)
1967
+ input_x_shape = F.shape(input_x)
1968
+ input_x = F.reshape(input_x, (-1,) + input_x_shape[non_batch_dim_index:])
1969
+
1970
+ grid = _bdim_at_front(grid, grid_dim, axis_size)
1971
+ grid_shape = F.shape(grid)
1972
+ grid = F.reshape(grid, (-1,) + grid_shape[non_batch_dim_index:])
1973
+
1974
+ out = prim(input_x, grid, interpolation_mode, padding_mode, align_corners)
1975
+ out_shape = F.shape(out)
1976
+ return_shape = input_x_shape[:non_batch_dim_index] + out_shape[non_batch_dim_index:]
1977
+ out = F.reshape(out, return_shape)
1978
+ return out, 0
1979
+
1980
+ return vmap_rule
1981
+
1982
+
1983
+ @vmap_rules_getters.register(NN.UpsampleNearest1D)
1984
+ @vmap_rules_getters.register(NN.UpsampleNearest2D)
1985
+ @vmap_rules_getters.register(NN.UpsampleNearest3D)
1986
+ def get_upsample_nearest_3d_vmap_rule(prim, axis_size):
1987
+ """VmapRule for `UpsampleNearest1D`, `UpsampleNearest2D` and `UpsampleNearest3D`."""
1988
+ prim_name = prim.name
1989
+ if prim_name == "UpsampleNearest1D":
1990
+ reverse_index = -2
1991
+ elif prim_name == "UpsampleNearest2D":
1992
+ reverse_index = -3
1993
+ else:
1994
+ reverse_index = -4
1995
+
1996
+ def vmap_rule(x_bdim, size_bdim, scales_bdim):
1997
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, size_bdim,
1998
+ scales_bdim)
1999
+ if is_all_none:
2000
+ return result
2001
+
2002
+ x, x_dim = x_bdim
2003
+ x = _bdim_at_front(x, x_dim, axis_size)
2004
+ size, size_dim = size_bdim
2005
+ scales, scales_dim = scales_bdim
2006
+ if size_dim is not None or scales_dim is not None:
2007
+ _raise_value_error(
2008
+ "For {0}, the source axis of `output_size` and `scales` must be None,"
2009
+ " but got {1} and {2}.".format(prim_name, size_dim, scales_dim))
2010
+
2011
+ x_shape = F.shape(x)
2012
+ input_shape = (-1,) + x_shape[reverse_index:]
2013
+ x = F.reshape(x, input_shape)
2014
+ out = prim(x, size, scales)
2015
+ out_shape = F.shape(out)
2016
+ return_shape = x_shape[:reverse_index] + out_shape[reverse_index:]
2017
+ out = F.reshape(out, return_shape)
2018
+ return out, 0
2019
+
2020
+ return vmap_rule
2021
+
2022
+
2023
+ @vmap_rules_getters.register(NN.UpsampleLinear1D)
2024
+ @vmap_rules_getters.register(NN.UpsampleBilinear2D)
2025
+ @vmap_rules_getters.register(NN.UpsampleTrilinear3D)
2026
+ def get_upsample_linear_vmap_rule(prim, axis_size):
2027
+ """VmapRule for `UpsampleLinear1D`, `UpsampleBilinear2D` and `UpsampleTrilinear3D`."""
2028
+ prim_name = prim.name
2029
+ if prim_name == "UpsampleLinear1D":
2030
+ reverse_index = -2
2031
+ elif prim_name == "UpsampleBilinear2D":
2032
+ reverse_index = -3
2033
+ else:
2034
+ reverse_index = -4
2035
+
2036
+ def vmap_rule(x_bdim, size_bdim, scales_bdim, align_corners_bdim):
2037
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, size_bdim,
2038
+ scales_bdim, align_corners_bdim)
2039
+ if is_all_none:
2040
+ return result
2041
+
2042
+ x, x_dim = x_bdim
2043
+ x = _bdim_at_front(x, x_dim, axis_size)
2044
+ size, size_dim = size_bdim
2045
+ scales, scales_dim = scales_bdim
2046
+ align_corners, align_corners_dim = align_corners_bdim
2047
+ if size_dim is not None or scales_dim is not None or align_corners_dim is not None:
2048
+ _raise_value_error(
2049
+ "For {0}, the source axis of `output_size`, `scales` and `align_corners`must"
2050
+ "be None, but got {1} and {2}.".format(prim_name, size_dim, scales_dim))
2051
+
2052
+ x_shape = F.shape(x)
2053
+ input_shape = (-1,) + x_shape[reverse_index:]
2054
+ x = F.reshape(x, input_shape)
2055
+ out = prim(x, size, scales, align_corners)
2056
+ out_shape = F.shape(out)
2057
+ return_shape = x_shape[:reverse_index] + out_shape[reverse_index:]
2058
+ out = F.reshape(out, return_shape)
2059
+ return out, 0
2060
+
2061
+ return vmap_rule
2062
+
2063
+
2064
+ @vmap_rules_getters.register(NN.SparseApplyAdagrad)
2065
+ @vmap_rules_getters.register(NN.SparseApplyAdagradV2)
2066
+ def get_sparse_apply_adagrad_vmap_rule(prim, axis_size):
2067
+ """VmapRule for `SparseApplyAdagrad`."""
2068
+ if hasattr(prim, 'batch_rank'):
2069
+ batch_rank = prim.batch_rank + 1
2070
+ else:
2071
+ batch_rank = 1
2072
+
2073
+ prim_name = prim.name
2074
+ batch_prim = _vmap_clone_prim(prim)
2075
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
2076
+
2077
+ def vmap_rule(var_bdim, accum_bdim, grad_bdim, indices_bdim, u_monad):
2078
+ var, var_dim = var_bdim
2079
+ accum, accum_dim = accum_bdim
2080
+ grad, grad_dim = grad_bdim
2081
+ indices, indices_dim = indices_bdim
2082
+ if var_dim is None:
2083
+ if any(dim is not None for dim in [accum_dim, grad_dim, indices_dim]):
2084
+ ValueError("The source axis of `var` is None, but the source "
2085
+ "axis of `accum/grad/indices` is not None. The execution order of "
2086
+ "operator `{}` cannot be guaranteed.".format(prim_name))
2087
+ var, accum = prim(var, accum, grad, indices, u_monad)
2088
+ return (var, None), (accum, None)
2089
+ if var_dim != 0 or accum_dim != var_dim:
2090
+ ValueError("For `{}`, the source axis of `var` must be equal to `accum`, and not equal to 0, "
2091
+ "but got the source axis of `var`: {}, `accum`: {}.".format(prim_name, var_dim, accum_dim))
2092
+
2093
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
2094
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
2095
+
2096
+ var, accum = batch_prim(var, accum, grad, indices, u_monad)
2097
+ return (var, 0), (accum, 0)
2098
+
2099
+ return vmap_rule
2100
+
2101
+
2102
+ @vmap_rules_getters.register(NN.SparseApplyFtrl)
2103
+ def get_sparse_apply_ftrl_vmap_rule(prim, axis_size):
2104
+ """VmapRule for `SparseApplyFtrl`."""
2105
+ if hasattr(prim, 'batch_rank'):
2106
+ batch_rank = prim.batch_rank + 1
2107
+ else:
2108
+ batch_rank = 1
2109
+
2110
+ prim_name = prim.name
2111
+ batch_prim = _vmap_clone_prim(prim)
2112
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
2113
+
2114
+ def vmap_rule(var_bdim, accum_bdim, linear_bdim, grad_bdim, indices_bdim, u_monad):
2115
+ var, var_dim = var_bdim
2116
+ accum, accum_dim = accum_bdim
2117
+ linear, linear_dim = linear_bdim
2118
+ grad, grad_dim = grad_bdim
2119
+ indices, indices_dim = indices_bdim
2120
+ if var_dim is None:
2121
+ if any(dim is not None for dim in [accum_dim, linear_dim, grad_dim, indices_dim]):
2122
+ ValueError("The source axis of `var` is None, but the source "
2123
+ "axis of `accum/linear/grad/indices` is not None. The execution order of "
2124
+ "operator `{}` cannot be guaranteed.".format(prim_name))
2125
+ var, accum, linear = prim(var, accum, linear, grad, indices, u_monad)
2126
+ return (var, None), (accum, None), (linear, None)
2127
+ if var_dim != 0 or accum_dim != var_dim or linear_dim != var_dim:
2128
+ ValueError("For `{}`, the source axis of `var`, `accum` and `linear` must be equal, and "
2129
+ "not equal to 0, but got the source axis of `var`: {}, `accum`: {}, "
2130
+ "`linear`:{}.".format(prim_name, var_dim, accum_dim, linear_dim))
2131
+
2132
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
2133
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
2134
+
2135
+ var, accum, linear = batch_prim(var, accum, linear, grad, indices, u_monad)
2136
+ return (var, 0), (accum, 0), (linear, 0)
2137
+
2138
+ return vmap_rule
2139
+
2140
+
2141
+ @vmap_rules_getters.register(P.Dense)
2142
+ def get_dense_vmap_rule(prim, axis_size):
2143
+ """VmapRule for `Dense` operation."""
2144
+ if isinstance(prim, str):
2145
+ prim = Primitive(prim)
2146
+
2147
+ batch_matmul = P.BatchMatMul(transpose_b=True)
2148
+
2149
+ @_primexpr
2150
+ def get_start_mid_end(x_shape):
2151
+ start = x_shape[0]
2152
+ mid = 1
2153
+ for shp in x_shape[1:-1]:
2154
+ mid *= shp
2155
+ end = x_shape[-1]
2156
+ return start, mid, end
2157
+
2158
+ def vmap_rule(x_bdim, w_bdim, b_bdim):
2159
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, w_bdim, b_bdim)
2160
+ if is_all_none:
2161
+ return result
2162
+
2163
+ x, x_dim = x_bdim
2164
+ w, w_dim = w_bdim
2165
+ b, b_dim = b_bdim
2166
+ x = _bdim_at_front(x, x_dim, axis_size)
2167
+ w = _bdim_at_front(w, w_dim, axis_size)
2168
+ if b is not None:
2169
+ b = _bdim_at_front(b, b_dim, axis_size)
2170
+
2171
+ x_shape = x.shape
2172
+ start, mid, end = get_start_mid_end(x_shape)
2173
+
2174
+ x = x.reshape(start, mid, end)
2175
+
2176
+ out = batch_matmul(x, w)
2177
+ out_shape = tuple(x_shape[:-1]) + (out.shape[-1],)
2178
+ out = out.reshape(out_shape)
2179
+
2180
+ if b is not None:
2181
+ b_shape = b.shape
2182
+ b_shape = (start,) + (1,) * (len(out_shape) - 2) + (b_shape[-1],)
2183
+ b = b.reshape(b_shape)
2184
+
2185
+ out = out + b
2186
+
2187
+ return out, 0
2188
+
2189
+ return vmap_rule
2190
+
2191
+
2192
+ @vmap_rules_getters.register(P.CeLU)
2193
+ def get_logit_vmap_rule(prim, axis_size):
2194
+ """VmapRule for `CeLU` operation"""
2195
+
2196
+ def vmap_rule(x_bdim, alpha_bdim):
2197
+ x_data, x_dim = x_bdim
2198
+ alpha_data, _ = alpha_bdim
2199
+ out = F.celu(x_data, alpha_data)
2200
+ return out, x_dim
2201
+
2202
+ return vmap_rule
2203
+
2204
+
2205
+ @vmap_rules_getters.register(P.Elu)
2206
+ def get_elu_vmap_rule(prim, axis_size):
2207
+ """VmapRule for Elu operations."""
2208
+ if isinstance(prim, str):
2209
+ prim = Primitive(prim)
2210
+
2211
+ def vmap_rule(x_bdim, alpha_bdim):
2212
+ var, dim = x_bdim
2213
+ alpha, alpha_dim = alpha_bdim
2214
+
2215
+ if alpha_dim is not None:
2216
+ _raise_value_error("The source alpha of `alpha` in ELu must be None, but got {}.".format(alpha_dim))
2217
+
2218
+ out = prim(var, alpha)
2219
+ return out, dim
2220
+
2221
+ return vmap_rule
2222
+
2223
+
2224
+ @vmap_rules_getters.register(Embedding)
2225
+ def get_embedding_vmap_rule(prim, axis_size):
2226
+ """VmapRule for Embedding operations."""
2227
+ if isinstance(prim, str):
2228
+ prim_name = prim
2229
+ else:
2230
+ prim_name = prim.name
2231
+ raise RuntimeError(f"THe {prim_name} does not support vmap.")
2232
+
2233
+
2234
+ # Unary vmap
2235
+ get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU)(get_unop_vmap_rule)
2236
+ get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU6)(get_unop_vmap_rule)
2237
+ get_unop_vmap_rule = vmap_rules_getters.register(P.SeLU)(get_unop_vmap_rule)
2238
+ get_unop_vmap_rule = vmap_rules_getters.register(P.HSigmoid)(get_unop_vmap_rule)
2239
+ get_unop_vmap_rule = vmap_rules_getters.register(P.Softplus)(get_unop_vmap_rule)
2240
+ get_unop_vmap_rule = vmap_rules_getters.register(P.Softsign)(get_unop_vmap_rule)
2241
+ get_unop_vmap_rule = vmap_rules_getters.register(P.GeLU)(get_unop_vmap_rule)
2242
+ get_unop_vmap_rule = vmap_rules_getters.register(P.FastGeLU)(get_unop_vmap_rule)
2243
+ get_unop_vmap_rule = vmap_rules_getters.register(P.HSwish)(get_unop_vmap_rule)
2244
+ get_unop_vmap_rule = vmap_rules_getters.register(P.Tanh)(get_unop_vmap_rule)
2245
+ # UnaryGrad vmap
2246
+ get_unary_grad_vmap_rule = vmap_rules_getters.register(G.TanhGrad)(get_unary_grad_vmap_rule)
2247
+ get_unary_grad_vmap_rule = vmap_rules_getters.register(G.SoftplusGrad)(get_unary_grad_vmap_rule)
2248
+ get_unary_grad_vmap_rule = vmap_rules_getters.register('ReluGrad')(get_unary_grad_vmap_rule)
2249
+ get_unary_grad_vmap_rule = vmap_rules_getters.register('ReLU6Grad')(get_unary_grad_vmap_rule)
2250
+ get_unary_grad_vmap_rule = vmap_rules_getters.register('RsqrtGrad')(get_unary_grad_vmap_rule)