mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -19,38 +19,28 @@ from __future__ import absolute_import
19
19
  from mindspore import Tensor
20
20
  from mindspore.ops.primitive import constexpr
21
21
  from mindspore.common import dtype as mstype
22
- from mindspore.numpy.array_ops import where
23
- from mindspore.ops._grad.grad_math_ops import binop_grad_common
24
- from mindspore.ops._grad.grad_base import bprop_getters, dyn_rank, dyn_fill, dyn_ones, create_tensor_by_element
25
- from mindspore.ops._grad.grad_base import convert_to_tensor
22
+ from mindspore.ops._grad_experimental.grad_math_ops import binop_grad_common
23
+ from mindspore.ops._grad_experimental.grad_base import bprop_getters, dyn_ones
24
+ from mindspore.ops._grad_experimental.grad_base import convert_to_tensor, create_tensor_by_element
26
25
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
27
- from mindspore.ops.operations.array_ops import Tril
28
26
  from mindspore.ops.operations.array_ops import MatrixDiagV3
29
27
  from mindspore.ops.operations.array_ops import MatrixDiagPartV3
30
28
  from mindspore.ops.operations.array_ops import ResizeNearestNeighborV2
31
29
  from mindspore.ops.operations.array_ops import MatrixSetDiagV3
30
+ from mindspore.ops.operations.array_ops import MatrixBandPart
32
31
  from mindspore.ops.operations.array_ops import Mvlgamma
33
- from mindspore.ops.operations.array_ops import Triu
34
- from mindspore.ops.operations.array_ops import IdentityN
35
32
  from mindspore.ops.operations.array_ops import IndexFill
36
33
  from mindspore.ops.operations.array_ops import IndexPut
37
- from mindspore.ops.operations.array_ops import CheckNumerics
38
- from mindspore.ops.operations.array_ops import ConjugateTranspose
39
- from mindspore.ops.operations.array_ops import SegmentMax
40
- from mindspore.ops.operations.array_ops import SegmentMin
41
34
  from mindspore.ops.operations.array_ops import SegmentSum
42
- from mindspore.ops.operations.array_ops import TensorScatterElements
43
35
  from mindspore.ops.operations.array_ops import ScatterAddWithAxis
44
36
  from mindspore.ops.operations.array_ops import Expand
45
37
  from mindspore.ops.operations.array_ops import SegmentMean
46
38
  from mindspore.ops.operations.array_ops import AffineGrid
47
39
  from mindspore.ops.operations.array_ops import Im2Col
48
40
  from mindspore.ops.operations.array_ops import Col2Im
49
- from mindspore.ops.operations.array_ops import StridedSliceV2
50
41
  from mindspore.ops.operations.array_ops import MaskedScatter
51
42
  from mindspore.ops.operations.array_ops import MaskedSelect
52
43
  from mindspore.ops.operations.array_ops import CountNonZero
53
- from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
54
44
  from mindspore.ops.operations.random_ops import LogNormalReverse
55
45
  from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal
56
46
  from mindspore.ops.operations import _inner_ops as inner
@@ -58,6 +48,18 @@ from mindspore.ops import functional as F
58
48
  from mindspore.ops import operations as P
59
49
  from mindspore.ops.operations import _grad_ops as G
60
50
  from mindspore import context
51
+ from mindspore.ops.primitive import _primexpr
52
+ from mindspore.common.sparse_tensor import RowTensorInner
53
+ from mindspore.ops._utils.utils import generate_shape_index
54
+
55
+ reduce_sum = P.ReduceSum()
56
+ unsorted_segment_sum = P.UnsortedSegmentSum()
57
+ transpose = P.Transpose()
58
+ shape_op = P.Shape()
59
+ reshape = P.Reshape()
60
+ size_op = P.Size()
61
+ invert_permutation = P.InvertPermutation()
62
+ logical_and = P.LogicalAnd()
61
63
 
62
64
 
63
65
  @constexpr
@@ -68,91 +70,28 @@ def _raise_value_error(*info):
68
70
  raise ValueError(info_str)
69
71
 
70
72
 
71
- @bprop_getters.register(P.FillV2)
72
- def get_bprop_fill_v2(self):
73
- """Generate bprop for FillV2"""
74
- sum_op = P.ReduceSum()
75
- cast_op = P.Cast()
76
- shape_op = P.TensorShape()
77
-
78
- def bprop(shape, value, out, dout):
79
- dout_type = F.dtype(dout)
80
- type_list = [
81
- mstype.int8, mstype.int16, mstype.int32, mstype.int64,
82
- mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
83
- mstype.float16, mstype.float64
84
- ]
85
- if dout_type in type_list:
86
- dout = cast_op(dout, mstype.float32)
87
- dout_shape = shape_op(dout)
88
- axis = tuple([i for i in range(len(dout_shape))])
89
- dvalue = sum_op(dout, axis)
90
- return zeros_like(shape), cast_op(dvalue, dout_type)
91
-
92
- return bprop
93
-
94
-
95
- @bprop_getters.register(StridedSliceV2)
96
- def get_bprop_strided_slice_v2(self):
97
- """Generate bprop for StridedSliceV2"""
98
- shape_op = P.Shape()
99
- dyn_shape_op = P.TensorShape()
100
- input_grad = StridedSliceV2Grad(self.begin_mask,
101
- self.end_mask,
102
- self.ellipsis_mask,
103
- self.new_axis_mask,
104
- self.shrink_axis_mask)
105
-
106
- def bprop(x, begin, end, strides, out, dout):
107
- x_shape = shape_op(x)
108
- if F.is_sequence_value_unknown(x_shape):
109
- x_shape = dyn_shape_op(x)
110
- dx = input_grad(x_shape, begin, end, strides, dout)
111
- dx_all = (dx, zeros_like(begin), zeros_like(end), zeros_like(strides))
112
- return dx_all
113
-
114
- return bprop
115
-
116
-
117
73
  @constexpr
118
74
  def _create_tensor(data, dtype):
119
75
  return Tensor(data, dtype=dtype)
120
76
 
121
77
 
122
- def _segment_min_or_max_grad(segment_sum_op, input_x, segment_ids, output, dout):
123
- """Calculate the gradient of SegmentMax or SegmentMin"""
124
- gather = P.Gather()
125
- equal = P.Equal()
126
- cast = P.Cast()
127
- divide = P.Div()
128
- input_x_type = F.dtype(input_x)
129
- input_x = cast(input_x, mstype.float32)
130
- output = cast(output, mstype.float32)
131
- dout = cast(dout, mstype.float32)
132
- zeros = zeros_like(input_x)
133
- gathered_outputs = gather(output, segment_ids, 0)
134
- is_selected = equal(input_x, gathered_outputs)
135
- num_selected = segment_sum_op(cast(is_selected, F.dtype(dout)), segment_ids)
136
- weighted_grads = divide(dout, num_selected)
137
- gathered_grads = gather(weighted_grads, segment_ids, 0)
138
- return cast(where(is_selected, gathered_grads, zeros), input_x_type), zeros_like(segment_ids)
139
-
140
-
141
78
  @bprop_getters.register(P.MaskedFill)
142
79
  def get_bprop_masked_select(self):
143
80
  """Generate bprop for MaskedFill"""
144
81
  mul_op = P.Mul()
145
82
  sum_op = P.ReduceSum()
146
83
  is_instance_op = inner.IsInstance()
84
+ rank = P.Rank()
147
85
 
148
86
  def bprop(input_data, mask, value, out, dout):
149
87
  mask = F.cast(mask, mstype.float32)
88
+ dout = F.cast(dout, mstype.float32)
150
89
  dinput = mul_op(dout, (1 - mask))
151
90
  dvalue = mul_op(dout, mask)
152
91
  dinput, dvalue = binop_grad_common(input_data, mask, dinput, dvalue)
153
92
  # for dynamic rank, reduce axis should be calc
154
93
  if F.is_sequence_shape_unknown(P.Shape()(dvalue)):
155
- axis = P.Range()(Tensor(0), dyn_rank(dvalue), Tensor(1))
94
+ axis = range(0, rank(dvalue), 1)
156
95
  dvalue = sum_op(dvalue, axis)
157
96
  else:
158
97
  dvalue = sum_op(dvalue)
@@ -169,37 +108,21 @@ def get_bprop_masked_select(self):
169
108
  @bprop_getters.register(MaskedScatter)
170
109
  def get_bprop_masked_scatter(self):
171
110
  """Generate bprop for MaskedScatter"""
172
- sort_ = P.Sort(descending=True)
173
- masked_scatter = MaskedScatter()
174
111
  masked_fill = P.MaskedFill()
175
112
  masked_select = P.MaskedSelect()
176
- size = P.Size()
177
- zeros = P.Zeros()
178
- concat = P.Concat(axis=0)
179
- reshape = P.Reshape()
180
- shape = P.Shape()
181
-
113
+ shape = P.TensorShape()
114
+ range_ = P.Range()
115
+ scatter_update = P.TensorScatterElements()
182
116
  def bprop(x, mask, updates, out, dout):
183
- dx = masked_fill(F.cast(dout, mstype.float32), mask, 0.0)
184
- mask_selected = masked_select(F.cast(dout, mstype.float32), mask)
185
- mask_broad = mask
186
- if shape(mask) != shape(x):
187
- broad_cast = P.BroadcastTo(shape(x))
188
- mask_broad = broad_cast(mask)
189
- mask_broad_vec = mask_broad.reshape(-1)
190
- mask_sorted = F.cast(sort_(F.cast(mask_broad_vec, mstype.float32))[0], F.dtype(mask))
191
- diff_num = size(updates) - size(mask_broad)
192
- if diff_num > 0:
193
- zeros_pad = zeros(diff_num, F.dtype(mask))
194
- mask_sorted = concat((mask_sorted, zeros_pad))
195
- zeros_tensor = zeros(size(updates), mstype.float32)
196
- dupdates = masked_scatter(zeros_tensor, mask_sorted, mask_selected)
197
- if shape(updates) != ():
198
- dupdates = reshape(dupdates, shape(updates))
199
- else:
200
- zeros_tensor = zeros(shape(updates), mstype.float32)
201
- dupdates = masked_scatter(zeros_tensor, mask, mask_selected)
202
- return F.cast(dx, F.dtype(x)), zeros_like(mask), F.cast(dupdates, F.dtype(updates))
117
+ dout = F.cast(dout, mstype.float32)
118
+ dx = masked_fill(dout, mask, F.cast(0, mstype.float32))
119
+ dupdates = F.cast(zeros_like(updates).reshape(-1), mstype.float32)
120
+ dupdates_val = F.cast(masked_select(dout, mask), mstype.float32)
121
+ length = F.cast(shape(dupdates_val)[0], mstype.int32)
122
+ scatter_indices = range_(F.cast(0, mstype.int32), length, F.cast(1, mstype.int32))
123
+ dupdates = scatter_update(dupdates, scatter_indices, dupdates_val)
124
+ dupdates = reshape(dupdates, shape(updates))
125
+ return F.cast(dx, x.dtype), zeros_like(mask), F.cast(dupdates, updates.dtype)
203
126
 
204
127
  return bprop
205
128
 
@@ -226,43 +149,19 @@ def get_bprop_mvlgamma(self):
226
149
  return bprop
227
150
 
228
151
 
229
- @bprop_getters.register(P.TensorScatterDiv)
230
- def get_bprop_tensor_scatter_div(self):
231
- """Generate bprop for TensorScatterDiv"""
232
- gather_nd = P.GatherNd()
233
- tensor_scatter_div = P.TensorScatterDiv()
234
- neg = P.Neg()
235
- div = P.Div()
236
- mul = P.Mul()
237
-
238
- def bprop(x, indices, update, out, dout):
239
- # (input)' / update
240
- in_grad = tensor_scatter_div(dout, indices, update)
241
-
242
- # - (input * (update)') / (update * update)
243
- gather_update = gather_nd(dout, indices)
244
- gather_x = gather_nd(x, indices)
245
- mul_result = mul(update, update)
246
- neg_result = neg(mul_result)
247
- update_grad = gather_update * div(gather_x, neg_result)
248
-
249
- return in_grad, zeros_like(indices), update_grad
250
-
251
- return bprop
252
-
253
-
254
152
  @bprop_getters.register(IndexFill)
255
153
  def get_bprop_index_fill(self):
256
154
  """Generate bprop for IndexFill"""
257
155
  gather = P.Gather()
258
156
  index_fill = IndexFill()
259
157
  shape = P.Shape()
158
+ rank = P.Rank()
260
159
 
261
160
  def bprop(x, dim, indices, value, out, dout):
262
161
  zero_value = zeros_like(value)
263
162
  x_grad = index_fill(dout, dim, indices, zero_value)
264
163
  if F.is_sequence_value_unknown(shape(x)):
265
- if dyn_rank(x) == 0:
164
+ if rank(x) == 0:
266
165
  value_grad = dout
267
166
  else:
268
167
  value_grad = gather(dout, indices, dim).sum()
@@ -286,6 +185,8 @@ def get_bprop_index_put(self):
286
185
  masked_select = MaskedSelect()
287
186
  masked_scatter = MaskedScatter()
288
187
  accumulate_grad = self.accumulate
188
+ equal = P.Equal()
189
+ cast = P.Cast()
289
190
  index_put = IndexPut(accumulate=accumulate_grad)
290
191
  is_ascend = context.get_context("device_target") == 'Ascend'
291
192
 
@@ -301,9 +202,10 @@ def get_bprop_index_put(self):
301
202
  indices_ms = [tile(x, (maxsize,)) if x.shape[0] == 1 else x for x in indices]
302
203
  if is_ascend:
303
204
  indices_ms = [convert_idx_positive(indices_ms[i], x1.shape[i]) for i in range(len(indices_ms))]
304
- indices_grad = stack(indices_ms).T
205
+ indices_me = stack(indices_ms)
206
+ indices_grad = F.transpose(indices_me, F.make_range(F.rank(indices_me)-1, -1, -1))
305
207
  values_grad = gather_nd(dout, indices_grad)
306
- if x2.shape[0] == 1:
208
+ if equal(cast(x2.shape[0], mstype.int32), Tensor(1)):
307
209
  values_grad = values_grad.sum().reshape(1)
308
210
  if values_grad.shape != x2.shape and len(indices) < len(x1.shape):
309
211
  _, values_grad = binop_grad_common(x1, x2, dout, values_grad)
@@ -314,50 +216,6 @@ def get_bprop_index_put(self):
314
216
  return bprop
315
217
 
316
218
 
317
- @bprop_getters.register(P.TensorScatterSub)
318
- def get_bprop_tensor_scatter_sub(self):
319
- """Generate bprop for TensorScatterSub"""
320
- gather_nd = P.GatherNd()
321
- neg = P.Neg()
322
-
323
- def bprop(x, indices, update, out, dout):
324
- update_grad = neg(gather_nd(dout, indices))
325
- return dout, zeros_like(indices), update_grad
326
-
327
- return bprop
328
-
329
-
330
- @bprop_getters.register(P.TensorScatterMul)
331
- def get_bprop_tensor_scatter_mul(self):
332
- """Generate bprop for TensorScatterMul"""
333
- gather_nd = P.GatherNd()
334
- mul_func = P.TensorScatterMul()
335
-
336
- def bprop(x, indices, update, out, dout):
337
- gather_update = gather_nd(dout, indices)
338
- gather_x = gather_nd(x, indices)
339
- dx = mul_func(dout, indices, update)
340
- d_update = gather_x * gather_update
341
- return dx, zeros_like(indices), d_update
342
-
343
- return bprop
344
-
345
-
346
- @bprop_getters.register(MatrixDiagV3)
347
- def get_bprop_matrix_diag_v3(self):
348
- """Generate bprop for MatrixDiagV3"""
349
- align = self.align
350
- matrix_diag_part_v3 = MatrixDiagPartV3(align=align)
351
- zeros = P.Zeros()
352
-
353
- def bprop(x, k, num_rows, num_cols, padding_value, out, dout):
354
- result = (matrix_diag_part_v3(dout, k, zeros((), dout.dtype)), zeros_like(k), zeros_like(num_rows),
355
- zeros_like(num_cols), zeros_like(padding_value))
356
- return result
357
-
358
- return bprop
359
-
360
-
361
219
  @bprop_getters.register(MatrixDiagPartV3)
362
220
  def get_bprop_matrix_diag_part_v3(self):
363
221
  """Generate bprop for MatrixDiagPartV3"""
@@ -380,6 +238,17 @@ def get_bprop_matrix_diag_part_v3(self):
380
238
  return bprop
381
239
 
382
240
 
241
+ @bprop_getters.register(MatrixBandPart)
242
+ def get_bprop_matrix_band_part(self):
243
+ """Grad definition for `MatrixBandPart` operation."""
244
+ matrix_band_part = MatrixBandPart()
245
+
246
+ def bprop(x, lower, upper, out, dout):
247
+ return matrix_band_part(dout, lower, upper), zeros_like(lower), zeros_like(upper)
248
+
249
+ return bprop
250
+
251
+
383
252
  @bprop_getters.register(MatrixSetDiagV3)
384
253
  def get_bprop_matrix_set_diag_v3(self):
385
254
  """Generate bprop for MatrixSetDiagV3"""
@@ -409,15 +278,11 @@ def tensor_scatter_possible_replacement(x, indices, updates, out, dout):
409
278
  scatter_nd = P.ScatterNd()
410
279
  equal = P.Equal()
411
280
  shape = P.Shape()
412
- dyn_shape_op = P.TensorShape()
413
281
 
414
282
  x_indicators = F.cast(equal(x, out), mstype.int32)
415
283
  possibly_updated = gather_nd(out, indices)
416
284
  out_indicators = F.cast(equal(updates, possibly_updated), mstype.int32)
417
285
  input_shape = shape(x)
418
- if F.is_sequence_value_unknown(input_shape):
419
- input_shape = dyn_shape_op(x)
420
-
421
286
  scattered_out_indicators = scatter_nd(indices, out_indicators, input_shape)
422
287
  indicators = x_indicators + scattered_out_indicators
423
288
  dx = dout * F.cast(x_indicators, F.dtype(dout)) / F.cast(indicators, F.dtype(dout))
@@ -474,80 +339,16 @@ def get_bprop_coalesce(self):
474
339
  return bprop
475
340
 
476
341
 
477
- @bprop_getters.register(ConjugateTranspose)
478
- def get_bprop_conjugate_transpose(self):
479
- """Generate bprop for ConjugateTranspose"""
480
- conjugate_transpose = ConjugateTranspose()
481
- invert_permutation = P.InvertPermutation()
482
-
483
- def bprop(x, perm, out, dout):
484
- return conjugate_transpose(dout, invert_permutation(perm)), zeros_like(perm)
485
-
486
- return bprop
487
-
488
-
489
- @bprop_getters.register(Triu)
490
- def get_bprop_triu(self):
491
- """Grad definition for 'Triu' operation"""
492
- diagonal = self.diagonal
493
- triu = Triu(diagonal)
494
-
495
- def bprop(x, out, dout):
496
- dx = triu(dout)
497
- return (dx,)
498
-
499
- return bprop
500
-
501
-
502
- @bprop_getters.register(CheckNumerics)
503
- def get_bprop_check_numerics(self):
504
- """Generate bprop for CheckNumerics"""
505
- check_numerics = CheckNumerics()
506
-
507
- def bprop(x_input, out, dout):
508
- return (check_numerics(dout),)
509
-
510
- return bprop
511
-
512
-
513
- @bprop_getters.register(P.SplitV)
514
- def get_bprop_split_v(self):
515
- """Generate bprop for SplitV"""
516
- split_dim = self.split_dim
517
- concat_op = P.Concat(split_dim)
518
-
519
- def bprop(x_input, output, dout):
520
- dx = concat_op(dout)
521
- return (dx,)
522
-
523
- return bprop
524
-
525
-
526
- @bprop_getters.register(IdentityN)
527
- def get_bprop_identity_n(self):
528
- """Generate bprop for IdentityN"""
529
-
530
- def bprop(x, out, dout):
531
- return (dout,)
532
-
533
- return bprop
534
-
535
-
536
342
  @bprop_getters.register(ResizeNearestNeighborV2)
537
343
  def get_bprop_resize_nearest_neighbor_v2(self):
538
344
  """Generate bprop for ResizeNearestNeighborV2"""
539
345
  align_corners = self.align_corners
540
346
  half_pixel_centers = self.half_pixel_centers
541
- data_format = self.data_format
542
- grad_op = G.ResizeNearestNeighborV2Grad(align_corners, half_pixel_centers, data_format)
347
+ grad_op = G.ResizeNearestNeighborV2Grad(align_corners, half_pixel_centers)
543
348
 
544
349
  def bprop(x, size, output, dout):
545
350
  x_shape = P.Shape()(x)
546
- if F.is_sequence_value_unknown(x_shape):
547
- x_shape = P.TensorShape()(x)
548
- grad_in_size = x_shape[1:3]
549
- if data_format == 'NCHW':
550
- grad_in_size = x_shape[2:4]
351
+ grad_in_size = x_shape[2:4]
551
352
 
552
353
  if F.is_sequence_value_unknown(P.Shape()(x)):
553
354
  dx = grad_op(dout, grad_in_size)
@@ -559,22 +360,6 @@ def get_bprop_resize_nearest_neighbor_v2(self):
559
360
  return bprop
560
361
 
561
362
 
562
- @bprop_getters.register(Col2Im)
563
- def get_bprop_col2im(self):
564
- """Generate bprop for Col2Im"""
565
- ksizes = self.kernel_size
566
- dilations = self.dilation
567
- strides = self.stride
568
- pads = self.padding
569
- im2col = Im2Col(ksizes=ksizes, dilations=dilations, strides=strides, pads=pads)
570
-
571
- def bprop(x, output_size, out, dout):
572
- dx = im2col(dout)
573
- return dx, zeros_like(output_size)
574
-
575
- return bprop
576
-
577
-
578
363
  @bprop_getters.register(Im2Col)
579
364
  def get_bprop_im2col(self):
580
365
  """
@@ -591,14 +376,13 @@ def get_bprop_im2col(self):
591
376
  dilation = self.dilations
592
377
  stride = self.strides
593
378
  padding = (self.pads[0], self.pads[-1])
594
- shape_op = P.TensorShape()
595
379
  col2im = Col2Im(kernel_size=kernel_size,
596
380
  dilation=dilation,
597
381
  stride=stride,
598
382
  padding=padding)
599
383
 
600
384
  def bprop(x, out, dout):
601
- x_shape = shape_op(x)[2:]
385
+ x_shape = P.TensorShape()(x)[2:]
602
386
  dx = col2im(dout, x_shape)
603
387
  return (dx,)
604
388
 
@@ -614,18 +398,16 @@ def get_bprop_extract_volume_patches(self):
614
398
  expend_dims = P.ExpandDims()
615
399
  scatter_nd = P.ScatterNd()
616
400
  slice_op = P.Slice()
617
- fill = P.Fill()
618
401
  dtype = P.DType()
619
402
  cast = P.Cast()
620
403
  matmul = P.MatMul()
621
404
  _, _, ksize_d, ksize_h, ksize_w = self.kernel_size
622
405
  range_ = P.Range()
623
- dyn_shape_op = P.TensorShape()
624
406
  ones_like = P.OnesLike()
625
407
 
626
408
  def _dyn_extract_volume_patches(x, out, dout):
627
- x_shape = dyn_shape_op(x)
628
- out_shape = dyn_shape_op(out)
409
+ x_shape = shape_op(x)
410
+ out_shape = shape_op(out)
629
411
  x_n, x_c, x_d, x_h, x_w = x_shape[0], x_shape[1], x_shape[2], x_shape[3], x_shape[4]
630
412
  x_indices_num = 1 + x_d * x_h * x_w
631
413
  x_idx = range_(cast(1, mstype.float32), cast(x_indices_num, mstype.float32), cast(1, mstype.float32))
@@ -683,7 +465,7 @@ def get_bprop_extract_volume_patches(self):
683
465
  idx_tensor = concat((expend_dims(x_idx_patched, -1), expend_dims(out_idx, -1)))
684
466
  idx_map = P.Reshape()(idx_tensor, (-1, 2))
685
467
  sp_shape = (x_indices_num, out_indices_num)
686
- sp_mat_full = scatter_nd(idx_map, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
468
+ sp_mat_full = scatter_nd(idx_map, F.fill(dtype(dout), (out_indices_num,), 1), sp_shape)
687
469
  sp_tensor = slice_op(sp_mat_full, (1, 0), (x_indices_num - 1, out_indices_num))
688
470
 
689
471
  grad = P.Transpose()(dout, (0, 2, 3, 4, 1))
@@ -700,19 +482,6 @@ def get_bprop_extract_volume_patches(self):
700
482
  return bprop
701
483
 
702
484
 
703
- @bprop_getters.register(Tril)
704
- def get_bprop_tril(self):
705
- """Grad definition for 'Tril' operation"""
706
- diagonal = self.diagonal
707
- tril = Tril(diagonal)
708
-
709
- def bprop(x, out, dout):
710
- dx = tril(dout)
711
- return (dx,)
712
-
713
- return bprop
714
-
715
-
716
485
  @bprop_getters.register(SegmentSum)
717
486
  def get_bprop_segment_sum(self):
718
487
  """Generate bprop for SegmentSum"""
@@ -738,16 +507,13 @@ def get_bprop_affinegrid(self):
738
507
  align_corners = self.align_corners
739
508
  input_grad = G.AffineGridGrad(align_corners)
740
509
  ones = P.Ones()
741
- transpose = P.Transpose()
742
510
  concat = P.Concat(1)
743
511
  concat0 = P.Concat(0)
744
512
  tile = P.Tile()
745
513
  div = P.Div()
746
- reshape = P.Reshape()
747
514
  linspace = P.LinSpace()
748
515
  batmatmul = P.BatchMatMul()
749
516
  expend_dims = P.ExpandDims()
750
- dyn_shape = P.TensorShape()
751
517
  reducesum = P.ReduceSum(keep_dims=False)
752
518
 
753
519
  def get_linspace(num):
@@ -846,7 +612,7 @@ def get_bprop_affinegrid(self):
846
612
  return transpose(dtheta, perm2), tre
847
613
 
848
614
  def dyn_bprop(theta, output_size, out, dout):
849
- len_output_size = reducesum(dyn_shape(output_size))
615
+ len_output_size = reducesum(shape_op(output_size))
850
616
  dtheta = dyn_ones(Tensor([1, 3, 2], mstype.int32), mstype.float32)
851
617
  ret = dyn_ones(Tensor([1, 6], mstype.int32), mstype.float32)
852
618
  if len_output_size == 5:
@@ -968,44 +734,6 @@ def get_bprop_affinegrid(self):
968
734
  return bprop
969
735
 
970
736
 
971
- @bprop_getters.register(SegmentMax)
972
- def get_bprop_segment_max(self):
973
- """Generate bprop for SegmentMax"""
974
- segment_sum = SegmentSum()
975
-
976
- def bprop(input_x, segment_ids, output, dout):
977
- return _segment_min_or_max_grad(segment_sum, input_x, segment_ids, output, dout)
978
-
979
- return bprop
980
-
981
-
982
- @bprop_getters.register(SegmentMin)
983
- def get_bprop_segment_min(self):
984
- """Generate bprop for SegmentMin"""
985
- segment_sum = SegmentSum()
986
-
987
- def bprop(input_x, segment_ids, output, dout):
988
- return _segment_min_or_max_grad(segment_sum, input_x, segment_ids, output, dout)
989
-
990
- return bprop
991
-
992
-
993
- @bprop_getters.register(TensorScatterElements)
994
- def get_bprop_tensor_scatter_elements(self):
995
- """Generate bprop for TensorScatterElements"""
996
- gather_d = P.GatherD()
997
- axis = self.axis
998
- reduction = self.reduction
999
- tensor_scatter_elements = TensorScatterElements(axis, reduction)
1000
-
1001
- def bprop(x, indices, update, out, dout):
1002
- x_grad = tensor_scatter_elements(dout, indices, zeros_like(update))
1003
- update_grad = gather_d(dout, axis, indices)
1004
- return x_grad, zeros_like(indices), update_grad
1005
-
1006
- return bprop
1007
-
1008
-
1009
737
  @bprop_getters.register(ScatterAddWithAxis)
1010
738
  def get_bprop_scatter_add_with_axis(self):
1011
739
  """Generate bprop for ScatterAddWithAxis"""
@@ -1066,38 +794,193 @@ def get_bprop_segment_mean(self):
1066
794
  """Generate bprop for SegmentMean"""
1067
795
  rank = P.Rank()
1068
796
  shape = P.Shape()
1069
- dyn_shape = P.TensorShape()
1070
- fill = P.Fill()
797
+ fill = P.FillV2()
1071
798
  divide = P.Div()
1072
799
  segment_sum = SegmentSum()
1073
800
  gather = P.Gather()
1074
801
  cast = P.Cast()
1075
- concat = P.Concat()
1076
- expand_dims = P.ExpandDims()
1077
802
 
1078
803
  def bprop(input_x, segment_ids, output, dout):
1079
804
  input_x_type = F.dtype(input_x)
1080
805
  input_x = cast(input_x, mstype.float32)
1081
806
  dout = cast(dout, mstype.float32)
1082
807
  dout_type = F.dtype(dout)
1083
-
1084
808
  ones_shape = shape(segment_ids)
1085
- if F.is_sequence_value_unknown(ones_shape):
1086
- ones_shape = dyn_shape(segment_ids)
1087
-
1088
- ones = ()
1089
- inputx_shape = shape(input_x)
1090
- if F.is_sequence_value_unknown(inputx_shape):
1091
- input_rank = dyn_rank(input_x)
1092
- if input_rank > cast(1, mstype.float32):
1093
- ones_shape = concat([ones_shape, dyn_ones(expand_dims(input_rank - 1, 0), mstype.int64)])
1094
- ones = dyn_fill(dout_type, ones_shape, 1)
1095
- else:
1096
- input_rank = rank(input_x)
1097
- ones_shape = ones_shape + (1,) * (input_rank - 1)
1098
- ones = fill(dout_type, ones_shape, 1)
1099
-
809
+ input_rank = rank(input_x)
810
+ ones_shape = ones_shape + (1,) * (input_rank - 1)
811
+ ones = fill(ones_shape, Tensor(1, dout_type))
1100
812
  scaled_grad = divide(dout, segment_sum(ones, segment_ids))
1101
813
  return cast(gather(scaled_grad, segment_ids, 0), input_x_type), zeros_like(segment_ids)
1102
814
 
1103
815
  return bprop
816
+
817
+
818
+ @bprop_getters.register(P.Ones)
819
+ def get_bprop_ones(self):
820
+ """Generate bprop for Ones"""
821
+
822
+ def bprop(dims, dtype, out, dout):
823
+ return zeros_like(dims)
824
+
825
+ return bprop
826
+
827
+
828
+ @bprop_getters.register(P.Zeros)
829
+ def get_bprop_zeros(self):
830
+ """Generate bprop for Zeros"""
831
+
832
+ def bprop(dims, dtype, out, dout):
833
+ return zeros_like(dims)
834
+
835
+ return bprop
836
+
837
+
838
+ @bprop_getters.register(P.EmbeddingLookup)
839
+ def get_bprop_embedding_lookup(self):
840
+ """Generate bprop for EmbeddingLookup"""
841
+ sub_op = P.Sub()
842
+ reshape_op = P.Reshape()
843
+
844
+ def bprop_sparse(x, indices, offset, out, dout):
845
+ x_shp = shape_op(x)
846
+ if F.is_sequence_value_unknown(x_shp):
847
+ raise RuntimeError("Now, EmbeddingLookup op's grad don't support Dynamic Sense!")
848
+ new_indices = sub_op(indices, offset)
849
+ indices_size = size_op(new_indices)
850
+ if indices_size > 0:
851
+ # Reshape the 'new_indices'
852
+ new_indices_shape_changed = (indices_size,)
853
+ new_indices = reshape_op(new_indices, new_indices_shape_changed)
854
+ else:
855
+ new_indices_shape_changed = ()
856
+ x_shp_tail = x_shp[1:]
857
+ actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
858
+ # Reshape the 'actual_dout' on device
859
+ actual_dout = reshape_op(dout, actual_dout_shape_changed)
860
+ return RowTensorInner(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
861
+
862
+ return bprop_sparse
863
+
864
+
865
+ @_primexpr
866
+ def _generate_inverse_index(x_shape, axis, batch_dims=0):
867
+ x_rank = len(x_shape)
868
+ index = tuple(range(x_rank))
869
+ if axis < 0:
870
+ axis += x_rank
871
+ perm = index[:batch_dims] + index[batch_dims + 1:1 + axis] + (index[batch_dims],) + index[1 + axis:]
872
+ return perm
873
+
874
+
875
+ @bprop_getters.register(P.SparseGatherV2)
876
+ def get_bprop_sparse_gather_v2(self):
877
+ """Generate bprop for SparseGatherV2"""
878
+
879
+ def bprop(x, indices, axis, out, dout):
880
+ x_shp = shape_op(x)
881
+ if axis == 0:
882
+ indices_size = (size_op(indices),)
883
+ if len(x_shp) <= 1:
884
+ x_tail_shp = ()
885
+ else:
886
+ x_tail_shp = x_shp[1:]
887
+ values_shape = indices_size + x_tail_shp
888
+ values = reshape(dout, values_shape)
889
+ indices_new = reshape(indices, indices_size)
890
+ return RowTensorInner(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis)
891
+ if F.rank(dout) == 0:
892
+ dout = P.ExpandDims()(dout, -1)
893
+ if F.rank(indices) == 0:
894
+ indices = P.ExpandDims()(indices, -1)
895
+ out_shp = shape_op(dout)
896
+ ind_shp = shape_op(indices)
897
+ # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
898
+ perm_1 = generate_shape_index(out_shp, ind_shp, axis)
899
+ values_transpose = transpose(dout, perm_1)
900
+ params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
901
+ # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
902
+ perm_2 = _generate_inverse_index(x_shp, axis)
903
+ params_grad = transpose(params_grad, perm_2)
904
+ return params_grad, zeros_like(indices), zeros_like(axis)
905
+
906
+ return bprop
907
+
908
+
909
+ @bprop_getters.register(P.Unstack)
910
+ def get_bprop_unstack(self):
911
+ """Generate bprop for Unstack"""
912
+ axis = self.axis
913
+
914
+ def bprop(x, out, dout):
915
+ unstack_grad = P.Stack(axis)
916
+ out = unstack_grad(dout)
917
+ return (out,)
918
+
919
+ return bprop
920
+
921
+
922
+ @bprop_getters.register(P.Eye)
923
+ def get_bprop_eye(self):
924
+ """Generate bprop for Eye"""
925
+
926
+ def bprop(n, m, t, out, dout):
927
+ return zeros_like(n), zeros_like(m), zeros_like(t)
928
+
929
+ return bprop
930
+
931
+
932
+ @bprop_getters.register(P.ScatterNdUpdate)
933
+ def get_bprop_scatter_nd_update(self):
934
+ """Generate bprop for ScatterNdUpdate"""
935
+ op = P.GatherNd()
936
+
937
+ def bprop(x, indices, update, out, dout):
938
+ return dout, zeros_like(indices), op(dout, indices)
939
+
940
+ return bprop
941
+
942
+
943
+ @bprop_getters.register(P.ScatterNonAliasingAdd)
944
+ def get_bprop_scatter_non_aliasing_add_update(self):
945
+ """Generate bprop for ScatterNonAliasingAdd"""
946
+ op = P.GatherNd()
947
+
948
+ def bprop(x, indices, update, out, dout):
949
+ return dout, zeros_like(indices), op(dout, indices)
950
+
951
+ return bprop
952
+
953
+
954
+ @bprop_getters.register(P.ScatterUpdate)
955
+ def get_bprop_scatter_update(self):
956
+ """Generate bprop for ScatterUpdate"""
957
+ gather = P.Gather()
958
+
959
+ def bprop(x, indices, update, out, dout):
960
+ return dout, zeros_like(indices), gather(dout, indices, 0)
961
+
962
+ return bprop
963
+
964
+
965
+ @bprop_getters.register(P.TransShape)
966
+ def get_bprop_trans_shape(self):
967
+ """Generate bprop for TransShape"""
968
+ op = P.TransShape()
969
+
970
+ def bprop(x, shape, out, dout):
971
+ dx = op(dout, shape_op(x))
972
+ return (dx, zeros_like(shape))
973
+
974
+ return bprop
975
+
976
+
977
+ @bprop_getters.register(P.Unique)
978
+ def get_bprop_unique(self):
979
+ """Generate bprop for Unique"""
980
+ op = G.UniqueGrad()
981
+
982
+ def bprop(x, out, dout):
983
+ dx = op(dout, out)
984
+ return (dx,)
985
+
986
+ return bprop