mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -17,83 +17,44 @@
17
17
 
18
18
  import numpy as np
19
19
  import mindspore.numpy as mnp
20
- from mindspore import context
21
20
  from mindspore.common import dtype as mstype
22
- from mindspore.nn import LGamma
23
21
  from mindspore.ops import functional as F
24
- from mindspore.ops.functional import broadcast_gradient_args
25
22
  from mindspore.ops import operations as P
26
- from mindspore.ops.operations import _inner_ops as inner
27
- from mindspore.ops.operations.math_ops import Trace, Bernoulli, Renorm
28
- from mindspore import nn, Tensor
23
+ from mindspore import Tensor
29
24
  from mindspore.ops.operations.math_ops import Real, Imag, Complex, Angle
30
25
  from mindspore.ops.operations.math_ops import Polar
31
- from mindspore.ops.operations.math_ops import ComplexAbs
32
- from mindspore.ops.operations.math_ops import Sinc
33
26
  from mindspore.ops.operations import _grad_ops as G
34
- from mindspore.ops.operations.math_ops import Igamma, Igammac
35
- from mindspore.ops.operations.math_ops import BesselI0
36
- from mindspore.ops.operations.math_ops import BesselI1
37
- from mindspore.ops.operations.math_ops import BesselJ0
38
- from mindspore.ops.operations.math_ops import BesselJ1
39
- from mindspore.ops.operations.math_ops import BesselK0
40
- from mindspore.ops.operations.math_ops import BesselK1
41
- from mindspore.ops.operations.math_ops import BesselK0e
42
- from mindspore.ops.operations.math_ops import BesselK1e
43
- from mindspore.ops.operations.math_ops import BesselY0
44
- from mindspore.ops.operations.math_ops import BesselY1
45
27
  from mindspore.ops.operations.math_ops import Lgamma
46
28
  from mindspore.ops.operations.math_ops import Digamma
47
29
  from mindspore.ops.operations.math_ops import Polygamma
48
- from mindspore.ops.operations.math_ops import NextAfter
49
- from mindspore.ops.operations.math_ops import Hypot
50
- from mindspore.ops.operations.math_ops import ReduceStd
51
- from mindspore.ops.operations.math_ops import LuUnpack
52
- from mindspore.ops.operations.math_ops import MatrixExp
53
30
  from mindspore.ops.operations.math_ops import CumulativeLogsumexp
54
31
  from mindspore.ops.operations.math_ops import MatrixSolve
55
32
  from mindspore.ops.operations.math_ops import MatrixSolveLs
56
- from mindspore.ops.operations.math_ops import MatrixPower
57
- from mindspore.ops.operations.math_ops import Median
58
33
  from mindspore.ops.operations.math_ops import MatrixTriangularSolve
59
34
  from mindspore.ops.operations.math_ops import NanToNum
60
35
  from mindspore.ops.operations.math_ops import FFTWithSize
61
- from mindspore.ops.operations.math_ops import Betainc
62
36
  from mindspore.ops.operations.math_ops import Cholesky
63
- from mindspore.ops.operations.math_ops import Fmin
64
37
  from mindspore.ops.operations.math_ops import CholeskySolve
65
38
  from mindspore.ops.operations.math_ops import InplaceIndexAdd
66
- from mindspore.ops.operations.math_ops import AddV2
67
- from mindspore.ops.operations.math_ops import TridiagonalMatMul
68
39
  from mindspore.ops.operations.math_ops import TridiagonalSolve
69
- from mindspore.ops.operations.math_ops import Logit
70
40
  from mindspore.ops.operations.math_ops import Diagonal
71
41
  from mindspore.ops.operations.math_ops import EuclideanNorm
72
42
  from mindspore.ops.operations.array_ops import Transpose, MatrixSetDiagV3
73
- from mindspore.ops.operations.math_ops import Fmax
74
43
  from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs
75
44
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
76
45
  from mindspore.ops.primitive import _primexpr
77
- from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_rank
78
- from mindspore.ops._grad.grad_base import dyn_ones, dyn_fill, sum_grad_reduce_axis
79
- from mindspore.ops._grad.grad_math_ops import binop_grad_common
46
+ from mindspore.ops._grad_experimental.grad_base import bprop_getters
47
+ from mindspore.ops._grad_experimental.grad_base import sum_grad_reduce_axis
80
48
  from mindspore.ops.operations.array_ops import MatrixBandPart
81
49
  from mindspore.ops.operations.array_ops import ConjugateTranspose
50
+ from mindspore.ops.functional import broadcast_gradient_args
51
+
82
52
 
83
53
  transpose = P.Transpose()
84
- dyn_shape_op = P.TensorShape()
85
54
  _conj = P.Conj()
86
-
87
-
88
- @_primexpr
89
- def _generate_perm(x_dim):
90
- perm = tuple(range(x_dim - 2))
91
- return perm
92
-
93
-
94
- def _dyn_generate_perm(x_dim):
95
- perm = P.Range()(P.Cast()(0, x_dim.dtype), x_dim - 2, P.Cast()(1, x_dim.dtype))
96
- return perm
55
+ shape_op = P.Shape()
56
+ reduce_sum = P.ReduceSum()
57
+ reshape = P.Reshape()
97
58
 
98
59
 
99
60
  def _adjoint(a):
@@ -108,256 +69,6 @@ def cholesky_transpose(a):
108
69
  return transpose(a, tuple(n_range))
109
70
 
110
71
 
111
- @bprop_getters.register(P.ACos)
112
- def get_bprop_acos(self):
113
- """Grad definition for `ACos` operation."""
114
- input_grad = G.ACosGrad()
115
-
116
- def bprop(input_x, out, dout):
117
- dx = input_grad(input_x, dout)
118
- return (dx,)
119
-
120
- return bprop
121
-
122
-
123
- @bprop_getters.register(Logit)
124
- def get_bprop_logit(self):
125
- """Grad definition for `Logit` operation."""
126
- logitgrad = G.LogitGrad(self.eps)
127
-
128
- def bprop(x, out, dout):
129
- dx = logitgrad(dout, x)
130
- return (dx,)
131
-
132
- return bprop
133
-
134
-
135
- @bprop_getters.register(P.Roll)
136
- def get_bprop_roll(self):
137
- """Generate bprop for Roll"""
138
- if context.get_context("device_target") == "GPU":
139
- shift = []
140
- axis = self.axis
141
- for tmp in enumerate(self.shift):
142
- shift.append(-tmp[1])
143
- roll_grad = P.Roll(shift, axis)
144
- else:
145
- shift = self.shift
146
- axis = self.axis
147
- roll_grad = P.Roll(-shift, axis)
148
-
149
- def bprop(x_input, out, dout):
150
- dx = roll_grad(dout)
151
- return (dx,)
152
-
153
- return bprop
154
-
155
-
156
- @bprop_getters.register(P.Cdist)
157
- def get_bprop_cdist(self):
158
- """Generate bprop for Cdist"""
159
- input_grad = G.CdistGrad(p=self.p)
160
-
161
- def bprop(input_x, input_y, out, dout):
162
- dout_shape = F.shape(dout)
163
- if F.is_sequence_value_unknown(dout_shape):
164
- dout_dim = dyn_rank(dout)
165
- dout_perm_part2 = create_tensor_by_element(
166
- (dout_dim - 1, dout_dim - 2))
167
- if dout_dim <= 2:
168
- dout_perm = dout_perm_part2
169
- else:
170
- dout_perm_part1 = _dyn_generate_perm(dout_dim)
171
- dout_perm = P.Concat(0)((dout_perm_part1, dout_perm_part2))
172
- else:
173
- dout_dim = len(dout_shape)
174
- dout_perm_part1 = _generate_perm(dout_dim)
175
- dout_perm_part2 = (dout_dim - 1, dout_dim - 2)
176
- dout_perm = dout_perm_part1 + dout_perm_part2
177
- out_perm = dout_perm
178
- dout_transpose = transpose(dout, dout_perm)
179
- out_transpose = transpose(out, out_perm)
180
- dx = input_grad(dout, input_x, input_y, out)
181
- dy = input_grad(dout_transpose, input_y, input_x, out_transpose)
182
- return dx, dy
183
-
184
- return bprop
185
-
186
-
187
- @bprop_getters.register(P.Lerp)
188
- def get_bprop_index_lerp(self):
189
- """Generate bprop for Lerp"""
190
- mul_op = P.Mul()
191
- sub_op = P.Sub()
192
- is_instance_op = inner.IsInstance()
193
-
194
- def bprop(start, end, weight, out, dout):
195
- dout = F.cast(dout, mstype.float32)
196
- dstart = mul_op(dout, 1 - weight)
197
- dend = mul_op(dout, weight)
198
- dweight = mul_op(dout, sub_op(end, start))
199
- dstart, dend = binop_grad_common(start, end, dstart, dend)
200
- if is_instance_op(weight, mstype.number):
201
- dweight = 0
202
- else:
203
- _, dweight = binop_grad_common(start, weight, dstart, dweight)
204
- dweight = F.cast(dweight, F.dtype(weight))
205
- dstart = F.cast(dstart, F.dtype(start))
206
- dend = F.cast(dend, F.dtype(end))
207
- return dstart, dend, dweight
208
-
209
- return bprop
210
-
211
-
212
- @bprop_getters.register(LuUnpack)
213
- def get_bprop_lu_unpack(self):
214
- """Grad definition for `LuUnpack` operation."""
215
- input_grad = G.LuUnpackGrad(L_grad_flag=True, U_grad_flag=True)
216
-
217
- def bprop(lu_data, lu_pivots, out, dout):
218
- dl, du = input_grad(dout[1], dout[2], lu_data)
219
- lu_data_grad = dl + du
220
- return (lu_data_grad, zeros_like(lu_pivots))
221
-
222
- return bprop
223
-
224
-
225
- @bprop_getters.register(Sinc)
226
- def get_bprop_sinc(self):
227
- """Grad definition for `Sinc` operation."""
228
- sin = P.Sin()
229
- cos = P.Cos()
230
- cast = P.Cast()
231
- conj = P.Conj()
232
-
233
- def bprop(x, out, dout):
234
- kpi = cast(np.pi, x.dtype)
235
- product = kpi * x
236
- reciprocal = (product * cos(product) - sin(product)) / (product * x)
237
- if reciprocal.dtype in [mstype.complex64, mstype.complex128]:
238
- reciprocal = conj(reciprocal)
239
- dx = reciprocal * dout
240
- return (dx,)
241
-
242
- return bprop
243
-
244
-
245
- @bprop_getters.register(ReduceStd)
246
- def get_bprop_reduce_std(self):
247
- """Grad definition for `ReduceStd` operation."""
248
- axis = list(self.axis)
249
- keep_dims = self.keep_dims
250
- unbiased = self.unbiased
251
- expand_dims_op = P.ExpandDims()
252
- size_op = P.Size()
253
- mul_op = P.Mul()
254
- sub_op = P.Sub()
255
- div_op = P.Div()
256
- add_op = P.Add()
257
-
258
- def bprop(x, out, dout):
259
- std_d = dout[0]
260
- std = out[0]
261
- mean_d = dout[1]
262
- mean = out[1]
263
- if axis == [] and x.shape != ():
264
- for i, _ in enumerate(x.shape):
265
- axis.append(i)
266
- for i, _ in enumerate(axis):
267
- if axis[i] < 0:
268
- axis[i] = axis[i] + len(x.shape)
269
- for i in range(1, len(axis)):
270
- for j in range(0, len(axis) - i):
271
- if axis[j] > axis[j + 1]:
272
- axis[j], axis[j + 1] = axis[j + 1], axis[j]
273
- if not keep_dims and x.shape != ():
274
- for i in axis:
275
- std_d = expand_dims_op(std_d, i)
276
- std = expand_dims_op(std, i)
277
- mean_d = expand_dims_op(mean_d, i)
278
- mean = expand_dims_op(mean, i)
279
- dx = sub_op(x, mean)
280
- dx = mul_op(dx, std_d)
281
- dx = div_op(dx, std)
282
- num = size_op(x)
283
- for i, _ in enumerate(x.shape):
284
- if i not in axis:
285
- num = num / x.shape[i]
286
- if unbiased:
287
- dx = div_op(dx, num - 1)
288
- else:
289
- dx = div_op(dx, num)
290
- temp = div_op(mean_d, num)
291
- dx = add_op(dx, temp)
292
- return (dx,)
293
-
294
- return bprop
295
-
296
-
297
- @bprop_getters.register(P.Addcdiv)
298
- def get_bprop_index_addcdiv(self):
299
- """Generate bprop for Addcdiv"""
300
- mul_op = P.Mul()
301
- div_op = P.Div()
302
- pow_op = P.Pow()
303
- neg_op = P.Neg()
304
-
305
- def bprop(input_data, x1, x2, value, out, dout):
306
- dinput_data = dout
307
- if dout.dtype in [mstype.float16, mstype.int64, mstype.float64]:
308
- input_data = F.cast(input_data, mstype.float32)
309
- x1 = F.cast(x1, mstype.float32)
310
- x2 = F.cast(x2, mstype.float32)
311
- value = F.cast(value, mstype.float32)
312
- dinput_data = F.cast(dinput_data, mstype.float32)
313
- inner_out = mul_op(value, div_op(x1, x2)) + input_data
314
- dx2 = neg_op(mul_op(mul_op(mul_op(x1, value), pow_op(x2, -2)), dinput_data))
315
- dx1 = mul_op(dinput_data, div_op(value, x2))
316
- dvalue = mul_op(dinput_data, div_op(x1, x2))
317
- _, dinput_data = binop_grad_common(inner_out, input_data, dout, dinput_data)
318
- _, dx1 = binop_grad_common(inner_out, x1, dout, dx1)
319
- _, dx2 = binop_grad_common(inner_out, x2, dout, dx2)
320
- _, dvalue = binop_grad_common(inner_out, value, dout, dvalue)
321
- if dout.dtype in [mstype.float16, mstype.int64, mstype.float64]:
322
- dinput_data = F.cast(dinput_data, dout.dtype)
323
- dx1 = F.cast(dx1, dout.dtype)
324
- dx2 = F.cast(dx2, dout.dtype)
325
- dvalue = F.cast(dvalue, dout.dtype)
326
- return dinput_data, dx1, dx2, dvalue
327
-
328
- return bprop
329
-
330
-
331
- @bprop_getters.register(P.Addcmul)
332
- def get_bprop_index_addcmul(self):
333
- """Generate bprop for Addcmul"""
334
- mul_op = P.Mul()
335
-
336
- def bprop(input_data, x1, x2, value, out, dout):
337
- if dout.dtype in [mstype.float16, mstype.float64, mstype.uint8, mstype.int8, mstype.int32, mstype.int64]:
338
- input_data = F.cast(input_data, mstype.float32)
339
- x1 = F.cast(x1, mstype.float32)
340
- x2 = F.cast(x2, mstype.float32)
341
- value = F.cast(value, mstype.float32)
342
- dinput_data = dout
343
- dx1 = mul_op(dout, mul_op(value, x2))
344
- dx2 = mul_op(dout, mul_op(value, x1))
345
- inner_out = mul_op(x1, x2) * value + input_data
346
- dvalue = mul_op(dout, mul_op(x1, x2))
347
- _, dinput_data = binop_grad_common(inner_out, input_data, dout, dinput_data)
348
- _, dx1 = binop_grad_common(inner_out, x1, dout, dx1)
349
- _, dx2 = binop_grad_common(inner_out, x2, dout, dx2)
350
- _, dvalue = binop_grad_common(inner_out, value, dout, dvalue)
351
- if dout.dtype in [mstype.float16, mstype.uint8, mstype.int8, mstype.float64, mstype.int32, mstype.int64]:
352
- dinput_data = F.cast(dinput_data, dout.dtype)
353
- dx1 = F.cast(dx1, dout.dtype)
354
- dx2 = F.cast(dx2, dout.dtype)
355
- dvalue = F.cast(dvalue, dout.dtype)
356
- return dinput_data, dx1, dx2, dvalue
357
-
358
- return bprop
359
-
360
-
361
72
  @_primexpr
362
73
  def renew_dim(shape, dim):
363
74
  """ Re-new dims"""
@@ -382,83 +93,6 @@ def get_bprop_euclidean_norm(self):
382
93
  return bprop
383
94
 
384
95
 
385
- @bprop_getters.register(Renorm)
386
- def get_bprop_renorm(self):
387
- """Generate bprop for Renorm """
388
- p = int(self.p)
389
- ext = 1e-7
390
- dim = self.dim
391
- max_norm = self.maxnorm
392
- greater_op = P.Greater()
393
- pow_op = P.Pow()
394
- abs_op = P.Abs()
395
- sign_op = P.Sign()
396
- reciprocal_op = P.Reciprocal()
397
-
398
- def bprop(input_x, out, dout):
399
- shape = F.shape(input_x)
400
- dims = renew_dim(shape, dim)
401
- norm = P.LpNorm(dims, p, keep_dims=True)(input_x)
402
- grad_out = (input_x * dout)
403
- grad_out = grad_out.sum(dims, keepdims=True)
404
- if p == 1:
405
- sig = sign_op(input_x)
406
- norm_bp = sig * grad_out
407
- elif p == 2:
408
- m = input_x * (grad_out / norm)
409
- norm_bp = F.masked_fill(m, norm == 0., 0.)
410
- else:
411
- abs_ = abs_op(input_x)
412
- input_scaled = input_x * pow_op(abs_, (p - 2))
413
- pow_ = pow_op(norm, (p - 1))
414
- scale_v = grad_out / pow_
415
- scale_v = F.masked_fill(scale_v, norm == 0., 0.)
416
- norm_bp = input_scaled * scale_v
417
-
418
- v = norm + ext
419
- inv_norm = reciprocal_op(v)
420
- grad_norm = max_norm * inv_norm * (dout - inv_norm * norm_bp)
421
- q = greater_op(norm, max_norm)
422
- return (mnp.where(q, grad_norm, dout),)
423
-
424
- return bprop
425
-
426
-
427
- @bprop_getters.register(P.LpNorm)
428
- def get_bprop_lp_norm(self):
429
- """Grad definition for `LpNorm` operation."""
430
- p = self.p
431
- keep_dims = self.keep_dims
432
- axis = self.axis
433
- if isinstance(axis, int):
434
- axis = [axis]
435
- sign_op = P.Sign()
436
- abs_op = P.Abs()
437
- zeros_like_op = P.ZerosLike()
438
- expand_dims_op = P.ExpandDims()
439
- pow_op = P.Pow()
440
-
441
- def bprop(input_x, out, dout):
442
- if not keep_dims and input_x.shape != ():
443
- for i in axis:
444
- dout = expand_dims_op(dout, i)
445
- out = expand_dims_op(out, i)
446
-
447
- if p == 0:
448
- return (zeros_like_op(input_x),)
449
- if p == 1:
450
- return (dout * sign_op(input_x),)
451
- if p == 2:
452
- input_scaled = input_x
453
- scale_v = dout / out
454
- else:
455
- input_scaled = pow_op(abs_op(input_x), (p - 2)) * input_x
456
- scale_v = dout / pow_op(out, (p - 1))
457
- return (mnp.where(input_scaled == 0, 0, input_scaled * scale_v),)
458
-
459
- return bprop
460
-
461
-
462
96
  @bprop_getters.register(CumulativeLogsumexp)
463
97
  def get_brop_cumulative_logsumexp(self):
464
98
  """Generate bprop for CumulativeLogsumexp"""
@@ -540,93 +174,6 @@ def get_bprop_matrix_triangular_solve(self):
540
174
  return bprop
541
175
 
542
176
 
543
- @bprop_getters.register(MatrixExp)
544
- def get_bprop_matrix_exp(self):
545
- """Gegerate brop for MatrixExp"""
546
- matrix_exp = MatrixExp()
547
- zeros = P.Zeros()
548
- concat_row = P.Concat(-1)
549
- concat_col = P.Concat(-2)
550
- cast = P.Cast()
551
- slice_op = P.Slice()
552
- range_op = P.Range()
553
- expand_dims = P.ExpandDims()
554
- dyn_shape = P.TensorShape()
555
-
556
- def bprop(x, out, dout):
557
- if F.is_sequence_value_unknown(x.shape):
558
- shape_x = dyn_shape(x)
559
- x_len = dyn_rank(x)
560
- input_perm = range_op(cast(0, mstype.int64), x_len, cast(1, mstype.int64))
561
- input_perm[-1] = input_perm[-2]
562
- input_perm[-2] = x_len - 1
563
- x_transpose = transpose(x, input_perm)
564
- zero_matrix = dyn_fill(mstype.float32, shape_x, 0)
565
- else:
566
- shape_x = x.shape
567
- x_len = len(shape_x)
568
- input_perm = [ele for ele in range(x_len)]
569
- input_perm[-1] = input_perm[-2]
570
- input_perm[-2] = x_len - 1
571
- input_perm = tuple(input_perm)
572
- x_transpose = P.Transpose()(x, input_perm)
573
- zero_matrix = zeros(shape_x, mstype.float32)
574
-
575
- zero_matrix = cast(zero_matrix, dout.dtype)
576
- meta_grad_up = concat_row((x_transpose, dout))
577
- meta_grad_down = concat_row((zero_matrix, x_transpose))
578
- meta_grad = concat_col((meta_grad_up, meta_grad_down))
579
- meta_grad = matrix_exp(meta_grad)
580
-
581
- if F.is_sequence_value_unknown(x.shape):
582
- begins = dyn_fill(mstype.int32, expand_dims(x_len, 0), 0)
583
- sizes = cast(shape_x, mstype.int32)
584
- else:
585
- begins = [0] * x_len
586
- sizes = [i for i in shape_x]
587
- n = shape_x[-1]
588
- begins[-1] = n
589
- sizes[-2] = n
590
- sizes[-1] = n
591
- return (slice_op(meta_grad, begins, sizes),)
592
-
593
- return bprop
594
-
595
-
596
- @bprop_getters.register(MatrixPower)
597
- def get_bprop_matrix_power(self):
598
- """Generate bprop for MatrixPower"""
599
- n = self.n
600
- batch_matmul_a = P.BatchMatMul(transpose_a=True)
601
- batch_matmul_b = P.BatchMatMul(transpose_b=True)
602
- neg = P.Neg()
603
-
604
- def bprop(x, out, dout):
605
- dout = F.cast(dout, mstype.float32)
606
- x = F.cast(x, mstype.float32)
607
- power = n
608
- dx = zeros_like(x)
609
- if power < 0:
610
- matrix_power = MatrixPower(n=-1)
611
- x_inv = matrix_power(x)
612
- for i in range(0, -power):
613
- matrix_power = MatrixPower(n=(-power - 1 - i))
614
- dx = dx + batch_matmul_b(dout, matrix_power(x_inv))
615
- dout = batch_matmul_a(x_inv, dout)
616
- dx = batch_matmul_b(dx, x_inv)
617
- dx = batch_matmul_a(x_inv, dx)
618
- dx = neg(dx)
619
- else:
620
- for i in range(0, power):
621
- matrix_power = MatrixPower(n=(power - 1 - i))
622
- dx = dx + batch_matmul_b(dout, matrix_power(x))
623
- dout = batch_matmul_a(x, dout)
624
- dx = F.cast(dx, F.dtype(out))
625
- return (dx,)
626
-
627
- return bprop
628
-
629
-
630
177
  @bprop_getters.register(MatrixSolve)
631
178
  def get_bprop_matrix_solve(self):
632
179
  """Generate bprop for MatrixSolve"""
@@ -648,13 +195,8 @@ def get_bprop_matrix_solve(self):
648
195
  if grad_b_type == mstype.float64:
649
196
  grad_b = cast(grad_b, mstype.float32)
650
197
 
651
- a_shape = F.shape(input_a)
652
- if F.is_sequence_value_unknown(a_shape):
653
- matrix_rank = dyn_rank(input_a)
654
- else:
655
- matrix_rank = rank(input_a)
656
-
657
198
  matrix_rank = rank(input_a)
199
+
658
200
  if adjoint:
659
201
  if matrix_rank > 2:
660
202
  grad_a = batchmatmul(out, grad_b)
@@ -824,162 +366,6 @@ def get_bprop_matrix_solve_ls(self):
824
366
  return bprop
825
367
 
826
368
 
827
- @bprop_getters.register(P.MatrixDeterminant)
828
- def get_bprop_matrix_determinant(self):
829
- """Generate bprop for MatrixDeterminant"""
830
- inverse_op = P.MatrixInverse(adjoint=True)
831
- shape_op = P.Shape()
832
- reshape = P.Reshape()
833
- concat = P.Concat(0)
834
-
835
- def bprop(x, out, dout):
836
- if F.is_sequence_value_unknown(shape_op(x)):
837
- x_adj_inv = inverse_op(x)
838
- out_shape = dyn_shape_op(out)
839
- ones = create_tensor_by_element((1, 1))
840
- multipliers = reshape(dout * out, concat((out_shape, ones)))
841
- dx = multipliers * x_adj_inv
842
- return (dx,)
843
- x_adj_inv = inverse_op(x)
844
- multipliers = reshape(dout * out, shape_op(out) + (1, 1))
845
- dx = multipliers * x_adj_inv
846
- return (dx,)
847
-
848
- return bprop
849
-
850
-
851
- @bprop_getters.register(P.LogMatrixDeterminant)
852
- def get_bprop_log_matrix_determinant(self):
853
- """Generate bprop for LogMatrixDeterminant"""
854
- inverse_op = P.MatrixInverse(adjoint=True)
855
- shape_op = P.Shape()
856
- reshape = P.Reshape()
857
-
858
- def bprop(x, out, dout):
859
- x_adj_inv = inverse_op(x)
860
- if F.is_sequence_value_unknown(shape_op(out[1])):
861
- const_value = F.cast(1, mstype.int64)
862
- const_value = P.ExpandDims()(const_value, 0)
863
- new_shape = P.Concat()((dyn_shape_op(out[1]), const_value, const_value))
864
- multipliers = reshape(dout[1], new_shape)
865
- else:
866
- multipliers = reshape(dout[1], shape_op(out[1]) + (1, 1))
867
- dx = multipliers * x_adj_inv
868
- return (dx,)
869
-
870
- return bprop
871
-
872
-
873
- @bprop_getters.register(Betainc)
874
- def get_bprop_betainc(self):
875
- """Grad definition for 'Betainc' operation"""
876
- lgamma = LGamma()
877
- exp = P.Exp()
878
- log1p = P.Log1p()
879
- xlogy = P.Xlogy()
880
- dyn_shape = P.TensorShape()
881
-
882
- def bprop(input_a, input_b, input_x, out, dout):
883
- if F.is_sequence_value_unknown(F.shape(input_x)):
884
- sx = dyn_shape(input_x)
885
- else:
886
- sx = F.shape(input_x)
887
- log_beta = (lgamma(input_a) + lgamma(input_b) - lgamma(input_a + input_b))
888
- partial_x = exp((input_b - 1) * log1p(-input_x) + xlogy(input_a - 1, input_x) - log_beta)
889
- return (zeros_like(input_a), zeros_like(input_b), F.reshape(partial_x * dout, sx))
890
-
891
- return bprop
892
-
893
-
894
- @bprop_getters.register(P.CholeskyInverse)
895
- def get_bprop_cholesky_inverse(self):
896
- """Grad definition for `CholeskyInverse` operation."""
897
- matmul = P.MatMul()
898
- upper = self.upper
899
- neg = P.Neg()
900
-
901
- def bprop(input_x, out, dout):
902
- input_perm = (1, 0)
903
- if dout.dtype == mstype.float64:
904
- input_x = F.cast(input_x, mstype.float32)
905
- out = F.cast(out, mstype.float32)
906
- dout = F.cast(dout, mstype.float32)
907
- common_term = dout + transpose(dout, input_perm)
908
- common_term = F.cast(common_term, mstype.float32)
909
- common_term = matmul(out, matmul(common_term, out))
910
- if upper is True:
911
- dx = neg(matmul(input_x, common_term))
912
- dx = F.cast(dx, mstype.float64)
913
- else:
914
- dx = neg(matmul(common_term, input_x))
915
- dx = F.cast(dx, mstype.float64)
916
- return (dx,)
917
- common_term = dout + transpose(dout, input_perm)
918
- common_term = matmul(out, matmul(common_term, out))
919
- if upper is True:
920
- dx = neg(matmul(input_x, common_term))
921
- else:
922
- dx = neg(matmul(common_term, input_x))
923
- return (dx,)
924
-
925
- return bprop
926
-
927
-
928
- @bprop_getters.register(Real)
929
- def get_bprop_real(self):
930
- """Grad definition for `Real` operation."""
931
- complex_grad = Complex()
932
-
933
- def bprop(input_1, out, dout):
934
- zero = zeros_like(dout)
935
- dx = dout
936
- res = complex_grad(dx, zero)
937
- return (res,)
938
-
939
- return bprop
940
-
941
-
942
- @bprop_getters.register(Imag)
943
- def get_bprop_imag(self):
944
- """Grad definition for `Real` operation."""
945
- complex_grad = Complex()
946
-
947
- def bprop(input_1, out, dout):
948
- zero = zeros_like(dout)
949
- dx = dout
950
- res = complex_grad(zero, dx)
951
- return (res,)
952
-
953
- return bprop
954
-
955
-
956
- @bprop_getters.register(Complex)
957
- def get_bprop_complex(self):
958
- """Grad definition for `Real` operation."""
959
- real_grad = Real()
960
- imag_grad = Imag()
961
-
962
- def bprop(real, imag, out, dout):
963
- dx = real_grad(dout)
964
- dy = imag_grad(dout)
965
- return (dx, dy,)
966
-
967
- return bprop
968
-
969
-
970
- @bprop_getters.register(ComplexAbs)
971
- def get_bprop_complex_abs(self):
972
- """Grad definition for `Real` operation."""
973
- div_no_nan = P.DivNoNan()
974
- complex_grad = Complex()
975
- mul = P.Mul()
976
-
977
- def bprop(x, out, dout):
978
- return (div_no_nan(mul(complex_grad(dout, zeros_like(dout)), x), complex_grad(out, zeros_like(out))),)
979
-
980
- return bprop
981
-
982
-
983
369
  @bprop_getters.register(NanToNum)
984
370
  def get_bprop_nan_to_num(self):
985
371
  """Grad definition for `NanToNum` operation."""
@@ -1036,409 +422,6 @@ def get_bprop_polar(self):
1036
422
  return bprop
1037
423
 
1038
424
 
1039
- @bprop_getters.register(P.Erfinv)
1040
- def get_bprop_erfinv(self):
1041
- """Grad definition for `Erfinv` operation."""
1042
- exp = P.Exp()
1043
- square = P.Square()
1044
- sqrt = P.Sqrt()
1045
- cast = P.Cast()
1046
- dtype = P.DType()
1047
-
1048
- def bprop(input_x, out, dout):
1049
- root_pi_over_two = cast(sqrt(F.scalar_to_tensor(np.pi)) / 2, dtype(dout))
1050
- out_square = square(out)
1051
- dx = dout * root_pi_over_two * exp(out_square)
1052
- return (dx,)
1053
-
1054
- return bprop
1055
-
1056
-
1057
- @bprop_getters.register(BesselI0)
1058
- def get_bprop_bessel_i0(self):
1059
- """Generate bprop for BesselI0"""
1060
- bessel_i1 = BesselI1()
1061
-
1062
- def bprop(x, out, dout):
1063
- dx = dout * bessel_i1(x)
1064
- return (dx,)
1065
-
1066
- return bprop
1067
-
1068
-
1069
- @bprop_getters.register(BesselI1)
1070
- def get_bprop_bessel_i1(self):
1071
- """Generate bprop for BesselI1"""
1072
- equal = P.Equal()
1073
- div = P.Div()
1074
- cast = P.Cast()
1075
- dtype = P.DType()
1076
- bessel_i0 = BesselI0()
1077
-
1078
- def bprop(x, out, dout):
1079
- dout_dx = mnp.where(equal(x, 0.), cast(1., dtype(x)), bessel_i0(x) - div(out, x))
1080
- dx = dout * dout_dx
1081
- return (dx,)
1082
-
1083
- return bprop
1084
-
1085
-
1086
- @bprop_getters.register(BesselJ0)
1087
- def get_bprop_bessel_j0(self):
1088
- """Generate bprop for BesselJ0"""
1089
- bessel_j1 = BesselJ1()
1090
-
1091
- def bprop(x, out, dout):
1092
- dx = -dout * bessel_j1(x)
1093
- return (dx,)
1094
-
1095
- return bprop
1096
-
1097
-
1098
- @bprop_getters.register(BesselJ1)
1099
- def get_bprop_bessel_j1(self):
1100
- """Generate bprop for BesselJ1"""
1101
- equal = P.Equal()
1102
- div = P.Div()
1103
- cast = P.Cast()
1104
- dtype = P.DType()
1105
- bessel_j0 = BesselJ0()
1106
-
1107
- def bprop(x, out, dout):
1108
- dout_dx = mnp.where(equal(x, 0.), cast(0.5, dtype(x)), bessel_j0(x) - div(out, x))
1109
- dx = dout * dout_dx
1110
- return (dx,)
1111
-
1112
- return bprop
1113
-
1114
-
1115
- @bprop_getters.register(BesselK0)
1116
- def get_bprop_bessel_k0(self):
1117
- """Generate bprop for BesselK0"""
1118
- bessel_k1 = BesselK1()
1119
-
1120
- def bprop(x, out, dout):
1121
- dx = -dout * bessel_k1(x)
1122
- return (dx,)
1123
-
1124
- return bprop
1125
-
1126
-
1127
- @bprop_getters.register(BesselK1)
1128
- def get_bprop_bessel_k1(self):
1129
- """Generate bprop for BesselK1"""
1130
- div = P.Div()
1131
- bessel_k0 = BesselK0()
1132
-
1133
- def bprop(x, out, dout):
1134
- dout_dx = -(bessel_k0(x) + div(out, x))
1135
- dx = dout * dout_dx
1136
- return (dx,)
1137
-
1138
- return bprop
1139
-
1140
-
1141
- @bprop_getters.register(BesselK0e)
1142
- def get_bprop_bessel_k0e(self):
1143
- """Generate bprop for BesselK0e"""
1144
- bessel_k1e = BesselK1e()
1145
-
1146
- def bprop(x, out, dout):
1147
- dx = dout * (out - bessel_k1e(x))
1148
- return (dx,)
1149
-
1150
- return bprop
1151
-
1152
-
1153
- @bprop_getters.register(BesselK1e)
1154
- def get_bprop_bessel_k1e(self):
1155
- """Generate bprop for BesselK1e"""
1156
- reciprocal = P.Reciprocal()
1157
- bessel_k0e = BesselK0e()
1158
-
1159
- def bprop(x, out, dout):
1160
- dout_dx = out * (1. - reciprocal(x)) - bessel_k0e(x)
1161
- dx = dout * dout_dx
1162
- return (dx,)
1163
-
1164
- return bprop
1165
-
1166
-
1167
- @bprop_getters.register(BesselY0)
1168
- def get_bprop_bessel_y0(self):
1169
- """Generate bprop for BesselY0"""
1170
- bessel_y1 = BesselY1()
1171
-
1172
- def bprop(x, out, dout):
1173
- dx = -dout * bessel_y1(x)
1174
- return (dx,)
1175
-
1176
- return bprop
1177
-
1178
-
1179
- @bprop_getters.register(BesselY1)
1180
- def get_bprop_bessel_y1(self):
1181
- """Generate bprop for BesselY1"""
1182
- div = P.Div()
1183
- bessel_y0 = BesselY0()
1184
-
1185
- def bprop(x, out, dout):
1186
- dout_dx = bessel_y0(x) - div(out, x)
1187
- dx = dout * dout_dx
1188
- return (dx,)
1189
-
1190
- return bprop
1191
-
1192
-
1193
- @bprop_getters.register(Hypot)
1194
- def get_bprop_hypot(self):
1195
- """Generate bprop for Hypot"""
1196
- mul_ = P.Mul()
1197
- div_ = P.Div()
1198
-
1199
- def bprop(x1, x2, out, dout):
1200
- x1_f32 = F.cast(x1, mstype.float32)
1201
- x2_f32 = F.cast(x2, mstype.float32)
1202
- out_f32 = F.cast(out, mstype.float32)
1203
- dout_f32 = F.cast(dout, mstype.float32)
1204
- dx1 = mul_(div_(x1_f32, out_f32), dout_f32)
1205
- dx2 = mul_(div_(x2_f32, out_f32), dout_f32)
1206
- result_dx1, result_dx2 = binop_grad_common(x1_f32, x2_f32, dx1, dx2)
1207
- result_dx1 = F.cast(result_dx1, F.dtype(x1))
1208
- result_dx2 = F.cast(result_dx2, F.dtype(x2))
1209
- return (result_dx1, result_dx2)
1210
-
1211
- return bprop
1212
-
1213
-
1214
- @bprop_getters.register(P.Asin)
1215
- def get_bprop_asin(self):
1216
- """Grad definition for `Asin` operation."""
1217
- input_grad = G.AsinGrad()
1218
-
1219
- def bprop(input_x, out, dout):
1220
- dx = input_grad(input_x, dout)
1221
- return (dx,)
1222
-
1223
- return bprop
1224
-
1225
-
1226
- @bprop_getters.register(P.Trunc)
1227
- def get_bprop_trunc(self):
1228
- """Grad definition for `Trunc` operation."""
1229
-
1230
- def bprop(input_x, output_y, dout):
1231
- bc_x = zeros_like(input_x)
1232
- return (bc_x,)
1233
-
1234
- return bprop
1235
-
1236
-
1237
- @bprop_getters.register(P.Ger)
1238
- def get_bprop_ger(self):
1239
- """Grad definition for 'Ger' operation"""
1240
- transpose_op = P.Transpose()
1241
- matmul = P.MatMul()
1242
- expand_dims = P.ExpandDims()
1243
- squeeze = P.Squeeze(1)
1244
-
1245
- def bprop(input_x, input_y, out, dout):
1246
- dx = squeeze(matmul(dout, expand_dims(input_y, 1)))
1247
- dy = squeeze(matmul(transpose_op(dout, (1, 0)), expand_dims(input_x, 1)))
1248
- return dx, dy
1249
-
1250
- return bprop
1251
-
1252
-
1253
- @bprop_getters.register(P.Cross)
1254
- def get_bprop_cross(self):
1255
- """Grad definition for 'Cross' operation"""
1256
- cross = P.Cross(dim=self.dim)
1257
-
1258
- def bprop(input1, input2, out, dout):
1259
- return cross(input2, dout), cross(dout, input1)
1260
-
1261
- return bprop
1262
-
1263
-
1264
- @bprop_getters.register(Median)
1265
- def get_bprop_median(self):
1266
- """Grad definition for 'Median' operation"""
1267
- input_grad = G.MedianGrad(global_median=self.global_median, axis=self.axis, keep_dims=self.keep_dims)
1268
-
1269
- def bprop(x, out, dout):
1270
- dx = F.cast(input_grad(dout[0], x, out[0], out[1]), F.dtype(x))
1271
- return (dx,)
1272
-
1273
- return bprop
1274
-
1275
-
1276
- @bprop_getters.register(P.MulNoNan)
1277
- def get_bprop_mul_no_nan(self):
1278
- """Grad definition for `MulNoNan` operation."""
1279
- mul_func = P.Mul()
1280
-
1281
- def bprop(x, y, out, dout):
1282
- bc_x = mul_func(dout, y)
1283
- bc_y = mul_func(x, dout)
1284
- return binop_grad_common(x, y, bc_x, bc_y)
1285
-
1286
- return bprop
1287
-
1288
-
1289
- @bprop_getters.register(Trace)
1290
- def get_bprop_trace(self):
1291
- """Grad definition for `Trace` operation."""
1292
- input_grad = G.TraceGrad()
1293
- shape_op = P.Shape()
1294
- to_array = P.TupleToArray()
1295
- cast = P.Cast()
1296
-
1297
- def bprop(x, out, dout):
1298
- shape = shape_op(x)
1299
- if F.is_sequence_value_unknown(shape):
1300
- shape = dyn_shape_op(x)
1301
- dx = input_grad(dout, shape)
1302
- else:
1303
- dx = input_grad(dout, cast(to_array(shape), mstype.int64))
1304
- return (dx,)
1305
-
1306
- return bprop
1307
-
1308
-
1309
- @bprop_getters.register(Fmin)
1310
- def get_bprop_fmin(self):
1311
- """Grad definition for 'Fmin' operation"""
1312
- shape_ = P.Shape()
1313
- masked_fill_op = P.MaskedFill()
1314
- logical_or_op = P.LogicalOr()
1315
- logical_not_op = P.LogicalNot()
1316
- logical_and_op = P.LogicalAnd()
1317
- mul_op = P.Mul()
1318
- is_nan_op = P.IsNan()
1319
- reshape_ = P.Reshape()
1320
-
1321
- def bprop(x1, x2, out, dout):
1322
- x1_dtype = F.dtype(x1)
1323
- x2_dtype = F.dtype(x2)
1324
- x1 = F.cast(x1, mstype.float32)
1325
- x2 = F.cast(x2, mstype.float32)
1326
- dout = F.cast(dout, mstype.float32)
1327
- b1 = logical_or_op((x1 <= x2), is_nan_op(x2))
1328
- b2 = logical_or_op((x2 < x1), logical_and_op(is_nan_op(x1), logical_not_op(is_nan_op(x2))))
1329
- rx1 = masked_fill_op(x1, b1, 1.)
1330
- rx1 = masked_fill_op(rx1, logical_not_op(b1), 0.)
1331
- rx2 = masked_fill_op(x2, b2, 1.)
1332
- rx2 = masked_fill_op(rx2, logical_not_op(b2), 0.)
1333
- rrx1 = mul_op(rx1, dout)
1334
- rrx2 = mul_op(rx2, dout)
1335
- shape_of_x1 = shape_(x1)
1336
- shape_of_x2 = shape_(x2)
1337
- x1_dim = len(shape_of_x1)
1338
- x2_dim = len(shape_of_x2)
1339
- if x1_dim == 0 and x2_dim != 0:
1340
- sum_r1 = rrx1.sum()
1341
- sum_r2 = rrx2
1342
- elif x1_dim == 0 and x2_dim == 0:
1343
- sum_r1 = rrx1.sum()
1344
- sum_r2 = rrx2.sum()
1345
- elif x1_dim != 0 and x2_dim == 0:
1346
- sum_r2 = rrx2.sum()
1347
- sum_r1 = rrx1
1348
- else:
1349
- rx, ry = DynamicBroadcastGradientArgs()(shape_of_x1, shape_of_x2)
1350
- sum_r1 = sum_grad_reduce_axis(rrx1, rx)
1351
- sum_r2 = sum_grad_reduce_axis(rrx2, ry)
1352
- brrx1 = reshape_(sum_r1, shape_of_x1)
1353
- brrx2 = reshape_(sum_r2, shape_of_x2)
1354
- brrx1 = F.cast(brrx1, x1_dtype)
1355
- brrx2 = F.cast(brrx2, x2_dtype)
1356
- return brrx1, brrx2
1357
-
1358
- return bprop
1359
-
1360
-
1361
- @bprop_getters.register(Fmax)
1362
- def get_bprop_fmax(self):
1363
- """Grad definition for 'Fmax' operation"""
1364
- shape_ = P.Shape()
1365
- masked_fill_op = P.MaskedFill()
1366
- logical_or_op = P.LogicalOr()
1367
- logical_not_op = P.LogicalNot()
1368
- logical_and_op = P.LogicalAnd()
1369
- mul_op = P.Mul()
1370
- is_nan_op = P.IsNan()
1371
- reshape_ = P.Reshape()
1372
-
1373
- def bprop(x1, x2, out, dout):
1374
- x1_dtype = F.dtype(x1)
1375
- x2_dtype = F.dtype(x2)
1376
- if x1_dtype != mstype.float32:
1377
- x1 = F.cast(x1, mstype.float32)
1378
- dout = F.cast(dout, mstype.float32)
1379
- if x2_dtype != mstype.float32:
1380
- x2 = F.cast(x2, mstype.float32)
1381
- dout = F.cast(dout, mstype.float32)
1382
- b1 = logical_or_op(logical_and_op((x1 >= x2), logical_not_op(is_nan_op(x1))), is_nan_op(x2))
1383
- b2 = logical_or_op(logical_and_op(x2 > x1, logical_not_op(is_nan_op(x2))),
1384
- logical_and_op(is_nan_op(x1), logical_not_op(is_nan_op(x2))))
1385
- rx1 = masked_fill_op(x1, b1, 1.)
1386
- rx1 = masked_fill_op(rx1, logical_not_op(b1), 0.)
1387
- rx2 = masked_fill_op(x2, b2, 1.)
1388
- rx2 = masked_fill_op(rx2, logical_not_op(b2), 0.)
1389
- rrx1 = mul_op(rx1, dout)
1390
- rrx2 = mul_op(rx2, dout)
1391
- shape_of_x1 = shape_(x1)
1392
- shape_of_x2 = shape_(x2)
1393
- x1_dim = len(shape_of_x1)
1394
- x2_dim = len(shape_of_x2)
1395
- if x1_dim == 0 and x2_dim != 0:
1396
- sum_r1 = rrx1.sum()
1397
- sum_r2 = rrx2
1398
- elif x1_dim == 0 and x2_dim == 0:
1399
- sum_r1 = rrx1.sum()
1400
- sum_r2 = rrx2.sum()
1401
- elif x1_dim != 0 and x2_dim == 0:
1402
- sum_r2 = rrx2.sum()
1403
- sum_r1 = rrx1
1404
- else:
1405
- rx, ry = DynamicBroadcastGradientArgs()(shape_of_x1, shape_of_x2)
1406
- sum_r1 = sum_grad_reduce_axis(rrx1, rx)
1407
- sum_r2 = sum_grad_reduce_axis(rrx2, ry)
1408
- brrx1 = reshape_(sum_r1, shape_of_x1)
1409
- brrx2 = reshape_(sum_r2, shape_of_x2)
1410
- brrx1 = F.cast(brrx1, x1_dtype)
1411
- brrx2 = F.cast(brrx2, x2_dtype)
1412
- return brrx1, brrx2
1413
-
1414
-
1415
- return bprop
1416
-
1417
-
1418
- @bprop_getters.register(G.MinimumGrad)
1419
- def get_bprop_minimum_grad(self):
1420
- """Grad definition for 'MinimumGrad' operation"""
1421
- input_grad = G.MinimumGradGrad()
1422
-
1423
- def bprop(x1, x2, grad, out, dout):
1424
- sopd_x1, sopd_x2, sopd_grads = input_grad(x1, x2, dout[0], dout[1])
1425
- sopd_x1 = zeros_like(x1)
1426
- sopd_x2 = zeros_like(x2)
1427
- return sopd_x1, sopd_x2, sopd_grads
1428
-
1429
- return bprop
1430
-
1431
-
1432
- @bprop_getters.register(Bernoulli)
1433
- def get_bprop_bernoulli(self):
1434
- """"Grad definition for 'Bernoulli' operation."""
1435
-
1436
- def bprop(x, p, out, dout):
1437
- return zeros_like(x), zeros_like(p)
1438
-
1439
- return bprop
1440
-
1441
-
1442
425
  @bprop_getters.register(TridiagonalSolve)
1443
426
  def get_bprop_tridiagonalsolve(self):
1444
427
  """Grad definition for 'TridiagonalSolve' operation"""
@@ -1463,85 +446,6 @@ def get_bprop_tridiagonalsolve(self):
1463
446
  return bprop
1464
447
 
1465
448
 
1466
- @bprop_getters.register(Igamma)
1467
- def get_bprop_igamma(self):
1468
- """Grad definition for `Igamma` operation."""
1469
- shape_ = P.Shape()
1470
- igammagrada = G.IgammaGradA()
1471
- lgamma = nn.LGamma()
1472
- log_ = P.Log()
1473
- exp_ = P.Exp()
1474
- reshape_ = P.Reshape()
1475
- reduce_sum_ = P.ReduceSum()
1476
-
1477
- def bprop(a, x, out, dout):
1478
- sa = shape_(a)
1479
- sx = shape_(x)
1480
- if F.is_sequence_value_unknown(sa) or F.is_sequence_value_unknown(sx):
1481
- sa = dyn_shape_op(a)
1482
- sx = dyn_shape_op(x)
1483
- ra, rx = DynamicBroadcastGradientArgs()(sa, sx)
1484
- partial_a = igammagrada(a, x)
1485
- partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
1486
- r1 = reshape_(sum_grad_reduce_axis(partial_a * dout, ra), sa)
1487
- r2 = reshape_(sum_grad_reduce_axis(partial_x * dout, rx), sx)
1488
- return r1, r2
1489
- ra, rx = broadcast_gradient_args(sa, sx)
1490
- partial_a = igammagrada(a, x)
1491
- partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
1492
- if ra != ():
1493
- r1 = reshape_(reduce_sum_(partial_a * dout, ra), sa)
1494
- else:
1495
- r1 = reshape_(partial_a * dout, sa)
1496
- if rx != ():
1497
- r2 = reshape_(reduce_sum_(partial_x * dout, rx), sx)
1498
- else:
1499
- r2 = reshape_(partial_x * dout, sx)
1500
- return r1, r2
1501
-
1502
- return bprop
1503
-
1504
-
1505
- @bprop_getters.register(Igammac)
1506
- def get_bprop_igammac(self):
1507
- """Grad definition for `Igammac` operation."""
1508
- shape_ = P.Shape()
1509
- igammagrada = G.IgammaGradA()
1510
- lgamma = nn.LGamma()
1511
- log_ = P.Log()
1512
- exp_ = P.Exp()
1513
- reshape_ = P.Reshape()
1514
- reduce_sum_ = P.ReduceSum()
1515
- neg_ = P.Neg()
1516
-
1517
- def bprop(a, x, out, dout):
1518
- sa = shape_(a)
1519
- sx = shape_(x)
1520
- if F.is_sequence_value_unknown(sa) or F.is_sequence_value_unknown(sx):
1521
- sa = dyn_shape_op(a)
1522
- sx = dyn_shape_op(x)
1523
- ra, rx = DynamicBroadcastGradientArgs()(sa, sx)
1524
- partial_a = igammagrada(a, x)
1525
- partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
1526
- r1 = neg_(reshape_(sum_grad_reduce_axis(partial_a * dout, ra), sa))
1527
- r2 = neg_(reshape_(sum_grad_reduce_axis(partial_x * dout, rx), sx))
1528
- return r1, r2
1529
- ra, rx = broadcast_gradient_args(sa, sx)
1530
- partial_a = igammagrada(a, x)
1531
- partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
1532
- if ra != ():
1533
- r1 = neg_(reshape_(reduce_sum_(partial_a * dout, ra), sa))
1534
- else:
1535
- r1 = neg_(reshape_(partial_a * dout, sa))
1536
- if rx != ():
1537
- r2 = neg_(reshape_(reduce_sum_(partial_x * dout, rx), sx))
1538
- else:
1539
- r2 = neg_(reshape_(partial_x * dout, sx))
1540
- return r1, r2
1541
-
1542
- return bprop
1543
-
1544
-
1545
449
  @bprop_getters.register(Lgamma)
1546
450
  def get_bprop_lgamma(self):
1547
451
  """Grad definition for `Lgamma` operation."""
@@ -1599,90 +503,19 @@ def get_bprop_polygamma(self):
1599
503
  return bprop
1600
504
 
1601
505
 
1602
- @bprop_getters.register(TridiagonalMatMul)
1603
- def get_bprop_tridiagonal_matmul(self):
1604
- """Grad definition for 'TridiagonalMatMul' operation"""
1605
-
1606
- def _leftshift(x):
1607
- """Shifts next-to-last dimension to the left, adding zero on the right."""
1608
- rank = P.Rank()(x)
1609
- paddings = ((0,) * (2),) * (rank - 2) + ((0, 1), (0, 0))
1610
- pad_op = P.Pad(paddings)
1611
- return pad_op(x[..., 1:, :])
1612
-
1613
- def _rightshift(x):
1614
- """Shifts next-to-last dimension to the right, adding zero on the left."""
1615
- rank = P.Rank()(x)
1616
- paddings = ((0,) * (2),) * (rank - 2) + ((1, 0), (0, 0))
1617
- pad_op = P.Pad(paddings)
1618
- return pad_op(x[..., :-1, :])
1619
-
1620
- def matrix_transpose(x):
1621
- x_rank = P.Rank()(x)
1622
- if x_rank > 2:
1623
- m = x_rank - 2
1624
- n = x_rank - 1
1625
- x_range = range(m)
1626
- perm = (x_range) + (n, m)
1627
- else:
1628
- perm = (1, 0)
1629
- return P.Transpose()(x, perm)
1630
-
1631
- reduce_sum = P.ReduceSum()
1632
- expand_dims = P.ExpandDims()
1633
- conjugate = P.Conj()
1634
-
1635
- def bprop(superdiag, maindiag, subdiag, rhs, out, grad):
1636
- superdiag_type = F.dtype(superdiag)
1637
- superdiag_conj = matrix_transpose(superdiag)
1638
- maindiag_conj = matrix_transpose(maindiag)
1639
- subdiag_conj = matrix_transpose(subdiag)
1640
- rhs_conj = rhs
1641
- if superdiag_type in (mstype.complex64, mstype.complex128):
1642
- superdiag_conj = conjugate(superdiag_conj)
1643
- maindiag_conj = conjugate(maindiag_conj)
1644
- subdiag_conj = conjugate(subdiag_conj)
1645
- rhs_conj = conjugate(rhs)
1646
- superdiag_grad = reduce_sum(_leftshift(rhs_conj) * grad, -1)
1647
- maindiag_grad = reduce_sum(rhs_conj * grad, -1)
1648
- subdiag_grad = reduce_sum(_rightshift(rhs_conj) * grad, -1)
1649
- rhs_grad = _rightshift(superdiag_conj * grad) + maindiag_conj * grad + \
1650
- _leftshift(subdiag_conj * grad)
1651
- superdiag_grad = expand_dims(superdiag_grad, -2)
1652
- maindiag_grad = expand_dims(maindiag_grad, -2)
1653
- subdiag_grad = expand_dims(subdiag_grad, -2)
1654
- return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad
1655
-
1656
- return bprop
1657
-
1658
-
1659
- @bprop_getters.register(AddV2)
1660
- def get_bprop_add_v2(self):
1661
- """Grad definition for `AddV2` operation."""
1662
-
1663
- def bprop(x, y, out, dout):
1664
- return binop_grad_common(x, y, dout, dout)
1665
-
1666
- return bprop
1667
-
1668
-
1669
506
  @bprop_getters.register(CholeskySolve)
1670
507
  def get_bprop_cholesky_solve(self):
1671
508
  """Grad definition for 'CholeskySolve' operation"""
1672
509
  batchmatmul_op = P.BatchMatMul()
1673
510
  matmul_op = P.MatMul()
1674
511
  neg_op = P.Neg()
1675
- shape_op = P.Shape()
1676
512
  upper = self.upper
1677
513
  cholesky_solve = CholeskySolve(upper=self.upper)
514
+ rank = P.Rank()
1678
515
 
1679
516
  def bprop(x1, x2, out, dout):
1680
517
  flag = 0
1681
- shape_x1 = shape_op(x1)
1682
- if F.is_sequence_shape_unknown(shape_x1):
1683
- len_x1 = dyn_rank(x1)
1684
- else:
1685
- len_x1 = len(shape_x1)
518
+ len_x1 = rank(x1)
1686
519
  if dout.dtype == mstype.float64:
1687
520
  flag = 1
1688
521
  x2 = F.cast(x2, mstype.float32)
@@ -1714,51 +547,6 @@ def get_bprop_cholesky_solve(self):
1714
547
  return bprop
1715
548
 
1716
549
 
1717
- @bprop_getters.register(NextAfter)
1718
- def get_bprop_nextafter(self):
1719
- """Grad definition for 'NextAfter' operation"""
1720
- shape = P.Shape()
1721
- dyn_shape = P.TensorShape()
1722
- ones = P.Ones()
1723
- zeros = P.Zeros()
1724
- dtype = P.DType()
1725
- reshape = P.Reshape()
1726
- cast = P.Cast()
1727
-
1728
- def bprop(x1, x2, out, dout):
1729
- dout_type = dtype(dout)
1730
- x1_type = dtype(x1)
1731
- x2_type = dtype(x2)
1732
- if x1_type == mstype.float64:
1733
- x1 = cast(x1, mstype.float32)
1734
- if x2_type == mstype.float64:
1735
- x2 = cast(x2, mstype.float32)
1736
- if dout_type == mstype.float64:
1737
- dout = cast(dout, mstype.float32)
1738
-
1739
- s_x1 = shape(x1)
1740
- partial_x1 = ()
1741
- if F.is_sequence_value_unknown(s_x1):
1742
- s_x1 = dyn_shape(x1)
1743
- partial_x1 = dyn_ones(s_x1, dtype(x1))
1744
- else:
1745
- partial_x1 = ones(s_x1, dtype(x1))
1746
-
1747
- s_x2 = shape(x2)
1748
- partial_x2 = ()
1749
- if F.is_sequence_value_unknown(s_x2):
1750
- s_x2 = dyn_shape(x2)
1751
- partial_x2 = dyn_fill(dtype(x2), s_x2, 0)
1752
- else:
1753
- partial_x2 = zeros(s_x2, dtype(x2))
1754
-
1755
- dx1 = reshape(partial_x1 * dout, s_x1)
1756
- dx2 = reshape(partial_x2 * dout, s_x2)
1757
- return cast(dx1, dtype(dout)), cast(dx2, dtype(dout))
1758
-
1759
- return bprop
1760
-
1761
-
1762
550
  @bprop_getters.register(Diagonal)
1763
551
  def get_bprop_diagonal(self):
1764
552
  """Grad definition for 'Diagonal' operation"""
@@ -1971,7 +759,6 @@ def get_bprop_fft_with_size(self):
1971
759
  onesided=onesided)
1972
760
 
1973
761
  complex_op = P.Complex()
1974
- shape_op = P.Shape()
1975
762
  to_tensor_op = P.ScalarToTensor()
1976
763
  type_op = P.DType()
1977
764
  concat_op = P.Concat()
@@ -2089,3 +876,143 @@ def get_bprop_fft_with_size(self):
2089
876
  return (dx,)
2090
877
 
2091
878
  return bprop
879
+
880
+
881
+ def dyn_binop_grad_common(x, y, dx, dy):
882
+ """
883
+ Common grad definition for binary operations when the input is dynamic shape.
884
+
885
+ The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
886
+ """
887
+ shape_of_x = shape_op(x)
888
+ shape_of_y = shape_op(y)
889
+ rx, ry = DynamicBroadcastGradientArgs()(shape_of_x, shape_of_y)
890
+ dx_origin_dtype = dx.dtype
891
+ if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
892
+ dx = F.cast(dx, mstype.float32)
893
+ dx = sum_grad_reduce_axis(dx, rx)
894
+ dx = F.cast(dx, dx_origin_dtype)
895
+ else:
896
+ dx = sum_grad_reduce_axis(dx, rx)
897
+ dy_origin_dtype = dy.dtype
898
+ if dy_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
899
+ dy = F.cast(dy, mstype.float32)
900
+ dy = sum_grad_reduce_axis(dy, ry)
901
+ dy = F.cast(dy, dy_origin_dtype)
902
+ else:
903
+ dy = sum_grad_reduce_axis(dy, ry)
904
+ reduce_dx = reshape(dx, shape_of_x)
905
+ reduce_dy = reshape(dy, shape_of_y)
906
+ return reduce_dx, reduce_dy
907
+
908
+
909
+ def dyn_binop_grad_common_with_shift(x, y, dx, dy, shift):
910
+ """
911
+ Common grad definition for binary operations with shift when the input is dynamic shape.
912
+
913
+ The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
914
+ """
915
+ shape_of_x = shape_op(x)
916
+ shape_of_y = shape_op(y)
917
+ broadcast_shape_of_x = shape_of_x[:-shift]
918
+ broadcast_shape_of_y = shape_of_y[:-shift]
919
+ rx, ry = DynamicBroadcastGradientArgs()(broadcast_shape_of_x, broadcast_shape_of_y)
920
+ dx = sum_grad_reduce_axis(dx, rx)
921
+ dy = sum_grad_reduce_axis(dy, ry)
922
+ reduce_dx = reshape(dx, shape_of_x)
923
+ reduce_dy = reshape(dy, shape_of_y)
924
+ return reduce_dx, reduce_dy
925
+
926
+
927
+ def _reduce_sum_with_cast(dx, axis):
928
+ dx_origin_dtype = dx.dtype
929
+ # Currently, for Ascend and GPU, the reduce_sum's input does not support int16, int32 and int64.
930
+ if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
931
+ dx = F.cast(dx, mstype.float32)
932
+ dx = reduce_sum(dx, axis)
933
+ return F.cast(dx, dx_origin_dtype)
934
+ return reduce_sum(dx, axis)
935
+
936
+
937
+ def binop_grad_common(x, y, dx, dy):
938
+ """
939
+ Common grad definition for binary operations.
940
+
941
+ The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
942
+ """
943
+ shape_of_x = shape_op(x)
944
+ shape_of_y = shape_op(y)
945
+ # if input shape is the same as dout shape, do not need to reduce
946
+ reduce_dx = dx
947
+ reduce_dy = dy
948
+ if not (F.is_sequence_value_unknown(shape_of_x) or F.is_sequence_value_unknown(shape_of_y)):
949
+ rx = broadcast_gradient_args(shape_of_x, shape_of_y)
950
+ if rx[0]:
951
+ # if dx is scalar whose shape is (), do not need reduce
952
+ if shape_op(dx):
953
+ dx = _reduce_sum_with_cast(dx, rx[0])
954
+ reduce_dx = reshape(dx, shape_of_x)
955
+ if rx[1]:
956
+ # if dy is scalar whose shape is (), do not need reduce
957
+ if shape_op(dy):
958
+ dy = _reduce_sum_with_cast(dy, rx[1])
959
+ reduce_dy = reshape(dy, shape_of_y)
960
+ return reduce_dx, reduce_dy
961
+
962
+ if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
963
+ # x or y is scalar
964
+ if not isinstance(shape_of_x, tuple):
965
+ reduce_dx = _reduce_sum_with_cast(dx, ())
966
+ if not isinstance(shape_of_y, tuple):
967
+ reduce_dy = _reduce_sum_with_cast(dy, ())
968
+ return reduce_dx, reduce_dy
969
+
970
+ return dyn_binop_grad_common(x, y, dx, dy)
971
+
972
+
973
+ def binop_grad_common_with_shift(x, y, dx, dy, shift):
974
+ """
975
+ Common grad definition for binary operations with shift.
976
+
977
+ The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
978
+ """
979
+ shape_of_x = shape_op(x)
980
+ shape_of_y = shape_op(y)
981
+ broadcast_shape_of_x = shape_of_x[:-shift]
982
+ broadcast_shape_of_y = shape_of_y[:-shift]
983
+ # if input shape is the same as dout shape, do not need to reduce
984
+ reduce_dx = dx
985
+ reduce_dy = dy
986
+ if not (F.is_sequence_value_unknown(broadcast_shape_of_x) or F.is_sequence_value_unknown(broadcast_shape_of_y)):
987
+ rx = broadcast_gradient_args(broadcast_shape_of_x, broadcast_shape_of_y)
988
+ if rx[0]:
989
+ # if dx is scalar whose shape is (), do not need reduce
990
+ if shape_op(dx):
991
+ dx = _reduce_sum_with_cast(dx, rx[0])
992
+ reduce_dx = reshape(dx, shape_of_x)
993
+ if rx[1]:
994
+ # if dy is scalar whose shape is (), do not need reduce
995
+ if shape_op(dy):
996
+ dy = _reduce_sum_with_cast(dy, rx[1])
997
+ reduce_dy = reshape(dy, shape_of_y)
998
+ return reduce_dx, reduce_dy
999
+
1000
+ if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
1001
+ # x or y is scalar
1002
+ if not isinstance(shape_of_x, tuple):
1003
+ reduce_dx = _reduce_sum_with_cast(dx, ())
1004
+ if not isinstance(shape_of_y, tuple):
1005
+ reduce_dy = _reduce_sum_with_cast(dy, ())
1006
+ return reduce_dx, reduce_dy
1007
+
1008
+ return dyn_binop_grad_common_with_shift(x, y, dx, dy, shift)
1009
+
1010
+
1011
+ @bprop_getters.register(P.TensorAdd)
1012
+ def get_bprop_tensor_add(self):
1013
+ """Grad definition for `Add` operation."""
1014
+
1015
+ def bprop(x, y, out, dout):
1016
+ return binop_grad_common(x, y, dout, dout)
1017
+
1018
+ return bprop