mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,1684 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """Define the grad rules of math related operations."""
17
-
18
- import numpy as np
19
- import mindspore as ms
20
- from mindspore import nn
21
- from mindspore.common import Tensor
22
- from mindspore.common import dtype as mstype
23
- from mindspore.ops import functional as F
24
- from mindspore.ops import operations as P
25
- from mindspore.ops.operations import _grad_ops as G
26
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
27
- from mindspore.ops.functional import broadcast_gradient_args, reduced_shape, tuple_div
28
- from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_invert_permutation
29
- from mindspore.ops._grad.grad_base import convert_to_tensor
30
- from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank
31
- from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
32
- from mindspore.ops.primitive import _primexpr
33
- from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
34
- from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass, DynamicBroadcastTo
35
- from mindspore.ops.operations import array_ops as A
36
-
37
- shape_op = P.Shape()
38
- dyn_shape_op = P.TensorShape()
39
- reduce_prod = P.ReduceProd()
40
- reduce_sum = P.ReduceSum()
41
- reshape = P.Reshape()
42
- tile = P.Tile()
43
- is_sub_class = IsSubClass()
44
- to_array = P.TupleToArray()
45
- real_div = P.RealDiv()
46
-
47
-
48
- def dyn_binop_grad_common(x, y, dx, dy):
49
- """
50
- Common grad definition for binary operations when the input is dynamic shape.
51
-
52
- The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
53
- """
54
- shape_of_x = dyn_shape_op(x)
55
- shape_of_y = dyn_shape_op(y)
56
- rx, ry = DynamicBroadcastGradientArgs()(shape_of_x, shape_of_y)
57
- dx_origin_dtype = dx.dtype
58
- if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
59
- dx = F.cast(dx, mstype.float32)
60
- dx = sum_grad_reduce_axis(dx, rx)
61
- dx = F.cast(dx, dx_origin_dtype)
62
- else:
63
- dx = sum_grad_reduce_axis(dx, rx)
64
- dy_origin_dtype = dy.dtype
65
- if dy_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
66
- dy = F.cast(dy, mstype.float32)
67
- dy = sum_grad_reduce_axis(dy, ry)
68
- dy = F.cast(dy, dy_origin_dtype)
69
- else:
70
- dy = sum_grad_reduce_axis(dy, ry)
71
- reduce_dx = reshape(dx, shape_of_x)
72
- reduce_dy = reshape(dy, shape_of_y)
73
- return reduce_dx, reduce_dy
74
-
75
-
76
- def dyn_binop_grad_common_with_shift(x, y, dx, dy, shift):
77
- """
78
- Common grad definition for binary operations with shift when the input is dynamic shape.
79
-
80
- The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
81
- """
82
- shape_of_x = dyn_shape_op(x)
83
- shape_of_y = dyn_shape_op(y)
84
- broadcast_shape_of_x = shape_of_x[:-shift]
85
- broadcast_shape_of_y = shape_of_y[:-shift]
86
- rx, ry = DynamicBroadcastGradientArgs()(broadcast_shape_of_x, broadcast_shape_of_y)
87
- dx = sum_grad_reduce_axis(dx, rx)
88
- dy = sum_grad_reduce_axis(dy, ry)
89
- reduce_dx = reshape(dx, shape_of_x)
90
- reduce_dy = reshape(dy, shape_of_y)
91
- return reduce_dx, reduce_dy
92
-
93
-
94
- def _reduce_sum_with_cast(dx, axis):
95
- dx_origin_dtype = dx.dtype
96
- # Currently, for Ascend and GPU, the reduce_sum's input does not support int16, int32 and int64.
97
- if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
98
- dx = F.cast(dx, mstype.float32)
99
- dx = reduce_sum(dx, axis)
100
- return F.cast(dx, dx_origin_dtype)
101
- return reduce_sum(dx, axis)
102
-
103
-
104
- def binop_grad_common(x, y, dx, dy):
105
- """
106
- Common grad definition for binary operations.
107
-
108
- The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
109
- """
110
- shape_of_x = shape_op(x)
111
- shape_of_y = shape_op(y)
112
- # if input shape is the same as dout shape, do not need to reduce
113
- reduce_dx = dx
114
- reduce_dy = dy
115
- if not (F.is_sequence_value_unknown(shape_of_x) or F.is_sequence_value_unknown(shape_of_y)):
116
- rx = broadcast_gradient_args(shape_of_x, shape_of_y)
117
- if rx[0]:
118
- # if dx is scalar whose shape is (), do not need reduce
119
- if shape_op(dx):
120
- dx = _reduce_sum_with_cast(dx, rx[0])
121
- reduce_dx = reshape(dx, shape_of_x)
122
- if rx[1]:
123
- # if dy is scalar whose shape is (), do not need reduce
124
- if shape_op(dy):
125
- dy = _reduce_sum_with_cast(dy, rx[1])
126
- reduce_dy = reshape(dy, shape_of_y)
127
- return reduce_dx, reduce_dy
128
-
129
- if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
130
- # x or y is scalar
131
- if not isinstance(shape_of_x, tuple):
132
- reduce_dx = _reduce_sum_with_cast(dx, ())
133
- if not isinstance(shape_of_y, tuple):
134
- reduce_dy = _reduce_sum_with_cast(dy, ())
135
- return reduce_dx, reduce_dy
136
-
137
- return dyn_binop_grad_common(x, y, dx, dy)
138
-
139
-
140
- def binop_grad_common_with_shift(x, y, dx, dy, shift):
141
- """
142
- Common grad definition for binary operations with shift.
143
-
144
- The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
145
- """
146
- shape_of_x = shape_op(x)
147
- shape_of_y = shape_op(y)
148
- broadcast_shape_of_x = shape_of_x[:-shift]
149
- broadcast_shape_of_y = shape_of_y[:-shift]
150
- # if input shape is the same as dout shape, do not need to reduce
151
- reduce_dx = dx
152
- reduce_dy = dy
153
- if not (F.is_sequence_value_unknown(broadcast_shape_of_x) or F.is_sequence_value_unknown(broadcast_shape_of_y)):
154
- rx = broadcast_gradient_args(broadcast_shape_of_x, broadcast_shape_of_y)
155
- if rx[0]:
156
- # if dx is scalar whose shape is (), do not need reduce
157
- if shape_op(dx):
158
- dx = _reduce_sum_with_cast(dx, rx[0])
159
- reduce_dx = reshape(dx, shape_of_x)
160
- if rx[1]:
161
- # if dy is scalar whose shape is (), do not need reduce
162
- if shape_op(dy):
163
- dy = _reduce_sum_with_cast(dy, rx[1])
164
- reduce_dy = reshape(dy, shape_of_y)
165
- return reduce_dx, reduce_dy
166
-
167
- if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
168
- # x or y is scalar
169
- if not isinstance(shape_of_x, tuple):
170
- reduce_dx = _reduce_sum_with_cast(dx, ())
171
- if not isinstance(shape_of_y, tuple):
172
- reduce_dy = _reduce_sum_with_cast(dy, ())
173
- return reduce_dx, reduce_dy
174
-
175
- return dyn_binop_grad_common_with_shift(x, y, dx, dy, shift)
176
-
177
-
178
- def _dyn_reduced_shape(input_shape, axis, x):
179
- """Dynamic reduce shape"""
180
- input_shape = P.Cast()(input_shape, ms.int32)
181
- if x is not None and not F.is_sequence_shape_unknown(shape_op(x)):
182
- input_rank = len(shape_op(x))
183
- else:
184
- input_rank = dyn_rank(x)
185
- input_rank = P.Cast()(input_rank, ms.int32)
186
-
187
- if (isinstance(axis, tuple) and axis == ()) or (isinstance(axis, list) and axis == []):
188
- res_shape = P.ExpandDims()(input_rank, 0)
189
- return dyn_ones(res_shape, res_shape.dtype)
190
-
191
- if isinstance(axis, int):
192
- axis = (axis,)
193
-
194
- real_axis = axis
195
- if not isinstance(axis, Tensor):
196
- real_axis = Tensor(axis, ms.int32)
197
-
198
- real_axis = (real_axis + input_rank) % input_rank
199
- if real_axis.ndim == 0:
200
- real_axis = P.ExpandDims()(real_axis, 0)
201
- expanded_axis = P.ExpandDims()(real_axis, 1)
202
- expanded_axis = P.Cast()(expanded_axis, ms.int32)
203
- update = P.Cast()(P.OnesLike()(real_axis), ms.float32)
204
- input_shape = P.Cast()(input_shape, ms.float32)
205
- return P.TensorScatterUpdate()(input_shape, expanded_axis, update)
206
-
207
-
208
- def _sum_grad(x, axis, dout):
209
- """Grad definition for `Sum` operation."""
210
- input_shape = shape_op(x)
211
- is_mutable, axis = convert_to_tensor(axis)
212
- if F.is_sequence_value_unknown(input_shape) or is_mutable:
213
- input_shape = dyn_shape_op(x)
214
- output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
215
- output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
216
- grad = reshape(dout, output_shape_kept_dims)
217
- return DynamicBroadcastTo()(grad, input_shape)
218
-
219
- output_shape_kept_dims = reduced_shape(input_shape, axis)
220
- tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
221
- grad = reshape(dout, output_shape_kept_dims)
222
- return tile(grad, tile_scaling)
223
-
224
-
225
- def _min_or_max_grad(x, axis, out, dout):
226
- """Grad definition for `Min` and `Max` operations."""
227
- input_shape = shape_op(x)
228
- output_shape_kept_dims = ()
229
- if F.is_sequence_value_unknown(input_shape):
230
- input_shape = dyn_shape_op(x)
231
- output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
232
- output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
233
- else:
234
- output_shape_kept_dims = reduced_shape(input_shape, axis)
235
-
236
- y = reshape(out, output_shape_kept_dims)
237
- grad = reshape(dout, output_shape_kept_dims)
238
- indicators = F.cast(F.equal(y, x), F.dtype(grad))
239
- min_num = F.cast(F.scalar_to_tensor(1e-24), F.dtype(grad))
240
- num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
241
- return indicators / num_selected * grad
242
-
243
-
244
- def _onehot_with_neg_axis(axis, indices, depth, on_value_dtype):
245
- """onehot support tensor axis"""
246
- depth_range = P.Range()(F.cast(0, depth.dtype), depth, F.cast(1, depth.dtype))
247
- indices_expand = P.ExpandDims()(indices, axis)
248
- indices_expand_rank = dyn_rank_1d(indices_expand)
249
- broad_shape = dyn_ones(indices_expand_rank, mstype.int64)
250
- # It should use int64 dtype, but the TensorScatterUpdate op does not support the int64
251
- # dtype on Ascend device, so the float32 dtype is used here.
252
- update_dtype = mstype.float32
253
- broad_shape = dyn_ones(indices_expand_rank, update_dtype)
254
- broad_shape[axis] = F.cast(depth, update_dtype)
255
- broad_shape = F.cast(broad_shape, mstype.int64)
256
- depth_broad = P.Reshape()(depth_range, broad_shape)
257
- one_hot_bool = P.Equal()(indices_expand, depth_broad)
258
- one_hot_res = F.cast(one_hot_bool, on_value_dtype)
259
- return one_hot_res
260
-
261
-
262
- def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
263
- """ArgMinWiwhValue and ArgMaxWithValue grad."""
264
- expand = P.ExpandDims()
265
- squeeze = P.Squeeze()
266
- x_shape = F.shape(x)
267
- x_dim = len(x_shape)
268
- x_axis = axis
269
- onehot_axis_is_neg = False
270
- if x_axis < 0:
271
- if not F.is_sequence_shape_unknown(x_shape):
272
- x_axis = axis + x_dim
273
- else:
274
- onehot_axis_is_neg = True
275
- onehot_axis = x_axis
276
- if keep_dims:
277
- dout_expand = dout[1]
278
- out = op(x)
279
- else:
280
- dout_expand = expand(dout[1], onehot_axis)
281
- out_shape = shape_op(out[0])
282
- if not F.is_sequence_shape_unknown(out_shape):
283
- if onehot_axis >= len(out_shape):
284
- onehot_axis = -1
285
- type_x = F.dtype(x)
286
- on_value = F.cast(F.scalar_to_tensor(1.0), type_x)
287
- off_value = F.cast(F.scalar_to_tensor(0.0), type_x)
288
- if not F.is_sequence_value_unknown(x_shape):
289
- depth = 1
290
- if x_shape:
291
- depth = x_shape[axis]
292
- onehot = P.OneHot(onehot_axis)
293
- dx = dout_expand * onehot(out[0], depth, on_value, off_value)
294
- if not x_shape:
295
- dx = squeeze(dx)
296
- return dx
297
- x_tensor_shape = P.TensorShape()(x)
298
- depth = x_tensor_shape[axis]
299
- if not onehot_axis_is_neg:
300
- onehot = P.OneHot(onehot_axis)
301
- dx = dout_expand * onehot(out[0], depth, on_value, off_value)
302
- else:
303
- if out[0].value is not None:
304
- # It is a temporary method: In the pynative mode, out may be a constant tensor. Constant
305
- # folding occurs in ExpandDims op, but such scenarios are not supported currently.
306
- out = op(x)
307
- dx = dout_expand * _onehot_with_neg_axis(onehot_axis, out[0], depth, on_value.dtype)
308
- return dx
309
-
310
-
311
- @bprop_getters.register(P.BatchMatMul)
312
- def bprop_batchmatmul(self):
313
- """Grad definition for `BatchMatMul` operation."""
314
- ta = self.transpose_a
315
- tb = self.transpose_b
316
- mul1 = P.BatchMatMul(transpose_a=(ta and tb),
317
- transpose_b=(ta or (not tb)))
318
- mul2 = P.BatchMatMul(transpose_a=((not ta) or tb),
319
- transpose_b=(ta and tb))
320
-
321
- def bprop(x, w, out, dout):
322
- if ta:
323
- dx = mul1(w, dout)
324
- else:
325
- dx = mul1(dout, w)
326
- if tb:
327
- dw = mul2(dout, x)
328
- else:
329
- dw = mul2(x, dout)
330
- return binop_grad_common_with_shift(x, w, dx, dw, 2)
331
-
332
- return bprop
333
-
334
-
335
- @bprop_getters.register(P.TensorAdd)
336
- def get_bprop_tensor_add(self):
337
- """Grad definition for `Add` operation."""
338
-
339
- def bprop(x, y, out, dout):
340
- return binop_grad_common(x, y, dout, dout)
341
-
342
- return bprop
343
-
344
-
345
- @bprop_getters.register(P.MatrixInverse)
346
- def get_bprop_matrix_inverse(self):
347
- """Grad definition for `MatrixInverse` operation."""
348
- matmul_x1 = nn.MatMul(transpose_x1=True)
349
- matmul_x2 = nn.MatMul(transpose_x2=True)
350
- neg = P.Neg()
351
-
352
- def bprop(x, out, dout):
353
- dx = matmul_x2(dout, out)
354
- dx = matmul_x1(out, dx)
355
- dx = neg(dx)
356
- return (dx,)
357
-
358
- return bprop
359
-
360
-
361
- @bprop_getters.register(P.Mul)
362
- def get_bprop_mul(self):
363
- """Grad definition for `Mul` operation."""
364
- mul_func = P.Mul()
365
-
366
- def bprop(x, y, out, dout):
367
- if x.dtype in (mstype.complex64, mstype.complex128):
368
- raise TypeError("For 'Mul', gradient not support for complex type currently.")
369
- bc_dx = mul_func(y, dout)
370
- bc_dy = mul_func(x, dout)
371
- return binop_grad_common(x, y, bc_dx, bc_dy)
372
-
373
- return bprop
374
-
375
-
376
- @bprop_getters.register(P.RealDiv)
377
- def get_bprop_real_div(self):
378
- """Grad definition for `RealDiv` operation."""
379
- div_op = P.RealDiv()
380
- neg = P.Neg()
381
- mul_op = P.Mul()
382
-
383
- def bprop(x, y, out, dout):
384
- if x.dtype in (mstype.complex64, mstype.complex128):
385
- raise TypeError("For 'RealDiv', gradient not support for complex type currently.")
386
- bc_x = div_op(dout, y)
387
- bc_y = neg(mul_op(bc_x, out))
388
- return binop_grad_common(x, y, bc_x, bc_y)
389
-
390
- return bprop
391
-
392
-
393
- @bprop_getters.register(P.Div)
394
- def get_bprop_div(self):
395
- """Grad definition for `Div` operation."""
396
- div_op = P.Div()
397
- neg = P.Neg()
398
- mul_op = P.Mul()
399
-
400
- def bprop(x, y, out, dout):
401
- bc_x = div_op(dout, y)
402
- bc_y = neg(mul_op(bc_x, out))
403
- return binop_grad_common(x, y, bc_x, bc_y)
404
-
405
- return bprop
406
-
407
-
408
- @bprop_getters.register(P.DivNoNan)
409
- def get_bprop_div_no_nan(self):
410
- """Grad definition for `DivNoNan` operation."""
411
- div_no_nan_op = P.DivNoNan()
412
- neg = P.Neg()
413
- mul_op = P.Mul()
414
-
415
- def bprop(x, y, out, dout):
416
- bc_x = div_no_nan_op(dout, y)
417
- bc_y = neg(mul_op(bc_x, out))
418
- return binop_grad_common(x, y, bc_x, bc_y)
419
-
420
- return bprop
421
-
422
-
423
- @bprop_getters.register(P.Xdivy)
424
- def get_bprop_xdivy(self):
425
- """Grad definition for `Xdivy` operation."""
426
- div_op = P.Xdivy()
427
-
428
- def bprop(x, y, out, dout):
429
- x_dtype = F.dtype(x)
430
- not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
431
- bc_x = div_op(not_zero_x, y) * dout
432
- bc_y = div_op(-x, F.square(y)) * dout
433
- return binop_grad_common(x, y, bc_x, bc_y)
434
-
435
- return bprop
436
-
437
-
438
- @bprop_getters.register(P.Floor)
439
- def get_bprop_floor(self):
440
- """Grad definition for `floor` operation."""
441
- fill_ = P.Fill()
442
- shape_ = P.Shape()
443
- dtype_ = P.DType()
444
-
445
- def bprop(x, out, dout):
446
- if F.is_sequence_value_unknown(shape_(x)):
447
- bc_x = zeros_like(x)
448
- else:
449
- bc_x = fill_(dtype_(x), shape_(x), 0.)
450
- return (bc_x,)
451
-
452
- return bprop
453
-
454
-
455
- @bprop_getters.register(P.Ceil)
456
- def get_bprop_ceil(self):
457
- """Grad definition for `ceil` operation."""
458
- fill_ = P.Fill()
459
- shape_ = P.Shape()
460
- dtype_ = P.DType()
461
-
462
- def bprop(x, out, dout):
463
- if F.is_sequence_value_unknown(shape_(x)):
464
- bc_x = zeros_like(x)
465
- else:
466
- bc_x = fill_(dtype_(x), shape_(x), 0.)
467
- return (bc_x,)
468
-
469
- return bprop
470
-
471
-
472
- @bprop_getters.register(P.FloorDiv)
473
- def get_bprop_floordiv(self):
474
- """Grad definition for `FloorDiv` operation."""
475
-
476
- def bprop(x, y, out, dout):
477
- return zeros_like(x), zeros_like(y)
478
-
479
- return bprop
480
-
481
-
482
- @bprop_getters.register(P.BitwiseAnd)
483
- def get_bprop_bitwiseand(self):
484
- """Grad definition for `BitwiseAnd` operation."""
485
-
486
- def bprop(x, y, out, dout):
487
- return zeros_like(x), zeros_like(y)
488
-
489
- return bprop
490
-
491
-
492
- @bprop_getters.register(P.BitwiseOr)
493
- def get_bprop_bitwiseor(self):
494
- """Grad definition for `BitwiseOr` operation."""
495
-
496
- def bprop(x, y, out, dout):
497
- return zeros_like(x), zeros_like(y)
498
-
499
- return bprop
500
-
501
-
502
- @bprop_getters.register(P.BitwiseXor)
503
- def get_bprop_bitwisexor(self):
504
- """Grad definition for `BitwiseXor` operation."""
505
-
506
- def bprop(x, y, out, dout):
507
- return zeros_like(x), zeros_like(y)
508
-
509
- return bprop
510
-
511
-
512
- @bprop_getters.register(P.FloorMod)
513
- def get_bprop_floormod(self):
514
- """Grad definition for `FloorMod` operation."""
515
-
516
- def bprop(x, y, out, dout):
517
- bc_x = dout
518
- bc_y = -dout * (x // y)
519
- return binop_grad_common(x, y, bc_x, bc_y)
520
-
521
- return bprop
522
-
523
-
524
- @bprop_getters.register(P.TruncateDiv)
525
- def get_bprop_truncate_div(self):
526
- """Grad definition for `TruncateDiv` operation."""
527
-
528
- def bprop(x, y, out, dout):
529
- return zeros_like(x), zeros_like(y)
530
-
531
- return bprop
532
-
533
-
534
- @bprop_getters.register(P.TruncateMod)
535
- def get_bprop_truncate_mod(self):
536
- """Grad definition for `TruncateMod` operation."""
537
- div_op = P.TruncateDiv()
538
-
539
- def bprop(x, y, out, dout):
540
- bc_x = dout
541
- bc_y = -dout * div_op(x, y)
542
- return binop_grad_common(x, y, bc_x, bc_y)
543
-
544
- return bprop
545
-
546
-
547
- @bprop_getters.register(P.Mod)
548
- def get_bprop_mod(self):
549
- """Grad definition for `Mod` operation."""
550
-
551
- def bprop(x, y, out, dout):
552
- bc_x = dout
553
- bc_y = -dout * (x // y)
554
- return binop_grad_common(x, y, bc_x, bc_y)
555
-
556
- return bprop
557
-
558
-
559
- @bprop_getters.register(P.Square)
560
- def get_bprop_square(self):
561
- """Grad definition for `Square` operation."""
562
- mul_func = P.Mul()
563
- fill_func = P.Fill()
564
- dtype = P.DType()
565
-
566
- def bprop(x, out, dout):
567
- temp = mul_func(dout, x)
568
- shape_x = shape_op(x)
569
- if F.is_sequence_value_unknown(shape_x):
570
- fill_value = dyn_fill(dtype(temp), dyn_shape_op(x), 2.0)
571
- else:
572
- fill_value = fill_func(dtype(temp), shape_x, 2.0)
573
- dx = mul_func(fill_value, temp)
574
- return (dx,)
575
-
576
- return bprop
577
-
578
-
579
- @bprop_getters.register(P.SquaredDifference)
580
- def get_bprop_squared_difference(self):
581
- """Grad definition for `SquaredDifference` operation."""
582
- neg = P.Neg()
583
-
584
- def bprop(x, y, out, dout):
585
- x_grad = 2 * dout * (x - y)
586
- bc_x = x_grad
587
- bc_y = neg(x_grad)
588
- return binop_grad_common(x, y, bc_x, bc_y)
589
-
590
- return bprop
591
-
592
-
593
- @bprop_getters.register(P.Xlogy)
594
- def get_bprop_xlogy(self):
595
- """Grad definition for `Xlogy` operation."""
596
- log_op = P.Xlogy()
597
- div_op = P.Xdivy()
598
-
599
- def bprop(x, y, out, dout):
600
- x_dtype = F.dtype(x)
601
- not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
602
- bc_x = log_op(not_zero_x, y) * dout
603
- bc_y = div_op(x, y) * dout
604
- return binop_grad_common(x, y, bc_x, bc_y)
605
-
606
- return bprop
607
-
608
-
609
- @bprop_getters.register(P.SquareSumAll)
610
- def get_bprop_square_sum_all(self):
611
- """Grad definition for `SquareSumAll` operation."""
612
- mul_func = P.Mul()
613
- fill_func = P.Fill()
614
- dtype = P.DType()
615
-
616
- def bprop(x, y, out, dout):
617
- temp_x = mul_func(dout[0], x)
618
- temp_y = mul_func(dout[1], y)
619
- if F.is_sequence_value_unknown(shape_op(x)):
620
- dx = mul_func(dyn_fill(dtype(temp_x), dyn_shape_op(x), 2.0), temp_x)
621
- else:
622
- dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
623
-
624
- if F.is_sequence_value_unknown(shape_op(y)):
625
- dy = mul_func(dyn_fill(dtype(temp_y), dyn_shape_op(y), 2.0), temp_y)
626
- else:
627
- dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
628
- return (dx, dy)
629
-
630
- return bprop
631
-
632
-
633
- @bprop_getters.register(P.Sqrt)
634
- def get_bprop_sqrt(self):
635
- """Grad definition for `Sqrt` operation."""
636
- sqrt_grad = G.SqrtGrad()
637
-
638
- def bprop(x, out, dout):
639
- dx = sqrt_grad(out, dout)
640
- return (dx,)
641
-
642
- return bprop
643
-
644
-
645
- @bprop_getters.register(G.SqrtGrad)
646
- def get_bprop_sqrt_grad(self):
647
- """Grad definition for `SqrtGrad` operation."""
648
-
649
- def bprop(y, grad, out, dout):
650
- gy = dout / y
651
- dy = -gy * out
652
- dgrad = 0.5 * gy
653
- return dy, dgrad
654
-
655
- return bprop
656
-
657
-
658
- @bprop_getters.register(P.Rsqrt)
659
- def get_bprop_rsqrt(self):
660
- """Grad definition for `Rsqrt` operation."""
661
- rsqrt_grad = G.RsqrtGrad()
662
-
663
- def bprop(x, out, dout):
664
- dx = rsqrt_grad(out, dout)
665
- return (dx,)
666
-
667
- return bprop
668
-
669
-
670
- @bprop_getters.register(G.RsqrtGrad)
671
- def get_bprop_rsqrt_grad(self):
672
- """Grad definition for `RsqrtGrad` operation."""
673
- rsqrt_grad = G.RsqrtGrad()
674
-
675
- def bprop(y, grad, out, dout):
676
- dy = -1.5 * grad * y * y * dout
677
- dgrad = rsqrt_grad(y, dout)
678
- return dy, dgrad
679
-
680
- return bprop
681
-
682
-
683
- @bprop_getters.register(P.Reciprocal)
684
- def get_bprop_reciprocal(self):
685
- """Grad definition for `Reciprocal` operation."""
686
- reciprocal_grad = G.ReciprocalGrad()
687
-
688
- def bprop(x, out, dout):
689
- dx = reciprocal_grad(out, dout)
690
- return (dx,)
691
-
692
- return bprop
693
-
694
-
695
- @bprop_getters.register(P.Log)
696
- def get_bprop_log(self):
697
- """Grad definition for `Log` operation."""
698
- reciprocal = P.Reciprocal()
699
-
700
- def bprop(x, out, dout):
701
- g = reciprocal(x)
702
- dx = g * dout
703
- return (dx,)
704
-
705
- return bprop
706
-
707
-
708
- @bprop_getters.register(P.Log1p)
709
- def get_bprop_log1p(self):
710
- """Grad definition for `Log1p` operation."""
711
- reciprocal = P.Reciprocal()
712
-
713
- def bprop(x, out, dout):
714
- x_1p = x + 1
715
- g = reciprocal(x_1p)
716
- dx = g * dout
717
- return (dx,)
718
-
719
- return bprop
720
-
721
-
722
- @bprop_getters.register(P.Erf)
723
- def get_bprop_erf(self):
724
- """Grad definition for `Erf` operation."""
725
- exp = P.Exp()
726
- square = P.Square()
727
- sqrt = P.Sqrt()
728
- cast = P.Cast()
729
- dtype = P.DType()
730
- neg = P.Neg()
731
-
732
- def bprop(x, out, dout):
733
- half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
734
- x_square = square(x)
735
- dx = dout * half_root_pi * exp(neg(x_square))
736
- return (dx,)
737
-
738
- return bprop
739
-
740
-
741
- @bprop_getters.register(P.Erfc)
742
- def get_bprop_erfc(self):
743
- """Grad definition for `Erfc` operation."""
744
- exp = P.Exp()
745
- square = P.Square()
746
- sqrt = P.Sqrt()
747
- cast = P.Cast()
748
- dtype = P.DType()
749
- neg = P.Neg()
750
-
751
- def bprop(x, out, dout):
752
- half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
753
- x_square = square(x)
754
- dx = dout * (neg(half_root_pi) * exp(neg(x_square)))
755
- return (dx,)
756
-
757
- return bprop
758
-
759
-
760
- @bprop_getters.register(P.Pow)
761
- def get_bprop_pow(self):
762
- """Grad definition for `Pow` operation."""
763
- pow_op = P.Pow()
764
- ln = P.Log()
765
-
766
- def bprop(x, power, out, dout):
767
- if x.dtype in (mstype.complex64, mstype.complex128):
768
- raise TypeError("For 'Pow', gradient not support for complex type currently.")
769
- bc_dx = power * pow_op(x, power - 1.0) * dout
770
- shape_x = shape_op(x)
771
- if F.is_sequence_value_unknown(shape_x):
772
- x = F.select(x < 0, dyn_fill(F.dtype(x), dyn_shape_op(x), 1), x)
773
- else:
774
- x = F.select(x < 0, F.fill(F.dtype(x), F.shape(x), 1), x)
775
- bc_dpower = out * ln(x) * dout
776
- return binop_grad_common(x, power, bc_dx, bc_dpower)
777
-
778
- return bprop
779
-
780
-
781
- @bprop_getters.register(P.Exp)
782
- def get_bprop_exp(self):
783
- """Grad definition for `Exp` operation."""
784
- exp_ = P.Exp()
785
-
786
- def bprop(x, out, dout):
787
- g = exp_(x)
788
- dx = g * dout
789
- return (dx,)
790
-
791
- return bprop
792
-
793
-
794
- @bprop_getters.register(P.Einsum)
795
- def get_bprop_einsum(self):
796
- """Grad definition for `Einsum` operation."""
797
- grad_func = G.EinsumGrad(self.equation)
798
-
799
- def bprop(x, out, dout):
800
- dx = grad_func(x, dout)
801
- return (dx,)
802
-
803
- return bprop
804
-
805
-
806
- @bprop_getters.register(P.Expm1)
807
- def get_bprop_expm1(self):
808
- """Grad definition for `Expm1` operation."""
809
- exp_ = P.Exp()
810
-
811
- def bprop(x, out, dout):
812
- g = exp_(x)
813
- dx = g * dout
814
- return (dx,)
815
-
816
- return bprop
817
-
818
-
819
- @bprop_getters.register(P.Minimum)
820
- def get_bprop_minimum(self):
821
- """Grad definition for `Minimum` operation."""
822
- input_grad = G.MinimumGrad()
823
-
824
- def bprop(x, y, out, dout):
825
- dx, dy = input_grad(x, y, dout)
826
- return dx, dy
827
-
828
- return bprop
829
-
830
-
831
- @bprop_getters.register(P.Maximum)
832
- def get_bprop_maximum(self):
833
- """Grad definition for `Maximum` operation."""
834
- input_grad = G.MaximumGrad()
835
-
836
- def bprop(x, y, out, dout):
837
- dx, dy = input_grad(x, y, dout)
838
- return dx, dy
839
-
840
- return bprop
841
-
842
-
843
- @bprop_getters.register(P.ReduceSum)
844
- def get_bprop_reducesum(self):
845
- """Grad definition for `ReduceSum` operation."""
846
-
847
- def bprop(x, axis, out, dout):
848
- dx = _sum_grad(x, axis, dout)
849
- return dx, zeros_like(axis)
850
-
851
- return bprop
852
-
853
-
854
- @bprop_getters.register(P.CumSum)
855
- def get_bprop_cumsum(self):
856
- """Grad definition for `CumSum` operation."""
857
- cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
858
-
859
- def bprop(x, axis, out, dout):
860
- return cumsum(dout, axis), zeros_like(axis)
861
-
862
- return bprop
863
-
864
-
865
- @_primexpr
866
- def _split_shape_index(input_shape, axis):
867
- """Calculate reduce_prod grad transpose indices and perm shape."""
868
- rank = len(input_shape)
869
- if isinstance(axis, int):
870
- axis = tuple([axis])
871
- reduction_indices = tuple([(i + rank) % rank for i in axis])
872
- other_indices_list = []
873
- for i in range(rank):
874
- if i not in reduction_indices and i not in other_indices_list:
875
- other_indices_list.append(i)
876
- other_indices = tuple(other_indices_list)
877
- reduced_list = [1] + [input_shape[i] for i in reduction_indices]
878
- other_list = [1] + [input_shape[i] for i in other_indices]
879
- reduced_num = 1
880
- for i in reduced_list:
881
- reduced_num = reduced_num * i
882
- other_num = 1
883
- for i in other_list:
884
- other_num = other_num * i
885
- perm = reduction_indices + other_indices
886
- return tuple([reduced_num, other_num]), perm
887
-
888
-
889
- @_primexpr
890
- def _invert_permutation(perm):
891
- """Calculate invert permutation."""
892
- out = [0] * len(perm)
893
- for i, value in enumerate(perm):
894
- out[value] = i
895
- return tuple(out)
896
-
897
-
898
- def _split_dyn_shape_index(x, axis):
899
- """Calculate reduce prod grad invert permutation."""
900
- input_shape = dyn_shape_op(x)
901
- rank = dyn_rank(x)
902
- if not isinstance(axis, Tensor):
903
- axis = Tensor(axis, dtype=mstype.int64)
904
- reduction_indices = reshape(axis, (-1,))
905
- reduction_indices = (reduction_indices + rank) % rank
906
- reduced = P.Cast()(reduction_indices, mstype.int64)
907
-
908
- start = Tensor(0, dtype=mstype.int64)
909
- delta = Tensor(1, dtype=mstype.int64)
910
- idx = P.Range()(start, rank, delta)
911
- other, _ = A.ListDiff()(idx, reduced)
912
- perm = P.Concat()((reduced, other))
913
- reduced_num = reduce_prod(P.Cast()(P.Gather()(input_shape, reduced, 0), mstype.int64), ())
914
- other_num = reduce_prod(P.Cast()(P.Gather()(input_shape, other, 0), mstype.int64), ())
915
- return (reduced_num, other_num), perm
916
-
917
-
918
- @bprop_getters.register(P.ReduceProd)
919
- def get_bprop_reduceprod(self):
920
- """Grad definition for `ReduceProd` operation."""
921
- transpose = P.Transpose()
922
- left_cumprod = P.CumProd(exclusive=True)
923
- right_cumprod = P.CumProd(exclusive=True, reverse=True)
924
-
925
- def bprop(x, axis, out, dout):
926
- """Grad definition for `Product` operation."""
927
- if x.dtype in (mstype.complex64, mstype.complex128):
928
- raise TypeError("The 'ReduceProd', gradient not support for complex type currently.")
929
- # Expand dout to full input shape
930
- input_shape = shape_op(x)
931
- if input_shape == ():
932
- dx = _sum_grad(x, axis, dout)
933
- return dx, zeros_like(axis)
934
-
935
- if F.is_sequence_value_unknown(input_shape):
936
- input_shape = dyn_shape_op(x)
937
- input_shape = P.Cast()(input_shape, ms.int64)
938
- output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
939
- output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int64)
940
- else:
941
- output_shape_kept_dims = reduced_shape(input_shape, axis)
942
-
943
- dout = reshape(dout, output_shape_kept_dims)
944
-
945
- # Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
946
- if F.is_sequence_value_unknown(shape_op(x)):
947
- pack_shape, perm = _split_dyn_shape_index(x, axis)
948
- else:
949
- pack_shape, perm = _split_shape_index(shape_op(x), axis)
950
-
951
- permuted = transpose(x, perm)
952
- permuted_shape = shape_op(permuted)
953
- if F.is_sequence_value_unknown(permuted_shape):
954
- permuted_shape = dyn_shape_op(permuted)
955
- pack_shape = create_tensor_by_element(pack_shape)
956
- reshaped = reshape(permuted, pack_shape)
957
-
958
- # Calculate product, leaving out the current entry
959
- left = left_cumprod(reshaped, 0)
960
- right = right_cumprod(reshaped, 0)
961
- y = reshape(left * right, permuted_shape)
962
-
963
- # Invert the transpose and reshape operations.
964
- # Make sure to set the statically known shape information through a reshape.
965
- if F.is_sequence_value_unknown(shape_op(permuted)):
966
- dout = DynamicBroadcastTo()(dout, input_shape)
967
- out = transpose(y, dyn_invert_permutation(perm)) * dout
968
- else:
969
- tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
970
- grad = tile(dout, tile_scaling)
971
- out = transpose(y, _invert_permutation(perm)) * grad
972
-
973
- dx = reshape(out, input_shape)
974
- return dx, zeros_like(axis)
975
-
976
- return bprop
977
-
978
-
979
- @bprop_getters.register(P.CumProd)
980
- def get_bprop_cumprod(self):
981
- """Grad definition for `CumProd` operation."""
982
- cumprod = P.CumProd(exclusive=self.exclusive, reverse=self.reverse)
983
- cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
984
-
985
- def bprop(x, axis, out, dout):
986
- """Grad definition for `Product` operation."""
987
- # This will fails when x contains 0
988
- prod = cumprod(x, axis)
989
- out = cumsum(prod * dout, axis)
990
- return out / x, zeros_like(axis)
991
-
992
- return bprop
993
-
994
-
995
- @bprop_getters.register(P.ReduceAll)
996
- def get_bprop_reduceall(self):
997
- """Grad definition for `ReduceAll` operation."""
998
-
999
- def bprop(x, axis, out, dout):
1000
- return zeros_like(x), zeros_like(axis)
1001
-
1002
- return bprop
1003
-
1004
-
1005
- @bprop_getters.register(P.ReduceAny)
1006
- def get_bprop_reduceany(self):
1007
- """Grad definition for `ReduceAny` operation."""
1008
-
1009
- def bprop(x, axis, out, dout):
1010
- return zeros_like(x), zeros_like(axis)
1011
-
1012
- return bprop
1013
-
1014
-
1015
- @bprop_getters.register(P.ReduceMax)
1016
- def get_bprop_reducemax(self):
1017
- """Grad definition for `Max` operation."""
1018
-
1019
- def bprop(x, axis, out, dout):
1020
- if x.dtype in (mstype.complex64, mstype.complex128):
1021
- raise TypeError("The 'ReduceMax', gradient not support for complex type currently.")
1022
- dx = _min_or_max_grad(x, axis, out, dout)
1023
- return (dx, zeros_like(axis))
1024
-
1025
- return bprop
1026
-
1027
-
1028
- @bprop_getters.register(P.ArgMaxWithValue)
1029
- def get_bprop_argmaxwithvalue(self):
1030
- """Grad definition for `ArgMaxWithValue` operation."""
1031
- axis = self.axis
1032
- keep_dims = self.keep_dims
1033
- op = P.ArgMaxWithValue(axis)
1034
-
1035
- def bprop(x, out, dout):
1036
- dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
1037
- return (dx,)
1038
-
1039
- return bprop
1040
-
1041
-
1042
- @bprop_getters.register(P.ReduceMin)
1043
- def get_bprop_reducemin(self):
1044
- """Grad definition for `ReduceMin` operation."""
1045
-
1046
- def bprop(x, axis, out, dout):
1047
- if x.dtype in (mstype.complex64, mstype.complex128):
1048
- raise TypeError("The 'ReduceMin', gradient not support for complex type currently.")
1049
- dx = _min_or_max_grad(x, axis, out, dout)
1050
- return (dx, zeros_like(axis))
1051
-
1052
- return bprop
1053
-
1054
-
1055
- @bprop_getters.register(P.ArgMinWithValue)
1056
- def get_bprop_argminwithvalue(self):
1057
- """Generate bprop for ArgMinWithValue"""
1058
- axis = self.axis
1059
- keep_dims = self.keep_dims
1060
- op = P.ArgMinWithValue(axis)
1061
-
1062
- def bprop(x, out, dout):
1063
- dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
1064
- return (dx,)
1065
-
1066
- return bprop
1067
-
1068
-
1069
- @bprop_getters.register(P.ReduceMean)
1070
- def get_bprop_reduce_mean(self):
1071
- """Grad definition for `ReduceMean` operation."""
1072
- div_op = P.RealDiv()
1073
- cast = P.Cast()
1074
- dtype = P.DType()
1075
-
1076
- def bprop(x, axis, out, dout):
1077
- if x.dtype in (mstype.complex64, mstype.complex128):
1078
- raise TypeError("The 'ReduceMean', gradient not support for complex type currently.")
1079
- grad = _sum_grad(x, axis, dout)
1080
- shape_x = shape_op(x)
1081
- shape_out = shape_op(out)
1082
- if F.is_sequence_value_unknown(shape_x) or F.is_sequence_value_unknown(shape_out):
1083
- shape_x = dyn_shape_op(x)
1084
- shape_out = dyn_shape_op(out)
1085
- div_shape = reduce_prod(cast(shape_x, mstype.float32), ()) /\
1086
- reduce_prod(cast(shape_out, mstype.float32), ())
1087
- dx = div_op(grad, cast(div_shape, dtype(grad)))
1088
- else:
1089
- div_shape = F.shape_mul(shape_x) / F.shape_mul(shape_out)
1090
- dx = div_op(grad, cast(F.scalar_to_tensor(div_shape), dtype(grad)))
1091
- return dx, zeros_like(axis)
1092
-
1093
- return bprop
1094
-
1095
-
1096
- @bprop_getters.register(P.IsFinite)
1097
- def get_bprop_isfinite(self):
1098
- """Grad definition for `IsFinite` operation."""
1099
-
1100
- def bprop(x, out, dout):
1101
- return (zeros_like(x),)
1102
-
1103
- return bprop
1104
-
1105
-
1106
- @bprop_getters.register(P.IsNan)
1107
- def get_bprop_isnan(self):
1108
- """Grad definition for `IsNan` operation."""
1109
-
1110
- def bprop(x, out, dout):
1111
- return (zeros_like(x),)
1112
-
1113
- return bprop
1114
-
1115
-
1116
- @bprop_getters.register(P.IsInf)
1117
- def get_bprop_isinf(self):
1118
- """Grad definition for `IsInf` operation."""
1119
-
1120
- def bprop(x, out, dout):
1121
- return (zeros_like(x),)
1122
-
1123
- return bprop
1124
-
1125
-
1126
- @bprop_getters.register(P.Equal)
1127
- def get_bprop_equal(self):
1128
- """Grad definition for `Equal` operation."""
1129
-
1130
- def bprop(x, y, out, dout):
1131
- return zeros_like(x), zeros_like(y)
1132
-
1133
- return bprop
1134
-
1135
-
1136
- @bprop_getters.register(P.NotEqual)
1137
- def get_bprop_not_equal(self):
1138
- """Grad definition for `NotEqual` operation."""
1139
-
1140
- def bprop(x, y, out, dout):
1141
- return zeros_like(x), zeros_like(y)
1142
-
1143
- return bprop
1144
-
1145
-
1146
- @bprop_getters.register(P.ApproximateEqual)
1147
- def get_bprop_approximate_equal(self):
1148
- """Grad definition for `ApproximateEqual` operation."""
1149
-
1150
- def bprop(x, y, out, dout):
1151
- return zeros_like(x), zeros_like(y)
1152
-
1153
- return bprop
1154
-
1155
-
1156
- @bprop_getters.register(P.Greater)
1157
- def get_bprop_greater(self):
1158
- """Grad definition for `Greater` operation."""
1159
-
1160
- def bprop(x, y, out, dout):
1161
- return zeros_like(x), zeros_like(y)
1162
-
1163
- return bprop
1164
-
1165
-
1166
- @bprop_getters.register(P.GreaterEqual)
1167
- def get_bprop_greater_equal(self):
1168
- """Grad definition for `GreaterEqual` operation."""
1169
-
1170
- def bprop(x, y, out, dout):
1171
- return zeros_like(x), zeros_like(y)
1172
-
1173
- return bprop
1174
-
1175
-
1176
- @bprop_getters.register(P.Less)
1177
- def get_bprop_less(self):
1178
- """Grad definition for `Less` operation."""
1179
-
1180
- def bprop(x, y, out, dout):
1181
- return zeros_like(x), zeros_like(y)
1182
-
1183
- return bprop
1184
-
1185
-
1186
- @bprop_getters.register(P.LessEqual)
1187
- def get_bprop_less_equal(self):
1188
- """Grad definition for `LessEqual` operation."""
1189
-
1190
- def bprop(x, y, out, dout):
1191
- return zeros_like(x), zeros_like(y)
1192
-
1193
- return bprop
1194
-
1195
-
1196
- @bprop_getters.register(P.LogicalNot)
1197
- def get_bprop_logical_not(self):
1198
- """Grad definition for `LogicalNot` operation."""
1199
-
1200
- def bprop(x, out, dout):
1201
- return (zeros_like(x),)
1202
-
1203
- return bprop
1204
-
1205
-
1206
- @bprop_getters.register(P.LogicalAnd)
1207
- def get_bprop_logical_and(self):
1208
- """Grad definition for `LogicalAnd` operation."""
1209
-
1210
- def bprop(x, y, out, dout):
1211
- return zeros_like(x), zeros_like(y)
1212
-
1213
- return bprop
1214
-
1215
-
1216
- @bprop_getters.register(P.NPUAllocFloatStatus)
1217
- def get_bprop_npu_alloc_float_status(self):
1218
- """Grad definition for `NPUAllocFloatStatus` operation."""
1219
-
1220
- def bprop(out, dout):
1221
- return ()
1222
-
1223
- return bprop
1224
-
1225
-
1226
- @bprop_getters.register(P.NPUGetFloatStatus)
1227
- def get_bprop_npu_get_float_status(self):
1228
- """Grad definition for `NPUGetFloatStatus` operation."""
1229
-
1230
- def bprop(x, out, dout):
1231
- return (zeros_like(x),)
1232
-
1233
- return bprop
1234
-
1235
-
1236
- @bprop_getters.register(P.NPUClearFloatStatus)
1237
- def get_bprop_npu_clear_float_status(self):
1238
- """Grad definition for `NPUClearFloatStatus` operation."""
1239
-
1240
- def bprop(x, out, dout):
1241
- return (zeros_like(x),)
1242
-
1243
- return bprop
1244
-
1245
-
1246
- @bprop_getters.register(P.AssignAdd)
1247
- def get_bprop_assign_add(self):
1248
- """Grad definition for `AssignAdd` operation."""
1249
-
1250
- def bprop(x, y, out, dout):
1251
- return zeros_like(x), zeros_like(y)
1252
-
1253
- return bprop
1254
-
1255
-
1256
- @bprop_getters.register(P.AssignSub)
1257
- def get_bprop_assign_sub(self):
1258
- """Grad definition for `AssignSub` operation."""
1259
-
1260
- def bprop(x, y, out, dout):
1261
- return zeros_like(x), zeros_like(y)
1262
-
1263
- return bprop
1264
-
1265
-
1266
- @bprop_getters.register(P.Sin)
1267
- def get_bprop_sin(self):
1268
- """Grad definition for `Sin` operation."""
1269
- cos = P.Cos()
1270
-
1271
- def bprop(x, out, dout):
1272
- dx = dout * cos(x)
1273
- return (dx,)
1274
-
1275
- return bprop
1276
-
1277
-
1278
- @bprop_getters.register(P.Asin)
1279
- def get_bprop_asin(self):
1280
- """Grad definition for `Asin` operation."""
1281
- input_grad = G.AsinGrad()
1282
-
1283
- def bprop(x, out, dout):
1284
- dx = input_grad(x, dout)
1285
- return (dx,)
1286
-
1287
- return bprop
1288
-
1289
-
1290
- @bprop_getters.register(G.AsinGrad)
1291
- def get_bprop_asin_grad(self):
1292
- """Grad definition for `AsinGrad` operation."""
1293
- input_grad = G.AsinGrad()
1294
- p_pow = P.Pow()
1295
-
1296
- def bprop(x, grad, out, dout):
1297
- d2x = dout * grad * x * p_pow((1 - x * x), - 1.5)
1298
- ddy = input_grad(x, dout)
1299
- return (d2x, ddy)
1300
-
1301
- return bprop
1302
-
1303
-
1304
- @bprop_getters.register(P.Asinh)
1305
- def get_bprop_asinh(self):
1306
- """Grad definition for `Asinh` operation."""
1307
- input_grad = G.AsinhGrad()
1308
-
1309
- def bprop(x, out, dout):
1310
- dx = input_grad(out, dout)
1311
- return (dx,)
1312
-
1313
- return bprop
1314
-
1315
-
1316
- @bprop_getters.register(G.AsinhGrad)
1317
- def get_bprop_asinh_grad(self):
1318
- """Grad definition for `AsinhGrad` operation."""
1319
- input_grad = G.AsinhGrad()
1320
- tanh = P.Tanh()
1321
-
1322
- def bprop(y, grad, out, dout):
1323
- dy = dout * out * -1.0 * tanh(y)
1324
- dgrad = input_grad(y, dout)
1325
- return dy, dgrad
1326
-
1327
- return bprop
1328
-
1329
-
1330
- @bprop_getters.register(P.Sinh)
1331
- def get_bprop_sinh(self):
1332
- """Grad definition for `Sinh` operation."""
1333
- cosh = P.Cosh()
1334
-
1335
- def bprop(x, out, dout):
1336
- dx = cosh(x) * dout
1337
- return (dx,)
1338
-
1339
- return bprop
1340
-
1341
-
1342
- @bprop_getters.register(P.Cos)
1343
- def get_bprop_cos(self):
1344
- """Grad definition for `Cos` operation."""
1345
- sin = P.Sin()
1346
- neg = P.Neg()
1347
-
1348
- def bprop(x, out, dout):
1349
- dx = dout * neg(sin(x))
1350
- return (dx,)
1351
-
1352
- return bprop
1353
-
1354
-
1355
- @bprop_getters.register(P.ACos)
1356
- def get_bprop_acos(self):
1357
- """Grad definition for `ACos` operation."""
1358
- input_grad = G.ACosGrad()
1359
-
1360
- def bprop(x, out, dout):
1361
- dx = input_grad(x, dout)
1362
- return (dx,)
1363
-
1364
- return bprop
1365
-
1366
-
1367
- @bprop_getters.register(G.ACosGrad)
1368
- def get_bprop_acos_grad(self):
1369
- """Grad definition for `ACosGrad` operation."""
1370
- input_grad = G.ACosGrad()
1371
- p_pow = P.Pow()
1372
-
1373
- def bprop(x, grad, out, dout):
1374
- d2x = -dout * grad * x * p_pow((1 - x * x), - 1.5)
1375
- ddy = input_grad(x, dout)
1376
- return (d2x, ddy)
1377
-
1378
- return bprop
1379
-
1380
-
1381
- @bprop_getters.register(P.Acosh)
1382
- def get_bprop_acosh(self):
1383
- """Grad definition for `Acosh` operation."""
1384
- input_grad = G.AcoshGrad()
1385
-
1386
- def bprop(x, out, dout):
1387
- dx = input_grad(out, dout)
1388
- return (dx,)
1389
-
1390
- return bprop
1391
-
1392
-
1393
- @bprop_getters.register(G.AcoshGrad)
1394
- def get_bprop_acosh_grad(self):
1395
- """Grad definition for `AcoshGrad` operation."""
1396
- input_grad = G.AcoshGrad()
1397
- tanh = P.Tanh()
1398
-
1399
- def bprop(y, grad, out, dout):
1400
- dy = dout * out * -1.0 / tanh(y)
1401
- dgrad = input_grad(y, dout)
1402
- return dy, dgrad
1403
-
1404
- return bprop
1405
-
1406
-
1407
- @bprop_getters.register(P.Cosh)
1408
- def get_bprop_cosh(self):
1409
- """Grad definition for `Cosh` operation."""
1410
- sinh = P.Sinh()
1411
-
1412
- def bprop(x, out, dout):
1413
- if x.dtype in (mstype.complex64, mstype.complex128):
1414
- raise TypeError("The 'Cosh', gradient not support for complex type currently.")
1415
-
1416
- dx = sinh(x) * dout
1417
- return (dx,)
1418
-
1419
- return bprop
1420
-
1421
-
1422
- @bprop_getters.register(P.Abs)
1423
- def get_bprop_abs(self):
1424
- """Grad definition for `Abs` operation."""
1425
- abs_grad = G.AbsGrad()
1426
-
1427
- def bprop(x, out, dout):
1428
- dx = abs_grad(x, dout)
1429
- return (dx,)
1430
-
1431
- return bprop
1432
-
1433
-
1434
- @bprop_getters.register(P.Conj)
1435
- def get_bprop_conj(self):
1436
- """Grad definition for `Conj` operation."""
1437
- conj = P.Conj()
1438
-
1439
- def bprop(x, out, dout):
1440
- dx = conj(dout)
1441
- return (dx,)
1442
-
1443
- return bprop
1444
-
1445
-
1446
- @bprop_getters.register(P.AccumulateNV2)
1447
- def get_bprop_scalar_accumulatenv2(self):
1448
- """Generate bprop for AccumulateNV2"""
1449
-
1450
- def bprop(x, out, dout):
1451
- dx = ()
1452
- for _ in range(len(x)):
1453
- dx = dx + (dout,)
1454
- return (dx,)
1455
-
1456
- return bprop
1457
-
1458
-
1459
- @bprop_getters.register(P.AddN)
1460
- def get_bprop_scalar_addn(self):
1461
- """Generate bprop for AddN"""
1462
-
1463
- def bprop(x, out, dout):
1464
- if is_sub_class(F.typeof(x), ms.list_):
1465
- dx = []
1466
- for _ in range(len(x)):
1467
- dx.append(dout)
1468
- return (dx,)
1469
-
1470
- dx = ()
1471
- for _ in range(len(x)):
1472
- dx = dx + (dout,)
1473
- return (dx,)
1474
-
1475
- return bprop
1476
-
1477
-
1478
- @bprop_getters.register(P.Sign)
1479
- def get_bprop_sign(self):
1480
- """Generate bprop for Sign"""
1481
-
1482
- def bprop(x, out, dout):
1483
- return (zeros_like(x),)
1484
-
1485
- return bprop
1486
-
1487
-
1488
- @bprop_getters.register(P.Round)
1489
- def get_bprop_round(self):
1490
- """Generate bprop for Round"""
1491
-
1492
- def bprop(x, out, dout):
1493
- return (zeros_like(x),)
1494
-
1495
- return bprop
1496
-
1497
-
1498
- @bprop_getters.register(P.Atan2)
1499
- def get_bprop_atan2(self):
1500
- """Generate bprop for Atan2"""
1501
-
1502
- square = P.Square()
1503
-
1504
- def bprop(x, y, out, dout):
1505
- tmp = dout / (square(x) + square(y))
1506
- bc_dx = tmp * y
1507
- bc_dy = tmp * (-x)
1508
- return binop_grad_common(x, y, bc_dx, bc_dy)
1509
-
1510
- return bprop
1511
-
1512
-
1513
- @bprop_getters.register(P.BesselI0e)
1514
- def get_bprop_bessel_i0e(self):
1515
- """Generate bprop for BesselI0e"""
1516
- sign = P.Sign()
1517
- bessel_i1e = P.BesselI1e()
1518
-
1519
- def bprop(x, out, dout):
1520
- dx = dout * (bessel_i1e(x) - sign(x) * out)
1521
- return (dx,)
1522
-
1523
- return bprop
1524
-
1525
-
1526
- @bprop_getters.register(P.Atan)
1527
- def get_bprop_atan(self):
1528
- """Grad definition for `Atan` operation."""
1529
- input_grad = G.AtanGrad()
1530
-
1531
- def bprop(x, out, dout):
1532
- dx = input_grad(x, dout)
1533
- return (dx,)
1534
-
1535
- return bprop
1536
-
1537
-
1538
- @bprop_getters.register(G.AtanGrad)
1539
- def get_bprop_atan_grad(self):
1540
- """Grad definition for `AtanGrad` operation."""
1541
- input_grad = G.AtanGrad()
1542
-
1543
- def bprop(x, grad, out, dout):
1544
- dgrad = input_grad(x, dout)
1545
- dx = out * dgrad * -2.0 * x
1546
- return dx, dgrad
1547
-
1548
- return bprop
1549
-
1550
-
1551
- @bprop_getters.register(P.Tan)
1552
- def get_bprop_tan(self):
1553
- """Grad definition for `Tan` operation."""
1554
- reciprocal = P.Reciprocal()
1555
- square = P.Square()
1556
- cos = P.Cos()
1557
-
1558
- def bprop(x, out, dout):
1559
- if x.dtype in (mstype.complex64, mstype.complex128):
1560
- raise TypeError("For 'Tan', gradient not support for complex type currently.")
1561
-
1562
- cosx = cos(x)
1563
- secx2 = square(reciprocal(cosx))
1564
- dx = secx2 * dout
1565
- return (dx,)
1566
-
1567
- return bprop
1568
-
1569
-
1570
- @bprop_getters.register(P.BesselI1e)
1571
- def get_bprop_bessel_i1e(self):
1572
- """Generate bprop for BesselI1e"""
1573
-
1574
- sign = P.Sign()
1575
- bessel_i0e = P.BesselI0e()
1576
- less = P.Less()
1577
- select = P.Select()
1578
- reciprocal = P.Reciprocal()
1579
- cast = P.Cast()
1580
- dtype = P.DType()
1581
- abs_ops = P.Abs()
1582
-
1583
- def bprop(x, out, dout):
1584
- zeros = zeros_like(x)
1585
- np_eps = const_utils.get_np_eps(dtype(x))
1586
- eps = cast(np_eps, dtype(x))
1587
- x_is_valid = less(eps, abs_ops(x))
1588
- x_safe = select(x_is_valid, x, eps + zeros)
1589
- tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
1590
- dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
1591
- return (dx,)
1592
-
1593
- return bprop
1594
-
1595
-
1596
- @bprop_getters.register(P.Atanh)
1597
- def get_bprop_atanh(self):
1598
- """Grad definition for `Atanh` operation."""
1599
- power = P.Pow()
1600
- div = P.Div()
1601
-
1602
- def bprop(x, out, dout):
1603
- if x.dtype in (mstype.complex64, mstype.complex128):
1604
- raise TypeError("For 'Atanh', gradient not support for complex type currently.")
1605
-
1606
- tmp = 1 - power(x, 2)
1607
- dx = div(1, tmp) * dout
1608
- return (dx,)
1609
-
1610
- return bprop
1611
-
1612
-
1613
- @bprop_getters.register(P.Inv)
1614
- def get_bprop_inv(self):
1615
- """Grad definition for 'Inv' operation"""
1616
- inv_grad = G.InvGrad()
1617
-
1618
- def bprop(x, out, dout):
1619
- dx = inv_grad(out, dout)
1620
- return (dx,)
1621
-
1622
- return bprop
1623
-
1624
-
1625
- @bprop_getters.register(P.LinSpace)
1626
- def get_bprop_lin_space(self):
1627
- """Grad definition for `LinSpace` operation."""
1628
-
1629
- def bprop(start, stop, num, out, dout):
1630
- return zeros_like(start), zeros_like(stop), zeros_like(num)
1631
-
1632
- return bprop
1633
-
1634
-
1635
- @bprop_getters.register(P.IndexAdd)
1636
- def get_bprop_index_add(self):
1637
- """Generate bprop for IndexAdd"""
1638
- gather = P.Gather()
1639
- _axis = self.axis
1640
-
1641
- def bprop(input_x, indices, input_y, out, dout):
1642
- return dout, zeros_like(indices), gather(dout, indices, _axis)
1643
-
1644
- return bprop
1645
-
1646
-
1647
- @bprop_getters.register(P.InplaceUpdate)
1648
- def get_bprop_inplace_update(self):
1649
- """Grad definition for `InplaceUpdate` operation."""
1650
-
1651
- def bprop(x, v, out, dout):
1652
- return zeros_like(x), zeros_like(v)
1653
-
1654
- return bprop
1655
-
1656
-
1657
- @bprop_getters.register(P.InplaceUpdateV2)
1658
- def get_bprop_inplace_update_v2(self):
1659
- """Grad definition for `InplaceUpdateV2` operation."""
1660
-
1661
- def bprop(x, indices, v, out, dout):
1662
- return zeros_like(x), zeros_like(indices), zeros_like(v)
1663
-
1664
- return bprop
1665
-
1666
-
1667
- @bprop_getters.register(P.InplaceSub)
1668
- def get_bprop_inplace_sub(self):
1669
- """Grad definition for `InplaceSub` operation."""
1670
-
1671
- def bprop(x, input_v, out, dout):
1672
- return zeros_like(x), zeros_like(input_v)
1673
-
1674
- return bprop
1675
-
1676
-
1677
- @bprop_getters.register(P.InplaceAdd)
1678
- def get_bprop_inplace_add(self):
1679
- """Grad definition for `InplaceAdd` operation."""
1680
-
1681
- def bprop(x, input_v, out, dout):
1682
- return zeros_like(x), zeros_like(input_v)
1683
-
1684
- return bprop