mindspore 1.10.0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (944) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
  3. mindspore/__init__.py +9 -4
  4. mindspore/_akg/akg/composite/build_module.py +11 -0
  5. mindspore/_akg/akg/config/repository_cuda.json +11 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
  7. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  8. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  9. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  10. mindspore/_check_jit_forbidden_api.py +102 -0
  11. mindspore/_checkparam.py +1066 -1001
  12. mindspore/_extends/builtin_operations.py +32 -4
  13. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  21. mindspore/_extends/parse/__init__.py +5 -3
  22. mindspore/_extends/parse/namespace.py +17 -2
  23. mindspore/_extends/parse/parser.py +193 -34
  24. mindspore/_extends/parse/resources.py +7 -8
  25. mindspore/_extends/parse/standard_method.py +1780 -435
  26. mindspore/_extends/parse/trope.py +3 -1
  27. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  28. mindspore/amp.py +53 -58
  29. mindspore/bin/cache_admin +0 -0
  30. mindspore/bin/cache_server +0 -0
  31. mindspore/boost/adasum.py +3 -2
  32. mindspore/boost/boost.py +2 -2
  33. mindspore/boost/boost_cell_wrapper.py +46 -26
  34. mindspore/boost/dim_reduce.py +6 -5
  35. mindspore/boost/grad_accumulation.py +2 -1
  36. mindspore/boost/group_loss_scale_manager.py +1 -1
  37. mindspore/common/__init__.py +11 -10
  38. mindspore/common/_decorator.py +2 -0
  39. mindspore/common/_register_for_adapter.py +55 -0
  40. mindspore/common/_stub_tensor.py +201 -0
  41. mindspore/common/_utils.py +57 -0
  42. mindspore/common/api.py +582 -297
  43. mindspore/common/dtype.py +66 -18
  44. mindspore/common/dump.py +2 -2
  45. mindspore/common/initializer.py +38 -1
  46. mindspore/common/jit_config.py +25 -13
  47. mindspore/common/mutable.py +53 -24
  48. mindspore/common/parameter.py +60 -37
  49. mindspore/common/seed.py +8 -24
  50. mindspore/common/sparse_tensor.py +927 -0
  51. mindspore/common/tensor.py +1627 -3900
  52. mindspore/communication/__init__.py +10 -5
  53. mindspore/communication/_comm_helper.py +78 -214
  54. mindspore/communication/_hccl_management.py +2 -1
  55. mindspore/communication/management.py +136 -47
  56. mindspore/config/op_info.config +501 -1008
  57. mindspore/config/super_bar_config.json +512 -0
  58. mindspore/context.py +291 -56
  59. mindspore/dataset/__init__.py +12 -8
  60. mindspore/dataset/audio/__init__.py +9 -9
  61. mindspore/dataset/audio/transforms.py +1090 -228
  62. mindspore/dataset/audio/utils.py +87 -39
  63. mindspore/dataset/audio/validators.py +223 -1
  64. mindspore/dataset/callback/ds_callback.py +17 -15
  65. mindspore/dataset/core/config.py +246 -17
  66. mindspore/dataset/core/py_util_helpers.py +4 -3
  67. mindspore/dataset/core/validator_helpers.py +10 -10
  68. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  69. mindspore/dataset/debug/debug_hook.py +65 -0
  70. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  71. mindspore/dataset/engine/__init__.py +7 -3
  72. mindspore/dataset/engine/cache_client.py +9 -9
  73. mindspore/dataset/engine/datasets.py +648 -477
  74. mindspore/dataset/engine/datasets_audio.py +165 -167
  75. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  76. mindspore/dataset/engine/datasets_text.py +492 -342
  77. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  78. mindspore/dataset/engine/datasets_vision.py +1224 -699
  79. mindspore/dataset/engine/graphdata.py +134 -69
  80. mindspore/dataset/engine/iterators.py +50 -9
  81. mindspore/dataset/engine/offload.py +52 -31
  82. mindspore/dataset/engine/samplers.py +27 -24
  83. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  84. mindspore/dataset/engine/validators.py +213 -52
  85. mindspore/dataset/text/__init__.py +10 -8
  86. mindspore/dataset/text/transforms.py +152 -57
  87. mindspore/dataset/text/utils.py +98 -49
  88. mindspore/dataset/text/validators.py +25 -0
  89. mindspore/dataset/transforms/__init__.py +4 -2
  90. mindspore/dataset/transforms/c_transforms.py +11 -13
  91. mindspore/dataset/transforms/py_transforms.py +2 -2
  92. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  93. mindspore/dataset/transforms/transforms.py +13 -15
  94. mindspore/dataset/transforms/validators.py +7 -7
  95. mindspore/dataset/utils/__init__.py +2 -1
  96. mindspore/dataset/utils/browse_dataset.py +13 -13
  97. mindspore/dataset/utils/line_reader.py +121 -0
  98. mindspore/dataset/vision/__init__.py +8 -7
  99. mindspore/dataset/vision/c_transforms.py +125 -126
  100. mindspore/dataset/vision/py_transforms.py +37 -37
  101. mindspore/dataset/vision/py_transforms_util.py +23 -20
  102. mindspore/dataset/vision/transforms.py +316 -315
  103. mindspore/dataset/vision/utils.py +313 -17
  104. mindspore/dataset/vision/validators.py +6 -6
  105. mindspore/default_config.py +0 -1
  106. mindspore/{compression → experimental}/__init__.py +6 -5
  107. mindspore/experimental/map_parameter.py +275 -0
  108. mindspore/include/OWNERS +0 -1
  109. mindspore/include/api/callback/callback.h +9 -13
  110. mindspore/include/api/callback/ckpt_saver.h +2 -2
  111. mindspore/include/api/callback/loss_monitor.h +2 -2
  112. mindspore/include/api/callback/lr_scheduler.h +5 -5
  113. mindspore/include/api/callback/time_monitor.h +2 -2
  114. mindspore/include/api/callback/train_accuracy.h +4 -6
  115. mindspore/include/api/cfg.h +19 -6
  116. mindspore/include/api/context.h +70 -9
  117. mindspore/include/api/delegate.h +8 -1
  118. mindspore/include/api/dual_abi_helper.h +8 -24
  119. mindspore/include/api/metrics/accuracy.h +2 -2
  120. mindspore/include/api/metrics/metrics.h +4 -3
  121. mindspore/include/api/model.h +9 -4
  122. mindspore/include/api/model_group.h +68 -0
  123. mindspore/include/api/model_parallel_runner.h +17 -17
  124. mindspore/include/api/net.h +12 -11
  125. mindspore/include/api/serialization.h +20 -4
  126. mindspore/include/api/status.h +7 -1
  127. mindspore/include/api/types.h +25 -21
  128. mindspore/include/api/visible.h +4 -0
  129. mindspore/include/c_api/model_c.h +5 -0
  130. mindspore/include/c_api/status_c.h +1 -1
  131. mindspore/include/dataset/config.h +1 -1
  132. mindspore/include/dataset/constants.h +14 -0
  133. mindspore/include/dataset/text.h +59 -0
  134. mindspore/include/dataset/vision.h +56 -117
  135. mindspore/include/dataset/vision_lite.h +102 -0
  136. mindspore/include/mindapi/base/type_id.h +42 -3
  137. mindspore/lib/libdnnl.so.2 +0 -0
  138. mindspore/lib/libicudata.so.69 +0 -0
  139. mindspore/lib/libicui18n.so.69 +0 -0
  140. mindspore/lib/libicuuc.so.69 +0 -0
  141. mindspore/lib/libmindspore.so +0 -0
  142. mindspore/lib/libmindspore_backend.so +0 -0
  143. mindspore/lib/libmindspore_common.so +0 -0
  144. mindspore/lib/libmindspore_core.so +0 -0
  145. mindspore/lib/libmindspore_glog.so.0 +0 -0
  146. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  147. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  148. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  149. mindspore/lib/libmindspore_shared_lib.so +0 -0
  150. mindspore/lib/libmpi_adapter.so +0 -0
  151. mindspore/lib/libmpi_collective.so +0 -0
  152. mindspore/lib/libnnacl.so +0 -0
  153. mindspore/lib/libopencv_core.so.4.5 +0 -0
  154. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  155. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  156. mindspore/lib/libps_cache.so +0 -0
  157. mindspore/lib/plugin/ascend/libakg.so +0 -0
  158. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  159. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  160. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  161. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  163. mindspore/lib/{libakg.so → plugin/cpu/libakg.so} +0 -0
  164. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  165. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  166. mindspore/log.py +28 -28
  167. mindspore/mindrecord/common/exceptions.py +2 -4
  168. mindspore/mindrecord/filereader.py +19 -1
  169. mindspore/mindrecord/filewriter.py +250 -88
  170. mindspore/mindrecord/mindpage.py +13 -13
  171. mindspore/mindrecord/shardheader.py +15 -15
  172. mindspore/mindrecord/shardreader.py +9 -0
  173. mindspore/mindrecord/shardwriter.py +29 -29
  174. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  175. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  176. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  177. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  178. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  179. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  180. mindspore/nn/__init__.py +1 -5
  181. mindspore/nn/cell.py +297 -234
  182. mindspore/nn/dynamic_lr.py +1 -1
  183. mindspore/nn/grad/cell_grad.py +17 -42
  184. mindspore/nn/layer/__init__.py +7 -4
  185. mindspore/nn/layer/activation.py +131 -88
  186. mindspore/nn/layer/basic.py +313 -613
  187. mindspore/nn/layer/channel_shuffle.py +103 -0
  188. mindspore/nn/layer/combined.py +1 -1
  189. mindspore/nn/layer/container.py +52 -6
  190. mindspore/nn/layer/conv.py +112 -43
  191. mindspore/nn/layer/dense.py +10 -9
  192. mindspore/nn/layer/embedding.py +36 -34
  193. mindspore/nn/layer/image.py +123 -27
  194. mindspore/nn/layer/math.py +108 -107
  195. mindspore/nn/layer/normalization.py +212 -366
  196. mindspore/nn/layer/padding.py +370 -42
  197. mindspore/nn/layer/pooling.py +1443 -219
  198. mindspore/nn/layer/rnn_cells.py +11 -16
  199. mindspore/nn/layer/rnns.py +38 -39
  200. mindspore/nn/layer/thor_layer.py +24 -25
  201. mindspore/nn/layer/timedistributed.py +5 -5
  202. mindspore/nn/layer/transformer.py +701 -0
  203. mindspore/nn/learning_rate_schedule.py +8 -8
  204. mindspore/nn/loss/__init__.py +9 -6
  205. mindspore/nn/loss/loss.py +678 -142
  206. mindspore/nn/metrics.py +53 -0
  207. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  208. mindspore/nn/optim/ada_grad.py +8 -8
  209. mindspore/nn/optim/adadelta.py +2 -3
  210. mindspore/nn/optim/adafactor.py +18 -14
  211. mindspore/nn/optim/adam.py +429 -87
  212. mindspore/nn/optim/adamax.py +5 -6
  213. mindspore/nn/optim/adasum.py +10 -8
  214. mindspore/nn/optim/asgd.py +7 -7
  215. mindspore/nn/optim/ftrl.py +81 -11
  216. mindspore/nn/optim/lamb.py +7 -8
  217. mindspore/nn/optim/lars.py +4 -4
  218. mindspore/nn/optim/lazyadam.py +82 -7
  219. mindspore/nn/optim/momentum.py +8 -7
  220. mindspore/nn/optim/optimizer.py +19 -10
  221. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  222. mindspore/nn/optim/rmsprop.py +3 -3
  223. mindspore/nn/optim/rprop.py +20 -16
  224. mindspore/nn/optim/sgd.py +21 -15
  225. mindspore/nn/optim/thor.py +23 -21
  226. mindspore/nn/probability/__init__.py +0 -2
  227. mindspore/nn/probability/bijector/bijector.py +7 -6
  228. mindspore/nn/probability/bijector/invert.py +4 -2
  229. mindspore/nn/probability/bijector/softplus.py +2 -2
  230. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  231. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  232. mindspore/nn/probability/distribution/__init__.py +6 -0
  233. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  234. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  235. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  236. mindspore/nn/probability/distribution/beta.py +1 -1
  237. mindspore/nn/probability/distribution/categorical.py +9 -9
  238. mindspore/nn/probability/distribution/cauchy.py +8 -8
  239. mindspore/nn/probability/distribution/distribution.py +12 -6
  240. mindspore/nn/probability/distribution/exponential.py +5 -5
  241. mindspore/nn/probability/distribution/gamma.py +3 -3
  242. mindspore/nn/probability/distribution/geometric.py +6 -5
  243. mindspore/nn/probability/distribution/gumbel.py +5 -5
  244. mindspore/nn/probability/distribution/half_normal.py +133 -0
  245. mindspore/nn/probability/distribution/laplace.py +128 -0
  246. mindspore/nn/probability/distribution/log_normal.py +0 -1
  247. mindspore/nn/probability/distribution/logistic.py +4 -5
  248. mindspore/nn/probability/distribution/normal.py +11 -15
  249. mindspore/nn/probability/distribution/poisson.py +6 -2
  250. mindspore/nn/probability/distribution/student_t.py +150 -0
  251. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  252. mindspore/nn/probability/distribution/uniform.py +5 -5
  253. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  254. mindspore/nn/reinforcement/tensor_array.py +2 -2
  255. mindspore/nn/sparse/sparse.py +8 -1
  256. mindspore/nn/wrap/cell_wrapper.py +55 -27
  257. mindspore/nn/wrap/grad_reducer.py +20 -11
  258. mindspore/nn/wrap/loss_scale.py +47 -30
  259. mindspore/numpy/array_creations.py +33 -22
  260. mindspore/numpy/array_ops.py +46 -42
  261. mindspore/numpy/logic_ops.py +6 -27
  262. mindspore/numpy/math_ops.py +26 -19
  263. mindspore/numpy/utils.py +1 -8
  264. mindspore/numpy/utils_const.py +112 -62
  265. mindspore/ops/__init__.py +6 -3
  266. mindspore/ops/_constants.py +0 -6
  267. mindspore/ops/_grad/__init__.py +2 -1
  268. mindspore/ops/_grad/grad_array_ops.py +209 -152
  269. mindspore/ops/_grad/grad_base.py +55 -17
  270. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  271. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  272. mindspore/ops/_grad/grad_implementations.py +21 -61
  273. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  274. mindspore/ops/_grad/grad_math_ops.py +306 -161
  275. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  276. mindspore/ops/_grad/grad_other_ops.py +1 -1
  277. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  278. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  279. mindspore/ops/_grad/grad_sparse.py +15 -9
  280. mindspore/ops/_grad_experimental/__init__.py +1 -0
  281. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  282. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  283. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  284. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  285. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  286. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  287. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  288. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  289. mindspore/ops/_op_impl/__init__.py +3 -3
  290. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  291. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  292. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  293. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  294. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  295. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  296. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  297. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  298. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  299. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  300. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  302. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  303. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  304. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  305. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  307. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  314. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  315. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  317. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  318. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  320. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  321. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  322. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  324. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  326. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  327. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  328. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  329. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  330. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  331. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  332. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  335. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  336. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  337. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  338. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  339. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  340. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  341. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  342. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  343. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  344. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  345. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  346. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  347. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  349. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  350. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  351. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  352. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  353. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  354. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  355. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  356. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  357. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  358. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  359. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  360. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  361. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  362. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  363. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  365. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  366. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  367. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  368. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  369. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  370. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  371. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  372. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  373. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  374. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  375. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  376. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  377. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  378. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  379. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  380. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  381. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  382. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  383. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  384. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  385. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  386. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  387. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  388. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  389. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  390. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  391. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  392. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  393. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  394. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  395. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  396. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  397. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  398. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  399. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  400. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  401. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  402. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  403. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  404. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  405. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  407. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  408. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  409. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  410. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  411. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  412. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  413. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  414. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  415. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  416. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  417. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  418. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  419. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  420. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  421. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  422. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  423. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  424. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  425. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  426. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  427. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  428. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  429. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  430. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  431. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  432. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  433. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  434. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  435. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  436. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  437. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  441. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  442. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  443. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  444. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  445. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  446. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  447. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  448. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  449. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  450. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  451. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  452. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  453. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  454. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  455. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  456. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  457. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  458. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  459. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  460. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  461. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  462. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  463. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  464. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  465. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  466. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  467. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  468. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  469. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  470. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  471. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  472. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  473. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  474. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  475. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  476. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  477. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  478. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  479. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  480. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  481. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  482. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  483. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  484. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  485. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  486. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  487. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  488. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  489. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  490. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  491. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  492. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  493. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  494. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  495. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  496. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  497. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  498. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  499. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  500. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  501. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  502. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  503. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  504. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  505. mindspore/ops/_primitive_cache.py +3 -2
  506. mindspore/ops/_register_for_op.py +11 -0
  507. mindspore/ops/_utils/__init__.py +1 -1
  508. mindspore/ops/_utils/utils.py +20 -41
  509. mindspore/ops/_vmap/__init__.py +2 -2
  510. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  511. mindspore/ops/_vmap/vmap_base.py +24 -10
  512. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  513. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  514. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  515. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  516. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  517. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  518. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  519. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  520. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  521. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  522. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  524. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  525. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  526. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  527. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  528. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  529. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  530. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  532. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  534. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  536. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  537. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  538. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  541. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  542. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  543. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  544. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  545. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  546. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  547. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  548. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  549. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  550. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  551. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  552. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  554. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  555. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  556. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  560. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  561. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  563. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  564. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  565. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  566. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  567. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  568. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  569. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  571. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  572. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  573. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  574. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  575. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  577. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  578. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  579. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  580. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  581. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  582. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  583. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  585. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  586. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  587. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  588. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  589. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  590. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  591. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  592. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  593. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  594. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  595. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  596. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  597. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  598. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  599. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  600. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  601. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  604. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  605. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  606. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  607. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  609. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  610. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  611. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  612. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  613. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  614. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  615. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  616. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  617. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  618. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  619. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  620. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  621. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  622. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  623. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  624. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  625. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  626. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  627. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  628. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  629. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  630. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  632. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  633. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  634. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  635. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  637. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  638. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  639. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  640. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  643. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  644. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  645. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  646. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  647. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  648. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  649. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  650. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  651. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  652. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  653. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  654. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  655. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  656. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  657. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  658. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  660. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  661. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  662. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  663. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  664. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  665. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  666. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  667. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  668. mindspore/ops/bprop_mindir/__init__.py +1 -4
  669. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  670. mindspore/ops/composite/__init__.py +12 -13
  671. mindspore/ops/composite/base.py +261 -254
  672. mindspore/ops/composite/env_ops.py +41 -0
  673. mindspore/ops/composite/math_ops.py +197 -156
  674. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  675. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  676. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  677. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  678. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  679. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  680. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  681. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  682. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  683. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  684. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  685. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  686. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  687. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  688. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  689. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  690. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  691. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  692. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  693. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  694. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  695. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  696. mindspore/ops/function/__init__.py +323 -8
  697. mindspore/ops/function/array_func.py +3511 -780
  698. mindspore/ops/function/clip_func.py +329 -0
  699. mindspore/ops/function/debug_func.py +6 -6
  700. mindspore/ops/function/grad/__init__.py +5 -1
  701. mindspore/ops/function/grad/grad_func.py +736 -65
  702. mindspore/ops/function/image_func.py +270 -0
  703. mindspore/ops/function/linalg_func.py +268 -8
  704. mindspore/ops/function/math_func.py +8032 -3164
  705. mindspore/ops/function/nn_func.py +5619 -1855
  706. mindspore/ops/function/other_func.py +115 -0
  707. mindspore/ops/function/parameter_func.py +11 -10
  708. mindspore/ops/function/random_func.py +939 -77
  709. mindspore/ops/function/sparse_func.py +249 -84
  710. mindspore/ops/function/sparse_unary_func.py +2303 -0
  711. mindspore/ops/function/spectral_func.py +146 -0
  712. mindspore/ops/function/vmap_func.py +114 -0
  713. mindspore/ops/functional.py +182 -254
  714. mindspore/ops/op_info_register.py +79 -34
  715. mindspore/ops/operations/__init__.py +210 -118
  716. mindspore/ops/operations/_csr_ops.py +7 -7
  717. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  718. mindspore/ops/operations/_grad_ops.py +447 -322
  719. mindspore/ops/operations/_inner_ops.py +547 -176
  720. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  721. mindspore/ops/operations/_ms_kernel.py +29 -27
  722. mindspore/ops/operations/_ocr_ops.py +11 -11
  723. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  724. mindspore/ops/operations/_quant_ops.py +186 -101
  725. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  726. mindspore/ops/operations/_scalar_ops.py +466 -0
  727. mindspore/ops/operations/_sequence_ops.py +1047 -0
  728. mindspore/ops/operations/_tensor_array.py +10 -11
  729. mindspore/ops/operations/_thor_ops.py +4 -4
  730. mindspore/ops/operations/array_ops.py +1428 -1226
  731. mindspore/ops/operations/comm_ops.py +180 -117
  732. mindspore/ops/operations/control_ops.py +4 -2
  733. mindspore/ops/operations/custom_ops.py +185 -98
  734. mindspore/ops/operations/debug_ops.py +92 -54
  735. mindspore/ops/operations/image_ops.py +406 -211
  736. mindspore/ops/operations/inner_ops.py +42 -53
  737. mindspore/ops/operations/linalg_ops.py +32 -29
  738. mindspore/ops/operations/math_ops.py +2076 -897
  739. mindspore/ops/operations/nn_ops.py +1282 -1252
  740. mindspore/ops/operations/other_ops.py +124 -278
  741. mindspore/ops/operations/random_ops.py +345 -178
  742. mindspore/ops/operations/rl_ops.py +8 -9
  743. mindspore/ops/operations/sparse_ops.py +502 -157
  744. mindspore/ops/operations/spectral_ops.py +107 -0
  745. mindspore/ops/primitive.py +192 -15
  746. mindspore/ops/vm_impl_registry.py +23 -2
  747. mindspore/parallel/__init__.py +6 -1
  748. mindspore/parallel/_auto_parallel_context.py +199 -92
  749. mindspore/parallel/_cell_wrapper.py +4 -2
  750. mindspore/parallel/_cost_model_context.py +3 -0
  751. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  752. mindspore/parallel/_offload_context.py +185 -0
  753. mindspore/parallel/_parallel_serialization.py +167 -28
  754. mindspore/parallel/_ps_context.py +9 -5
  755. mindspore/parallel/_recovery_context.py +1 -1
  756. mindspore/parallel/_tensor.py +9 -1
  757. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  758. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  759. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  760. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  761. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  762. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  763. mindspore/parallel/_utils.py +47 -7
  764. mindspore/parallel/algo_parameter_config.py +5 -1
  765. mindspore/parallel/checkpoint_transform.py +329 -0
  766. mindspore/parallel/shard.py +229 -0
  767. mindspore/profiler/__init__.py +2 -1
  768. mindspore/profiler/common/util.py +4 -3
  769. mindspore/profiler/common/validator/validate_path.py +2 -2
  770. mindspore/profiler/envprofiling.py +249 -0
  771. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  772. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  773. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  774. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  775. mindspore/profiler/parser/framework_parser.py +42 -16
  776. mindspore/profiler/parser/hccl_parser.py +158 -158
  777. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  778. mindspore/profiler/parser/integrator.py +18 -1579
  779. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  780. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  781. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  782. mindspore/profiler/parser/optime_parser.py +17 -18
  783. mindspore/profiler/parser/profiler_info.py +108 -0
  784. mindspore/profiler/parser/step_trace_parser.py +1 -1
  785. mindspore/profiler/profiling.py +396 -194
  786. mindspore/rewrite/__init__.py +6 -2
  787. mindspore/rewrite/api/node.py +51 -110
  788. mindspore/rewrite/api/node_type.py +10 -6
  789. mindspore/rewrite/api/pattern_engine.py +51 -7
  790. mindspore/rewrite/api/scoped_value.py +64 -53
  791. mindspore/rewrite/api/symbol_tree.py +108 -61
  792. mindspore/rewrite/api/tree_node_helper.py +2 -3
  793. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  794. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  795. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  796. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  797. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  798. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  799. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  800. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  801. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  802. mindspore/rewrite/common/__init__.py +2 -0
  803. mindspore/rewrite/common/event.py +1 -1
  804. mindspore/rewrite/common/observable.py +1 -1
  805. mindspore/rewrite/common/observer.py +1 -1
  806. mindspore/rewrite/common/rewrite_elog.py +35 -0
  807. mindspore/rewrite/namer.py +2 -2
  808. mindspore/rewrite/namespace.py +14 -4
  809. mindspore/rewrite/node.py +161 -13
  810. mindspore/rewrite/parser.py +0 -1
  811. mindspore/rewrite/parser_register.py +0 -1
  812. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  813. mindspore/rewrite/parsers/assign_parser.py +267 -67
  814. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  815. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  816. mindspore/rewrite/parsers/constant_parser.py +101 -0
  817. mindspore/rewrite/parsers/container_parser.py +88 -0
  818. mindspore/rewrite/parsers/for_parser.py +28 -15
  819. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  820. mindspore/rewrite/parsers/if_parser.py +11 -28
  821. mindspore/rewrite/parsers/module_parser.py +9 -6
  822. mindspore/rewrite/parsers/return_parser.py +3 -2
  823. mindspore/rewrite/sparsify/__init__.py +0 -0
  824. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  825. mindspore/rewrite/sparsify/sparsify.py +109 -0
  826. mindspore/rewrite/sparsify/utils.py +173 -0
  827. mindspore/rewrite/symbol_tree.py +322 -109
  828. mindspore/rewrite/symbol_tree_builder.py +45 -8
  829. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  830. mindspore/rewrite/topological_manager.py +1 -2
  831. mindspore/run_check/_check_version.py +209 -112
  832. mindspore/run_check/run_check.py +2 -1
  833. mindspore/scipy/linalg.py +13 -117
  834. mindspore/scipy/ops.py +5 -71
  835. mindspore/scipy/ops_grad.py +1 -25
  836. mindspore/scipy/ops_wrapper.py +1 -1
  837. mindspore/scipy/optimize/_bfgs.py +1 -1
  838. mindspore/scipy/optimize/_lagrange.py +200 -0
  839. mindspore/scipy/optimize/line_search.py +3 -2
  840. mindspore/scipy/optimize/minimize.py +43 -6
  841. mindspore/scipy/sparse/__init__.py +2 -2
  842. mindspore/scipy/sparse/linalg.py +5 -465
  843. mindspore/scipy/utils.py +2 -1
  844. mindspore/scipy/utils_const.py +7 -1
  845. mindspore/train/__init__.py +6 -4
  846. mindspore/train/_utils.py +28 -5
  847. mindspore/train/amp.py +321 -50
  848. mindspore/train/callback/__init__.py +3 -1
  849. mindspore/train/callback/_backup_and_restore.py +120 -0
  850. mindspore/train/callback/_callback.py +8 -8
  851. mindspore/train/callback/_checkpoint.py +12 -9
  852. mindspore/train/callback/_early_stop.py +13 -7
  853. mindspore/train/callback/_history.py +8 -8
  854. mindspore/train/callback/_lambda_callback.py +6 -6
  855. mindspore/train/callback/_landscape.py +36 -38
  856. mindspore/train/callback/_loss_monitor.py +12 -6
  857. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  858. mindspore/train/callback/_on_request_exit.py +212 -0
  859. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  860. mindspore/train/callback/_summary_collector.py +27 -19
  861. mindspore/train/callback/_time_monitor.py +13 -7
  862. mindspore/train/checkpoint_pb2.py +68 -8
  863. mindspore/train/data_sink.py +122 -33
  864. mindspore/train/dataset_helper.py +28 -87
  865. mindspore/train/loss_scale_manager.py +4 -7
  866. mindspore/{nn → train}/metrics/__init__.py +20 -20
  867. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  868. mindspore/{nn → train}/metrics/auc.py +4 -4
  869. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  870. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  871. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  872. mindspore/{nn → train}/metrics/dice.py +6 -5
  873. mindspore/{nn → train}/metrics/error.py +7 -5
  874. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  875. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  876. mindspore/{nn → train}/metrics/loss.py +4 -3
  877. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  878. mindspore/{nn → train}/metrics/metric.py +6 -5
  879. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  880. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  881. mindspore/{nn → train}/metrics/precision.py +5 -4
  882. mindspore/{nn → train}/metrics/recall.py +5 -4
  883. mindspore/{nn → train}/metrics/roc.py +7 -6
  884. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  885. mindspore/{nn → train}/metrics/topk.py +7 -5
  886. mindspore/train/mind_ir_pb2.py +339 -32
  887. mindspore/train/model.py +113 -84
  888. mindspore/train/serialization.py +547 -167
  889. mindspore/train/summary/_summary_adapter.py +1 -1
  890. mindspore/train/summary/summary_record.py +43 -12
  891. mindspore/train/train_thor/convert_utils.py +7 -1
  892. mindspore/train/train_thor/dataset_helper.py +3 -3
  893. mindspore/train/train_thor/model_thor.py +0 -4
  894. mindspore/version.py +1 -1
  895. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  896. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +899 -675
  897. mindspore/compression/common/constant.py +0 -124
  898. mindspore/compression/export/__init__.py +0 -19
  899. mindspore/compression/export/quant_export.py +0 -514
  900. mindspore/compression/quant/qat.py +0 -636
  901. mindspore/compression/quant/quant_utils.py +0 -462
  902. mindspore/compression/quant/quantizer.py +0 -68
  903. mindspore/nn/layer/quant.py +0 -1868
  904. mindspore/nn/layer/rnn_utils.py +0 -90
  905. mindspore/nn/probability/dpn/__init__.py +0 -22
  906. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  907. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  908. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  909. mindspore/nn/probability/infer/__init__.py +0 -22
  910. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  911. mindspore/nn/probability/infer/variational/svi.py +0 -84
  912. mindspore/nn/probability/toolbox/__init__.py +0 -22
  913. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  914. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  915. mindspore/nn/probability/transforms/__init__.py +0 -22
  916. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  917. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  918. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  919. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  920. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  921. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  922. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  923. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  924. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  925. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  926. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  927. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  928. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  929. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  930. mindspore/ops/composite/array_ops.py +0 -210
  931. mindspore/ops/composite/clip_ops.py +0 -238
  932. mindspore/ops/composite/random_ops.py +0 -426
  933. mindspore/ops/composite/vmap_ops.py +0 -38
  934. mindspore/ops/operations/sponge_ops.py +0 -3531
  935. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  936. mindspore/parallel/nn/__init__.py +0 -42
  937. mindspore/parallel/nn/loss.py +0 -22
  938. mindspore/parallel/nn/moe.py +0 -21
  939. mindspore/parallel/nn/op_parallel_config.py +0 -22
  940. mindspore/parallel/nn/transformer.py +0 -31
  941. mindspore/run_check/_check_deps_version.py +0 -84
  942. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  943. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  944. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -17,14 +17,17 @@
17
17
 
18
18
  import numpy as np
19
19
  import mindspore.numpy as mnp
20
+ from mindspore import context
20
21
  from mindspore.common import dtype as mstype
21
- from mindspore import nn
22
22
  from mindspore.nn import LGamma
23
23
  from mindspore.ops import functional as F
24
24
  from mindspore.ops.functional import broadcast_gradient_args
25
25
  from mindspore.ops import operations as P
26
+ from mindspore.ops.operations import _inner_ops as inner
26
27
  from mindspore.ops.operations.math_ops import Trace, Bernoulli, Renorm
28
+ from mindspore import nn, Tensor
27
29
  from mindspore.ops.operations.math_ops import Real, Imag, Complex, Angle
30
+ from mindspore.ops.operations.math_ops import Polar
28
31
  from mindspore.ops.operations.math_ops import ComplexAbs
29
32
  from mindspore.ops.operations.math_ops import Sinc
30
33
  from mindspore.ops.operations import _grad_ops as G
@@ -39,31 +42,50 @@ from mindspore.ops.operations.math_ops import BesselK0e
39
42
  from mindspore.ops.operations.math_ops import BesselK1e
40
43
  from mindspore.ops.operations.math_ops import BesselY0
41
44
  from mindspore.ops.operations.math_ops import BesselY1
45
+ from mindspore.ops.operations.math_ops import Lgamma
46
+ from mindspore.ops.operations.math_ops import Digamma
47
+ from mindspore.ops.operations.math_ops import Polygamma
42
48
  from mindspore.ops.operations.math_ops import NextAfter
43
49
  from mindspore.ops.operations.math_ops import Hypot
44
50
  from mindspore.ops.operations.math_ops import ReduceStd
45
51
  from mindspore.ops.operations.math_ops import LuUnpack
46
52
  from mindspore.ops.operations.math_ops import MatrixExp
53
+ from mindspore.ops.operations.math_ops import CumulativeLogsumexp
47
54
  from mindspore.ops.operations.math_ops import MatrixSolve
55
+ from mindspore.ops.operations.math_ops import MatrixSolveLs
56
+ from mindspore.ops.operations.math_ops import MatrixPower
48
57
  from mindspore.ops.operations.math_ops import Median
58
+ from mindspore.ops.operations.math_ops import MatrixTriangularSolve
59
+ from mindspore.ops.operations.math_ops import NanToNum
60
+ from mindspore.ops.operations.math_ops import FFTWithSize
49
61
  from mindspore.ops.operations.math_ops import Betainc
62
+ from mindspore.ops.operations.math_ops import Cholesky
63
+ from mindspore.ops.operations.math_ops import Fmin
50
64
  from mindspore.ops.operations.math_ops import CholeskySolve
65
+ from mindspore.ops.operations.math_ops import InplaceIndexAdd
51
66
  from mindspore.ops.operations.math_ops import AddV2
52
67
  from mindspore.ops.operations.math_ops import TridiagonalMatMul
68
+ from mindspore.ops.operations.math_ops import TridiagonalSolve
53
69
  from mindspore.ops.operations.math_ops import Logit
70
+ from mindspore.ops.operations.math_ops import Diagonal
71
+ from mindspore.ops.operations.math_ops import EuclideanNorm
72
+ from mindspore.ops.operations.array_ops import Transpose, MatrixSetDiagV3
73
+ from mindspore.ops.operations.math_ops import Fmax
74
+ from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs
54
75
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
55
- from mindspore.ops.primitive import constexpr
56
- from mindspore.ops._utils.utils import is_shape_unknown
76
+ from mindspore.ops.primitive import _primexpr
57
77
  from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_rank
78
+ from mindspore.ops._grad.grad_base import dyn_ones, dyn_fill, sum_grad_reduce_axis
58
79
  from mindspore.ops._grad.grad_math_ops import binop_grad_common
59
-
80
+ from mindspore.ops.operations.array_ops import MatrixBandPart
81
+ from mindspore.ops.operations.array_ops import ConjugateTranspose
60
82
 
61
83
  transpose = P.Transpose()
62
84
  dyn_shape_op = P.TensorShape()
63
85
  _conj = P.Conj()
64
86
 
65
87
 
66
- @constexpr
88
+ @_primexpr
67
89
  def _generate_perm(x_dim):
68
90
  perm = tuple(range(x_dim - 2))
69
91
  return perm
@@ -110,6 +132,27 @@ def get_bprop_logit(self):
110
132
  return bprop
111
133
 
112
134
 
135
+ @bprop_getters.register(P.Roll)
136
+ def get_bprop_roll(self):
137
+ """Generate bprop for Roll"""
138
+ if context.get_context("device_target") == "GPU":
139
+ shift = []
140
+ axis = self.axis
141
+ for tmp in enumerate(self.shift):
142
+ shift.append(-tmp[1])
143
+ roll_grad = P.Roll(shift, axis)
144
+ else:
145
+ shift = self.shift
146
+ axis = self.axis
147
+ roll_grad = P.Roll(-shift, axis)
148
+
149
+ def bprop(x_input, out, dout):
150
+ dx = roll_grad(dout)
151
+ return (dx,)
152
+
153
+ return bprop
154
+
155
+
113
156
  @bprop_getters.register(P.Cdist)
114
157
  def get_bprop_cdist(self):
115
158
  """Generate bprop for Cdist"""
@@ -117,7 +160,7 @@ def get_bprop_cdist(self):
117
160
 
118
161
  def bprop(input_x, input_y, out, dout):
119
162
  dout_shape = F.shape(dout)
120
- if is_shape_unknown(dout_shape):
163
+ if F.is_sequence_value_unknown(dout_shape):
121
164
  dout_dim = dyn_rank(dout)
122
165
  dout_perm_part2 = create_tensor_by_element(
123
166
  (dout_dim - 1, dout_dim - 2))
@@ -146,7 +189,7 @@ def get_bprop_index_lerp(self):
146
189
  """Generate bprop for Lerp"""
147
190
  mul_op = P.Mul()
148
191
  sub_op = P.Sub()
149
- is_instance_op = P.IsInstance()
192
+ is_instance_op = inner.IsInstance()
150
193
 
151
194
  def bprop(start, end, weight, out, dout):
152
195
  dout = F.cast(dout, mstype.float32)
@@ -315,7 +358,7 @@ def get_bprop_index_addcmul(self):
315
358
  return bprop
316
359
 
317
360
 
318
- @constexpr
361
+ @_primexpr
319
362
  def renew_dim(shape, dim):
320
363
  """ Re-new dims"""
321
364
  new_dim = dim if dim >= 0 else len(shape) + dim
@@ -324,6 +367,21 @@ def renew_dim(shape, dim):
324
367
  return tuple(tmp)
325
368
 
326
369
 
370
+ @bprop_getters.register(EuclideanNorm)
371
+ def get_bprop_euclidean_norm(self):
372
+ """Generate bprop for EuclideanNorm"""
373
+ expand_dims = P.ExpandDims()
374
+ keep_dims = self.keep_dims
375
+
376
+ def bprop(x, axes, out, dout):
377
+ scale_v = dout / out
378
+ if not keep_dims and x.shape != ():
379
+ scale_v = expand_dims(scale_v, axes)
380
+ return (x * scale_v, zeros_like(axes))
381
+
382
+ return bprop
383
+
384
+
327
385
  @bprop_getters.register(Renorm)
328
386
  def get_bprop_renorm(self):
329
387
  """Generate bprop for Renorm """
@@ -332,7 +390,6 @@ def get_bprop_renorm(self):
332
390
  dim = self.dim
333
391
  max_norm = self.maxnorm
334
392
  greater_op = P.Greater()
335
- masked_fill_op = P.MaskedFill()
336
393
  pow_op = P.Pow()
337
394
  abs_op = P.Abs()
338
395
  sign_op = P.Sign()
@@ -349,13 +406,13 @@ def get_bprop_renorm(self):
349
406
  norm_bp = sig * grad_out
350
407
  elif p == 2:
351
408
  m = input_x * (grad_out / norm)
352
- norm_bp = masked_fill_op(m, norm == 0., 0.)
409
+ norm_bp = F.masked_fill(m, norm == 0., 0.)
353
410
  else:
354
411
  abs_ = abs_op(input_x)
355
412
  input_scaled = input_x * pow_op(abs_, (p - 2))
356
413
  pow_ = pow_op(norm, (p - 1))
357
414
  scale_v = grad_out / pow_
358
- scale_v = masked_fill_op(scale_v, norm == 0., 0.)
415
+ scale_v = F.masked_fill(scale_v, norm == 0., 0.)
359
416
  norm_bp = input_scaled * scale_v
360
417
 
361
418
  v = norm + ext
@@ -397,7 +454,88 @@ def get_bprop_lp_norm(self):
397
454
  else:
398
455
  input_scaled = pow_op(abs_op(input_x), (p - 2)) * input_x
399
456
  scale_v = dout / pow_op(out, (p - 1))
400
- return (input_scaled * scale_v,)
457
+ return (mnp.where(input_scaled == 0, 0, input_scaled * scale_v),)
458
+
459
+ return bprop
460
+
461
+
462
+ @bprop_getters.register(CumulativeLogsumexp)
463
+ def get_brop_cumulative_logsumexp(self):
464
+ """Generate bprop for CumulativeLogsumexp"""
465
+ exp_op = P.Exp()
466
+ greater_op = P.Greater()
467
+ log_op = P.Log()
468
+ cumulative_op = CumulativeLogsumexp(self.exclusive, not self.reverse)
469
+ less_op = P.Less()
470
+ neg_op = P.Neg()
471
+ cast = P.Cast()
472
+
473
+ def bprop(x, axis, out, dout):
474
+ dtype_min = 0
475
+ if x.dtype == mstype.float16:
476
+ dtype_min = cast(np.finfo(np.float16).min, x.dtype)
477
+ else:
478
+ dtype_min = cast(np.finfo(np.float32).min, x.dtype)
479
+ log_grad_positive = mnp.where(greater_op(dout, 0), log_op(dout), dtype_min)
480
+ log_grad_negative = mnp.where(less_op(dout, 0), log_op(neg_op(dout)), dtype_min)
481
+ output_pos = exp_op(cumulative_op(log_grad_positive - out, axis) + x)
482
+ output_neg = exp_op(cumulative_op(log_grad_negative - out, axis) + x)
483
+ return (output_pos - output_neg, zeros_like(x))
484
+
485
+ return bprop
486
+
487
+
488
+ @bprop_getters.register(MatrixTriangularSolve)
489
+ def get_bprop_matrix_triangular_solve(self):
490
+ """Grad definition for 'MatrixTriangularSolve' operation"""
491
+ adjoint_a = self.adjoint
492
+ lower_a = self.lower
493
+ matrix_triangular_solve_op = P.MatrixTriangularSolve(lower=lower_a, adjoint=not adjoint_a)
494
+ mat_mul_2d_op = P.MatMul()
495
+ mat_mul_op = P.BatchMatMul()
496
+ real_op = P.Real()
497
+ imag_op = P.Imag()
498
+ neg_op = P.Neg()
499
+ complex_op = P.Complex()
500
+ matrix_band_part_op = MatrixBandPart()
501
+
502
+ def bprop(matrix, rhs, out, dout):
503
+ grad_rhs = matrix_triangular_solve_op(matrix, dout)
504
+ if matrix.dtype == mstype.complex64 or matrix.dtype == mstype.complex128:
505
+ grad_rhs_temp = _adjoint(grad_rhs)
506
+ out_temp = _adjoint(out)
507
+ else:
508
+ grad_rhs_temp = cholesky_transpose(grad_rhs)
509
+ out_temp = cholesky_transpose(out)
510
+ if adjoint_a:
511
+ if len(matrix.shape) == 2:
512
+ grad_matrix = mat_mul_2d_op(out, grad_rhs_temp)
513
+ grad_matrix = neg_op(grad_matrix)
514
+ else:
515
+ grad_matrix = mat_mul_op(out, grad_rhs_temp)
516
+ grad_matrix = neg_op(grad_matrix)
517
+ else:
518
+ if len(matrix.shape) == 2:
519
+ grad_matrix = mat_mul_2d_op(grad_rhs, out_temp)
520
+ grad_matrix = neg_op(grad_matrix)
521
+ else:
522
+ grad_matrix = mat_mul_op(grad_rhs, out_temp)
523
+ grad_matrix = neg_op(grad_matrix)
524
+ if lower_a:
525
+ if grad_matrix.dtype == mstype.complex64 or grad_matrix.dtype == mstype.complex128:
526
+ grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), -1, 0)
527
+ grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), -1, 0)
528
+ grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
529
+ else:
530
+ grad_matrix = matrix_band_part_op(grad_matrix, -1, 0)
531
+ else:
532
+ if grad_matrix.dtype == mstype.complex64 or grad_matrix.dtype == mstype.complex128:
533
+ grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), 0, -1)
534
+ grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), 0, -1)
535
+ grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
536
+ else:
537
+ grad_matrix = matrix_band_part_op(grad_matrix, 0, -1)
538
+ return (grad_matrix, grad_rhs)
401
539
 
402
540
  return bprop
403
541
 
@@ -411,27 +549,43 @@ def get_bprop_matrix_exp(self):
411
549
  concat_col = P.Concat(-2)
412
550
  cast = P.Cast()
413
551
  slice_op = P.Slice()
552
+ range_op = P.Range()
553
+ expand_dims = P.ExpandDims()
554
+ dyn_shape = P.TensorShape()
414
555
 
415
556
  def bprop(x, out, dout):
416
- shape_x = P.Shape()(x)
417
- n = shape_x[-1]
418
- zero_matrix = zeros(shape_x, mstype.float32)
419
- zero_matrix = cast(zero_matrix, dout.dtype)
420
- x_len = len(shape_x)
421
- input_perm = [ele for ele in range(x_len)]
422
- input_perm[-1] = input_perm[-2]
423
- input_perm[-2] = x_len-1
424
- input_perm = tuple(input_perm)
425
- x_transpose = P.Transpose()(x, input_perm)
557
+ if F.is_sequence_value_unknown(x.shape):
558
+ shape_x = dyn_shape(x)
559
+ x_len = dyn_rank(x)
560
+ input_perm = range_op(cast(0, mstype.int64), x_len, cast(1, mstype.int64))
561
+ input_perm[-1] = input_perm[-2]
562
+ input_perm[-2] = x_len - 1
563
+ x_transpose = transpose(x, input_perm)
564
+ zero_matrix = dyn_fill(mstype.float32, shape_x, 0)
565
+ else:
566
+ shape_x = x.shape
567
+ x_len = len(shape_x)
568
+ input_perm = [ele for ele in range(x_len)]
569
+ input_perm[-1] = input_perm[-2]
570
+ input_perm[-2] = x_len - 1
571
+ input_perm = tuple(input_perm)
572
+ x_transpose = P.Transpose()(x, input_perm)
573
+ zero_matrix = zeros(shape_x, mstype.float32)
426
574
 
575
+ zero_matrix = cast(zero_matrix, dout.dtype)
427
576
  meta_grad_up = concat_row((x_transpose, dout))
428
577
  meta_grad_down = concat_row((zero_matrix, x_transpose))
429
578
  meta_grad = concat_col((meta_grad_up, meta_grad_down))
430
579
  meta_grad = matrix_exp(meta_grad)
431
580
 
432
- begins = [0] * x_len
581
+ if F.is_sequence_value_unknown(x.shape):
582
+ begins = dyn_fill(mstype.int32, expand_dims(x_len, 0), 0)
583
+ sizes = cast(shape_x, mstype.int32)
584
+ else:
585
+ begins = [0] * x_len
586
+ sizes = [i for i in shape_x]
587
+ n = shape_x[-1]
433
588
  begins[-1] = n
434
- sizes = [i for i in shape_x]
435
589
  sizes[-2] = n
436
590
  sizes[-1] = n
437
591
  return (slice_op(meta_grad, begins, sizes),)
@@ -439,17 +593,35 @@ def get_bprop_matrix_exp(self):
439
593
  return bprop
440
594
 
441
595
 
442
- @bprop_getters.register(P.MatrixInverse)
443
- def get_bprop_matrix_inverse(self):
444
- """Generate bprop for MatrixInverse"""
445
- matmul_x1 = nn.MatMul(transpose_x1=True)
446
- matmul_x2 = nn.MatMul(transpose_x2=True)
596
+ @bprop_getters.register(MatrixPower)
597
+ def get_bprop_matrix_power(self):
598
+ """Generate bprop for MatrixPower"""
599
+ n = self.n
600
+ batch_matmul_a = P.BatchMatMul(transpose_a=True)
601
+ batch_matmul_b = P.BatchMatMul(transpose_b=True)
447
602
  neg = P.Neg()
448
603
 
449
604
  def bprop(x, out, dout):
450
- dx = matmul_x2(dout, out)
451
- dx = matmul_x1(out, dx)
452
- dx = neg(dx)
605
+ dout = F.cast(dout, mstype.float32)
606
+ x = F.cast(x, mstype.float32)
607
+ power = n
608
+ dx = zeros_like(x)
609
+ if power < 0:
610
+ matrix_power = MatrixPower(n=-1)
611
+ x_inv = matrix_power(x)
612
+ for i in range(0, -power):
613
+ matrix_power = MatrixPower(n=(-power - 1 - i))
614
+ dx = dx + batch_matmul_b(dout, matrix_power(x_inv))
615
+ dout = batch_matmul_a(x_inv, dout)
616
+ dx = batch_matmul_b(dx, x_inv)
617
+ dx = batch_matmul_a(x_inv, dx)
618
+ dx = neg(dx)
619
+ else:
620
+ for i in range(0, power):
621
+ matrix_power = MatrixPower(n=(power - 1 - i))
622
+ dx = dx + batch_matmul_b(dout, matrix_power(x))
623
+ dout = batch_matmul_a(x, dout)
624
+ dx = F.cast(dx, F.dtype(out))
453
625
  return (dx,)
454
626
 
455
627
  return bprop
@@ -475,6 +647,13 @@ def get_bprop_matrix_solve(self):
475
647
  grad_b_type = F.dtype(grad_b)
476
648
  if grad_b_type == mstype.float64:
477
649
  grad_b = cast(grad_b, mstype.float32)
650
+
651
+ a_shape = F.shape(input_a)
652
+ if F.is_sequence_value_unknown(a_shape):
653
+ matrix_rank = dyn_rank(input_a)
654
+ else:
655
+ matrix_rank = rank(input_a)
656
+
478
657
  matrix_rank = rank(input_a)
479
658
  if adjoint:
480
659
  if matrix_rank > 2:
@@ -495,6 +674,156 @@ def get_bprop_matrix_solve(self):
495
674
  return bprop
496
675
 
497
676
 
677
+ @_primexpr
678
+ def _generate_perm_matrix_solve_ls(x_dim):
679
+ perm = tuple(range(x_dim - 2))
680
+ perm = perm + (x_dim-1, x_dim-2)
681
+ return perm
682
+
683
+
684
+ @bprop_getters.register(MatrixSolveLs)
685
+ def get_bprop_matrix_solve_ls(self):
686
+ """Grad definition for 'MatrixSolveLs' operation"""
687
+ fast = self.fast
688
+ cast = P.Cast()
689
+ neg = P.Neg()
690
+ rank = P.Rank()
691
+ cholesky = Cholesky()
692
+ eye = P.Eye()
693
+ add = P.Add()
694
+ mul = P.Mul()
695
+ matmul = P.MatMul()
696
+ batch_matmul = P.BatchMatMul()
697
+ cholesky_solve = CholeskySolve()
698
+ _transpose = Transpose()
699
+ conjugate_transpose = ConjugateTranspose()
700
+ shape = P.Shape()
701
+ _complex = P.Complex()
702
+ scalar2tensor = P.ScalarToTensor()
703
+
704
+ def regularized_gramian_cholesky(matrix, l2, first_kind):
705
+ matrix_dim = rank(matrix)
706
+ perm = _generate_perm_matrix_solve_ls(matrix_dim)
707
+ if matrix.dtype in (mstype.complex64, mstype.complex128):
708
+ matrix_temp = conjugate_transpose(matrix, perm)
709
+ else:
710
+ matrix_temp = _transpose(matrix, perm)
711
+ if first_kind:
712
+ if matrix_dim > 2:
713
+ gramian = batch_matmul(matrix_temp, matrix)
714
+ else:
715
+ gramian = matmul(matrix_temp, matrix)
716
+ else:
717
+ if matrix_dim > 2:
718
+ gramian = batch_matmul(matrix, matrix_temp)
719
+ else:
720
+ gramian = matmul(matrix, matrix_temp)
721
+ if isinstance(l2, Tensor) or l2 != 0:
722
+ matrix_shape = shape(matrix)
723
+ if first_kind:
724
+ small_dim = matrix_shape[-1]
725
+ else:
726
+ small_dim = matrix_shape[-2]
727
+ identity = eye(small_dim, small_dim, matrix.dtype)
728
+ gramian = add(gramian, mul(l2, identity))
729
+
730
+ #Cholesky not support complex dtype for now
731
+ return cholesky(gramian)
732
+
733
+ def bprop(matrix, rhs, l2, out, dout):
734
+ #support dtype:float32
735
+ #support dimension: 2D,3D
736
+ def over_determined(matrix, rhs, out, l2, dout):
737
+ if matrix.dtype == mstype.complex64:
738
+ l2_regularizer = _complex(cast(l2, mstype.float32), Tensor(0, mstype.float32))
739
+ elif matrix.dtype == mstype.complex128:
740
+ l2_regularizer = _complex(cast(l2, mstype.float64), Tensor(0, mstype.float64))
741
+ else:
742
+ l2_regularizer = cast(l2, matrix.dtype)
743
+ chol = cast(regularized_gramian_cholesky(matrix, l2_regularizer, first_kind=True), matrix.dtype)
744
+ #CholeskySolve not support complex dtype and just support 2D or 3D matrices for now
745
+ z = cholesky_solve(dout, chol)
746
+
747
+ matrix_dim = rank(matrix)
748
+ perm = _generate_perm_matrix_solve_ls(matrix_dim)
749
+ if matrix.dtype in (mstype.complex64, mstype.complex128):
750
+ z_temp = conjugate_transpose(z, perm)
751
+ else:
752
+ z_temp = _transpose(z, perm)
753
+ if matrix_dim > 2:
754
+ xzt = batch_matmul(out, z_temp)
755
+ else:
756
+ xzt = matmul(out, z_temp)
757
+ zx_sym = add(xzt, _transpose(xzt, perm))
758
+
759
+ if matrix_dim > 2:
760
+ grad_a = add(neg(batch_matmul(matrix, zx_sym)), batch_matmul(rhs, z_temp))
761
+ grad_b = batch_matmul(matrix, z)
762
+ else:
763
+ grad_a = add(neg(matmul(matrix, zx_sym)), matmul(rhs, z_temp))
764
+ grad_b = matmul(matrix, z)
765
+
766
+ return (grad_a, grad_b, scalar2tensor(0, l2.dtype))
767
+
768
+ def under_determined(matrix, rhs, l2, dout):
769
+ if matrix.dtype == mstype.complex64:
770
+ l2_regularizer = _complex(cast(l2, mstype.float32), Tensor(0, mstype.float32))
771
+ elif matrix.dtype == mstype.complex128:
772
+ l2_regularizer = _complex(cast(l2, mstype.float64), Tensor(0, mstype.float64))
773
+ else:
774
+ l2_regularizer = cast(l2, matrix.dtype)
775
+ chol = cast(regularized_gramian_cholesky(matrix, l2_regularizer, first_kind=False), matrix.dtype)
776
+
777
+ matrix_dim = rank(matrix)
778
+ perm = _generate_perm_matrix_solve_ls(matrix_dim)
779
+ if matrix_dim > 2:
780
+ gramian = batch_matmul(matrix, dout)
781
+ else:
782
+ gramian = matmul(matrix, dout)
783
+ #CholeskySolve not support complex dtype and just support 2D or 3D matrices for now
784
+ grad_b = cholesky_solve(gramian, chol)
785
+ tmp = cholesky_solve(rhs, chol)
786
+
787
+ if matrix.dtype in (mstype.complex64, mstype.complex128):
788
+ tmp_temp = conjugate_transpose(tmp, perm)
789
+ matrix_temp = conjugate_transpose(matrix, perm)
790
+ else:
791
+ tmp_temp = _transpose(tmp, perm)
792
+ matrix_temp = _transpose(matrix, perm)
793
+ if matrix_dim > 2:
794
+ a1 = batch_matmul(tmp_temp, matrix)
795
+ a1 = neg(batch_matmul(grad_b, a1))
796
+ a2 = dout - batch_matmul(matrix_temp, grad_b)
797
+ if matrix.dtype in (mstype.complex64, mstype.complex128):
798
+ a2_temp = conjugate_transpose(a2, perm)
799
+ else:
800
+ a2_temp = _transpose(a2, perm)
801
+ a2 = batch_matmul(tmp, a2_temp)
802
+ else:
803
+ a1 = matmul(tmp_temp, matrix)
804
+ a1 = neg(matmul(grad_b, a1))
805
+ a2 = dout - matmul(matrix_temp, grad_b)
806
+ if matrix.dtype in (mstype.complex64, mstype.complex128):
807
+ a2_temp = conjugate_transpose(a2, perm)
808
+ else:
809
+ a2_temp = _transpose(a2, perm)
810
+ a2 = matmul(tmp, a2_temp)
811
+
812
+ grad_a = add(a1, a2)
813
+ return (grad_a, grad_b, scalar2tensor(0, l2.dtype))
814
+
815
+ if fast is False:
816
+ raise ValueError("For MatrixSolveLs, gradient not defined for fast=False")
817
+ matrix_shape = shape(matrix)[-2:]
818
+
819
+ if matrix_shape[-2] >= matrix_shape[-1]:
820
+ return over_determined(matrix, rhs, out, l2, dout)
821
+
822
+ return under_determined(matrix, rhs, l2, dout)
823
+
824
+ return bprop
825
+
826
+
498
827
  @bprop_getters.register(P.MatrixDeterminant)
499
828
  def get_bprop_matrix_determinant(self):
500
829
  """Generate bprop for MatrixDeterminant"""
@@ -504,7 +833,7 @@ def get_bprop_matrix_determinant(self):
504
833
  concat = P.Concat(0)
505
834
 
506
835
  def bprop(x, out, dout):
507
- if is_shape_unknown(shape_op(x)):
836
+ if F.is_sequence_value_unknown(shape_op(x)):
508
837
  x_adj_inv = inverse_op(x)
509
838
  out_shape = dyn_shape_op(out)
510
839
  ones = create_tensor_by_element((1, 1))
@@ -528,7 +857,13 @@ def get_bprop_log_matrix_determinant(self):
528
857
 
529
858
  def bprop(x, out, dout):
530
859
  x_adj_inv = inverse_op(x)
531
- multipliers = reshape(dout[1], shape_op(out[1]) + (1, 1))
860
+ if F.is_sequence_value_unknown(shape_op(out[1])):
861
+ const_value = F.cast(1, mstype.int64)
862
+ const_value = P.ExpandDims()(const_value, 0)
863
+ new_shape = P.Concat()((dyn_shape_op(out[1]), const_value, const_value))
864
+ multipliers = reshape(dout[1], new_shape)
865
+ else:
866
+ multipliers = reshape(dout[1], shape_op(out[1]) + (1, 1))
532
867
  dx = multipliers * x_adj_inv
533
868
  return (dx,)
534
869
 
@@ -542,18 +877,17 @@ def get_bprop_betainc(self):
542
877
  exp = P.Exp()
543
878
  log1p = P.Log1p()
544
879
  xlogy = P.Xlogy()
545
- reduce_sum = P.ReduceSum()
880
+ dyn_shape = P.TensorShape()
546
881
 
547
882
  def bprop(input_a, input_b, input_x, out, dout):
548
- sa = F.shape(input_a)
549
- sx = F.shape(input_x)
550
- _, rx = F.broadcast_gradient_args(sa, sx)
551
-
883
+ if F.is_sequence_value_unknown(F.shape(input_x)):
884
+ sx = dyn_shape(input_x)
885
+ else:
886
+ sx = F.shape(input_x)
552
887
  log_beta = (lgamma(input_a) + lgamma(input_b) - lgamma(input_a + input_b))
553
888
  partial_x = exp((input_b - 1) * log1p(-input_x) + xlogy(input_a - 1, input_x) - log_beta)
554
- if rx != ():
555
- return (zeros_like(input_a), zeros_like(input_b), F.reshape(reduce_sum(partial_x * dout, rx), sx))
556
889
  return (zeros_like(input_a), zeros_like(input_b), F.reshape(partial_x * dout, sx))
890
+
557
891
  return bprop
558
892
 
559
893
 
@@ -646,6 +980,18 @@ def get_bprop_complex_abs(self):
646
980
  return bprop
647
981
 
648
982
 
983
+ @bprop_getters.register(NanToNum)
984
+ def get_bprop_nan_to_num(self):
985
+ """Grad definition for `NanToNum` operation."""
986
+ isfinite = P.IsFinite()
987
+
988
+ def bprop(x, out, dout):
989
+ dx = dout * isfinite(x)
990
+ return (dx,)
991
+
992
+ return bprop
993
+
994
+
649
995
  @bprop_getters.register(Angle)
650
996
  def get_bprop_angle(self):
651
997
  """Grad definition for `Angle` operation."""
@@ -667,6 +1013,29 @@ def get_bprop_angle(self):
667
1013
  return bprop
668
1014
 
669
1015
 
1016
+ @bprop_getters.register(Polar)
1017
+ def get_bprop_polar(self):
1018
+ """Grad definition for `Polar` operation."""
1019
+ complex_op = Complex()
1020
+ conj = P.Conj()
1021
+ real = P.Real()
1022
+ sig = P.Sign()
1023
+ ones = P.Ones()
1024
+ zeros = P.Zeros()
1025
+
1026
+ def bprop(input1, angle, out, dout):
1027
+ grad_conj = conj(dout)
1028
+ zero = zeros(dout.shape, input1.dtype)
1029
+ one = ones(dout.shape, input1.dtype)
1030
+ i = complex_op(zero, one)
1031
+ grad_abs = real(grad_conj * sig(out))
1032
+ result_mul_1_j = out * i
1033
+ grad_angle = real(grad_conj * result_mul_1_j)
1034
+ return (grad_abs, grad_angle)
1035
+
1036
+ return bprop
1037
+
1038
+
670
1039
  @bprop_getters.register(P.Erfinv)
671
1040
  def get_bprop_erfinv(self):
672
1041
  """Grad definition for `Erfinv` operation."""
@@ -927,7 +1296,7 @@ def get_bprop_trace(self):
927
1296
 
928
1297
  def bprop(x, out, dout):
929
1298
  shape = shape_op(x)
930
- if is_shape_unknown(shape):
1299
+ if F.is_sequence_value_unknown(shape):
931
1300
  shape = dyn_shape_op(x)
932
1301
  dx = input_grad(dout, shape)
933
1302
  else:
@@ -937,12 +1306,121 @@ def get_bprop_trace(self):
937
1306
  return bprop
938
1307
 
939
1308
 
1309
+ @bprop_getters.register(Fmin)
1310
+ def get_bprop_fmin(self):
1311
+ """Grad definition for 'Fmin' operation"""
1312
+ shape_ = P.Shape()
1313
+ masked_fill_op = P.MaskedFill()
1314
+ logical_or_op = P.LogicalOr()
1315
+ logical_not_op = P.LogicalNot()
1316
+ logical_and_op = P.LogicalAnd()
1317
+ mul_op = P.Mul()
1318
+ is_nan_op = P.IsNan()
1319
+ reshape_ = P.Reshape()
1320
+
1321
+ def bprop(x1, x2, out, dout):
1322
+ x1_dtype = F.dtype(x1)
1323
+ x2_dtype = F.dtype(x2)
1324
+ x1 = F.cast(x1, mstype.float32)
1325
+ x2 = F.cast(x2, mstype.float32)
1326
+ dout = F.cast(dout, mstype.float32)
1327
+ b1 = logical_or_op((x1 <= x2), is_nan_op(x2))
1328
+ b2 = logical_or_op((x2 < x1), logical_and_op(is_nan_op(x1), logical_not_op(is_nan_op(x2))))
1329
+ rx1 = masked_fill_op(x1, b1, 1.)
1330
+ rx1 = masked_fill_op(rx1, logical_not_op(b1), 0.)
1331
+ rx2 = masked_fill_op(x2, b2, 1.)
1332
+ rx2 = masked_fill_op(rx2, logical_not_op(b2), 0.)
1333
+ rrx1 = mul_op(rx1, dout)
1334
+ rrx2 = mul_op(rx2, dout)
1335
+ shape_of_x1 = shape_(x1)
1336
+ shape_of_x2 = shape_(x2)
1337
+ x1_dim = len(shape_of_x1)
1338
+ x2_dim = len(shape_of_x2)
1339
+ if x1_dim == 0 and x2_dim != 0:
1340
+ sum_r1 = rrx1.sum()
1341
+ sum_r2 = rrx2
1342
+ elif x1_dim == 0 and x2_dim == 0:
1343
+ sum_r1 = rrx1.sum()
1344
+ sum_r2 = rrx2.sum()
1345
+ elif x1_dim != 0 and x2_dim == 0:
1346
+ sum_r2 = rrx2.sum()
1347
+ sum_r1 = rrx1
1348
+ else:
1349
+ rx, ry = DynamicBroadcastGradientArgs()(shape_of_x1, shape_of_x2)
1350
+ sum_r1 = sum_grad_reduce_axis(rrx1, rx)
1351
+ sum_r2 = sum_grad_reduce_axis(rrx2, ry)
1352
+ brrx1 = reshape_(sum_r1, shape_of_x1)
1353
+ brrx2 = reshape_(sum_r2, shape_of_x2)
1354
+ brrx1 = F.cast(brrx1, x1_dtype)
1355
+ brrx2 = F.cast(brrx2, x2_dtype)
1356
+ return brrx1, brrx2
1357
+
1358
+ return bprop
1359
+
1360
+
1361
+ @bprop_getters.register(Fmax)
1362
+ def get_bprop_fmax(self):
1363
+ """Grad definition for 'Fmax' operation"""
1364
+ shape_ = P.Shape()
1365
+ masked_fill_op = P.MaskedFill()
1366
+ logical_or_op = P.LogicalOr()
1367
+ logical_not_op = P.LogicalNot()
1368
+ logical_and_op = P.LogicalAnd()
1369
+ mul_op = P.Mul()
1370
+ is_nan_op = P.IsNan()
1371
+ reshape_ = P.Reshape()
1372
+
1373
+ def bprop(x1, x2, out, dout):
1374
+ x1_dtype = F.dtype(x1)
1375
+ x2_dtype = F.dtype(x2)
1376
+ if x1_dtype != mstype.float32:
1377
+ x1 = F.cast(x1, mstype.float32)
1378
+ dout = F.cast(dout, mstype.float32)
1379
+ if x2_dtype != mstype.float32:
1380
+ x2 = F.cast(x2, mstype.float32)
1381
+ dout = F.cast(dout, mstype.float32)
1382
+ b1 = logical_or_op(logical_and_op((x1 >= x2), logical_not_op(is_nan_op(x1))), is_nan_op(x2))
1383
+ b2 = logical_or_op(logical_and_op(x2 > x1, logical_not_op(is_nan_op(x2))),
1384
+ logical_and_op(is_nan_op(x1), logical_not_op(is_nan_op(x2))))
1385
+ rx1 = masked_fill_op(x1, b1, 1.)
1386
+ rx1 = masked_fill_op(rx1, logical_not_op(b1), 0.)
1387
+ rx2 = masked_fill_op(x2, b2, 1.)
1388
+ rx2 = masked_fill_op(rx2, logical_not_op(b2), 0.)
1389
+ rrx1 = mul_op(rx1, dout)
1390
+ rrx2 = mul_op(rx2, dout)
1391
+ shape_of_x1 = shape_(x1)
1392
+ shape_of_x2 = shape_(x2)
1393
+ x1_dim = len(shape_of_x1)
1394
+ x2_dim = len(shape_of_x2)
1395
+ if x1_dim == 0 and x2_dim != 0:
1396
+ sum_r1 = rrx1.sum()
1397
+ sum_r2 = rrx2
1398
+ elif x1_dim == 0 and x2_dim == 0:
1399
+ sum_r1 = rrx1.sum()
1400
+ sum_r2 = rrx2.sum()
1401
+ elif x1_dim != 0 and x2_dim == 0:
1402
+ sum_r2 = rrx2.sum()
1403
+ sum_r1 = rrx1
1404
+ else:
1405
+ rx, ry = DynamicBroadcastGradientArgs()(shape_of_x1, shape_of_x2)
1406
+ sum_r1 = sum_grad_reduce_axis(rrx1, rx)
1407
+ sum_r2 = sum_grad_reduce_axis(rrx2, ry)
1408
+ brrx1 = reshape_(sum_r1, shape_of_x1)
1409
+ brrx2 = reshape_(sum_r2, shape_of_x2)
1410
+ brrx1 = F.cast(brrx1, x1_dtype)
1411
+ brrx2 = F.cast(brrx2, x2_dtype)
1412
+ return brrx1, brrx2
1413
+
1414
+
1415
+ return bprop
1416
+
1417
+
940
1418
  @bprop_getters.register(G.MinimumGrad)
941
1419
  def get_bprop_minimum_grad(self):
942
1420
  """Grad definition for 'MinimumGrad' operation"""
943
1421
  input_grad = G.MinimumGradGrad()
944
1422
 
945
- def bprop(grad, x1, x2, out, dout):
1423
+ def bprop(x1, x2, grad, out, dout):
946
1424
  sopd_x1, sopd_x2, sopd_grads = input_grad(x1, x2, dout[0], dout[1])
947
1425
  sopd_x1 = zeros_like(x1)
948
1426
  sopd_x2 = zeros_like(x2)
@@ -961,6 +1439,30 @@ def get_bprop_bernoulli(self):
961
1439
  return bprop
962
1440
 
963
1441
 
1442
+ @bprop_getters.register(TridiagonalSolve)
1443
+ def get_bprop_tridiagonalsolve(self):
1444
+ """Grad definition for 'TridiagonalSolve' operation"""
1445
+ tridiagonalsolve = TridiagonalSolve()
1446
+
1447
+ def bprop(diagonals, rhs, out, dout):
1448
+ diags = diagonals
1449
+ diag1 = diags[..., 1, :]
1450
+ zeros1 = P.Zeros()(diags.shape[:-2] + (1,), diags.dtype)
1451
+ superdiag1 = P.Concat(-1)((diags[..., 2, 1:], zeros1))
1452
+ subdiag1 = P.Concat(-1)((zeros1, diags[..., 0, :-1]))
1453
+ diags_transposed = P.Stack(-2)([superdiag1, diag1, subdiag1])
1454
+ grad_rhs = tridiagonalsolve(diags_transposed, dout)
1455
+ diag2 = P.ReduceSum()(grad_rhs * out, -1)
1456
+ zeros2 = P.Zeros()(grad_rhs.shape[:-2] + (1, grad_rhs.shape[-1]), grad_rhs.dtype)
1457
+ superdiag2 = P.ReduceSum()(grad_rhs * P.Concat(-2)((out[..., 1:, :], zeros2)), -1)
1458
+ subdiag2 = P.ReduceSum()(grad_rhs * P.Concat(-2)((zeros2, out[..., :-1, :])), -1)
1459
+ a = (P.Stack(-2)([superdiag2, diag2, subdiag2]))
1460
+ grad_diags = 0 - a
1461
+ return grad_diags, grad_rhs
1462
+
1463
+ return bprop
1464
+
1465
+
964
1466
  @bprop_getters.register(Igamma)
965
1467
  def get_bprop_igamma(self):
966
1468
  """Grad definition for `Igamma` operation."""
@@ -975,6 +1477,15 @@ def get_bprop_igamma(self):
975
1477
  def bprop(a, x, out, dout):
976
1478
  sa = shape_(a)
977
1479
  sx = shape_(x)
1480
+ if F.is_sequence_value_unknown(sa) or F.is_sequence_value_unknown(sx):
1481
+ sa = dyn_shape_op(a)
1482
+ sx = dyn_shape_op(x)
1483
+ ra, rx = DynamicBroadcastGradientArgs()(sa, sx)
1484
+ partial_a = igammagrada(a, x)
1485
+ partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
1486
+ r1 = reshape_(sum_grad_reduce_axis(partial_a * dout, ra), sa)
1487
+ r2 = reshape_(sum_grad_reduce_axis(partial_x * dout, rx), sx)
1488
+ return r1, r2
978
1489
  ra, rx = broadcast_gradient_args(sa, sx)
979
1490
  partial_a = igammagrada(a, x)
980
1491
  partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
@@ -1006,6 +1517,15 @@ def get_bprop_igammac(self):
1006
1517
  def bprop(a, x, out, dout):
1007
1518
  sa = shape_(a)
1008
1519
  sx = shape_(x)
1520
+ if F.is_sequence_value_unknown(sa) or F.is_sequence_value_unknown(sx):
1521
+ sa = dyn_shape_op(a)
1522
+ sx = dyn_shape_op(x)
1523
+ ra, rx = DynamicBroadcastGradientArgs()(sa, sx)
1524
+ partial_a = igammagrada(a, x)
1525
+ partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
1526
+ r1 = neg_(reshape_(sum_grad_reduce_axis(partial_a * dout, ra), sa))
1527
+ r2 = neg_(reshape_(sum_grad_reduce_axis(partial_x * dout, rx), sx))
1528
+ return r1, r2
1009
1529
  ra, rx = broadcast_gradient_args(sa, sx)
1010
1530
  partial_a = igammagrada(a, x)
1011
1531
  partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
@@ -1022,6 +1542,63 @@ def get_bprop_igammac(self):
1022
1542
  return bprop
1023
1543
 
1024
1544
 
1545
+ @bprop_getters.register(Lgamma)
1546
+ def get_bprop_lgamma(self):
1547
+ """Grad definition for `Lgamma` operation."""
1548
+ digamma = Digamma()
1549
+
1550
+ def bprop(x, out, dout):
1551
+ if x.dtype in (mstype.float16,):
1552
+ x = F.cast(x, mstype.float32)
1553
+ dx = dout * digamma(x)
1554
+ dx = F.cast(dx, mstype.float16)
1555
+ elif x.dtype in (mstype.int32,):
1556
+ x = F.cast(x, mstype.float32)
1557
+ dx = dout * digamma(x)
1558
+ else:
1559
+ dx = dout * digamma(x)
1560
+ return (dx,)
1561
+
1562
+ return bprop
1563
+
1564
+
1565
+ @bprop_getters.register(Digamma)
1566
+ def get_bprop_digamma(self):
1567
+ """Grad definition for `Digamma` operation."""
1568
+ polygamma = Polygamma()
1569
+ a = Tensor(1)
1570
+
1571
+ def bprop(x, out, dout):
1572
+ if x.dtype in (mstype.float16,):
1573
+ x = F.cast(x, mstype.float32)
1574
+ dx = dout * polygamma(a, x)
1575
+ dx = F.cast(dx, mstype.float16)
1576
+ else:
1577
+ dx = dout * polygamma(a, x)
1578
+ return (dx,)
1579
+
1580
+ return bprop
1581
+
1582
+
1583
+ @bprop_getters.register(Polygamma)
1584
+ def get_bprop_polygamma(self):
1585
+ """Grad definition for `Polygamma` operation."""
1586
+ polygamma = Polygamma()
1587
+
1588
+ def bprop(a, x, out, dout):
1589
+ one = Tensor(1)
1590
+ a = a + one
1591
+ if x.dtype in (mstype.float16,):
1592
+ x = F.cast(x, mstype.float64)
1593
+ dx = dout * polygamma(a, x)
1594
+ dx = F.cast(dx, mstype.float16)
1595
+ else:
1596
+ dx = dout * polygamma(a, x)
1597
+ return zeros_like(a), dx
1598
+
1599
+ return bprop
1600
+
1601
+
1025
1602
  @bprop_getters.register(TridiagonalMatMul)
1026
1603
  def get_bprop_tridiagonal_matmul(self):
1027
1604
  """Grad definition for 'TridiagonalMatMul' operation"""
@@ -1070,7 +1647,7 @@ def get_bprop_tridiagonal_matmul(self):
1070
1647
  maindiag_grad = reduce_sum(rhs_conj * grad, -1)
1071
1648
  subdiag_grad = reduce_sum(_rightshift(rhs_conj) * grad, -1)
1072
1649
  rhs_grad = _rightshift(superdiag_conj * grad) + maindiag_conj * grad + \
1073
- _leftshift(subdiag_conj * grad)
1650
+ _leftshift(subdiag_conj * grad)
1074
1651
  superdiag_grad = expand_dims(superdiag_grad, -2)
1075
1652
  maindiag_grad = expand_dims(maindiag_grad, -2)
1076
1653
  subdiag_grad = expand_dims(subdiag_grad, -2)
@@ -1101,13 +1678,18 @@ def get_bprop_cholesky_solve(self):
1101
1678
 
1102
1679
  def bprop(x1, x2, out, dout):
1103
1680
  flag = 0
1681
+ shape_x1 = shape_op(x1)
1682
+ if F.is_sequence_shape_unknown(shape_x1):
1683
+ len_x1 = dyn_rank(x1)
1684
+ else:
1685
+ len_x1 = len(shape_x1)
1104
1686
  if dout.dtype == mstype.float64:
1105
1687
  flag = 1
1106
1688
  x2 = F.cast(x2, mstype.float32)
1107
1689
  out = F.cast(out, mstype.float32)
1108
1690
  dout = F.cast(dout, mstype.float32)
1109
1691
  dx1 = cholesky_solve(dout, x2)
1110
- if len(shape_op(x2)) == 2:
1692
+ if len_x1 == 2:
1111
1693
  common_term = matmul_op(dx1, transpose(out, (1, 0)))
1112
1694
  common_term = common_term + transpose(common_term, (1, 0))
1113
1695
  if upper is True:
@@ -1115,8 +1697,11 @@ def get_bprop_cholesky_solve(self):
1115
1697
  else:
1116
1698
  dx2 = neg_op(matmul_op(common_term, x2))
1117
1699
  else:
1118
- common_term = batchmatmul_op(dx1, transpose(out, (0, 2, 1)))
1119
- common_term = common_term + transpose(common_term, (0, 2, 1))
1700
+ x2_dim_size = len(shape_op(x2))
1701
+ x2_dim_order = list(range(x2_dim_size))
1702
+ target_order = x2_dim_order[:-2] + x2_dim_order[-2:][::-1]
1703
+ common_term = batchmatmul_op(dx1, transpose(out, tuple(target_order)))
1704
+ common_term = common_term + transpose(common_term, tuple(target_order))
1120
1705
  if upper is True:
1121
1706
  dx2 = neg_op(batchmatmul_op(x2, common_term))
1122
1707
  else:
@@ -1133,6 +1718,7 @@ def get_bprop_cholesky_solve(self):
1133
1718
  def get_bprop_nextafter(self):
1134
1719
  """Grad definition for 'NextAfter' operation"""
1135
1720
  shape = P.Shape()
1721
+ dyn_shape = P.TensorShape()
1136
1722
  ones = P.Ones()
1137
1723
  zeros = P.Zeros()
1138
1724
  dtype = P.DType()
@@ -1151,11 +1737,355 @@ def get_bprop_nextafter(self):
1151
1737
  dout = cast(dout, mstype.float32)
1152
1738
 
1153
1739
  s_x1 = shape(x1)
1740
+ partial_x1 = ()
1741
+ if F.is_sequence_value_unknown(s_x1):
1742
+ s_x1 = dyn_shape(x1)
1743
+ partial_x1 = dyn_ones(s_x1, dtype(x1))
1744
+ else:
1745
+ partial_x1 = ones(s_x1, dtype(x1))
1746
+
1154
1747
  s_x2 = shape(x2)
1155
- partial_x1 = ones(s_x1, dtype(x1))
1156
- partial_x2 = zeros(s_x2, dtype(x2))
1748
+ partial_x2 = ()
1749
+ if F.is_sequence_value_unknown(s_x2):
1750
+ s_x2 = dyn_shape(x2)
1751
+ partial_x2 = dyn_fill(dtype(x2), s_x2, 0)
1752
+ else:
1753
+ partial_x2 = zeros(s_x2, dtype(x2))
1754
+
1157
1755
  dx1 = reshape(partial_x1 * dout, s_x1)
1158
1756
  dx2 = reshape(partial_x2 * dout, s_x2)
1159
1757
  return cast(dx1, dtype(dout)), cast(dx2, dtype(dout))
1160
1758
 
1161
1759
  return bprop
1760
+
1761
+
1762
+ @bprop_getters.register(Diagonal)
1763
+ def get_bprop_diagonal(self):
1764
+ """Grad definition for 'Diagonal' operation"""
1765
+ offset = self.offset
1766
+ dim1 = self.dim1
1767
+ dim2 = self.dim2
1768
+ zeros_op = P.FillV2()
1769
+ size_op = P.Size()
1770
+ transpose_op = Transpose()
1771
+ matrix_set_diag_op = MatrixSetDiagV3(align="LEFT_RIGHT")
1772
+
1773
+ def bprop(x, out, dout):
1774
+ x_shape = x.shape
1775
+ x_dtype = x.dtype
1776
+ x_dim = len(x_shape)
1777
+ if dim1 < 0:
1778
+ dim1_ = dim1 + x_dim
1779
+ else:
1780
+ dim1_ = dim1
1781
+ if dim2 < 0:
1782
+ dim2_ = dim2 + x_dim
1783
+ else:
1784
+ dim2_ = dim2
1785
+ if size_op(out):
1786
+ batch_dim = out.shape[:-1]
1787
+ diag_plane = (x_shape[dim1_], x_shape[dim2_])
1788
+ dx_trans_shape = batch_dim + diag_plane
1789
+ value = Tensor(0, x_dtype)
1790
+ dx = zeros_op(dx_trans_shape, value)
1791
+ k = F.cast(offset, mstype.int32)
1792
+ dx = matrix_set_diag_op(dx, dout, k)
1793
+ dim = 0
1794
+ perm = ()
1795
+ for i in range(x_dim):
1796
+ if i == dim1_:
1797
+ perm = perm + (x_dim - 2,)
1798
+ elif i == dim2_:
1799
+ perm = perm + (x_dim - 1,)
1800
+ else:
1801
+ perm = perm + (dim,)
1802
+ dim = dim + 1
1803
+ dx = transpose_op(dx, perm)
1804
+ else:
1805
+ dx = zeros_like(x)
1806
+ return (dx,)
1807
+
1808
+ return bprop
1809
+
1810
+
1811
+ @bprop_getters.register(Cholesky)
1812
+ def get_bprop_cholesky(self):
1813
+ """Grad definition for `Cholesky` operation."""
1814
+ upper = self.upper
1815
+ choleskygrad = G.CholeskyGrad()
1816
+
1817
+ def bprop(x, out, dout):
1818
+ out = cholesky_transpose(out) if upper else out
1819
+ dout = cholesky_transpose(dout) if upper else dout
1820
+ dx = choleskygrad(out, dout)
1821
+ return (dx,)
1822
+
1823
+ return bprop
1824
+
1825
+
1826
+ @bprop_getters.register(InplaceIndexAdd)
1827
+ def get_bprop_inplace_index_add(self):
1828
+ """Generate bprop for InplaceIndexAdd"""
1829
+ gather = P.Gather()
1830
+ _axis = self.axis
1831
+
1832
+ def bprop(var, indices, updates, out, dout):
1833
+ return dout, zeros_like(indices), gather(dout, indices, _axis)
1834
+
1835
+ return bprop
1836
+
1837
+
1838
+ @bprop_getters.register(P.Zeta)
1839
+ def get_bprop_zeta(self):
1840
+ """Generate bprop for Zeta"""
1841
+ zeta = P.Zeta()
1842
+
1843
+ def bprop(x, q, out, dout):
1844
+ dq = -x * zeta(x + 1, q) * dout
1845
+ return zeros_like(x), dq
1846
+
1847
+ return bprop
1848
+
1849
+
1850
+ @_primexpr
1851
+ def _fft_rank_offset(norm_shape, rank):
1852
+ """generate offset for fft with rank"""
1853
+ norm_shape_product = 1
1854
+ for i in norm_shape[-rank:]:
1855
+ norm_shape_product *= i
1856
+ return norm_shape_product
1857
+
1858
+
1859
+ @_primexpr
1860
+ def _fft_with_size_back_norm(norm_shape, norm, inverse, rank):
1861
+ """generate reverse term for fft_with_size"""
1862
+ if inverse is False:
1863
+ if norm == "forward":
1864
+ norm_ = 1 / _fft_rank_offset(norm_shape, rank)
1865
+ if norm == "backward":
1866
+ norm_ = 1 * _fft_rank_offset(norm_shape, rank)
1867
+ if norm == "ortho":
1868
+ norm_ = 1
1869
+ if inverse is True:
1870
+ if norm == "forward":
1871
+ norm_ = 1 * _fft_rank_offset(norm_shape, rank)
1872
+ if norm == "backward":
1873
+ norm_ = 1 / _fft_rank_offset(norm_shape, rank)
1874
+ if norm == "ortho":
1875
+ norm_ = 1
1876
+ return norm_
1877
+
1878
+
1879
+ @_primexpr
1880
+ def _rfft_norm(norm_shape, norm, rank):
1881
+ """generate norm for rfft"""
1882
+ norm_ = 1.0
1883
+ if norm == "forward":
1884
+ norm_ = 1 / _fft_rank_offset(norm_shape, rank)
1885
+ if norm == "backward":
1886
+ norm_ = 1
1887
+ if norm == "ortho":
1888
+ norm_ = 1 / np.sqrt(_fft_rank_offset(norm_shape, rank))
1889
+ return norm_
1890
+
1891
+
1892
+ @_primexpr
1893
+ def _get_last_dim_slice_shape(tensor_shape, index):
1894
+ """generate shape for slice last tensor"""
1895
+ from_shape = [0 for x in tensor_shape]
1896
+ if index < 0:
1897
+ from_shape[-1] = tensor_shape[-1] + index
1898
+ else:
1899
+ from_shape[-1] = index
1900
+ to_shape = list(tensor_shape)
1901
+ to_shape[-1] = 1
1902
+ return tuple(from_shape), tuple(to_shape)
1903
+
1904
+
1905
+ @_primexpr
1906
+ def _rfft_reshape(shape_a, shape_b):
1907
+ """generate rfft shape for reshape"""
1908
+ new_shape = list(shape_b)
1909
+ for i in range(len(shape_a) - 2):
1910
+ new_shape.insert(i, 1)
1911
+ return tuple(new_shape)
1912
+
1913
+
1914
+ @_primexpr
1915
+ def _rfft_tile_reshape(shape_a):
1916
+ """generate rfft shape for tile"""
1917
+ reshape_a = list(shape_a)
1918
+ reshape_a[-2] = 1
1919
+ reshape_a[-1] = 1
1920
+ return tuple(reshape_a)
1921
+
1922
+
1923
+ @_primexpr
1924
+ def _rfft_last_term_shape(shape_a, shape_b):
1925
+ """generate rfft shape for last term"""
1926
+ new_shape = list(shape_b)
1927
+ for i in range(len(shape_a) - 1):
1928
+ new_shape.insert(i, 1)
1929
+ return tuple(new_shape)
1930
+
1931
+
1932
+ @_primexpr
1933
+ def _batch_matmul_shape_increase(shape_before):
1934
+ """increase tensor shape for batch_matmul"""
1935
+ return (1, *shape_before)
1936
+
1937
+
1938
+ @_primexpr
1939
+ def _batch_matmul_shape_decrease(matrix_shape):
1940
+ """decrease tensor shape after batch_matmul"""
1941
+ shape_tmp = list(matrix_shape)
1942
+ shape_tmp[-1] = 1
1943
+ return tuple(shape_tmp)
1944
+
1945
+
1946
+ @bprop_getters.register(FFTWithSize)
1947
+ def get_bprop_fft_with_size(self):
1948
+ """Grad definition for `FFTWithSize` operation."""
1949
+ signal_ndim = self.signal_ndim
1950
+ inverse = self.inverse
1951
+ real = self.real
1952
+ norm = self.norm
1953
+ onesided = self.onesided
1954
+ fft_fn = FFTWithSize(signal_ndim=signal_ndim,
1955
+ inverse=False,
1956
+ real=False,
1957
+ norm=norm)
1958
+ ifft_fn = FFTWithSize(signal_ndim=signal_ndim,
1959
+ inverse=True,
1960
+ real=False,
1961
+ norm=norm)
1962
+ rfft_fn = FFTWithSize(signal_ndim=signal_ndim,
1963
+ inverse=False,
1964
+ real=True,
1965
+ norm=norm,
1966
+ onesided=onesided)
1967
+ irfft_fn = FFTWithSize(signal_ndim=signal_ndim,
1968
+ inverse=True,
1969
+ real=True,
1970
+ norm=norm,
1971
+ onesided=onesided)
1972
+
1973
+ complex_op = P.Complex()
1974
+ shape_op = P.Shape()
1975
+ to_tensor_op = P.ScalarToTensor()
1976
+ type_op = P.DType()
1977
+ concat_op = P.Concat()
1978
+ ones_op = P.Ones()
1979
+ zeros_op = P.Zeros()
1980
+ real_op = P.Real()
1981
+ imag_op = P.Imag()
1982
+ slice_op = P.Slice()
1983
+ tile_op = P.Tile()
1984
+ expand_dims = P.ExpandDims()
1985
+ transpose_op = P.Transpose()
1986
+ exp_op = P.Exp()
1987
+ reshape_op = P.Reshape()
1988
+ conj_op = P.Conj()
1989
+ batch_matmul_op = P.BatchMatMul()
1990
+
1991
+ def bprop(x, out, dout):
1992
+ dx = 0
1993
+ input_type = type_op(x)
1994
+ output_type = type_op(out)
1995
+ input_shape = shape_op(x)
1996
+ offset_shape = shape_op(x)
1997
+ dout_shape = shape_op(dout)
1998
+ offset_size = to_tensor_op(_fft_with_size_back_norm(offset_shape, norm, inverse, signal_ndim), output_type)
1999
+ if real is False:
2000
+ if inverse is False:
2001
+ dx = ifft_fn(dout) * offset_size
2002
+ else:
2003
+ dx = fft_fn(dout) * offset_size
2004
+ else:
2005
+ irfft_ = FFTWithSize(signal_ndim=1, inverse=True, real=True, norm="backward", onesided=onesided,
2006
+ signal_sizes=offset_shape[-1:])
2007
+ irfft2d_ = FFTWithSize(signal_ndim=2, inverse=True, real=True, norm="backward", onesided=onesided,
2008
+ signal_sizes=offset_shape[-2:])
2009
+ irfft3d_ = FFTWithSize(signal_ndim=3, inverse=True, real=True, norm="backward", onesided=onesided,
2010
+ signal_sizes=offset_shape[-3:])
2011
+ if inverse is False:
2012
+ if onesided is True:
2013
+ terms = 0
2014
+ is_even = to_tensor_op(1 - (input_shape[-1] % 2), input_type)
2015
+ dout_first_from, dout_first_to = _get_last_dim_slice_shape(dout_shape, 0)
2016
+ dout_first = slice_op(dout, dout_first_from, dout_first_to)
2017
+ rfft_offset_size = to_tensor_op(_fft_rank_offset(input_shape, signal_ndim), input_type)
2018
+ rfft_norm_offset = to_tensor_op(_rfft_norm(input_shape, norm, signal_ndim), input_type)
2019
+ dout_last_from, dout_last_to = _get_last_dim_slice_shape(dout_shape, -1)
2020
+ dout_last = slice_op(dout, dout_last_from, dout_last_to)
2021
+ if signal_ndim == 1:
2022
+ dx = irfft_(dout)
2023
+ vec_mask = complex_op(1 - 2 * (mnp.arange(0, input_shape[-1], 1, input_type) % 2),
2024
+ zeros_op(input_shape[-1], input_type))
2025
+ terms = real_op(dout_first) + is_even * real_op(dout_last * vec_mask)
2026
+ elif signal_ndim == 2:
2027
+ dx = irfft2d_(dout)
2028
+ arange_inner = mnp.arange(0, input_shape[-2], 1, input_type)
2029
+ matrix_a = tile_op(expand_dims(arange_inner, 0), (input_shape[-2], 1))
2030
+ matrix_b = transpose_op(matrix_a, (1, 0))
2031
+ matrix_mul = matrix_a * matrix_b
2032
+ imag_offset = complex_op(to_tensor_op(0, input_type), to_tensor_op(-2, input_type))
2033
+ pi_tensor = to_tensor_op(mnp.pi, output_type)
2034
+ matrix_mul_complex = complex_op(matrix_mul, zeros_op(shape_op(matrix_mul), input_type))
2035
+ matrix_base_mask = exp_op(imag_offset * pi_tensor * matrix_mul_complex /
2036
+ to_tensor_op(input_shape[-2], output_type))
2037
+ expanded_matrix_mask = reshape_op(matrix_base_mask, _rfft_reshape(shape_op(dout_first),
2038
+ shape_op(matrix_base_mask)))
2039
+ tile_matrix_mask = complex_op(tile_op(real_op(expanded_matrix_mask), _rfft_tile_reshape(
2040
+ shape_op(dout_first))), tile_op(imag_op(expanded_matrix_mask),
2041
+ _rfft_tile_reshape(shape_op(dout_first))))
2042
+ tile_matrix_mask_shape = shape_op(tile_matrix_mask)
2043
+ dout_first_term = reshape_op(batch_matmul_op(reshape_op(tile_matrix_mask,
2044
+ _batch_matmul_shape_increase(
2045
+ tile_matrix_mask_shape)),
2046
+ reshape_op(conj_op(
2047
+ dout_first), _batch_matmul_shape_increase(
2048
+ shape_op(dout_first)))),
2049
+ _batch_matmul_shape_decrease(tile_matrix_mask_shape))
2050
+ dout_last_term = reshape_op(batch_matmul_op(reshape_op(tile_matrix_mask,
2051
+ _batch_matmul_shape_increase(
2052
+ tile_matrix_mask_shape)),
2053
+ reshape_op(conj_op(dout_last),
2054
+ _batch_matmul_shape_increase(
2055
+ shape_op(dout_last)))),
2056
+ _batch_matmul_shape_decrease(
2057
+ tile_matrix_mask_shape))
2058
+ vec_mask = complex_op(1 - 2 * (mnp.arange(0, input_shape[-1], 1, input_type) % 2), zeros_op(
2059
+ input_shape[-1], input_type))
2060
+ dout_last_term = complex_op(tile_op(real_op(dout_last_term), _rfft_last_term_shape(dout_shape,
2061
+ [input_shape[
2062
+ -1]])),
2063
+ tile_op(imag_op(dout_last_term), _rfft_last_term_shape(
2064
+ dout_shape, [input_shape[-1]])))
2065
+ dout_last_term = dout_last_term * vec_mask
2066
+ terms = real_op(dout_first_term) + is_even * real_op(dout_last_term)
2067
+ elif signal_ndim == 3:
2068
+ dx = irfft3d_(dout) * real_op(offset_size)
2069
+ dx = to_tensor_op(0.5, input_type) * (dx * rfft_offset_size + terms) * rfft_norm_offset
2070
+ else:
2071
+ dx = irfft_fn(dout) * real_op(offset_size)
2072
+ else:
2073
+ dx = rfft_fn(dout)
2074
+ if onesided is True:
2075
+ if signal_ndim != 3:
2076
+ is_odd = dout_shape[-1] % 2
2077
+ last_shape = offset_shape[-1]
2078
+ mask = concat_op((ones_op(1, output_type), 2.0 * ones_op(
2079
+ (last_shape - 2 + is_odd,), output_type), ones_op((1 - is_odd,), output_type)))
2080
+ dx = dx * complex_op(mask, zeros_op(shape_op(mask), output_type))
2081
+ irfft_offset_size = to_tensor_op(
2082
+ _fft_with_size_back_norm(shape_op(dout), norm, inverse, signal_ndim),
2083
+ output_type)
2084
+ dx = dx * complex_op(irfft_offset_size, zeros_op(1, output_type))
2085
+ else:
2086
+ dx = dx * complex_op(offset_size, zeros_op(1, output_type))
2087
+ else:
2088
+ dx = dx * complex_op(offset_size, zeros_op(1, output_type))
2089
+ return (dx,)
2090
+
2091
+ return bprop