mindspore 2.3.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1400) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +51 -0
  18. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +106 -0
  22. mindspore/_checkparam.py +1378 -0
  23. mindspore/_extends/__init__.py +23 -0
  24. mindspore/_extends/builtin_operations.py +224 -0
  25. mindspore/_extends/graph_kernel/__init__.py +17 -0
  26. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  27. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  28. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  29. mindspore/_extends/graph_kernel/model/model.py +553 -0
  30. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  31. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  32. mindspore/_extends/graph_kernel/splitter.py +140 -0
  33. mindspore/_extends/graph_kernel/utils.py +28 -0
  34. mindspore/_extends/parallel_compile/__init__.py +19 -0
  35. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  38. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  39. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  40. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  41. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  42. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  43. mindspore/_extends/parse/__init__.py +49 -0
  44. mindspore/_extends/parse/compile_config.py +258 -0
  45. mindspore/_extends/parse/namespace.py +136 -0
  46. mindspore/_extends/parse/parser.py +1446 -0
  47. mindspore/_extends/parse/resources.py +213 -0
  48. mindspore/_extends/parse/standard_method.py +4437 -0
  49. mindspore/_extends/parse/trope.py +97 -0
  50. mindspore/_extends/pijit/__init__.py +23 -0
  51. mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
  52. mindspore/_extends/remote/__init__.py +19 -0
  53. mindspore/_extends/remote/kernel_build_server.py +199 -0
  54. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  57. mindspore/_extends/utils.py +68 -0
  58. mindspore/_install_custom.py +43 -0
  59. mindspore/_profiler.py +30 -0
  60. mindspore/amp.py +419 -0
  61. mindspore/atlprov.dll +0 -0
  62. mindspore/avcodec-59.dll +0 -0
  63. mindspore/avdevice-59.dll +0 -0
  64. mindspore/avfilter-8.dll +0 -0
  65. mindspore/avformat-59.dll +0 -0
  66. mindspore/avutil-57.dll +0 -0
  67. mindspore/boost/__init__.py +42 -0
  68. mindspore/boost/adasum.py +319 -0
  69. mindspore/boost/base.py +535 -0
  70. mindspore/boost/boost.py +400 -0
  71. mindspore/boost/boost_cell_wrapper.py +790 -0
  72. mindspore/boost/dim_reduce.py +323 -0
  73. mindspore/boost/grad_accumulation.py +79 -0
  74. mindspore/boost/grad_freeze.py +382 -0
  75. mindspore/boost/group_loss_scale_manager.py +166 -0
  76. mindspore/boost/less_batch_normalization.py +174 -0
  77. mindspore/c1.dll +0 -0
  78. mindspore/c1xx.dll +0 -0
  79. mindspore/c2.dll +0 -0
  80. mindspore/cfgpersist.dll +0 -0
  81. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  82. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  83. mindspore/common/__init__.py +84 -0
  84. mindspore/common/_auto_dynamic.py +68 -0
  85. mindspore/common/_decorator.py +50 -0
  86. mindspore/common/_jit_fallback_utils.py +110 -0
  87. mindspore/common/_monad.py +25 -0
  88. mindspore/common/_register_for_adapter.py +74 -0
  89. mindspore/common/_register_for_recompute.py +48 -0
  90. mindspore/common/_register_for_tensor.py +45 -0
  91. mindspore/common/_stub_tensor.py +210 -0
  92. mindspore/common/_utils.py +122 -0
  93. mindspore/common/api.py +2049 -0
  94. mindspore/common/auto_dynamic_shape.py +507 -0
  95. mindspore/common/dtype.py +422 -0
  96. mindspore/common/dump.py +131 -0
  97. mindspore/common/file_system.py +48 -0
  98. mindspore/common/generator.py +260 -0
  99. mindspore/common/hook_handle.py +155 -0
  100. mindspore/common/initializer.py +880 -0
  101. mindspore/common/jit_config.py +98 -0
  102. mindspore/common/lazy_inline.py +240 -0
  103. mindspore/common/mindir_util.py +111 -0
  104. mindspore/common/mutable.py +234 -0
  105. mindspore/common/no_inline.py +54 -0
  106. mindspore/common/np_dtype.py +25 -0
  107. mindspore/common/parameter.py +1048 -0
  108. mindspore/common/recompute.py +262 -0
  109. mindspore/common/seed.py +260 -0
  110. mindspore/common/sparse_tensor.py +1171 -0
  111. mindspore/common/symbol.py +122 -0
  112. mindspore/common/tensor.py +4859 -0
  113. mindspore/communication/__init__.py +37 -0
  114. mindspore/communication/_comm_helper.py +466 -0
  115. mindspore/communication/_hccl_management.py +297 -0
  116. mindspore/communication/comm_func.py +1140 -0
  117. mindspore/communication/management.py +673 -0
  118. mindspore/config/op_info.config +533 -0
  119. mindspore/context.py +1976 -0
  120. mindspore/d3dcompiler_47.dll +0 -0
  121. mindspore/dataset/__init__.py +90 -0
  122. mindspore/dataset/audio/__init__.py +61 -0
  123. mindspore/dataset/audio/transforms.py +3690 -0
  124. mindspore/dataset/audio/utils.py +386 -0
  125. mindspore/dataset/audio/validators.py +1172 -0
  126. mindspore/dataset/callback/__init__.py +20 -0
  127. mindspore/dataset/callback/ds_callback.py +368 -0
  128. mindspore/dataset/callback/validators.py +32 -0
  129. mindspore/dataset/core/__init__.py +13 -0
  130. mindspore/dataset/core/config.py +1088 -0
  131. mindspore/dataset/core/datatypes.py +101 -0
  132. mindspore/dataset/core/py_util_helpers.py +65 -0
  133. mindspore/dataset/core/validator_helpers.py +774 -0
  134. mindspore/dataset/debug/__init__.py +21 -0
  135. mindspore/dataset/debug/debug_hook.py +97 -0
  136. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  137. mindspore/dataset/engine/__init__.py +124 -0
  138. mindspore/dataset/engine/cache_admin.py +47 -0
  139. mindspore/dataset/engine/cache_client.py +129 -0
  140. mindspore/dataset/engine/datasets.py +4554 -0
  141. mindspore/dataset/engine/datasets_audio.py +911 -0
  142. mindspore/dataset/engine/datasets_standard_format.py +493 -0
  143. mindspore/dataset/engine/datasets_text.py +2161 -0
  144. mindspore/dataset/engine/datasets_user_defined.py +1114 -0
  145. mindspore/dataset/engine/datasets_vision.py +4816 -0
  146. mindspore/dataset/engine/iterators.py +342 -0
  147. mindspore/dataset/engine/obs/__init__.py +23 -0
  148. mindspore/dataset/engine/obs/config_loader.py +68 -0
  149. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  150. mindspore/dataset/engine/obs/util.py +475 -0
  151. mindspore/dataset/engine/offload.py +596 -0
  152. mindspore/dataset/engine/queue.py +250 -0
  153. mindspore/dataset/engine/samplers.py +895 -0
  154. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  155. mindspore/dataset/engine/validators.py +2875 -0
  156. mindspore/dataset/text/__init__.py +54 -0
  157. mindspore/dataset/text/transforms.py +1703 -0
  158. mindspore/dataset/text/utils.py +715 -0
  159. mindspore/dataset/text/validators.py +642 -0
  160. mindspore/dataset/transforms/__init__.py +48 -0
  161. mindspore/dataset/transforms/c_transforms.py +638 -0
  162. mindspore/dataset/transforms/py_transforms.py +393 -0
  163. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  164. mindspore/dataset/transforms/transforms.py +1260 -0
  165. mindspore/dataset/transforms/validators.py +410 -0
  166. mindspore/dataset/utils/__init__.py +19 -0
  167. mindspore/dataset/utils/browse_dataset.py +190 -0
  168. mindspore/dataset/utils/line_reader.py +124 -0
  169. mindspore/dataset/vision/__init__.py +68 -0
  170. mindspore/dataset/vision/c_transforms.py +2641 -0
  171. mindspore/dataset/vision/py_transforms.py +2120 -0
  172. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  173. mindspore/dataset/vision/transforms.py +7295 -0
  174. mindspore/dataset/vision/utils.py +863 -0
  175. mindspore/dataset/vision/validators.py +1482 -0
  176. mindspore/default_config.py +2 -0
  177. mindspore/dnnl.dll +0 -0
  178. mindspore/dpcmi.dll +0 -0
  179. mindspore/experimental/__init__.py +20 -0
  180. mindspore/experimental/map_parameter.py +309 -0
  181. mindspore/experimental/optim/__init__.py +40 -0
  182. mindspore/experimental/optim/adadelta.py +161 -0
  183. mindspore/experimental/optim/adagrad.py +168 -0
  184. mindspore/experimental/optim/adam.py +193 -0
  185. mindspore/experimental/optim/adamax.py +170 -0
  186. mindspore/experimental/optim/adamw.py +205 -0
  187. mindspore/experimental/optim/asgd.py +153 -0
  188. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  189. mindspore/experimental/optim/nadam.py +157 -0
  190. mindspore/experimental/optim/optimizer.py +259 -0
  191. mindspore/experimental/optim/radam.py +194 -0
  192. mindspore/experimental/optim/rmsprop.py +154 -0
  193. mindspore/experimental/optim/rprop.py +164 -0
  194. mindspore/experimental/optim/sgd.py +156 -0
  195. mindspore/hal/__init__.py +40 -0
  196. mindspore/hal/_ascend.py +57 -0
  197. mindspore/hal/_base.py +57 -0
  198. mindspore/hal/_cpu.py +56 -0
  199. mindspore/hal/_gpu.py +57 -0
  200. mindspore/hal/device.py +356 -0
  201. mindspore/hal/event.py +179 -0
  202. mindspore/hal/memory.py +326 -0
  203. mindspore/hal/stream.py +339 -0
  204. mindspore/include/OWNERS +7 -0
  205. mindspore/include/api/allocator.h +97 -0
  206. mindspore/include/api/callback/callback.h +93 -0
  207. mindspore/include/api/callback/ckpt_saver.h +41 -0
  208. mindspore/include/api/callback/loss_monitor.h +33 -0
  209. mindspore/include/api/callback/lr_scheduler.h +51 -0
  210. mindspore/include/api/callback/time_monitor.h +34 -0
  211. mindspore/include/api/callback/train_accuracy.h +37 -0
  212. mindspore/include/api/cell.h +90 -0
  213. mindspore/include/api/cfg.h +82 -0
  214. mindspore/include/api/context.h +602 -0
  215. mindspore/include/api/data_type.h +47 -0
  216. mindspore/include/api/delegate.h +178 -0
  217. mindspore/include/api/delegate_api.h +75 -0
  218. mindspore/include/api/dual_abi_helper.h +208 -0
  219. mindspore/include/api/format.h +28 -0
  220. mindspore/include/api/graph.h +46 -0
  221. mindspore/include/api/kernel.h +58 -0
  222. mindspore/include/api/kernel_api.h +168 -0
  223. mindspore/include/api/metrics/accuracy.h +36 -0
  224. mindspore/include/api/metrics/metrics.h +41 -0
  225. mindspore/include/api/model.h +438 -0
  226. mindspore/include/api/model_group.h +79 -0
  227. mindspore/include/api/model_parallel_runner.h +168 -0
  228. mindspore/include/api/serialization.h +185 -0
  229. mindspore/include/api/status.h +192 -0
  230. mindspore/include/api/types.h +431 -0
  231. mindspore/include/api/visible.h +41 -0
  232. mindspore/include/c_api/context_c.h +179 -0
  233. mindspore/include/c_api/data_type_c.h +52 -0
  234. mindspore/include/c_api/format_c.h +46 -0
  235. mindspore/include/c_api/model_c.h +347 -0
  236. mindspore/include/c_api/ms/abstract.h +67 -0
  237. mindspore/include/c_api/ms/attribute.h +197 -0
  238. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  239. mindspore/include/c_api/ms/base/macros.h +32 -0
  240. mindspore/include/c_api/ms/base/status.h +33 -0
  241. mindspore/include/c_api/ms/base/types.h +283 -0
  242. mindspore/include/c_api/ms/context.h +102 -0
  243. mindspore/include/c_api/ms/graph.h +160 -0
  244. mindspore/include/c_api/ms/node.h +606 -0
  245. mindspore/include/c_api/ms/tensor.h +161 -0
  246. mindspore/include/c_api/ms/value.h +84 -0
  247. mindspore/include/c_api/status_c.h +79 -0
  248. mindspore/include/c_api/tensor_c.h +146 -0
  249. mindspore/include/c_api/types_c.h +67 -0
  250. mindspore/include/dataset/config.h +163 -0
  251. mindspore/include/dataset/constants.h +363 -0
  252. mindspore/include/dataset/execute.h +196 -0
  253. mindspore/include/dataset/text.h +1092 -0
  254. mindspore/include/dataset/transforms.h +638 -0
  255. mindspore/include/dataset/vision.h +2125 -0
  256. mindspore/include/dataset/vision_ascend.h +206 -0
  257. mindspore/include/dataset/vision_lite.h +625 -0
  258. mindspore/jpeg62.dll +0 -0
  259. mindspore/log.py +633 -0
  260. mindspore/mindrecord/__init__.py +43 -0
  261. mindspore/mindrecord/common/__init__.py +17 -0
  262. mindspore/mindrecord/common/constant.py +20 -0
  263. mindspore/mindrecord/common/enums.py +44 -0
  264. mindspore/mindrecord/common/exceptions.py +311 -0
  265. mindspore/mindrecord/config.py +809 -0
  266. mindspore/mindrecord/filereader.py +174 -0
  267. mindspore/mindrecord/filewriter.py +705 -0
  268. mindspore/mindrecord/mindpage.py +210 -0
  269. mindspore/mindrecord/shardheader.py +141 -0
  270. mindspore/mindrecord/shardindexgenerator.py +74 -0
  271. mindspore/mindrecord/shardreader.py +117 -0
  272. mindspore/mindrecord/shardsegment.py +128 -0
  273. mindspore/mindrecord/shardutils.py +185 -0
  274. mindspore/mindrecord/shardwriter.py +237 -0
  275. mindspore/mindrecord/tools/__init__.py +17 -0
  276. mindspore/mindrecord/tools/cifar10.py +140 -0
  277. mindspore/mindrecord/tools/cifar100.py +153 -0
  278. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  279. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  280. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  281. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  282. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  283. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  284. mindspore/mindspore_backend.dll +0 -0
  285. mindspore/mindspore_common.dll +0 -0
  286. mindspore/mindspore_core.dll +0 -0
  287. mindspore/mindspore_glog.dll +0 -0
  288. mindspore/mindspore_np_dtype.dll +0 -0
  289. mindspore/mindspore_shared_lib.dll +0 -0
  290. mindspore/mint/__init__.py +1137 -0
  291. mindspore/mint/linalg/__init__.py +22 -0
  292. mindspore/mint/nn/__init__.py +512 -0
  293. mindspore/mint/nn/functional.py +573 -0
  294. mindspore/mint/optim/__init__.py +24 -0
  295. mindspore/mint/optim/adamw.py +185 -0
  296. mindspore/msobj140.dll +0 -0
  297. mindspore/mspdb140.dll +0 -0
  298. mindspore/mspdbcore.dll +0 -0
  299. mindspore/mspdbst.dll +0 -0
  300. mindspore/mspft140.dll +0 -0
  301. mindspore/msvcdis140.dll +0 -0
  302. mindspore/msvcp140.dll +0 -0
  303. mindspore/msvcp140_1.dll +0 -0
  304. mindspore/msvcp140_2.dll +0 -0
  305. mindspore/msvcp140_atomic_wait.dll +0 -0
  306. mindspore/msvcp140_codecvt_ids.dll +0 -0
  307. mindspore/multiprocessing/__init__.py +72 -0
  308. mindspore/nn/__init__.py +48 -0
  309. mindspore/nn/cell.py +2605 -0
  310. mindspore/nn/dynamic_lr.py +482 -0
  311. mindspore/nn/extend/__init__.py +29 -0
  312. mindspore/nn/extend/basic.py +140 -0
  313. mindspore/nn/extend/embedding.py +143 -0
  314. mindspore/nn/extend/layer/__init__.py +27 -0
  315. mindspore/nn/extend/layer/normalization.py +109 -0
  316. mindspore/nn/extend/pooling.py +117 -0
  317. mindspore/nn/grad/__init__.py +21 -0
  318. mindspore/nn/grad/cell_grad.py +196 -0
  319. mindspore/nn/layer/__init__.py +63 -0
  320. mindspore/nn/layer/activation.py +1655 -0
  321. mindspore/nn/layer/basic.py +1519 -0
  322. mindspore/nn/layer/channel_shuffle.py +90 -0
  323. mindspore/nn/layer/combined.py +248 -0
  324. mindspore/nn/layer/container.py +734 -0
  325. mindspore/nn/layer/conv.py +1505 -0
  326. mindspore/nn/layer/dense.py +204 -0
  327. mindspore/nn/layer/embedding.py +751 -0
  328. mindspore/nn/layer/embedding_service.py +531 -0
  329. mindspore/nn/layer/embedding_service_layer.py +393 -0
  330. mindspore/nn/layer/image.py +661 -0
  331. mindspore/nn/layer/math.py +1069 -0
  332. mindspore/nn/layer/normalization.py +1177 -0
  333. mindspore/nn/layer/padding.py +894 -0
  334. mindspore/nn/layer/pooling.py +2148 -0
  335. mindspore/nn/layer/rnn_cells.py +388 -0
  336. mindspore/nn/layer/rnns.py +849 -0
  337. mindspore/nn/layer/thor_layer.py +963 -0
  338. mindspore/nn/layer/timedistributed.py +155 -0
  339. mindspore/nn/layer/transformer.py +823 -0
  340. mindspore/nn/learning_rate_schedule.py +512 -0
  341. mindspore/nn/loss/__init__.py +36 -0
  342. mindspore/nn/loss/loss.py +2846 -0
  343. mindspore/nn/metrics.py +53 -0
  344. mindspore/nn/optim/__init__.py +44 -0
  345. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  346. mindspore/nn/optim/ada_grad.py +217 -0
  347. mindspore/nn/optim/adadelta.py +206 -0
  348. mindspore/nn/optim/adafactor.py +448 -0
  349. mindspore/nn/optim/adam.py +1297 -0
  350. mindspore/nn/optim/adamax.py +220 -0
  351. mindspore/nn/optim/adasum.py +548 -0
  352. mindspore/nn/optim/asgd.py +216 -0
  353. mindspore/nn/optim/ftrl.py +401 -0
  354. mindspore/nn/optim/lamb.py +296 -0
  355. mindspore/nn/optim/lars.py +202 -0
  356. mindspore/nn/optim/lazyadam.py +533 -0
  357. mindspore/nn/optim/momentum.py +239 -0
  358. mindspore/nn/optim/optimizer.py +1034 -0
  359. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  360. mindspore/nn/optim/rmsprop.py +264 -0
  361. mindspore/nn/optim/rprop.py +251 -0
  362. mindspore/nn/optim/sgd.py +237 -0
  363. mindspore/nn/optim/thor.py +1310 -0
  364. mindspore/nn/probability/__init__.py +22 -0
  365. mindspore/nn/probability/bijector/__init__.py +35 -0
  366. mindspore/nn/probability/bijector/bijector.py +337 -0
  367. mindspore/nn/probability/bijector/exp.py +65 -0
  368. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  369. mindspore/nn/probability/bijector/invert.py +126 -0
  370. mindspore/nn/probability/bijector/power_transform.py +196 -0
  371. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  372. mindspore/nn/probability/bijector/softplus.py +189 -0
  373. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  374. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  375. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  376. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  377. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  378. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  379. mindspore/nn/probability/distribution/__init__.py +56 -0
  380. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  381. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  382. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  383. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  384. mindspore/nn/probability/distribution/beta.py +391 -0
  385. mindspore/nn/probability/distribution/categorical.py +435 -0
  386. mindspore/nn/probability/distribution/cauchy.py +383 -0
  387. mindspore/nn/probability/distribution/distribution.py +827 -0
  388. mindspore/nn/probability/distribution/exponential.py +350 -0
  389. mindspore/nn/probability/distribution/gamma.py +391 -0
  390. mindspore/nn/probability/distribution/geometric.py +335 -0
  391. mindspore/nn/probability/distribution/gumbel.py +257 -0
  392. mindspore/nn/probability/distribution/half_normal.py +133 -0
  393. mindspore/nn/probability/distribution/laplace.py +128 -0
  394. mindspore/nn/probability/distribution/log_normal.py +272 -0
  395. mindspore/nn/probability/distribution/logistic.py +379 -0
  396. mindspore/nn/probability/distribution/normal.py +336 -0
  397. mindspore/nn/probability/distribution/poisson.py +288 -0
  398. mindspore/nn/probability/distribution/student_t.py +149 -0
  399. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  400. mindspore/nn/probability/distribution/uniform.py +375 -0
  401. mindspore/nn/reinforcement/__init__.py +24 -0
  402. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  403. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  404. mindspore/nn/reinforcement/tensor_array.py +145 -0
  405. mindspore/nn/sparse/__init__.py +23 -0
  406. mindspore/nn/sparse/sparse.py +147 -0
  407. mindspore/nn/wrap/__init__.py +49 -0
  408. mindspore/nn/wrap/cell_wrapper.py +979 -0
  409. mindspore/nn/wrap/grad_reducer.py +608 -0
  410. mindspore/nn/wrap/loss_scale.py +680 -0
  411. mindspore/numpy/__init__.py +121 -0
  412. mindspore/numpy/array_creations.py +2734 -0
  413. mindspore/numpy/array_ops.py +2625 -0
  414. mindspore/numpy/dtypes.py +185 -0
  415. mindspore/numpy/fft.py +431 -0
  416. mindspore/numpy/logic_ops.py +935 -0
  417. mindspore/numpy/math_ops.py +5910 -0
  418. mindspore/numpy/utils.py +214 -0
  419. mindspore/numpy/utils_const.py +565 -0
  420. mindspore/opencv_core452.dll +0 -0
  421. mindspore/opencv_imgcodecs452.dll +0 -0
  422. mindspore/opencv_imgproc452.dll +0 -0
  423. mindspore/ops/__init__.py +54 -0
  424. mindspore/ops/_constants.py +30 -0
  425. mindspore/ops/_grad_experimental/__init__.py +31 -0
  426. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  427. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  428. mindspore/ops/_grad_experimental/grad_comm_ops.py +670 -0
  429. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  430. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  431. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  432. mindspore/ops/_grad_experimental/grad_math_ops.py +824 -0
  433. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  434. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  435. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  436. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  437. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  438. mindspore/ops/_op_impl/__init__.py +23 -0
  439. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  440. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  441. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  442. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  443. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  444. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  445. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  446. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  447. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  448. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  449. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  450. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  451. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  452. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  453. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  454. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  455. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  456. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  457. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  458. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  459. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  460. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  461. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  462. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  463. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  464. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  465. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  466. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  467. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  468. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  469. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  470. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  471. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  472. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  473. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  474. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  475. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  476. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  477. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  478. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  479. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  480. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  481. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  482. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  483. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  484. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  485. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  486. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  487. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  488. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  489. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  490. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  491. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  492. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  493. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  494. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  495. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  497. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  498. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  499. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  500. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  501. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  502. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  503. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  504. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  505. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  506. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  507. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  508. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  509. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  510. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  511. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  512. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  513. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  514. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  515. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  516. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  519. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  520. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  521. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  522. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  523. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  524. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  525. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  526. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  527. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  528. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  529. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  530. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  531. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  532. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  533. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  534. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  535. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  536. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  537. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  538. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  539. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  540. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  541. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  542. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  543. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  544. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  545. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  546. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  547. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  548. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  549. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  550. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  551. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  552. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  553. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  554. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  555. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  556. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  557. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  558. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  559. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  560. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  561. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  562. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  563. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  564. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  565. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  566. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  567. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  568. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  569. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  570. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  571. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  572. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  573. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  574. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  575. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  576. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  577. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  578. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  579. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  580. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  581. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  582. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  583. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  584. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  585. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  586. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  587. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  588. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  589. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  590. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  591. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  592. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  593. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  594. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  595. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  596. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  597. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  598. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  599. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  600. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  601. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  602. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  603. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  604. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  605. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  606. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  607. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  608. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  609. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  610. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  611. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  612. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  613. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  614. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  615. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  616. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  617. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  618. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  619. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  620. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  621. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  622. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  623. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  624. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  625. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  626. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  627. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  628. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  629. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  630. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  631. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  632. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  633. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  634. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  635. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  636. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  637. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  638. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  640. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  641. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  643. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  644. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  645. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  646. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  647. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  648. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  649. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  650. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  651. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  652. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  653. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  654. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  655. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  656. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  658. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  659. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  660. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  661. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  662. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  663. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  664. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  665. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  666. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  667. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  668. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  669. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  670. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  671. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  672. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  673. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  674. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  675. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  676. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  677. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  678. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  679. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  680. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  681. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  682. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  683. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  684. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  685. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  686. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  687. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  688. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  689. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  690. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  691. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  692. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  693. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  694. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  695. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  696. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  697. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  698. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  699. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  700. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  701. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  702. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  703. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  704. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  705. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  706. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  707. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  708. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  709. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  710. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  711. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  712. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  713. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  714. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  715. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  716. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  717. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  718. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  719. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  720. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  721. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  722. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  723. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  724. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  725. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  727. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  729. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  730. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  732. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  733. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  734. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  735. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  736. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  737. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  738. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  739. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  740. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  741. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  742. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  744. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  745. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  746. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  747. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  749. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  750. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  751. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  752. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  753. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  754. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  755. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  756. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  757. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  758. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  759. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  760. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  761. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  762. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  763. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  764. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  765. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  766. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  767. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  768. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  769. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  771. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  772. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  773. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  774. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  775. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  776. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  777. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  778. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  779. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  780. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  781. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  782. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  783. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  784. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  785. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  786. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  787. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  788. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  789. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  790. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  791. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  793. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  794. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  795. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  796. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  797. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  798. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  799. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  800. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  801. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  802. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  803. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  804. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  805. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  806. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  807. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  808. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  809. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  810. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  811. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  812. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  813. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  814. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  815. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  816. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  817. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  818. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  819. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  820. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  821. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  822. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  823. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  824. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  825. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  826. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  827. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  841. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  842. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  843. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  844. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  845. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  846. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  847. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  848. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  849. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  850. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  851. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  852. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  853. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  854. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  855. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  856. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  857. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  858. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  859. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  860. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  861. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  862. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  863. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  865. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  866. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  867. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  868. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  869. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  870. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  871. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  873. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  874. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  875. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  876. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  877. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  878. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  879. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  880. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  881. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  882. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  883. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  884. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  885. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  886. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  887. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  888. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  889. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  890. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  891. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  892. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  893. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  894. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  895. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  896. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  897. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  898. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  899. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  900. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  901. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  902. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  903. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  904. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  905. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  906. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  907. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  908. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  909. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  910. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  911. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  912. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  913. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  914. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  915. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  916. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  917. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  918. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  919. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  920. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  921. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  922. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  923. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  924. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  925. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  926. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  927. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  929. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  930. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  931. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  932. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  933. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  934. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  935. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  936. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  937. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  938. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  939. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  940. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  941. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  942. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  943. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  944. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  945. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  946. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  947. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  948. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  949. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  950. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  951. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  952. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  953. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  954. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  955. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  956. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  957. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  958. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  959. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  960. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  961. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  962. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  963. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  964. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  965. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  966. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  967. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  968. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  969. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  970. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  971. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  972. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  973. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  974. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  975. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  976. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  977. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  978. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  979. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  980. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  981. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  982. mindspore/ops/_op_impl/cpu/div.py +32 -0
  983. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  984. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  985. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  986. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  987. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  988. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  989. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  990. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  991. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  992. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  993. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  994. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  995. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  996. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  997. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  998. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  999. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  1000. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  1001. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  1002. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  1003. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  1004. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  1005. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  1006. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  1007. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  1009. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  1010. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  1011. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  1012. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  1013. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  1014. mindspore/ops/_op_impl/cpu/range.py +34 -0
  1015. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  1016. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  1017. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  1018. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  1019. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  1020. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  1021. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  1022. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1023. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1024. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1025. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1026. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1027. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1028. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1029. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1030. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1031. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1032. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1033. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1034. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1035. mindspore/ops/_primitive_cache.py +90 -0
  1036. mindspore/ops/_register_for_op.py +73 -0
  1037. mindspore/ops/_utils/__init__.py +20 -0
  1038. mindspore/ops/_utils/utils.py +147 -0
  1039. mindspore/ops/_vmap/__init__.py +25 -0
  1040. mindspore/ops/_vmap/vmap_array_ops.py +2151 -0
  1041. mindspore/ops/_vmap/vmap_base.py +533 -0
  1042. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1043. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1044. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1045. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1046. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1047. mindspore/ops/_vmap/vmap_math_ops.py +977 -0
  1048. mindspore/ops/_vmap/vmap_nn_ops.py +2209 -0
  1049. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1050. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1051. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1052. mindspore/ops/auto_generate/__init__.py +31 -0
  1053. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
  1054. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
  1055. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1056. mindspore/ops/auto_generate/gen_extend_func.py +980 -0
  1057. mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
  1058. mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
  1059. mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
  1060. mindspore/ops/composite/__init__.py +71 -0
  1061. mindspore/ops/composite/base.py +1281 -0
  1062. mindspore/ops/composite/env_ops.py +41 -0
  1063. mindspore/ops/composite/math_ops.py +125 -0
  1064. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1065. mindspore/ops/composite/multitype_ops/_compile_utils.py +1458 -0
  1066. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1067. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1068. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1069. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1070. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1071. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1072. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1073. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1074. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1075. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1076. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1077. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1078. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1079. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1080. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1081. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1082. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1083. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1084. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1085. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1086. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1087. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1088. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1089. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1090. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1091. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1092. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1093. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1094. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1095. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1096. mindspore/ops/deprecated.py +315 -0
  1097. mindspore/ops/extend/__init__.py +53 -0
  1098. mindspore/ops/extend/array_func.py +218 -0
  1099. mindspore/ops/extend/math_func.py +76 -0
  1100. mindspore/ops/extend/nn_func.py +308 -0
  1101. mindspore/ops/function/__init__.py +760 -0
  1102. mindspore/ops/function/array_func.py +6889 -0
  1103. mindspore/ops/function/clip_func.py +384 -0
  1104. mindspore/ops/function/debug_func.py +69 -0
  1105. mindspore/ops/function/fft_func.py +31 -0
  1106. mindspore/ops/function/grad/__init__.py +34 -0
  1107. mindspore/ops/function/grad/grad_func.py +1424 -0
  1108. mindspore/ops/function/image_func.py +292 -0
  1109. mindspore/ops/function/linalg_func.py +416 -0
  1110. mindspore/ops/function/math_func.py +11877 -0
  1111. mindspore/ops/function/nn_func.py +8175 -0
  1112. mindspore/ops/function/other_func.py +114 -0
  1113. mindspore/ops/function/parameter_func.py +134 -0
  1114. mindspore/ops/function/random_func.py +1539 -0
  1115. mindspore/ops/function/reshard_func.py +102 -0
  1116. mindspore/ops/function/sparse_func.py +884 -0
  1117. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1118. mindspore/ops/function/spectral_func.py +150 -0
  1119. mindspore/ops/function/vmap_func.py +116 -0
  1120. mindspore/ops/functional.py +454 -0
  1121. mindspore/ops/op_info_register.py +1572 -0
  1122. mindspore/ops/operations/__init__.py +717 -0
  1123. mindspore/ops/operations/_csr_ops.py +403 -0
  1124. mindspore/ops/operations/_custom_grad.py +181 -0
  1125. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1126. mindspore/ops/operations/_grad_ops.py +3052 -0
  1127. mindspore/ops/operations/_infer_ops.py +19 -0
  1128. mindspore/ops/operations/_inner_ops.py +2567 -0
  1129. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1130. mindspore/ops/operations/_ms_kernel.py +601 -0
  1131. mindspore/ops/operations/_ocr_ops.py +379 -0
  1132. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1133. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1134. mindspore/ops/operations/_quant_ops.py +1844 -0
  1135. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1136. mindspore/ops/operations/_scalar_ops.py +106 -0
  1137. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1138. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1139. mindspore/ops/operations/_tensor_array.py +359 -0
  1140. mindspore/ops/operations/_thor_ops.py +807 -0
  1141. mindspore/ops/operations/array_ops.py +6258 -0
  1142. mindspore/ops/operations/comm_ops.py +1996 -0
  1143. mindspore/ops/operations/control_ops.py +127 -0
  1144. mindspore/ops/operations/custom_ops.py +1065 -0
  1145. mindspore/ops/operations/debug_ops.py +646 -0
  1146. mindspore/ops/operations/image_ops.py +1041 -0
  1147. mindspore/ops/operations/inner_ops.py +697 -0
  1148. mindspore/ops/operations/linalg_ops.py +95 -0
  1149. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1150. mindspore/ops/operations/manually_defined/_inner.py +61 -0
  1151. mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
  1152. mindspore/ops/operations/math_ops.py +5306 -0
  1153. mindspore/ops/operations/nn_ops.py +9669 -0
  1154. mindspore/ops/operations/other_ops.py +871 -0
  1155. mindspore/ops/operations/random_ops.py +1243 -0
  1156. mindspore/ops/operations/reshard_ops.py +53 -0
  1157. mindspore/ops/operations/rl_ops.py +288 -0
  1158. mindspore/ops/operations/sparse_ops.py +2753 -0
  1159. mindspore/ops/operations/spectral_ops.py +111 -0
  1160. mindspore/ops/primitive.py +1034 -0
  1161. mindspore/ops/signature.py +54 -0
  1162. mindspore/ops/silent_check.py +162 -0
  1163. mindspore/ops/vm_impl_registry.py +91 -0
  1164. mindspore/ops_generate/__init__.py +27 -0
  1165. mindspore/ops_generate/arg_dtype_cast.py +250 -0
  1166. mindspore/ops_generate/arg_handler.py +197 -0
  1167. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1168. mindspore/ops_generate/gen_ops.py +1084 -0
  1169. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1170. mindspore/ops_generate/gen_pyboost_func.py +968 -0
  1171. mindspore/ops_generate/gen_utils.py +209 -0
  1172. mindspore/ops_generate/op_proto.py +138 -0
  1173. mindspore/ops_generate/pyboost_utils.py +354 -0
  1174. mindspore/ops_generate/template.py +239 -0
  1175. mindspore/parallel/__init__.py +28 -0
  1176. mindspore/parallel/_auto_parallel_context.py +1466 -0
  1177. mindspore/parallel/_cell_wrapper.py +91 -0
  1178. mindspore/parallel/_cost_model_context.py +700 -0
  1179. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1180. mindspore/parallel/_offload_context.py +275 -0
  1181. mindspore/parallel/_parallel_serialization.py +533 -0
  1182. mindspore/parallel/_ps_context.py +242 -0
  1183. mindspore/parallel/_recovery_context.py +110 -0
  1184. mindspore/parallel/_tensor.py +660 -0
  1185. mindspore/parallel/_transformer/__init__.py +35 -0
  1186. mindspore/parallel/_transformer/layers.py +765 -0
  1187. mindspore/parallel/_transformer/loss.py +251 -0
  1188. mindspore/parallel/_transformer/moe.py +693 -0
  1189. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1190. mindspore/parallel/_transformer/transformer.py +3119 -0
  1191. mindspore/parallel/_utils.py +600 -0
  1192. mindspore/parallel/algo_parameter_config.py +400 -0
  1193. mindspore/parallel/checkpoint_transform.py +643 -0
  1194. mindspore/parallel/cluster/__init__.py +15 -0
  1195. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1196. mindspore/parallel/cluster/process_entity/_api.py +344 -0
  1197. mindspore/parallel/cluster/process_entity/_utils.py +126 -0
  1198. mindspore/parallel/cluster/run.py +136 -0
  1199. mindspore/parallel/mpi/__init__.py +14 -0
  1200. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1201. mindspore/parallel/parameter_broadcast.py +152 -0
  1202. mindspore/parallel/shard.py +350 -0
  1203. mindspore/perf_msvcbuildinsights.dll +0 -0
  1204. mindspore/pgodb140.dll +0 -0
  1205. mindspore/pgort140.dll +0 -0
  1206. mindspore/profiler/__init__.py +27 -0
  1207. mindspore/profiler/common/__init__.py +14 -0
  1208. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1209. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1210. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1211. mindspore/profiler/common/process_pool.py +41 -0
  1212. mindspore/profiler/common/singleton.py +28 -0
  1213. mindspore/profiler/common/struct_type.py +118 -0
  1214. mindspore/profiler/common/util.py +444 -0
  1215. mindspore/profiler/common/validator/__init__.py +14 -0
  1216. mindspore/profiler/common/validator/validate_path.py +84 -0
  1217. mindspore/profiler/envprofiling.py +256 -0
  1218. mindspore/profiler/parser/__init__.py +14 -0
  1219. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1220. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1221. mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
  1222. mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
  1223. mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
  1224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
  1225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
  1226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
  1227. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
  1228. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
  1230. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1231. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1232. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1233. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1234. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1235. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1236. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1237. mindspore/profiler/parser/ascend_msprof_exporter.py +281 -0
  1238. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1239. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1240. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1241. mindspore/profiler/parser/ascend_timeline_generator.py +543 -0
  1242. mindspore/profiler/parser/base_timeline_generator.py +489 -0
  1243. mindspore/profiler/parser/container.py +229 -0
  1244. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  1245. mindspore/profiler/parser/flops_parser.py +531 -0
  1246. mindspore/profiler/parser/framework_enum.py +111 -0
  1247. mindspore/profiler/parser/framework_parser.py +854 -0
  1248. mindspore/profiler/parser/framework_struct.py +61 -0
  1249. mindspore/profiler/parser/hccl_parser.py +573 -0
  1250. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1251. mindspore/profiler/parser/integrator.py +526 -0
  1252. mindspore/profiler/parser/memory_usage_parser.py +431 -0
  1253. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1254. mindspore/profiler/parser/minddata_parser.py +186 -0
  1255. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1256. mindspore/profiler/parser/msadvisor_analyzer.py +82 -0
  1257. mindspore/profiler/parser/msadvisor_parser.py +240 -0
  1258. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1259. mindspore/profiler/parser/optime_parser.py +250 -0
  1260. mindspore/profiler/parser/profiler_info.py +141 -0
  1261. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1262. mindspore/profiler/profiling.py +2054 -0
  1263. mindspore/rewrite/__init__.py +29 -0
  1264. mindspore/rewrite/api/__init__.py +17 -0
  1265. mindspore/rewrite/api/node.py +519 -0
  1266. mindspore/rewrite/api/node_type.py +53 -0
  1267. mindspore/rewrite/api/pattern_engine.py +490 -0
  1268. mindspore/rewrite/api/scoped_value.py +181 -0
  1269. mindspore/rewrite/api/symbol_tree.py +497 -0
  1270. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1271. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1272. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1273. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1274. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1275. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1276. mindspore/rewrite/common/__init__.py +19 -0
  1277. mindspore/rewrite/common/config.py +24 -0
  1278. mindspore/rewrite/common/error_log.py +39 -0
  1279. mindspore/rewrite/common/event.py +28 -0
  1280. mindspore/rewrite/common/namer.py +271 -0
  1281. mindspore/rewrite/common/namespace.py +118 -0
  1282. mindspore/rewrite/common/observable.py +44 -0
  1283. mindspore/rewrite/common/observer.py +54 -0
  1284. mindspore/rewrite/node/__init__.py +22 -0
  1285. mindspore/rewrite/node/call_function.py +95 -0
  1286. mindspore/rewrite/node/cell_container.py +139 -0
  1287. mindspore/rewrite/node/control_flow.py +113 -0
  1288. mindspore/rewrite/node/node.py +1428 -0
  1289. mindspore/rewrite/node/node_manager.py +283 -0
  1290. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1291. mindspore/rewrite/parsers/__init__.py +29 -0
  1292. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1293. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1294. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1295. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1296. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1297. mindspore/rewrite/parsers/container_parser.py +88 -0
  1298. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1299. mindspore/rewrite/parsers/for_parser.py +61 -0
  1300. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1301. mindspore/rewrite/parsers/if_parser.py +85 -0
  1302. mindspore/rewrite/parsers/module_parser.py +117 -0
  1303. mindspore/rewrite/parsers/parser.py +43 -0
  1304. mindspore/rewrite/parsers/parser_register.py +86 -0
  1305. mindspore/rewrite/parsers/return_parser.py +37 -0
  1306. mindspore/rewrite/parsers/while_parser.py +59 -0
  1307. mindspore/rewrite/sparsify/__init__.py +0 -0
  1308. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1309. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1310. mindspore/rewrite/sparsify/utils.py +179 -0
  1311. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1312. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1313. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1314. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1315. mindspore/run_check/__init__.py +20 -0
  1316. mindspore/run_check/_check_version.py +574 -0
  1317. mindspore/run_check/run_check.py +66 -0
  1318. mindspore/safeguard/__init__.py +18 -0
  1319. mindspore/safeguard/rewrite_obfuscation.py +531 -0
  1320. mindspore/swresample-4.dll +0 -0
  1321. mindspore/swscale-6.dll +0 -0
  1322. mindspore/tbbmalloc.dll +0 -0
  1323. mindspore/tinyxml2.dll +0 -0
  1324. mindspore/train/__init__.py +47 -0
  1325. mindspore/train/_utils.py +439 -0
  1326. mindspore/train/amp.py +817 -0
  1327. mindspore/train/anf_ir_pb2.py +1517 -0
  1328. mindspore/train/callback/__init__.py +44 -0
  1329. mindspore/train/callback/_backup_and_restore.py +117 -0
  1330. mindspore/train/callback/_callback.py +613 -0
  1331. mindspore/train/callback/_checkpoint.py +751 -0
  1332. mindspore/train/callback/_cluster_monitor.py +201 -0
  1333. mindspore/train/callback/_dataset_graph.py +150 -0
  1334. mindspore/train/callback/_early_stop.py +239 -0
  1335. mindspore/train/callback/_flops_collector.py +238 -0
  1336. mindspore/train/callback/_history.py +92 -0
  1337. mindspore/train/callback/_lambda_callback.py +80 -0
  1338. mindspore/train/callback/_landscape.py +1049 -0
  1339. mindspore/train/callback/_loss_monitor.py +107 -0
  1340. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1341. mindspore/train/callback/_mindio_ttp.py +443 -0
  1342. mindspore/train/callback/_on_request_exit.py +195 -0
  1343. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1344. mindspore/train/callback/_summary_collector.py +1184 -0
  1345. mindspore/train/callback/_time_monitor.py +141 -0
  1346. mindspore/train/checkpoint_pb2.py +233 -0
  1347. mindspore/train/data_sink.py +219 -0
  1348. mindspore/train/dataset_helper.py +688 -0
  1349. mindspore/train/lineage_pb2.py +1260 -0
  1350. mindspore/train/loss_scale_manager.py +213 -0
  1351. mindspore/train/memory_profiling_pb2.py +298 -0
  1352. mindspore/train/metrics/__init__.py +175 -0
  1353. mindspore/train/metrics/accuracy.py +133 -0
  1354. mindspore/train/metrics/auc.py +129 -0
  1355. mindspore/train/metrics/bleu_score.py +170 -0
  1356. mindspore/train/metrics/confusion_matrix.py +700 -0
  1357. mindspore/train/metrics/cosine_similarity.py +109 -0
  1358. mindspore/train/metrics/dice.py +116 -0
  1359. mindspore/train/metrics/error.py +175 -0
  1360. mindspore/train/metrics/fbeta.py +167 -0
  1361. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1362. mindspore/train/metrics/loss.py +97 -0
  1363. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1364. mindspore/train/metrics/metric.py +373 -0
  1365. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1366. mindspore/train/metrics/perplexity.py +133 -0
  1367. mindspore/train/metrics/precision.py +160 -0
  1368. mindspore/train/metrics/recall.py +159 -0
  1369. mindspore/train/metrics/roc.py +223 -0
  1370. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1371. mindspore/train/metrics/topk.py +167 -0
  1372. mindspore/train/mind_ir_pb2.py +1903 -0
  1373. mindspore/train/model.py +2176 -0
  1374. mindspore/train/node_strategy_pb2.py +653 -0
  1375. mindspore/train/print_pb2.py +184 -0
  1376. mindspore/train/profiling_parallel_pb2.py +151 -0
  1377. mindspore/train/serialization.py +3101 -0
  1378. mindspore/train/summary/__init__.py +23 -0
  1379. mindspore/train/summary/_lineage_adapter.py +41 -0
  1380. mindspore/train/summary/_summary_adapter.py +496 -0
  1381. mindspore/train/summary/_writer_pool.py +207 -0
  1382. mindspore/train/summary/enums.py +56 -0
  1383. mindspore/train/summary/summary_record.py +581 -0
  1384. mindspore/train/summary/writer.py +167 -0
  1385. mindspore/train/summary_pb2.py +1165 -0
  1386. mindspore/train/train_thor/__init__.py +20 -0
  1387. mindspore/train/train_thor/convert_utils.py +268 -0
  1388. mindspore/train/train_thor/dataset_helper.py +192 -0
  1389. mindspore/train/train_thor/model_thor.py +257 -0
  1390. mindspore/turbojpeg.dll +0 -0
  1391. mindspore/vcmeta.dll +0 -0
  1392. mindspore/vcomp140.dll +0 -0
  1393. mindspore/vcruntime140.dll +0 -0
  1394. mindspore/vcruntime140_1.dll +0 -0
  1395. mindspore/version.py +1 -0
  1396. mindspore-2.3.0.dist-info/METADATA +351 -0
  1397. mindspore-2.3.0.dist-info/RECORD +1400 -0
  1398. mindspore-2.3.0.dist-info/WHEEL +5 -0
  1399. mindspore-2.3.0.dist-info/entry_points.txt +4 -0
  1400. mindspore-2.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2209 @@
1
+ # Copyright 2022-2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ # pylint: disable=unused-variable
17
+ """nn_ops vmap impl."""
18
+ from __future__ import absolute_import
19
+
20
+ import mindspore
21
+ from mindspore.common import Tensor
22
+ from mindspore.ops import operations as P
23
+ from mindspore.ops.operations import _grad_ops as G
24
+ from mindspore.ops.operations import nn_ops as NN
25
+ from mindspore.ops import functional as F
26
+ from mindspore.ops import constexpr
27
+ from mindspore.ops.primitive import _primexpr
28
+ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, \
29
+ _bdim_at_any, _bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, \
30
+ _vmap_clone_prim, _get_reduce_batch_axis
31
+ from mindspore.ops.primitive import Primitive
32
+ from mindspore.ops.auto_generate.gen_arg_handler import Format
33
+ from mindspore.ops.auto_generate import Embedding
34
+ from mindspore.ops.auto_generate import gen_arg_handler as handler
35
+
36
+
37
+ @vmap_rules_getters.register(P.ApplyAdaMax)
38
+ def get_apply_ada_max_rule(prim, axis_size):
39
+ """VmapRule for `ApplyAdaMax` operation."""
40
+ if hasattr(prim, 'batch_rank'):
41
+ batch_rank = prim.batch_rank + 1
42
+ else:
43
+ batch_rank = 1
44
+ prim_name = prim.name
45
+ batch_prim = _vmap_clone_prim(prim)
46
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
47
+
48
+ def vmap_rule(var_bdim, m_bdim, v_bdim, beta1_power_bdim, lr_bdim, beta1_bdim, beta2_bdim,
49
+ epsilon_bdim, grad_bdim, u_monad):
50
+ var, var_dim = var_bdim
51
+ m, m_dim = m_bdim
52
+ v, v_dim = v_bdim
53
+ lr, lr_dim = lr_bdim
54
+ beta1_power, beta1_power_dim = beta1_power_bdim
55
+ beta1, beta1_dim = beta1_bdim
56
+ beta2, beta2_dim = beta2_bdim
57
+ epsilon, epsilon_dim = epsilon_bdim
58
+ grad, grad_dim = grad_bdim
59
+
60
+ if var_dim is None:
61
+ if any(dim is not None for dim in [m_bdim, v_bdim, beta1_power_bdim, lr_bdim, beta1_bdim, beta2_bdim,
62
+ epsilon_bdim, grad_bdim]):
63
+ raise ValueError("The source axis of `var` is None, but the source "
64
+ "axis of `accum/lr/beta1/beta1_power/beta2/epsilon/grad` is not None. "
65
+ "The execution order of operator `{}` cannot be guaranteed.".format(prim_name))
66
+ var, m, v = prim(var, m, v, beta1_power, lr, beta1, beta2, epsilon, grad, u_monad)
67
+ return (var, None), (m, None), (v, None)
68
+ if var_dim != 0 or m_dim != var_dim or var_dim != v_dim:
69
+ raise ValueError("For `{}`, the source axis of `var` must be equal to `accum`, and not equal to 0, "
70
+ "but got the source axis of `var`: {}, `accum`: {}.".format(prim_name, var_dim, m_dim))
71
+
72
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
73
+ beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
74
+ beta1 = _bdim_at_front(beta1, beta1_dim, axis_size)
75
+ beta2 = _bdim_at_front(beta2, beta2_dim, axis_size)
76
+ epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
77
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
78
+ var, m, v = batch_prim(var, m, v, beta1_power, lr, beta1, beta2, epsilon, grad, u_monad)
79
+ return (var, 0), (m, 0), (v, 0)
80
+
81
+ return vmap_rule
82
+
83
+
84
+ @vmap_rules_getters.register(P.ApplyAdadelta)
85
+ def get_apply_adadelta_rule(prim, axis_size):
86
+ """VmapRule for `ApplyAdadelta` operation."""
87
+ if hasattr(prim, 'batch_rank'):
88
+ batch_rank = prim.batch_rank + 1
89
+ else:
90
+ batch_rank = 1
91
+
92
+ prim_name = prim.name
93
+ batch_prim = _vmap_clone_prim(prim)
94
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
95
+
96
+ def vmap_rule(var_bdim, accum_bdim, accum_update_bdim, lr_bdim, rho_bdim, epsilon_bdim, grad_bdim, u_monad):
97
+ var, var_dim = var_bdim
98
+ accum, accum_dim = accum_bdim
99
+ accum_update, accum_update_dim = accum_update_bdim
100
+ lr, lr_dim = lr_bdim
101
+ rho, rho_dim = rho_bdim
102
+ epsilon, epsilon_dim = epsilon_bdim
103
+ grad, grad_dim = grad_bdim
104
+
105
+ if var_dim is None:
106
+ if any(dim is not None for dim in [accum, accum_dim, lr_dim, rho_dim, epsilon_dim, grad_dim]):
107
+ raise ValueError("The source axis of `var` is None, but the source "
108
+ "axis of `accum/accum_dim/lr/rho/epsilon/grad` is not None. The execution order of "
109
+ "operator `{}` cannot be guaranteed.".format(prim_name))
110
+ var, accum, accum_update = prim(var, accum, accum_update, lr, rho, epsilon, grad, u_monad)
111
+ return (var, None), (accum, None), (accum_update, None)
112
+ if var_dim != 0 or accum_dim != var_dim or accum_update_dim != var_dim:
113
+ raise ValueError(
114
+ "For `{}`, the source axis of `var` must be equal to `accum` and `accum_update`, and not equal to 0, "
115
+ "but got the source axis of `var`: {}, `accum`: {}, `accum_update`: {}.".format(
116
+ prim_name, var_dim, accum_dim, accum_update_dim))
117
+
118
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
119
+ rho = _bdim_at_front(rho, rho_dim, axis_size)
120
+ epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
121
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
122
+
123
+ var, accum, accum_update = batch_prim(var, accum, accum_update, lr, rho, epsilon, grad, u_monad)
124
+ return (var, 0), (accum, 0), (accum_update, 0)
125
+
126
+ return vmap_rule
127
+
128
+
129
+ @vmap_rules_getters.register(P.ApplyFtrl)
130
+ def get_apply_ftrl_rule(prim, axis_size):
131
+ """VmapRule for `ApplyFtrl` operation"""
132
+ if hasattr(prim, "batch_rank"):
133
+ batch_rank = prim.batch_rank + 1
134
+ else:
135
+ batch_rank = 1
136
+ prim_name = prim.name
137
+ batch_prim = _vmap_clone_prim(prim)
138
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
139
+
140
+ def vmap_rule(var_bdim, accum_bdim, linear_bdim, grad_bdim, lr_bdim, l1_bdim, l2_bdim, lr_power_bdim, u_monad):
141
+ var, var_dim = var_bdim
142
+ accum, accum_dim = accum_bdim
143
+ linear, linear_dim = linear_bdim
144
+ grad, grad_dim = grad_bdim
145
+ lr, lr_dim = lr_bdim
146
+ l1, l1_dim = l1_bdim
147
+ l2, l2_dim = l2_bdim
148
+ lr_power, lr_power_dim = lr_power_bdim
149
+
150
+ if var_dim is None:
151
+ if any(dim is not None for dim in [accum_dim, linear_dim, grad_dim, lr_dim, l1_dim, l2_dim, lr_power_dim]):
152
+ raise ValueError("The source axis of `var` is None, "
153
+ "but the source axis of `accum/linear/grad/lr/l1/l1/lr_power` is not None. "
154
+ "The execution order of operator `{}` cannot be guaranteed.".format(prim_name))
155
+ var = prim(var, accum, linear, grad, lr, l1, l2, lr_power, u_monad)
156
+ return var, None
157
+ if var_dim != 0 or accum_dim != var_dim or linear_dim != var_dim:
158
+ raise ValueError("For `{}`, the source axis of `var/accum/linear` must be 0, "
159
+ "but get `var`: {}, `accum`: {}, `linear`: {}.".format(prim_name, var_dim, accum_dim,
160
+ linear_dim))
161
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
162
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
163
+ l1 = _bdim_at_front(l1, l1_dim, axis_size)
164
+ l2 = _bdim_at_front(l2, l2_dim, axis_size)
165
+ lr_power = _bdim_at_front(lr_power, lr_power_dim, axis_size)
166
+
167
+ var = batch_prim(var, accum, linear, grad, lr, l1, l2, lr_power, u_monad)
168
+ return var, 0
169
+
170
+ return vmap_rule
171
+
172
+
173
+ @vmap_rules_getters.register(P.ApplyProximalAdagrad)
174
+ def get_apply_proximal_adagrad_rule(prim, axis_size):
175
+ """VmapRule for `ApplyProximalAdagrad` operation."""
176
+ if hasattr(prim, 'batch_rank'):
177
+ batch_rank = prim.batch_rank + 1
178
+ else:
179
+ batch_rank = 1
180
+
181
+ prim_name = prim.name
182
+ batch_prim = _vmap_clone_prim(prim)
183
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
184
+
185
+ def vmap_rule(var_bdim, accum_bdim, lr_bdim, l1_bdim, l2_bdim, grad_bdim, u_monad):
186
+ var, var_dim = var_bdim
187
+ accum, accum_dim = accum_bdim
188
+ lr, lr_dim = lr_bdim
189
+ l1, l1_dim = l1_bdim
190
+ l2, l2_dim = l2_bdim
191
+ grad, grad_dim = grad_bdim
192
+
193
+ if var_dim is None:
194
+ if any(dim is not None for dim in [accum_dim, lr_dim, l1_dim, l2_dim, grad_dim]):
195
+ raise ValueError("The source axis of `var` is None, but the source "
196
+ "axis of `accum/lr/l1/l2/grad` is not None. The execution order of "
197
+ "operator `{}` cannot be guaranteed.".format(prim_name))
198
+ var, accum = prim(var, accum, lr, l1, l2, grad, u_monad)
199
+ return (var, None), (accum, None)
200
+
201
+ if var_dim != 0 or accum_dim != var_dim:
202
+ raise ValueError("For `{}`, the source axis of `var` must be equal to `accum`, and not equal to 0, "
203
+ "but got the source axis of `var`: {}, `accum`: {}.".format(prim_name, var_dim, accum_dim))
204
+
205
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
206
+ l1 = _bdim_at_front(l1, l1_dim, axis_size)
207
+ l2 = _bdim_at_front(l2, l2_dim, axis_size)
208
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
209
+
210
+ var, accum = batch_prim(var, accum, lr, l1, l2, grad, u_monad)
211
+ return (var, 0), (accum, 0)
212
+
213
+ return vmap_rule
214
+
215
+
216
+ @vmap_rules_getters.register(P.ApplyGradientDescent)
217
+ def get_apply_gradient_descent_rule(prim, axis_size):
218
+ """VmapRule for `ApplyGradientDescent` operation."""
219
+ if hasattr(prim, 'batch_rank'):
220
+ batch_rank = prim.batch_rank + 1
221
+ else:
222
+ batch_rank = 1
223
+
224
+ prim_name = prim.name
225
+ batch_prim = _vmap_clone_prim(prim)
226
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
227
+
228
+ def vmap_rule(var_bdim, alpha_bdim, delta_bdim, u_monad):
229
+ var, var_dim = var_bdim
230
+ alpha, alpha_dim = alpha_bdim
231
+ delta, delta_dim = delta_bdim
232
+
233
+ if var_dim is None:
234
+ if any(dim is not None for dim in [alpha_dim, delta_dim]):
235
+ raise ValueError("The source axis of `var` is None, but the source "
236
+ "axis of `alpha/delta` is not None. The execution order of "
237
+ "operator `{}` cannot be guaranteed.".format(prim_name))
238
+ var = prim(var, alpha, delta, u_monad)
239
+ return var, None
240
+
241
+ if var_dim != 0:
242
+ raise ValueError("For `{}`, the source axis of `var` must not equal to 0, "
243
+ "but got the source axis of `var`: {}.".format(prim_name, var_dim))
244
+
245
+ alpha = _bdim_at_front(alpha, alpha_dim, axis_size)
246
+ delta = _bdim_at_front(delta, delta_dim, axis_size)
247
+
248
+ var = batch_prim(var, alpha, delta, u_monad)
249
+ return var, 0
250
+
251
+ return vmap_rule
252
+
253
+
254
+ @vmap_rules_getters.register(P.ApplyProximalGradientDescent)
255
+ def get_apply_proximal_gradient_descent_rule(prim, axis_size):
256
+ """VmapRule for `ApplyProximalGradientDescent` operation."""
257
+ if hasattr(prim, 'batch_rank'):
258
+ batch_rank = prim.batch_rank + 1
259
+ else:
260
+ batch_rank = 1
261
+
262
+ prim_name = prim.name
263
+ batch_prim = _vmap_clone_prim(prim)
264
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
265
+
266
+ def vmap_rule(var_bdim, alpha_bdim, l1_bdim, l2_bdim, delta_bdim, u_monad):
267
+ var, var_dim = var_bdim
268
+ alpha, alpha_dim = alpha_bdim
269
+ l1, l1_dim = l1_bdim
270
+ l2, l2_dim = l2_bdim
271
+ delta, delta_dim = delta_bdim
272
+
273
+ if var_dim is None:
274
+ if any(dim is not None for dim in [alpha_dim, l1_dim, l2_dim, delta_dim]):
275
+ raise ValueError("The source axis of `var` is None, but the source "
276
+ "axis of `alpha/l1/l2/delta` is not None. The execution order of "
277
+ "operator `{}` cannot be guaranteed.".format(prim_name))
278
+ var = prim(var, alpha, l1, l2, delta, u_monad)
279
+ return var, None
280
+
281
+ if var_dim != 0:
282
+ raise ValueError("For `{}`, the source axis of `var` must not equal to 0, "
283
+ "but got the source axis of `var`: {}.".format(prim_name, var_dim))
284
+
285
+ alpha = _bdim_at_front(alpha, alpha_dim, axis_size)
286
+ l1 = _bdim_at_front(l1, l1_dim, axis_size)
287
+ l2 = _bdim_at_front(l2, l2_dim, axis_size)
288
+ delta = _bdim_at_front(delta, delta_dim, axis_size)
289
+
290
+ var = batch_prim(var, alpha, l1, l2, delta, u_monad)
291
+ return var, 0
292
+
293
+ return vmap_rule
294
+
295
+
296
+ @vmap_rules_getters.register(NN.BCEWithLogitsLoss)
297
+ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
298
+ """VmapRule for 'BCEWithLogitsLoss' ."""
299
+
300
+ if isinstance(prim, str):
301
+ prim = Primitive(prim)
302
+ prim_name = prim.name
303
+ bce_logits_with_loss_op = NN.BCEWithLogitsLoss('none')
304
+
305
+ def vmap_rule(logits_bdim, label_bdim, weight_bdim, pos_weight_bdim, reduction_bdim):
306
+ is_all_none, result = vmap_general_preprocess(prim, logits_bdim, label_bdim, weight_bdim, pos_weight_bdim,
307
+ reduction_bdim)
308
+ if is_all_none:
309
+ return result
310
+ logits, logits_dim = logits_bdim
311
+ label, label_dim = label_bdim
312
+ weight, weight_dim = weight_bdim
313
+ pos_weight, pos_weight_dim = pos_weight_bdim
314
+ prim_reduction, _ = reduction_bdim
315
+ logits_rank = F.rank(logits)
316
+ label_rank = F.rank(label)
317
+ weight_rank = F.rank(weight)
318
+ pos_weight_rank = F.rank(pos_weight)
319
+ max_rank = max(logits_rank, label_rank)
320
+ max_rank = max(max_rank, weight_rank)
321
+ max_rank = max(max_rank, pos_weight_rank)
322
+ reduce_indexes = None
323
+ # If rank is larger than 1, we need to reduce result when reduction != 'none'
324
+ if max_rank > 1:
325
+ reduce_indexes = tuple(range(1, max_rank))
326
+ logits_dim_ok = logits_dim == label_dim and logits_dim == weight_dim and logits_dim == pos_weight_dim
327
+ shape = F.shape(logits)
328
+ shape_ok = shape == F.shape(label) and shape == F.shape(weight) and shape == F.shape(pos_weight)
329
+ if logits_dim_ok and shape_ok:
330
+ if prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'none'):
331
+ output = prim(logits, label, weight, pos_weight, prim_reduction)
332
+ elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'mean'):
333
+ out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
334
+ output = P.ReduceMean()(out, reduce_indexes)
335
+ elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'sum'):
336
+ out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
337
+ output = P.ReduceSum()(out, reduce_indexes)
338
+ else:
339
+ raise RuntimeError("For {} vmap, the attribute of reduction must in "
340
+ "('none', 'mean', 'sum'), but got {}."
341
+ .format(prim_name, prim_reduction))
342
+ return output, logits_dim
343
+
344
+ logits = _bdim_at_front(logits, logits_dim, axis_size)
345
+ label = _bdim_at_front(label, label_dim, axis_size)
346
+ weight = _bdim_at_front(weight, weight_dim, axis_size)
347
+ pos_weight = _bdim_at_front(pos_weight, pos_weight_dim, axis_size)
348
+ logits_shape = F.shape(logits)
349
+ weight_shape = F.shape(weight)
350
+ pos_weight_shape = F.shape(pos_weight)
351
+ weight = _handle_broadcasting(weight, weight_shape, logits_shape)
352
+ pos_weight = _handle_broadcasting(pos_weight, pos_weight_shape, logits_shape)
353
+ if prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'none'):
354
+ output = prim(logits, label, weight, pos_weight, prim_reduction)
355
+ elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'mean'):
356
+ out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
357
+ output = P.ReduceMean()(out, reduce_indexes)
358
+ elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'sum'):
359
+ out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
360
+ output = P.ReduceSum()(out, reduce_indexes)
361
+ else:
362
+ raise RuntimeError("For {} vmap, the attribute of reduction must in "
363
+ "('none', 'mean', 'sum'), but got {}."
364
+ .format(prim_name, prim_reduction))
365
+ return output, 0
366
+
367
+ return vmap_rule
368
+
369
+
370
+ @vmap_rules_getters.register(P.BiasAdd)
371
+ def get_bias_add_vmap_rule(prim, axis_size):
372
+ """VmapRule for `BiasAdd` operation."""
373
+ add_op = P.Add()
374
+
375
+ @constexpr
376
+ def get_channal_pos_in_x(d_format, n_dims):
377
+ if d_format == Format.NHWC:
378
+ return n_dims
379
+ return 2
380
+
381
+ @_primexpr
382
+ def get_bias_dst_shape(x_shape, n_dims, d_format, has_b_dim: bool):
383
+ pos = get_channal_pos_in_x(d_format, n_dims)
384
+
385
+ bias_shape = []
386
+ for i in range(n_dims):
387
+ if i != pos:
388
+ bias_shape.append(1)
389
+ else:
390
+ bias_shape.append(x_shape[i])
391
+
392
+ if has_b_dim:
393
+ bias_shape[0] = axis_size
394
+
395
+ return tuple(bias_shape)
396
+
397
+ def vmap_rule(x_bdim, bias_bdim, data_format_bdim):
398
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, bias_bdim, data_format_bdim)
399
+ if is_all_none:
400
+ return result
401
+
402
+ x, x_dim = x_bdim
403
+ b, b_dim = bias_bdim
404
+ data_format_data, _ = data_format_bdim
405
+
406
+ x = _bdim_at_front(x, x_dim, axis_size)
407
+ has_b_dim = False
408
+ if b_dim is not None:
409
+ b = _bdim_at_front(b, b_dim, axis_size)
410
+ has_b_dim = True
411
+
412
+ x_shape = x.shape
413
+ n_dims = len(x_shape)
414
+ b_shape = get_bias_dst_shape(x_shape, n_dims, data_format_data, has_b_dim)
415
+
416
+ b = b.reshape(b_shape)
417
+ result = add_op(x, b)
418
+
419
+ return (result, 0)
420
+
421
+ return vmap_rule
422
+
423
+
424
+ @vmap_rules_getters.register(G.BiasAddGrad)
425
+ def get_bias_add_grad_vmap_rule(prim, axis_size):
426
+ """VmapRule for `BiasAddGrad` operation."""
427
+ @constexpr
428
+ def get_channal_pos(d_format, x_rank):
429
+ if d_format == Format.NHWC:
430
+ return x_rank
431
+ return 2
432
+
433
+ @_primexpr
434
+ def get_axis_for_reduce(x_shape_rank, data_format):
435
+ channal_pos = get_channal_pos(data_format, x_shape_rank)
436
+ axis_list = ()
437
+ for i in range(1, x_shape_rank):
438
+ if channal_pos == i:
439
+ continue
440
+ axis_list += (i,)
441
+
442
+ return axis_list
443
+
444
+ def vmap_rule(x_bdim, data_format_bdim):
445
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, data_format_bdim)
446
+ if is_all_none:
447
+ return result
448
+
449
+ x, x_dim = x_bdim
450
+ data_format_data, _ = data_format_bdim
451
+ x = _bdim_at_front(x, x_dim, axis_size)
452
+ x_shape_rank = len(x.shape)
453
+
454
+ axis_for_reduce = get_axis_for_reduce(x_shape_rank, data_format_data)
455
+
456
+ result = x.sum(axis=axis_for_reduce)
457
+ return (result, 0)
458
+
459
+ return vmap_rule
460
+
461
+
462
+ @vmap_rules_getters.register(P.Dropout)
463
+ @vmap_rules_getters.register(P.Dropout2D)
464
+ @vmap_rules_getters.register(P.Dropout3D)
465
+ def get_dropout_nd_vmap_rule(prim, axis_size):
466
+ """VmapRule for 'DropoutND' operation."""
467
+ prim_name = prim.name
468
+ dropout_nd_dim = 4
469
+ if prim_name == "Dropout3D":
470
+ dropout_nd_dim = 5
471
+
472
+ def vmap_rule(x_bdim):
473
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
474
+ if is_all_none:
475
+ return result
476
+
477
+ x, x_dim = x_bdim
478
+ x = _bdim_at_front(x, x_dim, axis_size)
479
+ x_ndim = F.rank(x)
480
+ if x_ndim > dropout_nd_dim:
481
+ x_ori_shape = F.shape(x)
482
+ x = F.reshape(x, (-1,) + x_ori_shape[2:x_ndim])
483
+ output, mask = prim(x)
484
+ output = F.reshape(output, x_ori_shape)
485
+ mask = F.reshape(mask, x_ori_shape)
486
+ else:
487
+ output, mask = prim(x)
488
+
489
+ return (output, 0), (mask, 0)
490
+
491
+ return vmap_rule
492
+
493
+
494
+ @vmap_rules_getters.register(P.InTopK)
495
+ def get_in_top_k_vmap_rule(prim, axis_size):
496
+ """VmapRule for `InTopK`."""
497
+
498
+ def vmap_rule(x1_bdim, x2_bdim):
499
+ is_all_none, result = vmap_general_preprocess(prim, x1_bdim, x2_bdim)
500
+ if is_all_none:
501
+ return result
502
+
503
+ x1, x1_dim = x1_bdim
504
+ x2, x2_dim = x2_bdim
505
+ x1 = _bdim_at_front(x1, x1_dim, axis_size)
506
+ x2 = _bdim_at_front(x2, x2_dim, axis_size)
507
+ x1_shape = F.shape(x1)
508
+ x2_shape = F.shape(x2)
509
+ x1 = F.reshape(x1, (-1, x1_shape[-1]))
510
+ x2 = F.reshape(x2, (-1,))
511
+ output = prim(x1, x2)
512
+ output = F.reshape(output, x2_shape)
513
+ return output, 0
514
+
515
+ return vmap_rule
516
+
517
+
518
+ @vmap_rules_getters.register(G.FastGeLUGrad)
519
+ @vmap_rules_getters.register(G.HSwishGrad)
520
+ @vmap_rules_getters.register(G.SoftShrinkGrad)
521
+ def get_common_activation_grad_vmap_rule(prim, axis_size):
522
+ """VmapRule for common activation grad operation."""
523
+ prim_name = prim.name
524
+
525
+ def vmap_rule(x_bdim, dy_bdim):
526
+ x, x_dim = x_bdim
527
+ dy, dy_dim = dy_bdim
528
+ x_shape = F.shape(x)
529
+ dy_shape = F.shape(dy)
530
+ if x_dim == dy_dim and x_shape == dy_shape:
531
+ out = prim(x, dy)
532
+ return out, x_dim
533
+
534
+ if F.rank(x):
535
+ x = _bdim_at_front(x, x_dim, 1)
536
+ if F.rank(dy):
537
+ dy = _bdim_at_front(dy, dy_dim, 1)
538
+ x_shape = F.shape(x)
539
+ dy_shape = F.shape(dy)
540
+ if x_shape != dy_shape:
541
+ raise RuntimeError("For {} vmap, input x shape is supposed to be the same as input dy shape "
542
+ "after batch transforming, but got x_shape {}, dy_shape {}"
543
+ .format(prim_name, x_shape, dy_shape))
544
+ out = prim(x, dy)
545
+ return out, 0
546
+
547
+ return vmap_rule
548
+
549
+
550
+ @vmap_rules_getters.register("HShrink")
551
+ def get_hshrink_vmap_rule(prim, axis_size):
552
+ """VmapRule for `HShrink`."""
553
+ def vmap_rule(x_bdim, lambd_bdim):
554
+ var, dim = x_bdim
555
+ lambd, _ = lambd_bdim
556
+ out = prim(var, lambd)
557
+ return out, dim
558
+
559
+ return vmap_rule
560
+
561
+
562
+ @vmap_rules_getters.register("HShrinkGrad")
563
+ def get_hshrink_grad_vmap_rule(prim, axis_size):
564
+ """VmapRule for `HShrinkGrad`."""
565
+ prim_name = prim.name
566
+
567
+ def vmap_rule(dy_bdim, x_bdim, lambd_bdim):
568
+ x, x_dim = x_bdim
569
+ lambd, _ = lambd_bdim
570
+ dy, dy_dim = dy_bdim
571
+ x_shape = F.shape(x)
572
+ dy_shape = F.shape(dy)
573
+ if x_dim == dy_dim and x_shape == dy_shape:
574
+ out = prim(dy, x, lambd)
575
+ return out, x_dim
576
+
577
+ if F.rank(x):
578
+ x = _bdim_at_front(x, x_dim, 1)
579
+ if F.rank(dy):
580
+ dy = _bdim_at_front(dy, dy_dim, 1)
581
+ x_shape = F.shape(x)
582
+ dy_shape = F.shape(dy)
583
+ if x_shape != dy_shape:
584
+ raise RuntimeError("For {} vmap, input x shape is supposed to be the same as input dy shape "
585
+ "after batch transforming, but got x_shape {}, dy_shape {}"
586
+ .format(prim_name, x_shape, dy_shape))
587
+ out = prim(dy, x, lambd)
588
+ return out, 0
589
+
590
+ return vmap_rule
591
+
592
+
593
+ @vmap_rules_getters.register(P.Pad)
594
+ def get_pad_vmap_rule(prim, axis_size):
595
+ """VmapRule for `Pad`"""
596
+ paddings = prim.paddings
597
+
598
+ @constexpr
599
+ def _get_paddings(cur_paddings, x_dim):
600
+ """get paddings."""
601
+ new_paddings = list(cur_paddings)
602
+ new_paddings.insert(x_dim, (0, 0))
603
+ return tuple(new_paddings)
604
+
605
+ def vmap_rule(x_bdim):
606
+ x, x_dim = x_bdim
607
+ if x_dim is None:
608
+ # case1: batch not exists
609
+ out = prim(x)
610
+ else:
611
+ # case2: batch exists
612
+ new_paddings = _get_paddings(paddings, x_dim)
613
+ op = P.Pad(new_paddings)
614
+ out = op(x)
615
+ return out, x_dim
616
+
617
+ return vmap_rule
618
+
619
+
620
+ @vmap_rules_getters.register(NN.Pdist)
621
+ def get_pdist_vmap_rule(prim, axis_size):
622
+ """VmapRule for `Pdist`"""
623
+ if isinstance(prim, str):
624
+ prim = Primitive(prim)
625
+ prim.add_prim_attr('p', 2.0)
626
+
627
+ def vmap_rule(x_bdim):
628
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
629
+ if is_all_none:
630
+ return result
631
+ x, x_dim = x_bdim
632
+ x = _bdim_at_front(x, x_dim, axis_size)
633
+ out = prim(x)
634
+ return out, 0
635
+
636
+ return vmap_rule
637
+
638
+
639
+ @vmap_rules_getters.register(NN.DeformableOffsets)
640
+ def get_matmul_vmap_rule(prim, axis_size):
641
+ """VmapRule for `DeformableOffsets` operation."""
642
+ nchw_size = 4
643
+ chw_size = 3
644
+ chw_reverse_index = -chw_size
645
+
646
+ def vmap_rule(x_bdim, offsets_bdim):
647
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, offsets_bdim)
648
+ if is_all_none:
649
+ return result
650
+
651
+ x, x_dim = x_bdim
652
+ offsets, offsets_dim = offsets_bdim
653
+ x = _bdim_at_front(x, x_dim, axis_size)
654
+ x_ndim = F.rank(x)
655
+ x_origin_shape = F.shape(x)
656
+
657
+ offsets = _bdim_at_front(offsets, offsets_dim, axis_size)
658
+ offsets_ndim = F.rank(offsets)
659
+ offsets_origin_shape = F.shape(offsets)
660
+
661
+ batch_origin_shape = x_origin_shape
662
+ if x_ndim > nchw_size:
663
+ x = F.reshape(x, (-1,) + x_origin_shape[chw_reverse_index:])
664
+ if offsets_ndim > nchw_size:
665
+ offsets = F.reshape(offsets, (-1,) + offsets_origin_shape[chw_reverse_index:])
666
+ batch_origin_shape = offsets_origin_shape
667
+
668
+ out = prim(x, offsets)
669
+ out_shape = F.shape(out)
670
+ out = F.reshape(out, batch_origin_shape[:(nchw_size + 1 - chw_size)] + out_shape[chw_reverse_index:])
671
+ return out, 0
672
+
673
+ return vmap_rule
674
+
675
+
676
+ @vmap_rules_getters.register("Softmax")
677
+ def get_softmax_vmap_rule(prim, axis_size):
678
+ """VmapRule for `Softmax`"""
679
+
680
+ def vmap_rule(x_bdim, axis_bdim):
681
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, axis_bdim)
682
+ if is_all_none:
683
+ return result
684
+ x, x_dim = x_bdim
685
+ axis, _ = axis_bdim
686
+ x_ndim = F.rank(x)
687
+ if not F.isconstant(axis) or not F.isconstant(x_ndim):
688
+ raise ValueError
689
+ batch_axis = _get_reduce_batch_axis(axis, x_dim, x_ndim)
690
+ out = prim(x, batch_axis)
691
+ return out, x_dim
692
+
693
+ return vmap_rule
694
+
695
+
696
+ @vmap_rules_getters.register(P.AdaptiveAvgPool2D)
697
+ def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
698
+ """VmapRule for `AdaptiveAvgPool2D` operation."""
699
+ chw_reverse_index = -3
700
+ hw_reverse_index = -2
701
+
702
+ def vmap_rule(input_bdim):
703
+ is_all_none, result = vmap_general_preprocess(prim, input_bdim)
704
+ if is_all_none:
705
+ return result
706
+
707
+ input_x, x_dim = input_bdim
708
+ input_x = _bdim_at_front(input_x, x_dim, axis_size)
709
+ x_shape = F.shape(input_x)
710
+ input_shape = (-1,) + x_shape[chw_reverse_index:]
711
+ input_x = F.reshape(input_x, input_shape)
712
+ out = prim(input_x)
713
+ out_shape = F.shape(out)
714
+ real_out_shape = x_shape[:hw_reverse_index] + out_shape[hw_reverse_index:]
715
+ out = F.reshape(out, real_out_shape)
716
+ return out, 0
717
+
718
+ return vmap_rule
719
+
720
+
721
+ @vmap_rules_getters.register(NN.AdaptiveAvgPool3D)
722
+ def get_adaptive_avgpool3d_vmap_rule(prim, axis_size):
723
+ """VmapRule for `AdaptiveAvgPool3D` operation."""
724
+ dhw_reverse_index = -3
725
+ max_dims = 5
726
+
727
+ def vmap_rule(x_bdim):
728
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
729
+ if is_all_none:
730
+ return result
731
+
732
+ x, x_dim = x_bdim
733
+ x = _bdim_at_front(x, x_dim, axis_size)
734
+ if F.rank(x) == max_dims:
735
+ out = prim(x)
736
+ return out, 0
737
+
738
+ x_shape = F.shape(x)
739
+ shape = (-1,) + x_shape[dhw_reverse_index:]
740
+ x = F.reshape(x, shape)
741
+ out = prim(x)
742
+ out_shape = F.shape(out)
743
+ real_out_shape = x_shape[:dhw_reverse_index] + out_shape[dhw_reverse_index:]
744
+ out = F.reshape(out, real_out_shape)
745
+ return out, 0
746
+
747
+ return vmap_rule
748
+
749
+
750
+ @vmap_rules_getters.register("AvgPool")
751
+ def get_avgpool_vmap_rule(prim, axis_size):
752
+ """VmapRule for `AvgPool`."""
753
+ chw_reverse_index = -3
754
+
755
+ def vmap_rule(x_bdim, kernel_size_bdim, strides_bdim, pad_mode_bdim, data_format_bdim):
756
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, kernel_size_bdim, strides_bdim, pad_mode_bdim,
757
+ data_format_bdim)
758
+ if is_all_none:
759
+ return result
760
+
761
+ x, x_dim = x_bdim
762
+ kernel_size, _ = kernel_size_bdim
763
+ strides, _ = strides_bdim
764
+ pad_mode, _ = pad_mode_bdim
765
+ data_format, _ = data_format_bdim
766
+ x = _bdim_at_front(x, x_dim, axis_size)
767
+ x_shape = F.shape(x)
768
+ input_shape = (-1,) + x_shape[chw_reverse_index:]
769
+ x = F.reshape(x, input_shape)
770
+ out = prim(x, kernel_size, strides, pad_mode, data_format)
771
+ out_shape = F.shape(out)
772
+ real_out_shape = x_shape[:chw_reverse_index] + out_shape[chw_reverse_index:]
773
+ out = F.reshape(out, real_out_shape)
774
+ return out, 0
775
+
776
+ return vmap_rule
777
+
778
+
779
+ @vmap_rules_getters.register(NN.AdaptiveMaxPool3D)
780
+ def get_adaptive_max_pool3d_vmap_rule(prim, axis_size):
781
+ """VmapRule for `AdaptiveMaxPool3D`."""
782
+ dhw_reverse_index = -3
783
+ max_dims = 5
784
+
785
+ @constexpr
786
+ def convert_shape_to_tensor(shape):
787
+ return Tensor(shape, dtype=mindspore.int32)
788
+
789
+ def vmap_rule(x_bdim, out_size_bdim):
790
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, out_size_bdim)
791
+ if is_all_none:
792
+ return result
793
+
794
+ x, x_dim = x_bdim
795
+ out_size, out_size_dim = out_size_bdim
796
+ x = _bdim_at_front(x, x_dim, axis_size)
797
+ if out_size_dim is not None:
798
+ _raise_value_error("The source axis of `output_size` in `AdaptiveMaxPool3D` must be None, "
799
+ "but got {}.".format(out_size_dim))
800
+ if F.rank(x) == max_dims:
801
+ out, indices = prim(x, out_size)
802
+ return (out, 0), (indices, 0)
803
+
804
+ x_shape = F.shape(x)
805
+ shape = (-1,) + x_shape[dhw_reverse_index:]
806
+ x = F.reshape(x, shape)
807
+ out, indices = prim(x, out_size)
808
+ # AdaptiveMaxPool3D is a dynamic op, the 'shape' of reshape should be a tensor
809
+ front_shape = convert_shape_to_tensor(x_shape[:dhw_reverse_index])
810
+ output_shape = F.concat((front_shape, out_size))
811
+ out = F.reshape(out, output_shape)
812
+ indices = F.reshape(indices, output_shape)
813
+ return (out, 0), (indices, 0)
814
+
815
+ return vmap_rule
816
+
817
+
818
+ @vmap_rules_getters.register(NN.InstanceNorm)
819
+ def get_instance_norm_rule(prim, axis_size):
820
+ """VmapRule for `InstanceNorm` operation."""
821
+ if hasattr(prim, 'batch_rank'):
822
+ batch_rank = prim.batch_rank + 1
823
+ else:
824
+ batch_rank = 1
825
+
826
+ prim_name = prim.name
827
+ batch_prim = _vmap_clone_prim(prim)
828
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
829
+
830
+ def vmap_rule(input_x_bdim, gamma_bdim, beta_bdim, mean_bdim, variance_bdim, u_monad):
831
+ input_x, input_x_dim = input_x_bdim
832
+ gamma, gamma_dim = gamma_bdim
833
+ beta, beta_dim = beta_bdim
834
+ mean, mean_dim = mean_bdim
835
+ variance, variance_dim = variance_bdim
836
+ if gamma_dim is None:
837
+ if any(dim is not None for dim in [input_x_dim, beta_dim, mean_dim, variance_dim]):
838
+ raise ValueError("The source axis of `gamma` is None, but the source "
839
+ "axis of `input_x/beta/mean/variance` is not None. The execution order of "
840
+ "operator `{}` cannot be guaranteed.".format(prim_name))
841
+ output_x, updated_moving_mean, updated_moving_variance = prim(input_x, gamma, beta, mean, variance, u_monad)
842
+ return (output_x, None), (updated_moving_mean, None), (updated_moving_variance, None)
843
+
844
+ precondition = gamma_dim != 0 or beta_dim != gamma_dim or mean_dim != gamma_dim or variance_dim != gamma_dim
845
+ if precondition:
846
+ # pylint: disable=too-many-format-args
847
+ raise ValueError(
848
+ "For `{}`, the source axis of `var` must be equal to `accum` and `accum_update`, and not equal to 0, "
849
+ "but got the source axis of `var`: {}, `accum`: {}, `accum_update`: {}.".format(
850
+ prim_name, gamma_dim, beta_dim, mean_dim, variance_dim))
851
+ input_x = _bdim_at_front(input_x, input_x_dim, axis_size)
852
+ output_x, updated_moving_mean, updated_moving_variance = batch_prim(input_x, gamma, beta, mean, variance,
853
+ u_monad)
854
+ return (output_x, 0), (updated_moving_mean, 0), (updated_moving_variance, 0)
855
+
856
+ return vmap_rule
857
+
858
+
859
+ @vmap_rules_getters.register(P.KLDivLoss)
860
+ def get_kl_div_loss_vmap_rule(prim, axis_size):
861
+ """VmapRule for `KLDivLoss` operation."""
862
+ if isinstance(prim, str):
863
+ prim = Primitive(prim)
864
+
865
+ prim_reduction = prim.reduction
866
+ if prim_reduction == "mean":
867
+ kl_div_loss_op = P.KLDivLoss("none")
868
+ reduce_op = P.ReduceMean()
869
+ elif prim_reduction == "sum":
870
+ kl_div_loss_op = P.KLDivLoss("none")
871
+ reduce_op = P.ReduceSum()
872
+ elif prim_reduction == "batchmean":
873
+ kl_div_loss_op = P.KLDivLoss("none")
874
+ reduce_op = P.ReduceSum()
875
+ factor_op = P.Div()
876
+
877
+ def vmap_rule(x_bdim, target_bdim):
878
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, target_bdim)
879
+ if is_all_none:
880
+ return result
881
+
882
+ x, x_dim = x_bdim
883
+ target, target_dim = target_bdim
884
+ x_ndim = F.rank(x)
885
+ target_ndim = F.rank(target)
886
+ max_rank = max(x_ndim, target_ndim)
887
+ x = _bdim_at_front(x, x_dim, axis_size)
888
+ target = _bdim_at_front(target, target_dim, axis_size)
889
+ reduce_indexes = None
890
+ factor = 1
891
+ # if rank is larger than 1, we need to reduce result when reduction != 'none'
892
+ if max_rank > 1:
893
+ reduce_indexes = tuple(range(1, max_rank))
894
+ factor = F.shape(x)[1]
895
+
896
+ # elementwise style when reduction='none', otherwise reduce style
897
+ if prim_reduction == "none":
898
+ out = prim(x, target)
899
+ elif prim_reduction in ("mean", "sum"):
900
+ out = kl_div_loss_op(x, target)
901
+ if reduce_indexes is not None:
902
+ out = reduce_op(out, reduce_indexes)
903
+ elif prim_reduction == "batchmean":
904
+ out = kl_div_loss_op(x, target)
905
+ if reduce_indexes is not None:
906
+ out = reduce_op(out, reduce_indexes)
907
+ out = factor_op(out, factor)
908
+ else:
909
+ raise RuntimeError("For KLDivLoss vmap, reduction should be one of "
910
+ "['none', 'mean', 'batchmean', 'sum'], but got '{}'".format(prim_reduction))
911
+ return out, 0
912
+
913
+ return vmap_rule
914
+
915
+
916
+ @vmap_rules_getters.register(G.KLDivLossGrad)
917
+ def get_kl_div_loss_grad_vmap_rule(prim, axis_size):
918
+ """VmapRule for `KLDivLossGrad`."""
919
+ if isinstance(prim, str):
920
+ prim = Primitive(prim)
921
+ reduction = "mean"
922
+ else:
923
+ reduction = prim.reduction
924
+
925
+ kldivloss_grad = G.KLDivLossGrad(reduction=reduction)
926
+
927
+ def vmap_rule(dy_bdim, x_bdim, target_bdim):
928
+ is_all_none, result = vmap_general_preprocess(prim, dy_bdim, x_bdim, target_bdim)
929
+ if is_all_none:
930
+ return result
931
+
932
+ dy, dy_dim = dy_bdim
933
+ x, x_dim = x_bdim
934
+ target, target_dim = target_bdim
935
+ dy = _bdim_at_front(dy, dy_dim, axis_size)
936
+ x = _bdim_at_front(x, x_dim, axis_size)
937
+ target = _bdim_at_front(target, target_dim, axis_size)
938
+
939
+ out = kldivloss_grad(dy, x, target)
940
+ return out, 0
941
+
942
+ return vmap_rule
943
+
944
+
945
+ @vmap_rules_getters.register(P.SmoothL1Loss)
946
+ def get_smooth_l1_loss_vmap_rule(prim, axis_size):
947
+ """VmapRule for `SmoothL1Loss` operation."""
948
+ if isinstance(prim, str):
949
+ prim = Primitive(prim)
950
+ prim_beta = 1.0
951
+ prim_reduction = 'none'
952
+ else:
953
+ prim_reduction = prim.reduction
954
+ prim_beta = prim.beta
955
+
956
+ smooth_l1_loss_op = P.SmoothL1Loss(prim_beta, 'none')
957
+ if prim_reduction == 'mean':
958
+ reduce_op = P.ReduceMean()
959
+ elif prim_reduction == "sum":
960
+ reduce_op = P.ReduceSum()
961
+
962
+ def vmap_rule(x_bdim, target_bdim):
963
+ is_all_none, result = vmap_general_preprocess(
964
+ prim, x_bdim, target_bdim)
965
+ if is_all_none:
966
+ return result
967
+
968
+ x, x_dim = x_bdim
969
+ target, target_dim = target_bdim
970
+ x_ndim = F.rank(x)
971
+ target_ndim = F.rank(target)
972
+ max_rank = max(x_ndim, target_ndim)
973
+ x = _bdim_at_front(x, x_dim, axis_size)
974
+ target = _bdim_at_front(target, target_dim, axis_size)
975
+ reduce_indexes = None
976
+ # if rank is larger than 1, we need to reduce result when reduction != 'none'
977
+ if max_rank > 1:
978
+ reduce_indexes = tuple(range(1, max_rank))
979
+
980
+ # elementwise style when reduction='none', otherwise reduce style
981
+ if prim_reduction == "none":
982
+ out = prim(x, target)
983
+ elif prim_reduction in ("mean", "sum"):
984
+ out = smooth_l1_loss_op(x, target)
985
+ if reduce_indexes is not None:
986
+ out = reduce_op(out, reduce_indexes)
987
+ else:
988
+ raise RuntimeError("For SmoothL1Loss vmap, reduction should be one of "
989
+ "['none', 'mean', 'sum'], but got '{}'".format(prim_reduction))
990
+ return out, 0
991
+
992
+ return vmap_rule
993
+
994
+
995
+ @vmap_rules_getters.register(G.SmoothL1LossGrad)
996
+ def get_smooth_l1_loss_grad_vmap_rule(prim, axis_size):
997
+ """VmapRule for `SmoothL1LossGrad`."""
998
+ if isinstance(prim, str):
999
+ prim = Primitive(prim)
1000
+ reduction = "none"
1001
+ beta = 1.0
1002
+ else:
1003
+ reduction = prim.reduction
1004
+ beta = prim.beta
1005
+ smooth_l1_loss_grad = G.SmoothL1LossGrad(beta, reduction)
1006
+
1007
+ def vmap_rule(x_bdim, target_bdim, dy_bdim):
1008
+ is_all_none, result = vmap_general_preprocess(
1009
+ prim, dy_bdim, x_bdim, target_bdim)
1010
+ if is_all_none:
1011
+ return result
1012
+
1013
+ dy, dy_dim = dy_bdim
1014
+ x, x_dim = x_bdim
1015
+ target, target_dim = target_bdim
1016
+ dy = _bdim_at_front(dy, dy_dim, axis_size)
1017
+ x = _bdim_at_front(x, x_dim, axis_size)
1018
+ target = _bdim_at_front(target, target_dim, axis_size)
1019
+
1020
+ out = smooth_l1_loss_grad(x, target, dy)
1021
+ return out, 0
1022
+
1023
+ return vmap_rule
1024
+
1025
+
1026
+ @vmap_rules_getters.register(P.nn_ops.LogSoftmax)
1027
+ def get_log_softmax_vmap_rule(prim_func, axis_size):
1028
+ """VmapRule for 'LogSoftmax' operation."""
1029
+ def vmap_rule(x_bdim, axis_bdim):
1030
+ is_all_none, result = vmap_general_preprocess(prim_func, x_bdim, axis_bdim)
1031
+ if is_all_none:
1032
+ return result
1033
+ x, x_dim = x_bdim
1034
+ axis, _ = axis_bdim
1035
+ x_ndim = F.rank(x) - 1
1036
+
1037
+ batch_axis = axis + x_ndim if axis < 0 else axis
1038
+ batch_axis = batch_axis if batch_axis < x_dim else batch_axis + 1
1039
+
1040
+ out = F.log_softmax(x, batch_axis)
1041
+ return out, x_dim
1042
+
1043
+ return vmap_rule
1044
+
1045
+
1046
+ @vmap_rules_getters.register(P.RandomCategorical)
1047
+ def get_random_categorical_vmap_rule(prim, axis_size):
1048
+ """VmapRule for `RandomCategorical` operation."""
1049
+
1050
+ default_dim = 2
1051
+
1052
+ def vmap_rule(logits_bdim, num_sample_bdim, seed_bdim):
1053
+ is_all_none, result = vmap_general_preprocess(prim, logits_bdim, num_sample_bdim, seed_bdim)
1054
+ if is_all_none:
1055
+ return result
1056
+ logits, logits_dim = logits_bdim
1057
+ num_sample, num_sample_dim = num_sample_bdim
1058
+ seed, seed_dim = seed_bdim
1059
+ if num_sample_dim is not None or seed_dim is not None:
1060
+ raise RuntimeError("For RandomCategorical vmap, num_sample and seed should be None.")
1061
+ # Move axis to first dim
1062
+ logits = _bdim_at_front(logits, logits_dim, axis_size)
1063
+ x_ndim = F.rank(logits)
1064
+ if x_ndim > default_dim:
1065
+ x_ori_shape = F.shape(logits)
1066
+ logits = F.reshape(logits, (-1, x_ori_shape[-1]))
1067
+ dx = prim(logits, num_sample, seed)
1068
+ new_output_shape = (x_ori_shape[0], x_ori_shape[1], num_sample)
1069
+ dx = F.reshape(dx, new_output_shape)
1070
+ else:
1071
+ dx = prim(logits, num_sample, seed)
1072
+ return dx, 0
1073
+
1074
+ return vmap_rule
1075
+
1076
+
1077
+ @vmap_rules_getters.register(NN.LRN)
1078
+ def get_lrn_vmap_rule(prim, axis_size):
1079
+ """VmapRule for `LRN` operation."""
1080
+ lrn_default_dim = 4
1081
+ lrn_pre_remain_dim = 3
1082
+
1083
+ def vmap_rule(x_bdim):
1084
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1085
+ if is_all_none:
1086
+ return result
1087
+ input_x, input_x_dim = x_bdim
1088
+ # Move axis to last dim
1089
+ x = _bdim_at_back(input_x, input_x_dim, axis_size)
1090
+ x_ndim = F.rank(x)
1091
+ if x_ndim > lrn_default_dim:
1092
+ x_ori_shape = F.shape(x)
1093
+ x = F.reshape(x, x_ori_shape[:lrn_pre_remain_dim] + (-1,))
1094
+ out = prim(x)
1095
+ out = F.reshape(out, x_ori_shape)
1096
+ else:
1097
+ out = prim(x)
1098
+ return out, x_ndim - 1
1099
+
1100
+ return vmap_rule
1101
+
1102
+
1103
+ @vmap_rules_getters.register(NN.PadV3)
1104
+ def get_pad_v3_vmap_rule(prim, axis_size):
1105
+ """VmapRule for `PadV3` operation."""
1106
+ pad_pair = 2
1107
+ input_max_dim = 4
1108
+ mode = prim.mode
1109
+
1110
+ def vmap_rule(*params_bdim):
1111
+ is_all_none, result = vmap_general_preprocess(
1112
+ prim, params_bdim)
1113
+ if is_all_none:
1114
+ return result
1115
+ if len(params_bdim) < 2:
1116
+ _raise_value_error("The input params in `PadV3` must >= 2, "
1117
+ "but got {}.".format(len(params_bdim)))
1118
+ input_x, input_x_dim = params_bdim[0]
1119
+ paddings, paddings_dim = params_bdim[1]
1120
+ values = None
1121
+ out = None
1122
+ x = _bdim_at_front(input_x, input_x_dim, axis_size)
1123
+ if paddings_dim is not None:
1124
+ _raise_value_error("The source axis of `paddings` in `PadV3` must be None, "
1125
+ "but got {}.".format(paddings_dim))
1126
+ if mode == "constant":
1127
+ if len(params_bdim) != 3:
1128
+ _raise_value_error("The input params in `PadV3` of constant mode must be 3, "
1129
+ "but got {}.".format(len(params_bdim)))
1130
+ values, values_dim = params_bdim[2]
1131
+ if values_dim is not None:
1132
+ _raise_value_error("The source axis of `values_dim` in `PadV3` must be None, "
1133
+ "but got {}.".format(values_dim))
1134
+ if isinstance(paddings, Tensor):
1135
+ pad_dim = F.shape(paddings)[0] / pad_pair
1136
+ else:
1137
+ pad_dim = len(paddings) / pad_pair
1138
+ x_ndim = F.rank(x)
1139
+ # pylint: disable=chained-comparison
1140
+ if pad_dim < x_ndim and x_ndim < input_max_dim:
1141
+ if mode == "constant":
1142
+ out = prim(x, paddings, values)
1143
+ else:
1144
+ out = prim(x, paddings)
1145
+ elif x_ndim >= input_max_dim:
1146
+ # reshape to 4 dims
1147
+ x_shape = F.shape(x)
1148
+ diff_dim = x_ndim - input_max_dim
1149
+ first_shape = 1
1150
+ for i in range(diff_dim + 1):
1151
+ first_shape *= x_shape[i]
1152
+ input_shape = (first_shape,) + x_shape[(-input_max_dim + 1):]
1153
+ x = F.reshape(x, input_shape)
1154
+ if mode == "constant":
1155
+ out = prim(x, paddings, values)
1156
+ else:
1157
+ out = prim(x, paddings)
1158
+ out_shape = F.shape(out)
1159
+ real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
1160
+ out = F.reshape(out, real_out_shape)
1161
+ else:
1162
+ _raise_value_error("The dim of `input_x` in `PadV3` must be bigger than {}, "
1163
+ "but got {}.".format(pad_dim, x_ndim))
1164
+ return out, 0
1165
+
1166
+ return vmap_rule
1167
+
1168
+
1169
+ @vmap_rules_getters.register(NN.MirrorPad)
1170
+ def get_mirror_pad_vmap_rule(prim, axis_size):
1171
+ """VmapRule for `MirrorPad` operation."""
1172
+ input_max_dim = 4
1173
+
1174
+ def vmap_rule(*params_bdim):
1175
+ is_all_none, result = vmap_general_preprocess(prim, params_bdim)
1176
+ if is_all_none:
1177
+ return result
1178
+ if len(params_bdim) < 2:
1179
+ _raise_value_error("The input params in `{}` must >= 2, but got {}.".format(prim.name, len(params_bdim)))
1180
+ input_x, input_x_dim = params_bdim[0]
1181
+ paddings, paddings_dim = params_bdim[1]
1182
+
1183
+ out = None
1184
+ x = _bdim_at_front(input_x, input_x_dim, axis_size)
1185
+ if paddings_dim is not None:
1186
+ _raise_value_error(
1187
+ "The source axis of `paddings` in `{}` must be None, but got {}.".format(prim.name, paddings_dim))
1188
+ pad_dim = F.shape(paddings)[0]
1189
+ x_ndim = F.rank(x)
1190
+
1191
+ if pad_dim == x_ndim and x_ndim <= input_max_dim:
1192
+ out = prim(x, paddings)
1193
+ elif x_ndim > input_max_dim:
1194
+ # reshape to 4 dims
1195
+ x_shape = F.shape(x)
1196
+ diff_dim = x_ndim - input_max_dim
1197
+ first_shape = 1
1198
+ for i in range(diff_dim + 1):
1199
+ first_shape *= x_shape[i]
1200
+ input_shape = (first_shape,) + x_shape[(-input_max_dim + 1):]
1201
+ x = F.reshape(x, input_shape)
1202
+ out = prim(x, paddings)
1203
+ out_shape = F.shape(out)
1204
+ real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
1205
+ out = F.reshape(out, real_out_shape)
1206
+ else:
1207
+ _raise_value_error("The dim of `input_x` in `{}` must be bigger than {}, "
1208
+ "but got {}.".format(prim.name, pad_dim, x_ndim))
1209
+ return out, 0
1210
+
1211
+ return vmap_rule
1212
+
1213
+
1214
+ @vmap_rules_getters.register(G.LRNGrad)
1215
+ def get_lrn_grad_vmap_rule(prim, axis_size):
1216
+ """VmapRule for `LRNGrad` operation."""
1217
+ lrn_default_dim = 4
1218
+ lrn_pre_remain_dim = 3
1219
+
1220
+ def vmap_rule(dout_bdim, x_bdim, out_bdim):
1221
+ is_all_none, result = vmap_general_preprocess(prim, dout_bdim, x_bdim, out_bdim)
1222
+ if is_all_none:
1223
+ return result
1224
+ x, x_dim = x_bdim
1225
+ dy, dy_dim = dout_bdim
1226
+ y, y_dim = out_bdim
1227
+ # Move axis to last dim
1228
+ x = _bdim_at_back(x, x_dim, axis_size)
1229
+ dy = _bdim_at_back(dy, dy_dim, axis_size)
1230
+ y = _bdim_at_back(y, y_dim, axis_size)
1231
+ x_ndim = F.rank(x)
1232
+ if x_ndim > lrn_default_dim:
1233
+ x_ori_shape = F.shape(x)
1234
+ dy_ori_shape = F.shape(dy)
1235
+ y_ori_shape = F.shape(y)
1236
+ x = F.reshape(x, x_ori_shape[:lrn_pre_remain_dim] + (-1,))
1237
+ dy = F.reshape(dy, dy_ori_shape[:lrn_pre_remain_dim] + (-1,))
1238
+ y = F.reshape(y, y_ori_shape[:lrn_pre_remain_dim] + (-1,))
1239
+ dx = prim(dy, x, y)
1240
+ dx = F.reshape(dx, x_ori_shape)
1241
+ else:
1242
+ dx = prim(dy, x, y)
1243
+ return dx, x_ndim - 1
1244
+
1245
+ return vmap_rule
1246
+
1247
+
1248
+ @vmap_rules_getters.register(P.BatchNorm)
1249
+ def get_batchnorm_vmap_rule(prim, axis_size):
1250
+ """VmapRule for `BatchNorm` operation."""
1251
+ bn_min_dim = 3
1252
+ bn_max_dim = 5
1253
+ prim_name = "BatchNorm"
1254
+ NCHW = Format.NCHW
1255
+
1256
+ def vmap_rule(*inputs):
1257
+ is_all_none, result = vmap_general_preprocess(prim, *inputs)
1258
+ if is_all_none:
1259
+ return result
1260
+ input_x, input_x_dim = inputs[0]
1261
+ scale, scale_dim = inputs[1]
1262
+ offset, offset_dim = inputs[2]
1263
+ mean, mean_dim = inputs[3]
1264
+ var, var_dim = inputs[4]
1265
+ is_training, _ = inputs[5]
1266
+ epsilon, _ = inputs[6]
1267
+ momentum, _ = inputs[7]
1268
+ data_format, _ = inputs[8]
1269
+ if is_training:
1270
+ raise ValueError("Operator {} does not support Vmap during training, since the input `scale, offset, mean, "
1271
+ "var of BatchNorm are parameters when is_training = true. If multiple batches of input "
1272
+ "data share the same parameters, please stack batches to the N dimension manually."
1273
+ .format(prim_name))
1274
+ x_ndim = F.rank(input_x)
1275
+ if x_ndim < bn_min_dim or x_ndim > bn_max_dim:
1276
+ raise ValueError("The dim of `input_x` in `{}` must be larger than {} and less than {}, "
1277
+ "but got {}.".format(prim_name, bn_min_dim - 1, bn_max_dim + 1, x_ndim))
1278
+ # Move input_x axis to the dim front of C
1279
+ out_axis = 1 if data_format == NCHW else x_ndim - 2
1280
+ input_x = _bdim_at_any(input_x, input_x_dim, out_axis, axis_size)
1281
+ scale = _bdim_at_front(scale, scale_dim, axis_size)
1282
+ offset = _bdim_at_front(offset, offset_dim, axis_size)
1283
+ mean = _bdim_at_front(mean, mean_dim, axis_size)
1284
+ var = _bdim_at_front(var, var_dim, axis_size)
1285
+ x_shape = input_x.shape
1286
+ other_shape = scale.shape
1287
+ vmap_shape = (x_shape[0], -1,) + x_shape[3:] if data_format == NCHW else x_shape[:-2] + (-1,)
1288
+ input_x = F.reshape(input_x, vmap_shape)
1289
+ scale = scale.flatten()
1290
+ offset = offset.flatten()
1291
+ mean = mean.flatten()
1292
+ var = var.flatten()
1293
+ out, batch_mean, batch_var, rsv_1, rsv_2 =\
1294
+ prim(input_x, scale, offset, mean, var, is_training, epsilon, momentum, data_format)
1295
+ out = F.reshape(out, x_shape)
1296
+ batch_mean = F.reshape(batch_mean, other_shape)
1297
+ batch_var = F.reshape(batch_var, other_shape)
1298
+ rsv_1 = F.reshape(rsv_1, other_shape)
1299
+ rsv_2 = F.reshape(rsv_2, other_shape)
1300
+ return (out, out_axis), (batch_mean, 0), (batch_var, 0), (rsv_1, 0), (rsv_2, 0)
1301
+
1302
+ return vmap_rule
1303
+
1304
+
1305
+ @vmap_rules_getters.register(P.ApplyAdamWithAmsgrad)
1306
+ def get_apply_adam_with_amsgrad_rule(prim, axis_size):
1307
+ """VmapRule for `ApplyAdamWithAmsgrad` operation"""
1308
+ if hasattr(prim, "batch_rank"):
1309
+ batch_rank = prim.batch_rank + 1
1310
+ else:
1311
+ batch_rank = 1
1312
+ prim_name = prim.name
1313
+ batch_prim = _vmap_clone_prim(prim)
1314
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1315
+
1316
+ def vmap_rule(var_bdim, m_bdim, v_bdim, vhat_bdim, beta1_power_bdim, beta2_power_bdim, lr_bdim, grad_bdim, u_monad):
1317
+ var, var_dim = var_bdim
1318
+ m, m_dim = m_bdim
1319
+ v, v_dim = v_bdim
1320
+ vhat, vhat_dim = vhat_bdim
1321
+ beta1_power, beta1_power_dim = beta1_power_bdim
1322
+ beta2_power, beta2_power_dim = beta2_power_bdim
1323
+ lr, lr_dim = lr_bdim
1324
+ grad, grad_dim = grad_bdim
1325
+
1326
+ if var_dim is None:
1327
+ if any(dim is not None for dim in [m_dim, v_dim, vhat_dim, beta1_power_dim,
1328
+ beta2_power_dim, lr_dim, grad_dim]):
1329
+ raise ValueError("The source axis of `var` is None, "
1330
+ "but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/grad` is not None. "
1331
+ "The execution of operator `{}` cannot be guaranteed.".format(prim_name))
1332
+ out_var, out_m, out_v, out_vhat = prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
1333
+ return (out_var, None), (out_m, None), (out_v, None), (out_vhat, None)
1334
+
1335
+ if any(dim != 0 for dim in [var_dim, m_dim, v_dim, vhat_dim]):
1336
+ raise ValueError("For `{}`, the source axis of `var/m/v/vhat` must be 0, "
1337
+ "but get `var`: {}, `m`: {}, `v`: {}, `vhat`: {}".format(prim_name, var_dim,
1338
+ m_dim, v_dim, vhat_dim))
1339
+
1340
+ beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
1341
+ beta2_power = _bdim_at_front(beta2_power, beta2_power_dim, axis_size)
1342
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1343
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1344
+
1345
+ out_var, out_m, out_v, out_vhat = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
1346
+ return (out_var, 0), (out_m, 0), (out_v, 0), (out_vhat, 0)
1347
+
1348
+ return vmap_rule
1349
+
1350
+
1351
+ @vmap_rules_getters.register(P.ApplyAdamWithAmsgradV2)
1352
+ def get_apply_adam_with_amsgrad_v2_rule(prim, axis_size):
1353
+ """VmapRule for `ApplyAdamWithAmsgradV2` operation"""
1354
+ if hasattr(prim, "batch_rank"):
1355
+ batch_rank = prim.batch_rank + 1
1356
+ else:
1357
+ batch_rank = 1
1358
+ prim_name = prim.name
1359
+ batch_prim = _vmap_clone_prim(prim)
1360
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1361
+
1362
+ def vmap_rule(var_bdim, m_bdim, v_bdim, vhat_bdim, beta1_power_bdim, beta2_power_bdim, lr_bdim, beta1_bdim,
1363
+ beta2_bdim, epsilon_bdim, grad_bdim, u_monad):
1364
+ var, var_dim = var_bdim
1365
+ m, m_dim = m_bdim
1366
+ v, v_dim = v_bdim
1367
+ vhat, vhat_dim = vhat_bdim
1368
+ beta1_power, beta1_power_dim = beta1_power_bdim
1369
+ beta2_power, beta2_power_dim = beta2_power_bdim
1370
+ lr, lr_dim = lr_bdim
1371
+ beta1, beta1_dim = beta1_bdim
1372
+ beta2, beta2_dim = beta2_bdim
1373
+ epsilon, epsilon_dim = epsilon_bdim
1374
+ grad, grad_dim = grad_bdim
1375
+
1376
+ if var_dim is None:
1377
+ if any(dim is not None for dim in [m_dim, v_dim, vhat_dim, beta1_power_dim,
1378
+ beta2_power_dim, lr_dim, beta1_dim, beta2_dim, grad_dim]):
1379
+ raise ValueError("The source axis of `var` is None, "
1380
+ "but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/beta1/beta2/grad` is not "
1381
+ "None. The execution of operator `{}` cannot be guaranteed.".format(prim_name))
1382
+ out_var, out_m, out_v, out_vhat = prim(var, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2, epsilon,
1383
+ grad, u_monad)
1384
+ return (out_var, None), (out_m, None), (out_v, None), (out_vhat, None)
1385
+
1386
+ if any(dim != 0 for dim in [var_dim, m_dim, v_dim, vhat_dim]):
1387
+ raise ValueError("For `{}`, the source axis of `var/m/v/vhat` must be 0, "
1388
+ "but get `var`: {}, `m`: {}, `v`: {}, `vhat`: {}".format(prim_name, var_dim,
1389
+ m_dim, v_dim, vhat_dim))
1390
+
1391
+ beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
1392
+ beta2_power = _bdim_at_front(beta2_power, beta2_power_dim, axis_size)
1393
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1394
+ beta1 = _bdim_at_front(beta1, beta1_dim, axis_size)
1395
+ beta2 = _bdim_at_front(beta2, beta2_dim, axis_size)
1396
+ epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
1397
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1398
+
1399
+ out_var, out_m, out_v, out_vhat = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2,
1400
+ epsilon, grad, u_monad)
1401
+ return (out_var, 0), (out_m, 0), (out_v, 0), (out_vhat, 0)
1402
+
1403
+ return vmap_rule
1404
+
1405
+
1406
+ @vmap_rules_getters.register(P.Adam)
1407
+ def get_adam_rule(prim, axis_size):
1408
+ """VmapRule for `Adam` operation"""
1409
+ if hasattr(prim, "batch_rank"):
1410
+ batch_rank = prim.batch_rank + 1
1411
+ else:
1412
+ batch_rank = 1
1413
+ prim_name = prim.name
1414
+ batch_prim = _vmap_clone_prim(prim)
1415
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1416
+
1417
+ def vmap_rule(var_bdim, m_bdim, v_bdim, beta1_power_bdim, beta2_power_bdim, lr_bdim, beta1_bdim,
1418
+ beta2_bdim, epsilon_bdim, grad_bdim, u_monad):
1419
+ var, var_dim = var_bdim
1420
+ m, m_dim = m_bdim
1421
+ v, v_dim = v_bdim
1422
+ beta1_power, beta1_power_dim = beta1_power_bdim
1423
+ beta2_power, beta2_power_dim = beta2_power_bdim
1424
+ lr, lr_dim = lr_bdim
1425
+ beta1, beta1_dim = beta1_bdim
1426
+ beta2, beta2_dim = beta2_bdim
1427
+ epsilon, epsilon_dim = epsilon_bdim
1428
+ grad, grad_dim = grad_bdim
1429
+
1430
+ all_dim = [m_dim, v_dim, beta1_power_dim, beta2_power_dim, lr_dim, beta1_dim, beta2_dim, epsilon_dim, grad_dim]
1431
+ if var_dim is None:
1432
+ if any(dim is not None for dim in all_dim):
1433
+ raise ValueError("The source axis of `var` is None, "
1434
+ "but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/beta1/beta2/epsilon grad"
1435
+ " is not None. The execution of operator `{}` cannot be guaranteed.".format(prim_name))
1436
+ out_var, out_m, out_v = prim(
1437
+ var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, u_monad)
1438
+ return ((out_var, None), (out_m, None), (out_v, None))
1439
+
1440
+ if any(dim != 0 for dim in [var_dim, m_dim, v_dim]):
1441
+ raise ValueError("For `{}`, the source axis of `var/m/v` must be 0, "
1442
+ "but get `var`: {}, `m`: {}, `v`: {}".format(prim_name, var_dim,
1443
+ m_dim, v_dim))
1444
+
1445
+ beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
1446
+ beta2_power = _bdim_at_front(beta2_power, beta2_power_dim, axis_size)
1447
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1448
+ beta1 = _bdim_at_front(beta1, beta1_dim, axis_size)
1449
+ beta2 = _bdim_at_front(beta2, beta2_dim, axis_size)
1450
+ epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
1451
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1452
+
1453
+ out_var, out_m, out_v = batch_prim(
1454
+ var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, u_monad)
1455
+ return ((out_var, 0), (out_m, 0), (out_v, 0))
1456
+
1457
+ return vmap_rule
1458
+
1459
+
1460
+ @vmap_rules_getters.register(P.ApplyPowerSign)
1461
+ def get_apply_power_sign_rule(prim, axis_size):
1462
+ """VmapRule for `ApplyPowerSign` operation."""
1463
+ if hasattr(prim, 'batch_rank'):
1464
+ batch_rank = prim.batch_rank + 1
1465
+ else:
1466
+ batch_rank = 1
1467
+
1468
+ prim_name = prim.name
1469
+ batch_prim = _vmap_clone_prim(prim)
1470
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1471
+
1472
+ def vmap_rule(var_bdim, m_bdim, lr_bdim, logbase_bdim, sign_decay_bdim, beta_bdim, grad_bdim, u_monad):
1473
+ var, var_dim = var_bdim
1474
+ m, m_dim = m_bdim
1475
+ lr, lr_dim = lr_bdim
1476
+ logbase, logbase_dim = logbase_bdim
1477
+ sign_decay, sign_decay_dim = sign_decay_bdim
1478
+ beta, beta_dim = beta_bdim
1479
+ grad, grad_dim = grad_bdim
1480
+
1481
+ if var_dim is None:
1482
+ if any(dim is not None for dim in [m_bdim, lr_bdim, logbase_bdim, sign_decay_bdim, beta_bdim, grad_bdim]):
1483
+ raise ValueError("The source axis of `var` is None, but the source "
1484
+ "axis of `m/lr/logbase/sign_decay/beta/grad` is not None. The execution order of "
1485
+ "operator `{}` cannot be guaranteed.".format(prim_name))
1486
+ var, m = prim(var, m, lr, logbase, sign_decay, beta, grad, u_monad)
1487
+ return (var, None), (m, None)
1488
+ if var_dim != 0 or m_dim != var_dim:
1489
+ raise ValueError("For `{}`, the source axis of `var` must be equal to `m`, and not equal to 0, "
1490
+ "but got the source axis of `var`: {}, `m`: {}.".format(prim_name, var_dim, m_dim))
1491
+
1492
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1493
+ logbase = _bdim_at_front(logbase, logbase_dim, axis_size)
1494
+ sign_decay = _bdim_at_front(sign_decay, sign_decay_dim, axis_size)
1495
+ beta = _bdim_at_front(beta, beta_dim, axis_size)
1496
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1497
+ var, m = batch_prim(var, m, lr, logbase, sign_decay, beta, grad, u_monad)
1498
+ return (var, 0), (m, 0)
1499
+
1500
+ return vmap_rule
1501
+
1502
+
1503
+ @vmap_rules_getters.register(P.ApplyAdagradV2)
1504
+ def get_apply_adagrad_v2_vmap_rule(prim, axis_size):
1505
+ """VmapRule for `ApplyAdagradV2` operation."""
1506
+ if hasattr(prim, 'batch_rank'):
1507
+ batch_rank = prim.batch_rank + 1
1508
+ else:
1509
+ batch_rank = 1
1510
+
1511
+ batch_prim = _vmap_clone_prim(prim)
1512
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
1513
+ prim_name = prim.name
1514
+
1515
+ def vmap_rule(var_bdim, accum_bdim, lr_bdim, grad_bdim, u_monad):
1516
+ var, var_dim = var_bdim
1517
+ accum, accum_dim = accum_bdim
1518
+ lr, lr_dim = lr_bdim
1519
+ grad, grad_dim = grad_bdim
1520
+
1521
+ if var_dim is None:
1522
+ if any(dim is not None for dim in
1523
+ [accum_bdim, lr_dim, grad_bdim]):
1524
+ raise ValueError("The source axis of 'var' is None, but the source "
1525
+ "axis of 'accum/lr/grad'"
1526
+ " is not None. The execution order of "
1527
+ "operator '{}' cannot be guaranteed.".format(prim_name))
1528
+ var, accum = prim(var, accum, lr, grad, u_monad)
1529
+ return (var, None), (accum, None)
1530
+ if var_dim != 0 or var_dim != accum_dim:
1531
+ raise ValueError(
1532
+ f"For '{prim_name}', the source axis of 'var' must be equal to 'accum_dim' "
1533
+ f"and not equal to 0, but got the source axis of 'var': {var_dim}, "
1534
+ f"'accum_dim': {accum_dim}")
1535
+
1536
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1537
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1538
+
1539
+ var, accum = batch_prim(var, accum, lr, grad, u_monad)
1540
+ return (var, 0), (accum, 0)
1541
+
1542
+ return vmap_rule
1543
+
1544
+
1545
+ @vmap_rules_getters.register(P.ApplyAdagradDA)
1546
+ def get_apply_adagrad_da_vmap_rule(prim, axis_size):
1547
+ """VmapRule for `ApplyAdagradDA` operation."""
1548
+ if hasattr(prim, 'batch_rank'):
1549
+ batch_rank = prim.batch_rank + 1
1550
+ else:
1551
+ batch_rank = 1
1552
+
1553
+ attr = prim.init_attrs
1554
+ batch_prim = P.ApplyAdagradDA(**attr)
1555
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
1556
+ prim_name = prim.name
1557
+
1558
+ def vmap_rule(var_bdim, gradient_accumulator_bdim, gradient_squared_accumulator_bdim, grad_bdim, lr_bdim, l1_bdim,
1559
+ l2_bdim, global_step_bdim, u_monad):
1560
+ var, var_dim = var_bdim
1561
+ gradient_accumulator, gradient_accumulator_dim = gradient_accumulator_bdim
1562
+ gradient_squared_accumulator, gradient_squared_accumulator_dim = gradient_squared_accumulator_bdim
1563
+ grad, grad_dim = grad_bdim
1564
+ lr, lr_dim = lr_bdim
1565
+ l1, l1_dim = l1_bdim
1566
+ l2, l2_dim = l2_bdim
1567
+ global_step, global_step_dim = global_step_bdim
1568
+
1569
+ if var_dim is None:
1570
+ if any(dim is not None for dim in
1571
+ [gradient_accumulator_bdim, gradient_squared_accumulator_bdim, grad_bdim, lr_bdim, l1_bdim, l2_bdim,
1572
+ global_step_bdim]):
1573
+ raise ValueError("The source axis of 'var' is None, but the source "
1574
+ "axis of 'gradient_accumulator/gradient_squared_accumulator/grad/lr/l1/l2/global_step'"
1575
+ " is not None. The execution order of "
1576
+ "operator '{}' cannot be guaranteed.".format(prim_name))
1577
+ var, gradient_accumulator, gradient_squared_accumulator = prim(var, gradient_accumulator,
1578
+ gradient_squared_accumulator, grad, lr, l1,
1579
+ l2,
1580
+ global_step,
1581
+ u_monad) # Low dimensional operator
1582
+ return (var, None), (gradient_accumulator, None), (gradient_squared_accumulator, None)
1583
+ if var_dim != 0 or var_dim != gradient_accumulator_dim or var_dim != gradient_squared_accumulator_dim:
1584
+ raise ValueError(
1585
+ f"For '{prim_name}', the source axis of 'var' must be equal to 'gradient_accumulator_dim' "
1586
+ f"and 'gradient_squared_accumulator_dim' and not equal to 0, "
1587
+ f"but got the source axis of 'var': {var_dim}, "
1588
+ f"'gradient_accumulator_dim': {gradient_accumulator_dim}, "
1589
+ f"'gradient_squared_accumulator_dim': {gradient_squared_accumulator_dim}")
1590
+
1591
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1592
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1593
+ l1 = _bdim_at_front(l1, l1_dim, axis_size)
1594
+ l2 = _bdim_at_front(l2, l2_dim, axis_size)
1595
+ global_step = _bdim_at_front(global_step, global_step_dim, axis_size)
1596
+
1597
+ var = batch_prim(var, gradient_accumulator,
1598
+ gradient_squared_accumulator, grad, lr, l1,
1599
+ l2,
1600
+ global_step,
1601
+ u_monad) # High dimensional operator;
1602
+ return (var, 0)
1603
+
1604
+ return vmap_rule
1605
+
1606
+
1607
+ @vmap_rules_getters.register(NN.AdaptiveMaxPool2D)
1608
+ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1609
+ """VmapRule for `AdaptiveMaxPool2D`."""
1610
+ nchw_index = 4
1611
+ chw_reverse_index = -3
1612
+ hw_size = 2
1613
+ output_size = prim.output_size
1614
+
1615
+ @_primexpr
1616
+ def get_output_shape(x_ori_shape, output_size):
1617
+ if isinstance(output_size, tuple):
1618
+ h_out, w_out = output_size
1619
+ else:
1620
+ h_out = output_size
1621
+ w_out = output_size
1622
+
1623
+ rank = len(x_ori_shape)
1624
+ output_shape = x_ori_shape[:rank - hw_size]
1625
+ if h_out is None or h_out == -1:
1626
+ output_shape += (x_ori_shape[-2],)
1627
+ else:
1628
+ output_shape += (h_out,)
1629
+
1630
+ if w_out is None or w_out == -1:
1631
+ output_shape += (x_ori_shape[-1],)
1632
+ else:
1633
+ output_shape += (w_out,)
1634
+ return output_shape
1635
+
1636
+ def vmap_rule(input_x_bdim):
1637
+ is_all_none, result = vmap_general_preprocess(prim, input_x_bdim)
1638
+ if is_all_none:
1639
+ return result
1640
+
1641
+ input_x, input_x_dim = input_x_bdim
1642
+ x = _bdim_at_front(input_x, input_x_dim, axis_size)
1643
+ x_ndim = F.rank(x)
1644
+
1645
+ if x_ndim > nchw_index:
1646
+ # for the case of NCHW
1647
+ x_ori_shape = F.shape(x)
1648
+ x = F.reshape(x, (-1,) + x_ori_shape[chw_reverse_index:])
1649
+ output_shape = get_output_shape(x_ori_shape, output_size)
1650
+ out, indices = prim(x)
1651
+ out = F.reshape(out, output_shape)
1652
+ indices = F.reshape(indices, output_shape)
1653
+ return (out, 0), (indices, 0)
1654
+
1655
+ # for the case of CHW
1656
+ out, indices = prim(x)
1657
+ return (out, 0), (indices, 0)
1658
+
1659
+ return vmap_rule
1660
+
1661
+
1662
+ @vmap_rules_getters.register(NN.MaxPool3DWithArgmax)
1663
+ def get_max_pool3d_with_argmax_vmap_rule(prim, axis_size):
1664
+ """VmapRule for `MaxPool3DWithArgmax`."""
1665
+ cdhw_reverse_index = -4
1666
+
1667
+ def vmap_rule(x_bdim):
1668
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1669
+ if is_all_none:
1670
+ return result
1671
+
1672
+ x, x_dim = x_bdim
1673
+ x = _bdim_at_front(x, x_dim, axis_size)
1674
+ x_shape = F.shape(x)
1675
+ input_shape = (-1,) + x_shape[cdhw_reverse_index:]
1676
+ x = F.reshape(x, input_shape)
1677
+ out, indices = prim(x)
1678
+ out_shape = F.shape(out)
1679
+ return_shape = x_shape[:cdhw_reverse_index] + out_shape[cdhw_reverse_index:]
1680
+ out = F.reshape(out, return_shape)
1681
+ indices = F.reshape(indices, return_shape)
1682
+ return (out, 0), (indices, 0)
1683
+
1684
+ return vmap_rule
1685
+
1686
+
1687
+ @vmap_rules_getters.register(P.ApplyRMSProp)
1688
+ def get_rmsprop_vmap_rule(prim, axis_size):
1689
+ """VmapRule for `ApplyRMSProp` operation."""
1690
+ if hasattr(prim, 'batch_rank'):
1691
+ batch_rank = prim.batch_rank + 1
1692
+ else:
1693
+ batch_rank = 1
1694
+
1695
+ batch_prim = _vmap_clone_prim(prim)
1696
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
1697
+ prim_name = prim.name
1698
+
1699
+ def vmap_rule(var_bdim, mean_square_bdim, moment_bdim, lr_bdim, grad_bdim, decay_bdim, momentum_bdim,
1700
+ epsilon_bdim, u_monad):
1701
+ var, var_dim = var_bdim
1702
+ mean_square, mean_square_dim = mean_square_bdim
1703
+ moment, moment_dim = moment_bdim
1704
+ grad, grad_dim = grad_bdim
1705
+ lr, lr_dim = lr_bdim
1706
+ decay, decay_dim = decay_bdim
1707
+ momentum, momentum_dim = momentum_bdim
1708
+ epsilon, epsilon_dim = epsilon_bdim
1709
+
1710
+ if var_dim is None:
1711
+ if any(dim is not None for dim in
1712
+ [mean_square_dim, moment_dim, grad_dim, lr_dim, decay_dim, momentum_dim, epsilon_dim]):
1713
+ raise ValueError("The source axis of 'var' is None, but the source "
1714
+ "axis of 'mean_square/moment/lr/grad/decay/momentum/epsilon'"
1715
+ " is not None. The execution order of "
1716
+ "operator '{}' cannot be guaranteed.".format(prim_name))
1717
+
1718
+ res = prim(var, mean_square, moment, lr, grad, decay, momentum, epsilon,
1719
+ u_monad) # low dimensional operator;
1720
+ return (res, None)
1721
+ precondition = var_dim != 0 or var_dim != mean_square_dim or var_dim != moment_dim or var_dim != grad_dim
1722
+ if precondition:
1723
+ raise ValueError(
1724
+ f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_square_dim' "
1725
+ f"and 'moment_dim' and 'grad_dim' and not equal to 0, "
1726
+ f"but got the source axis of 'var': {var_dim}, "
1727
+ f"'mean_square_dim': {mean_square_dim}, "
1728
+ f"'moment_dim': {moment_dim},"
1729
+ f"'gradient_dim':{grad_dim}.")
1730
+
1731
+ mean_square = _bdim_at_front(mean_square, mean_square_dim, axis_size)
1732
+ moment = _bdim_at_front(moment, moment_dim, axis_size)
1733
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1734
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1735
+
1736
+ res = batch_prim(var, mean_square, moment, lr, grad, decay, momentum, epsilon,
1737
+ u_monad) # High dimensional operator;
1738
+
1739
+ return res, 0
1740
+
1741
+ return vmap_rule
1742
+
1743
+
1744
+ @vmap_rules_getters.register(P.ApplyCenteredRMSProp)
1745
+ def get_apply_centered_rmsprop_vmap_rule(prim, axis_size):
1746
+ """VmapRule for `ApplyCenteredRMSProp` operation."""
1747
+ if hasattr(prim, 'batch_rank'):
1748
+ batch_rank = prim.batch_rank + 1
1749
+ else:
1750
+ batch_rank = 1
1751
+ prim_name = prim.name
1752
+ batch_prim = _vmap_clone_prim(prim)
1753
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1754
+
1755
+ def vmap_rule(var_bdim, mean_grad_bdim, mean_square_bdim, mom_bdim, grad_bdim, lr_bdim, rho_bdim,
1756
+ momentum_bdim, eps_bdim, u_monad):
1757
+ var, var_dim = var_bdim
1758
+ mean_grad, mean_grad_dim = mean_grad_bdim
1759
+ mean_square, mean_square_dim = mean_square_bdim
1760
+ mom, mom_dim = mom_bdim
1761
+ grad, grad_dim = grad_bdim
1762
+ lr, lr_dim = lr_bdim
1763
+ rho, rho_dim = rho_bdim
1764
+ momentum, momentum_dim = momentum_bdim
1765
+ eps, eps_dim = eps_bdim
1766
+
1767
+ if var_dim is None:
1768
+ if any(dim is not None for dim in
1769
+ [mean_grad_dim, mean_square_dim, mom_dim, grad_dim, lr_dim, rho_dim,
1770
+ momentum_dim, eps_dim]):
1771
+ raise ValueError("The source axis of 'var' is None, but the source "
1772
+ "axis of 'mean_gradient/mean_square/mom/grad/lr/rho/momentum/eps'"
1773
+ " is not None. The execution order of "
1774
+ "operator '{}' cannot be guaranteed.".format(prim_name))
1775
+ var = prim(var, mean_grad, mean_square,
1776
+ mom, grad, lr, rho, momentum, eps, u_monad)
1777
+ return (var, None)
1778
+ precondition = var_dim != 0 or var_dim != mean_grad_dim or var_dim != mean_square_dim or var_dim != mom_dim
1779
+ if precondition:
1780
+ raise ValueError(
1781
+ f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_grad_dim' "
1782
+ f"and 'mean_square_dim' and 'mom_dim' and not equal to 0, "
1783
+ f"but got the source axis of 'var': {var_dim}, "
1784
+ f"'mean_grad_dim': {mean_grad_dim}, "
1785
+ f"'mean_square_dim': {mean_square_dim},"
1786
+ f"'mom_dim': {mom_dim}.")
1787
+
1788
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1789
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1790
+ rho = _bdim_at_front(rho, rho_dim, axis_size)
1791
+ momentum = _bdim_at_front(momentum, momentum_dim, axis_size)
1792
+ eps = _bdim_at_front(eps, eps_dim, axis_size)
1793
+
1794
+ var = batch_prim(var, mean_grad, mean_square,
1795
+ mom, grad, lr, rho, momentum, eps, u_monad)
1796
+ return var, 0
1797
+
1798
+ return vmap_rule
1799
+
1800
+
1801
+ @vmap_rules_getters.register(P.MaxPool)
1802
+ @vmap_rules_getters.register(P.MaxPoolWithArgmax)
1803
+ @vmap_rules_getters.register(P.MaxPoolWithArgmaxV2)
1804
+ def get_max_pool_vmap_rule(prim, axis_size):
1805
+ """VmapRule for `MaxPool` operation."""
1806
+ if isinstance(prim, str):
1807
+ prim = Primitive(prim)
1808
+
1809
+ prim_name = prim.name
1810
+
1811
+ @_primexpr
1812
+ def get_original_shape(x_shape, out_shape):
1813
+ h_new = out_shape[2]
1814
+ w_new = out_shape[3]
1815
+ original_shape = x_shape[:3] + (h_new,) + (w_new,)
1816
+ return original_shape
1817
+
1818
+ def vmap_rule(x_bdim):
1819
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1820
+ if is_all_none:
1821
+ return result
1822
+ x, x_dim = x_bdim
1823
+ x = _bdim_at_front(x, x_dim, axis_size)
1824
+ x_shape = x.shape
1825
+ x_new_shape = (-1,) + x_shape[2:]
1826
+ x = x.reshape(x_new_shape)
1827
+ if prim_name == "MaxPool":
1828
+ out = prim(x)
1829
+ out_shape = out.shape
1830
+ original_shape = get_original_shape(x_shape, out_shape)
1831
+ out = out.reshape(original_shape)
1832
+ return out, 0
1833
+ out, indices = prim(x)
1834
+ out_shape = out.shape
1835
+ original_shape = get_original_shape(x_shape, out_shape)
1836
+ out = out.reshape(original_shape)
1837
+ indices = indices.reshape(original_shape)
1838
+ return (out, 0), (indices, 0)
1839
+
1840
+ return vmap_rule
1841
+
1842
+
1843
+ @vmap_rules_getters.register("LayerNorm")
1844
+ def get_layernorm_vmap_rule(prim, axis_size):
1845
+ """VmapRule for `LayerNorm` operation."""
1846
+
1847
+ def process_attr_axis(prim_attr_axis):
1848
+ if prim_attr_axis < 0:
1849
+ return prim_attr_axis
1850
+ return prim_attr_axis + 1
1851
+
1852
+ @_primexpr
1853
+ def get_logical_shape(var_shape):
1854
+ return var_shape[1:]
1855
+
1856
+ def vmap_rule(x_bdim, gamma_bdim, beta_bdim, begin_norm_axis_bdim, begin_params_axis_bdim, epsilon_bdim):
1857
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, gamma_bdim, beta_bdim, begin_norm_axis_bdim,
1858
+ begin_params_axis_bdim, epsilon_bdim)
1859
+ if is_all_none:
1860
+ return result
1861
+
1862
+ x, x_dim = x_bdim
1863
+ g, g_dim = gamma_bdim
1864
+ b, b_dim = beta_bdim
1865
+ begin_norm_axis, _ = begin_norm_axis_bdim
1866
+ begin_params_axis, _ = begin_params_axis_bdim
1867
+ epsilon, _ = epsilon_bdim
1868
+
1869
+ begin_norm_axis = process_attr_axis(begin_norm_axis)
1870
+ begin_params_axis = process_attr_axis(begin_params_axis)
1871
+
1872
+ x = _bdim_at_front(x, x_dim, axis_size)
1873
+
1874
+ if g_dim is None and b_dim is None:
1875
+ output, mean, var = prim(x, g, b, begin_norm_axis, begin_params_axis, epsilon)
1876
+ return (output, 0), (mean, 0), (var, 0)
1877
+
1878
+ g = _bdim_at_front(g, g_dim, axis_size)
1879
+ b = _bdim_at_front(b, b_dim, axis_size)
1880
+ g_logical_shape = get_logical_shape(F.shape(g))
1881
+ b_logical_shape = get_logical_shape(F.shape(b))
1882
+
1883
+ ones_like_g = F.ones(g_logical_shape, F.dtype(g))
1884
+ zeros_like_b = F.zeros(b_logical_shape, F.dtype(b))
1885
+ output_tmp, mean, var = prim(x, ones_like_g, zeros_like_b, begin_norm_axis, begin_params_axis, epsilon)
1886
+
1887
+ x_shape = F.shape(x)
1888
+ g_shape = F.shape(g)
1889
+ b_shape = F.shape(b)
1890
+ g = _handle_broadcasting(g, g_shape, x_shape)
1891
+ b = _handle_broadcasting(b, b_shape, x_shape)
1892
+ output = F.add(F.mul(output_tmp, g), b)
1893
+
1894
+ return (output, 0), (mean, 0), (var, 0)
1895
+
1896
+ return vmap_rule
1897
+
1898
+
1899
+ @vmap_rules_getters.register(NN.GridSampler2D)
1900
+ @vmap_rules_getters.register(NN.GridSampler3D)
1901
+ def get_grid_sampler_vmap_rule(prim, axis_size):
1902
+ """VmapRule for `GridSampler2D` and `GridSampler3D`."""
1903
+ prim_name = prim.name
1904
+ if prim_name == "GridSampler2D":
1905
+ non_batch_dim_index = -3
1906
+ elif prim_name == "GridSampler3D":
1907
+ non_batch_dim_index = -4
1908
+ else:
1909
+ _raise_value_error(
1910
+ "The prim name must be `GridSampler2D` or `GridSampler3D`, but got {}.".format(prim_name))
1911
+
1912
+ def vmap_rule(input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim):
1913
+ is_all_none, result = vmap_general_preprocess(
1914
+ prim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim)
1915
+ if is_all_none:
1916
+ return result
1917
+
1918
+ input_x, input_x_dim = input_x_bdim
1919
+ grid, grid_dim = grid_bdim
1920
+ interpolation_mode, _ = interpolation_mode_bdim
1921
+ padding_mode, _ = padding_mode_bdim
1922
+ align_corners, _ = align_corners_bdim
1923
+
1924
+ input_x = _bdim_at_front(input_x, input_x_dim, axis_size)
1925
+ input_x_shape = F.shape(input_x)
1926
+ input_x = F.reshape(input_x, (-1,) + input_x_shape[non_batch_dim_index:])
1927
+
1928
+ grid = _bdim_at_front(grid, grid_dim, axis_size)
1929
+ grid_shape = F.shape(grid)
1930
+ grid = F.reshape(grid, (-1,) + grid_shape[non_batch_dim_index:])
1931
+
1932
+ out = prim(input_x, grid, interpolation_mode, padding_mode, align_corners)
1933
+ out_shape = F.shape(out)
1934
+ return_shape = input_x_shape[:non_batch_dim_index] + out_shape[non_batch_dim_index:]
1935
+ out = F.reshape(out, return_shape)
1936
+ return out, 0
1937
+
1938
+ return vmap_rule
1939
+
1940
+
1941
+ @vmap_rules_getters.register(NN.UpsampleNearest1D)
1942
+ @vmap_rules_getters.register(NN.UpsampleNearest2D)
1943
+ @vmap_rules_getters.register(NN.UpsampleNearest3D)
1944
+ def get_upsample_nearest_3d_vmap_rule(prim, axis_size):
1945
+ """VmapRule for `UpsampleNearest1D`, `UpsampleNearest2D` and `UpsampleNearest3D`."""
1946
+ prim_name = prim.name
1947
+ if prim_name == "UpsampleNearest1D":
1948
+ reverse_index = -2
1949
+ elif prim_name == "UpsampleNearest2D":
1950
+ reverse_index = -3
1951
+ else:
1952
+ reverse_index = -4
1953
+
1954
+ def vmap_rule(x_bdim, size_bdim, scales_bdim):
1955
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, size_bdim,
1956
+ scales_bdim)
1957
+ if is_all_none:
1958
+ return result
1959
+
1960
+ x, x_dim = x_bdim
1961
+ x = _bdim_at_front(x, x_dim, axis_size)
1962
+ size, size_dim = size_bdim
1963
+ scales, scales_dim = scales_bdim
1964
+ if size_dim is not None or scales_dim is not None:
1965
+ _raise_value_error(
1966
+ "For {0}, the source axis of `output_size` and `scales` must be None,"
1967
+ " but got {1} and {2}.".format(prim_name, size_dim, scales_dim))
1968
+
1969
+ x_shape = F.shape(x)
1970
+ input_shape = (-1,) + x_shape[reverse_index:]
1971
+ x = F.reshape(x, input_shape)
1972
+ out = prim(x, size, scales)
1973
+ out_shape = F.shape(out)
1974
+ return_shape = x_shape[:reverse_index] + out_shape[reverse_index:]
1975
+ out = F.reshape(out, return_shape)
1976
+ return out, 0
1977
+
1978
+ return vmap_rule
1979
+
1980
+
1981
+ @vmap_rules_getters.register(NN.UpsampleLinear1D)
1982
+ @vmap_rules_getters.register(NN.UpsampleBilinear2D)
1983
+ @vmap_rules_getters.register(NN.UpsampleTrilinear3D)
1984
+ def get_upsample_linear_vmap_rule(prim, axis_size):
1985
+ """VmapRule for `UpsampleLinear1D`, `UpsampleBilinear2D` and `UpsampleTrilinear3D`."""
1986
+ prim_name = prim.name
1987
+ if prim_name == "UpsampleLinear1D":
1988
+ reverse_index = -2
1989
+ elif prim_name == "UpsampleBilinear2D":
1990
+ reverse_index = -3
1991
+ else:
1992
+ reverse_index = -4
1993
+
1994
+ def vmap_rule(x_bdim, size_bdim, scales_bdim, align_corners_bdim):
1995
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, size_bdim,
1996
+ scales_bdim, align_corners_bdim)
1997
+ if is_all_none:
1998
+ return result
1999
+
2000
+ x, x_dim = x_bdim
2001
+ x = _bdim_at_front(x, x_dim, axis_size)
2002
+ size, size_dim = size_bdim
2003
+ scales, scales_dim = scales_bdim
2004
+ align_corners, align_corners_dim = align_corners_bdim
2005
+ if size_dim is not None or scales_dim is not None or align_corners_dim is not None:
2006
+ _raise_value_error(
2007
+ "For {0}, the source axis of `output_size`, `scales` and `align_corners`must"
2008
+ "be None, but got {1} and {2}.".format(prim_name, size_dim, scales_dim))
2009
+
2010
+ x_shape = F.shape(x)
2011
+ input_shape = (-1,) + x_shape[reverse_index:]
2012
+ x = F.reshape(x, input_shape)
2013
+ out = prim(x, size, scales, align_corners)
2014
+ out_shape = F.shape(out)
2015
+ return_shape = x_shape[:reverse_index] + out_shape[reverse_index:]
2016
+ out = F.reshape(out, return_shape)
2017
+ return out, 0
2018
+
2019
+ return vmap_rule
2020
+
2021
+
2022
+ @vmap_rules_getters.register(NN.SparseApplyAdagrad)
2023
+ @vmap_rules_getters.register(NN.SparseApplyAdagradV2)
2024
+ def get_sparse_apply_adagrad_vmap_rule(prim, axis_size):
2025
+ """VmapRule for `SparseApplyAdagrad`."""
2026
+ if hasattr(prim, 'batch_rank'):
2027
+ batch_rank = prim.batch_rank + 1
2028
+ else:
2029
+ batch_rank = 1
2030
+
2031
+ prim_name = prim.name
2032
+ batch_prim = _vmap_clone_prim(prim)
2033
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
2034
+
2035
+ def vmap_rule(var_bdim, accum_bdim, grad_bdim, indices_bdim, u_monad):
2036
+ var, var_dim = var_bdim
2037
+ accum, accum_dim = accum_bdim
2038
+ grad, grad_dim = grad_bdim
2039
+ indices, indices_dim = indices_bdim
2040
+ if var_dim is None:
2041
+ if any(dim is not None for dim in [accum_dim, grad_dim, indices_dim]):
2042
+ ValueError("The source axis of `var` is None, but the source "
2043
+ "axis of `accum/grad/indices` is not None. The execution order of "
2044
+ "operator `{}` cannot be guaranteed.".format(prim_name))
2045
+ var, accum = prim(var, accum, grad, indices, u_monad)
2046
+ return (var, None), (accum, None)
2047
+ if var_dim != 0 or accum_dim != var_dim:
2048
+ ValueError("For `{}`, the source axis of `var` must be equal to `accum`, and not equal to 0, "
2049
+ "but got the source axis of `var`: {}, `accum`: {}.".format(prim_name, var_dim, accum_dim))
2050
+
2051
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
2052
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
2053
+
2054
+ var, accum = batch_prim(var, accum, grad, indices, u_monad)
2055
+ return (var, 0), (accum, 0)
2056
+
2057
+ return vmap_rule
2058
+
2059
+
2060
+ @vmap_rules_getters.register(NN.SparseApplyFtrl)
2061
+ def get_sparse_apply_ftrl_vmap_rule(prim, axis_size):
2062
+ """VmapRule for `SparseApplyFtrl`."""
2063
+ if hasattr(prim, 'batch_rank'):
2064
+ batch_rank = prim.batch_rank + 1
2065
+ else:
2066
+ batch_rank = 1
2067
+
2068
+ prim_name = prim.name
2069
+ batch_prim = _vmap_clone_prim(prim)
2070
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
2071
+
2072
+ def vmap_rule(var_bdim, accum_bdim, linear_bdim, grad_bdim, indices_bdim, u_monad):
2073
+ var, var_dim = var_bdim
2074
+ accum, accum_dim = accum_bdim
2075
+ linear, linear_dim = linear_bdim
2076
+ grad, grad_dim = grad_bdim
2077
+ indices, indices_dim = indices_bdim
2078
+ if var_dim is None:
2079
+ if any(dim is not None for dim in [accum_dim, linear_dim, grad_dim, indices_dim]):
2080
+ ValueError("The source axis of `var` is None, but the source "
2081
+ "axis of `accum/linear/grad/indices` is not None. The execution order of "
2082
+ "operator `{}` cannot be guaranteed.".format(prim_name))
2083
+ var, accum, linear = prim(var, accum, linear, grad, indices, u_monad)
2084
+ return (var, None), (accum, None), (linear, None)
2085
+ if var_dim != 0 or accum_dim != var_dim or linear_dim != var_dim:
2086
+ ValueError("For `{}`, the source axis of `var`, `accum` and `linear` must be equal, and "
2087
+ "not equal to 0, but got the source axis of `var`: {}, `accum`: {}, "
2088
+ "`linear`:{}.".format(prim_name, var_dim, accum_dim, linear_dim))
2089
+
2090
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
2091
+ indices = _bdim_at_front(indices, indices_dim, axis_size)
2092
+
2093
+ var, accum, linear = batch_prim(var, accum, linear, grad, indices, u_monad)
2094
+ return (var, 0), (accum, 0), (linear, 0)
2095
+
2096
+ return vmap_rule
2097
+
2098
+
2099
+ @vmap_rules_getters.register(P.Dense)
2100
+ def get_dense_vmap_rule(prim, axis_size):
2101
+ """VmapRule for `Dense` operation."""
2102
+ if isinstance(prim, str):
2103
+ prim = Primitive(prim)
2104
+
2105
+ batch_matmul = P.BatchMatMul(transpose_b=True)
2106
+
2107
+ @_primexpr
2108
+ def get_start_mid_end(x_shape):
2109
+ start = x_shape[0]
2110
+ mid = 1
2111
+ for shp in x_shape[1:-1]:
2112
+ mid *= shp
2113
+ end = x_shape[-1]
2114
+ return start, mid, end
2115
+
2116
+ def vmap_rule(x_bdim, w_bdim, b_bdim):
2117
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, w_bdim, b_bdim)
2118
+ if is_all_none:
2119
+ return result
2120
+
2121
+ x, x_dim = x_bdim
2122
+ w, w_dim = w_bdim
2123
+ b, b_dim = b_bdim
2124
+ x = _bdim_at_front(x, x_dim, axis_size)
2125
+ w = _bdim_at_front(w, w_dim, axis_size)
2126
+ if b is not None:
2127
+ b = _bdim_at_front(b, b_dim, axis_size)
2128
+
2129
+ x_shape = x.shape
2130
+ start, mid, end = get_start_mid_end(x_shape)
2131
+
2132
+ x = x.reshape(start, mid, end)
2133
+
2134
+ out = batch_matmul(x, w)
2135
+ out_shape = tuple(x_shape[:-1]) + (out.shape[-1],)
2136
+ out = out.reshape(out_shape)
2137
+
2138
+ if b is not None:
2139
+ b_shape = b.shape
2140
+ b_shape = (start,) + (1,) * (len(out_shape) - 2) + (b_shape[-1],)
2141
+ b = b.reshape(b_shape)
2142
+
2143
+ out = out + b
2144
+
2145
+ return out, 0
2146
+
2147
+ return vmap_rule
2148
+
2149
+
2150
+ @vmap_rules_getters.register(P.CeLU)
2151
+ def get_logit_vmap_rule(prim, axis_size):
2152
+ """VmapRule for `CeLU` operation"""
2153
+
2154
+ def vmap_rule(x_bdim, alpha_bdim):
2155
+ x_data, x_dim = x_bdim
2156
+ alpha_data, _ = alpha_bdim
2157
+ out = F.celu(x_data, alpha_data)
2158
+ return out, x_dim
2159
+
2160
+ return vmap_rule
2161
+
2162
+
2163
+ @vmap_rules_getters.register(P.Elu)
2164
+ def get_elu_vmap_rule(prim, axis_size):
2165
+ """VmapRule for Elu operations."""
2166
+ if isinstance(prim, str):
2167
+ prim = Primitive(prim)
2168
+
2169
+ def vmap_rule(x_bdim, alpha_bdim):
2170
+ var, dim = x_bdim
2171
+ alpha, alpha_dim = alpha_bdim
2172
+
2173
+ if alpha_dim is not None:
2174
+ _raise_value_error("The source alpha of `alpha` in ELu must be None, but got {}.".format(alpha_dim))
2175
+
2176
+ out = prim(var, alpha)
2177
+ return out, dim
2178
+
2179
+ return vmap_rule
2180
+
2181
+
2182
+ @vmap_rules_getters.register(Embedding)
2183
+ def get_embedding_vmap_rule(prim, axis_size):
2184
+ """VmapRule for Embedding operations."""
2185
+ if isinstance(prim, str):
2186
+ prim_name = prim
2187
+ else:
2188
+ prim_name = prim.name
2189
+ raise RuntimeError(f"THe {prim_name} does not support vmap.")
2190
+
2191
+
2192
+ # Unary vmap
2193
+ get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU)(get_unop_vmap_rule)
2194
+ get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU6)(get_unop_vmap_rule)
2195
+ get_unop_vmap_rule = vmap_rules_getters.register(P.SeLU)(get_unop_vmap_rule)
2196
+ get_unop_vmap_rule = vmap_rules_getters.register(P.HSigmoid)(get_unop_vmap_rule)
2197
+ get_unop_vmap_rule = vmap_rules_getters.register(P.Softplus)(get_unop_vmap_rule)
2198
+ get_unop_vmap_rule = vmap_rules_getters.register(P.Softsign)(get_unop_vmap_rule)
2199
+ get_unop_vmap_rule = vmap_rules_getters.register(P.SoftShrink)(get_unop_vmap_rule)
2200
+ get_unop_vmap_rule = vmap_rules_getters.register(P.GeLU)(get_unop_vmap_rule)
2201
+ get_unop_vmap_rule = vmap_rules_getters.register(P.FastGeLU)(get_unop_vmap_rule)
2202
+ get_unop_vmap_rule = vmap_rules_getters.register(P.HSwish)(get_unop_vmap_rule)
2203
+ get_unop_vmap_rule = vmap_rules_getters.register(P.Tanh)(get_unop_vmap_rule)
2204
+ # UnaryGrad vmap
2205
+ get_unary_grad_vmap_rule = vmap_rules_getters.register(G.TanhGrad)(get_unary_grad_vmap_rule)
2206
+ get_unary_grad_vmap_rule = vmap_rules_getters.register(G.SoftplusGrad)(get_unary_grad_vmap_rule)
2207
+ get_unary_grad_vmap_rule = vmap_rules_getters.register('ReluGrad')(get_unary_grad_vmap_rule)
2208
+ get_unary_grad_vmap_rule = vmap_rules_getters.register('ReLU6Grad')(get_unary_grad_vmap_rule)
2209
+ get_unary_grad_vmap_rule = vmap_rules_getters.register('RsqrtGrad')(get_unary_grad_vmap_rule)