mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,1529 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """Define the grad rules of neural network related operations."""
17
- from mindspore import context
18
- from mindspore.common import dtype as mstype
19
- from mindspore.common.tensor import Tensor
20
- from mindspore.ops.primitive import _primexpr
21
- from mindspore.ops.operations import nn_ops as nps
22
- from mindspore.ops._grad.grad_base import bprop_getters, dyn_size, create_tensor_by_element, dyn_rank
23
- from mindspore.ops import functional as F
24
- from mindspore.ops import operations as P
25
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
26
- from mindspore.ops.operations import _grad_ops as G
27
- from mindspore.ops.operations import _inner_ops as inner
28
- from mindspore.ops.operations import _rl_inner_ops as rl_ops
29
- from mindspore.ops._utils.utils import range_op, get_1d_shape
30
-
31
-
32
- @_primexpr
33
- def bias_add_gradgrad_helper(shape, bias_shape, data_format):
34
- """Helper function of BiasGradGrad to calculate expanded shape."""
35
- new_shape = list(shape)
36
- new_bias_shape = list(bias_shape)
37
-
38
- ones_1 = []
39
- ones_2 = []
40
- for _ in new_shape[2:]:
41
- ones_1.append(1)
42
-
43
- for _ in new_shape[:-1]:
44
- ones_2.append(1)
45
-
46
- if data_format == "NCHW":
47
- expanded_shape = [1] + new_bias_shape + ones_1
48
- tile_mults = [new_shape[0]] + [1] + new_shape[2:]
49
- else:
50
- expanded_shape = ones_2 + new_bias_shape
51
- tile_mults = new_shape[:-1] + [1]
52
- return tuple(expanded_shape), tuple(tile_mults)
53
-
54
-
55
- def bias_add_gradgrad_helper_dynamic(shape, bias_shape, data_format):
56
- """Helper function of BiasGradGrad to calculate expanded shape(dynamic version)."""
57
- if data_format == "NCHW":
58
- expanded_shape = P.Concat(0)((P.OnesLike()(shape[:1]), bias_shape, P.OnesLike()(shape[2:])))
59
- tile_mults = P.Concat(0)((shape[:1], Tensor([1], dtype=mstype.int64), shape[2:]))
60
- else:
61
- expanded_shape = P.Concat(0)((P.OnesLike()(shape[:-1]), bias_shape))
62
- tile_mults = P.Concat(0)((shape[:-1], Tensor([1], dtype=mstype.int64)))
63
- return expanded_shape, tile_mults
64
-
65
-
66
- @bprop_getters.register(G.BiasAddGrad)
67
- def get_bprop_bias_add_grad(self):
68
- """Grad definition for `BiasAddGrad` operation."""
69
-
70
- data_format = self.data_format
71
-
72
- def bprop(dy, out, dout):
73
- reshape = P.Reshape()
74
- tile = P.Tile()
75
- dyn_shape = P.TensorShape()
76
- dy_shape = dy.shape
77
- dout_shape = dout.shape
78
- if F.is_sequence_value_unknown(dy_shape) or F.is_sequence_value_unknown(dout_shape):
79
- dy_shape = dyn_shape(dy)
80
- dout_shape = dyn_shape(dout)
81
- expanded_shape, tile_mults = bias_add_gradgrad_helper_dynamic(dy_shape, dout_shape, data_format)
82
- expanded_grad = reshape(dout, expanded_shape)
83
- tiled_grad = tile(expanded_grad, tile_mults)
84
- else:
85
- expanded_shape, tile_mults = bias_add_gradgrad_helper(dy_shape, dout_shape, data_format)
86
- expanded_grad = reshape(dout, expanded_shape)
87
- tiled_grad = tile(expanded_grad, tile_mults)
88
- return (tiled_grad,)
89
-
90
- return bprop
91
-
92
-
93
- @bprop_getters.register(nps.Conv3D)
94
- def get_bprop_conv3d(self):
95
- """Grad definition for `Conv3D` operation."""
96
- input_grad = nps.Conv3DBackpropInput(
97
- self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
98
- pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
99
- )
100
- filter_grad = G.Conv3DBackpropFilter(
101
- self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
102
- pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
103
- )
104
- get_shape = P.Shape()
105
- get_dyn_shape = P.TensorShape()
106
- cast = P.Cast()
107
- get_dtype = P.DType()
108
-
109
- def bprop(x, w, out, dout):
110
- if F.is_sequence_value_unknown(get_shape(x)) or F.is_sequence_value_unknown(get_shape(w)):
111
- dx = input_grad(w, dout, get_dyn_shape(x))
112
- dw = cast(filter_grad(x, dout, get_dyn_shape(w)), get_dtype(x))
113
- return dx, dw
114
-
115
- dx = input_grad(w, dout, get_shape(x))
116
- dw = cast(filter_grad(x, dout, get_shape(w)), get_dtype(x))
117
- return dx, dw
118
-
119
- return bprop
120
-
121
-
122
- @bprop_getters.register(nps.Conv3DTranspose)
123
- def get_bprop_conv3d_transpose(self):
124
- """Grad definition for `Conv3DTranspose` operation."""
125
- stride = (self.stride[2], self.stride[3], self.stride[4])
126
- dilation = (self.dilation[2], self.dilation[3], self.dilation[4])
127
- pad_list = self.get_attr_dict()['pad_list']
128
- input_grad = nps.Conv3D(
129
- out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
130
- pad=pad_list, stride=stride, dilation=dilation, group=self.group, data_format=self.data_format
131
- )
132
- filter_grad = G.Conv3DBackpropFilter(
133
- out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
134
- pad=pad_list, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
135
- )
136
- get_dyn_shape = P.TensorShape()
137
-
138
- def bprop(x, w, out, dout):
139
- if F.is_sequence_value_unknown(F.shape(w)):
140
- dx = input_grad(dout, w)
141
- dw = filter_grad(dout, x, get_dyn_shape(w))
142
- return dx, dw
143
-
144
- dx = input_grad(dout, w)
145
- dw = filter_grad(dout, x, F.shape(w))
146
- return dx, dw
147
-
148
- return bprop
149
-
150
-
151
- @bprop_getters.register(inner.ExtractImagePatches)
152
- def get_bprop_extract_image_patches(self):
153
- """Grad definition for `ExtractImagePatches` operation."""
154
- get_shape = P.Shape()
155
- reshape = P.Reshape()
156
- extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes,
157
- strides=self.strides,
158
- rates=self.rates,
159
- padding=self.padding)
160
- concat = P.Concat(axis=-1)
161
- expand_dims = P.ExpandDims()
162
- scatter_nd = P.ScatterNd()
163
- dtype = P.DType()
164
- fill = P.Fill()
165
- slice_op = P.Slice()
166
- transpose = P.Transpose()
167
- cast = P.Cast()
168
- matmul = P.MatMul()
169
- range_ = P.Range()
170
- dyn_shape_op = P.TensorShape()
171
- ones_like = P.OnesLike()
172
-
173
- _, _, ksizes_row, ksizes_col = self.ksizes
174
-
175
- def _dyn_extract_image_patched(x, out, dout):
176
- x_shape = dyn_shape_op(x)
177
- out_shape = dyn_shape_op(out)
178
- x_batch, x_depth, x_row, x_col = x_shape[0], x_shape[1], x_shape[2], x_shape[3]
179
- x_indices_num = x_row * x_col + 1
180
- x_idx = range_(cast(1, mstype.float32), cast(x_indices_num, mstype.float32), cast(1, mstype.float32))
181
- x_idx = reshape(x_idx, create_tensor_by_element((1, 1, x_row, x_col)))
182
- x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
183
- x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
184
-
185
- out_row, out_col = out_shape[2], out_shape[3]
186
- out_indices_num = out_row * out_col * ksizes_row * ksizes_col
187
- out_idx_ori = range_(cast(0, mstype.int32), cast(out_indices_num, mstype.int32), cast(1, mstype.int32))
188
- out_idx = reshape(out_idx_ori, create_tensor_by_element((1, out_row, out_col, ksizes_row * ksizes_col)))
189
-
190
- idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
191
- idx_tensor = reshape(idx_tensor, (-1, 2))
192
- sp_shape = create_tensor_by_element((x_indices_num, out_indices_num))
193
- update = cast(ones_like(out_idx_ori), dtype(dout))
194
- sp_tensor = scatter_nd(idx_tensor, update, sp_shape)
195
- begin = create_tensor_by_element((1, 0))
196
- size = create_tensor_by_element((x_indices_num - 1, out_indices_num))
197
- sp_tensor = slice_op(sp_tensor, begin, size)
198
-
199
- grad = transpose(dout, (0, 2, 3, 1))
200
- grad = reshape(grad, create_tensor_by_element((x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth)))
201
- grad = transpose(grad, (1, 2, 3, 4, 0, 5))
202
- grad = reshape(grad, create_tensor_by_element((out_row * out_col * ksizes_row * ksizes_col, x_batch * x_depth)))
203
-
204
- jac = matmul(sp_tensor, grad)
205
- dx = reshape(jac, create_tensor_by_element((x_row, x_col, x_batch, x_depth)))
206
- dx = transpose(dx, (2, 3, 0, 1))
207
- return (dx,)
208
-
209
- def bprop(x, out, dout):
210
- x_shape = get_shape(x)
211
- out_shape = get_shape(out)
212
- if F.is_sequence_value_unknown(x_shape) or F.is_sequence_value_unknown(out_shape):
213
- return _dyn_extract_image_patched(x, out, dout)
214
- x_batch, x_depth, x_row, x_col = x_shape
215
- x_indices_num = x_row * x_col + 1
216
- x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32)
217
- x_idx = reshape(x_idx, (1, 1, x_row, x_col))
218
- x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
219
- x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
220
-
221
- _, _, out_row, out_col = out_shape
222
- out_indices_num = out_row * out_col * ksizes_row * ksizes_col
223
- out_idx = F.tuple_to_array(range(out_indices_num))
224
- out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
225
-
226
- idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
227
- idx_tensor = reshape(idx_tensor, (-1, 2))
228
- sp_shape = (x_indices_num, out_indices_num)
229
- sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
230
- sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
231
-
232
- grad = transpose(dout, (0, 2, 3, 1))
233
- grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
234
- grad = transpose(grad, (1, 2, 3, 4, 0, 5))
235
- grad = reshape(grad, (-1, x_batch * x_depth))
236
-
237
- jac = matmul(sp_tensor, grad)
238
- dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
239
- dx = transpose(dx, (2, 3, 0, 1))
240
- return (dx,)
241
-
242
- return bprop
243
-
244
-
245
- @bprop_getters.register(P.DepthwiseConv2dNative)
246
- def get_bprop_depthwise_conv2d_native(self):
247
- """Grad definition for `DepthwiseConv2dNative` operation."""
248
- input_grad = G.DepthwiseConv2dNativeBackpropInput(
249
- self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
250
- self.dilation, self.group
251
- )
252
- filter_grad = G.DepthwiseConv2dNativeBackpropFilter(
253
- self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
254
- self.dilation, self.group
255
- )
256
- get_shape = P.Shape()
257
-
258
- def bprop(x, w, out, dout):
259
- dx = input_grad(get_shape(x), w, dout)
260
-
261
- dw = filter_grad(x, get_shape(w), dout)
262
- return dx, dw
263
-
264
- return bprop
265
-
266
-
267
- @bprop_getters.register(P.MaxPoolWithArgmax)
268
- def get_bprop_max_pool_with_argmax(self):
269
- """Grad definition for `MaxPoolWithArgmax` operation."""
270
- maxpool_grad = G.MaxPoolGradWithArgmax(
271
- kernel_size=self.kernel_size,
272
- strides=self.strides,
273
- pad_mode=self.pad_mode)
274
-
275
- def bprop(x, out, dout):
276
- dx = maxpool_grad(x, dout[0], out[1])
277
- return (dx,)
278
-
279
- return bprop
280
-
281
-
282
- @bprop_getters.register(G.MaxPoolGrad)
283
- def get_bprop_max_pool_grad_grad(self):
284
- """Grad definition for `MaxPoolGrad` operation."""
285
- device_target = context.get_context("device_target")
286
- is_ascend = (device_target == "Ascend")
287
- if device_target == "Ascend":
288
- maxpool_grad_grad = G.MaxPoolGradGrad(
289
- kernel_size=self.kernel_size,
290
- strides=self.strides,
291
- pad_mode=self.pad_mode)
292
- elif device_target == "GPU":
293
- if self.data_format != "NCHW":
294
- raise RuntimeError("MaxPoolGradGrad does not support NHWC!")
295
- kernel_size = self.kernel_size
296
- if isinstance(kernel_size, tuple) and len(kernel_size) == 4:
297
- kernel_size = kernel_size[2:]
298
- strides = self.strides
299
- if isinstance(strides, tuple) and len(strides) == 4:
300
- strides = strides[2:]
301
- maxpool_with_argmax = P.MaxPoolWithArgmax(kernel_size=kernel_size, strides=strides, pad_mode=self.pad_mode)
302
- gather = P.GatherNd()
303
- reshape = P.Reshape()
304
- else:
305
- raise RuntimeError("MaxPoolGradGrad does not support on CPU!")
306
- shape_op = P.Shape()
307
- dyn_shape_op = P.TensorShape()
308
- op_range = P.Range()
309
- dyn_broadcast_op = inner.DynamicBroadcastTo()
310
-
311
-
312
- def bprop(x1, x2, grad, out, dout):
313
- dx1 = zeros_like(x1)
314
- dx2 = zeros_like(x2)
315
- if is_ascend:
316
- dgrad = maxpool_grad_grad(x1, x2, dout)
317
- else:
318
- shape_x2 = shape_op(x2)
319
- if F.is_sequence_value_unknown(shape_x2):
320
- shape_x2 = dyn_shape_op(x2)
321
- b, c, h, w = shape_x2
322
- _, ind = maxpool_with_argmax(x1)
323
- batch = op_range(F.cast(0, mstype.int32), F.cast(b, mstype.int32), F.cast(1, mstype.int32))
324
- batch = dyn_broadcast_op(reshape(batch, (-1, 1)),
325
- create_tensor_by_element((dyn_size(batch), c * h * w)))
326
- gather_ind = P.Stack(-1)((batch, reshape(ind, create_tensor_by_element((b, -1)))))
327
- dgrad = reshape(gather(reshape(dout, create_tensor_by_element((b, -1))), gather_ind),
328
- create_tensor_by_element((b, c, h, w)))
329
- else:
330
- b, c, h, w = shape_x2
331
- _, ind = maxpool_with_argmax(x1)
332
- batch = F.cast(F.tuple_to_array(range(b)), mstype.int32)
333
- batch = P.Tile()(reshape(batch, (-1, 1)), (1, c * h * w))
334
- gather_ind = P.Stack(-1)((batch, reshape(ind, (b, -1))))
335
- dgrad = reshape(gather(reshape(dout, (b, -1)), gather_ind), (b, c, h, w))
336
- return (dx1, dx2, dgrad)
337
-
338
- return bprop
339
-
340
-
341
- @bprop_getters.register(G.MaxPoolGradGrad)
342
- def get_bprop_max_pool_grad_grad_grad(self):
343
- """Grad definition for `MaxPoolGradGrad` operation."""
344
- maxpool_grad = G.MaxPoolGrad(
345
- kernel_size=self.kernel_size,
346
- strides=self.strides,
347
- pad_mode=self.pad_mode)
348
-
349
- def bprop(x1, x2, grad, out, dout):
350
- dx1 = zeros_like(x1)
351
- dx2 = zeros_like(x2)
352
- dgrad = maxpool_grad(x1, x2, dout)
353
- return (dx1, dx2, dgrad)
354
-
355
- return bprop
356
-
357
-
358
- @bprop_getters.register(P.MaxPool3D)
359
- def get_bprop_max_pool3d_grad(self):
360
- """Grad definition for `MaxPool3D` operation."""
361
- max_pool3d_grad = G.MaxPool3DGrad(
362
- kernel_size=self.kernel_size,
363
- strides=self.strides,
364
- pad_mode=self.pad_mode,
365
- pad_list=self.pad_list,
366
- data_format=self.data_format)
367
-
368
- def bprop(x, out, dout):
369
- dx = max_pool3d_grad(x, out, dout)
370
- return (dx,)
371
-
372
- return bprop
373
-
374
-
375
- @bprop_getters.register(G.MaxPool3DGrad)
376
- def get_bprop_max_pool3d_grad_grad(self):
377
- """Grad definition for `MaxPool3Grad` operation."""
378
- max_pool3d_grad_grad = G.MaxPool3DGradGrad(
379
- kernel_size=self.kernel_size,
380
- strides=self.strides,
381
- pad_mode=self.pad_mode,
382
- data_format=self.data_format)
383
-
384
- def bprop(x, y, grad, out, dout):
385
- dgrad = max_pool3d_grad_grad(x, y, dout)
386
- return zeros_like(x), zeros_like(y), dgrad
387
-
388
- return bprop
389
-
390
-
391
- @bprop_getters.register(G.MaxPool3DGradGrad)
392
- def get_bprop_max_pool3d_grad_grad_grad(self):
393
- """Grad definition for `MaxPool3GradGrad` operation."""
394
-
395
- max_pool3d_grad = G.MaxPool3DGrad(
396
- kernel_size=self.kernel_size,
397
- strides=self.strides,
398
- pad_mode=self.pad_mode,
399
- data_format=self.data_format)
400
-
401
- def bprop(x, y, grad, out, dout):
402
- dgrad = max_pool3d_grad(x, y, dout)
403
- return zeros_like(x), zeros_like(y), dgrad
404
-
405
- return bprop
406
-
407
-
408
- @bprop_getters.register(nps.AdaptiveMaxPool2D)
409
- def get_bprop_adaptive_max_pool2d_grad(self):
410
- """Grad definition for `AdaptiveMaxPool2D` operation."""
411
- adaptive_maxpool2d_grad = G.AdaptiveMaxPool2DGrad()
412
-
413
- def bprop(x, out, dout):
414
- dy = dout[0]
415
- index = out[1]
416
- dx = adaptive_maxpool2d_grad(dy, x, index)
417
- return (dx,)
418
-
419
- return bprop
420
-
421
-
422
- @bprop_getters.register(P.AvgPool)
423
- def get_bprop_avg_pool_grad(self):
424
- """Grad definition for `AvgPool` operation."""
425
- avgpool_grad = G.AvgPoolGrad(
426
- kernel_size=self.kernel_size,
427
- strides=self.strides,
428
- pad_mode=self.pad_mode,
429
- data_format=self.format)
430
-
431
- def bprop(x, out, dout):
432
- dx = avgpool_grad(x, out, dout)
433
- return (dx,)
434
-
435
- return bprop
436
-
437
-
438
- @bprop_getters.register(P.AdaptiveAvgPool2D)
439
- def get_bprop_adaptive_avg_pool2d_grad(self):
440
- """Grad definition for `AdaptiveAvgPool2D` operation."""
441
- adaptive_avgpool_grad = G.AdaptiveAvgPool2DGrad()
442
- shape = P.TensorShape()
443
-
444
- def bprop(x, out, dout):
445
- dx = adaptive_avgpool_grad(dout, shape(x))
446
- return (dx,)
447
-
448
- return bprop
449
-
450
-
451
- @bprop_getters.register(P.AvgPool3D)
452
- def get_bprop_avg_pool_3d_grad(self):
453
- """Grad definition for `AvgPool3D` operation."""
454
- pad_list = self.get_attr_dict()['pad_list']
455
- count_include_pad = self.get_attr_dict()['count_include_pad']
456
- avgpool3d_grad = G.AvgPool3DGrad(kernel_size=self.kernel_size,
457
- strides=self.strides,
458
- pads=pad_list,
459
- ceil_mode=self.ceil_mode,
460
- count_include_pad=count_include_pad,
461
- divisor_override=self.divisor_override,
462
- data_format=self.data_format,
463
- pad_mode=self.pad_mode)
464
-
465
- def bprop(x, out, dout):
466
- x_shape = F.shape(x)
467
- if F.is_sequence_value_unknown(x_shape):
468
- x_shape = P.TensorShape()(x)
469
- dx = avgpool3d_grad(x_shape, dout)
470
- return (dx,)
471
-
472
- return bprop
473
-
474
-
475
- @bprop_getters.register(P.DropoutGenMask)
476
- def get_bprop_dropout_gen_mask(self):
477
- """Grad definition for `DropoutGenMask` operation."""
478
-
479
- def bprop(shape, keep_prob, out, dout):
480
- return (zeros_like(shape), zeros_like(keep_prob))
481
-
482
- return bprop
483
-
484
-
485
- @bprop_getters.register(P.DropoutDoMask)
486
- def get_bprop_dropout_do_mask(self):
487
- """Grad definition for `DropoutDoMask` operation."""
488
- do_mask = P.DropoutDoMask()
489
-
490
- def bprop(x, y, keep_prob, out, dout):
491
- return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
492
-
493
- return bprop
494
-
495
-
496
- @bprop_getters.register(P.Mish)
497
- def get_bprop_mish(self):
498
- """Grad definition for `Mish` operation."""
499
- tanh = P.Tanh()
500
- tanh_grad = G.TanhGrad()
501
- softplus = P.Softplus()
502
- softplus_grad = G.SoftplusGrad()
503
-
504
- def bprop(x, out, dout):
505
- dx1 = tanh(softplus(x))
506
- dx2 = softplus_grad(tanh_grad(dx1, x * dout), x)
507
- dx = (dx1 * dout + dx2)
508
- return (dx,)
509
-
510
- return bprop
511
-
512
-
513
- @bprop_getters.register(P.SeLU)
514
- def get_bprop_selu(self):
515
- """Grad definition for `SeLU` operation."""
516
- scale = 1.0507009873554804934193349852946
517
- elu_grad = G.EluGrad()
518
-
519
- def bprop(x, out, dout):
520
- dx = elu_grad(dout, out) * scale
521
- return (dx,)
522
-
523
- return bprop
524
-
525
-
526
- @bprop_getters.register(P.MulNoNan)
527
- def get_bprop_mul_no_nan(self):
528
- """Grad definition for `MulNoNan` operation."""
529
- mul_no_nan = P.MulNoNan()
530
- reduce_sum = P.ReduceSum()
531
- reshape = P.Reshape()
532
-
533
- def bprop(x, y, out, dout):
534
- x_shape = F.shape(x)
535
- y_shape = F.shape(y)
536
- dx = mul_no_nan(dout, y)
537
- dy = mul_no_nan(x, dout)
538
- broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape)
539
- if broadcast_x != ():
540
- dx = reshape(reduce_sum(dx, broadcast_x), x_shape)
541
- if broadcast_y != ():
542
- dy = reshape(reduce_sum(dy, broadcast_y), y_shape)
543
- return dx, dy
544
-
545
- return bprop
546
-
547
-
548
- @bprop_getters.register(G.ReluGrad)
549
- def get_bprop_relu_grad(self):
550
- """Grad definition for `ReLUGrad` operation."""
551
- input_grad = G.ReluGrad()
552
-
553
- def bprop(grad, y, out, dout):
554
- dgrad = input_grad(dout, y)
555
- return dgrad, zeros_like(y)
556
-
557
- return bprop
558
-
559
-
560
- @bprop_getters.register(P.ReLU6)
561
- def get_bprop_relu6(self):
562
- """Grad definition for `ReLU6` operation."""
563
- input_grad = G.ReLU6Grad()
564
-
565
- def bprop(x, out, dout):
566
- dx = input_grad(dout, x)
567
- return (dx,)
568
-
569
- return bprop
570
-
571
-
572
- @bprop_getters.register(P.ReLUV2)
573
- def get_bprop_relu_v2(self):
574
- """Grad definition for `ReLUV2` operation."""
575
- input_grad = G.ReluGradV2()
576
-
577
- def bprop(x, out, dout):
578
- mask = out[1]
579
- dx = input_grad(dout[0], mask)
580
- return (dx,)
581
-
582
- return bprop
583
-
584
-
585
- @bprop_getters.register(P.HSwish)
586
- def get_bprop_hswish(self):
587
- """Grad definition for `HSwish` operation."""
588
- input_grad = G.HSwishGrad()
589
-
590
- def bprop(x, out, dout):
591
- dx = input_grad(dout, x)
592
- return (dx,)
593
-
594
- return bprop
595
-
596
-
597
- @bprop_getters.register(P.HSigmoid)
598
- def get_bprop_hsigmoid(self):
599
- """Grad definition for `HSigmoid` operation."""
600
- input_grad = G.HSigmoidGrad()
601
-
602
- def bprop(x, out, dout):
603
- dx = input_grad(dout, x)
604
- return (dx,)
605
-
606
- return bprop
607
-
608
-
609
- @bprop_getters.register(P.Elu)
610
- def get_bprop_elu(self):
611
- """Grad definition for `Elu` operation."""
612
- input_grad = G.EluGrad()
613
-
614
- def bprop(x, out, dout):
615
- dx = input_grad(dout, out)
616
- return (dx,)
617
-
618
- return bprop
619
-
620
-
621
- @bprop_getters.register(P.Sigmoid)
622
- def get_bprop_sigmoid(self):
623
- """Grad definition for `Sigmoid` operation."""
624
- input_grad = G.SigmoidGrad()
625
-
626
- def bprop(x, out, dout):
627
- dx = input_grad(out, dout)
628
- return (dx,)
629
-
630
- return bprop
631
-
632
-
633
- @bprop_getters.register(G.SigmoidGrad)
634
- def get_bprop_sigmoid_grad(self):
635
- """Grad definition for `SigmoidGrad` operation."""
636
- sigmoid_grad = G.SigmoidGrad()
637
-
638
- def bprop(y, grad, out, dout):
639
- dy = dout * grad * (1. - 2 * y)
640
- dgrad = sigmoid_grad(y, dout)
641
- return dy, dgrad
642
-
643
- return bprop
644
-
645
-
646
- @_primexpr
647
- def _get_transpose_axis(x_shp, axis):
648
- rank = len(x_shp)
649
- if axis < 0:
650
- axis += rank
651
- reverse_axis = [i for i in range(rank)]
652
- reverse_axis[axis] = rank - 1
653
- reverse_axis[rank - 1] = axis
654
- return tuple(reverse_axis)
655
-
656
-
657
- def _get_dyn_transpose_axis(x, axis, is_ascend):
658
- """Get transpose axis"""
659
- if F.is_sequence_shape_unknown(P.Shape()(x)):
660
- rank = dyn_rank(x)
661
- start = Tensor(0, dtype=mstype.int64)
662
- delta = Tensor(1, dtype=mstype.int64)
663
- else:
664
- rank = P.Cast()(len(P.Shape()(x)), mstype.int64)
665
- start = P.Cast()(0, mstype.int64)
666
- delta = P.Cast()(1, mstype.int64)
667
-
668
- if axis < 0:
669
- axis += rank
670
- range_ops = P.Range()
671
-
672
- reverse_axis = range_ops(start, rank, delta)
673
- if is_ascend:
674
- reverse_axis = P.Cast()(reverse_axis, mstype.int8)
675
- axis = P.Cast()(axis, mstype.int32)
676
- reverse_axis[axis] = rank - 1
677
- rank = P.Cast()(rank, mstype.int32)
678
- else:
679
- reverse_axis[axis] = rank - 1
680
-
681
- reverse_axis[rank - 1] = axis
682
- return reverse_axis
683
-
684
-
685
- @bprop_getters.register(P.Softmax)
686
- def get_bprop_softmax(self):
687
- """Grad definition for `Softmax` operation."""
688
- sum_func = P.ReduceSum(keep_dims=True)
689
- sub = P.Sub()
690
- mul = P.Mul()
691
- get_shape = P.Shape()
692
- transpose = P.Transpose()
693
- axis = self.axis
694
- if not isinstance(axis, int):
695
- axis = axis[0]
696
-
697
- device_target = context.get_context("device_target")
698
- is_ascend = (device_target == "Ascend")
699
-
700
- def bprop(x, out, dout):
701
- # dx can be expressed as (dout - sum(dout * out)) * out
702
- # This formula is correct only when the `axis` is the last dimension.
703
- # In order to support the scenario where the `axis` is other values,
704
- # we transpose the data of the `axis` dimension to the last dimension for calculation,
705
- # and then transpose it back after the calculation.
706
- shp = get_shape(x)
707
- if F.is_sequence_value_unknown(shp):
708
- reverse_axis = _get_dyn_transpose_axis(x, axis, is_ascend)
709
- if is_ascend:
710
- reverse_axis = P.Cast()(reverse_axis, mstype.int32)
711
- else:
712
- reverse_axis = _get_transpose_axis(get_shape(x), axis)
713
- out = transpose(out, reverse_axis)
714
- dout = transpose(dout, reverse_axis)
715
- dx = mul(out, sub(dout, sum_func(mul(out, dout), -1)))
716
- dx = transpose(dx, reverse_axis)
717
- return (dx,)
718
-
719
- return bprop
720
-
721
-
722
- @bprop_getters.register(P.LogSoftmax)
723
- def get_bprop_log_softmax(self):
724
- """Grad definition for `LogSoftmax` operation."""
725
- logsoftmax_grad = G.LogSoftmaxGrad(self.axis)
726
-
727
- def bprop(x, out, dout):
728
- dx = logsoftmax_grad(out, dout)
729
- return (dx,)
730
-
731
- return bprop
732
-
733
-
734
- @bprop_getters.register(P.Softplus)
735
- def get_bprop_softplus(self):
736
- """Grad definition for `Softplus` operation."""
737
- softplus_grad = G.SoftplusGrad()
738
-
739
- def bprop(x, out, dout):
740
- dx = softplus_grad(dout, x)
741
- return (dx,)
742
-
743
- return bprop
744
-
745
-
746
- @bprop_getters.register(P.Softsign)
747
- def get_bprop_softsign(self):
748
- """Grad definition for `Softsign` operation."""
749
- mul = P.Mul()
750
- absolute = P.Abs()
751
- div = P.Div()
752
- square = P.Square()
753
-
754
- def bprop(x, out, dout):
755
- dx = mul(dout, div(1, square(1 + absolute(x))))
756
- return (dx,)
757
-
758
- return bprop
759
-
760
-
761
- @bprop_getters.register(P.Tanh)
762
- def get_bprop_tanh(self):
763
- """Grad definition for `Tanh` operation."""
764
- tanh_grad = G.TanhGrad()
765
- conj = P.Conj()
766
-
767
- def bprop(x, out, dout):
768
- if x.dtype in (mstype.complex64, mstype.complex128):
769
- dout = conj(dout)
770
- dx = tanh_grad(out, dout)
771
- dx = conj(dx)
772
- else:
773
- dx = tanh_grad(out, dout)
774
- return (dx,)
775
-
776
- return bprop
777
-
778
-
779
- @bprop_getters.register(G.TanhGrad)
780
- def get_bprop_tanh_grad(self):
781
- """Grad definition for `TanhGrad` operation."""
782
- tanh_grad = G.TanhGrad()
783
-
784
- def bprop(y, grad, out, dout):
785
- dy = dout * -2.0 * grad * y
786
- dgrad = tanh_grad(y, dout)
787
- return dy, dgrad
788
-
789
- return bprop
790
-
791
-
792
- @bprop_getters.register(P.FastGeLU)
793
- def get_bprop_fast_gelu(self):
794
- """Grad definition for `FastGeLU` operation."""
795
- input_grad = G.FastGeLUGrad()
796
-
797
- def bprop(x, out, dout):
798
- dx = input_grad(dout, x)
799
- return (dx,)
800
-
801
- return bprop
802
-
803
-
804
- @bprop_getters.register(P.FastGelu)
805
- def get_bprop_fast_gelu_2(self):
806
- """Grad definition for `FastGeLU` operation."""
807
- input_grad = G.FastGeLUGrad()
808
-
809
- def bprop(x, out, dout):
810
- dx = input_grad(dout, x)
811
- return (dx,)
812
-
813
- return bprop
814
-
815
-
816
- @bprop_getters.register(P.InstanceNorm)
817
- def get_bprop_instance_norm(self):
818
- """Grad definition for `InstanceNorm` operation."""
819
- input_grad = G.InstanceNormGrad(self.epsilon, self.momentum)
820
-
821
- def bprop(x, gamma, beta, mean, variance, out, dout):
822
- saved_mean = out[1]
823
- saved_variance = out[2]
824
- out = input_grad(dout[0], x, gamma, saved_mean, saved_variance)
825
- dx = out[0]
826
- dgamma = out[1]
827
- dbeta = out[2]
828
- return dx, dgamma, dbeta, zeros_like(mean), zeros_like(variance)
829
-
830
- return bprop
831
-
832
-
833
- @bprop_getters.register(G.BatchNormGrad)
834
- def get_bprop_batch_norm_grad(self):
835
- """Grad definition for `BatchNorm` operation."""
836
- grad_op = G.BatchNormGradGrad(self.is_training, self.epsilon, self.data_format)
837
-
838
- def bprop(dy, x, scale, mean, variance, reserve, out, dout):
839
- dx, ddy, dscale = grad_op(x, dy, scale, mean, variance, dout[0], dout[1], dout[2])
840
- return ddy, dx, dscale, zeros_like(mean), zeros_like(variance), zeros_like(reserve)
841
-
842
- return bprop
843
-
844
-
845
- @bprop_getters.register(G.LayerNormGrad)
846
- def get_bprop_layer_norm_grad(self):
847
- """Grad definition for `LayerNormGrad` operation."""
848
- layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis)
849
-
850
- def bprop(x, dy, variance, mean, gamma, out, dout):
851
- d_x, d_dy, d_gamma = layer_norm_grad_grad(
852
- x, dy, variance, mean, gamma, dout[0], dout[1], dout[2])
853
- return d_x, d_dy, zeros_like(variance), zeros_like(mean), d_gamma
854
-
855
- return bprop
856
-
857
-
858
- @bprop_getters.register(P.L2Normalize)
859
- def get_bprop_l2normalize(self):
860
- """Grad definition for `L2Normalize` operation."""
861
- input_grad = G.L2NormalizeGrad(self.axis, self.epsilon)
862
-
863
- def bprop(x, out, dout):
864
- dx = input_grad(x, out, dout)
865
- return (dx,)
866
-
867
- return bprop
868
-
869
-
870
- @bprop_getters.register(P.SoftmaxCrossEntropyWithLogits)
871
- def get_bprop_softmax_cross_entropy_with_logits(self):
872
- """Grad definition for `SoftmaxCrossEntropyWithLogits` operation."""
873
- expand = P.ExpandDims()
874
-
875
- def bprop(logits, labels, out, dout):
876
- grad = out[1]
877
- grad = grad * expand(dout[0], -1)
878
- return grad, zeros_like(labels)
879
-
880
- return bprop
881
-
882
-
883
- @bprop_getters.register(P.NLLLoss)
884
- def get_bprop_nll_loss(self):
885
- """Grad definition for `NLLLoss` operation."""
886
- nll_loss_grad = G.NLLLossGrad(reduction=self.reduction)
887
-
888
- def bprop(x, target, weight, out, dout):
889
- total_weight = out[1]
890
- dout_x = dout[0]
891
- dx = nll_loss_grad(x, dout_x, target, weight, total_weight)
892
- return dx, zeros_like(target), zeros_like(weight)
893
-
894
- return bprop
895
-
896
-
897
- @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
898
- def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
899
- """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation."""
900
- is_grad = self.is_grad
901
- grad_op = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
902
-
903
- def bprop(logits, labels, out, dout):
904
- grad = out[0]
905
- if not is_grad:
906
- # if construct use loss
907
- grad = grad_op(logits, labels)
908
- grad = F.depend(grad, out)
909
- grad = grad * dout
910
- return grad, zeros_like(labels)
911
-
912
- return bprop
913
-
914
-
915
- @bprop_getters.register(P.ResizeBilinear)
916
- def get_bprop_resize_bilinear(self):
917
- """Grad definition for `ResizeBilinear` operation."""
918
- resize_grad = G.ResizeBilinearGrad(self.align_corners, self.half_pixel_centers)
919
-
920
- def bprop(x, out, dout):
921
- dx = resize_grad(dout, x)
922
- return (dx,)
923
-
924
- return bprop
925
-
926
-
927
- @bprop_getters.register(P.OneHot)
928
- def get_bprop_onehot(self):
929
- """Grad definition for `OneHot` operation."""
930
-
931
- def bprop(indices, depth, on_value, off_value, out, dout):
932
- return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
933
-
934
- return bprop
935
-
936
-
937
- @bprop_getters.register(P.TopK)
938
- def get_bprop_top_kv2(self):
939
- """Grad definition for `TopK` operation."""
940
- scatter = P.ScatterNd()
941
- expand_dims = P.ExpandDims()
942
- shape_op = P.Shape()
943
- dyn_shape = P.TensorShape()
944
- reshape_op = P.Reshape()
945
- dtype = P.DType()
946
- cast = P.Cast()
947
-
948
- def _bprop_static(input_x, k, out, dout):
949
- in_shape = shape_op(input_x)
950
- in_lastdim = in_shape[-1]
951
-
952
- indices = out[1]
953
- ind_shape = shape_op(indices)
954
- ind_lastdim = ind_shape[-1]
955
-
956
- ind_2d = reshape_op(indices, (-1, ind_lastdim))
957
- outerdim = shape_op(ind_2d)[0]
958
-
959
- # range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
960
- indices_dtype = dtype(indices)
961
- range_flatten_index = range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
962
-
963
- # expand_dims to (k, 1), then broadcast
964
- ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
965
- in_shape_1d = get_1d_shape(in_shape)
966
-
967
- out_grad = reshape_op(
968
- scatter(
969
- expand_dims(ind, -1),
970
- reshape_op(dout[0], (-1,)),
971
- in_shape_1d),
972
- in_shape)
973
- return out_grad, zeros_like(k)
974
-
975
- def _bprop_dynshape(input_x, k, out, dout):
976
- in_shape = dyn_shape(input_x)
977
- in_lastdim = in_shape[-1]
978
-
979
- indices = out[1]
980
- ind_shape = dyn_shape(indices)
981
- ind_lastdim = ind_shape[-1]
982
-
983
- ind_2d = reshape_op(indices, create_tensor_by_element((-1, ind_lastdim)))
984
- outerdim = dyn_shape(ind_2d)[0]
985
-
986
- # range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
987
- range_flatten_index = P.Range()(cast(0, mstype.int64), outerdim * in_lastdim, in_lastdim)
988
-
989
- # expand_dims to (k, 1), then broadcast
990
- ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), create_tensor_by_element((-1,)))
991
- in_shape_1d = expand_dims(dyn_size(input_x, mstype.int64), -1)
992
-
993
- out_grad = reshape_op(
994
- scatter(
995
- expand_dims(ind, -1),
996
- reshape_op(dout[0], create_tensor_by_element((-1,))),
997
- in_shape_1d),
998
- in_shape)
999
- return out_grad, zeros_like(k)
1000
-
1001
- def bprop(input_x, k, out, dout):
1002
- if F.is_sequence_value_unknown(shape_op(input_x)):
1003
- return _bprop_dynshape(input_x, k, out, dout)
1004
- return _bprop_static(input_x, k, out, dout)
1005
-
1006
- return bprop
1007
-
1008
-
1009
- @bprop_getters.register(P.SmoothL1Loss)
1010
- def get_bprop_smooth_l1_loss(self):
1011
- """Grad definition for `SmoothL1Loss` operation."""
1012
- grad = G.SmoothL1LossGrad(self.beta, self.reduction)
1013
-
1014
- def bprop(prediction, target, out, dout):
1015
- dx = grad(prediction, target, dout)
1016
- dy = grad(target, prediction, dout)
1017
- return dx, dy
1018
-
1019
- return bprop
1020
-
1021
-
1022
- @bprop_getters.register(P.L2Loss)
1023
- def get_bprop_l2_loss(self):
1024
- """Grad definition for `L2Loss` operation."""
1025
-
1026
- def bprop(x, out, dout):
1027
- dx = x * dout
1028
- return (dx,)
1029
-
1030
- return bprop
1031
-
1032
-
1033
- @bprop_getters.register(P.RNNTLoss)
1034
- def get_bprop_rnnt_loss(self):
1035
- """Grad definition for `RNNTLoss` operation."""
1036
-
1037
- def bprop(acts, labels, act_lens, label_lens, out, dout):
1038
- grad = out[1]
1039
- return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
1040
-
1041
- return bprop
1042
-
1043
-
1044
- @bprop_getters.register(P.PReLU)
1045
- def get_bprop_prelu(self):
1046
- """Grad definition for `PReLU` operation."""
1047
- grad = G.PReLUGrad()
1048
-
1049
- def bprop(x, w, out, dout):
1050
- dx, dw = grad(dout, x, w)
1051
- return dx, dw
1052
-
1053
- return bprop
1054
-
1055
-
1056
- @bprop_getters.register(P.LSTM)
1057
- def get_bprop_lstm(self):
1058
- """Grad definition for `LSTM` operation."""
1059
- lstm_grad_data = G.LSTMGradData(
1060
- input_size=self.input_size,
1061
- hidden_size=self.hidden_size,
1062
- num_layers=self.num_layers,
1063
- has_bias=self.has_bias,
1064
- bidirectional=self.bidirectional,
1065
- dropout=self.dropout
1066
- )
1067
-
1068
- lstm_grad_weight = G.LSTMGradWeight(
1069
- input_size=self.input_size,
1070
- hidden_size=self.hidden_size,
1071
- num_layers=self.num_layers,
1072
- has_bias=self.has_bias,
1073
- bidirectional=self.bidirectional,
1074
- dropout=self.dropout
1075
- )
1076
- lstm_grad = G.LSTMGrad(
1077
- input_size=self.input_size,
1078
- hidden_size=self.hidden_size,
1079
- num_layers=self.num_layers,
1080
- has_bias=self.has_bias,
1081
- bidirectional=self.bidirectional,
1082
- dropout=self.dropout
1083
- )
1084
-
1085
- def bprop(x, hx, cx, w, out, dout):
1086
- y, _, _, reserve, state = out
1087
- dy, dhy, dcy, _, _ = dout
1088
- dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
1089
- dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
1090
- return dx, dhx, dcx, dw
1091
-
1092
- #
1093
- def bprop_cpu(x, hx, cx, w, out, dout):
1094
- y, hy, cy, reserve, _ = out
1095
- dy, dhy, dcy, _, _ = dout
1096
- dx, dhx, dcx, dw = lstm_grad(x, hx, cx, w, y, hy, cy, dy, dhy, dcy, reserve)
1097
- return dx, dhx, dcx, dw
1098
-
1099
- if context.get_context('device_target') == "CPU":
1100
- self.add_prim_attr("is_training", True)
1101
- return bprop_cpu
1102
-
1103
- return bprop
1104
-
1105
-
1106
- @bprop_getters.register(rl_ops.GRUV2)
1107
- def get_bppro_gru_v2(self):
1108
- """Grad definition for `GRUV2` operation."""
1109
- gru_grad_v2 = G.GRUV2Grad(
1110
- self.input_size,
1111
- self.hidden_size,
1112
- self.num_layers,
1113
- self.has_bias,
1114
- self.bidirectional,
1115
- self.dropout
1116
- )
1117
-
1118
- def bpro(x, hx, w, seq_length, out, dout):
1119
- y, hy, reverse, _ = out
1120
- dy, dhy, _, _ = dout
1121
- dx, dhx, dw = gru_grad_v2(x, hx, w, seq_length, y, hy, dy, dhy, reverse)
1122
- return dx, dhx, dw, (0)
1123
-
1124
- return bpro
1125
-
1126
-
1127
- @bprop_getters.register(rl_ops.CudnnGRU)
1128
- def get_bprop_gru(self):
1129
- """Grad definition for `GRU` operation."""
1130
- gru_grad_data = G.GruGradData(
1131
- input_size=self.input_size,
1132
- hidden_size=self.hidden_size,
1133
- num_layers=self.num_layers,
1134
- has_bias=self.has_bias,
1135
- bidirectional=self.bidirectional,
1136
- dropout=self.dropout
1137
- )
1138
-
1139
- gru_grad_weight = G.GruGradWeight(
1140
- input_size=self.input_size,
1141
- hidden_size=self.hidden_size,
1142
- num_layers=self.num_layers,
1143
- has_bias=self.has_bias,
1144
- bidirectional=self.bidirectional,
1145
- dropout=self.dropout
1146
- )
1147
-
1148
- def bprop(x, hx, w, out, dout):
1149
- y, _, reserve, state = out
1150
- dy, dhy, _, _ = dout
1151
- dx, dhx = gru_grad_data(y, dy, dhy, w, hx, reserve, state)
1152
- dw = gru_grad_weight(F.depend(x, dx), hx, y, reserve, state)
1153
- return dx, dhx, dw
1154
-
1155
- return bprop
1156
-
1157
-
1158
- @bprop_getters.register(P.DynamicRNN)
1159
- def get_bprop_dynamic_rnn(self):
1160
- """Grad definition for `DynamicRNN` operation."""
1161
- dynamic_rnn_grad = G.DynamicRNNGrad(cell_type=self.cell_type,
1162
- direction=self.direction,
1163
- cell_depth=self.cell_depth,
1164
- use_peephole=self.use_peephole,
1165
- keep_prob=self.keep_prob,
1166
- cell_clip=self.cell_clip,
1167
- num_proj=self.num_proj,
1168
- time_major=self.time_major,
1169
- forget_bias=self.forget_bias)
1170
- expand_dims = P.ExpandDims()
1171
-
1172
- def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
1173
- dy, dh, dc, _, _, _, _, _, = dout
1174
- dh = dh[-1]
1175
- dc = dc[-1]
1176
- y, h, c, i, j, f, o, tanhct = out
1177
- dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
1178
- c, dy, dh, dc, i, j, f, o, tanhct)
1179
- dh_prev = expand_dims(dh_prev, 0)
1180
- dc_prev = expand_dims(dc_prev, 0)
1181
- return dx, dw, db, (0), dh_prev, dc_prev
1182
-
1183
- return bprop
1184
-
1185
-
1186
- @bprop_getters.register(P.DynamicGRUV2)
1187
- def get_bprop_dynamic_gru_v2(self):
1188
- """Grad definition for `DynamicGRUV2` operation."""
1189
- dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip,
1190
- self.num_proj, self.time_major, self.gate_order,
1191
- self.reset_after)
1192
-
1193
- def bprop(x, winput, whidden, binput, bhidden, seq, init_h, out, dout):
1194
- y, out_h, update, reset, new, hidden_new = out
1195
- dy, dout_h, _, _, _, _ = dout
1196
-
1197
- dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev = dynamic_gru_v2_grad(x, winput, whidden, y, init_h,
1198
- out_h, dy, dout_h[-1], update,
1199
- reset, new, hidden_new, None, None)
1200
- return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
1201
-
1202
- return bprop
1203
-
1204
-
1205
- @bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
1206
- def get_bprop_sigmoid_crossentropy_with_logits(self):
1207
- """Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
1208
- op = G.SigmoidCrossEntropyWithLogitsGrad()
1209
-
1210
- def bprop(x, y, out, dout):
1211
- dx = op(x, y, dout)
1212
- return (dx, zeros_like(y))
1213
-
1214
- return bprop
1215
-
1216
-
1217
- @bprop_getters.register(P.Pad)
1218
- def get_bprop_pad(self):
1219
- """Grad definition for `Pad` operation."""
1220
- shape_op = P.Shape()
1221
- dyn_shape_op = P.TensorShape()
1222
- paddings = self.paddings
1223
-
1224
- def bprop(x, out, dout):
1225
- begin = ()
1226
- for item in paddings:
1227
- begin += (item[0],)
1228
- shp = shape_op(x)
1229
- if F.is_sequence_value_unknown(shp):
1230
- shp = dyn_shape_op(x)
1231
- dx = P.Slice()(dout, begin, shp)
1232
- return (dx,)
1233
-
1234
- return bprop
1235
-
1236
-
1237
- @bprop_getters.register(P.MirrorPad)
1238
- def get_bprop_mirror_pad(self):
1239
- """Grad definition for `MirrorPad` operation."""
1240
- mirror_pad_grad = G.MirrorPadGrad(self.mode)
1241
-
1242
- def bprop(x, paddings, out, dout):
1243
- dx = mirror_pad_grad(dout, paddings)
1244
- return (dx, zeros_like(paddings))
1245
-
1246
- return bprop
1247
-
1248
-
1249
- @bprop_getters.register(P.ROIAlign)
1250
- def get_bprop_roi_align(self):
1251
- """Grad definition for `ROIAlign` operation."""
1252
- shape_op = P.Shape()
1253
- dyn_shape = P.TensorShape()
1254
- pooled_height = self.pooled_height
1255
- pooled_width = self.pooled_width
1256
- spatial_scale = self.spatial_scale
1257
- sample_num = self.sample_num
1258
-
1259
- def bprop(inputs, rois, out, dout):
1260
- inputs_shape = shape_op(inputs)
1261
- if F.is_sequence_value_unknown(inputs_shape):
1262
- inputs_shape = dyn_shape(inputs)
1263
- dx = G.ROIAlignGrad(pooled_height, pooled_width, spatial_scale, sample_num)(dout, rois, inputs_shape)
1264
- return dx, zeros_like(rois)
1265
-
1266
- return bprop
1267
-
1268
-
1269
- @bprop_getters.register(P.Conv2DTranspose)
1270
- @bprop_getters.register(P.Conv2DBackpropInput)
1271
- def get_bprop_conv2d_backprop_input(self):
1272
- """Grad definition for `Conv2DBackpropInput` operation."""
1273
- pad_list = self.get_attr_dict()['pad_list']
1274
- out_channel = self.get_attr_dict()['out_channel']
1275
- filter_grad = G.Conv2DBackpropFilter(
1276
- out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
1277
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
1278
- )
1279
- input_grad = P.Conv2D(
1280
- out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
1281
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
1282
- )
1283
- get_shape = P.Shape()
1284
- get_dyn_shape = P.TensorShape()
1285
-
1286
- def bprop(x, w, f_sizes, out, dout):
1287
- w_shape = get_shape(w)
1288
- if F.is_sequence_value_unknown(w_shape):
1289
- w_shape = get_dyn_shape(w)
1290
- dx = input_grad(dout, w)
1291
- dw = filter_grad(x, dout, w_shape)
1292
- return dx, dw, zeros_like(f_sizes)
1293
-
1294
- return bprop
1295
-
1296
-
1297
- @bprop_getters.register(P.BinaryCrossEntropy)
1298
- def get_bprop_binary_cross_entropy(self):
1299
- """Grad definition for `BinaryCrossEntropy` operation."""
1300
- grad = G.BinaryCrossEntropyGrad(self.reduction)
1301
-
1302
- def bprop(x, y, weight, out, dout):
1303
- dx = grad(x, y, dout, weight)
1304
- return dx, zeros_like(y), zeros_like(weight)
1305
-
1306
- return bprop
1307
-
1308
-
1309
- @bprop_getters.register(P.BCEWithLogitsLoss)
1310
- def get_bprop_bce_with_logits_loss(self):
1311
- """Grad definition for `BCEWithLogitsLoss` operation."""
1312
- reduction = self.reduction
1313
- mul = P.Mul()
1314
- sigmoid = P.Sigmoid()
1315
- add = P.Add()
1316
- sub = P.Sub()
1317
- size = P.Size()
1318
- neg = P.Neg()
1319
- log = P.Log()
1320
- shape = P.Shape()
1321
-
1322
- def bprop(predict, target, weight, pos_weight, out, dout):
1323
- sigmoid_input = sigmoid(predict)
1324
- if pos_weight is not None:
1325
- t = mul(target, pos_weight)
1326
- dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout)
1327
- grad_target = mul(sub(log(sub(1, sigmoid_input)), mul(pos_weight, log(sigmoid_input))), dout)
1328
- else:
1329
- dx = mul((sigmoid_input - target), dout)
1330
- grad_target = mul(predict, neg(dout))
1331
- if weight is not None:
1332
- dx = mul(dx, weight)
1333
- grad_target = mul(grad_target, weight)
1334
- if reduction == 'mean':
1335
- dx_size = dyn_size(dx) if F.is_sequence_value_unknown(shape(dx)) else size(dx)
1336
- target_size = dyn_size(target) if F.is_sequence_value_unknown(shape(target)) else size(target)
1337
- dx = dx / dx_size
1338
- grad_target = grad_target / target_size
1339
- return dx, grad_target, zeros_like(weight), zeros_like(pos_weight)
1340
-
1341
- return bprop
1342
-
1343
-
1344
- @bprop_getters.register(P.KLDivLoss)
1345
- def get_bprop_kl_div_loss(self):
1346
- """Grad definition for `KLDivLoss` operation."""
1347
- reduce_type = self.reduction
1348
-
1349
- size = P.Size()
1350
- shape = P.Shape()
1351
-
1352
- def bprop(x, y, out, dout):
1353
- if reduce_type == "mean":
1354
- grad = G.KLDivLossGrad("sum")
1355
- else:
1356
- grad = G.KLDivLossGrad(self.reduction)
1357
- dx = grad(dout, x, y)
1358
- if reduce_type == "mean":
1359
- x_size = dyn_size(x) if F.is_sequence_value_unknown(shape(x)) else size(x)
1360
- return dx / x_size, zeros_like(y)
1361
- return dx, zeros_like(y)
1362
-
1363
- return bprop
1364
-
1365
-
1366
- @bprop_getters.register(P.Dropout)
1367
- def get_bprop_dropout(self):
1368
- """Grad definition for `Dropout` operation."""
1369
- grad = G.DropoutGrad(self.keep_prob)
1370
-
1371
- def bprop(x, out, dout):
1372
- _, mask = out
1373
- dy, _ = dout
1374
- dx = grad(dy, mask)
1375
- return (dx,)
1376
-
1377
- return bprop
1378
-
1379
-
1380
- @bprop_getters.register(G.DropoutGrad)
1381
- def get_bprop_dropout_grad(self):
1382
- """Grad definition for `DropoutGrad` operation."""
1383
- grad = G.DropoutGrad(self.keep_prob)
1384
-
1385
- def bprop(x, mask, out, dout):
1386
- dy = dout
1387
- dx = grad(dy, mask)
1388
- return dx, zeros_like(mask)
1389
-
1390
- return bprop
1391
-
1392
-
1393
- @bprop_getters.register(P.Dropout2D)
1394
- @bprop_getters.register(P.Dropout3D)
1395
- def get_bprop_dropout3d(self):
1396
- """Grad definition for `Dropout2D` and `Dropout3D` operation."""
1397
- dtype = P.DType()
1398
- cast = P.Cast()
1399
- mul = P.Mul()
1400
- keep_prob = self.keep_prob
1401
-
1402
- def bprop(x, out, dout):
1403
- _, mask = out
1404
- dy, _ = dout
1405
- mask = cast(mask, mstype.float32)
1406
- if keep_prob != 0:
1407
- dy = dy * (1 / keep_prob)
1408
- dy = mul(mask, dy)
1409
- dy = cast(dy, dtype(x))
1410
- return (dy,)
1411
-
1412
- return bprop
1413
-
1414
-
1415
- @bprop_getters.register(P.CTCLoss)
1416
- def get_bprop_ctc_loss(self):
1417
- """Grad definition for `CTCLoss` operation"""
1418
- expand = P.ExpandDims()
1419
-
1420
- def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
1421
- grad_loss = out[1]
1422
- grad = grad_loss * expand(dout[0], -1)
1423
- return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
1424
-
1425
- return bprop
1426
-
1427
-
1428
- @bprop_getters.register(P.BasicLSTMCell)
1429
- def get_bprop_basic_lstm_cell(self):
1430
- """Grad definition for `BasicLSTMCell` operation."""
1431
- basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad(
1432
- forget_bias=self.forget_bias,
1433
- activation=self.activation
1434
- )
1435
-
1436
- basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad()
1437
-
1438
- basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob)
1439
-
1440
- def bprop(x, h, c, w, b, out, dout):
1441
- _, _, it, jt, ft, ot, tanhct = out
1442
- dct, dht, _, _, _, _, _ = dout
1443
- dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
1444
- dxt, dht = basic_lstm_cell_input_grad(dgate, w)
1445
- dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
1446
- return dxt, dht, dct_1, dw, db
1447
-
1448
- return bprop
1449
-
1450
-
1451
- @bprop_getters.register(nps.DeformableOffsets)
1452
- def get_bprop_deformable_offsets(self):
1453
- """Grad definition for `DeformableOffsets` operation."""
1454
- grad = G.DeformableOffsetsGrad(self.strides, self.pads, self.ksize, self.dilations, self.data_format,
1455
- self.deformable_groups, self.modulated)
1456
-
1457
- def bprop(x, offsets, out, dout):
1458
- out_grad = grad(dout, x, offsets)
1459
- return out_grad
1460
-
1461
- return bprop
1462
-
1463
-
1464
- @bprop_getters.register(P.LRN)
1465
- def get_bprop_lrn(self):
1466
- """Grad definition for `LRN` operation."""
1467
- grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta)
1468
-
1469
- def bprop(x, out, dout):
1470
- dx = grad(dout, x, out)
1471
- return (dx,)
1472
-
1473
- return bprop
1474
-
1475
-
1476
- @bprop_getters.register(G.Conv2DBackpropFilter)
1477
- def get_bprop_conv2d_backprop_filter(self):
1478
- """Grad definition for `Conv2DBackpropFilter` operation."""
1479
- input_grad = P.Conv2DBackpropInput(
1480
- self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
1481
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
1482
- )
1483
- filter_grad = P.Conv2D(
1484
- self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
1485
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
1486
- )
1487
- get_shape = P.Shape()
1488
- get_dyn_shape = P.TensorShape()
1489
-
1490
- def bprop(dy, x, filter_size, out, dout):
1491
- x_shape = get_shape(x)
1492
- if F.is_sequence_value_unknown(x_shape):
1493
- x_shape = get_dyn_shape(x)
1494
- dw_dx = input_grad(dy, dout, x_shape)
1495
- dw_dy = filter_grad(x, dout)
1496
- return dw_dy, dw_dx, zeros_like(filter_size)
1497
-
1498
- return bprop
1499
-
1500
-
1501
- @bprop_getters.register(nps.UpsampleNearest3D)
1502
- def get_bprop_upsample_nearest_3d_grad(self):
1503
- """Grad definition for `UpsampleNearest3D` operation."""
1504
- get_shape = P.Shape()
1505
- output_size = self.output_size
1506
- scales = self.scales
1507
-
1508
- def bprop(input_x, out, dout):
1509
- input_grad = G.UpsampleNearest3DGrad(get_shape(input_x), output_size, scales)
1510
- dx = input_grad(dout)
1511
- return (dx,)
1512
-
1513
- return bprop
1514
-
1515
-
1516
- @bprop_getters.register(nps.UpsampleTrilinear3D)
1517
- def get_bprop_upsample_trilinear_3d_grad(self):
1518
- """Grad definition for `UpsampleTrilinear3D` operation."""
1519
- get_shape = P.Shape()
1520
- output_size = self.output_size
1521
- scales = self.scales
1522
- align_corners = self.align_corners
1523
-
1524
- def bprop(input_x, out, dout):
1525
- input_grad = G.UpsampleTrilinear3DGrad(get_shape(input_x), output_size, scales, align_corners)
1526
- dx = input_grad(dout)
1527
- return (dx,)
1528
-
1529
- return bprop