mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,1529 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """Define the grad rules of neural network related operations."""
17
- from mindspore import context
18
- from mindspore.common import dtype as mstype
19
- from mindspore.common.tensor import Tensor
20
- from mindspore.ops.primitive import _primexpr
21
- from mindspore.ops.operations import nn_ops as nps
22
- from mindspore.ops._grad.grad_base import bprop_getters, dyn_size, create_tensor_by_element, dyn_rank
23
- from mindspore.ops import functional as F
24
- from mindspore.ops import operations as P
25
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
26
- from mindspore.ops.operations import _grad_ops as G
27
- from mindspore.ops.operations import _inner_ops as inner
28
- from mindspore.ops.operations import _rl_inner_ops as rl_ops
29
- from mindspore.ops._utils.utils import range_op, get_1d_shape
30
-
31
-
32
- @_primexpr
33
- def bias_add_gradgrad_helper(shape, bias_shape, data_format):
34
- """Helper function of BiasGradGrad to calculate expanded shape."""
35
- new_shape = list(shape)
36
- new_bias_shape = list(bias_shape)
37
-
38
- ones_1 = []
39
- ones_2 = []
40
- for _ in new_shape[2:]:
41
- ones_1.append(1)
42
-
43
- for _ in new_shape[:-1]:
44
- ones_2.append(1)
45
-
46
- if data_format == "NCHW":
47
- expanded_shape = [1] + new_bias_shape + ones_1
48
- tile_mults = [new_shape[0]] + [1] + new_shape[2:]
49
- else:
50
- expanded_shape = ones_2 + new_bias_shape
51
- tile_mults = new_shape[:-1] + [1]
52
- return tuple(expanded_shape), tuple(tile_mults)
53
-
54
-
55
- def bias_add_gradgrad_helper_dynamic(shape, bias_shape, data_format):
56
- """Helper function of BiasGradGrad to calculate expanded shape(dynamic version)."""
57
- if data_format == "NCHW":
58
- expanded_shape = P.Concat(0)((P.OnesLike()(shape[:1]), bias_shape, P.OnesLike()(shape[2:])))
59
- tile_mults = P.Concat(0)((shape[:1], Tensor([1], dtype=mstype.int64), shape[2:]))
60
- else:
61
- expanded_shape = P.Concat(0)((P.OnesLike()(shape[:-1]), bias_shape))
62
- tile_mults = P.Concat(0)((shape[:-1], Tensor([1], dtype=mstype.int64)))
63
- return expanded_shape, tile_mults
64
-
65
-
66
- @bprop_getters.register(G.BiasAddGrad)
67
- def get_bprop_bias_add_grad(self):
68
- """Grad definition for `BiasAddGrad` operation."""
69
-
70
- data_format = self.data_format
71
-
72
- def bprop(dy, out, dout):
73
- reshape = P.Reshape()
74
- tile = P.Tile()
75
- dyn_shape = P.TensorShape()
76
- dy_shape = dy.shape
77
- dout_shape = dout.shape
78
- if F.is_sequence_value_unknown(dy_shape) or F.is_sequence_value_unknown(dout_shape):
79
- dy_shape = dyn_shape(dy)
80
- dout_shape = dyn_shape(dout)
81
- expanded_shape, tile_mults = bias_add_gradgrad_helper_dynamic(dy_shape, dout_shape, data_format)
82
- expanded_grad = reshape(dout, expanded_shape)
83
- tiled_grad = tile(expanded_grad, tile_mults)
84
- else:
85
- expanded_shape, tile_mults = bias_add_gradgrad_helper(dy_shape, dout_shape, data_format)
86
- expanded_grad = reshape(dout, expanded_shape)
87
- tiled_grad = tile(expanded_grad, tile_mults)
88
- return (tiled_grad,)
89
-
90
- return bprop
91
-
92
-
93
- @bprop_getters.register(nps.Conv3D)
94
- def get_bprop_conv3d(self):
95
- """Grad definition for `Conv3D` operation."""
96
- input_grad = nps.Conv3DBackpropInput(
97
- self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
98
- pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
99
- )
100
- filter_grad = G.Conv3DBackpropFilter(
101
- self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
102
- pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
103
- )
104
- get_shape = P.Shape()
105
- get_dyn_shape = P.TensorShape()
106
- cast = P.Cast()
107
- get_dtype = P.DType()
108
-
109
- def bprop(x, w, out, dout):
110
- if F.is_sequence_value_unknown(get_shape(x)) or F.is_sequence_value_unknown(get_shape(w)):
111
- dx = input_grad(w, dout, get_dyn_shape(x))
112
- dw = cast(filter_grad(x, dout, get_dyn_shape(w)), get_dtype(x))
113
- return dx, dw
114
-
115
- dx = input_grad(w, dout, get_shape(x))
116
- dw = cast(filter_grad(x, dout, get_shape(w)), get_dtype(x))
117
- return dx, dw
118
-
119
- return bprop
120
-
121
-
122
- @bprop_getters.register(nps.Conv3DTranspose)
123
- def get_bprop_conv3d_transpose(self):
124
- """Grad definition for `Conv3DTranspose` operation."""
125
- stride = (self.stride[2], self.stride[3], self.stride[4])
126
- dilation = (self.dilation[2], self.dilation[3], self.dilation[4])
127
- pad_list = self.get_attr_dict()['pad_list']
128
- input_grad = nps.Conv3D(
129
- out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
130
- pad=pad_list, stride=stride, dilation=dilation, group=self.group, data_format=self.data_format
131
- )
132
- filter_grad = G.Conv3DBackpropFilter(
133
- out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
134
- pad=pad_list, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
135
- )
136
- get_dyn_shape = P.TensorShape()
137
-
138
- def bprop(x, w, out, dout):
139
- if F.is_sequence_value_unknown(F.shape(w)):
140
- dx = input_grad(dout, w)
141
- dw = filter_grad(dout, x, get_dyn_shape(w))
142
- return dx, dw
143
-
144
- dx = input_grad(dout, w)
145
- dw = filter_grad(dout, x, F.shape(w))
146
- return dx, dw
147
-
148
- return bprop
149
-
150
-
151
- @bprop_getters.register(inner.ExtractImagePatches)
152
- def get_bprop_extract_image_patches(self):
153
- """Grad definition for `ExtractImagePatches` operation."""
154
- get_shape = P.Shape()
155
- reshape = P.Reshape()
156
- extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes,
157
- strides=self.strides,
158
- rates=self.rates,
159
- padding=self.padding)
160
- concat = P.Concat(axis=-1)
161
- expand_dims = P.ExpandDims()
162
- scatter_nd = P.ScatterNd()
163
- dtype = P.DType()
164
- fill = P.Fill()
165
- slice_op = P.Slice()
166
- transpose = P.Transpose()
167
- cast = P.Cast()
168
- matmul = P.MatMul()
169
- range_ = P.Range()
170
- dyn_shape_op = P.TensorShape()
171
- ones_like = P.OnesLike()
172
-
173
- _, _, ksizes_row, ksizes_col = self.ksizes
174
-
175
- def _dyn_extract_image_patched(x, out, dout):
176
- x_shape = dyn_shape_op(x)
177
- out_shape = dyn_shape_op(out)
178
- x_batch, x_depth, x_row, x_col = x_shape[0], x_shape[1], x_shape[2], x_shape[3]
179
- x_indices_num = x_row * x_col + 1
180
- x_idx = range_(cast(1, mstype.float32), cast(x_indices_num, mstype.float32), cast(1, mstype.float32))
181
- x_idx = reshape(x_idx, create_tensor_by_element((1, 1, x_row, x_col)))
182
- x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
183
- x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
184
-
185
- out_row, out_col = out_shape[2], out_shape[3]
186
- out_indices_num = out_row * out_col * ksizes_row * ksizes_col
187
- out_idx_ori = range_(cast(0, mstype.int32), cast(out_indices_num, mstype.int32), cast(1, mstype.int32))
188
- out_idx = reshape(out_idx_ori, create_tensor_by_element((1, out_row, out_col, ksizes_row * ksizes_col)))
189
-
190
- idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
191
- idx_tensor = reshape(idx_tensor, (-1, 2))
192
- sp_shape = create_tensor_by_element((x_indices_num, out_indices_num))
193
- update = cast(ones_like(out_idx_ori), dtype(dout))
194
- sp_tensor = scatter_nd(idx_tensor, update, sp_shape)
195
- begin = create_tensor_by_element((1, 0))
196
- size = create_tensor_by_element((x_indices_num - 1, out_indices_num))
197
- sp_tensor = slice_op(sp_tensor, begin, size)
198
-
199
- grad = transpose(dout, (0, 2, 3, 1))
200
- grad = reshape(grad, create_tensor_by_element((x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth)))
201
- grad = transpose(grad, (1, 2, 3, 4, 0, 5))
202
- grad = reshape(grad, create_tensor_by_element((out_row * out_col * ksizes_row * ksizes_col, x_batch * x_depth)))
203
-
204
- jac = matmul(sp_tensor, grad)
205
- dx = reshape(jac, create_tensor_by_element((x_row, x_col, x_batch, x_depth)))
206
- dx = transpose(dx, (2, 3, 0, 1))
207
- return (dx,)
208
-
209
- def bprop(x, out, dout):
210
- x_shape = get_shape(x)
211
- out_shape = get_shape(out)
212
- if F.is_sequence_value_unknown(x_shape) or F.is_sequence_value_unknown(out_shape):
213
- return _dyn_extract_image_patched(x, out, dout)
214
- x_batch, x_depth, x_row, x_col = x_shape
215
- x_indices_num = x_row * x_col + 1
216
- x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32)
217
- x_idx = reshape(x_idx, (1, 1, x_row, x_col))
218
- x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
219
- x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
220
-
221
- _, _, out_row, out_col = out_shape
222
- out_indices_num = out_row * out_col * ksizes_row * ksizes_col
223
- out_idx = F.tuple_to_array(range(out_indices_num))
224
- out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
225
-
226
- idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
227
- idx_tensor = reshape(idx_tensor, (-1, 2))
228
- sp_shape = (x_indices_num, out_indices_num)
229
- sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
230
- sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
231
-
232
- grad = transpose(dout, (0, 2, 3, 1))
233
- grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
234
- grad = transpose(grad, (1, 2, 3, 4, 0, 5))
235
- grad = reshape(grad, (-1, x_batch * x_depth))
236
-
237
- jac = matmul(sp_tensor, grad)
238
- dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
239
- dx = transpose(dx, (2, 3, 0, 1))
240
- return (dx,)
241
-
242
- return bprop
243
-
244
-
245
- @bprop_getters.register(P.DepthwiseConv2dNative)
246
- def get_bprop_depthwise_conv2d_native(self):
247
- """Grad definition for `DepthwiseConv2dNative` operation."""
248
- input_grad = G.DepthwiseConv2dNativeBackpropInput(
249
- self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
250
- self.dilation, self.group
251
- )
252
- filter_grad = G.DepthwiseConv2dNativeBackpropFilter(
253
- self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
254
- self.dilation, self.group
255
- )
256
- get_shape = P.Shape()
257
-
258
- def bprop(x, w, out, dout):
259
- dx = input_grad(get_shape(x), w, dout)
260
-
261
- dw = filter_grad(x, get_shape(w), dout)
262
- return dx, dw
263
-
264
- return bprop
265
-
266
-
267
- @bprop_getters.register(P.MaxPoolWithArgmax)
268
- def get_bprop_max_pool_with_argmax(self):
269
- """Grad definition for `MaxPoolWithArgmax` operation."""
270
- maxpool_grad = G.MaxPoolGradWithArgmax(
271
- kernel_size=self.kernel_size,
272
- strides=self.strides,
273
- pad_mode=self.pad_mode)
274
-
275
- def bprop(x, out, dout):
276
- dx = maxpool_grad(x, dout[0], out[1])
277
- return (dx,)
278
-
279
- return bprop
280
-
281
-
282
- @bprop_getters.register(G.MaxPoolGrad)
283
- def get_bprop_max_pool_grad_grad(self):
284
- """Grad definition for `MaxPoolGrad` operation."""
285
- device_target = context.get_context("device_target")
286
- is_ascend = (device_target == "Ascend")
287
- if device_target == "Ascend":
288
- maxpool_grad_grad = G.MaxPoolGradGrad(
289
- kernel_size=self.kernel_size,
290
- strides=self.strides,
291
- pad_mode=self.pad_mode)
292
- elif device_target == "GPU":
293
- if self.data_format != "NCHW":
294
- raise RuntimeError("MaxPoolGradGrad does not support NHWC!")
295
- kernel_size = self.kernel_size
296
- if isinstance(kernel_size, tuple) and len(kernel_size) == 4:
297
- kernel_size = kernel_size[2:]
298
- strides = self.strides
299
- if isinstance(strides, tuple) and len(strides) == 4:
300
- strides = strides[2:]
301
- maxpool_with_argmax = P.MaxPoolWithArgmax(kernel_size=kernel_size, strides=strides, pad_mode=self.pad_mode)
302
- gather = P.GatherNd()
303
- reshape = P.Reshape()
304
- else:
305
- raise RuntimeError("MaxPoolGradGrad does not support on CPU!")
306
- shape_op = P.Shape()
307
- dyn_shape_op = P.TensorShape()
308
- op_range = P.Range()
309
- dyn_broadcast_op = inner.DynamicBroadcastTo()
310
-
311
-
312
- def bprop(x1, x2, grad, out, dout):
313
- dx1 = zeros_like(x1)
314
- dx2 = zeros_like(x2)
315
- if is_ascend:
316
- dgrad = maxpool_grad_grad(x1, x2, dout)
317
- else:
318
- shape_x2 = shape_op(x2)
319
- if F.is_sequence_value_unknown(shape_x2):
320
- shape_x2 = dyn_shape_op(x2)
321
- b, c, h, w = shape_x2
322
- _, ind = maxpool_with_argmax(x1)
323
- batch = op_range(F.cast(0, mstype.int32), F.cast(b, mstype.int32), F.cast(1, mstype.int32))
324
- batch = dyn_broadcast_op(reshape(batch, (-1, 1)),
325
- create_tensor_by_element((dyn_size(batch), c * h * w)))
326
- gather_ind = P.Stack(-1)((batch, reshape(ind, create_tensor_by_element((b, -1)))))
327
- dgrad = reshape(gather(reshape(dout, create_tensor_by_element((b, -1))), gather_ind),
328
- create_tensor_by_element((b, c, h, w)))
329
- else:
330
- b, c, h, w = shape_x2
331
- _, ind = maxpool_with_argmax(x1)
332
- batch = F.cast(F.tuple_to_array(range(b)), mstype.int32)
333
- batch = P.Tile()(reshape(batch, (-1, 1)), (1, c * h * w))
334
- gather_ind = P.Stack(-1)((batch, reshape(ind, (b, -1))))
335
- dgrad = reshape(gather(reshape(dout, (b, -1)), gather_ind), (b, c, h, w))
336
- return (dx1, dx2, dgrad)
337
-
338
- return bprop
339
-
340
-
341
- @bprop_getters.register(G.MaxPoolGradGrad)
342
- def get_bprop_max_pool_grad_grad_grad(self):
343
- """Grad definition for `MaxPoolGradGrad` operation."""
344
- maxpool_grad = G.MaxPoolGrad(
345
- kernel_size=self.kernel_size,
346
- strides=self.strides,
347
- pad_mode=self.pad_mode)
348
-
349
- def bprop(x1, x2, grad, out, dout):
350
- dx1 = zeros_like(x1)
351
- dx2 = zeros_like(x2)
352
- dgrad = maxpool_grad(x1, x2, dout)
353
- return (dx1, dx2, dgrad)
354
-
355
- return bprop
356
-
357
-
358
- @bprop_getters.register(P.MaxPool3D)
359
- def get_bprop_max_pool3d_grad(self):
360
- """Grad definition for `MaxPool3D` operation."""
361
- max_pool3d_grad = G.MaxPool3DGrad(
362
- kernel_size=self.kernel_size,
363
- strides=self.strides,
364
- pad_mode=self.pad_mode,
365
- pad_list=self.pad_list,
366
- data_format=self.data_format)
367
-
368
- def bprop(x, out, dout):
369
- dx = max_pool3d_grad(x, out, dout)
370
- return (dx,)
371
-
372
- return bprop
373
-
374
-
375
- @bprop_getters.register(G.MaxPool3DGrad)
376
- def get_bprop_max_pool3d_grad_grad(self):
377
- """Grad definition for `MaxPool3Grad` operation."""
378
- max_pool3d_grad_grad = G.MaxPool3DGradGrad(
379
- kernel_size=self.kernel_size,
380
- strides=self.strides,
381
- pad_mode=self.pad_mode,
382
- data_format=self.data_format)
383
-
384
- def bprop(x, y, grad, out, dout):
385
- dgrad = max_pool3d_grad_grad(x, y, dout)
386
- return zeros_like(x), zeros_like(y), dgrad
387
-
388
- return bprop
389
-
390
-
391
- @bprop_getters.register(G.MaxPool3DGradGrad)
392
- def get_bprop_max_pool3d_grad_grad_grad(self):
393
- """Grad definition for `MaxPool3GradGrad` operation."""
394
-
395
- max_pool3d_grad = G.MaxPool3DGrad(
396
- kernel_size=self.kernel_size,
397
- strides=self.strides,
398
- pad_mode=self.pad_mode,
399
- data_format=self.data_format)
400
-
401
- def bprop(x, y, grad, out, dout):
402
- dgrad = max_pool3d_grad(x, y, dout)
403
- return zeros_like(x), zeros_like(y), dgrad
404
-
405
- return bprop
406
-
407
-
408
- @bprop_getters.register(nps.AdaptiveMaxPool2D)
409
- def get_bprop_adaptive_max_pool2d_grad(self):
410
- """Grad definition for `AdaptiveMaxPool2D` operation."""
411
- adaptive_maxpool2d_grad = G.AdaptiveMaxPool2DGrad()
412
-
413
- def bprop(x, out, dout):
414
- dy = dout[0]
415
- index = out[1]
416
- dx = adaptive_maxpool2d_grad(dy, x, index)
417
- return (dx,)
418
-
419
- return bprop
420
-
421
-
422
- @bprop_getters.register(P.AvgPool)
423
- def get_bprop_avg_pool_grad(self):
424
- """Grad definition for `AvgPool` operation."""
425
- avgpool_grad = G.AvgPoolGrad(
426
- kernel_size=self.kernel_size,
427
- strides=self.strides,
428
- pad_mode=self.pad_mode,
429
- data_format=self.format)
430
-
431
- def bprop(x, out, dout):
432
- dx = avgpool_grad(x, out, dout)
433
- return (dx,)
434
-
435
- return bprop
436
-
437
-
438
- @bprop_getters.register(P.AdaptiveAvgPool2D)
439
- def get_bprop_adaptive_avg_pool2d_grad(self):
440
- """Grad definition for `AdaptiveAvgPool2D` operation."""
441
- adaptive_avgpool_grad = G.AdaptiveAvgPool2DGrad()
442
- shape = P.TensorShape()
443
-
444
- def bprop(x, out, dout):
445
- dx = adaptive_avgpool_grad(dout, shape(x))
446
- return (dx,)
447
-
448
- return bprop
449
-
450
-
451
- @bprop_getters.register(P.AvgPool3D)
452
- def get_bprop_avg_pool_3d_grad(self):
453
- """Grad definition for `AvgPool3D` operation."""
454
- pad_list = self.get_attr_dict()['pad_list']
455
- count_include_pad = self.get_attr_dict()['count_include_pad']
456
- avgpool3d_grad = G.AvgPool3DGrad(kernel_size=self.kernel_size,
457
- strides=self.strides,
458
- pads=pad_list,
459
- ceil_mode=self.ceil_mode,
460
- count_include_pad=count_include_pad,
461
- divisor_override=self.divisor_override,
462
- data_format=self.data_format,
463
- pad_mode=self.pad_mode)
464
-
465
- def bprop(x, out, dout):
466
- x_shape = F.shape(x)
467
- if F.is_sequence_value_unknown(x_shape):
468
- x_shape = P.TensorShape()(x)
469
- dx = avgpool3d_grad(x_shape, dout)
470
- return (dx,)
471
-
472
- return bprop
473
-
474
-
475
- @bprop_getters.register(P.DropoutGenMask)
476
- def get_bprop_dropout_gen_mask(self):
477
- """Grad definition for `DropoutGenMask` operation."""
478
-
479
- def bprop(shape, keep_prob, out, dout):
480
- return (zeros_like(shape), zeros_like(keep_prob))
481
-
482
- return bprop
483
-
484
-
485
- @bprop_getters.register(P.DropoutDoMask)
486
- def get_bprop_dropout_do_mask(self):
487
- """Grad definition for `DropoutDoMask` operation."""
488
- do_mask = P.DropoutDoMask()
489
-
490
- def bprop(x, y, keep_prob, out, dout):
491
- return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
492
-
493
- return bprop
494
-
495
-
496
- @bprop_getters.register(P.Mish)
497
- def get_bprop_mish(self):
498
- """Grad definition for `Mish` operation."""
499
- tanh = P.Tanh()
500
- tanh_grad = G.TanhGrad()
501
- softplus = P.Softplus()
502
- softplus_grad = G.SoftplusGrad()
503
-
504
- def bprop(x, out, dout):
505
- dx1 = tanh(softplus(x))
506
- dx2 = softplus_grad(tanh_grad(dx1, x * dout), x)
507
- dx = (dx1 * dout + dx2)
508
- return (dx,)
509
-
510
- return bprop
511
-
512
-
513
- @bprop_getters.register(P.SeLU)
514
- def get_bprop_selu(self):
515
- """Grad definition for `SeLU` operation."""
516
- scale = 1.0507009873554804934193349852946
517
- elu_grad = G.EluGrad()
518
-
519
- def bprop(x, out, dout):
520
- dx = elu_grad(dout, out) * scale
521
- return (dx,)
522
-
523
- return bprop
524
-
525
-
526
- @bprop_getters.register(P.MulNoNan)
527
- def get_bprop_mul_no_nan(self):
528
- """Grad definition for `MulNoNan` operation."""
529
- mul_no_nan = P.MulNoNan()
530
- reduce_sum = P.ReduceSum()
531
- reshape = P.Reshape()
532
-
533
- def bprop(x, y, out, dout):
534
- x_shape = F.shape(x)
535
- y_shape = F.shape(y)
536
- dx = mul_no_nan(dout, y)
537
- dy = mul_no_nan(x, dout)
538
- broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape)
539
- if broadcast_x != ():
540
- dx = reshape(reduce_sum(dx, broadcast_x), x_shape)
541
- if broadcast_y != ():
542
- dy = reshape(reduce_sum(dy, broadcast_y), y_shape)
543
- return dx, dy
544
-
545
- return bprop
546
-
547
-
548
- @bprop_getters.register(G.ReluGrad)
549
- def get_bprop_relu_grad(self):
550
- """Grad definition for `ReLUGrad` operation."""
551
- input_grad = G.ReluGrad()
552
-
553
- def bprop(grad, y, out, dout):
554
- dgrad = input_grad(dout, y)
555
- return dgrad, zeros_like(y)
556
-
557
- return bprop
558
-
559
-
560
- @bprop_getters.register(P.ReLU6)
561
- def get_bprop_relu6(self):
562
- """Grad definition for `ReLU6` operation."""
563
- input_grad = G.ReLU6Grad()
564
-
565
- def bprop(x, out, dout):
566
- dx = input_grad(dout, x)
567
- return (dx,)
568
-
569
- return bprop
570
-
571
-
572
- @bprop_getters.register(P.ReLUV2)
573
- def get_bprop_relu_v2(self):
574
- """Grad definition for `ReLUV2` operation."""
575
- input_grad = G.ReluGradV2()
576
-
577
- def bprop(x, out, dout):
578
- mask = out[1]
579
- dx = input_grad(dout[0], mask)
580
- return (dx,)
581
-
582
- return bprop
583
-
584
-
585
- @bprop_getters.register(P.HSwish)
586
- def get_bprop_hswish(self):
587
- """Grad definition for `HSwish` operation."""
588
- input_grad = G.HSwishGrad()
589
-
590
- def bprop(x, out, dout):
591
- dx = input_grad(dout, x)
592
- return (dx,)
593
-
594
- return bprop
595
-
596
-
597
- @bprop_getters.register(P.HSigmoid)
598
- def get_bprop_hsigmoid(self):
599
- """Grad definition for `HSigmoid` operation."""
600
- input_grad = G.HSigmoidGrad()
601
-
602
- def bprop(x, out, dout):
603
- dx = input_grad(dout, x)
604
- return (dx,)
605
-
606
- return bprop
607
-
608
-
609
- @bprop_getters.register(P.Elu)
610
- def get_bprop_elu(self):
611
- """Grad definition for `Elu` operation."""
612
- input_grad = G.EluGrad()
613
-
614
- def bprop(x, out, dout):
615
- dx = input_grad(dout, out)
616
- return (dx,)
617
-
618
- return bprop
619
-
620
-
621
- @bprop_getters.register(P.Sigmoid)
622
- def get_bprop_sigmoid(self):
623
- """Grad definition for `Sigmoid` operation."""
624
- input_grad = G.SigmoidGrad()
625
-
626
- def bprop(x, out, dout):
627
- dx = input_grad(out, dout)
628
- return (dx,)
629
-
630
- return bprop
631
-
632
-
633
- @bprop_getters.register(G.SigmoidGrad)
634
- def get_bprop_sigmoid_grad(self):
635
- """Grad definition for `SigmoidGrad` operation."""
636
- sigmoid_grad = G.SigmoidGrad()
637
-
638
- def bprop(y, grad, out, dout):
639
- dy = dout * grad * (1. - 2 * y)
640
- dgrad = sigmoid_grad(y, dout)
641
- return dy, dgrad
642
-
643
- return bprop
644
-
645
-
646
- @_primexpr
647
- def _get_transpose_axis(x_shp, axis):
648
- rank = len(x_shp)
649
- if axis < 0:
650
- axis += rank
651
- reverse_axis = [i for i in range(rank)]
652
- reverse_axis[axis] = rank - 1
653
- reverse_axis[rank - 1] = axis
654
- return tuple(reverse_axis)
655
-
656
-
657
- def _get_dyn_transpose_axis(x, axis, is_ascend):
658
- """Get transpose axis"""
659
- if F.is_sequence_shape_unknown(P.Shape()(x)):
660
- rank = dyn_rank(x)
661
- start = Tensor(0, dtype=mstype.int64)
662
- delta = Tensor(1, dtype=mstype.int64)
663
- else:
664
- rank = P.Cast()(len(P.Shape()(x)), mstype.int64)
665
- start = P.Cast()(0, mstype.int64)
666
- delta = P.Cast()(1, mstype.int64)
667
-
668
- if axis < 0:
669
- axis += rank
670
- range_ops = P.Range()
671
-
672
- reverse_axis = range_ops(start, rank, delta)
673
- if is_ascend:
674
- reverse_axis = P.Cast()(reverse_axis, mstype.int8)
675
- axis = P.Cast()(axis, mstype.int32)
676
- reverse_axis[axis] = rank - 1
677
- rank = P.Cast()(rank, mstype.int32)
678
- else:
679
- reverse_axis[axis] = rank - 1
680
-
681
- reverse_axis[rank - 1] = axis
682
- return reverse_axis
683
-
684
-
685
- @bprop_getters.register(P.Softmax)
686
- def get_bprop_softmax(self):
687
- """Grad definition for `Softmax` operation."""
688
- sum_func = P.ReduceSum(keep_dims=True)
689
- sub = P.Sub()
690
- mul = P.Mul()
691
- get_shape = P.Shape()
692
- transpose = P.Transpose()
693
- axis = self.axis
694
- if not isinstance(axis, int):
695
- axis = axis[0]
696
-
697
- device_target = context.get_context("device_target")
698
- is_ascend = (device_target == "Ascend")
699
-
700
- def bprop(x, out, dout):
701
- # dx can be expressed as (dout - sum(dout * out)) * out
702
- # This formula is correct only when the `axis` is the last dimension.
703
- # In order to support the scenario where the `axis` is other values,
704
- # we transpose the data of the `axis` dimension to the last dimension for calculation,
705
- # and then transpose it back after the calculation.
706
- shp = get_shape(x)
707
- if F.is_sequence_value_unknown(shp):
708
- reverse_axis = _get_dyn_transpose_axis(x, axis, is_ascend)
709
- if is_ascend:
710
- reverse_axis = P.Cast()(reverse_axis, mstype.int32)
711
- else:
712
- reverse_axis = _get_transpose_axis(get_shape(x), axis)
713
- out = transpose(out, reverse_axis)
714
- dout = transpose(dout, reverse_axis)
715
- dx = mul(out, sub(dout, sum_func(mul(out, dout), -1)))
716
- dx = transpose(dx, reverse_axis)
717
- return (dx,)
718
-
719
- return bprop
720
-
721
-
722
- @bprop_getters.register(P.LogSoftmax)
723
- def get_bprop_log_softmax(self):
724
- """Grad definition for `LogSoftmax` operation."""
725
- logsoftmax_grad = G.LogSoftmaxGrad(self.axis)
726
-
727
- def bprop(x, out, dout):
728
- dx = logsoftmax_grad(out, dout)
729
- return (dx,)
730
-
731
- return bprop
732
-
733
-
734
- @bprop_getters.register(P.Softplus)
735
- def get_bprop_softplus(self):
736
- """Grad definition for `Softplus` operation."""
737
- softplus_grad = G.SoftplusGrad()
738
-
739
- def bprop(x, out, dout):
740
- dx = softplus_grad(dout, x)
741
- return (dx,)
742
-
743
- return bprop
744
-
745
-
746
- @bprop_getters.register(P.Softsign)
747
- def get_bprop_softsign(self):
748
- """Grad definition for `Softsign` operation."""
749
- mul = P.Mul()
750
- absolute = P.Abs()
751
- div = P.Div()
752
- square = P.Square()
753
-
754
- def bprop(x, out, dout):
755
- dx = mul(dout, div(1, square(1 + absolute(x))))
756
- return (dx,)
757
-
758
- return bprop
759
-
760
-
761
- @bprop_getters.register(P.Tanh)
762
- def get_bprop_tanh(self):
763
- """Grad definition for `Tanh` operation."""
764
- tanh_grad = G.TanhGrad()
765
- conj = P.Conj()
766
-
767
- def bprop(x, out, dout):
768
- if x.dtype in (mstype.complex64, mstype.complex128):
769
- dout = conj(dout)
770
- dx = tanh_grad(out, dout)
771
- dx = conj(dx)
772
- else:
773
- dx = tanh_grad(out, dout)
774
- return (dx,)
775
-
776
- return bprop
777
-
778
-
779
- @bprop_getters.register(G.TanhGrad)
780
- def get_bprop_tanh_grad(self):
781
- """Grad definition for `TanhGrad` operation."""
782
- tanh_grad = G.TanhGrad()
783
-
784
- def bprop(y, grad, out, dout):
785
- dy = dout * -2.0 * grad * y
786
- dgrad = tanh_grad(y, dout)
787
- return dy, dgrad
788
-
789
- return bprop
790
-
791
-
792
- @bprop_getters.register(P.FastGeLU)
793
- def get_bprop_fast_gelu(self):
794
- """Grad definition for `FastGeLU` operation."""
795
- input_grad = G.FastGeLUGrad()
796
-
797
- def bprop(x, out, dout):
798
- dx = input_grad(dout, x)
799
- return (dx,)
800
-
801
- return bprop
802
-
803
-
804
- @bprop_getters.register(P.FastGelu)
805
- def get_bprop_fast_gelu_2(self):
806
- """Grad definition for `FastGeLU` operation."""
807
- input_grad = G.FastGeLUGrad()
808
-
809
- def bprop(x, out, dout):
810
- dx = input_grad(dout, x)
811
- return (dx,)
812
-
813
- return bprop
814
-
815
-
816
- @bprop_getters.register(P.InstanceNorm)
817
- def get_bprop_instance_norm(self):
818
- """Grad definition for `InstanceNorm` operation."""
819
- input_grad = G.InstanceNormGrad(self.epsilon, self.momentum)
820
-
821
- def bprop(x, gamma, beta, mean, variance, out, dout):
822
- saved_mean = out[1]
823
- saved_variance = out[2]
824
- out = input_grad(dout[0], x, gamma, saved_mean, saved_variance)
825
- dx = out[0]
826
- dgamma = out[1]
827
- dbeta = out[2]
828
- return dx, dgamma, dbeta, zeros_like(mean), zeros_like(variance)
829
-
830
- return bprop
831
-
832
-
833
- @bprop_getters.register(G.BatchNormGrad)
834
- def get_bprop_batch_norm_grad(self):
835
- """Grad definition for `BatchNorm` operation."""
836
- grad_op = G.BatchNormGradGrad(self.is_training, self.epsilon, self.data_format)
837
-
838
- def bprop(dy, x, scale, mean, variance, reserve, out, dout):
839
- dx, ddy, dscale = grad_op(x, dy, scale, mean, variance, dout[0], dout[1], dout[2])
840
- return ddy, dx, dscale, zeros_like(mean), zeros_like(variance), zeros_like(reserve)
841
-
842
- return bprop
843
-
844
-
845
- @bprop_getters.register(G.LayerNormGrad)
846
- def get_bprop_layer_norm_grad(self):
847
- """Grad definition for `LayerNormGrad` operation."""
848
- layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis)
849
-
850
- def bprop(x, dy, variance, mean, gamma, out, dout):
851
- d_x, d_dy, d_gamma = layer_norm_grad_grad(
852
- x, dy, variance, mean, gamma, dout[0], dout[1], dout[2])
853
- return d_x, d_dy, zeros_like(variance), zeros_like(mean), d_gamma
854
-
855
- return bprop
856
-
857
-
858
- @bprop_getters.register(P.L2Normalize)
859
- def get_bprop_l2normalize(self):
860
- """Grad definition for `L2Normalize` operation."""
861
- input_grad = G.L2NormalizeGrad(self.axis, self.epsilon)
862
-
863
- def bprop(x, out, dout):
864
- dx = input_grad(x, out, dout)
865
- return (dx,)
866
-
867
- return bprop
868
-
869
-
870
- @bprop_getters.register(P.SoftmaxCrossEntropyWithLogits)
871
- def get_bprop_softmax_cross_entropy_with_logits(self):
872
- """Grad definition for `SoftmaxCrossEntropyWithLogits` operation."""
873
- expand = P.ExpandDims()
874
-
875
- def bprop(logits, labels, out, dout):
876
- grad = out[1]
877
- grad = grad * expand(dout[0], -1)
878
- return grad, zeros_like(labels)
879
-
880
- return bprop
881
-
882
-
883
- @bprop_getters.register(P.NLLLoss)
884
- def get_bprop_nll_loss(self):
885
- """Grad definition for `NLLLoss` operation."""
886
- nll_loss_grad = G.NLLLossGrad(reduction=self.reduction)
887
-
888
- def bprop(x, target, weight, out, dout):
889
- total_weight = out[1]
890
- dout_x = dout[0]
891
- dx = nll_loss_grad(x, dout_x, target, weight, total_weight)
892
- return dx, zeros_like(target), zeros_like(weight)
893
-
894
- return bprop
895
-
896
-
897
- @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
898
- def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
899
- """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation."""
900
- is_grad = self.is_grad
901
- grad_op = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
902
-
903
- def bprop(logits, labels, out, dout):
904
- grad = out[0]
905
- if not is_grad:
906
- # if construct use loss
907
- grad = grad_op(logits, labels)
908
- grad = F.depend(grad, out)
909
- grad = grad * dout
910
- return grad, zeros_like(labels)
911
-
912
- return bprop
913
-
914
-
915
- @bprop_getters.register(P.ResizeBilinear)
916
- def get_bprop_resize_bilinear(self):
917
- """Grad definition for `ResizeBilinear` operation."""
918
- resize_grad = G.ResizeBilinearGrad(self.align_corners, self.half_pixel_centers)
919
-
920
- def bprop(x, out, dout):
921
- dx = resize_grad(dout, x)
922
- return (dx,)
923
-
924
- return bprop
925
-
926
-
927
- @bprop_getters.register(P.OneHot)
928
- def get_bprop_onehot(self):
929
- """Grad definition for `OneHot` operation."""
930
-
931
- def bprop(indices, depth, on_value, off_value, out, dout):
932
- return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
933
-
934
- return bprop
935
-
936
-
937
- @bprop_getters.register(P.TopK)
938
- def get_bprop_top_kv2(self):
939
- """Grad definition for `TopK` operation."""
940
- scatter = P.ScatterNd()
941
- expand_dims = P.ExpandDims()
942
- shape_op = P.Shape()
943
- dyn_shape = P.TensorShape()
944
- reshape_op = P.Reshape()
945
- dtype = P.DType()
946
- cast = P.Cast()
947
-
948
- def _bprop_static(input_x, k, out, dout):
949
- in_shape = shape_op(input_x)
950
- in_lastdim = in_shape[-1]
951
-
952
- indices = out[1]
953
- ind_shape = shape_op(indices)
954
- ind_lastdim = ind_shape[-1]
955
-
956
- ind_2d = reshape_op(indices, (-1, ind_lastdim))
957
- outerdim = shape_op(ind_2d)[0]
958
-
959
- # range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
960
- indices_dtype = dtype(indices)
961
- range_flatten_index = range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
962
-
963
- # expand_dims to (k, 1), then broadcast
964
- ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
965
- in_shape_1d = get_1d_shape(in_shape)
966
-
967
- out_grad = reshape_op(
968
- scatter(
969
- expand_dims(ind, -1),
970
- reshape_op(dout[0], (-1,)),
971
- in_shape_1d),
972
- in_shape)
973
- return out_grad, zeros_like(k)
974
-
975
- def _bprop_dynshape(input_x, k, out, dout):
976
- in_shape = dyn_shape(input_x)
977
- in_lastdim = in_shape[-1]
978
-
979
- indices = out[1]
980
- ind_shape = dyn_shape(indices)
981
- ind_lastdim = ind_shape[-1]
982
-
983
- ind_2d = reshape_op(indices, create_tensor_by_element((-1, ind_lastdim)))
984
- outerdim = dyn_shape(ind_2d)[0]
985
-
986
- # range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
987
- range_flatten_index = P.Range()(cast(0, mstype.int64), outerdim * in_lastdim, in_lastdim)
988
-
989
- # expand_dims to (k, 1), then broadcast
990
- ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), create_tensor_by_element((-1,)))
991
- in_shape_1d = expand_dims(dyn_size(input_x, mstype.int64), -1)
992
-
993
- out_grad = reshape_op(
994
- scatter(
995
- expand_dims(ind, -1),
996
- reshape_op(dout[0], create_tensor_by_element((-1,))),
997
- in_shape_1d),
998
- in_shape)
999
- return out_grad, zeros_like(k)
1000
-
1001
- def bprop(input_x, k, out, dout):
1002
- if F.is_sequence_value_unknown(shape_op(input_x)):
1003
- return _bprop_dynshape(input_x, k, out, dout)
1004
- return _bprop_static(input_x, k, out, dout)
1005
-
1006
- return bprop
1007
-
1008
-
1009
- @bprop_getters.register(P.SmoothL1Loss)
1010
- def get_bprop_smooth_l1_loss(self):
1011
- """Grad definition for `SmoothL1Loss` operation."""
1012
- grad = G.SmoothL1LossGrad(self.beta, self.reduction)
1013
-
1014
- def bprop(prediction, target, out, dout):
1015
- dx = grad(prediction, target, dout)
1016
- dy = grad(target, prediction, dout)
1017
- return dx, dy
1018
-
1019
- return bprop
1020
-
1021
-
1022
- @bprop_getters.register(P.L2Loss)
1023
- def get_bprop_l2_loss(self):
1024
- """Grad definition for `L2Loss` operation."""
1025
-
1026
- def bprop(x, out, dout):
1027
- dx = x * dout
1028
- return (dx,)
1029
-
1030
- return bprop
1031
-
1032
-
1033
- @bprop_getters.register(P.RNNTLoss)
1034
- def get_bprop_rnnt_loss(self):
1035
- """Grad definition for `RNNTLoss` operation."""
1036
-
1037
- def bprop(acts, labels, act_lens, label_lens, out, dout):
1038
- grad = out[1]
1039
- return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
1040
-
1041
- return bprop
1042
-
1043
-
1044
- @bprop_getters.register(P.PReLU)
1045
- def get_bprop_prelu(self):
1046
- """Grad definition for `PReLU` operation."""
1047
- grad = G.PReLUGrad()
1048
-
1049
- def bprop(x, w, out, dout):
1050
- dx, dw = grad(dout, x, w)
1051
- return dx, dw
1052
-
1053
- return bprop
1054
-
1055
-
1056
- @bprop_getters.register(P.LSTM)
1057
- def get_bprop_lstm(self):
1058
- """Grad definition for `LSTM` operation."""
1059
- lstm_grad_data = G.LSTMGradData(
1060
- input_size=self.input_size,
1061
- hidden_size=self.hidden_size,
1062
- num_layers=self.num_layers,
1063
- has_bias=self.has_bias,
1064
- bidirectional=self.bidirectional,
1065
- dropout=self.dropout
1066
- )
1067
-
1068
- lstm_grad_weight = G.LSTMGradWeight(
1069
- input_size=self.input_size,
1070
- hidden_size=self.hidden_size,
1071
- num_layers=self.num_layers,
1072
- has_bias=self.has_bias,
1073
- bidirectional=self.bidirectional,
1074
- dropout=self.dropout
1075
- )
1076
- lstm_grad = G.LSTMGrad(
1077
- input_size=self.input_size,
1078
- hidden_size=self.hidden_size,
1079
- num_layers=self.num_layers,
1080
- has_bias=self.has_bias,
1081
- bidirectional=self.bidirectional,
1082
- dropout=self.dropout
1083
- )
1084
-
1085
- def bprop(x, hx, cx, w, out, dout):
1086
- y, _, _, reserve, state = out
1087
- dy, dhy, dcy, _, _ = dout
1088
- dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
1089
- dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
1090
- return dx, dhx, dcx, dw
1091
-
1092
- #
1093
- def bprop_cpu(x, hx, cx, w, out, dout):
1094
- y, hy, cy, reserve, _ = out
1095
- dy, dhy, dcy, _, _ = dout
1096
- dx, dhx, dcx, dw = lstm_grad(x, hx, cx, w, y, hy, cy, dy, dhy, dcy, reserve)
1097
- return dx, dhx, dcx, dw
1098
-
1099
- if context.get_context('device_target') == "CPU":
1100
- self.add_prim_attr("is_training", True)
1101
- return bprop_cpu
1102
-
1103
- return bprop
1104
-
1105
-
1106
- @bprop_getters.register(rl_ops.GRUV2)
1107
- def get_bppro_gru_v2(self):
1108
- """Grad definition for `GRUV2` operation."""
1109
- gru_grad_v2 = G.GRUV2Grad(
1110
- self.input_size,
1111
- self.hidden_size,
1112
- self.num_layers,
1113
- self.has_bias,
1114
- self.bidirectional,
1115
- self.dropout
1116
- )
1117
-
1118
- def bpro(x, hx, w, seq_length, out, dout):
1119
- y, hy, reverse, _ = out
1120
- dy, dhy, _, _ = dout
1121
- dx, dhx, dw = gru_grad_v2(x, hx, w, seq_length, y, hy, dy, dhy, reverse)
1122
- return dx, dhx, dw, (0)
1123
-
1124
- return bpro
1125
-
1126
-
1127
- @bprop_getters.register(rl_ops.CudnnGRU)
1128
- def get_bprop_gru(self):
1129
- """Grad definition for `GRU` operation."""
1130
- gru_grad_data = G.GruGradData(
1131
- input_size=self.input_size,
1132
- hidden_size=self.hidden_size,
1133
- num_layers=self.num_layers,
1134
- has_bias=self.has_bias,
1135
- bidirectional=self.bidirectional,
1136
- dropout=self.dropout
1137
- )
1138
-
1139
- gru_grad_weight = G.GruGradWeight(
1140
- input_size=self.input_size,
1141
- hidden_size=self.hidden_size,
1142
- num_layers=self.num_layers,
1143
- has_bias=self.has_bias,
1144
- bidirectional=self.bidirectional,
1145
- dropout=self.dropout
1146
- )
1147
-
1148
- def bprop(x, hx, w, out, dout):
1149
- y, _, reserve, state = out
1150
- dy, dhy, _, _ = dout
1151
- dx, dhx = gru_grad_data(y, dy, dhy, w, hx, reserve, state)
1152
- dw = gru_grad_weight(F.depend(x, dx), hx, y, reserve, state)
1153
- return dx, dhx, dw
1154
-
1155
- return bprop
1156
-
1157
-
1158
- @bprop_getters.register(P.DynamicRNN)
1159
- def get_bprop_dynamic_rnn(self):
1160
- """Grad definition for `DynamicRNN` operation."""
1161
- dynamic_rnn_grad = G.DynamicRNNGrad(cell_type=self.cell_type,
1162
- direction=self.direction,
1163
- cell_depth=self.cell_depth,
1164
- use_peephole=self.use_peephole,
1165
- keep_prob=self.keep_prob,
1166
- cell_clip=self.cell_clip,
1167
- num_proj=self.num_proj,
1168
- time_major=self.time_major,
1169
- forget_bias=self.forget_bias)
1170
- expand_dims = P.ExpandDims()
1171
-
1172
- def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
1173
- dy, dh, dc, _, _, _, _, _, = dout
1174
- dh = dh[-1]
1175
- dc = dc[-1]
1176
- y, h, c, i, j, f, o, tanhct = out
1177
- dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
1178
- c, dy, dh, dc, i, j, f, o, tanhct)
1179
- dh_prev = expand_dims(dh_prev, 0)
1180
- dc_prev = expand_dims(dc_prev, 0)
1181
- return dx, dw, db, (0), dh_prev, dc_prev
1182
-
1183
- return bprop
1184
-
1185
-
1186
- @bprop_getters.register(P.DynamicGRUV2)
1187
- def get_bprop_dynamic_gru_v2(self):
1188
- """Grad definition for `DynamicGRUV2` operation."""
1189
- dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip,
1190
- self.num_proj, self.time_major, self.gate_order,
1191
- self.reset_after)
1192
-
1193
- def bprop(x, winput, whidden, binput, bhidden, seq, init_h, out, dout):
1194
- y, out_h, update, reset, new, hidden_new = out
1195
- dy, dout_h, _, _, _, _ = dout
1196
-
1197
- dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev = dynamic_gru_v2_grad(x, winput, whidden, y, init_h,
1198
- out_h, dy, dout_h[-1], update,
1199
- reset, new, hidden_new, None, None)
1200
- return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
1201
-
1202
- return bprop
1203
-
1204
-
1205
- @bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
1206
- def get_bprop_sigmoid_crossentropy_with_logits(self):
1207
- """Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
1208
- op = G.SigmoidCrossEntropyWithLogitsGrad()
1209
-
1210
- def bprop(x, y, out, dout):
1211
- dx = op(x, y, dout)
1212
- return (dx, zeros_like(y))
1213
-
1214
- return bprop
1215
-
1216
-
1217
- @bprop_getters.register(P.Pad)
1218
- def get_bprop_pad(self):
1219
- """Grad definition for `Pad` operation."""
1220
- shape_op = P.Shape()
1221
- dyn_shape_op = P.TensorShape()
1222
- paddings = self.paddings
1223
-
1224
- def bprop(x, out, dout):
1225
- begin = ()
1226
- for item in paddings:
1227
- begin += (item[0],)
1228
- shp = shape_op(x)
1229
- if F.is_sequence_value_unknown(shp):
1230
- shp = dyn_shape_op(x)
1231
- dx = P.Slice()(dout, begin, shp)
1232
- return (dx,)
1233
-
1234
- return bprop
1235
-
1236
-
1237
- @bprop_getters.register(P.MirrorPad)
1238
- def get_bprop_mirror_pad(self):
1239
- """Grad definition for `MirrorPad` operation."""
1240
- mirror_pad_grad = G.MirrorPadGrad(self.mode)
1241
-
1242
- def bprop(x, paddings, out, dout):
1243
- dx = mirror_pad_grad(dout, paddings)
1244
- return (dx, zeros_like(paddings))
1245
-
1246
- return bprop
1247
-
1248
-
1249
- @bprop_getters.register(P.ROIAlign)
1250
- def get_bprop_roi_align(self):
1251
- """Grad definition for `ROIAlign` operation."""
1252
- shape_op = P.Shape()
1253
- dyn_shape = P.TensorShape()
1254
- pooled_height = self.pooled_height
1255
- pooled_width = self.pooled_width
1256
- spatial_scale = self.spatial_scale
1257
- sample_num = self.sample_num
1258
-
1259
- def bprop(inputs, rois, out, dout):
1260
- inputs_shape = shape_op(inputs)
1261
- if F.is_sequence_value_unknown(inputs_shape):
1262
- inputs_shape = dyn_shape(inputs)
1263
- dx = G.ROIAlignGrad(pooled_height, pooled_width, spatial_scale, sample_num)(dout, rois, inputs_shape)
1264
- return dx, zeros_like(rois)
1265
-
1266
- return bprop
1267
-
1268
-
1269
- @bprop_getters.register(P.Conv2DTranspose)
1270
- @bprop_getters.register(P.Conv2DBackpropInput)
1271
- def get_bprop_conv2d_backprop_input(self):
1272
- """Grad definition for `Conv2DBackpropInput` operation."""
1273
- pad_list = self.get_attr_dict()['pad_list']
1274
- out_channel = self.get_attr_dict()['out_channel']
1275
- filter_grad = G.Conv2DBackpropFilter(
1276
- out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
1277
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
1278
- )
1279
- input_grad = P.Conv2D(
1280
- out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
1281
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
1282
- )
1283
- get_shape = P.Shape()
1284
- get_dyn_shape = P.TensorShape()
1285
-
1286
- def bprop(x, w, f_sizes, out, dout):
1287
- w_shape = get_shape(w)
1288
- if F.is_sequence_value_unknown(w_shape):
1289
- w_shape = get_dyn_shape(w)
1290
- dx = input_grad(dout, w)
1291
- dw = filter_grad(x, dout, w_shape)
1292
- return dx, dw, zeros_like(f_sizes)
1293
-
1294
- return bprop
1295
-
1296
-
1297
- @bprop_getters.register(P.BinaryCrossEntropy)
1298
- def get_bprop_binary_cross_entropy(self):
1299
- """Grad definition for `BinaryCrossEntropy` operation."""
1300
- grad = G.BinaryCrossEntropyGrad(self.reduction)
1301
-
1302
- def bprop(x, y, weight, out, dout):
1303
- dx = grad(x, y, dout, weight)
1304
- return dx, zeros_like(y), zeros_like(weight)
1305
-
1306
- return bprop
1307
-
1308
-
1309
- @bprop_getters.register(P.BCEWithLogitsLoss)
1310
- def get_bprop_bce_with_logits_loss(self):
1311
- """Grad definition for `BCEWithLogitsLoss` operation."""
1312
- reduction = self.reduction
1313
- mul = P.Mul()
1314
- sigmoid = P.Sigmoid()
1315
- add = P.Add()
1316
- sub = P.Sub()
1317
- size = P.Size()
1318
- neg = P.Neg()
1319
- log = P.Log()
1320
- shape = P.Shape()
1321
-
1322
- def bprop(predict, target, weight, pos_weight, out, dout):
1323
- sigmoid_input = sigmoid(predict)
1324
- if pos_weight is not None:
1325
- t = mul(target, pos_weight)
1326
- dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout)
1327
- grad_target = mul(sub(log(sub(1, sigmoid_input)), mul(pos_weight, log(sigmoid_input))), dout)
1328
- else:
1329
- dx = mul((sigmoid_input - target), dout)
1330
- grad_target = mul(predict, neg(dout))
1331
- if weight is not None:
1332
- dx = mul(dx, weight)
1333
- grad_target = mul(grad_target, weight)
1334
- if reduction == 'mean':
1335
- dx_size = dyn_size(dx) if F.is_sequence_value_unknown(shape(dx)) else size(dx)
1336
- target_size = dyn_size(target) if F.is_sequence_value_unknown(shape(target)) else size(target)
1337
- dx = dx / dx_size
1338
- grad_target = grad_target / target_size
1339
- return dx, grad_target, zeros_like(weight), zeros_like(pos_weight)
1340
-
1341
- return bprop
1342
-
1343
-
1344
- @bprop_getters.register(P.KLDivLoss)
1345
- def get_bprop_kl_div_loss(self):
1346
- """Grad definition for `KLDivLoss` operation."""
1347
- reduce_type = self.reduction
1348
-
1349
- size = P.Size()
1350
- shape = P.Shape()
1351
-
1352
- def bprop(x, y, out, dout):
1353
- if reduce_type == "mean":
1354
- grad = G.KLDivLossGrad("sum")
1355
- else:
1356
- grad = G.KLDivLossGrad(self.reduction)
1357
- dx = grad(dout, x, y)
1358
- if reduce_type == "mean":
1359
- x_size = dyn_size(x) if F.is_sequence_value_unknown(shape(x)) else size(x)
1360
- return dx / x_size, zeros_like(y)
1361
- return dx, zeros_like(y)
1362
-
1363
- return bprop
1364
-
1365
-
1366
- @bprop_getters.register(P.Dropout)
1367
- def get_bprop_dropout(self):
1368
- """Grad definition for `Dropout` operation."""
1369
- grad = G.DropoutGrad(self.keep_prob)
1370
-
1371
- def bprop(x, out, dout):
1372
- _, mask = out
1373
- dy, _ = dout
1374
- dx = grad(dy, mask)
1375
- return (dx,)
1376
-
1377
- return bprop
1378
-
1379
-
1380
- @bprop_getters.register(G.DropoutGrad)
1381
- def get_bprop_dropout_grad(self):
1382
- """Grad definition for `DropoutGrad` operation."""
1383
- grad = G.DropoutGrad(self.keep_prob)
1384
-
1385
- def bprop(x, mask, out, dout):
1386
- dy = dout
1387
- dx = grad(dy, mask)
1388
- return dx, zeros_like(mask)
1389
-
1390
- return bprop
1391
-
1392
-
1393
- @bprop_getters.register(P.Dropout2D)
1394
- @bprop_getters.register(P.Dropout3D)
1395
- def get_bprop_dropout3d(self):
1396
- """Grad definition for `Dropout2D` and `Dropout3D` operation."""
1397
- dtype = P.DType()
1398
- cast = P.Cast()
1399
- mul = P.Mul()
1400
- keep_prob = self.keep_prob
1401
-
1402
- def bprop(x, out, dout):
1403
- _, mask = out
1404
- dy, _ = dout
1405
- mask = cast(mask, mstype.float32)
1406
- if keep_prob != 0:
1407
- dy = dy * (1 / keep_prob)
1408
- dy = mul(mask, dy)
1409
- dy = cast(dy, dtype(x))
1410
- return (dy,)
1411
-
1412
- return bprop
1413
-
1414
-
1415
- @bprop_getters.register(P.CTCLoss)
1416
- def get_bprop_ctc_loss(self):
1417
- """Grad definition for `CTCLoss` operation"""
1418
- expand = P.ExpandDims()
1419
-
1420
- def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
1421
- grad_loss = out[1]
1422
- grad = grad_loss * expand(dout[0], -1)
1423
- return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
1424
-
1425
- return bprop
1426
-
1427
-
1428
- @bprop_getters.register(P.BasicLSTMCell)
1429
- def get_bprop_basic_lstm_cell(self):
1430
- """Grad definition for `BasicLSTMCell` operation."""
1431
- basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad(
1432
- forget_bias=self.forget_bias,
1433
- activation=self.activation
1434
- )
1435
-
1436
- basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad()
1437
-
1438
- basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob)
1439
-
1440
- def bprop(x, h, c, w, b, out, dout):
1441
- _, _, it, jt, ft, ot, tanhct = out
1442
- dct, dht, _, _, _, _, _ = dout
1443
- dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
1444
- dxt, dht = basic_lstm_cell_input_grad(dgate, w)
1445
- dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
1446
- return dxt, dht, dct_1, dw, db
1447
-
1448
- return bprop
1449
-
1450
-
1451
- @bprop_getters.register(nps.DeformableOffsets)
1452
- def get_bprop_deformable_offsets(self):
1453
- """Grad definition for `DeformableOffsets` operation."""
1454
- grad = G.DeformableOffsetsGrad(self.strides, self.pads, self.ksize, self.dilations, self.data_format,
1455
- self.deformable_groups, self.modulated)
1456
-
1457
- def bprop(x, offsets, out, dout):
1458
- out_grad = grad(dout, x, offsets)
1459
- return out_grad
1460
-
1461
- return bprop
1462
-
1463
-
1464
- @bprop_getters.register(P.LRN)
1465
- def get_bprop_lrn(self):
1466
- """Grad definition for `LRN` operation."""
1467
- grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta)
1468
-
1469
- def bprop(x, out, dout):
1470
- dx = grad(dout, x, out)
1471
- return (dx,)
1472
-
1473
- return bprop
1474
-
1475
-
1476
- @bprop_getters.register(G.Conv2DBackpropFilter)
1477
- def get_bprop_conv2d_backprop_filter(self):
1478
- """Grad definition for `Conv2DBackpropFilter` operation."""
1479
- input_grad = P.Conv2DBackpropInput(
1480
- self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
1481
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
1482
- )
1483
- filter_grad = P.Conv2D(
1484
- self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
1485
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
1486
- )
1487
- get_shape = P.Shape()
1488
- get_dyn_shape = P.TensorShape()
1489
-
1490
- def bprop(dy, x, filter_size, out, dout):
1491
- x_shape = get_shape(x)
1492
- if F.is_sequence_value_unknown(x_shape):
1493
- x_shape = get_dyn_shape(x)
1494
- dw_dx = input_grad(dy, dout, x_shape)
1495
- dw_dy = filter_grad(x, dout)
1496
- return dw_dy, dw_dx, zeros_like(filter_size)
1497
-
1498
- return bprop
1499
-
1500
-
1501
- @bprop_getters.register(nps.UpsampleNearest3D)
1502
- def get_bprop_upsample_nearest_3d_grad(self):
1503
- """Grad definition for `UpsampleNearest3D` operation."""
1504
- get_shape = P.Shape()
1505
- output_size = self.output_size
1506
- scales = self.scales
1507
-
1508
- def bprop(input_x, out, dout):
1509
- input_grad = G.UpsampleNearest3DGrad(get_shape(input_x), output_size, scales)
1510
- dx = input_grad(dout)
1511
- return (dx,)
1512
-
1513
- return bprop
1514
-
1515
-
1516
- @bprop_getters.register(nps.UpsampleTrilinear3D)
1517
- def get_bprop_upsample_trilinear_3d_grad(self):
1518
- """Grad definition for `UpsampleTrilinear3D` operation."""
1519
- get_shape = P.Shape()
1520
- output_size = self.output_size
1521
- scales = self.scales
1522
- align_corners = self.align_corners
1523
-
1524
- def bprop(input_x, out, dout):
1525
- input_grad = G.UpsampleTrilinear3DGrad(get_shape(input_x), output_size, scales, align_corners)
1526
- dx = input_grad(dout)
1527
- return (dx,)
1528
-
1529
- return bprop