mindspore 1.10.0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,19 +14,29 @@
14
14
  # ============================================================================
15
15
 
16
16
  """Operators for gradients."""
17
- from __future__ import division
18
17
  from __future__ import absolute_import
19
18
 
19
+ from __future__ import division
20
20
  from mindspore._checkparam import _check_3d_int_or_tuple
21
21
  from mindspore.ops.operations.nn_ops import _check_positive_int_or_tuple
22
22
  from mindspore.ops import signature as sig
23
23
  from mindspore.ops._utils import get_concat_offset
24
24
  from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, prim_attr_register
25
25
  import mindspore.context as context
26
- from mindspore._checkparam import Validator as validator, Rel
26
+ from mindspore import _checkparam as validator
27
27
  from mindspore.common import dtype as mstype
28
28
  from mindspore.communication.management import GlobalComm
29
- from mindspore.ops._utils import is_shape_unknown
29
+ from mindspore.common._utils import is_shape_unknown, is_dim_unknown
30
+
31
+
32
+ class SparseFillEmptyRowsGrad(Primitive):
33
+ """Performs grad of SparseFillEmptyRows operation."""
34
+
35
+ @prim_attr_register
36
+ def __init__(self):
37
+ """Initialize SparseFillEmptyRowsGrad."""
38
+ self.init_prim_io_names(inputs=['reverse_index_map', 'grad_values'],
39
+ outputs=['y_values', 'y_default_value'])
30
40
 
31
41
 
32
42
  class AbsGrad(PrimitiveWithInfer):
@@ -104,6 +114,7 @@ class ReciprocalGrad(Primitive):
104
114
  @prim_attr_register
105
115
  def __init__(self):
106
116
  """Initialize ReciprocalGrad"""
117
+ self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
107
118
 
108
119
 
109
120
  class RsqrtGrad(Primitive):
@@ -145,7 +156,7 @@ class BatchNormGrad(Primitive):
145
156
  @prim_attr_register
146
157
  def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'):
147
158
  self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
148
- self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
159
+ self.epsilon = validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
149
160
  self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
150
161
 
151
162
 
@@ -155,7 +166,7 @@ class BatchNormGradGrad(Primitive):
155
166
  @prim_attr_register
156
167
  def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'):
157
168
  self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
158
- self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
169
+ self.epsilon = validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
159
170
  self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
160
171
 
161
172
 
@@ -164,10 +175,10 @@ class SyncBatchNormGrad(Primitive):
164
175
 
165
176
  @prim_attr_register
166
177
  def __init__(self, epsilon=1e-5, group="group0", device_num=2):
167
- validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
178
+ validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
168
179
  if not isinstance(group, str):
169
180
  raise TypeError("The group attr of SyncBatchNormGrad must be str.")
170
- validator.check_int(device_num, 2, Rel.GE, "device_num", self.name)
181
+ validator.check_int(device_num, 2, validator.GE, "device_num", self.name)
171
182
 
172
183
 
173
184
  class BiasAddGrad(Primitive):
@@ -177,8 +188,6 @@ class BiasAddGrad(Primitive):
177
188
  def __init__(self, data_format="NCHW"):
178
189
  self.init_prim_io_names(inputs=['dout'], outputs=['output'])
179
190
  self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
180
- if context.get_context("device_target") != "GPU" and self.format == "NHWC":
181
- raise ValueError("NHWC format only support in GPU target.")
182
191
  if self.format == "NCDHW":
183
192
  self.format = "NCHW"
184
193
  self.add_prim_attr('data_format', self.format)
@@ -232,10 +241,27 @@ class ConcatOffset(PrimitiveWithInfer):
232
241
  x_type = input_x['dtype']
233
242
  self.add_prim_attr('T', x_type[0].element_type())
234
243
 
244
+ # input_x is dynamic rank
245
+ rank = -1
246
+ is_dyn_rank = False
247
+ for _, sh in enumerate(x_shp):
248
+ if is_dim_unknown(sh):
249
+ is_dyn_rank = True
250
+ else:
251
+ rank = len(sh)
252
+ if is_dyn_rank:
253
+ return {
254
+ 'shape': [len(x_shp), rank],
255
+ 'dtype': mstype.int64,
256
+ 'value': None
257
+ }
258
+
235
259
  # if the dimension of input_x on the axis is dynamic
236
- rank_base = len(x_shp[0])
260
+ if axis < -rank or axis >= rank:
261
+ raise ValueError("For 'ConcatOffset', 'axis' must be in range [{}, {}), but got {}"
262
+ .format(-rank, rank, axis))
237
263
  if axis < 0:
238
- axis = axis + rank_base
264
+ axis = axis + rank
239
265
  for each in x_shp:
240
266
  if each[axis] == -1:
241
267
  return {
@@ -466,9 +492,6 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
466
492
  self.group = group
467
493
  self.add_prim_attr('data_format', "NCHW")
468
494
 
469
- def __call__(self, x, w_size, dout):
470
- raise NotImplementedError
471
-
472
495
  def __infer__(self, x, w_size, dout):
473
496
  w_size_v = w_size['value']
474
497
  args = {'x': x['dtype'], 'dout': dout['dtype']}
@@ -533,9 +556,6 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
533
556
  self.group = group
534
557
  self.add_prim_attr('data_format', "NCHW")
535
558
 
536
- def __call__(self, x_size, w, dout):
537
- raise NotImplementedError
538
-
539
559
  def __infer__(self, x_size, w, dout):
540
560
  args = {'w': w['dtype'], 'dout': dout['dtype']}
541
561
  validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
@@ -571,7 +591,7 @@ class DropoutGrad(Primitive):
571
591
 
572
592
  @prim_attr_register
573
593
  def __init__(self, keep_prob=0.5):
574
- self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
594
+ self.keep_prob = validator.check_float_range(keep_prob, 0, 1, validator.INC_RIGHT, "keep_prob", self.name)
575
595
 
576
596
 
577
597
  class FlattenGrad(PrimitiveWithInfer):
@@ -599,6 +619,18 @@ class InstanceNormGrad(PrimitiveWithInfer):
599
619
  outputs=['dx', 'bn_gamma', 'bn_beta'])
600
620
 
601
621
 
622
+ class InstanceNormV2Grad(Primitive):
623
+ """Gradients of InstanceNormV2 operation."""
624
+
625
+ @prim_attr_register
626
+ def __init__(self, is_training=True, epsilon=1e-5):
627
+ self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'mean', 'variance', 'save_mean', 'save_variance'],
628
+ outputs=['pd_x', 'pd_gamma', 'pd_beta'])
629
+ validator.check_is_float(epsilon, 'epsilon', self.name)
630
+ validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
631
+ validator.check_bool(is_training, "is_training", self.name)
632
+
633
+
602
634
  class EinsumGrad(PrimitiveWithInfer):
603
635
  """Gradients of Einsum."""
604
636
 
@@ -626,9 +658,6 @@ class UniqueGrad(Primitive):
626
658
  def __init__(self):
627
659
  self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
628
660
 
629
- def __call__(self, dy, x, scale, save_mean, save_inv_variance):
630
- raise NotImplementedError
631
-
632
661
 
633
662
  class BNTrainingReduceGrad(Primitive):
634
663
  """Gradients of FusedBatchNorm operation."""
@@ -666,7 +695,7 @@ class NeighborExchangeV2Grad(PrimitiveWithInfer):
666
695
 
667
696
  def __infer__(self, dy):
668
697
  dy_shape = dy['shape']
669
- validator.check(f'dy_shape.size()', len(dy_shape), f'4', 4, Rel.EQ, self.name)
698
+ validator.check(f'dy_shape.size()', len(dy_shape), f'4', 4, validator.EQ, self.name)
670
699
  if self.send_rank_ids[5] != -1 or self.send_rank_ids[6] != -1 or self.send_rank_ids[7] != -1:
671
700
  dy_shape[3] -= self.send_lens[2]
672
701
 
@@ -683,9 +712,6 @@ class NeighborExchangeV2Grad(PrimitiveWithInfer):
683
712
  'dtype': dy['dtype'],
684
713
  'value': None}
685
714
 
686
- def __call__(self, tensor):
687
- raise NotImplementedError
688
-
689
715
 
690
716
  class GeLUGrad(Primitive):
691
717
  """Gradients of GeLU operation."""
@@ -853,18 +879,13 @@ class AvgPoolGradV1(Primitive):
853
879
  self.add_prim_attr("strides", self.strides_adapt)
854
880
 
855
881
 
856
- class AdaptiveAvgPool2DGrad(PrimitiveWithInfer):
882
+ class AdaptiveAvgPool2DGrad(Primitive):
857
883
  """Gradients of the adaptive avg pool 2D operation."""
858
884
 
859
885
  @prim_attr_register
860
886
  def __init__(self):
861
887
  """Initialize AdaptiveAvgPool2DGrad"""
862
-
863
- def infer_shape(self, x1_shape, grad_shape):
864
- return x1_shape
865
-
866
- def infer_dtype(self, x1_dtype, grad_dtype):
867
- return x1_dtype
888
+ self.init_prim_io_names(inputs=['input_grad', 'orig_input_shape'], outputs=['output_grad'])
868
889
 
869
890
 
870
891
  class AdaptiveAvgPool3DGrad(Primitive):
@@ -872,7 +893,6 @@ class AdaptiveAvgPool3DGrad(Primitive):
872
893
  @prim_attr_register
873
894
  def __init__(self):
874
895
  self.init_prim_io_names(inputs=['y_grad', 'orig_input_shape'], outputs=['x_grad'])
875
- self.set_const_input_indexes([1])
876
896
 
877
897
 
878
898
  class AvgPool3DGrad(Primitive):
@@ -1180,9 +1200,6 @@ class MaximumGrad(Primitive):
1180
1200
  """Initialize MaximumGrad"""
1181
1201
  self.init_prim_io_names(inputs=['x1', 'x2', 'grads'], outputs=['y1', 'y2'])
1182
1202
 
1183
- def __call__(self, x, y, dout):
1184
- raise NotImplementedError
1185
-
1186
1203
 
1187
1204
  class MaximumGradGrad(Primitive):
1188
1205
  """Grad for maximum grad."""
@@ -1194,13 +1211,71 @@ class MaximumGradGrad(Primitive):
1194
1211
  self.init_prim_io_names(inputs=['x1', 'x2', 'dy1', 'dy2'], outputs=['sopd_x1', 'sopd_x2', 'sopd_grad'])
1195
1212
 
1196
1213
 
1197
- class MaxPoolGradWithArgmax(_PoolGrad):
1214
+ class MaxPoolGradWithArgmax(Primitive):
1198
1215
  """Computes the gradients of MaxPoolWithArgmax."""
1216
+ @prim_attr_register
1217
+ def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
1218
+ self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
1219
+ validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
1220
+ validator.check_value_type('strides', strides, [int, tuple], self.name)
1221
+ validator.check_value_type('pad_mode', pad_mode, [str], self.name)
1222
+ self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
1223
+ self.add_prim_attr("pad_mode", self.pad_mode)
1224
+ self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
1225
+ if context.get_context("device_target") != "GPU" and self.format == "NHWC":
1226
+ raise ValueError("NHWC format only support in GPU target.")
1227
+ self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
1228
+ if not self.is_maxpoolgradwithargmax:
1229
+ self.add_prim_attr('data_format', self.format)
1230
+
1231
+ def _grad_check_int_or_tuple(arg_name, arg_val):
1232
+ validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
1233
+ error_msg = ValueError(f"For '{self.name}' the '{arg_name}' must be an positive int number "
1234
+ f"or a tuple of two or four positive int numbers, but got {arg_val}")
1235
+ if isinstance(arg_val, int):
1236
+ ret = (1, arg_val, arg_val, 1)
1237
+ elif len(arg_val) == 2:
1238
+ ret = (1, arg_val[0], arg_val[1], 1)
1239
+ elif len(arg_val) == 4:
1240
+ ret = arg_val
1241
+ else:
1242
+ raise error_msg
1243
+ # whether all elements of tuple are positive integers
1244
+ for item in ret:
1245
+ if not isinstance(item, int) or item <= 0:
1246
+ raise error_msg
1247
+ return ret
1248
+
1249
+ kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size)
1250
+ self.kernel_size = kernel_size
1251
+ self.add_prim_attr("kernel_size", self.kernel_size)
1252
+
1253
+ strides = _grad_check_int_or_tuple("strides", strides)
1254
+ self.strides = strides
1255
+ self.add_prim_attr("strides", self.strides)
1256
+
1257
+
1258
+ class MaxPoolGradWithArgmaxV2(Primitive):
1259
+ """Gradients of the MaxPoolWithArgmaxV2 operation."""
1199
1260
 
1200
1261
  @prim_attr_register
1201
- def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
1202
- self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
1203
- super(MaxPoolGradWithArgmax, self).__init__(kernel_size, strides, pad_mode)
1262
+ def __init__(self, kernel_size, strides=None, pads=0, dilation=(1, 1), ceil_mode=False, argmax_type=mstype.int64):
1263
+ self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['y'])
1264
+ self.kernel_size = _check_positive_int_or_tuple("kernel_size", kernel_size, self.name, allow_four=True,
1265
+ ret_four=True)
1266
+ self.add_prim_attr('kernel_size', self.kernel_size)
1267
+ if strides is None:
1268
+ strides = kernel_size
1269
+ self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=True, ret_four=True)
1270
+ self.add_prim_attr('strides', self.strides)
1271
+ self.pads = _check_positive_int_or_tuple("pads", pads, self.name, allow_four=True, ret_four=True,
1272
+ strict_positive=False)
1273
+ self.add_prim_attr('pads', self.pads)
1274
+ validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
1275
+ self.add_prim_attr('ceil_mode', self.ceil_mode)
1276
+ self.dilation = _check_positive_int_or_tuple("dilation", dilation, self.name, allow_four=True, ret_four=True)
1277
+ self.add_prim_attr('dilation', self.dilation)
1278
+ self.add_prim_attr('argmax_type', self.argmax_type)
1204
1279
 
1205
1280
 
1206
1281
  class MaxPool3DGradWithArgmax(Primitive):
@@ -1292,9 +1367,6 @@ class MinimumGrad(Primitive):
1292
1367
  """Initialize MinimumGrad"""
1293
1368
  self.init_prim_io_names(inputs=['x1', 'x2', 'grads'], outputs=['y1', 'y2'])
1294
1369
 
1295
- def __call__(self, x, y, dout):
1296
- raise NotImplementedError
1297
-
1298
1370
 
1299
1371
  class MinimumGradGrad(Primitive):
1300
1372
  """Grad for minimum_grad."""
@@ -1354,9 +1426,6 @@ class LayerNormGrad(Primitive):
1354
1426
  self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
1355
1427
  self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
1356
1428
 
1357
- def __call__(self, x, dy, variance, mean, gamma):
1358
- raise NotImplementedError
1359
-
1360
1429
 
1361
1430
  class LayerNormGradGrad(Primitive):
1362
1431
  """
@@ -1389,7 +1458,7 @@ class LayerNormGradGrad(Primitive):
1389
1458
  ValueError: If gamma, d_dg, d_db don't have the same shape.
1390
1459
 
1391
1460
  Supported Platforms:
1392
- ``Ascend`` ``CPU`` ``GPU``
1461
+ ``Ascend`` ``GPU`` ``CPU``
1393
1462
  """
1394
1463
 
1395
1464
  @prim_attr_register
@@ -1397,6 +1466,8 @@ class LayerNormGradGrad(Primitive):
1397
1466
  """init"""
1398
1467
  self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
1399
1468
  self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
1469
+ self.init_prim_io_names(inputs=['x', 'dy', 'variance', 'mean', 'gamma', 'd_dx', 'd_dg', 'd_db'],
1470
+ outputs=['sopd_x', 'sopd_dy', 'sopd_gamma'])
1400
1471
 
1401
1472
 
1402
1473
  class LogSoftmaxGrad(Primitive):
@@ -1408,7 +1479,7 @@ class LogSoftmaxGrad(Primitive):
1408
1479
  validator.check_value_type("axis", axis, [int], self.name)
1409
1480
 
1410
1481
 
1411
- class LSTMGradData(PrimitiveWithInfer):
1482
+ class LSTMGradData(Primitive):
1412
1483
  """Computes the data gradients of LSTM."""
1413
1484
 
1414
1485
  @prim_attr_register
@@ -1419,43 +1490,15 @@ class LSTMGradData(PrimitiveWithInfer):
1419
1490
  self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1420
1491
  self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1421
1492
  self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1422
- self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
1493
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1423
1494
 
1424
1495
  if bidirectional:
1425
1496
  self.num_directions = 2
1426
1497
  else:
1427
1498
  self.num_directions = 1
1428
1499
 
1429
- def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
1430
- hx_shape, cx_shape, reserve_shape, state_shape):
1431
- # dhy and dcy should be same shape
1432
- validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
1433
- validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
1434
- validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
1435
- validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
1436
- validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
1437
-
1438
- validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
1439
- validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
1440
-
1441
- validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
1442
- validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
1443
- validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
1444
-
1445
- dx_shape = (y_shape[0], y_shape[1], self.input_size)
1446
- dhx_shape = dhy_shape
1447
- dcx_shape = dcy_shape
1448
1500
 
1449
- return (dx_shape, dhx_shape, dcx_shape)
1450
-
1451
- def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
1452
- hx_dtype, cx_dtype, reserve_dtype, state_dtype):
1453
- args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
1454
- validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
1455
- return (dy_dtype, dy_dtype, dy_dtype)
1456
-
1457
-
1458
- class LSTMGradWeight(PrimitiveWithInfer):
1501
+ class LSTMGradWeight(Primitive):
1459
1502
  """Computes the weight gradients of LSTM."""
1460
1503
 
1461
1504
  @prim_attr_register
@@ -1466,31 +1509,15 @@ class LSTMGradWeight(PrimitiveWithInfer):
1466
1509
  self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1467
1510
  self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1468
1511
  self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1469
- self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
1512
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1470
1513
 
1471
1514
  if bidirectional:
1472
1515
  self.num_directions = 2
1473
1516
  else:
1474
1517
  self.num_directions = 1
1475
1518
 
1476
- def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape):
1477
- weight_size = 0
1478
- gate_size = 4 * self.hidden_size
1479
- for layer in range(self.num_layers):
1480
- for _ in range(self.num_directions):
1481
- input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
1482
- weight_size += gate_size * input_layer_size
1483
- weight_size += gate_size * self.hidden_size
1484
- if self.has_bias:
1485
- weight_size += 2 * gate_size
1486
-
1487
- return (weight_size, 1, 1)
1488
1519
 
1489
- def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype):
1490
- return hx_dtype
1491
-
1492
-
1493
- class LSTMGrad(PrimitiveWithInfer):
1520
+ class LSTMGrad(Primitive):
1494
1521
  """Computes the data and weight gradients of LSTM."""
1495
1522
 
1496
1523
  @prim_attr_register
@@ -1501,50 +1528,15 @@ class LSTMGrad(PrimitiveWithInfer):
1501
1528
  self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1502
1529
  self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1503
1530
  self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1504
- self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
1531
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1505
1532
 
1506
1533
  if bidirectional:
1507
1534
  self.num_directions = 2
1508
1535
  else:
1509
1536
  self.num_directions = 1
1510
1537
 
1511
- def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
1512
- dcy_shape, reserve_shape):
1513
- # dhy and dcy should be same shape
1514
- validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
1515
- validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
1516
- validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
1517
- validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
1518
- validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
1519
-
1520
- validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
1521
- validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
1522
1538
 
1523
- validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
1524
- validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
1525
- validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
1526
-
1527
- dx_shape = (y_shape[0], y_shape[1], self.input_size)
1528
- dhx_shape = dhy_shape
1529
- dcx_shape = dcy_shape
1530
- weight_size = 0
1531
- gate_size = 4 * self.hidden_size
1532
- for layer in range(self.num_layers):
1533
- for _ in range(self.num_directions):
1534
- input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
1535
- weight_size += gate_size * input_layer_size
1536
- weight_size += gate_size * self.hidden_size
1537
- if self.has_bias:
1538
- weight_size += gate_size
1539
-
1540
- return (dx_shape, dhx_shape, dcx_shape, (weight_size, 1, 1))
1541
-
1542
- def infer_dtype(self, x_dtype, hx_dtype, cx_dtype, w_dtype, y_dtype, hy_dtype, cy_dtype, dy_dtype, dhy_dtype,
1543
- dcy_dtype, reserve_dtype):
1544
- return (dy_dtype, dy_dtype, dy_dtype, hx_dtype)
1545
-
1546
-
1547
- class DynamicRNNGrad(PrimitiveWithInfer):
1539
+ class DynamicRNNGrad(Primitive):
1548
1540
  """Computes the input gradients of DynamicRNN."""
1549
1541
 
1550
1542
  @prim_attr_register
@@ -1560,35 +1552,6 @@ class DynamicRNNGrad(PrimitiveWithInfer):
1560
1552
  forget_bias=0.0):
1561
1553
  self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
1562
1554
 
1563
- def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
1564
- c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
1565
- validator.check_equal_int(len(x_shape), 3, "x_shape", self.name)
1566
- num_step, batch_size, input_size = x_shape
1567
- hidden_size = w_shape[-1] // 4
1568
- if w_shape[-1] % 4 != 0:
1569
- raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
1570
- validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
1571
- input_size + hidden_size, Rel.EQ, self.name)
1572
- valid_shape = [num_step, batch_size, hidden_size]
1573
- validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
1574
- validator.check("y_shape", y_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1575
- validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1576
- validator.check("c_shape", c_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1577
- validator.check("i_shape", i_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1578
- validator.check("j_shape", j_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1579
- validator.check("f_shape", f_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1580
- validator.check("o_shape", o_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1581
- validator.check("tanhc_shape", tanhc_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1582
- validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1583
- validator.check("dh_shape", dh_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
1584
- validator.check("dc_shape", dc_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
1585
-
1586
- return w_shape, (w_shape[1],), x_shape, dh_shape, dc_shape
1587
-
1588
- def infer_dtype(self, x_dtype, w_dtype, b_dtype, y_dtype, init_h_dtype, init_c_dtype, h_dtype,
1589
- c_dtype, dy_dtype, dh_dtype, dc_dtype, i_dtype, j_dtype, f_dtype, o_dtype, tanhc_dtype):
1590
- return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
1591
-
1592
1555
 
1593
1556
  class GruGradData(PrimitiveWithInfer):
1594
1557
  """Computes the data gradients of GRU."""
@@ -1601,7 +1564,7 @@ class GruGradData(PrimitiveWithInfer):
1601
1564
  self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1602
1565
  self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1603
1566
  self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1604
- self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
1567
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1605
1568
 
1606
1569
  if bidirectional:
1607
1570
  self.num_directions = 2
@@ -1613,12 +1576,12 @@ class GruGradData(PrimitiveWithInfer):
1613
1576
  # dhy and dcy should be same shape
1614
1577
  validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
1615
1578
 
1616
- validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
1579
+ validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, validator.EQ, "h_shape[0]", self.name)
1617
1580
  validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
1618
1581
 
1619
1582
  validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
1620
1583
  validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
1621
- validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
1584
+ validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, validator.EQ, "dy[2]", self.name)
1622
1585
 
1623
1586
  dx_shape = (y_shape[0], y_shape[1], self.input_size)
1624
1587
  dhx_shape = dhy_shape
@@ -1643,7 +1606,7 @@ class GruGradWeight(PrimitiveWithInfer):
1643
1606
  self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1644
1607
  self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1645
1608
  self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1646
- self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
1609
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1647
1610
 
1648
1611
  if bidirectional:
1649
1612
  self.num_directions = 2
@@ -1667,7 +1630,26 @@ class GruGradWeight(PrimitiveWithInfer):
1667
1630
  return hx_dtype
1668
1631
 
1669
1632
 
1670
- class DynamicGRUV2Grad(PrimitiveWithInfer):
1633
+ class GRUV2Grad(Primitive):
1634
+ """Computes the grad gradients of GRU."""
1635
+
1636
+ @prim_attr_register
1637
+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1638
+ self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1639
+ self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1640
+ self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1641
+ self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1642
+ self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1643
+ self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1644
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1645
+
1646
+ if bidirectional:
1647
+ self.num_directions = 2
1648
+ else:
1649
+ self.num_directions = 1
1650
+
1651
+
1652
+ class DynamicGRUV2Grad(Primitive):
1671
1653
  r"""
1672
1654
  Computes the input gradients of DynamicGRUV2.
1673
1655
 
@@ -1719,13 +1701,13 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
1719
1701
  - **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`.
1720
1702
  Has the same type with input `x`.
1721
1703
  - **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden\_size)`.
1722
- Has the same type with input `x`.
1704
+ Has the same type with input `init\_h`.
1723
1705
  - **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden\_size)`.
1724
- Has the same type with input `x`.
1706
+ Has the same type with input `init\_h`.
1725
1707
  - **dx** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1726
1708
  Has the same type with input `x`.
1727
1709
  - **dh_prev** (Tensor) - A Tensor of shape :math:`(batch\_size, hidden\_size)`.
1728
- Has the same type with input `x`.
1710
+ Has the same type with input `init\_h`.
1729
1711
  """
1730
1712
 
1731
1713
  @prim_attr_register
@@ -1746,62 +1728,14 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
1746
1728
  self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
1747
1729
  self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
1748
1730
  self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
1749
-
1750
- def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape,
1751
- dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape):
1752
- validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
1753
- validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
1754
- validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
1755
- validator.check_int(len(y_shape), 3, Rel.EQ, "y shape rank", self.name)
1756
- num_step, batch_size, input_size = x_shape
1757
- hidden_size = whidden_shape[0]
1758
- validator.check("weight_hidden_shape[-1]", whidden_shape[-1], "3 * hidden_size",
1759
- 3 * hidden_size, Rel.EQ, self.name)
1760
- validator.check("weight_input_shape", winput_shape, "excepted shape",
1761
- [input_size, 3 * hidden_size], Rel.EQ, self.name)
1762
- if self.num_proj > 0:
1763
- valid_y_shape = [num_step, batch_size, min(hidden_size, self.num_proj)]
1764
- else:
1765
- valid_y_shape = [num_step, batch_size, hidden_size]
1766
- validator.check("y_shape", y_shape, "excepted shape", valid_y_shape, Rel.EQ, self.name)
1767
-
1768
- validator.check("init_h_shape", init_h_shape, "excepted shape",
1769
- [batch_size, hidden_size], Rel.EQ, self.name)
1770
- valid_shape = [num_step, batch_size, hidden_size]
1771
- validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1772
- validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1773
- validator.check("dh_shape", dh_shape, "excepted shape",
1774
- [batch_size, hidden_size], Rel.EQ, self.name)
1775
- validator.check("update_shape", update_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1776
- validator.check("reset_shape", reset_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1777
- validator.check("new_shape", new_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1778
- validator.check("hnew_shape", hnew_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1779
- if seq_shape is not None:
1780
- validator.check("seq_shape", seq_shape, "batch_size", batch_size, Rel.EQ, self.name)
1781
-
1782
- dx_shape = (num_step, batch_size, input_size)
1783
- dh_shape = (batch_size, hidden_size)
1784
- dwinput_shape = (input_size, 3 * hidden_size)
1785
- dwhidden_shape = (hidden_size, 3 * hidden_size)
1786
- db_shape = (3 * hidden_size,)
1787
- return dwinput_shape, dwhidden_shape, db_shape, db_shape, dx_shape, dh_shape
1788
-
1789
- def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype,
1790
- dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype):
1791
- valid_types = (mstype.float16, mstype.float32)
1792
- args = {"y_dtype": y_dtype, "h_dtype": h_dtype, "dy_dtype": dy_dtype,
1793
- "dh_dtype": dh_dtype, "update_dtype": update_dtype, "reset_dtype": reset_dtype,
1794
- "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
1795
- validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name)
1796
- validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name)
1797
- validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name)
1798
- validator.check_tensor_dtype_valid("init_h_dtype", init_h_dtype, valid_types, self.name)
1799
- validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
1800
- if seq_dtype is not None:
1801
- validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name)
1802
- if mask_dtype is not None:
1803
- validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name)
1804
- return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
1731
+ self.init_prim_io_names(inputs=[
1732
+ "x", "weight_input", "weight_hidden", "y", "init_h", "h", "dy",
1733
+ "dh", "update", "reset", "new", "hidden_new", "seq_length", "mask"
1734
+ ],
1735
+ outputs=[
1736
+ "dw_input", "dw_hidden", "db_input",
1737
+ "db_hidden", "dx", "dh_prev"
1738
+ ])
1805
1739
 
1806
1740
 
1807
1741
  class PReLUGrad(Primitive):
@@ -1825,6 +1759,44 @@ class PReLUGrad(Primitive):
1825
1759
  pass
1826
1760
 
1827
1761
 
1762
+ class RandomGammaGrad(Primitive):
1763
+ r"""
1764
+ Computes the derivative of a random sample of Gamma with respect to alpha.:
1765
+
1766
+ Inputs:
1767
+ - **alpha** (Tensor) - α is the shape parameter of RandomGamma distribution.
1768
+ It must be greater than 0. Must be one of the following types: float32, float64.
1769
+ - **sample** (Tensor) - The sample of random gamma tensor. Must be one of the
1770
+ following types: float32, float64.
1771
+
1772
+ Outputs:
1773
+ The dtype is the same type as alpha.
1774
+ The output shape is derived from the input through broadcasting.
1775
+
1776
+ Raises:
1777
+ TypeError: If data type of `alpha` and `sample` is not float32 or float64.
1778
+ TypeError: If data type of `alpha` and `sample` is not same.
1779
+ ValueError: If the shape last dim of `sample` and `alpha` is not equal.
1780
+
1781
+ Supported Platforms:
1782
+ ``GPU``
1783
+
1784
+ Examples:
1785
+ >>> alpha = Tensor(np.array([1., 0.6, 3., 26.]), mstype.float32)
1786
+ >>> sample = Tensor(np.array([6., 7, 11., 0.5]), mstype.float32)
1787
+ >>> randomgammagrad = ops.RandomGammaGrad()
1788
+ >>> output = randomgammagrad(alpha, sample)
1789
+ >>> print(output)
1790
+ [2.5142431 3.4334087 1.8847835 0.07780622]
1791
+ """
1792
+
1793
+ @prim_attr_register
1794
+ def __init__(self):
1795
+ """Initialize RandomGammaGrad"""
1796
+ self.init_prim_io_names(inputs=['alpha', 'sample'], outputs=['output'])
1797
+ self.add_prim_attr("side_effect_hidden", True)
1798
+
1799
+
1828
1800
  class ReluGrad(Primitive):
1829
1801
  """Performs grad of Relu operation."""
1830
1802
 
@@ -1833,8 +1805,14 @@ class ReluGrad(Primitive):
1833
1805
  """Initialize ReluGrad"""
1834
1806
  self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
1835
1807
 
1836
- def __call__(self, y_backprop, x):
1837
- raise NotImplementedError
1808
+
1809
+ class SiLUGrad(Primitive):
1810
+ """Performs grad of SiLU operation."""
1811
+
1812
+ @prim_attr_register
1813
+ def __init__(self):
1814
+ """Initialize SiLUGrad"""
1815
+ self.init_prim_io_names(inputs=['dout', 'out'], outputs=['output'])
1838
1816
 
1839
1817
 
1840
1818
  class ReLU6Grad(Primitive):
@@ -1844,9 +1822,6 @@ class ReLU6Grad(Primitive):
1844
1822
  def __init__(self):
1845
1823
  self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
1846
1824
 
1847
- def __call__(self, y_grad, x):
1848
- raise NotImplementedError
1849
-
1850
1825
 
1851
1826
  class ReluGradV2(Primitive):
1852
1827
  """Performs grad of ReLUV2 operation."""
@@ -1855,9 +1830,6 @@ class ReluGradV2(Primitive):
1855
1830
  def __init__(self):
1856
1831
  self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output'])
1857
1832
 
1858
- def __call__(self, gradients, mask):
1859
- raise NotImplementedError
1860
-
1861
1833
 
1862
1834
  class EluGrad(Primitive):
1863
1835
  """Performs grad of Elu operation."""
@@ -1885,12 +1857,9 @@ class GatherDGradV2(Primitive):
1885
1857
  """Performs grad of GatherD operation."""
1886
1858
 
1887
1859
  @prim_attr_register
1888
- def __init__(self, dim=0):
1860
+ def __init__(self):
1889
1861
  """Initialize GatherDGradV2"""
1890
- validator.check_is_int(dim, int)
1891
- self.add_prim_attr("dim", dim)
1892
- self.dim = dim
1893
- self.init_prim_io_names(inputs=['x', 'index', 'grad'], outputs=['output'])
1862
+ self.init_prim_io_names(inputs=['x', 'dim', 'index', 'grad'], outputs=['output'])
1894
1863
 
1895
1864
 
1896
1865
  class ResizeBilinearGrad(Primitive):
@@ -1907,9 +1876,6 @@ class ResizeBilinearGrad(Primitive):
1907
1876
  self.init_prim_io_names(inputs=['grads', 'original_image'], outputs=['y'])
1908
1877
  if half_pixel_centers and align_corners:
1909
1878
  raise ValueError(f"If half_pixel_centers is True, align_corners must be False, but got {align_corners}")
1910
- target = context.get_context("device_target")
1911
- if half_pixel_centers and target.lower() == "cpu":
1912
- raise ValueError(f"Currently `half_pixel_centers`=True not support in cpu device_target")
1913
1879
 
1914
1880
 
1915
1881
  class ResizeNearestNeighborGrad(Primitive):
@@ -1934,12 +1900,12 @@ class ResizeLinear1DGrad(Primitive):
1934
1900
  """
1935
1901
  Compute gradient of `ResizeLinear1D` operator.
1936
1902
 
1937
- Note:
1938
- This is an experimental feature and is subjected to change.
1903
+ .. warning::
1904
+ This is an experimental API that is subject to change.
1939
1905
 
1940
1906
  Args:
1941
1907
  coordinate_transformation_mode (string): Default is 'align_corners'. Describes how to transform the coordinate
1942
- in the resized tensor to the coordinate in the original tensor. Other optional: 'half_pixel', 'asymmetric'.
1908
+ in the resized tensor to the coordinate in the original tensor. Other optional: 'half_pixel'.
1943
1909
  """
1944
1910
 
1945
1911
  @prim_attr_register
@@ -1949,7 +1915,7 @@ class ResizeLinear1DGrad(Primitive):
1949
1915
  inputs=['grads', 'input_x'], outputs=['y'])
1950
1916
  validator.check_value_type(
1951
1917
  "coordinate_transformation_mode", coordinate_transformation_mode, [str], self.name)
1952
- validator.check_string(coordinate_transformation_mode, ["align_corners", "half_pixel", "asymmetric"],
1918
+ validator.check_string(coordinate_transformation_mode, ["align_corners", "half_pixel"],
1953
1919
  "coordinate_transformation_mode", self.name)
1954
1920
 
1955
1921
 
@@ -1960,7 +1926,7 @@ class ResizeNearestNeighborV2Grad(Primitive):
1960
1926
  Args:
1961
1927
  align_corners (bool): Whether the centers of the 4 corner pixels of the input
1962
1928
  and output tensors are aligned. Default: False.
1963
- half_pixel_centers (bool): Default :False.
1929
+ half_pixel_centers (bool): Default: False.
1964
1930
  data_format: An optional `string` that describes the format of the input `x` Defaults to `NHWC`.
1965
1931
  """
1966
1932
 
@@ -2012,7 +1978,7 @@ class UpsampleNearest3DGrad(Primitive):
2012
1978
  ValueError: If shape of `x` is not 5D.
2013
1979
 
2014
1980
  Supported Platforms:
2015
- ``GPU`` ``Ascend`` ``CPU``
1981
+ ``Ascend`` ``GPU`` ``CPU``
2016
1982
  """
2017
1983
  @prim_attr_register
2018
1984
  def __init__(self, input_size, output_size=None, scales=None):
@@ -2146,13 +2112,14 @@ class SliceGrad(PrimitiveWithInfer):
2146
2112
  def __infer__(self, dy, x, begin, size):
2147
2113
  dy_shape, x_shape, size_value, begin_v = dy['shape'], x['shape'], size['value'], begin['value']
2148
2114
  dy_shape_len = len(dy_shape)
2149
- if (size_value is not None) and not is_shape_unknown(x_shape):
2115
+ if size_value is not None and not is_shape_unknown(x_shape) and not is_shape_unknown(dy_shape):
2150
2116
  size_value = list(size_value)
2151
2117
  for i in range(dy_shape_len):
2152
2118
  if size_value[i] == -1:
2153
2119
  size_value[i] = x_shape[i] - begin_v[i]
2154
- validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name)
2155
- validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name)
2120
+ validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], validator.LE, self.name)
2121
+ validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]',
2122
+ size_value[i], validator.EQ, self.name)
2156
2123
 
2157
2124
  return {'shape': x_shape,
2158
2125
  'dtype': x['dtype'],
@@ -2175,6 +2142,7 @@ class SmoothL1LossGrad(Primitive):
2175
2142
 
2176
2143
  @prim_attr_register
2177
2144
  def __init__(self, beta=1.0, reduction='none'):
2145
+ self.add_prim_attr('sigma', self.beta)
2178
2146
  self.reduction = validator.check_string(
2179
2147
  reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2180
2148
 
@@ -2188,6 +2156,36 @@ class SoftMarginLossGrad(Primitive):
2188
2156
  self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2189
2157
 
2190
2158
 
2159
+ class StridedSliceV2Grad(Primitive):
2160
+ """
2161
+ Performs grad of StridedSliceV2 operation.
2162
+
2163
+ Inputs:
2164
+ - **shapex** (Tensor) - StridedSliceV2 shape of input
2165
+ - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
2166
+ constant value is allowed.
2167
+ - **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
2168
+ Only constant value is allowed.
2169
+ - **strides** (tuple[int]) - A tuple which represents the stride is continuously added
2170
+ before reaching the maximum location. Only constant value is allowed.
2171
+ - **dy** (Tensor) - The output of StridedSliceV2
2172
+
2173
+ Outputs:
2174
+ Tensor, the shape same as the input of StridedSliceV2
2175
+ """
2176
+
2177
+ @prim_attr_register
2178
+ def __init__(self,
2179
+ begin_mask=0,
2180
+ end_mask=0,
2181
+ ellipsis_mask=0,
2182
+ new_axis_mask=0,
2183
+ shrink_axis_mask=0):
2184
+ """Initialize StridedSliceV2Grad"""
2185
+ self.set_const_input_indexes([0])
2186
+ self.init_prim_io_names(inputs=['shapex', 'begin', 'end', 'strides', 'dy'], outputs=['output'])
2187
+
2188
+
2191
2189
  class StridedSliceGrad(Primitive):
2192
2190
  """
2193
2191
  Performs grad of StridedSlice operation.
@@ -2256,7 +2254,7 @@ class PadV3Grad(Primitive):
2256
2254
  """Initialize Padv3Grad"""
2257
2255
  self.add_prim_attr("cust_aicpu", self.name)
2258
2256
  self.init_prim_io_names(inputs=['x', 'paddings'], outputs=['y'])
2259
- validator.check_string(mode, ['reflect', 'edge'], 'mode', self.name)
2257
+ validator.check_string(mode, ['reflect', 'edge', 'circular'], 'mode', self.name)
2260
2258
  validator.check_bool(paddings_contiguous, "paddings_contiguous", self.name)
2261
2259
  self.set_const_input_indexes([1])
2262
2260
  self.mode = mode
@@ -2274,7 +2272,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
2274
2272
  @prim_attr_register
2275
2273
  def __init__(self):
2276
2274
  self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
2277
- self.add_prim_attr('primitive_target', 'CPU')
2275
+ self.set_device('CPU')
2278
2276
  self.tuple_setitem = Primitive('tuple_setitem')
2279
2277
 
2280
2278
  def __infer__(self, dy, split_num):
@@ -2352,20 +2350,20 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
2352
2350
  def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
2353
2351
  # dhy and dcy should be same shape
2354
2352
  validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
2355
- validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2356
- validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2357
- validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2358
- validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2359
- validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2360
- validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2361
- validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2362
- validator.check("dht shape", dht_shape, "c shape", c_shape, Rel.EQ, self.name)
2363
- validator.check("dct shape", dct_shape, "c shape", c_shape, Rel.EQ, self.name)
2364
- validator.check("it shape", it_shape, "c shape", c_shape, Rel.EQ, self.name)
2365
- validator.check("jt shape", jt_shape, "c shape", c_shape, Rel.EQ, self.name)
2366
- validator.check("ft shape", ft_shape, "c shape", c_shape, Rel.EQ, self.name)
2367
- validator.check("ot shape", ot_shape, "c shape", c_shape, Rel.EQ, self.name)
2368
- validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, Rel.EQ, self.name)
2353
+ validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), validator.EQ, self.name)
2354
+ validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), validator.EQ, self.name)
2355
+ validator.check("it rank", len(it_shape), "c rank", len(c_shape), validator.EQ, self.name)
2356
+ validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), validator.EQ, self.name)
2357
+ validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), validator.EQ, self.name)
2358
+ validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), validator.EQ, self.name)
2359
+ validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), validator.EQ, self.name)
2360
+ validator.check("dht shape", dht_shape, "c shape", c_shape, validator.EQ, self.name)
2361
+ validator.check("dct shape", dct_shape, "c shape", c_shape, validator.EQ, self.name)
2362
+ validator.check("it shape", it_shape, "c shape", c_shape, validator.EQ, self.name)
2363
+ validator.check("jt shape", jt_shape, "c shape", c_shape, validator.EQ, self.name)
2364
+ validator.check("ft shape", ft_shape, "c shape", c_shape, validator.EQ, self.name)
2365
+ validator.check("ot shape", ot_shape, "c shape", c_shape, validator.EQ, self.name)
2366
+ validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, validator.EQ, self.name)
2369
2367
 
2370
2368
  dgate_shape = (c_shape[0], 4 * c_shape[1])
2371
2369
  dct_1_shape = c_shape
@@ -2401,11 +2399,11 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
2401
2399
 
2402
2400
  def infer_shape(self, x_shape, h_shape, dgate_shape):
2403
2401
  validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
2404
- validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
2405
- validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
2406
- validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
2407
- validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
2408
- validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
2402
+ validator.check("h rank", len(h_shape), " x rank", len(x_shape), validator.EQ, self.name)
2403
+ validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), validator.EQ, self.name)
2404
+ validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], validator.EQ, self.name)
2405
+ validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], validator.EQ, self.name)
2406
+ validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], validator.EQ, self.name)
2409
2407
  input_size = x_shape[1]
2410
2408
  hidden_size = h_shape[1]
2411
2409
  dw_shape = (input_size + hidden_size, 4 * hidden_size)
@@ -2428,12 +2426,12 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
2428
2426
  @prim_attr_register
2429
2427
  def __init__(self, keep_prob):
2430
2428
  self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
2431
- self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
2429
+ self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, validator.INC_BOTH, "keep_prob", self.name)
2432
2430
 
2433
2431
  def infer_shape(self, dgate_shape, w_shape):
2434
2432
  validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
2435
2433
  validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
2436
- validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
2434
+ validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], validator.EQ, self.name)
2437
2435
  batch_size = dgate_shape[0]
2438
2436
  hidden_size = dgate_shape[1] // 4
2439
2437
  input_size = w_shape[0] - hidden_size
@@ -2457,7 +2455,7 @@ class InvGrad(Primitive):
2457
2455
  self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y'])
2458
2456
 
2459
2457
 
2460
- class LRNGrad(PrimitiveWithInfer):
2458
+ class LRNGrad(Primitive):
2461
2459
  """Computes gradients for LRN operation."""
2462
2460
 
2463
2461
  @prim_attr_register
@@ -2468,14 +2466,6 @@ class LRNGrad(PrimitiveWithInfer):
2468
2466
  validator.check_value_type("alpha", alpha, [float], self.name)
2469
2467
  validator.check_value_type("beta", beta, [float], self.name)
2470
2468
 
2471
- def infer_dtype(self, grads, x, y):
2472
- args = {"grads": grads, "x": x, "y": y}
2473
- validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name)
2474
- return x
2475
-
2476
- def infer_shape(self, grads, x, y):
2477
- return x
2478
-
2479
2469
 
2480
2470
  class MvlgammaGrad(Primitive):
2481
2471
  r"""
@@ -2557,7 +2547,7 @@ class SoftShrinkGrad(Primitive):
2557
2547
  def __init__(self, lambd=0.5):
2558
2548
  self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
2559
2549
  validator.check_value_type("lambd", lambd, [float], self.name)
2560
- validator.check_number("lambd", lambd, 0, Rel.GE, self.name)
2550
+ validator.check_number("lambd", lambd, 0, validator.GE, self.name)
2561
2551
 
2562
2552
 
2563
2553
  class CdistGrad(Primitive):
@@ -2594,7 +2584,7 @@ class PdistGrad(Primitive):
2594
2584
  ValueError: If dimension of `x` is not 2.
2595
2585
 
2596
2586
  Supported Platforms:
2597
- ``Ascend`` ``CPU`` ``GPU``
2587
+ ``Ascend`` ``GPU`` ``CPU``
2598
2588
  """
2599
2589
 
2600
2590
  @prim_attr_register
@@ -2669,7 +2659,7 @@ class HShrinkGrad(Primitive):
2669
2659
  TypeError: If dtype of `gradients` or `features` is neither float16 nor float32.
2670
2660
 
2671
2661
  Supported Platforms:
2672
- ``Ascend`` ``CPU`` ``GPU``
2662
+ ``Ascend`` ``GPU`` ``CPU``
2673
2663
  """
2674
2664
 
2675
2665
  @prim_attr_register
@@ -2726,7 +2716,7 @@ class Dilation2DBackpropInput(Primitive):
2726
2716
  ValueError: If `data_format` is not the str of 'NCHW'.
2727
2717
 
2728
2718
  Supported Platforms:
2729
- ``Ascend`` ``GPU``
2719
+ ``Ascend`` ``GPU`` ``CPU``
2730
2720
 
2731
2721
  Examples:
2732
2722
  (pad_mode="SAME", data_format="NCHW")
@@ -2843,7 +2833,7 @@ class Dilation2DBackpropFilter(Primitive):
2843
2833
 
2844
2834
 
2845
2835
  Supported Platforms:
2846
- ``Ascend`` ``GPU``
2836
+ ``Ascend`` ``GPU`` ``CPU``
2847
2837
 
2848
2838
  Examples:
2849
2839
  (pad_mode="SAME", data_format="NCHW")
@@ -2990,7 +2980,7 @@ class MultiMarginLossGrad(Primitive):
2990
2980
  def __init__(self, p=1, margin=1.0, reduction="mean"):
2991
2981
  """Initialize MultiMarginLossGrad"""
2992
2982
  self.p = validator.check_value_type('p', p, [int], self.name)
2993
- validator.check_int(p, {1, 2}, Rel.IN, 'p', self.name)
2983
+ validator.check_int(p, {1, 2}, validator.IN, 'p', self.name)
2994
2984
  self.margin = validator.check_value_type('margin', margin, [float], self.name)
2995
2985
  self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2996
2986
  self.init_prim_io_names(inputs=['y_grad', 'x', 'target', 'weight'], outputs=['x_grad'])
@@ -3043,7 +3033,7 @@ class UpsampleTrilinear3DGrad(Primitive):
3043
3033
  ValueError: If elements number of `input_size` is not 5.
3044
3034
 
3045
3035
  Supported Platforms:
3046
- ``Ascend`` ``CPU`` ``GPU``
3036
+ ``Ascend`` ``GPU`` ``CPU``
3047
3037
  """
3048
3038
  @prim_attr_register
3049
3039
  def __init__(self, input_size, output_size=None, scales=None, align_corners=False):
@@ -3111,7 +3101,7 @@ class GridSampler3DGrad(Primitive):
3111
3101
  ValueError: If the shape of `grad` is inconsistent with the shape of the output result of forward calculation.
3112
3102
 
3113
3103
  Supported Platforms:
3114
- ``CPU````GPU``
3104
+ ``GPU`` ``CPU``
3115
3105
  """
3116
3106
 
3117
3107
  @prim_attr_register
@@ -3156,7 +3146,7 @@ class SparseSegmentMeanGrad(Primitive):
3156
3146
  ValueError: If `indices` is out of range of `output_dim0`.
3157
3147
 
3158
3148
  Supported Platforms:
3159
- ``Ascend`` ``CPU``
3149
+ ``Ascend`` ``GPU`` ``CPU``
3160
3150
  """
3161
3151
 
3162
3152
  @prim_attr_register
@@ -3199,9 +3189,9 @@ class MaxUnpool2DGrad(Primitive):
3199
3189
  validator.check_value_type("pads", pads, [int, tuple], self.name)
3200
3190
  validator.check_value_type("output_shape", output_shape, [tuple], self.name)
3201
3191
  validator.check_string(data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
3202
- validator.check_int(len(ksize), 4, Rel.EQ, "ksize rank", self.name)
3203
- validator.check_int(len(strides), 4, Rel.EQ, "strides rank", self.name)
3204
- validator.check_int(len(pads), 4, Rel.EQ, "pads rank", self.name)
3192
+ validator.check_int(len(ksize), 4, validator.EQ, "ksize rank", self.name)
3193
+ validator.check_int(len(strides), 4, validator.EQ, "strides rank", self.name)
3194
+ validator.check_int(len(pads), 4, validator.EQ, "pads rank", self.name)
3205
3195
 
3206
3196
 
3207
3197
  class MaxUnpool3DGrad(Primitive):
@@ -3218,9 +3208,9 @@ class MaxUnpool3DGrad(Primitive):
3218
3208
  validator.check_value_type("pads", pads, [int, tuple], self.name)
3219
3209
  validator.check_value_type("output_shape", output_shape, [tuple], self.name)
3220
3210
  validator.check_string(data_format, ['NCDHW', 'NDHWC'], 'data_format', self.name)
3221
- validator.check_int(len(ksize), 5, Rel.EQ, "ksize rank", self.name)
3222
- validator.check_int(len(strides), 5, Rel.EQ, "strides rank", self.name)
3223
- validator.check_int(len(pads), 5, Rel.EQ, "pads rank", self.name)
3211
+ validator.check_int(len(ksize), 5, validator.EQ, "ksize rank", self.name)
3212
+ validator.check_int(len(strides), 5, validator.EQ, "strides rank", self.name)
3213
+ validator.check_int(len(pads), 5, validator.EQ, "pads rank", self.name)
3224
3214
 
3225
3215
 
3226
3216
  class FractionalAvgPoolGrad(Primitive):
@@ -3269,6 +3259,7 @@ class AdaptiveMaxPool3DGrad(Primitive):
3269
3259
  @prim_attr_register
3270
3260
  def __init__(self):
3271
3261
  """Initialize AdaptiveMaxPool3DGrad"""
3262
+ self.init_prim_io_names(inputs=['input_grad', 'x', 'argmax'], outputs=['output_grad'])
3272
3263
 
3273
3264
 
3274
3265
  class TraceGrad(Primitive):
@@ -3292,7 +3283,7 @@ class TraceGrad(Primitive):
3292
3283
  ValueError: If length of shape of `x_shape` is not equal to 2.
3293
3284
 
3294
3285
  Support Platforms:
3295
- ``Ascend`` ``CPU`` ``GPU``
3286
+ ``Ascend`` ``GPU`` ``CPU``
3296
3287
  """
3297
3288
 
3298
3289
  @prim_attr_register
@@ -3564,7 +3555,7 @@ class GridSampler2DGrad(Primitive):
3564
3555
  ValueError: If the shape of `grad` is inconsistent with the shape of the output result of forward calculation.
3565
3556
 
3566
3557
  Supported Platforms:
3567
- ``CPU````GPU``
3558
+ ``GPU`` ``CPU``
3568
3559
  """
3569
3560
 
3570
3561
  @prim_attr_register
@@ -3605,7 +3596,7 @@ class ResizeBicubicGrad(Primitive):
3605
3596
  ValueError: If `size` dim is not 4.
3606
3597
 
3607
3598
  Supported Platforms:
3608
- ``Ascend`` ``CPU`` ``GPU``
3599
+ ``Ascend`` ``GPU`` ``CPU``
3609
3600
  """
3610
3601
  @prim_attr_register
3611
3602
  def __init__(self, align_corners=False, half_pixel_centers=False):
@@ -3634,15 +3625,16 @@ class ResizeBicubicGrad(Primitive):
3634
3625
  validator.check_tensor_dtype_valid("original_image", original_image_dtype,
3635
3626
  [mstype.float32, mstype.float64], self.name)
3636
3627
  # check input shape rank
3637
- validator.check("grads rank", len(grads_shape), "expected", 4, Rel.EQ, self.name)
3638
- validator.check("original_image rank", len(original_image_shape), "expected", 4, Rel.EQ, self.name)
3639
- validator.check("batch_size equal", grads_shape[0], "expected", original_image_shape[0], Rel.EQ, self.name)
3640
- validator.check("channel equal", grads_shape[3], "expected", original_image_shape[3], Rel.EQ, self.name)
3628
+ validator.check("grads rank", len(grads_shape), "expected", 4, validator.EQ, self.name)
3629
+ validator.check("original_image rank", len(original_image_shape), "expected", 4, validator.EQ, self.name)
3630
+ validator.check("batch_size equal", grads_shape[0], "expected",
3631
+ original_image_shape[0], validator.EQ, self.name)
3632
+ validator.check("channel equal", grads_shape[3], "expected", original_image_shape[3], validator.EQ, self.name)
3641
3633
  # check original_image_shape and grads_shape
3642
3634
  validator.check("original_image[0] and grads[0]", original_image_shape[0],
3643
- "expected", grads_shape[0], Rel.EQ, self.name)
3635
+ "expected", grads_shape[0], validator.EQ, self.name)
3644
3636
  validator.check("original_image[3] and grads[3]", original_image_shape[3],
3645
- "expected", grads_shape[3], Rel.EQ, self.name)
3637
+ "expected", grads_shape[3], validator.EQ, self.name)
3646
3638
 
3647
3639
  batch_size = grads_shape[0]
3648
3640
  height = original_image_shape[1]
@@ -3660,38 +3652,41 @@ class SparseSliceGrad(Primitive):
3660
3652
 
3661
3653
  Inputs:
3662
3654
  - **backprop_val_grad** (Tensor) - A 1D Tensor.
3663
- The shape should be :math:`(n,)`.
3664
- - **indices** (Tensor) - A 2D Tensor of type int64. The indices of the SparseTensor.
3655
+ The shape should be :math:`(N,)`.
3656
+ - **indices** (Tensor) - A 2D Tensor (N x R matrix) of type int64. The indices of the SparseTensor.
3665
3657
  Support int64, each element value should be a non-negative int number. This tensor should be sorted.
3666
- The shape is :math:`(n, 2)`.
3658
+ The shape is :math:`(N, R)`.
3667
3659
  - **start** (Tensor) - A 1D Tensor of type int64, represents the start of the indices.
3668
- - **new_indices** (Tensor) - A 2D Tensor of type int64. The indices of the SparseTensor.
3660
+ The shape should be :math:`(R,)`.
3661
+ - **new_indices** (Tensor) - A 2D Tensor (N x C matrix) of type int64. The indices of the SparseTensor.
3669
3662
  Support int64, each element value should be a non-negative int number. This tensor should be sorted.
3670
- The shape is :math:`(n, 2)`.
3663
+ The shape is :math:`(N, C)`.
3671
3664
 
3672
3665
  Outputs:
3673
- - *y_grad_val: A Tensor. Has the same type as "backprop_val_grad".
3666
+ - *y_grad_val: A Tensor. Has the same type as `backprop_val_grad`.
3667
+ Has the same number as `indices`.
3674
3668
 
3675
3669
  Raises:
3676
3670
  TypeError: If the dtype of `indices`, `start`, `new_indices` are not int64.
3677
3671
  ValueError: If `indices`, `new_indices` are not 2-D tensor.
3678
3672
  ValueError: If `backprop_val_grad`, `start` is not a 1-D tensor.
3679
3673
  ValueError: If the number of `backprop_val_grad` is not corresponding to the number of `new_indices`.
3680
- RunTimeError: If the `backprop_val_grad` is not all backpropagated, because `indices` or `new_indices`
3674
+ ValueError: If the shape of `indices[1]` is not corresponding to `start[1]`.
3675
+ ValueError: If the shape of `indices[1]` is not corresponding to `new_indices[1]`.
3676
+ RuntimeError: If the `backprop_val_grad` is not all backpropagated, because `indices` or `new_indices`
3681
3677
  is not sorted.
3682
3678
 
3683
3679
  Supported Platforms:
3684
- ``GPU``
3680
+ ``Ascend`` ``GPU`` ``CPU``
3685
3681
  Examples:
3686
- >>> backprop_val_grad = Tensor([4, 2, 3])
3687
- >>> indices = Tensor([[0, 0], [0, 2], [1, 2], [1, 3], [2, 3], [2, 4]], dtype=ms.int64)
3688
- >>> values = Tensor([1, 2, 3, 4])
3689
- >>> start = Tensor([0, 0], dtype=ms.int64)
3690
- >>> new_indices = Tensor([0, 2], [1, 2], [1, 3], dtype=ms.int64)
3682
+ >>> backprop_val_grad = Tensor(np.array([1, 2, 3, 4]).astype(np.int64))
3683
+ >>> indices = Tensor(np.array([[0, 0], [0, 2], [1, 2], [1, 3], [2, 3], [2, 4]]).astype(np.int64))
3684
+ >>> start = Tensor(np.array([0, 0]).astype(np.int64))
3685
+ >>> new_indices = Tensor(np.array([[0, 2], [1, 2], [1, 3], [2, 4]]).astype(np.int64))
3691
3686
  >>> grad = SparseSliceGrad()
3692
3687
  >>> output = grad(backprop_val_grad, indices, start, new_indices)
3693
3688
  >>> print(output)
3694
- [0, 4, 2, 3, 0, 0]
3689
+ [0 1 2 3 0 4]
3695
3690
  """
3696
3691
 
3697
3692
  @prim_attr_register
@@ -3726,7 +3721,7 @@ class FractionalMaxPoolGradWithFixedKsize(Primitive):
3726
3721
  ValueError: If the second dimension size of `origin_input` and `out_backprop` is not equal.
3727
3722
 
3728
3723
  Supported Platforms:
3729
- ``Ascend`` ``CPU``
3724
+ ``Ascend`` ``GPU`` ``CPU``
3730
3725
  """
3731
3726
 
3732
3727
  @prim_attr_register
@@ -3736,9 +3731,139 @@ class FractionalMaxPoolGradWithFixedKsize(Primitive):
3736
3731
  self.init_prim_io_names(inputs=['origin_input', 'out_backprop', 'argmax'], outputs=['y'])
3737
3732
 
3738
3733
 
3734
+ class AffineGridGrad(Primitive):
3735
+ r"""
3736
+ Computes gradients for AffineGrid operation.
3737
+
3738
+ Args:
3739
+ align_corners (bool): if True, consider -1 and 1 to refer to the centers
3740
+ of the corner pixels rather than the image corners. Default: False.
3741
+
3742
+ Inputs:
3743
+ - **y_grad** (Tensor) - Data type must be float16 or float32.
3744
+ - **x_size** (tuple) - Data type must be int32 or int64.
3745
+
3746
+ Outputs:
3747
+ Tensor, with data type same as `y_grad`.
3748
+
3749
+ Supported Platforms:
3750
+ ``CPU``
3751
+
3752
+ Examples:
3753
+ >>> import mindspore.ops.operations._grad_ops as _grad_ops
3754
+ >>> affinegridgrad = _grad_ops.AffineGridGrad()
3755
+ >>> y_grad = Tensor(np.ones([1, 2, 2, 2]), mindspore.float32)
3756
+ >>> x_size = (1, 2, 2, 2)
3757
+ >>> x_grad = affinegridgrad(y_grad, x_size)
3758
+ >>> print(x_grad)
3759
+ [[[0. 0. 4.]
3760
+ [0. 0. 4.]]]
3761
+ """
3762
+
3763
+ @prim_attr_register
3764
+ def __init__(self, align_corners=False):
3765
+ """Initialize AffineGridGrad."""
3766
+ validator.check_value_type("align_corners", align_corners, [bool], self.name)
3767
+ self.init_prim_io_names(inputs=['y_grad', 'x_size'], outputs=['x_grad'])
3768
+
3769
+
3739
3770
  class HSigmoidGrad(Primitive):
3740
3771
  """Gets the gradient of HSigmoid operation."""
3741
3772
  @prim_attr_register
3742
3773
  def __init__(self):
3743
3774
  """Initialize HSigmoidGrad"""
3744
3775
  self.init_prim_io_names(inputs=['grads', 'input_x'], outputs=['output'])
3776
+
3777
+
3778
+ class GluGrad(Primitive):
3779
+ """
3780
+ Computes grad for Glu operation.
3781
+ """
3782
+
3783
+ @prim_attr_register
3784
+ def __init__(self, axis):
3785
+ self.add_prim_attr("cust_aicpu", self.name)
3786
+ self.init_prim_io_names(inputs=["grads", "x"], outputs=["y"])
3787
+ validator.check_value_type("axis", axis, [int], self.name)
3788
+
3789
+
3790
+ class CholeskyGrad(Primitive):
3791
+ r"""
3792
+ Computes the reverse mode backpropgated gradient of the Cholesky algorithm.
3793
+
3794
+ Inputs:
3795
+ - **x** (Tensor) - A tensor with float32 or float64 data type.
3796
+ - **grad** (Tensor) - A tensor with float32 or float64 data type. `x` should have
3797
+ the same dtype with `a`.
3798
+
3799
+ Outputs:
3800
+ Tensor, has the same dtype as `a` and `x`.
3801
+
3802
+ Raises:
3803
+ TypeError: If x is not Tensor.
3804
+ TypeError: If grad is not Tensor.
3805
+ TypeError: If dtype of input x and grad is not float64 nor float32,
3806
+ TypeError: If x has different dtype with grad.
3807
+ ValueError: If input tensor's last two dims are not equal,
3808
+ ValueError: If the shape of x and grad mismatch.
3809
+
3810
+ Supported Platforms:
3811
+ ``Ascend``
3812
+
3813
+ Examples:
3814
+ >>> x = Tensor(np.array([[4, 2],[2, 3]]), mstype.float64)
3815
+ >>> grad = Tensor(np.array([[4, 2],[2, 3]]), mstype.float64)
3816
+ >>> choleskygrad = G.CholeskyGrad()
3817
+ >>> output = choleskygrad(x, grad)
3818
+ >>> print (output)
3819
+ [[0.5 0. ]
3820
+ [0. 0.5]]
3821
+
3822
+ """
3823
+
3824
+ @prim_attr_register
3825
+ def __init__(self):
3826
+ """Initialize CholeskyGrad"""
3827
+ self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y'])
3828
+
3829
+
3830
+ class MapTensorGetGrad(Primitive):
3831
+ """
3832
+ Computes gradients for MapTensorGet operation.
3833
+
3834
+ Inputs:
3835
+ - **map_tensor** (MapTensor) - The input `map_tensor` of the forward operator MapTensorGet.
3836
+ - **key_tensor** (Tensor) - The input `key_tensor` of the forward operator MapTensorGet.
3837
+ - **default_value** (Scalar) - The input `default_value` of the forward operator MapTensorGet.
3838
+ - **grad** (Tensor) - The grad value according the forward operator MapTensorGet.
3839
+
3840
+ Outputs:
3841
+ - **output** (MapTensor) - MapTensor with grad values.
3842
+ """
3843
+ @prim_attr_register
3844
+ def __init__(self):
3845
+ """Initialize MapTensorGetGrad"""
3846
+ self.init_prim_io_names(inputs=['map_tensor', 'key_tensor', 'default_value', 'grad'], outputs=['output'])
3847
+ self.add_prim_attr('side_effect_mem', True)
3848
+
3849
+
3850
+ class ResizeV2Grad(Primitive):
3851
+ r"""
3852
+ Calculates the gradient of ResizeV2 operation.
3853
+
3854
+ Supported Platforms:
3855
+ ``CPU``
3856
+ """
3857
+
3858
+ @prim_attr_register
3859
+ def __init__(self, coordinate_transformation_mode="half_pixel", mode="nearest"):
3860
+ """Initialize ResizeV2Grad."""
3861
+ self.init_prim_io_names(inputs=["grads", "roi", "scales", "original_size"], outputs=["y"])
3862
+ self.add_prim_attr("nearest_mode", "floor")
3863
+ self.add_prim_attr("cubic_coeff_a", -0.75)
3864
+ validator.check_value_type(
3865
+ "coordinate_transformation_mode", coordinate_transformation_mode, [str], self.name)
3866
+ validator.check_string(coordinate_transformation_mode,
3867
+ ["align_corners", "half_pixel"], "coordinate_transformation_mode", self.name)
3868
+ validator.check_value_type("mode", mode, [str], self.name)
3869
+ validator.check_string(mode, ["nearest", "linear", "cubic"], "mode", self.name)