mindspore 1.10.0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (944) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
  3. mindspore/__init__.py +9 -4
  4. mindspore/_akg/akg/composite/build_module.py +11 -0
  5. mindspore/_akg/akg/config/repository_cuda.json +11 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
  7. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  8. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  9. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  10. mindspore/_check_jit_forbidden_api.py +102 -0
  11. mindspore/_checkparam.py +1066 -1001
  12. mindspore/_extends/builtin_operations.py +32 -4
  13. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  21. mindspore/_extends/parse/__init__.py +5 -3
  22. mindspore/_extends/parse/namespace.py +17 -2
  23. mindspore/_extends/parse/parser.py +193 -34
  24. mindspore/_extends/parse/resources.py +7 -8
  25. mindspore/_extends/parse/standard_method.py +1780 -435
  26. mindspore/_extends/parse/trope.py +3 -1
  27. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  28. mindspore/amp.py +53 -58
  29. mindspore/bin/cache_admin +0 -0
  30. mindspore/bin/cache_server +0 -0
  31. mindspore/boost/adasum.py +3 -2
  32. mindspore/boost/boost.py +2 -2
  33. mindspore/boost/boost_cell_wrapper.py +46 -26
  34. mindspore/boost/dim_reduce.py +6 -5
  35. mindspore/boost/grad_accumulation.py +2 -1
  36. mindspore/boost/group_loss_scale_manager.py +1 -1
  37. mindspore/common/__init__.py +11 -10
  38. mindspore/common/_decorator.py +2 -0
  39. mindspore/common/_register_for_adapter.py +55 -0
  40. mindspore/common/_stub_tensor.py +201 -0
  41. mindspore/common/_utils.py +57 -0
  42. mindspore/common/api.py +582 -297
  43. mindspore/common/dtype.py +66 -18
  44. mindspore/common/dump.py +2 -2
  45. mindspore/common/initializer.py +38 -1
  46. mindspore/common/jit_config.py +25 -13
  47. mindspore/common/mutable.py +53 -24
  48. mindspore/common/parameter.py +60 -37
  49. mindspore/common/seed.py +8 -24
  50. mindspore/common/sparse_tensor.py +927 -0
  51. mindspore/common/tensor.py +1627 -3900
  52. mindspore/communication/__init__.py +10 -5
  53. mindspore/communication/_comm_helper.py +78 -214
  54. mindspore/communication/_hccl_management.py +2 -1
  55. mindspore/communication/management.py +136 -47
  56. mindspore/config/op_info.config +501 -1008
  57. mindspore/config/super_bar_config.json +512 -0
  58. mindspore/context.py +291 -56
  59. mindspore/dataset/__init__.py +12 -8
  60. mindspore/dataset/audio/__init__.py +9 -9
  61. mindspore/dataset/audio/transforms.py +1090 -228
  62. mindspore/dataset/audio/utils.py +87 -39
  63. mindspore/dataset/audio/validators.py +223 -1
  64. mindspore/dataset/callback/ds_callback.py +17 -15
  65. mindspore/dataset/core/config.py +246 -17
  66. mindspore/dataset/core/py_util_helpers.py +4 -3
  67. mindspore/dataset/core/validator_helpers.py +10 -10
  68. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  69. mindspore/dataset/debug/debug_hook.py +65 -0
  70. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  71. mindspore/dataset/engine/__init__.py +7 -3
  72. mindspore/dataset/engine/cache_client.py +9 -9
  73. mindspore/dataset/engine/datasets.py +648 -477
  74. mindspore/dataset/engine/datasets_audio.py +165 -167
  75. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  76. mindspore/dataset/engine/datasets_text.py +492 -342
  77. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  78. mindspore/dataset/engine/datasets_vision.py +1224 -699
  79. mindspore/dataset/engine/graphdata.py +134 -69
  80. mindspore/dataset/engine/iterators.py +50 -9
  81. mindspore/dataset/engine/offload.py +52 -31
  82. mindspore/dataset/engine/samplers.py +27 -24
  83. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  84. mindspore/dataset/engine/validators.py +213 -52
  85. mindspore/dataset/text/__init__.py +10 -8
  86. mindspore/dataset/text/transforms.py +152 -57
  87. mindspore/dataset/text/utils.py +98 -49
  88. mindspore/dataset/text/validators.py +25 -0
  89. mindspore/dataset/transforms/__init__.py +4 -2
  90. mindspore/dataset/transforms/c_transforms.py +11 -13
  91. mindspore/dataset/transforms/py_transforms.py +2 -2
  92. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  93. mindspore/dataset/transforms/transforms.py +13 -15
  94. mindspore/dataset/transforms/validators.py +7 -7
  95. mindspore/dataset/utils/__init__.py +2 -1
  96. mindspore/dataset/utils/browse_dataset.py +13 -13
  97. mindspore/dataset/utils/line_reader.py +121 -0
  98. mindspore/dataset/vision/__init__.py +8 -7
  99. mindspore/dataset/vision/c_transforms.py +125 -126
  100. mindspore/dataset/vision/py_transforms.py +37 -37
  101. mindspore/dataset/vision/py_transforms_util.py +23 -20
  102. mindspore/dataset/vision/transforms.py +316 -315
  103. mindspore/dataset/vision/utils.py +313 -17
  104. mindspore/dataset/vision/validators.py +6 -6
  105. mindspore/default_config.py +0 -1
  106. mindspore/{compression → experimental}/__init__.py +6 -5
  107. mindspore/experimental/map_parameter.py +275 -0
  108. mindspore/include/OWNERS +0 -1
  109. mindspore/include/api/callback/callback.h +9 -13
  110. mindspore/include/api/callback/ckpt_saver.h +2 -2
  111. mindspore/include/api/callback/loss_monitor.h +2 -2
  112. mindspore/include/api/callback/lr_scheduler.h +5 -5
  113. mindspore/include/api/callback/time_monitor.h +2 -2
  114. mindspore/include/api/callback/train_accuracy.h +4 -6
  115. mindspore/include/api/cfg.h +19 -6
  116. mindspore/include/api/context.h +70 -9
  117. mindspore/include/api/delegate.h +8 -1
  118. mindspore/include/api/dual_abi_helper.h +8 -24
  119. mindspore/include/api/metrics/accuracy.h +2 -2
  120. mindspore/include/api/metrics/metrics.h +4 -3
  121. mindspore/include/api/model.h +9 -4
  122. mindspore/include/api/model_group.h +68 -0
  123. mindspore/include/api/model_parallel_runner.h +17 -17
  124. mindspore/include/api/net.h +12 -11
  125. mindspore/include/api/serialization.h +20 -4
  126. mindspore/include/api/status.h +7 -1
  127. mindspore/include/api/types.h +25 -21
  128. mindspore/include/api/visible.h +4 -0
  129. mindspore/include/c_api/model_c.h +5 -0
  130. mindspore/include/c_api/status_c.h +1 -1
  131. mindspore/include/dataset/config.h +1 -1
  132. mindspore/include/dataset/constants.h +14 -0
  133. mindspore/include/dataset/text.h +59 -0
  134. mindspore/include/dataset/vision.h +56 -117
  135. mindspore/include/dataset/vision_lite.h +102 -0
  136. mindspore/include/mindapi/base/type_id.h +42 -3
  137. mindspore/lib/libdnnl.so.2 +0 -0
  138. mindspore/lib/libicudata.so.69 +0 -0
  139. mindspore/lib/libicui18n.so.69 +0 -0
  140. mindspore/lib/libicuuc.so.69 +0 -0
  141. mindspore/lib/libmindspore.so +0 -0
  142. mindspore/lib/libmindspore_backend.so +0 -0
  143. mindspore/lib/libmindspore_common.so +0 -0
  144. mindspore/lib/libmindspore_core.so +0 -0
  145. mindspore/lib/libmindspore_glog.so.0 +0 -0
  146. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  147. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  148. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  149. mindspore/lib/libmindspore_shared_lib.so +0 -0
  150. mindspore/lib/libmpi_adapter.so +0 -0
  151. mindspore/lib/libmpi_collective.so +0 -0
  152. mindspore/lib/libnnacl.so +0 -0
  153. mindspore/lib/libopencv_core.so.4.5 +0 -0
  154. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  155. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  156. mindspore/lib/libps_cache.so +0 -0
  157. mindspore/lib/plugin/ascend/libakg.so +0 -0
  158. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  159. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  160. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  161. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  163. mindspore/lib/{libakg.so → plugin/cpu/libakg.so} +0 -0
  164. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  165. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  166. mindspore/log.py +28 -28
  167. mindspore/mindrecord/common/exceptions.py +2 -4
  168. mindspore/mindrecord/filereader.py +19 -1
  169. mindspore/mindrecord/filewriter.py +250 -88
  170. mindspore/mindrecord/mindpage.py +13 -13
  171. mindspore/mindrecord/shardheader.py +15 -15
  172. mindspore/mindrecord/shardreader.py +9 -0
  173. mindspore/mindrecord/shardwriter.py +29 -29
  174. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  175. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  176. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  177. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  178. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  179. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  180. mindspore/nn/__init__.py +1 -5
  181. mindspore/nn/cell.py +297 -234
  182. mindspore/nn/dynamic_lr.py +1 -1
  183. mindspore/nn/grad/cell_grad.py +17 -42
  184. mindspore/nn/layer/__init__.py +7 -4
  185. mindspore/nn/layer/activation.py +131 -88
  186. mindspore/nn/layer/basic.py +313 -613
  187. mindspore/nn/layer/channel_shuffle.py +103 -0
  188. mindspore/nn/layer/combined.py +1 -1
  189. mindspore/nn/layer/container.py +52 -6
  190. mindspore/nn/layer/conv.py +112 -43
  191. mindspore/nn/layer/dense.py +10 -9
  192. mindspore/nn/layer/embedding.py +36 -34
  193. mindspore/nn/layer/image.py +123 -27
  194. mindspore/nn/layer/math.py +108 -107
  195. mindspore/nn/layer/normalization.py +212 -366
  196. mindspore/nn/layer/padding.py +370 -42
  197. mindspore/nn/layer/pooling.py +1443 -219
  198. mindspore/nn/layer/rnn_cells.py +11 -16
  199. mindspore/nn/layer/rnns.py +38 -39
  200. mindspore/nn/layer/thor_layer.py +24 -25
  201. mindspore/nn/layer/timedistributed.py +5 -5
  202. mindspore/nn/layer/transformer.py +701 -0
  203. mindspore/nn/learning_rate_schedule.py +8 -8
  204. mindspore/nn/loss/__init__.py +9 -6
  205. mindspore/nn/loss/loss.py +678 -142
  206. mindspore/nn/metrics.py +53 -0
  207. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  208. mindspore/nn/optim/ada_grad.py +8 -8
  209. mindspore/nn/optim/adadelta.py +2 -3
  210. mindspore/nn/optim/adafactor.py +18 -14
  211. mindspore/nn/optim/adam.py +429 -87
  212. mindspore/nn/optim/adamax.py +5 -6
  213. mindspore/nn/optim/adasum.py +10 -8
  214. mindspore/nn/optim/asgd.py +7 -7
  215. mindspore/nn/optim/ftrl.py +81 -11
  216. mindspore/nn/optim/lamb.py +7 -8
  217. mindspore/nn/optim/lars.py +4 -4
  218. mindspore/nn/optim/lazyadam.py +82 -7
  219. mindspore/nn/optim/momentum.py +8 -7
  220. mindspore/nn/optim/optimizer.py +19 -10
  221. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  222. mindspore/nn/optim/rmsprop.py +3 -3
  223. mindspore/nn/optim/rprop.py +20 -16
  224. mindspore/nn/optim/sgd.py +21 -15
  225. mindspore/nn/optim/thor.py +23 -21
  226. mindspore/nn/probability/__init__.py +0 -2
  227. mindspore/nn/probability/bijector/bijector.py +7 -6
  228. mindspore/nn/probability/bijector/invert.py +4 -2
  229. mindspore/nn/probability/bijector/softplus.py +2 -2
  230. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  231. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  232. mindspore/nn/probability/distribution/__init__.py +6 -0
  233. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  234. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  235. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  236. mindspore/nn/probability/distribution/beta.py +1 -1
  237. mindspore/nn/probability/distribution/categorical.py +9 -9
  238. mindspore/nn/probability/distribution/cauchy.py +8 -8
  239. mindspore/nn/probability/distribution/distribution.py +12 -6
  240. mindspore/nn/probability/distribution/exponential.py +5 -5
  241. mindspore/nn/probability/distribution/gamma.py +3 -3
  242. mindspore/nn/probability/distribution/geometric.py +6 -5
  243. mindspore/nn/probability/distribution/gumbel.py +5 -5
  244. mindspore/nn/probability/distribution/half_normal.py +133 -0
  245. mindspore/nn/probability/distribution/laplace.py +128 -0
  246. mindspore/nn/probability/distribution/log_normal.py +0 -1
  247. mindspore/nn/probability/distribution/logistic.py +4 -5
  248. mindspore/nn/probability/distribution/normal.py +11 -15
  249. mindspore/nn/probability/distribution/poisson.py +6 -2
  250. mindspore/nn/probability/distribution/student_t.py +150 -0
  251. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  252. mindspore/nn/probability/distribution/uniform.py +5 -5
  253. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  254. mindspore/nn/reinforcement/tensor_array.py +2 -2
  255. mindspore/nn/sparse/sparse.py +8 -1
  256. mindspore/nn/wrap/cell_wrapper.py +55 -27
  257. mindspore/nn/wrap/grad_reducer.py +20 -11
  258. mindspore/nn/wrap/loss_scale.py +47 -30
  259. mindspore/numpy/array_creations.py +33 -22
  260. mindspore/numpy/array_ops.py +46 -42
  261. mindspore/numpy/logic_ops.py +6 -27
  262. mindspore/numpy/math_ops.py +26 -19
  263. mindspore/numpy/utils.py +1 -8
  264. mindspore/numpy/utils_const.py +112 -62
  265. mindspore/ops/__init__.py +6 -3
  266. mindspore/ops/_constants.py +0 -6
  267. mindspore/ops/_grad/__init__.py +2 -1
  268. mindspore/ops/_grad/grad_array_ops.py +209 -152
  269. mindspore/ops/_grad/grad_base.py +55 -17
  270. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  271. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  272. mindspore/ops/_grad/grad_implementations.py +21 -61
  273. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  274. mindspore/ops/_grad/grad_math_ops.py +306 -161
  275. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  276. mindspore/ops/_grad/grad_other_ops.py +1 -1
  277. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  278. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  279. mindspore/ops/_grad/grad_sparse.py +15 -9
  280. mindspore/ops/_grad_experimental/__init__.py +1 -0
  281. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  282. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  283. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  284. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  285. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  286. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  287. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  288. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  289. mindspore/ops/_op_impl/__init__.py +3 -3
  290. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  291. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  292. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  293. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  294. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  295. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  296. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  297. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  298. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  299. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  300. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  302. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  303. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  304. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  305. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  307. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  314. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  315. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  317. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  318. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  320. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  321. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  322. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  324. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  326. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  327. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  328. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  329. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  330. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  331. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  332. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  335. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  336. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  337. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  338. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  339. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  340. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  341. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  342. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  343. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  344. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  345. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  346. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  347. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  349. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  350. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  351. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  352. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  353. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  354. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  355. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  356. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  357. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  358. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  359. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  360. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  361. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  362. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  363. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  365. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  366. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  367. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  368. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  369. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  370. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  371. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  372. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  373. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  374. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  375. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  376. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  377. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  378. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  379. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  380. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  381. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  382. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  383. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  384. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  385. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  386. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  387. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  388. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  389. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  390. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  391. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  392. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  393. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  394. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  395. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  396. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  397. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  398. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  399. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  400. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  401. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  402. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  403. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  404. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  405. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  407. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  408. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  409. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  410. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  411. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  412. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  413. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  414. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  415. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  416. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  417. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  418. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  419. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  420. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  421. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  422. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  423. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  424. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  425. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  426. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  427. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  428. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  429. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  430. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  431. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  432. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  433. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  434. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  435. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  436. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  437. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  441. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  442. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  443. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  444. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  445. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  446. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  447. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  448. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  449. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  450. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  451. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  452. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  453. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  454. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  455. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  456. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  457. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  458. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  459. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  460. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  461. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  462. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  463. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  464. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  465. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  466. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  467. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  468. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  469. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  470. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  471. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  472. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  473. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  474. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  475. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  476. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  477. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  478. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  479. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  480. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  481. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  482. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  483. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  484. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  485. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  486. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  487. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  488. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  489. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  490. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  491. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  492. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  493. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  494. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  495. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  496. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  497. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  498. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  499. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  500. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  501. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  502. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  503. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  504. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  505. mindspore/ops/_primitive_cache.py +3 -2
  506. mindspore/ops/_register_for_op.py +11 -0
  507. mindspore/ops/_utils/__init__.py +1 -1
  508. mindspore/ops/_utils/utils.py +20 -41
  509. mindspore/ops/_vmap/__init__.py +2 -2
  510. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  511. mindspore/ops/_vmap/vmap_base.py +24 -10
  512. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  513. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  514. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  515. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  516. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  517. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  518. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  519. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  520. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  521. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  522. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  524. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  525. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  526. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  527. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  528. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  529. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  530. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  532. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  534. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  536. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  537. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  538. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  541. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  542. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  543. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  544. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  545. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  546. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  547. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  548. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  549. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  550. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  551. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  552. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  554. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  555. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  556. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  560. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  561. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  563. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  564. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  565. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  566. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  567. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  568. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  569. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  571. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  572. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  573. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  574. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  575. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  577. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  578. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  579. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  580. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  581. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  582. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  583. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  585. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  586. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  587. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  588. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  589. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  590. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  591. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  592. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  593. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  594. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  595. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  596. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  597. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  598. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  599. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  600. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  601. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  604. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  605. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  606. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  607. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  609. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  610. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  611. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  612. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  613. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  614. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  615. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  616. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  617. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  618. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  619. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  620. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  621. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  622. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  623. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  624. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  625. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  626. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  627. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  628. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  629. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  630. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  632. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  633. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  634. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  635. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  637. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  638. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  639. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  640. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  643. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  644. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  645. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  646. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  647. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  648. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  649. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  650. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  651. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  652. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  653. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  654. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  655. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  656. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  657. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  658. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  660. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  661. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  662. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  663. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  664. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  665. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  666. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  667. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  668. mindspore/ops/bprop_mindir/__init__.py +1 -4
  669. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  670. mindspore/ops/composite/__init__.py +12 -13
  671. mindspore/ops/composite/base.py +261 -254
  672. mindspore/ops/composite/env_ops.py +41 -0
  673. mindspore/ops/composite/math_ops.py +197 -156
  674. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  675. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  676. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  677. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  678. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  679. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  680. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  681. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  682. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  683. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  684. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  685. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  686. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  687. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  688. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  689. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  690. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  691. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  692. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  693. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  694. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  695. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  696. mindspore/ops/function/__init__.py +323 -8
  697. mindspore/ops/function/array_func.py +3511 -780
  698. mindspore/ops/function/clip_func.py +329 -0
  699. mindspore/ops/function/debug_func.py +6 -6
  700. mindspore/ops/function/grad/__init__.py +5 -1
  701. mindspore/ops/function/grad/grad_func.py +736 -65
  702. mindspore/ops/function/image_func.py +270 -0
  703. mindspore/ops/function/linalg_func.py +268 -8
  704. mindspore/ops/function/math_func.py +8032 -3164
  705. mindspore/ops/function/nn_func.py +5619 -1855
  706. mindspore/ops/function/other_func.py +115 -0
  707. mindspore/ops/function/parameter_func.py +11 -10
  708. mindspore/ops/function/random_func.py +939 -77
  709. mindspore/ops/function/sparse_func.py +249 -84
  710. mindspore/ops/function/sparse_unary_func.py +2303 -0
  711. mindspore/ops/function/spectral_func.py +146 -0
  712. mindspore/ops/function/vmap_func.py +114 -0
  713. mindspore/ops/functional.py +182 -254
  714. mindspore/ops/op_info_register.py +79 -34
  715. mindspore/ops/operations/__init__.py +210 -118
  716. mindspore/ops/operations/_csr_ops.py +7 -7
  717. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  718. mindspore/ops/operations/_grad_ops.py +447 -322
  719. mindspore/ops/operations/_inner_ops.py +547 -176
  720. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  721. mindspore/ops/operations/_ms_kernel.py +29 -27
  722. mindspore/ops/operations/_ocr_ops.py +11 -11
  723. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  724. mindspore/ops/operations/_quant_ops.py +186 -101
  725. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  726. mindspore/ops/operations/_scalar_ops.py +466 -0
  727. mindspore/ops/operations/_sequence_ops.py +1047 -0
  728. mindspore/ops/operations/_tensor_array.py +10 -11
  729. mindspore/ops/operations/_thor_ops.py +4 -4
  730. mindspore/ops/operations/array_ops.py +1428 -1226
  731. mindspore/ops/operations/comm_ops.py +180 -117
  732. mindspore/ops/operations/control_ops.py +4 -2
  733. mindspore/ops/operations/custom_ops.py +185 -98
  734. mindspore/ops/operations/debug_ops.py +92 -54
  735. mindspore/ops/operations/image_ops.py +406 -211
  736. mindspore/ops/operations/inner_ops.py +42 -53
  737. mindspore/ops/operations/linalg_ops.py +32 -29
  738. mindspore/ops/operations/math_ops.py +2076 -897
  739. mindspore/ops/operations/nn_ops.py +1282 -1252
  740. mindspore/ops/operations/other_ops.py +124 -278
  741. mindspore/ops/operations/random_ops.py +345 -178
  742. mindspore/ops/operations/rl_ops.py +8 -9
  743. mindspore/ops/operations/sparse_ops.py +502 -157
  744. mindspore/ops/operations/spectral_ops.py +107 -0
  745. mindspore/ops/primitive.py +192 -15
  746. mindspore/ops/vm_impl_registry.py +23 -2
  747. mindspore/parallel/__init__.py +6 -1
  748. mindspore/parallel/_auto_parallel_context.py +199 -92
  749. mindspore/parallel/_cell_wrapper.py +4 -2
  750. mindspore/parallel/_cost_model_context.py +3 -0
  751. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  752. mindspore/parallel/_offload_context.py +185 -0
  753. mindspore/parallel/_parallel_serialization.py +167 -28
  754. mindspore/parallel/_ps_context.py +9 -5
  755. mindspore/parallel/_recovery_context.py +1 -1
  756. mindspore/parallel/_tensor.py +9 -1
  757. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  758. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  759. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  760. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  761. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  762. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  763. mindspore/parallel/_utils.py +47 -7
  764. mindspore/parallel/algo_parameter_config.py +5 -1
  765. mindspore/parallel/checkpoint_transform.py +329 -0
  766. mindspore/parallel/shard.py +229 -0
  767. mindspore/profiler/__init__.py +2 -1
  768. mindspore/profiler/common/util.py +4 -3
  769. mindspore/profiler/common/validator/validate_path.py +2 -2
  770. mindspore/profiler/envprofiling.py +249 -0
  771. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  772. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  773. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  774. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  775. mindspore/profiler/parser/framework_parser.py +42 -16
  776. mindspore/profiler/parser/hccl_parser.py +158 -158
  777. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  778. mindspore/profiler/parser/integrator.py +18 -1579
  779. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  780. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  781. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  782. mindspore/profiler/parser/optime_parser.py +17 -18
  783. mindspore/profiler/parser/profiler_info.py +108 -0
  784. mindspore/profiler/parser/step_trace_parser.py +1 -1
  785. mindspore/profiler/profiling.py +396 -194
  786. mindspore/rewrite/__init__.py +6 -2
  787. mindspore/rewrite/api/node.py +51 -110
  788. mindspore/rewrite/api/node_type.py +10 -6
  789. mindspore/rewrite/api/pattern_engine.py +51 -7
  790. mindspore/rewrite/api/scoped_value.py +64 -53
  791. mindspore/rewrite/api/symbol_tree.py +108 -61
  792. mindspore/rewrite/api/tree_node_helper.py +2 -3
  793. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  794. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  795. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  796. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  797. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  798. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  799. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  800. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  801. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  802. mindspore/rewrite/common/__init__.py +2 -0
  803. mindspore/rewrite/common/event.py +1 -1
  804. mindspore/rewrite/common/observable.py +1 -1
  805. mindspore/rewrite/common/observer.py +1 -1
  806. mindspore/rewrite/common/rewrite_elog.py +35 -0
  807. mindspore/rewrite/namer.py +2 -2
  808. mindspore/rewrite/namespace.py +14 -4
  809. mindspore/rewrite/node.py +161 -13
  810. mindspore/rewrite/parser.py +0 -1
  811. mindspore/rewrite/parser_register.py +0 -1
  812. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  813. mindspore/rewrite/parsers/assign_parser.py +267 -67
  814. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  815. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  816. mindspore/rewrite/parsers/constant_parser.py +101 -0
  817. mindspore/rewrite/parsers/container_parser.py +88 -0
  818. mindspore/rewrite/parsers/for_parser.py +28 -15
  819. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  820. mindspore/rewrite/parsers/if_parser.py +11 -28
  821. mindspore/rewrite/parsers/module_parser.py +9 -6
  822. mindspore/rewrite/parsers/return_parser.py +3 -2
  823. mindspore/rewrite/sparsify/__init__.py +0 -0
  824. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  825. mindspore/rewrite/sparsify/sparsify.py +109 -0
  826. mindspore/rewrite/sparsify/utils.py +173 -0
  827. mindspore/rewrite/symbol_tree.py +322 -109
  828. mindspore/rewrite/symbol_tree_builder.py +45 -8
  829. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  830. mindspore/rewrite/topological_manager.py +1 -2
  831. mindspore/run_check/_check_version.py +209 -112
  832. mindspore/run_check/run_check.py +2 -1
  833. mindspore/scipy/linalg.py +13 -117
  834. mindspore/scipy/ops.py +5 -71
  835. mindspore/scipy/ops_grad.py +1 -25
  836. mindspore/scipy/ops_wrapper.py +1 -1
  837. mindspore/scipy/optimize/_bfgs.py +1 -1
  838. mindspore/scipy/optimize/_lagrange.py +200 -0
  839. mindspore/scipy/optimize/line_search.py +3 -2
  840. mindspore/scipy/optimize/minimize.py +43 -6
  841. mindspore/scipy/sparse/__init__.py +2 -2
  842. mindspore/scipy/sparse/linalg.py +5 -465
  843. mindspore/scipy/utils.py +2 -1
  844. mindspore/scipy/utils_const.py +7 -1
  845. mindspore/train/__init__.py +6 -4
  846. mindspore/train/_utils.py +28 -5
  847. mindspore/train/amp.py +321 -50
  848. mindspore/train/callback/__init__.py +3 -1
  849. mindspore/train/callback/_backup_and_restore.py +120 -0
  850. mindspore/train/callback/_callback.py +8 -8
  851. mindspore/train/callback/_checkpoint.py +12 -9
  852. mindspore/train/callback/_early_stop.py +13 -7
  853. mindspore/train/callback/_history.py +8 -8
  854. mindspore/train/callback/_lambda_callback.py +6 -6
  855. mindspore/train/callback/_landscape.py +36 -38
  856. mindspore/train/callback/_loss_monitor.py +12 -6
  857. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  858. mindspore/train/callback/_on_request_exit.py +212 -0
  859. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  860. mindspore/train/callback/_summary_collector.py +27 -19
  861. mindspore/train/callback/_time_monitor.py +13 -7
  862. mindspore/train/checkpoint_pb2.py +68 -8
  863. mindspore/train/data_sink.py +122 -33
  864. mindspore/train/dataset_helper.py +28 -87
  865. mindspore/train/loss_scale_manager.py +4 -7
  866. mindspore/{nn → train}/metrics/__init__.py +20 -20
  867. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  868. mindspore/{nn → train}/metrics/auc.py +4 -4
  869. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  870. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  871. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  872. mindspore/{nn → train}/metrics/dice.py +6 -5
  873. mindspore/{nn → train}/metrics/error.py +7 -5
  874. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  875. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  876. mindspore/{nn → train}/metrics/loss.py +4 -3
  877. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  878. mindspore/{nn → train}/metrics/metric.py +6 -5
  879. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  880. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  881. mindspore/{nn → train}/metrics/precision.py +5 -4
  882. mindspore/{nn → train}/metrics/recall.py +5 -4
  883. mindspore/{nn → train}/metrics/roc.py +7 -6
  884. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  885. mindspore/{nn → train}/metrics/topk.py +7 -5
  886. mindspore/train/mind_ir_pb2.py +339 -32
  887. mindspore/train/model.py +113 -84
  888. mindspore/train/serialization.py +547 -167
  889. mindspore/train/summary/_summary_adapter.py +1 -1
  890. mindspore/train/summary/summary_record.py +43 -12
  891. mindspore/train/train_thor/convert_utils.py +7 -1
  892. mindspore/train/train_thor/dataset_helper.py +3 -3
  893. mindspore/train/train_thor/model_thor.py +0 -4
  894. mindspore/version.py +1 -1
  895. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  896. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +899 -675
  897. mindspore/compression/common/constant.py +0 -124
  898. mindspore/compression/export/__init__.py +0 -19
  899. mindspore/compression/export/quant_export.py +0 -514
  900. mindspore/compression/quant/qat.py +0 -636
  901. mindspore/compression/quant/quant_utils.py +0 -462
  902. mindspore/compression/quant/quantizer.py +0 -68
  903. mindspore/nn/layer/quant.py +0 -1868
  904. mindspore/nn/layer/rnn_utils.py +0 -90
  905. mindspore/nn/probability/dpn/__init__.py +0 -22
  906. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  907. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  908. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  909. mindspore/nn/probability/infer/__init__.py +0 -22
  910. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  911. mindspore/nn/probability/infer/variational/svi.py +0 -84
  912. mindspore/nn/probability/toolbox/__init__.py +0 -22
  913. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  914. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  915. mindspore/nn/probability/transforms/__init__.py +0 -22
  916. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  917. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  918. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  919. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  920. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  921. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  922. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  923. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  924. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  925. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  926. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  927. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  928. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  929. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  930. mindspore/ops/composite/array_ops.py +0 -210
  931. mindspore/ops/composite/clip_ops.py +0 -238
  932. mindspore/ops/composite/random_ops.py +0 -426
  933. mindspore/ops/composite/vmap_ops.py +0 -38
  934. mindspore/ops/operations/sponge_ops.py +0 -3531
  935. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  936. mindspore/parallel/nn/__init__.py +0 -42
  937. mindspore/parallel/nn/loss.py +0 -22
  938. mindspore/parallel/nn/moe.py +0 -21
  939. mindspore/parallel/nn/op_parallel_config.py +0 -22
  940. mindspore/parallel/nn/transformer.py +0 -31
  941. mindspore/run_check/_check_deps_version.py +0 -84
  942. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  943. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  944. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -15,23 +15,24 @@
15
15
 
16
16
  """Define the grad rules of math related operations."""
17
17
 
18
- from functools import reduce
19
18
  import numpy as np
20
19
  import mindspore as ms
21
20
  from mindspore import nn
22
- from .. import functional as F
23
- from .. import operations as P
24
- from ..operations import _grad_ops as G
25
- from ..composite.multitype_ops.zeros_like_impl import zeros_like
26
- from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
27
- from .grad_base import bprop_getters
28
- from .grad_base import convert_to_tensor
29
- from ..primitive import constexpr
30
- from ..composite.multitype_ops import _constexpr_utils as const_utils
31
- from ..operations._inner_ops import DynamicStitch, DynamicBroadcastGradientArgs, DynamicBroadcastTo
32
- from ...common import Tensor
33
- from .._utils.utils import is_shape_unknown
34
- from ...common import dtype as mstype
21
+ from mindspore.common import Tensor
22
+ from mindspore.common import dtype as mstype
23
+ from mindspore.ops import functional as F
24
+ from mindspore.ops import operations as P
25
+ from mindspore.ops.operations import _grad_ops as G
26
+ from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
27
+ from mindspore.ops.functional import broadcast_gradient_args, reduced_shape, tuple_div
28
+ from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_invert_permutation
29
+ from mindspore.ops._grad.grad_base import convert_to_tensor
30
+ from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank
31
+ from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
32
+ from mindspore.ops.primitive import _primexpr
33
+ from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
34
+ from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass, DynamicBroadcastTo
35
+ from mindspore.ops.operations import array_ops as A
35
36
 
36
37
  shape_op = P.Shape()
37
38
  dyn_shape_op = P.TensorShape()
@@ -39,7 +40,7 @@ reduce_prod = P.ReduceProd()
39
40
  reduce_sum = P.ReduceSum()
40
41
  reshape = P.Reshape()
41
42
  tile = P.Tile()
42
- is_sub_class = P.IsSubClass()
43
+ is_sub_class = IsSubClass()
43
44
  to_array = P.TupleToArray()
44
45
  real_div = P.RealDiv()
45
46
 
@@ -56,17 +57,17 @@ def dyn_binop_grad_common(x, y, dx, dy):
56
57
  dx_origin_dtype = dx.dtype
57
58
  if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
58
59
  dx = F.cast(dx, mstype.float32)
59
- dx = reduce_sum(dx, rx)
60
+ dx = sum_grad_reduce_axis(dx, rx)
60
61
  dx = F.cast(dx, dx_origin_dtype)
61
62
  else:
62
- dx = reduce_sum(dx, rx)
63
+ dx = sum_grad_reduce_axis(dx, rx)
63
64
  dy_origin_dtype = dy.dtype
64
65
  if dy_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
65
66
  dy = F.cast(dy, mstype.float32)
66
- dy = reduce_sum(dy, ry)
67
+ dy = sum_grad_reduce_axis(dy, ry)
67
68
  dy = F.cast(dy, dy_origin_dtype)
68
69
  else:
69
- dy = reduce_sum(dy, ry)
70
+ dy = sum_grad_reduce_axis(dy, ry)
70
71
  reduce_dx = reshape(dx, shape_of_x)
71
72
  reduce_dy = reshape(dy, shape_of_y)
72
73
  return reduce_dx, reduce_dy
@@ -83,8 +84,8 @@ def dyn_binop_grad_common_with_shift(x, y, dx, dy, shift):
83
84
  broadcast_shape_of_x = shape_of_x[:-shift]
84
85
  broadcast_shape_of_y = shape_of_y[:-shift]
85
86
  rx, ry = DynamicBroadcastGradientArgs()(broadcast_shape_of_x, broadcast_shape_of_y)
86
- dx = reduce_sum(dx, rx)
87
- dy = reduce_sum(dy, ry)
87
+ dx = sum_grad_reduce_axis(dx, rx)
88
+ dy = sum_grad_reduce_axis(dy, ry)
88
89
  reduce_dx = reshape(dx, shape_of_x)
89
90
  reduce_dy = reshape(dy, shape_of_y)
90
91
  return reduce_dx, reduce_dy
@@ -111,7 +112,7 @@ def binop_grad_common(x, y, dx, dy):
111
112
  # if input shape is the same as dout shape, do not need to reduce
112
113
  reduce_dx = dx
113
114
  reduce_dy = dy
114
- if not (is_shape_unknown(shape_of_x) or is_shape_unknown(shape_of_y)):
115
+ if not (F.is_sequence_value_unknown(shape_of_x) or F.is_sequence_value_unknown(shape_of_y)):
115
116
  rx = broadcast_gradient_args(shape_of_x, shape_of_y)
116
117
  if rx[0]:
117
118
  # if dx is scalar whose shape is (), do not need reduce
@@ -124,11 +125,12 @@ def binop_grad_common(x, y, dx, dy):
124
125
  dy = _reduce_sum_with_cast(dy, rx[1])
125
126
  reduce_dy = reshape(dy, shape_of_y)
126
127
  return reduce_dx, reduce_dy
127
- if not shape_of_x or not shape_of_y:
128
+
129
+ if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
128
130
  # x or y is scalar
129
- if not shape_of_x:
131
+ if not isinstance(shape_of_x, tuple):
130
132
  reduce_dx = _reduce_sum_with_cast(dx, ())
131
- if not shape_of_y:
133
+ if not isinstance(shape_of_y, tuple):
132
134
  reduce_dy = _reduce_sum_with_cast(dy, ())
133
135
  return reduce_dx, reduce_dy
134
136
 
@@ -148,7 +150,7 @@ def binop_grad_common_with_shift(x, y, dx, dy, shift):
148
150
  # if input shape is the same as dout shape, do not need to reduce
149
151
  reduce_dx = dx
150
152
  reduce_dy = dy
151
- if not (is_shape_unknown(broadcast_shape_of_x) or is_shape_unknown(broadcast_shape_of_y)):
153
+ if not (F.is_sequence_value_unknown(broadcast_shape_of_x) or F.is_sequence_value_unknown(broadcast_shape_of_y)):
152
154
  rx = broadcast_gradient_args(broadcast_shape_of_x, broadcast_shape_of_y)
153
155
  if rx[0]:
154
156
  # if dx is scalar whose shape is (), do not need reduce
@@ -161,49 +163,56 @@ def binop_grad_common_with_shift(x, y, dx, dy, shift):
161
163
  dy = _reduce_sum_with_cast(dy, rx[1])
162
164
  reduce_dy = reshape(dy, shape_of_y)
163
165
  return reduce_dx, reduce_dy
164
- if not shape_of_x or not shape_of_y:
166
+
167
+ if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
165
168
  # x or y is scalar
166
- if not shape_of_x:
169
+ if not isinstance(shape_of_x, tuple):
167
170
  reduce_dx = _reduce_sum_with_cast(dx, ())
168
- if not shape_of_y:
171
+ if not isinstance(shape_of_y, tuple):
169
172
  reduce_dy = _reduce_sum_with_cast(dy, ())
170
173
  return reduce_dx, reduce_dy
171
174
 
172
175
  return dyn_binop_grad_common_with_shift(x, y, dx, dy, shift)
173
176
 
174
177
 
175
- def _dyn_reduced_shape(input_shape, axis):
178
+ def _dyn_reduced_shape(input_shape, axis, x):
176
179
  """Dynamic reduce shape"""
177
180
  input_shape = P.Cast()(input_shape, ms.int32)
178
- if isinstance(axis, Tensor):
179
- if is_shape_unknown(shape_op(axis)):
180
- expanded_axis = P.ExpandDims()(axis, 1)
181
- update = P.Cast()(P.OnesLike()(axis), ms.int32)
182
- return P.TensorScatterUpdate()(input_shape, expanded_axis, update)
183
- input_rank = P.Rank()(input_shape)
184
- real_axis = (axis + input_rank) % input_rank
185
- axis_shape = shape_op(real_axis)
181
+ if x is not None and not F.is_sequence_shape_unknown(shape_op(x)):
182
+ input_rank = len(shape_op(x))
186
183
  else:
187
- real_axis = ()
188
- input_rank = len(input_shape)
189
- if isinstance(axis, int):
190
- axis = (axis,)
191
- elif not axis:
192
- axis = range(input_rank)
193
- for i in axis:
194
- real_axis += ((i + input_rank) % input_rank,)
195
- axis_shape = (len(real_axis),)
196
- return DynamicStitch()([to_array(range(input_rank)), to_array(real_axis)],
197
- [input_shape, P.Fill()(ms.int32, axis_shape, 1)])
184
+ input_rank = dyn_rank(x)
185
+ input_rank = P.Cast()(input_rank, ms.int32)
186
+
187
+ if (isinstance(axis, tuple) and axis == ()) or (isinstance(axis, list) and axis == []):
188
+ res_shape = P.ExpandDims()(input_rank, 0)
189
+ return dyn_ones(res_shape, res_shape.dtype)
190
+
191
+ if isinstance(axis, int):
192
+ axis = (axis,)
193
+
194
+ real_axis = axis
195
+ if not isinstance(axis, Tensor):
196
+ real_axis = Tensor(axis, ms.int32)
197
+
198
+ real_axis = (real_axis + input_rank) % input_rank
199
+ if real_axis.ndim == 0:
200
+ real_axis = P.ExpandDims()(real_axis, 0)
201
+ expanded_axis = P.ExpandDims()(real_axis, 1)
202
+ expanded_axis = P.Cast()(expanded_axis, ms.int32)
203
+ update = P.Cast()(P.OnesLike()(real_axis), ms.float32)
204
+ input_shape = P.Cast()(input_shape, ms.float32)
205
+ return P.TensorScatterUpdate()(input_shape, expanded_axis, update)
198
206
 
199
207
 
200
208
  def _sum_grad(x, axis, dout):
201
209
  """Grad definition for `Sum` operation."""
202
210
  input_shape = shape_op(x)
203
211
  is_mutable, axis = convert_to_tensor(axis)
204
- if is_shape_unknown(input_shape) or is_mutable:
212
+ if F.is_sequence_value_unknown(input_shape) or is_mutable:
205
213
  input_shape = dyn_shape_op(x)
206
- output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis)
214
+ output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
215
+ output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
207
216
  grad = reshape(dout, output_shape_kept_dims)
208
217
  return DynamicBroadcastTo()(grad, input_shape)
209
218
 
@@ -216,15 +225,40 @@ def _sum_grad(x, axis, dout):
216
225
  def _min_or_max_grad(x, axis, out, dout):
217
226
  """Grad definition for `Min` and `Max` operations."""
218
227
  input_shape = shape_op(x)
219
- output_shape_kept_dims = reduced_shape(input_shape, axis)
228
+ output_shape_kept_dims = ()
229
+ if F.is_sequence_value_unknown(input_shape):
230
+ input_shape = dyn_shape_op(x)
231
+ output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
232
+ output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
233
+ else:
234
+ output_shape_kept_dims = reduced_shape(input_shape, axis)
235
+
220
236
  y = reshape(out, output_shape_kept_dims)
221
237
  grad = reshape(dout, output_shape_kept_dims)
222
238
  indicators = F.cast(F.equal(y, x), F.dtype(grad))
223
- min_num = F.cast(F.scalar_to_array(1e-24), F.dtype(grad))
239
+ min_num = F.cast(F.scalar_to_tensor(1e-24), F.dtype(grad))
224
240
  num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
225
241
  return indicators / num_selected * grad
226
242
 
227
243
 
244
+ def _onehot_with_neg_axis(axis, indices, depth, on_value_dtype):
245
+ """onehot support tensor axis"""
246
+ depth_range = P.Range()(F.cast(0, depth.dtype), depth, F.cast(1, depth.dtype))
247
+ indices_expand = P.ExpandDims()(indices, axis)
248
+ indices_expand_rank = dyn_rank_1d(indices_expand)
249
+ broad_shape = dyn_ones(indices_expand_rank, mstype.int64)
250
+ # It should use int64 dtype, but the TensorScatterUpdate op does not support the int64
251
+ # dtype on Ascend device, so the float32 dtype is used here.
252
+ update_dtype = mstype.float32
253
+ broad_shape = dyn_ones(indices_expand_rank, update_dtype)
254
+ broad_shape[axis] = F.cast(depth, update_dtype)
255
+ broad_shape = F.cast(broad_shape, mstype.int64)
256
+ depth_broad = P.Reshape()(depth_range, broad_shape)
257
+ one_hot_bool = P.Equal()(indices_expand, depth_broad)
258
+ one_hot_res = F.cast(one_hot_bool, on_value_dtype)
259
+ return one_hot_res
260
+
261
+
228
262
  def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
229
263
  """ArgMinWiwhValue and ArgMaxWithValue grad."""
230
264
  expand = P.ExpandDims()
@@ -232,53 +266,48 @@ def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
232
266
  x_shape = F.shape(x)
233
267
  x_dim = len(x_shape)
234
268
  x_axis = axis
269
+ onehot_axis_is_neg = False
235
270
  if x_axis < 0:
236
- x_axis = axis + x_dim
271
+ if not F.is_sequence_shape_unknown(x_shape):
272
+ x_axis = axis + x_dim
273
+ else:
274
+ onehot_axis_is_neg = True
237
275
  onehot_axis = x_axis
238
- depth = 1
239
- if x_shape:
240
- depth = x_shape[axis]
241
276
  if keep_dims:
242
277
  dout_expand = dout[1]
243
278
  out = op(x)
244
279
  else:
245
280
  dout_expand = expand(dout[1], onehot_axis)
246
- if onehot_axis >= len(shape_op(out[0])):
247
- onehot_axis = -1
248
- onehot = P.OneHot(onehot_axis)
281
+ out_shape = shape_op(out[0])
282
+ if not F.is_sequence_shape_unknown(out_shape):
283
+ if onehot_axis >= len(out_shape):
284
+ onehot_axis = -1
249
285
  type_x = F.dtype(x)
250
- on_value = F.cast(F.scalar_to_array(1.0), type_x)
251
- off_value = F.cast(F.scalar_to_array(0.0), type_x)
252
- dx = dout_expand * onehot(out[0], depth, on_value, off_value)
253
- if not x_shape:
254
- dx = squeeze(dx)
286
+ on_value = F.cast(F.scalar_to_tensor(1.0), type_x)
287
+ off_value = F.cast(F.scalar_to_tensor(0.0), type_x)
288
+ if not F.is_sequence_value_unknown(x_shape):
289
+ depth = 1
290
+ if x_shape:
291
+ depth = x_shape[axis]
292
+ onehot = P.OneHot(onehot_axis)
293
+ dx = dout_expand * onehot(out[0], depth, on_value, off_value)
294
+ if not x_shape:
295
+ dx = squeeze(dx)
296
+ return dx
297
+ x_tensor_shape = P.TensorShape()(x)
298
+ depth = x_tensor_shape[axis]
299
+ if not onehot_axis_is_neg:
300
+ onehot = P.OneHot(onehot_axis)
301
+ dx = dout_expand * onehot(out[0], depth, on_value, off_value)
302
+ else:
303
+ if out[0].value is not None:
304
+ # It is a temporary method: In the pynative mode, out may be a constant tensor. Constant
305
+ # folding occurs in ExpandDims op, but such scenarios are not supported currently.
306
+ out = op(x)
307
+ dx = dout_expand * _onehot_with_neg_axis(onehot_axis, out[0], depth, on_value.dtype)
255
308
  return dx
256
309
 
257
310
 
258
- @bprop_getters.register(P.MatMul)
259
- def bprop_matmul(self):
260
- """Grad definition for `MatMul` operation."""
261
- ta = self.transpose_a
262
- tb = self.transpose_b
263
- mul1 = P.MatMul(transpose_a=(ta and tb),
264
- transpose_b=(ta or (not tb)))
265
- mul2 = P.MatMul(transpose_a=((not ta) or tb),
266
- transpose_b=(ta and tb))
267
-
268
- def bprop(x, w, out, dout):
269
- if ta:
270
- dx = mul1(w, dout)
271
- else:
272
- dx = mul1(dout, w)
273
- if tb:
274
- dw = mul2(dout, x)
275
- else:
276
- dw = mul2(x, dout)
277
- return dx, dw
278
-
279
- return bprop
280
-
281
-
282
311
  @bprop_getters.register(P.BatchMatMul)
283
312
  def bprop_batchmatmul(self):
284
313
  """Grad definition for `BatchMatMul` operation."""
@@ -303,16 +332,6 @@ def bprop_batchmatmul(self):
303
332
  return bprop
304
333
 
305
334
 
306
- @bprop_getters.register(P.Add)
307
- def get_bprop_add(self):
308
- """Grad definition for `Add` operation."""
309
-
310
- def bprop(x, y, out, dout):
311
- return binop_grad_common(x, y, dout, dout)
312
-
313
- return bprop
314
-
315
-
316
335
  @bprop_getters.register(P.TensorAdd)
317
336
  def get_bprop_tensor_add(self):
318
337
  """Grad definition for `Add` operation."""
@@ -339,35 +358,14 @@ def get_bprop_matrix_inverse(self):
339
358
  return bprop
340
359
 
341
360
 
342
- @bprop_getters.register(P.Neg)
343
- def get_bprop_neg(self):
344
- """Grad definition for `Neg` operation."""
345
- neg_grad = P.Neg()
346
-
347
- def bprop(x, out, dout):
348
- dx = neg_grad(dout)
349
- return (dx,)
350
-
351
- return bprop
352
-
353
-
354
- @bprop_getters.register(P.Sub)
355
- def get_bprop_sub(self):
356
- """Grad definition for `Sub` operation."""
357
- neg_func = P.Neg()
358
-
359
- def bprop(x, y, out, dout):
360
- return binop_grad_common(x, y, dout, neg_func(dout))
361
-
362
- return bprop
363
-
364
-
365
361
  @bprop_getters.register(P.Mul)
366
362
  def get_bprop_mul(self):
367
363
  """Grad definition for `Mul` operation."""
368
364
  mul_func = P.Mul()
369
365
 
370
366
  def bprop(x, y, out, dout):
367
+ if x.dtype in (mstype.complex64, mstype.complex128):
368
+ raise TypeError("For 'Mul', gradient not support for complex type currently.")
371
369
  bc_dx = mul_func(y, dout)
372
370
  bc_dy = mul_func(x, dout)
373
371
  return binop_grad_common(x, y, bc_dx, bc_dy)
@@ -383,6 +381,8 @@ def get_bprop_real_div(self):
383
381
  mul_op = P.Mul()
384
382
 
385
383
  def bprop(x, y, out, dout):
384
+ if x.dtype in (mstype.complex64, mstype.complex128):
385
+ raise TypeError("For 'RealDiv', gradient not support for complex type currently.")
386
386
  bc_x = div_op(dout, y)
387
387
  bc_y = neg(mul_op(bc_x, out))
388
388
  return binop_grad_common(x, y, bc_x, bc_y)
@@ -443,7 +443,10 @@ def get_bprop_floor(self):
443
443
  dtype_ = P.DType()
444
444
 
445
445
  def bprop(x, out, dout):
446
- bc_x = fill_(dtype_(x), shape_(x), 0.)
446
+ if F.is_sequence_value_unknown(shape_(x)):
447
+ bc_x = zeros_like(x)
448
+ else:
449
+ bc_x = fill_(dtype_(x), shape_(x), 0.)
447
450
  return (bc_x,)
448
451
 
449
452
  return bprop
@@ -457,7 +460,10 @@ def get_bprop_ceil(self):
457
460
  dtype_ = P.DType()
458
461
 
459
462
  def bprop(x, out, dout):
460
- bc_x = fill_(dtype_(x), shape_(x), 0.)
463
+ if F.is_sequence_value_unknown(shape_(x)):
464
+ bc_x = zeros_like(x)
465
+ else:
466
+ bc_x = fill_(dtype_(x), shape_(x), 0.)
461
467
  return (bc_x,)
462
468
 
463
469
  return bprop
@@ -473,6 +479,36 @@ def get_bprop_floordiv(self):
473
479
  return bprop
474
480
 
475
481
 
482
+ @bprop_getters.register(P.BitwiseAnd)
483
+ def get_bprop_bitwiseand(self):
484
+ """Grad definition for `BitwiseAnd` operation."""
485
+
486
+ def bprop(x, y, out, dout):
487
+ return zeros_like(x), zeros_like(y)
488
+
489
+ return bprop
490
+
491
+
492
+ @bprop_getters.register(P.BitwiseOr)
493
+ def get_bprop_bitwiseor(self):
494
+ """Grad definition for `BitwiseOr` operation."""
495
+
496
+ def bprop(x, y, out, dout):
497
+ return zeros_like(x), zeros_like(y)
498
+
499
+ return bprop
500
+
501
+
502
+ @bprop_getters.register(P.BitwiseXor)
503
+ def get_bprop_bitwisexor(self):
504
+ """Grad definition for `BitwiseXor` operation."""
505
+
506
+ def bprop(x, y, out, dout):
507
+ return zeros_like(x), zeros_like(y)
508
+
509
+ return bprop
510
+
511
+
476
512
  @bprop_getters.register(P.FloorMod)
477
513
  def get_bprop_floormod(self):
478
514
  """Grad definition for `FloorMod` operation."""
@@ -529,7 +565,12 @@ def get_bprop_square(self):
529
565
 
530
566
  def bprop(x, out, dout):
531
567
  temp = mul_func(dout, x)
532
- dx = mul_func(fill_func(dtype(temp), shape_op(x), 2.0), temp)
568
+ shape_x = shape_op(x)
569
+ if F.is_sequence_value_unknown(shape_x):
570
+ fill_value = dyn_fill(dtype(temp), dyn_shape_op(x), 2.0)
571
+ else:
572
+ fill_value = fill_func(dtype(temp), shape_x, 2.0)
573
+ dx = mul_func(fill_value, temp)
533
574
  return (dx,)
534
575
 
535
576
  return bprop
@@ -575,8 +616,15 @@ def get_bprop_square_sum_all(self):
575
616
  def bprop(x, y, out, dout):
576
617
  temp_x = mul_func(dout[0], x)
577
618
  temp_y = mul_func(dout[1], y)
578
- dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
579
- dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
619
+ if F.is_sequence_value_unknown(shape_op(x)):
620
+ dx = mul_func(dyn_fill(dtype(temp_x), dyn_shape_op(x), 2.0), temp_x)
621
+ else:
622
+ dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
623
+
624
+ if F.is_sequence_value_unknown(shape_op(y)):
625
+ dy = mul_func(dyn_fill(dtype(temp_y), dyn_shape_op(y), 2.0), temp_y)
626
+ else:
627
+ dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
580
628
  return (dx, dy)
581
629
 
582
630
  return bprop
@@ -716,8 +764,14 @@ def get_bprop_pow(self):
716
764
  ln = P.Log()
717
765
 
718
766
  def bprop(x, power, out, dout):
767
+ if x.dtype in (mstype.complex64, mstype.complex128):
768
+ raise TypeError("For 'Pow', gradient not support for complex type currently.")
719
769
  bc_dx = power * pow_op(x, power - 1.0) * dout
720
- x = F.select(x < 0, F.fill(F.dtype(x), F.shape(x), 1), x)
770
+ shape_x = shape_op(x)
771
+ if F.is_sequence_value_unknown(shape_x):
772
+ x = F.select(x < 0, dyn_fill(F.dtype(x), dyn_shape_op(x), 1), x)
773
+ else:
774
+ x = F.select(x < 0, F.fill(F.dtype(x), F.shape(x), 1), x)
721
775
  bc_dpower = out * ln(x) * dout
722
776
  return binop_grad_common(x, power, bc_dx, bc_dpower)
723
777
 
@@ -808,21 +862,31 @@ def get_bprop_cumsum(self):
808
862
  return bprop
809
863
 
810
864
 
811
- @constexpr
865
+ @_primexpr
812
866
  def _split_shape_index(input_shape, axis):
813
867
  """Calculate reduce_prod grad transpose indices and perm shape."""
814
868
  rank = len(input_shape)
815
869
  if isinstance(axis, int):
816
870
  axis = tuple([axis])
817
871
  reduction_indices = tuple([(i + rank) % rank for i in axis])
818
- other_indices = tuple(set(range(rank)) - set(reduction_indices))
819
- reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
820
- other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
872
+ other_indices_list = []
873
+ for i in range(rank):
874
+ if i not in reduction_indices and i not in other_indices_list:
875
+ other_indices_list.append(i)
876
+ other_indices = tuple(other_indices_list)
877
+ reduced_list = [1] + [input_shape[i] for i in reduction_indices]
878
+ other_list = [1] + [input_shape[i] for i in other_indices]
879
+ reduced_num = 1
880
+ for i in reduced_list:
881
+ reduced_num = reduced_num * i
882
+ other_num = 1
883
+ for i in other_list:
884
+ other_num = other_num * i
821
885
  perm = reduction_indices + other_indices
822
886
  return tuple([reduced_num, other_num]), perm
823
887
 
824
888
 
825
- @constexpr
889
+ @_primexpr
826
890
  def _invert_permutation(perm):
827
891
  """Calculate invert permutation."""
828
892
  out = [0] * len(perm)
@@ -831,6 +895,26 @@ def _invert_permutation(perm):
831
895
  return tuple(out)
832
896
 
833
897
 
898
+ def _split_dyn_shape_index(x, axis):
899
+ """Calculate reduce prod grad invert permutation."""
900
+ input_shape = dyn_shape_op(x)
901
+ rank = dyn_rank(x)
902
+ if not isinstance(axis, Tensor):
903
+ axis = Tensor(axis, dtype=mstype.int64)
904
+ reduction_indices = reshape(axis, (-1,))
905
+ reduction_indices = (reduction_indices + rank) % rank
906
+ reduced = P.Cast()(reduction_indices, mstype.int64)
907
+
908
+ start = Tensor(0, dtype=mstype.int64)
909
+ delta = Tensor(1, dtype=mstype.int64)
910
+ idx = P.Range()(start, rank, delta)
911
+ other, _ = A.ListDiff()(idx, reduced)
912
+ perm = P.Concat()((reduced, other))
913
+ reduced_num = reduce_prod(P.Cast()(P.Gather()(input_shape, reduced, 0), mstype.int64), ())
914
+ other_num = reduce_prod(P.Cast()(P.Gather()(input_shape, other, 0), mstype.int64), ())
915
+ return (reduced_num, other_num), perm
916
+
917
+
834
918
  @bprop_getters.register(P.ReduceProd)
835
919
  def get_bprop_reduceprod(self):
836
920
  """Grad definition for `ReduceProd` operation."""
@@ -840,17 +924,35 @@ def get_bprop_reduceprod(self):
840
924
 
841
925
  def bprop(x, axis, out, dout):
842
926
  """Grad definition for `Product` operation."""
927
+ if x.dtype in (mstype.complex64, mstype.complex128):
928
+ raise TypeError("The 'ReduceProd', gradient not support for complex type currently.")
843
929
  # Expand dout to full input shape
844
930
  input_shape = shape_op(x)
845
- output_shape_kept_dims = reduced_shape(input_shape, axis)
931
+ if input_shape == ():
932
+ dx = _sum_grad(x, axis, dout)
933
+ return dx, zeros_like(axis)
934
+
935
+ if F.is_sequence_value_unknown(input_shape):
936
+ input_shape = dyn_shape_op(x)
937
+ input_shape = P.Cast()(input_shape, ms.int64)
938
+ output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
939
+ output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int64)
940
+ else:
941
+ output_shape_kept_dims = reduced_shape(input_shape, axis)
942
+
846
943
  dout = reshape(dout, output_shape_kept_dims)
847
- tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
848
- grad = tile(dout, tile_scaling)
849
944
 
850
945
  # Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
851
- pack_shape, perm = _split_shape_index(input_shape, axis)
946
+ if F.is_sequence_value_unknown(shape_op(x)):
947
+ pack_shape, perm = _split_dyn_shape_index(x, axis)
948
+ else:
949
+ pack_shape, perm = _split_shape_index(shape_op(x), axis)
950
+
852
951
  permuted = transpose(x, perm)
853
952
  permuted_shape = shape_op(permuted)
953
+ if F.is_sequence_value_unknown(permuted_shape):
954
+ permuted_shape = dyn_shape_op(permuted)
955
+ pack_shape = create_tensor_by_element(pack_shape)
854
956
  reshaped = reshape(permuted, pack_shape)
855
957
 
856
958
  # Calculate product, leaving out the current entry
@@ -860,7 +962,14 @@ def get_bprop_reduceprod(self):
860
962
 
861
963
  # Invert the transpose and reshape operations.
862
964
  # Make sure to set the statically known shape information through a reshape.
863
- out = transpose(y, _invert_permutation(perm)) * grad
965
+ if F.is_sequence_value_unknown(shape_op(permuted)):
966
+ dout = DynamicBroadcastTo()(dout, input_shape)
967
+ out = transpose(y, dyn_invert_permutation(perm)) * dout
968
+ else:
969
+ tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
970
+ grad = tile(dout, tile_scaling)
971
+ out = transpose(y, _invert_permutation(perm)) * grad
972
+
864
973
  dx = reshape(out, input_shape)
865
974
  return dx, zeros_like(axis)
866
975
 
@@ -908,6 +1017,8 @@ def get_bprop_reducemax(self):
908
1017
  """Grad definition for `Max` operation."""
909
1018
 
910
1019
  def bprop(x, axis, out, dout):
1020
+ if x.dtype in (mstype.complex64, mstype.complex128):
1021
+ raise TypeError("The 'ReduceMax', gradient not support for complex type currently.")
911
1022
  dx = _min_or_max_grad(x, axis, out, dout)
912
1023
  return (dx, zeros_like(axis))
913
1024
 
@@ -933,6 +1044,8 @@ def get_bprop_reducemin(self):
933
1044
  """Grad definition for `ReduceMin` operation."""
934
1045
 
935
1046
  def bprop(x, axis, out, dout):
1047
+ if x.dtype in (mstype.complex64, mstype.complex128):
1048
+ raise TypeError("The 'ReduceMin', gradient not support for complex type currently.")
936
1049
  dx = _min_or_max_grad(x, axis, out, dout)
937
1050
  return (dx, zeros_like(axis))
938
1051
 
@@ -961,17 +1074,20 @@ def get_bprop_reduce_mean(self):
961
1074
  dtype = P.DType()
962
1075
 
963
1076
  def bprop(x, axis, out, dout):
1077
+ if x.dtype in (mstype.complex64, mstype.complex128):
1078
+ raise TypeError("The 'ReduceMean', gradient not support for complex type currently.")
964
1079
  grad = _sum_grad(x, axis, dout)
965
1080
  shape_x = shape_op(x)
966
1081
  shape_out = shape_op(out)
967
- if is_shape_unknown(shape_x):
1082
+ if F.is_sequence_value_unknown(shape_x) or F.is_sequence_value_unknown(shape_out):
968
1083
  shape_x = dyn_shape_op(x)
969
1084
  shape_out = dyn_shape_op(out)
970
- div_shape = reduce_prod(shape_x) / reduce_prod(shape_out)
1085
+ div_shape = reduce_prod(cast(shape_x, mstype.float32), ()) /\
1086
+ reduce_prod(cast(shape_out, mstype.float32), ())
971
1087
  dx = div_op(grad, cast(div_shape, dtype(grad)))
972
1088
  else:
973
1089
  div_shape = F.shape_mul(shape_x) / F.shape_mul(shape_out)
974
- dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad)))
1090
+ dx = div_op(grad, cast(F.scalar_to_tensor(div_shape), dtype(grad)))
975
1091
  return dx, zeros_like(axis)
976
1092
 
977
1093
  return bprop
@@ -1097,16 +1213,6 @@ def get_bprop_logical_and(self):
1097
1213
  return bprop
1098
1214
 
1099
1215
 
1100
- @bprop_getters.register(P.LogicalOr)
1101
- def get_bprop_logical_or(self):
1102
- """Grad definition for `LogicalOr` operation."""
1103
-
1104
- def bprop(x, y, out, dout):
1105
- return zeros_like(x), zeros_like(y)
1106
-
1107
- return bprop
1108
-
1109
-
1110
1216
  @bprop_getters.register(P.NPUAllocFloatStatus)
1111
1217
  def get_bprop_npu_alloc_float_status(self):
1112
1218
  """Grad definition for `NPUAllocFloatStatus` operation."""
@@ -1304,6 +1410,9 @@ def get_bprop_cosh(self):
1304
1410
  sinh = P.Sinh()
1305
1411
 
1306
1412
  def bprop(x, out, dout):
1413
+ if x.dtype in (mstype.complex64, mstype.complex128):
1414
+ raise TypeError("The 'Cosh', gradient not support for complex type currently.")
1415
+
1307
1416
  dx = sinh(x) * dout
1308
1417
  return (dx,)
1309
1418
 
@@ -1334,16 +1443,6 @@ def get_bprop_conj(self):
1334
1443
  return bprop
1335
1444
 
1336
1445
 
1337
- @bprop_getters.register(P.ScalarCast)
1338
- def get_bprop_scalar_cast(self):
1339
- """Generate bprop for ScalarCast"""
1340
-
1341
- def bprop(x, t, out, dout):
1342
- return F.scalar_cast(dout, F.typeof(x)), zeros_like(t)
1343
-
1344
- return bprop
1345
-
1346
-
1347
1446
  @bprop_getters.register(P.AccumulateNV2)
1348
1447
  def get_bprop_scalar_accumulatenv2(self):
1349
1448
  """Generate bprop for AccumulateNV2"""
@@ -1457,6 +1556,9 @@ def get_bprop_tan(self):
1457
1556
  cos = P.Cos()
1458
1557
 
1459
1558
  def bprop(x, out, dout):
1559
+ if x.dtype in (mstype.complex64, mstype.complex128):
1560
+ raise TypeError("For 'Tan', gradient not support for complex type currently.")
1561
+
1460
1562
  cosx = cos(x)
1461
1563
  secx2 = square(reciprocal(cosx))
1462
1564
  dx = secx2 * dout
@@ -1498,6 +1600,9 @@ def get_bprop_atanh(self):
1498
1600
  div = P.Div()
1499
1601
 
1500
1602
  def bprop(x, out, dout):
1603
+ if x.dtype in (mstype.complex64, mstype.complex128):
1604
+ raise TypeError("For 'Atanh', gradient not support for complex type currently.")
1605
+
1501
1606
  tmp = 1 - power(x, 2)
1502
1607
  dx = div(1, tmp) * dout
1503
1608
  return (dx,)
@@ -1537,3 +1642,43 @@ def get_bprop_index_add(self):
1537
1642
  return dout, zeros_like(indices), gather(dout, indices, _axis)
1538
1643
 
1539
1644
  return bprop
1645
+
1646
+
1647
+ @bprop_getters.register(P.InplaceUpdate)
1648
+ def get_bprop_inplace_update(self):
1649
+ """Grad definition for `InplaceUpdate` operation."""
1650
+
1651
+ def bprop(x, v, out, dout):
1652
+ return zeros_like(x), zeros_like(v)
1653
+
1654
+ return bprop
1655
+
1656
+
1657
+ @bprop_getters.register(P.InplaceUpdateV2)
1658
+ def get_bprop_inplace_update_v2(self):
1659
+ """Grad definition for `InplaceUpdateV2` operation."""
1660
+
1661
+ def bprop(x, indices, v, out, dout):
1662
+ return zeros_like(x), zeros_like(indices), zeros_like(v)
1663
+
1664
+ return bprop
1665
+
1666
+
1667
+ @bprop_getters.register(P.InplaceSub)
1668
+ def get_bprop_inplace_sub(self):
1669
+ """Grad definition for `InplaceSub` operation."""
1670
+
1671
+ def bprop(x, input_v, out, dout):
1672
+ return zeros_like(x), zeros_like(input_v)
1673
+
1674
+ return bprop
1675
+
1676
+
1677
+ @bprop_getters.register(P.InplaceAdd)
1678
+ def get_bprop_inplace_add(self):
1679
+ """Grad definition for `InplaceAdd` operation."""
1680
+
1681
+ def bprop(x, input_v, out, dout):
1682
+ return zeros_like(x), zeros_like(input_v)
1683
+
1684
+ return bprop