mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,1347 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """array_ops"""
17
-
18
- import numpy as np
19
- import mindspore as ms
20
- from mindspore import Tensor
21
- from mindspore.ops import operations as P
22
- from mindspore.ops.operations import _grad_ops as G
23
- from mindspore.ops.operations.array_ops import Fills, NonZero
24
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
25
- from mindspore.ops.functional import broadcast_gradient_args
26
- from mindspore.ops import functional as F
27
- from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element
28
- from mindspore.ops.primitive import constexpr
29
- from mindspore.ops.primitive import _primexpr
30
- from mindspore.common import dtype as mstype
31
- from mindspore.common.sparse_tensor import RowTensorInner
32
- from mindspore.ops._utils.utils import range_op, get_1d_shape, generate_shape_index
33
- from mindspore.ops._grad.grad_base import dyn_rank, convert_to_tensor, dyn_ones, dyn_fill
34
- from mindspore.ops._grad.grad_base import sum_grad_reduce_axis
35
- from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs
36
- from ..operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass
37
-
38
- reduce_sum = P.ReduceSum()
39
- unsorted_segment_sum = P.UnsortedSegmentSum()
40
- transpose = P.Transpose()
41
- shape_op = P.Shape()
42
- dyn_shape_op = P.TensorShape()
43
- reshape = P.Reshape()
44
- size_op = P.Size()
45
- invert_permutation = P.InvertPermutation()
46
- logical_and = P.LogicalAnd()
47
- is_sub_class = IsSubClass()
48
-
49
-
50
- @bprop_getters.register(P.Fill)
51
- def get_bprop_fill(self):
52
- """Generate bprop for Fill"""
53
-
54
- def bprop(dtype, dims, x, out, dout):
55
- return zeros_like(dims), zeros_like(x)
56
-
57
- return bprop
58
-
59
-
60
- @bprop_getters.register(Fills)
61
- def get_bprop_fills(self):
62
- """Generate bprop for Fills."""
63
-
64
- def bprop(x, value, out, dout):
65
- return zeros_like(x), zeros_like(value)
66
-
67
- return bprop
68
-
69
-
70
- @bprop_getters.register(P.Ones)
71
- def get_bprop_ones(self):
72
- """Generate bprop for Ones"""
73
-
74
- def bprop(dims, dtype, out, dout):
75
- return zeros_like(dims)
76
-
77
- return bprop
78
-
79
-
80
- @bprop_getters.register(P.Zeros)
81
- def get_bprop_zeros(self):
82
- """Generate bprop for Zeros"""
83
-
84
- def bprop(dims, dtype, out, dout):
85
- return zeros_like(dims)
86
-
87
- return bprop
88
-
89
-
90
- @bprop_getters.register(P.DType)
91
- def get_bprop_dtype(self):
92
- """Generate bprop for DType"""
93
-
94
- def bprop(x, out, dout):
95
- return (zeros_like(x),)
96
-
97
- return bprop
98
-
99
-
100
- @bprop_getters.register(P.Shape)
101
- def get_bprop_shape(self):
102
- """Generate bprop for Shape"""
103
-
104
- def bprop(x, out, dout):
105
- return (zeros_like(x),)
106
-
107
- return bprop
108
-
109
-
110
- @bprop_getters.register(P.DynamicShape)
111
- def get_bprop_dynamicshape(self):
112
- """Generate bprop for DynamicShape"""
113
-
114
- def bprop(x, out, dout):
115
- return (zeros_like(x),)
116
-
117
- return bprop
118
-
119
-
120
- @bprop_getters.register(P.TensorShape)
121
- def get_bprop_tensorshape(self):
122
- """Generate bprop for TensorShape"""
123
-
124
- def bprop(x, out, dout):
125
- return (zeros_like(x),)
126
-
127
- return bprop
128
-
129
-
130
- @bprop_getters.register(P.Split)
131
- def get_bprop_split(self):
132
- """Generate bprop for Split"""
133
- axis = self.axis
134
-
135
- def bprop(x, out, dout):
136
- concat_op = P.Concat(axis)
137
- dx = concat_op(dout)
138
- return (dx,)
139
-
140
- return bprop
141
-
142
-
143
- @bprop_getters.register(P.Rank)
144
- def get_bprop_rank(self):
145
- """Generate bprop for Rank"""
146
-
147
- def bprop(x, out, dout):
148
- return (zeros_like(x),)
149
-
150
- return bprop
151
-
152
-
153
- @bprop_getters.register(P.Reshape)
154
- def get_bprop_reshape(self):
155
- """Generate bprop for Reshape"""
156
-
157
- def bprop(x, shp, out, dout):
158
- shapex = shape_op(x)
159
- if F.is_sequence_value_unknown(shapex):
160
- shapex = dyn_shape_op(x)
161
- return reshape(dout, shapex), zeros_like(shp)
162
-
163
- return bprop
164
-
165
-
166
- @bprop_getters.register(P.ExpandDims)
167
- def get_bprop_expand_dims(self):
168
- """Generate bprop for ExpandDims"""
169
-
170
- def bprop(x, axis, out, dout):
171
- shapex = shape_op(x)
172
- if F.is_sequence_value_unknown(shapex):
173
- shapex = dyn_shape_op(x)
174
- return reshape(dout, shapex), zeros_like(axis)
175
-
176
- return bprop
177
-
178
-
179
- @bprop_getters.register(P.Squeeze)
180
- def get_bprop_squeeze(self):
181
- """Generate bprop for Squeeze"""
182
-
183
- def bprop(x, out, dout):
184
- shapex = shape_op(x)
185
- if F.is_sequence_value_unknown(shapex):
186
- shapex = dyn_shape_op(x)
187
- return (reshape(dout, shapex),)
188
-
189
- return bprop
190
-
191
-
192
- @bprop_getters.register(P.Flatten)
193
- def get_bprop_flatten(self):
194
- """Generate bprop for Flatten"""
195
- flatten_grad = P.Reshape()
196
-
197
- def bprop(x, out, dout):
198
- shape_x = shape_op(x)
199
- if F.is_sequence_value_unknown(shape_x):
200
- shape_x = dyn_shape_op(x)
201
- dx = flatten_grad(dout, shape_x)
202
- return (dx,)
203
-
204
- return bprop
205
-
206
-
207
- @_primexpr
208
- def _tile_shape(multiples, shapex):
209
- """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
210
- len_muli = len(multiples)
211
- rank = len(shapex)
212
- len_cmp = len_muli - rank
213
- max_len = max(len_muli, rank)
214
- i = 0
215
- j = 0
216
- ret = []
217
- while (i < max_len) and (j < max_len):
218
- if len_cmp == 0:
219
- ret.append(multiples[i])
220
- ret.append(shapex[j])
221
- i += 1
222
- j += 1
223
- elif len_cmp > 0:
224
- ret.append(multiples[i])
225
- ret.append(1)
226
- i += 1
227
- len_cmp -= 1
228
- else:
229
- ret.append(1)
230
- ret.append(shapex[j])
231
- j += 1
232
- len_cmp += 1
233
- return tuple(ret)
234
-
235
-
236
- @bprop_getters.register(P.Tile)
237
- def get_bprop_tile(self):
238
- """Generate bprop for Tile"""
239
- cast = P.Cast()
240
- concat = P.Concat()
241
- stridedslice = P.StridedSlice()
242
-
243
- def get_reduce_axis(r_shape):
244
- """
245
- reshape grad to r_shape, and reduce along all even dimensions to get the result with input_shape
246
- For example:
247
- input_shape = [20, 30, 40]
248
- multiples = [2, 3, 4]
249
- r_shape = [2, 20, 3, 30, 4, 40]
250
- axis = [0, 2, 4]
251
- """
252
- rankr = dyn_shape_op(r_shape)[0]
253
- tmp = range_op(0, 20, 2, mstype.int64)
254
- return stridedslice(tmp, (0,), F.expand_dims(rankr // 2, 0), (1,))
255
-
256
- def bprop(x, multiples, out, dout):
257
- shapex = shape_op(x)
258
- if F.is_sequence_value_unknown(shapex):
259
- shapex = dyn_shape_op(x)
260
- if isinstance(multiples, tuple) and isinstance(shapex, tuple):
261
- r_shape = _tile_shape(multiples, shapex)
262
- # 0 represents the start index, and 2 represents the step
263
- axis = F.make_range(0, len(r_shape), 2)
264
- else:
265
- shapex = dyn_shape_op(x)
266
- shapey = create_tensor_by_element(multiples)
267
- rankx = dyn_rank(x)
268
- ranky = dyn_shape_op(shapey)[0]
269
- offset = F.expand_dims(ranky - rankx + 1, 0)
270
- shape_x = concat((dyn_ones(offset, mstype.int64), shapex))
271
- shape_x = shape_x[1:]
272
- shapey = concat((P.Ones()((1,), mstype.int64), shapey))
273
- shapey = shapey[1:]
274
- tile_shape = P.Stack(1)((shapey, shape_x))
275
- r_shape = P.Reshape()(tile_shape, (-1,))
276
- axis = get_reduce_axis(r_shape)
277
-
278
- dout_reshaped = P.Reshape()(dout, r_shape)
279
- dout_origin_dtype = dout_reshaped.dtype
280
- # Currently, for Ascend and GPU, the reduce_sum's input does not support int16, int32 and int64.
281
- if dout_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
282
- dout_reshaped = cast(dout_reshaped, mstype.float32)
283
- dx = reduce_sum(dout_reshaped, axis)
284
- dx = cast(dx, dout_origin_dtype)
285
- else:
286
- dx = reduce_sum(dout_reshaped, axis)
287
- dx = reshape(dx, shapex)
288
- return dx, zeros_like(multiples)
289
-
290
- return bprop
291
-
292
-
293
- @bprop_getters.register(P.EmbeddingLookup)
294
- def get_bprop_embedding_lookup(self):
295
- """Generate bprop for EmbeddingLookup"""
296
- sub_op = P.Sub()
297
- reshape_op = P.Reshape()
298
-
299
- def bprop_sparse(x, indices, offset, out, dout):
300
- x_shp = shape_op(x)
301
- if F.is_sequence_value_unknown(x_shp):
302
- raise RuntimeError("Now, EmbeddingLookup op's grad don't support Dynamic Sense!")
303
- new_indices = sub_op(indices, offset)
304
- indices_size = size_op(new_indices)
305
- if indices_size > 0:
306
- # Reshape the 'new_indices'
307
- new_indices_shape_changed = (indices_size,)
308
- new_indices = reshape_op(new_indices, new_indices_shape_changed)
309
- else:
310
- new_indices_shape_changed = ()
311
- x_shp_tail = x_shp[1:]
312
- actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
313
- # Reshape the 'actual_dout' on device
314
- actual_dout = reshape_op(dout, actual_dout_shape_changed)
315
- return RowTensorInner(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
316
-
317
- return bprop_sparse
318
-
319
-
320
- @_primexpr
321
- def make_begin(shp):
322
- """Creates a tuple with zero according to the shape."""
323
- begin = tuple([0 for _ in shp])
324
- return begin
325
-
326
-
327
- def make_dynamic_begin(shp):
328
- """Creates a tuple with zero according to the shape."""
329
- begin = zeros_like(shp)
330
- return begin
331
-
332
-
333
- @bprop_getters.register(P.Padding)
334
- def get_bprop_padding(self):
335
- """Grad definition for `Padding` operation."""
336
-
337
- def bprop(x, out, dout):
338
- shp = shape_op(x)
339
- begin = ()
340
- if F.is_sequence_value_unknown(shp):
341
- shp = dyn_shape_op(x)
342
- begin = make_dynamic_begin(shp)
343
- else:
344
- begin = make_begin(shp)
345
- dx = P.Slice()(dout, begin, shp)
346
- return (dx,)
347
-
348
- return bprop
349
-
350
-
351
- @_primexpr
352
- def _concat_grad_uniform(input_shapes, input_nums):
353
- """Helper function for bprop of Concat"""
354
- is_uniform = True
355
- for i in range(1, input_nums):
356
- if input_shapes[i - 1] != input_shapes[i]:
357
- is_uniform = False
358
- break
359
- return is_uniform
360
-
361
-
362
- @bprop_getters.register(P.Concat)
363
- def get_bprop_concat(self):
364
- """Generate bprop for Concat"""
365
- axis = self.axis
366
-
367
- def bprop(x, out, dout):
368
- out_offset = G.ConcatOffset(len(x), axis)(x)
369
- input_nums = len(x)
370
- input_shapes = ()
371
- if isinstance(out_offset, tuple):
372
- for i in range(input_nums):
373
- input_shapes = input_shapes + (shape_op(x[i]),)
374
- is_uniform = _concat_grad_uniform(input_shapes, input_nums)
375
- else:
376
- # for dynamic shape
377
- for i in range(input_nums):
378
- input_shapes = input_shapes + (dyn_shape_op(x[i]),)
379
- is_uniform = False
380
-
381
- if isinstance(x, list):
382
- dx = []
383
- if is_uniform:
384
- dx_tuple = P.Split(axis, input_nums)(dout)
385
- for _, i in enumerate(dx_tuple):
386
- dx = dx + [i,]
387
- else:
388
- for i in range(input_nums):
389
- slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
390
- dx = dx + [slice_out,]
391
- else:
392
- dx = ()
393
- if is_uniform:
394
- dx = P.Split(axis, input_nums)(dout)
395
- else:
396
- for i in range(input_nums):
397
- slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
398
- dx = dx + (slice_out,)
399
- return (dx,)
400
-
401
- return bprop
402
-
403
-
404
- @bprop_getters.register(P.Slice)
405
- def get_bprop_slice(self):
406
- """Generate bprop for Slice"""
407
-
408
- def bprop(x, begin, size, out, dout):
409
- dx = G.SliceGrad()(dout, x, begin, size)
410
- return (dx, zeros_like(begin), zeros_like(size))
411
-
412
- return bprop
413
-
414
-
415
- @_primexpr
416
- def _generate_inverse_index(x_shape, axis, batch_dims=0):
417
- x_rank = len(x_shape)
418
- index = tuple(range(x_rank))
419
- if axis < 0:
420
- axis += x_rank
421
- perm = index[:batch_dims] + index[batch_dims + 1:1 + axis] + (index[batch_dims],) + index[1 + axis:]
422
- return perm
423
-
424
-
425
- @_primexpr
426
- def _regenerate_output_shape(x_shp, ind_shp, axis):
427
- rank = len(x_shp)
428
- if axis < 0:
429
- axis += rank
430
- out_shape = x_shp[:axis] + ind_shp + x_shp[axis + 1:]
431
- return out_shape
432
-
433
-
434
- def _dyn_regenerate_output_shape(x_shp, ind_shp, axis):
435
- """Get reshape new_shape"""
436
- rank = dyn_shape_op(x_shp)[0]
437
- if axis < 0:
438
- axis += rank
439
- out_shape = P.Concat(0)((x_shp[:axis], ind_shp, x_shp[axis + 1:]))
440
- return out_shape
441
-
442
-
443
- def _dyn_generate_shape_index(out_shape, indices_shape, axis, batch_dims=0):
444
- """Get tranpose order"""
445
- out_rank = F.reshape(dyn_shape_op(out_shape), ())
446
- ind_rank = F.reshape(dyn_shape_op(indices_shape), ())
447
- if axis < 0:
448
- axis += out_rank - ind_rank + 1
449
- perm_part1 = P.Range()(F.cast(0, mstype.int32), F.cast(20, mstype.int32), F.cast(1, mstype.int32))
450
- ind_end = axis + ind_rank - batch_dims
451
- perm_part1 = perm_part1[axis: ind_end]
452
- index = P.Range()(F.cast(0, mstype.int32), F.cast(out_rank, mstype.int32), F.cast(1, mstype.int32))
453
- perm = F.hstack((index[:batch_dims], perm_part1, index[batch_dims:axis], index[ind_end:]))
454
- return perm
455
-
456
-
457
- def _dyn_generate_inverse_index(x_shp, axis, batch_dims=0):
458
- """Get tranpose order"""
459
- x_rank = F.reshape(dyn_shape_op(x_shp), ())
460
- index = P.Range()(F.cast(0, mstype.int32), F.cast(x_rank, mstype.int32), F.cast(1, mstype.int32))
461
- if axis < 0:
462
- axis += x_rank
463
- perm = F.hstack((index[:batch_dims], index[batch_dims + 1:1 + axis], index[batch_dims], index[1 + axis:]))
464
- return perm
465
-
466
-
467
- def calculate_batch_gather(values, indices, x_shape, axis, batch_dims):
468
- """Calculate gather grad with batch_dims"""
469
- values_shape = dyn_shape_op(values)
470
- batch_size = F.prod(x_shape[:batch_dims])
471
- batch_size = F.cast(batch_size, mstype.int32)
472
- axis_dim = F.cast(x_shape[axis], mstype.int32)
473
-
474
- # Move batch dimension to first non-batch dimension
475
- values = values.reshape((-1,) + values.shape[batch_dims:])
476
- indices = indices.reshape((-1,) + indices.shape[batch_dims:])
477
- offset = P.Range()(F.cast(0, mstype.int32), batch_size * axis_dim, axis_dim)
478
- offset_shape = F.hstack([batch_size] + [Tensor(1, dtype=mstype.int32) for _ in range(len(indices.shape) - 1)])
479
- offset = reshape(offset, offset_shape)
480
- indices = indices + offset
481
- num_segments = batch_size * axis_dim
482
- params_grad = unsorted_segment_sum(values, indices, num_segments)
483
- grad_shape = dyn_shape_op(params_grad)
484
- ret_shape = F.hstack([values_shape[:batch_dims], F.cast(axis_dim, mstype.int64), grad_shape[1:]])
485
- params_grad = reshape(params_grad, ret_shape)
486
- return params_grad
487
-
488
-
489
- @bprop_getters.register(P.Gather)
490
- @bprop_getters.register(P.GatherV2)
491
- def get_bprop_gather_v2(self):
492
- """Generate bprop for GatherV2"""
493
-
494
- def _dyn_bprop_gather_v2(x, indices, axis, dout):
495
- """dyn shape bprop for GatherV2"""
496
- orig_indices = indices
497
- x_shp = dyn_shape_op(x)
498
- ind_shp = dyn_shape_op(indices)
499
- out_shp = dyn_shape_op(dout)
500
- batch_dims = self.batch_dims
501
- if batch_dims < 0:
502
- batch_dims += F.reshape(dyn_shape_op(ind_shp), ())
503
-
504
- if F.rank(dout) == 0:
505
- dout = P.ExpandDims()(dout, -1)
506
- if F.rank(indices) == 0:
507
- indices = P.ExpandDims()(indices, -1)
508
- out_shp = _dyn_regenerate_output_shape(x_shp, ind_shp, axis)
509
- dout = reshape(dout, out_shp)
510
-
511
- # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
512
- perm_1 = _dyn_generate_shape_index(out_shp, ind_shp, axis, batch_dims)
513
- values_transpose = transpose(dout, perm_1)
514
- if batch_dims > 0:
515
- params_grad = calculate_batch_gather(values_transpose, indices, x_shp, axis, batch_dims)
516
- else:
517
- params_grad = unsorted_segment_sum(values_transpose, indices, x_shp[axis])
518
- perm_2 = _dyn_generate_inverse_index(x_shp, axis, batch_dims)
519
- params_grad = transpose(params_grad, perm_2)
520
- return params_grad, zeros_like(orig_indices), zeros_like(axis)
521
-
522
- def bprop(x, indices, axis, out, dout):
523
- is_mutable, axis = convert_to_tensor(axis)
524
- if (F.is_sequence_value_unknown(shape_op(x)) or F.is_sequence_value_unknown(shape_op(indices)) or \
525
- F.is_sequence_value_unknown(shape_op(dout))) and is_mutable:
526
- return _dyn_bprop_gather_v2(x, indices, axis, dout)
527
- orig_indices = indices
528
- if F.rank(dout) == 0:
529
- dout = P.ExpandDims()(dout, -1)
530
- if F.rank(indices) == 0:
531
- indices = P.ExpandDims()(indices, -1)
532
- x_shp = shape_op(x)
533
- ind_shp = shape_op(indices)
534
- out_shp = _regenerate_output_shape(x_shp, ind_shp, axis)
535
- dout = reshape(dout, out_shp)
536
-
537
- x_shp = shape_op(x)
538
- out_shp = shape_op(dout)
539
- ind_shp = shape_op(indices)
540
- batch_dims = self.batch_dims
541
- if batch_dims < 0:
542
- batch_dims += len(ind_shp)
543
- # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
544
- perm_1 = generate_shape_index(out_shp, ind_shp, axis, batch_dims)
545
- values_transpose = transpose(dout, perm_1)
546
- dyn_x_sape = dyn_shape_op(x)
547
- if batch_dims > 0:
548
- params_grad = calculate_batch_gather(values_transpose, indices, dyn_x_sape, axis, batch_dims)
549
- else:
550
- params_grad = unsorted_segment_sum(values_transpose, indices, dyn_x_sape[axis])
551
- # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
552
- perm_2 = _generate_inverse_index(x_shp, axis, batch_dims)
553
- params_grad = transpose(params_grad, perm_2)
554
- return params_grad, zeros_like(orig_indices), zeros_like(axis)
555
-
556
- return bprop
557
-
558
-
559
- @bprop_getters.register(P.GatherD)
560
- def get_bprop_gather_d(self):
561
- """Generate bprop for GatherD"""
562
-
563
- def bprop(x, dim, index, out, dout):
564
- dx = G.GatherDGradV2()(x, dim, index, dout)
565
- return dx, zeros_like(dim), zeros_like(index)
566
-
567
- return bprop
568
-
569
-
570
- @bprop_getters.register(G.GatherDGrad)
571
- def get_bprop_gather_d_grad(self):
572
- """Generate bprop for GatherDGrad"""
573
- op = P.Gather()
574
- dim = self.dim
575
- x_shp = self.out_shape
576
-
577
- def bprop(index, x, out, dout):
578
- index_shp = shape_op(index)
579
- dim_before_axis = 1
580
- for i in range(dim):
581
- dim_before_axis *= x_shp[i]
582
- dim_at_axis_index = index_shp[dim]
583
- dim_at_axis_output = x_shp[dim]
584
- dim_after_axis = 1
585
- for i in range(dim + 1, len(x_shp)):
586
- dim_after_axis *= x_shp[i]
587
- element = dim_before_axis * dim_at_axis_index * dim_after_axis
588
- id_ = range_op(0, element, 1, index.dtype)
589
- i = id_ // (dim_at_axis_index * dim_after_axis)
590
- k = id_ % dim_after_axis
591
- j = P.Cast()(index < 0, index.dtype)
592
- j_read = dim_at_axis_index * j + index
593
- j_read = P.Reshape()(j_read, (-1,))
594
- read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k
595
- dout = P.Reshape()(dout, (-1,))
596
- dx = op(dout, read_id, 0)
597
- dx = P.Reshape()(dx, shape_op(x))
598
- return zeros_like(index), dx
599
-
600
- return bprop
601
-
602
-
603
- @bprop_getters.register(G.GatherDGradV2)
604
- def get_bprop_gather_d_grad_v2(self):
605
- """Generate bprop for GatherDGradV2"""
606
- op = P.Gather()
607
- dim = self.dim
608
-
609
- def bprop(index, x, out, dout):
610
- index_shp = shape_op(index)
611
- dim_before_axis = 1
612
- x_shp = shape_op(x)
613
- for i in range(dim):
614
- dim_before_axis *= x_shp[i]
615
- dim_at_axis_index = index_shp[dim]
616
- dim_at_axis_output = x_shp[dim]
617
- dim_after_axis = 1
618
- for i in range(dim + 1, len(x_shp)):
619
- dim_after_axis *= x_shp[i]
620
- element = dim_before_axis * dim_at_axis_index * dim_after_axis
621
- id_ = range_op(0, element, 1, index.dtype)
622
- i = id_ // (dim_at_axis_index * dim_after_axis)
623
- k = id_ % dim_after_axis
624
- j = P.Cast()(index < 0, index.dtype)
625
- j_read = dim_at_axis_index * j + index
626
- j_read = P.Reshape()(j_read, (-1,))
627
- read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k
628
- dout = P.Reshape()(dout, (-1,))
629
- dx = op(dout, read_id, 0)
630
- dx = P.Reshape()(dx, shape_op(x))
631
- return zeros_like(index), dx
632
-
633
- return bprop
634
-
635
-
636
- @bprop_getters.register(P.SparseGatherV2)
637
- def get_bprop_sparse_gather_v2(self):
638
- """Generate bprop for SparseGatherV2"""
639
-
640
- def bprop(x, indices, axis, out, dout):
641
- x_shp = shape_op(x)
642
- if axis == 0:
643
- indices_size = (size_op(indices),)
644
- if len(x_shp) <= 1:
645
- x_tail_shp = ()
646
- else:
647
- x_tail_shp = x_shp[1:]
648
- values_shape = indices_size + x_tail_shp
649
- values = reshape(dout, values_shape)
650
- indices_new = reshape(indices, indices_size)
651
- return RowTensorInner(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis)
652
- if F.rank(dout) == 0:
653
- dout = P.ExpandDims()(dout, -1)
654
- if F.rank(indices) == 0:
655
- indices = P.ExpandDims()(indices, -1)
656
- out_shp = shape_op(dout)
657
- ind_shp = shape_op(indices)
658
- # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
659
- perm_1 = generate_shape_index(out_shp, ind_shp, axis)
660
- values_transpose = transpose(dout, perm_1)
661
- params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
662
- # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
663
- perm_2 = _generate_inverse_index(x_shp, axis)
664
- params_grad = transpose(params_grad, perm_2)
665
- return params_grad, zeros_like(indices), zeros_like(axis)
666
-
667
- return bprop
668
-
669
-
670
- @constexpr
671
- def _get_transposition(axis, rank):
672
- """helper function for grad of Sort"""
673
- if axis < 0:
674
- axis += rank
675
- transposition = np.r_[np.arange(axis), [rank - 1], np.arange(axis + 1, rank - 1), [axis]]
676
- trans = tuple(transposition.tolist())
677
- return trans
678
-
679
-
680
- @bprop_getters.register(P.Sort)
681
- def get_bprop_sort(self):
682
- """Grad definition for `Sort` operation."""
683
- axis = self.axis
684
- descending = self.descending
685
- scatter = P.ScatterNd()
686
- expand_dims = P.ExpandDims()
687
- reshape_op = P.Reshape()
688
- dtype = P.DType()
689
- topk = P.TopK()
690
- neg = P.Neg()
691
- tranpose = P.Transpose()
692
-
693
- def bprop(input_x, out, dout):
694
- x_shape = input_x.shape
695
- k = x_shape[axis]
696
- rank = F.rank(input_x)
697
- dvalue = dout[0]
698
- if not descending:
699
- input_x = neg(input_x)
700
- dvalue = neg(dvalue)
701
- if axis == -1 or (axis + 1) == rank:
702
- transposition = None
703
- top_k_input = input_x
704
- else:
705
- transposition = _get_transposition(axis, rank)
706
- top_k_input = tranpose(input_x, transposition)
707
-
708
- _, indices = topk(top_k_input, k)
709
- ind_shape = indices.shape
710
- top_k_input_shape = top_k_input.shape
711
- in_lastdim = top_k_input_shape[-1]
712
- ind_lastdim = ind_shape[-1]
713
- ind_2d = reshape_op(indices, (-1, ind_lastdim))
714
- outer_dim = ind_2d.shape[0]
715
-
716
- indices_dtype = dtype(indices)
717
- range_flatten_index = range_op(0, outer_dim * in_lastdim, in_lastdim, indices_dtype)
718
-
719
- # expand_dims to (k, 1), then broadcast
720
- ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
721
- x_shape_1d = get_1d_shape(top_k_input_shape)
722
-
723
- if transposition is not None:
724
- dvalue = tranpose(dvalue, invert_permutation(transposition))
725
- out_grad = reshape_op(
726
- scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
727
- dx = tranpose(out_grad, invert_permutation(transposition))
728
- else:
729
- dx = reshape_op(scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
730
- if not descending:
731
- dx = neg(dx)
732
- return (dx,)
733
-
734
- return bprop
735
-
736
-
737
- @bprop_getters.register(P.Identity)
738
- def get_bprop_identity(self):
739
- """Generate bprop for Identity"""
740
-
741
- def bprop(x, out, dout):
742
- return (dout,)
743
-
744
- return bprop
745
-
746
-
747
- @bprop_getters.register(P.Range)
748
- def get_bprop_range(self):
749
- """Generate bprop for Range"""
750
-
751
- def bprop(start, limit, delta, out, dout):
752
- return (zeros_like(start), zeros_like(limit), zeros_like(delta))
753
-
754
- return bprop
755
-
756
-
757
- @bprop_getters.register(P.Pack)
758
- @bprop_getters.register(P.Stack)
759
- def get_bprop_stack(self):
760
- """Generate bprop for Stack"""
761
- axis = self.axis
762
-
763
- def bprop(x, out, dout):
764
- stack_grad = P.Unstack(num=len(x), axis=axis)
765
- out = stack_grad(dout)
766
- if is_sub_class(F.typeof(x), ms.list_):
767
- ret = []
768
- for item in out:
769
- ret.append(item)
770
- return (ret,)
771
- return (out,)
772
-
773
- return bprop
774
-
775
-
776
- @bprop_getters.register(P.ReverseV2)
777
- def get_bprop_reverse_v2(self):
778
- """Generate bprop for ReverseV2"""
779
- axis = self.axis
780
-
781
- def bprop(x, out, dout):
782
- reverse_grad = P.ReverseV2(axis)
783
- dx = reverse_grad(dout)
784
- return (dx,)
785
-
786
- return bprop
787
-
788
-
789
- @bprop_getters.register(P.Unstack)
790
- def get_bprop_unstack(self):
791
- """Generate bprop for Unstack"""
792
- axis = self.axis
793
-
794
- def bprop(x, out, dout):
795
- unstack_grad = P.Stack(axis)
796
- out = unstack_grad(dout)
797
- return (out,)
798
-
799
- return bprop
800
-
801
-
802
- @bprop_getters.register(P.StridedSlice)
803
- def get_bprop_strided_slice(self):
804
- """Generate bprop for StridedSlice"""
805
- input_grad = G.StridedSliceGrad(self.begin_mask,
806
- self.end_mask,
807
- self.ellipsis_mask,
808
- self.new_axis_mask,
809
- self.shrink_axis_mask)
810
-
811
- def bprop(x, begin, end, strides, out, dout):
812
- x_shape = shape_op(x)
813
- if F.is_sequence_value_unknown(x_shape):
814
- x_shape = dyn_shape_op(x)
815
- dx = input_grad(dout, x_shape, begin, end, strides)
816
- return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
817
-
818
- return bprop
819
-
820
-
821
- @bprop_getters.register(G.StridedSliceGrad)
822
- def get_bprop_strided_slice_grad(self):
823
- """Generate bprop for StridedSliceGrad"""
824
- strided_slice = P.StridedSlice(begin_mask=self.begin_mask,
825
- end_mask=self.end_mask,
826
- ellipsis_mask=self.ellipsis_mask,
827
- new_axis_mask=self.new_axis_mask,
828
- shrink_axis_mask=self.shrink_axis_mask)
829
-
830
- def bprop(dy, shapex, begin, end, strides, out, dout):
831
- return strided_slice(dout, begin, end, strides), zeros_like(shapex), zeros_like(begin), zeros_like(end), \
832
- zeros_like(strides)
833
-
834
- return bprop
835
-
836
-
837
- @bprop_getters.register(P.Eye)
838
- def get_bprop_eye(self):
839
- """Generate bprop for Eye"""
840
-
841
- def bprop(n, m, t, out, dout):
842
- return zeros_like(n), zeros_like(m), zeros_like(t)
843
-
844
- return bprop
845
-
846
-
847
- @bprop_getters.register(P.Select)
848
- def get_bprop_select(self):
849
- """Generate bprop for Select"""
850
- select = P.Select()
851
-
852
- def bprop(cond, x, y, out, dout):
853
- return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
854
-
855
- return bprop
856
-
857
-
858
- @bprop_getters.register(P.OnesLike)
859
- def get_bprop_oneslike(self):
860
- """Generate bprop for OnesLike"""
861
-
862
- def bprop(x, out, dout):
863
- return (zeros_like(x),)
864
-
865
- return bprop
866
-
867
-
868
- @bprop_getters.register(P.ZerosLike)
869
- def get_bprop_zeroslike(self):
870
- """Generate bprop for ZerosLike"""
871
-
872
- def bprop(x, out, dout):
873
- return (zeros_like(x),)
874
-
875
- return bprop
876
-
877
-
878
- @bprop_getters.register(P.ResizeNearestNeighbor)
879
- def get_bprop_resize_nearest_neighbor(self):
880
- """Generate bprop for ResizeNearestNeighbor"""
881
- op = G.ResizeNearestNeighborGrad(self.align_corners)
882
- tensor_shape = P.TensorShape()
883
-
884
- def bprop(inputs, out, dout):
885
- if F.is_sequence_value_unknown(shape_op(inputs)) or F.is_sequence_shape_unknown(shape_op(inputs)):
886
- shp = tensor_shape(inputs)
887
- else:
888
- shp = shape_op(inputs)
889
- # 2 and 3 represent the height and width
890
- shp = shp[2:]
891
- return (op(dout, shp),)
892
-
893
- return bprop
894
-
895
-
896
- @bprop_getters.register(P.GatherNd)
897
- def get_bprop_gather_nd(self):
898
- """Generate bprop for GatherNd"""
899
- op = P.ScatterNd()
900
-
901
- def bprop(x, indices, out, dout):
902
- shp = shape_op(x)
903
- if F.is_sequence_value_unknown(shp):
904
- shp = dyn_shape_op(x)
905
- return op(indices, dout, shp), zeros_like(indices)
906
-
907
- return bprop
908
-
909
-
910
- @bprop_getters.register(P.ScatterNd)
911
- def get_bprop_scatter_nd(self):
912
- """Generate bprop for ScatterNd"""
913
- op = P.GatherNd()
914
-
915
- def bprop(indices, x, shape, out, dout):
916
- return zeros_like(indices), op(dout, indices), zeros_like(shape)
917
-
918
- return bprop
919
-
920
-
921
- @bprop_getters.register(P.ScatterNdUpdate)
922
- def get_bprop_scatter_nd_update(self):
923
- """Generate bprop for ScatterNdUpdate"""
924
- op = P.GatherNd()
925
-
926
- def bprop(x, indices, update, out, dout):
927
- return dout, zeros_like(indices), op(dout, indices)
928
-
929
- return bprop
930
-
931
-
932
- @bprop_getters.register(P.ScatterNonAliasingAdd)
933
- def get_bprop_scatter_non_aliasing_add_update(self):
934
- """Generate bprop for ScatterNonAliasingAdd"""
935
- op = P.GatherNd()
936
-
937
- def bprop(x, indices, update, out, dout):
938
- return dout, zeros_like(indices), op(dout, indices)
939
-
940
- return bprop
941
-
942
-
943
- @bprop_getters.register(P.TensorScatterUpdate)
944
- def get_bprop_tensor_scatter_update(self):
945
- """Generate bprop for TensorScatterUpdate"""
946
- gather_nd = P.GatherNd()
947
- tensor_scatter_update = P.TensorScatterUpdate()
948
-
949
- def bprop(x, indices, update, out, dout):
950
- x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
951
- update_grad = gather_nd(dout, indices)
952
- return x_grad, zeros_like(indices), update_grad
953
-
954
- return bprop
955
-
956
-
957
- @bprop_getters.register(P.TensorScatterAdd)
958
- def get_bprop_tensor_scatter_add(self):
959
- """Generate bprop for TensorScatterAdd"""
960
- gather_nd = P.GatherNd()
961
-
962
- def bprop(x, indices, update, out, dout):
963
- update_grad = gather_nd(dout, indices)
964
- return dout, zeros_like(indices), update_grad
965
-
966
- return bprop
967
-
968
-
969
- @bprop_getters.register(P.ScatterMax)
970
- def get_bprop_scatter_max(self):
971
- """Generate bprop for ScatterMax"""
972
- gather = P.Gather()
973
-
974
- def bprop(x, indices, update, out, dout):
975
- return dout, zeros_like(indices), gather(dout, indices, 0)
976
-
977
- return bprop
978
-
979
-
980
- @bprop_getters.register(P.ScatterMin)
981
- def get_bprop_scatter_min(self):
982
- """Generate bprop for ScatterMin"""
983
- gather = P.Gather()
984
-
985
- def bprop(x, indices, update, out, dout):
986
- return dout, zeros_like(indices), gather(dout, indices, 0)
987
-
988
- return bprop
989
-
990
-
991
- @bprop_getters.register(P.ScatterUpdate)
992
- def get_bprop_scatter_update(self):
993
- """Generate bprop for ScatterUpdate"""
994
- gather = P.Gather()
995
-
996
- def bprop(x, indices, update, out, dout):
997
- return dout, zeros_like(indices), gather(dout, indices, 0)
998
-
999
- return bprop
1000
-
1001
-
1002
- @bprop_getters.register(P.Argmax)
1003
- def get_bprop_argmax(self):
1004
- """Generate bprop for Argmax"""
1005
-
1006
- def bprop(x, out, dout):
1007
- return (zeros_like(x),)
1008
-
1009
- return bprop
1010
-
1011
-
1012
- @bprop_getters.register(P.Argmin)
1013
- def get_bprop_argmin(self):
1014
- """Generate bprop for Argmin"""
1015
-
1016
- def bprop(x, out, dout):
1017
- return (zeros_like(x),)
1018
-
1019
- return bprop
1020
-
1021
-
1022
- @bprop_getters.register(P.SpaceToDepth)
1023
- def get_bprop_space_to_depth(self):
1024
- """Generate bprop for SpaceToDepth"""
1025
- op = P.DepthToSpace(self.block_size)
1026
-
1027
- def bprop(x, out, dout):
1028
- return (op(dout),)
1029
-
1030
- return bprop
1031
-
1032
-
1033
- @bprop_getters.register(P.DepthToSpace)
1034
- def get_bprop_depth_to_space(self):
1035
- """Generate bprop for DepthToSpace"""
1036
- op = P.SpaceToDepth(self.block_size)
1037
-
1038
- def bprop(x, out, dout):
1039
- return (op(dout),)
1040
-
1041
- return bprop
1042
-
1043
-
1044
- @bprop_getters.register(P.Diag)
1045
- def get_bprop_diag(self):
1046
- """Generate bprop for Diag"""
1047
- op = P.DiagPart()
1048
-
1049
- def bprop(x, out, dout):
1050
- return (op(dout),)
1051
-
1052
- return bprop
1053
-
1054
-
1055
- @bprop_getters.register(P.DiagPart)
1056
- def get_bprop_diag_part(self):
1057
- """Generate bprop for DiagPart"""
1058
- op = P.Diag()
1059
-
1060
- def bprop(x, out, dout):
1061
- return (op(dout),)
1062
-
1063
- return bprop
1064
-
1065
-
1066
- def _gather_drop_negatives(params,
1067
- ids,
1068
- zero_clipped_indices=None,
1069
- is_positive=None):
1070
- """Helper function for unsorted segment ops."""
1071
- maximum = P.Maximum()
1072
- gather = P.Gather()
1073
- greater_equal = P.GreaterEqual()
1074
- rank = P.Rank()
1075
- fill = P.Fill()
1076
- select = P.Select()
1077
-
1078
- if zero_clipped_indices is None:
1079
- zero_clipped_indices = maximum(ids, zeros_like(ids))
1080
- gathered = gather(params, zero_clipped_indices, 0)
1081
- zero_slice = zeros_like(gathered)
1082
- if is_positive is None:
1083
- is_positive = greater_equal(ids, 0)
1084
- is_positive_shape = shape_op(is_positive)
1085
- gathered_shape = shape_op(gathered)
1086
- if F.is_sequence_value_unknown(gathered_shape) or F.is_sequence_value_unknown(is_positive_shape):
1087
- gathered_shape = dyn_shape_op(gathered)
1088
- rank_gathered = dyn_rank(gathered)
1089
- fill_gathered = dyn_fill(mstype.int64, gathered_shape, 1)
1090
- is_positive_shape = dyn_shape_op(is_positive)
1091
- rank_positive = dyn_rank(is_positive)
1092
- if rank_gathered - rank_positive > 0:
1093
- padded_size = F.expand_dims(rank_gathered - rank_positive, 0)
1094
- padded_shape = dyn_ones(padded_size, is_positive_shape.dtype)
1095
- is_positive_shape = P.Concat(-1)((is_positive_shape, padded_shape))
1096
- is_positive = reshape(is_positive, is_positive_shape)
1097
- is_positive = logical_and(is_positive, F.cast(fill_gathered, mstype.bool_))
1098
- else:
1099
- broadcastable_shape = is_positive_shape
1100
- for _ in range(rank(gathered) - rank(is_positive)):
1101
- broadcastable_shape += (1,)
1102
- is_positive = reshape(is_positive, broadcastable_shape)
1103
- is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
1104
- return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
1105
-
1106
-
1107
- def _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout):
1108
- """Gradient for UnsortedSegmentMin or UnsortedSegmentMax"""
1109
- equal = P.Equal()
1110
- cast = P.Cast()
1111
- divide = P.RealDiv()
1112
- get_dtype = P.DType()
1113
- select = P.Select()
1114
-
1115
- gathered_outputs, zero_clipped_indices, is_positive = _gather_drop_negatives(out, segment_ids, None, None)
1116
- is_selected = equal(x, gathered_outputs)
1117
- is_selected = logical_and(is_selected, is_positive)
1118
- num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
1119
- segment_ids, num_segments)
1120
- weighted_grads = divide(dout, num_selected)
1121
- gathered_grads, _, _ = _gather_drop_negatives(weighted_grads, None,
1122
- zero_clipped_indices, is_positive)
1123
- zeros = zeros_like(gathered_grads)
1124
- return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
1125
-
1126
-
1127
- @bprop_getters.register(P.UnsortedSegmentSum)
1128
- def get_bprop_unsorted_segment_sum(self):
1129
- """Generate bprop for UnsortedSegmentSum"""
1130
-
1131
- def bprop(x, segment_ids, num_segments, out, dout):
1132
- return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
1133
- zeros_like(num_segments)
1134
-
1135
- return bprop
1136
-
1137
-
1138
- @bprop_getters.register(P.UnsortedSegmentMin)
1139
- def get_bprop_unsorted_segment_min(self):
1140
- """Generate bprop for UnsortedSegmentMin"""
1141
-
1142
- def bprop(x, segment_ids, num_segments, out, dout):
1143
- return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
1144
-
1145
- return bprop
1146
-
1147
-
1148
- @bprop_getters.register(P.UnsortedSegmentMax)
1149
- def get_bprop_unsorted_segment_max(self):
1150
- """Generate bprop for UnsortedSegmentMax"""
1151
-
1152
- def bprop(x, segment_ids, num_segments, out, dout):
1153
- return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
1154
-
1155
- return bprop
1156
-
1157
-
1158
- @bprop_getters.register(P.UnsortedSegmentProd)
1159
- def get_bprop_unsorted_segment_prod(self):
1160
- """Generate bprop for UnsortedSegmentProd"""
1161
- equal = P.Equal()
1162
- cast = P.Cast()
1163
- select = P.Select()
1164
- gather = P.Gather()
1165
- greater = P.Greater()
1166
- ones_like = P.OnesLike()
1167
- maximum = P.Maximum()
1168
- unsorted_segment_prod = P.UnsortedSegmentProd()
1169
-
1170
- def bprop(x, segment_ids, num_segments, out, dout):
1171
- if x.dtype == mstype.complex64 or x.dtype == mstype.complex128:
1172
- raise TypeError("For 'UnsortedSegmentProd', complex number is not supported for gradient currently.")
1173
- if x.dtype == mstype.complex64 or x.dtype == mstype.complex128:
1174
- is_zero = equal(x, F.scalar_to_tensor(0).astype(x.dtype))
1175
- else:
1176
- is_zero = equal(cast(x, mstype.float32), F.scalar_to_tensor(0).astype(np.float32))
1177
- num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments)
1178
- grad = select(greater(num_zero, 1), zeros_like(dout), dout)
1179
- if x.dtype == mstype.complex64 or x.dtype == mstype.complex128:
1180
- non_zero_data = select(is_zero, ones_like(x), x)
1181
- else:
1182
- temp_var = ones_like(cast(x, mstype.float32))
1183
- non_zero_data = select(is_zero, cast(temp_var, x.dtype), x)
1184
- non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments)
1185
- zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids))
1186
- gathered_prod = gather(out, zero_clipped_indices, 0)
1187
- gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
1188
- if x.dtype == mstype.uint32 or x.dtype == mstype.uint64:
1189
- prod_divided_by_x = cast(gathered_prod, mstype.float32) / cast(x, mstype.float32)
1190
- else:
1191
- prod_divided_by_x = gathered_prod / x
1192
- partial_derivative = select(is_zero, gathered_non_zero_prod,
1193
- cast(prod_divided_by_x, gathered_non_zero_prod.dtype))
1194
- gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices, None)
1195
- if x.dtype == mstype.uint32 or x.dtype == mstype.uint64:
1196
- temp_dx = cast(gathered_grad, mstype.float32) * cast(partial_derivative, mstype.float32)
1197
- dx = cast(temp_dx, x.dtype)
1198
- else:
1199
- dx = gathered_grad * partial_derivative
1200
- return dx, zeros_like(segment_ids), zeros_like(num_segments)
1201
-
1202
- return bprop
1203
-
1204
-
1205
- @bprop_getters.register(P.SpaceToBatch)
1206
- def get_bprop_space_to_batch(self):
1207
- """Generate bprop for SpaceToBatch"""
1208
- space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
1209
-
1210
- def bprop(x, out, dout):
1211
- dx = space_to_batch_grad(dout)
1212
- return (dx,)
1213
-
1214
- return bprop
1215
-
1216
-
1217
- @bprop_getters.register(P.BatchToSpace)
1218
- def get_bprop_batch_to_space(self):
1219
- """Generate bprop for BatchToSpace"""
1220
- batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
1221
-
1222
- def bprop(x, out, dout):
1223
- dx = batch_to_space_grad(dout)
1224
- return (dx,)
1225
-
1226
- return bprop
1227
-
1228
-
1229
- @bprop_getters.register(P.SpaceToBatchND)
1230
- def get_bprop_space_to_batch_nd(self):
1231
- """Generate bprop for SpaceToBatchND"""
1232
- space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
1233
-
1234
- def bprop(x, out, dout):
1235
- dx = space_to_batch_nd_grad(dout)
1236
- return (dx,)
1237
-
1238
- return bprop
1239
-
1240
-
1241
- @bprop_getters.register(P.BatchToSpaceND)
1242
- def get_bprop_batch_to_space_nd(self):
1243
- """Generate bprop for BatchToSpaceND"""
1244
- batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
1245
-
1246
- def bprop(x, out, dout):
1247
- dx = batch_to_space_nd_grad(dout)
1248
- return (dx,)
1249
-
1250
- return bprop
1251
-
1252
-
1253
- @bprop_getters.register(P.BroadcastTo)
1254
- def get_bprop_broadcast_to(self):
1255
- """Generate bprop for BroadcastTo"""
1256
- reduce_keep_dim = P.ReduceSum(keep_dims=True)
1257
-
1258
- def bprop(x, out, dout):
1259
- x_shape = shape_op(x)
1260
- dout_shape = shape_op(dout)
1261
- broadcast_shape = shape_op(out)
1262
- dynamic = F.is_sequence_value_unknown(x_shape) or F.is_sequence_value_unknown(dout_shape)
1263
- if not dynamic and x_shape == dout_shape:
1264
- return (dout,)
1265
- dynamic = dynamic or F.is_sequence_value_unknown(broadcast_shape)
1266
- out_type = dout.dtype
1267
- if not dynamic:
1268
- _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
1269
- if out_type in (ms.int16, ms.int32, ms.int64):
1270
- dout = P.Cast()(dout, ms.float32)
1271
- reduced_grad = reduce_keep_dim(dout, reduction_axes)
1272
- reduced_grad = P.Cast()(reduced_grad, out_type)
1273
- else:
1274
- reduced_grad = reduce_keep_dim(dout, reduction_axes)
1275
- dx = reshape(reduced_grad, x_shape)
1276
- else:
1277
- x_shape = dyn_shape_op(x)
1278
- broadcast_shape = dyn_shape_op(out)
1279
- _, reduction_axes = DynamicBroadcastGradientArgs()(broadcast_shape, x_shape)
1280
- if out_type in (ms.int16, ms.int32, ms.int64):
1281
- dout = P.Cast()(dout, ms.float32)
1282
- reduced_grad = sum_grad_reduce_axis(dout, reduction_axes, keep_dims=True)
1283
- reduced_grad = P.Cast()(reduced_grad, out_type)
1284
- else:
1285
- reduced_grad = sum_grad_reduce_axis(dout, reduction_axes, keep_dims=True)
1286
- dx = reshape(reduced_grad, x_shape)
1287
- return (dx,)
1288
-
1289
- return bprop
1290
-
1291
-
1292
- @bprop_getters.register(P.ReverseSequence)
1293
- def get_bprop_reverse_sequence(self):
1294
- """Generate bprop for ReverseSequence"""
1295
- reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)
1296
-
1297
- def bprop(x, seq_lengths, out, dout):
1298
- dx = reverse_sequence_grad(dout, seq_lengths)
1299
- return dx, zeros_like(seq_lengths)
1300
-
1301
- return bprop
1302
-
1303
-
1304
- @bprop_getters.register(P.TransShape)
1305
- def get_bprop_trans_shape(self):
1306
- """Generate bprop for TransShape"""
1307
- op = P.TransShape()
1308
-
1309
- def bprop(x, shape, out, dout):
1310
- dx = op(dout, shape_op(x))
1311
- return (dx, zeros_like(shape))
1312
-
1313
- return bprop
1314
-
1315
-
1316
- @bprop_getters.register(P.Unique)
1317
- def get_bprop_unique(self):
1318
- """Generate bprop for Unique"""
1319
- op = G.UniqueGrad()
1320
-
1321
- def bprop(x, out, dout):
1322
- dx = op(dout, out)
1323
- return (dx,)
1324
-
1325
- return bprop
1326
-
1327
-
1328
- @bprop_getters.register(P.MaskedSelect)
1329
- def get_bprop_masked_select(self):
1330
- """Generate bprop for MaskedSelect"""
1331
- op = G.MaskedSelectGrad()
1332
-
1333
- def bprop(x, mask, out, dout):
1334
- dx = op(x, mask, dout)
1335
- return (dx, zeros_like(mask))
1336
-
1337
- return bprop
1338
-
1339
-
1340
- @bprop_getters.register(NonZero)
1341
- def get_bprop_non_zero(self):
1342
- """Generate bprop for NonZero"""
1343
-
1344
- def bprop(x, out, dout):
1345
- return (zeros_like(x),)
1346
-
1347
- return bprop