mindspore 1.10.0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (944) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
  3. mindspore/__init__.py +9 -4
  4. mindspore/_akg/akg/composite/build_module.py +11 -0
  5. mindspore/_akg/akg/config/repository_cuda.json +11 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
  7. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  8. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  9. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  10. mindspore/_check_jit_forbidden_api.py +102 -0
  11. mindspore/_checkparam.py +1066 -1001
  12. mindspore/_extends/builtin_operations.py +32 -4
  13. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  21. mindspore/_extends/parse/__init__.py +5 -3
  22. mindspore/_extends/parse/namespace.py +17 -2
  23. mindspore/_extends/parse/parser.py +193 -34
  24. mindspore/_extends/parse/resources.py +7 -8
  25. mindspore/_extends/parse/standard_method.py +1780 -435
  26. mindspore/_extends/parse/trope.py +3 -1
  27. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  28. mindspore/amp.py +53 -58
  29. mindspore/bin/cache_admin +0 -0
  30. mindspore/bin/cache_server +0 -0
  31. mindspore/boost/adasum.py +3 -2
  32. mindspore/boost/boost.py +2 -2
  33. mindspore/boost/boost_cell_wrapper.py +46 -26
  34. mindspore/boost/dim_reduce.py +6 -5
  35. mindspore/boost/grad_accumulation.py +2 -1
  36. mindspore/boost/group_loss_scale_manager.py +1 -1
  37. mindspore/common/__init__.py +11 -10
  38. mindspore/common/_decorator.py +2 -0
  39. mindspore/common/_register_for_adapter.py +55 -0
  40. mindspore/common/_stub_tensor.py +201 -0
  41. mindspore/common/_utils.py +57 -0
  42. mindspore/common/api.py +582 -297
  43. mindspore/common/dtype.py +66 -18
  44. mindspore/common/dump.py +2 -2
  45. mindspore/common/initializer.py +38 -1
  46. mindspore/common/jit_config.py +25 -13
  47. mindspore/common/mutable.py +53 -24
  48. mindspore/common/parameter.py +60 -37
  49. mindspore/common/seed.py +8 -24
  50. mindspore/common/sparse_tensor.py +927 -0
  51. mindspore/common/tensor.py +1627 -3900
  52. mindspore/communication/__init__.py +10 -5
  53. mindspore/communication/_comm_helper.py +78 -214
  54. mindspore/communication/_hccl_management.py +2 -1
  55. mindspore/communication/management.py +136 -47
  56. mindspore/config/op_info.config +501 -1008
  57. mindspore/config/super_bar_config.json +512 -0
  58. mindspore/context.py +291 -56
  59. mindspore/dataset/__init__.py +12 -8
  60. mindspore/dataset/audio/__init__.py +9 -9
  61. mindspore/dataset/audio/transforms.py +1090 -228
  62. mindspore/dataset/audio/utils.py +87 -39
  63. mindspore/dataset/audio/validators.py +223 -1
  64. mindspore/dataset/callback/ds_callback.py +17 -15
  65. mindspore/dataset/core/config.py +246 -17
  66. mindspore/dataset/core/py_util_helpers.py +4 -3
  67. mindspore/dataset/core/validator_helpers.py +10 -10
  68. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  69. mindspore/dataset/debug/debug_hook.py +65 -0
  70. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  71. mindspore/dataset/engine/__init__.py +7 -3
  72. mindspore/dataset/engine/cache_client.py +9 -9
  73. mindspore/dataset/engine/datasets.py +648 -477
  74. mindspore/dataset/engine/datasets_audio.py +165 -167
  75. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  76. mindspore/dataset/engine/datasets_text.py +492 -342
  77. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  78. mindspore/dataset/engine/datasets_vision.py +1224 -699
  79. mindspore/dataset/engine/graphdata.py +134 -69
  80. mindspore/dataset/engine/iterators.py +50 -9
  81. mindspore/dataset/engine/offload.py +52 -31
  82. mindspore/dataset/engine/samplers.py +27 -24
  83. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  84. mindspore/dataset/engine/validators.py +213 -52
  85. mindspore/dataset/text/__init__.py +10 -8
  86. mindspore/dataset/text/transforms.py +152 -57
  87. mindspore/dataset/text/utils.py +98 -49
  88. mindspore/dataset/text/validators.py +25 -0
  89. mindspore/dataset/transforms/__init__.py +4 -2
  90. mindspore/dataset/transforms/c_transforms.py +11 -13
  91. mindspore/dataset/transforms/py_transforms.py +2 -2
  92. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  93. mindspore/dataset/transforms/transforms.py +13 -15
  94. mindspore/dataset/transforms/validators.py +7 -7
  95. mindspore/dataset/utils/__init__.py +2 -1
  96. mindspore/dataset/utils/browse_dataset.py +13 -13
  97. mindspore/dataset/utils/line_reader.py +121 -0
  98. mindspore/dataset/vision/__init__.py +8 -7
  99. mindspore/dataset/vision/c_transforms.py +125 -126
  100. mindspore/dataset/vision/py_transforms.py +37 -37
  101. mindspore/dataset/vision/py_transforms_util.py +23 -20
  102. mindspore/dataset/vision/transforms.py +316 -315
  103. mindspore/dataset/vision/utils.py +313 -17
  104. mindspore/dataset/vision/validators.py +6 -6
  105. mindspore/default_config.py +0 -1
  106. mindspore/{compression → experimental}/__init__.py +6 -5
  107. mindspore/experimental/map_parameter.py +275 -0
  108. mindspore/include/OWNERS +0 -1
  109. mindspore/include/api/callback/callback.h +9 -13
  110. mindspore/include/api/callback/ckpt_saver.h +2 -2
  111. mindspore/include/api/callback/loss_monitor.h +2 -2
  112. mindspore/include/api/callback/lr_scheduler.h +5 -5
  113. mindspore/include/api/callback/time_monitor.h +2 -2
  114. mindspore/include/api/callback/train_accuracy.h +4 -6
  115. mindspore/include/api/cfg.h +19 -6
  116. mindspore/include/api/context.h +70 -9
  117. mindspore/include/api/delegate.h +8 -1
  118. mindspore/include/api/dual_abi_helper.h +8 -24
  119. mindspore/include/api/metrics/accuracy.h +2 -2
  120. mindspore/include/api/metrics/metrics.h +4 -3
  121. mindspore/include/api/model.h +9 -4
  122. mindspore/include/api/model_group.h +68 -0
  123. mindspore/include/api/model_parallel_runner.h +17 -17
  124. mindspore/include/api/net.h +12 -11
  125. mindspore/include/api/serialization.h +20 -4
  126. mindspore/include/api/status.h +7 -1
  127. mindspore/include/api/types.h +25 -21
  128. mindspore/include/api/visible.h +4 -0
  129. mindspore/include/c_api/model_c.h +5 -0
  130. mindspore/include/c_api/status_c.h +1 -1
  131. mindspore/include/dataset/config.h +1 -1
  132. mindspore/include/dataset/constants.h +14 -0
  133. mindspore/include/dataset/text.h +59 -0
  134. mindspore/include/dataset/vision.h +56 -117
  135. mindspore/include/dataset/vision_lite.h +102 -0
  136. mindspore/include/mindapi/base/type_id.h +42 -3
  137. mindspore/lib/libdnnl.so.2 +0 -0
  138. mindspore/lib/libicudata.so.69 +0 -0
  139. mindspore/lib/libicui18n.so.69 +0 -0
  140. mindspore/lib/libicuuc.so.69 +0 -0
  141. mindspore/lib/libmindspore.so +0 -0
  142. mindspore/lib/libmindspore_backend.so +0 -0
  143. mindspore/lib/libmindspore_common.so +0 -0
  144. mindspore/lib/libmindspore_core.so +0 -0
  145. mindspore/lib/libmindspore_glog.so.0 +0 -0
  146. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  147. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  148. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  149. mindspore/lib/libmindspore_shared_lib.so +0 -0
  150. mindspore/lib/libmpi_adapter.so +0 -0
  151. mindspore/lib/libmpi_collective.so +0 -0
  152. mindspore/lib/libnnacl.so +0 -0
  153. mindspore/lib/libopencv_core.so.4.5 +0 -0
  154. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  155. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  156. mindspore/lib/libps_cache.so +0 -0
  157. mindspore/lib/plugin/ascend/libakg.so +0 -0
  158. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  159. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  160. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  161. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  163. mindspore/lib/{libakg.so → plugin/cpu/libakg.so} +0 -0
  164. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  165. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  166. mindspore/log.py +28 -28
  167. mindspore/mindrecord/common/exceptions.py +2 -4
  168. mindspore/mindrecord/filereader.py +19 -1
  169. mindspore/mindrecord/filewriter.py +250 -88
  170. mindspore/mindrecord/mindpage.py +13 -13
  171. mindspore/mindrecord/shardheader.py +15 -15
  172. mindspore/mindrecord/shardreader.py +9 -0
  173. mindspore/mindrecord/shardwriter.py +29 -29
  174. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  175. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  176. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  177. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  178. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  179. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  180. mindspore/nn/__init__.py +1 -5
  181. mindspore/nn/cell.py +297 -234
  182. mindspore/nn/dynamic_lr.py +1 -1
  183. mindspore/nn/grad/cell_grad.py +17 -42
  184. mindspore/nn/layer/__init__.py +7 -4
  185. mindspore/nn/layer/activation.py +131 -88
  186. mindspore/nn/layer/basic.py +313 -613
  187. mindspore/nn/layer/channel_shuffle.py +103 -0
  188. mindspore/nn/layer/combined.py +1 -1
  189. mindspore/nn/layer/container.py +52 -6
  190. mindspore/nn/layer/conv.py +112 -43
  191. mindspore/nn/layer/dense.py +10 -9
  192. mindspore/nn/layer/embedding.py +36 -34
  193. mindspore/nn/layer/image.py +123 -27
  194. mindspore/nn/layer/math.py +108 -107
  195. mindspore/nn/layer/normalization.py +212 -366
  196. mindspore/nn/layer/padding.py +370 -42
  197. mindspore/nn/layer/pooling.py +1443 -219
  198. mindspore/nn/layer/rnn_cells.py +11 -16
  199. mindspore/nn/layer/rnns.py +38 -39
  200. mindspore/nn/layer/thor_layer.py +24 -25
  201. mindspore/nn/layer/timedistributed.py +5 -5
  202. mindspore/nn/layer/transformer.py +701 -0
  203. mindspore/nn/learning_rate_schedule.py +8 -8
  204. mindspore/nn/loss/__init__.py +9 -6
  205. mindspore/nn/loss/loss.py +678 -142
  206. mindspore/nn/metrics.py +53 -0
  207. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  208. mindspore/nn/optim/ada_grad.py +8 -8
  209. mindspore/nn/optim/adadelta.py +2 -3
  210. mindspore/nn/optim/adafactor.py +18 -14
  211. mindspore/nn/optim/adam.py +429 -87
  212. mindspore/nn/optim/adamax.py +5 -6
  213. mindspore/nn/optim/adasum.py +10 -8
  214. mindspore/nn/optim/asgd.py +7 -7
  215. mindspore/nn/optim/ftrl.py +81 -11
  216. mindspore/nn/optim/lamb.py +7 -8
  217. mindspore/nn/optim/lars.py +4 -4
  218. mindspore/nn/optim/lazyadam.py +82 -7
  219. mindspore/nn/optim/momentum.py +8 -7
  220. mindspore/nn/optim/optimizer.py +19 -10
  221. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  222. mindspore/nn/optim/rmsprop.py +3 -3
  223. mindspore/nn/optim/rprop.py +20 -16
  224. mindspore/nn/optim/sgd.py +21 -15
  225. mindspore/nn/optim/thor.py +23 -21
  226. mindspore/nn/probability/__init__.py +0 -2
  227. mindspore/nn/probability/bijector/bijector.py +7 -6
  228. mindspore/nn/probability/bijector/invert.py +4 -2
  229. mindspore/nn/probability/bijector/softplus.py +2 -2
  230. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  231. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  232. mindspore/nn/probability/distribution/__init__.py +6 -0
  233. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  234. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  235. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  236. mindspore/nn/probability/distribution/beta.py +1 -1
  237. mindspore/nn/probability/distribution/categorical.py +9 -9
  238. mindspore/nn/probability/distribution/cauchy.py +8 -8
  239. mindspore/nn/probability/distribution/distribution.py +12 -6
  240. mindspore/nn/probability/distribution/exponential.py +5 -5
  241. mindspore/nn/probability/distribution/gamma.py +3 -3
  242. mindspore/nn/probability/distribution/geometric.py +6 -5
  243. mindspore/nn/probability/distribution/gumbel.py +5 -5
  244. mindspore/nn/probability/distribution/half_normal.py +133 -0
  245. mindspore/nn/probability/distribution/laplace.py +128 -0
  246. mindspore/nn/probability/distribution/log_normal.py +0 -1
  247. mindspore/nn/probability/distribution/logistic.py +4 -5
  248. mindspore/nn/probability/distribution/normal.py +11 -15
  249. mindspore/nn/probability/distribution/poisson.py +6 -2
  250. mindspore/nn/probability/distribution/student_t.py +150 -0
  251. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  252. mindspore/nn/probability/distribution/uniform.py +5 -5
  253. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  254. mindspore/nn/reinforcement/tensor_array.py +2 -2
  255. mindspore/nn/sparse/sparse.py +8 -1
  256. mindspore/nn/wrap/cell_wrapper.py +55 -27
  257. mindspore/nn/wrap/grad_reducer.py +20 -11
  258. mindspore/nn/wrap/loss_scale.py +47 -30
  259. mindspore/numpy/array_creations.py +33 -22
  260. mindspore/numpy/array_ops.py +46 -42
  261. mindspore/numpy/logic_ops.py +6 -27
  262. mindspore/numpy/math_ops.py +26 -19
  263. mindspore/numpy/utils.py +1 -8
  264. mindspore/numpy/utils_const.py +112 -62
  265. mindspore/ops/__init__.py +6 -3
  266. mindspore/ops/_constants.py +0 -6
  267. mindspore/ops/_grad/__init__.py +2 -1
  268. mindspore/ops/_grad/grad_array_ops.py +209 -152
  269. mindspore/ops/_grad/grad_base.py +55 -17
  270. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  271. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  272. mindspore/ops/_grad/grad_implementations.py +21 -61
  273. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  274. mindspore/ops/_grad/grad_math_ops.py +306 -161
  275. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  276. mindspore/ops/_grad/grad_other_ops.py +1 -1
  277. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  278. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  279. mindspore/ops/_grad/grad_sparse.py +15 -9
  280. mindspore/ops/_grad_experimental/__init__.py +1 -0
  281. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  282. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  283. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  284. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  285. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  286. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  287. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  288. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  289. mindspore/ops/_op_impl/__init__.py +3 -3
  290. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  291. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  292. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  293. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  294. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  295. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  296. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  297. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  298. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  299. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  300. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  302. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  303. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  304. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  305. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  307. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  314. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  315. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  317. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  318. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  320. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  321. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  322. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  324. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  326. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  327. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  328. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  329. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  330. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  331. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  332. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  335. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  336. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  337. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  338. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  339. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  340. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  341. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  342. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  343. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  344. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  345. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  346. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  347. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  349. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  350. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  351. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  352. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  353. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  354. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  355. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  356. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  357. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  358. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  359. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  360. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  361. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  362. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  363. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  365. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  366. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  367. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  368. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  369. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  370. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  371. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  372. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  373. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  374. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  375. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  376. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  377. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  378. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  379. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  380. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  381. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  382. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  383. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  384. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  385. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  386. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  387. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  388. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  389. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  390. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  391. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  392. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  393. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  394. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  395. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  396. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  397. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  398. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  399. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  400. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  401. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  402. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  403. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  404. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  405. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  407. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  408. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  409. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  410. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  411. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  412. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  413. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  414. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  415. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  416. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  417. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  418. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  419. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  420. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  421. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  422. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  423. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  424. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  425. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  426. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  427. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  428. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  429. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  430. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  431. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  432. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  433. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  434. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  435. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  436. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  437. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  441. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  442. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  443. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  444. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  445. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  446. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  447. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  448. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  449. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  450. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  451. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  452. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  453. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  454. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  455. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  456. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  457. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  458. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  459. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  460. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  461. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  462. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  463. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  464. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  465. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  466. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  467. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  468. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  469. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  470. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  471. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  472. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  473. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  474. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  475. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  476. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  477. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  478. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  479. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  480. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  481. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  482. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  483. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  484. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  485. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  486. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  487. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  488. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  489. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  490. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  491. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  492. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  493. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  494. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  495. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  496. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  497. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  498. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  499. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  500. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  501. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  502. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  503. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  504. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  505. mindspore/ops/_primitive_cache.py +3 -2
  506. mindspore/ops/_register_for_op.py +11 -0
  507. mindspore/ops/_utils/__init__.py +1 -1
  508. mindspore/ops/_utils/utils.py +20 -41
  509. mindspore/ops/_vmap/__init__.py +2 -2
  510. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  511. mindspore/ops/_vmap/vmap_base.py +24 -10
  512. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  513. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  514. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  515. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  516. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  517. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  518. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  519. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  520. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  521. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  522. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  524. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  525. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  526. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  527. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  528. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  529. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  530. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  532. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  534. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  536. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  537. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  538. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  541. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  542. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  543. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  544. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  545. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  546. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  547. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  548. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  549. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  550. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  551. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  552. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  554. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  555. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  556. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  560. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  561. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  563. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  564. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  565. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  566. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  567. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  568. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  569. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  571. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  572. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  573. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  574. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  575. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  577. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  578. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  579. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  580. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  581. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  582. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  583. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  585. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  586. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  587. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  588. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  589. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  590. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  591. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  592. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  593. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  594. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  595. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  596. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  597. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  598. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  599. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  600. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  601. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  604. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  605. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  606. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  607. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  609. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  610. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  611. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  612. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  613. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  614. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  615. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  616. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  617. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  618. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  619. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  620. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  621. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  622. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  623. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  624. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  625. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  626. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  627. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  628. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  629. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  630. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  632. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  633. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  634. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  635. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  637. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  638. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  639. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  640. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  643. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  644. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  645. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  646. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  647. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  648. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  649. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  650. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  651. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  652. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  653. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  654. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  655. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  656. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  657. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  658. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  660. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  661. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  662. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  663. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  664. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  665. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  666. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  667. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  668. mindspore/ops/bprop_mindir/__init__.py +1 -4
  669. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  670. mindspore/ops/composite/__init__.py +12 -13
  671. mindspore/ops/composite/base.py +261 -254
  672. mindspore/ops/composite/env_ops.py +41 -0
  673. mindspore/ops/composite/math_ops.py +197 -156
  674. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  675. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  676. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  677. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  678. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  679. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  680. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  681. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  682. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  683. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  684. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  685. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  686. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  687. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  688. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  689. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  690. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  691. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  692. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  693. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  694. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  695. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  696. mindspore/ops/function/__init__.py +323 -8
  697. mindspore/ops/function/array_func.py +3511 -780
  698. mindspore/ops/function/clip_func.py +329 -0
  699. mindspore/ops/function/debug_func.py +6 -6
  700. mindspore/ops/function/grad/__init__.py +5 -1
  701. mindspore/ops/function/grad/grad_func.py +736 -65
  702. mindspore/ops/function/image_func.py +270 -0
  703. mindspore/ops/function/linalg_func.py +268 -8
  704. mindspore/ops/function/math_func.py +8032 -3164
  705. mindspore/ops/function/nn_func.py +5619 -1855
  706. mindspore/ops/function/other_func.py +115 -0
  707. mindspore/ops/function/parameter_func.py +11 -10
  708. mindspore/ops/function/random_func.py +939 -77
  709. mindspore/ops/function/sparse_func.py +249 -84
  710. mindspore/ops/function/sparse_unary_func.py +2303 -0
  711. mindspore/ops/function/spectral_func.py +146 -0
  712. mindspore/ops/function/vmap_func.py +114 -0
  713. mindspore/ops/functional.py +182 -254
  714. mindspore/ops/op_info_register.py +79 -34
  715. mindspore/ops/operations/__init__.py +210 -118
  716. mindspore/ops/operations/_csr_ops.py +7 -7
  717. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  718. mindspore/ops/operations/_grad_ops.py +447 -322
  719. mindspore/ops/operations/_inner_ops.py +547 -176
  720. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  721. mindspore/ops/operations/_ms_kernel.py +29 -27
  722. mindspore/ops/operations/_ocr_ops.py +11 -11
  723. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  724. mindspore/ops/operations/_quant_ops.py +186 -101
  725. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  726. mindspore/ops/operations/_scalar_ops.py +466 -0
  727. mindspore/ops/operations/_sequence_ops.py +1047 -0
  728. mindspore/ops/operations/_tensor_array.py +10 -11
  729. mindspore/ops/operations/_thor_ops.py +4 -4
  730. mindspore/ops/operations/array_ops.py +1428 -1226
  731. mindspore/ops/operations/comm_ops.py +180 -117
  732. mindspore/ops/operations/control_ops.py +4 -2
  733. mindspore/ops/operations/custom_ops.py +185 -98
  734. mindspore/ops/operations/debug_ops.py +92 -54
  735. mindspore/ops/operations/image_ops.py +406 -211
  736. mindspore/ops/operations/inner_ops.py +42 -53
  737. mindspore/ops/operations/linalg_ops.py +32 -29
  738. mindspore/ops/operations/math_ops.py +2076 -897
  739. mindspore/ops/operations/nn_ops.py +1282 -1252
  740. mindspore/ops/operations/other_ops.py +124 -278
  741. mindspore/ops/operations/random_ops.py +345 -178
  742. mindspore/ops/operations/rl_ops.py +8 -9
  743. mindspore/ops/operations/sparse_ops.py +502 -157
  744. mindspore/ops/operations/spectral_ops.py +107 -0
  745. mindspore/ops/primitive.py +192 -15
  746. mindspore/ops/vm_impl_registry.py +23 -2
  747. mindspore/parallel/__init__.py +6 -1
  748. mindspore/parallel/_auto_parallel_context.py +199 -92
  749. mindspore/parallel/_cell_wrapper.py +4 -2
  750. mindspore/parallel/_cost_model_context.py +3 -0
  751. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  752. mindspore/parallel/_offload_context.py +185 -0
  753. mindspore/parallel/_parallel_serialization.py +167 -28
  754. mindspore/parallel/_ps_context.py +9 -5
  755. mindspore/parallel/_recovery_context.py +1 -1
  756. mindspore/parallel/_tensor.py +9 -1
  757. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  758. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  759. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  760. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  761. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  762. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  763. mindspore/parallel/_utils.py +47 -7
  764. mindspore/parallel/algo_parameter_config.py +5 -1
  765. mindspore/parallel/checkpoint_transform.py +329 -0
  766. mindspore/parallel/shard.py +229 -0
  767. mindspore/profiler/__init__.py +2 -1
  768. mindspore/profiler/common/util.py +4 -3
  769. mindspore/profiler/common/validator/validate_path.py +2 -2
  770. mindspore/profiler/envprofiling.py +249 -0
  771. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  772. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  773. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  774. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  775. mindspore/profiler/parser/framework_parser.py +42 -16
  776. mindspore/profiler/parser/hccl_parser.py +158 -158
  777. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  778. mindspore/profiler/parser/integrator.py +18 -1579
  779. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  780. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  781. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  782. mindspore/profiler/parser/optime_parser.py +17 -18
  783. mindspore/profiler/parser/profiler_info.py +108 -0
  784. mindspore/profiler/parser/step_trace_parser.py +1 -1
  785. mindspore/profiler/profiling.py +396 -194
  786. mindspore/rewrite/__init__.py +6 -2
  787. mindspore/rewrite/api/node.py +51 -110
  788. mindspore/rewrite/api/node_type.py +10 -6
  789. mindspore/rewrite/api/pattern_engine.py +51 -7
  790. mindspore/rewrite/api/scoped_value.py +64 -53
  791. mindspore/rewrite/api/symbol_tree.py +108 -61
  792. mindspore/rewrite/api/tree_node_helper.py +2 -3
  793. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  794. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  795. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  796. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  797. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  798. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  799. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  800. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  801. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  802. mindspore/rewrite/common/__init__.py +2 -0
  803. mindspore/rewrite/common/event.py +1 -1
  804. mindspore/rewrite/common/observable.py +1 -1
  805. mindspore/rewrite/common/observer.py +1 -1
  806. mindspore/rewrite/common/rewrite_elog.py +35 -0
  807. mindspore/rewrite/namer.py +2 -2
  808. mindspore/rewrite/namespace.py +14 -4
  809. mindspore/rewrite/node.py +161 -13
  810. mindspore/rewrite/parser.py +0 -1
  811. mindspore/rewrite/parser_register.py +0 -1
  812. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  813. mindspore/rewrite/parsers/assign_parser.py +267 -67
  814. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  815. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  816. mindspore/rewrite/parsers/constant_parser.py +101 -0
  817. mindspore/rewrite/parsers/container_parser.py +88 -0
  818. mindspore/rewrite/parsers/for_parser.py +28 -15
  819. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  820. mindspore/rewrite/parsers/if_parser.py +11 -28
  821. mindspore/rewrite/parsers/module_parser.py +9 -6
  822. mindspore/rewrite/parsers/return_parser.py +3 -2
  823. mindspore/rewrite/sparsify/__init__.py +0 -0
  824. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  825. mindspore/rewrite/sparsify/sparsify.py +109 -0
  826. mindspore/rewrite/sparsify/utils.py +173 -0
  827. mindspore/rewrite/symbol_tree.py +322 -109
  828. mindspore/rewrite/symbol_tree_builder.py +45 -8
  829. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  830. mindspore/rewrite/topological_manager.py +1 -2
  831. mindspore/run_check/_check_version.py +209 -112
  832. mindspore/run_check/run_check.py +2 -1
  833. mindspore/scipy/linalg.py +13 -117
  834. mindspore/scipy/ops.py +5 -71
  835. mindspore/scipy/ops_grad.py +1 -25
  836. mindspore/scipy/ops_wrapper.py +1 -1
  837. mindspore/scipy/optimize/_bfgs.py +1 -1
  838. mindspore/scipy/optimize/_lagrange.py +200 -0
  839. mindspore/scipy/optimize/line_search.py +3 -2
  840. mindspore/scipy/optimize/minimize.py +43 -6
  841. mindspore/scipy/sparse/__init__.py +2 -2
  842. mindspore/scipy/sparse/linalg.py +5 -465
  843. mindspore/scipy/utils.py +2 -1
  844. mindspore/scipy/utils_const.py +7 -1
  845. mindspore/train/__init__.py +6 -4
  846. mindspore/train/_utils.py +28 -5
  847. mindspore/train/amp.py +321 -50
  848. mindspore/train/callback/__init__.py +3 -1
  849. mindspore/train/callback/_backup_and_restore.py +120 -0
  850. mindspore/train/callback/_callback.py +8 -8
  851. mindspore/train/callback/_checkpoint.py +12 -9
  852. mindspore/train/callback/_early_stop.py +13 -7
  853. mindspore/train/callback/_history.py +8 -8
  854. mindspore/train/callback/_lambda_callback.py +6 -6
  855. mindspore/train/callback/_landscape.py +36 -38
  856. mindspore/train/callback/_loss_monitor.py +12 -6
  857. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  858. mindspore/train/callback/_on_request_exit.py +212 -0
  859. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  860. mindspore/train/callback/_summary_collector.py +27 -19
  861. mindspore/train/callback/_time_monitor.py +13 -7
  862. mindspore/train/checkpoint_pb2.py +68 -8
  863. mindspore/train/data_sink.py +122 -33
  864. mindspore/train/dataset_helper.py +28 -87
  865. mindspore/train/loss_scale_manager.py +4 -7
  866. mindspore/{nn → train}/metrics/__init__.py +20 -20
  867. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  868. mindspore/{nn → train}/metrics/auc.py +4 -4
  869. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  870. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  871. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  872. mindspore/{nn → train}/metrics/dice.py +6 -5
  873. mindspore/{nn → train}/metrics/error.py +7 -5
  874. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  875. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  876. mindspore/{nn → train}/metrics/loss.py +4 -3
  877. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  878. mindspore/{nn → train}/metrics/metric.py +6 -5
  879. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  880. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  881. mindspore/{nn → train}/metrics/precision.py +5 -4
  882. mindspore/{nn → train}/metrics/recall.py +5 -4
  883. mindspore/{nn → train}/metrics/roc.py +7 -6
  884. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  885. mindspore/{nn → train}/metrics/topk.py +7 -5
  886. mindspore/train/mind_ir_pb2.py +339 -32
  887. mindspore/train/model.py +113 -84
  888. mindspore/train/serialization.py +547 -167
  889. mindspore/train/summary/_summary_adapter.py +1 -1
  890. mindspore/train/summary/summary_record.py +43 -12
  891. mindspore/train/train_thor/convert_utils.py +7 -1
  892. mindspore/train/train_thor/dataset_helper.py +3 -3
  893. mindspore/train/train_thor/model_thor.py +0 -4
  894. mindspore/version.py +1 -1
  895. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  896. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +899 -675
  897. mindspore/compression/common/constant.py +0 -124
  898. mindspore/compression/export/__init__.py +0 -19
  899. mindspore/compression/export/quant_export.py +0 -514
  900. mindspore/compression/quant/qat.py +0 -636
  901. mindspore/compression/quant/quant_utils.py +0 -462
  902. mindspore/compression/quant/quantizer.py +0 -68
  903. mindspore/nn/layer/quant.py +0 -1868
  904. mindspore/nn/layer/rnn_utils.py +0 -90
  905. mindspore/nn/probability/dpn/__init__.py +0 -22
  906. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  907. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  908. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  909. mindspore/nn/probability/infer/__init__.py +0 -22
  910. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  911. mindspore/nn/probability/infer/variational/svi.py +0 -84
  912. mindspore/nn/probability/toolbox/__init__.py +0 -22
  913. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  914. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  915. mindspore/nn/probability/transforms/__init__.py +0 -22
  916. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  917. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  918. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  919. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  920. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  921. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  922. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  923. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  924. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  925. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  926. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  927. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  928. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  929. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  930. mindspore/ops/composite/array_ops.py +0 -210
  931. mindspore/ops/composite/clip_ops.py +0 -238
  932. mindspore/ops/composite/random_ops.py +0 -426
  933. mindspore/ops/composite/vmap_ops.py +0 -38
  934. mindspore/ops/operations/sponge_ops.py +0 -3531
  935. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  936. mindspore/parallel/nn/__init__.py +0 -42
  937. mindspore/parallel/nn/loss.py +0 -22
  938. mindspore/parallel/nn/moe.py +0 -21
  939. mindspore/parallel/nn/op_parallel_config.py +0 -22
  940. mindspore/parallel/nn/transformer.py +0 -31
  941. mindspore/run_check/_check_deps_version.py +0 -84
  942. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  943. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  944. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -17,19 +17,23 @@
17
17
 
18
18
  import numpy as np
19
19
  import mindspore as ms
20
- from mindspore.ops import composite as C
20
+ from mindspore import Tensor
21
21
  from mindspore.ops import operations as P
22
22
  from mindspore.ops.operations import _grad_ops as G
23
23
  from mindspore.ops.operations.array_ops import Fills, NonZero
24
24
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
25
25
  from mindspore.ops.functional import broadcast_gradient_args
26
26
  from mindspore.ops import functional as F
27
- from mindspore.ops._grad.grad_base import bprop_getters
27
+ from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element
28
28
  from mindspore.ops.primitive import constexpr
29
+ from mindspore.ops.primitive import _primexpr
29
30
  from mindspore.common import dtype as mstype
30
- from mindspore.common.tensor import RowTensor
31
- from mindspore.ops._utils.utils import range_op, get_1d_shape, generate_shape_index, is_shape_unknown
32
- from .._grad.grad_base import dyn_rank, convert_to_tensor, dyn_invert_permutation, dyn_size, dyn_ones, dyn_fill
31
+ from mindspore.common.sparse_tensor import RowTensorInner
32
+ from mindspore.ops._utils.utils import range_op, get_1d_shape, generate_shape_index
33
+ from mindspore.ops._grad.grad_base import dyn_rank, convert_to_tensor, dyn_ones, dyn_fill
34
+ from mindspore.ops._grad.grad_base import sum_grad_reduce_axis
35
+ from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs
36
+ from ..operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass
33
37
 
34
38
  reduce_sum = P.ReduceSum()
35
39
  unsorted_segment_sum = P.UnsortedSegmentSum()
@@ -40,7 +44,7 @@ reshape = P.Reshape()
40
44
  size_op = P.Size()
41
45
  invert_permutation = P.InvertPermutation()
42
46
  logical_and = P.LogicalAnd()
43
- is_sub_class = P.IsSubClass()
47
+ is_sub_class = IsSubClass()
44
48
 
45
49
 
46
50
  @bprop_getters.register(P.Fill)
@@ -93,46 +97,6 @@ def get_bprop_dtype(self):
93
97
  return bprop
94
98
 
95
99
 
96
- dout_cast = C.MultitypeFuncGraph("dout_cast")
97
-
98
-
99
- @dout_cast.register("Tensor", "Tensor")
100
- def dout_cast_tensor(dout, x):
101
- """Casts dout to the dtype of x for Tensor."""
102
- cast = P.Cast()
103
- get_dtype = P.DType()
104
- dx = cast(dout, get_dtype(x))
105
- return dx
106
-
107
-
108
- @dout_cast.register("Number", "Number")
109
- def dout_cast_number(dout, x):
110
- """Casts dout to the dtype of x for Number."""
111
- cast = P.Cast()
112
- get_dtype = P.DType()
113
- dx = cast(dout, get_dtype(x))
114
- return dx
115
-
116
-
117
- @dout_cast.register("RowTensor", "Tensor")
118
- def dout_cast_row_tensor(dout, x):
119
- """Casts dout values to the dtype of x for RowTensor."""
120
- cast = P.Cast()
121
- get_dtype = P.DType()
122
- values = cast(dout.values, get_dtype(x))
123
- return RowTensor(dout.indices, values, dout.dense_shape)
124
-
125
-
126
- @bprop_getters.register(P.Cast)
127
- def get_bprop_cast(self):
128
- """Generate bprop for Cast"""
129
- def bprop(x, t, out, dout):
130
- dx = dout_cast(dout, x)
131
- return dx, zeros_like(t)
132
-
133
- return bprop
134
-
135
-
136
100
  @bprop_getters.register(P.Shape)
137
101
  def get_bprop_shape(self):
138
102
  """Generate bprop for Shape"""
@@ -152,6 +116,7 @@ def get_bprop_dynamicshape(self):
152
116
 
153
117
  return bprop
154
118
 
119
+
155
120
  @bprop_getters.register(P.TensorShape)
156
121
  def get_bprop_tensorshape(self):
157
122
  """Generate bprop for TensorShape"""
@@ -191,7 +156,7 @@ def get_bprop_reshape(self):
191
156
 
192
157
  def bprop(x, shp, out, dout):
193
158
  shapex = shape_op(x)
194
- if is_shape_unknown(shapex):
159
+ if F.is_sequence_value_unknown(shapex):
195
160
  shapex = dyn_shape_op(x)
196
161
  return reshape(dout, shapex), zeros_like(shp)
197
162
 
@@ -204,7 +169,7 @@ def get_bprop_expand_dims(self):
204
169
 
205
170
  def bprop(x, axis, out, dout):
206
171
  shapex = shape_op(x)
207
- if is_shape_unknown(shapex):
172
+ if F.is_sequence_value_unknown(shapex):
208
173
  shapex = dyn_shape_op(x)
209
174
  return reshape(dout, shapex), zeros_like(axis)
210
175
 
@@ -217,7 +182,7 @@ def get_bprop_squeeze(self):
217
182
 
218
183
  def bprop(x, out, dout):
219
184
  shapex = shape_op(x)
220
- if is_shape_unknown(shapex):
185
+ if F.is_sequence_value_unknown(shapex):
221
186
  shapex = dyn_shape_op(x)
222
187
  return (reshape(dout, shapex),)
223
188
 
@@ -230,13 +195,16 @@ def get_bprop_flatten(self):
230
195
  flatten_grad = P.Reshape()
231
196
 
232
197
  def bprop(x, out, dout):
233
- dx = flatten_grad(dout, shape_op(x))
198
+ shape_x = shape_op(x)
199
+ if F.is_sequence_value_unknown(shape_x):
200
+ shape_x = dyn_shape_op(x)
201
+ dx = flatten_grad(dout, shape_x)
234
202
  return (dx,)
235
203
 
236
204
  return bprop
237
205
 
238
206
 
239
- @constexpr
207
+ @_primexpr
240
208
  def _tile_shape(multiples, shapex):
241
209
  """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
242
210
  len_muli = len(multiples)
@@ -268,40 +236,45 @@ def _tile_shape(multiples, shapex):
268
236
  @bprop_getters.register(P.Tile)
269
237
  def get_bprop_tile(self):
270
238
  """Generate bprop for Tile"""
271
- tuple_to_array = P.TupleToArray()
272
239
  cast = P.Cast()
273
- stack_op = P.Stack(1)
274
- ones = P.Ones()
275
240
  concat = P.Concat()
241
+ stridedslice = P.StridedSlice()
242
+
243
+ def get_reduce_axis(r_shape):
244
+ """
245
+ reshape grad to r_shape, and reduce along all even dimensions to get the result with input_shape
246
+ For example:
247
+ input_shape = [20, 30, 40]
248
+ multiples = [2, 3, 4]
249
+ r_shape = [2, 20, 3, 30, 4, 40]
250
+ axis = [0, 2, 4]
251
+ """
252
+ rankr = dyn_shape_op(r_shape)[0]
253
+ tmp = range_op(0, 20, 2, mstype.int64)
254
+ return stridedslice(tmp, (0,), F.expand_dims(rankr // 2, 0), (1,))
276
255
 
277
256
  def bprop(x, multiples, out, dout):
278
257
  shapex = shape_op(x)
279
- if is_shape_unknown(shapex):
258
+ if F.is_sequence_value_unknown(shapex):
280
259
  shapex = dyn_shape_op(x)
281
- # if shapex or multiples not tuple, it should be dynamic shape.
282
260
  if isinstance(multiples, tuple) and isinstance(shapex, tuple):
283
261
  r_shape = _tile_shape(multiples, shapex)
262
+ # 0 represents the start index, and 2 represents the step
263
+ axis = F.make_range(0, len(r_shape), 2)
284
264
  else:
285
- if isinstance(multiples, tuple):
286
- multiples = tuple_to_array(multiples)
287
- multiples = cast(multiples, mstype.int64)
288
- len_multi = size_op(multiples)
289
- rank = len(shapex)
290
- if isinstance(shapex, tuple):
291
- shape_tensor = cast(tuple_to_array(shapex), mstype.int64)
292
- else:
293
- shape_tensor = shapex
294
- if len_multi > rank:
295
- one_tensor = ones((len_multi - rank,), mstype.int64)
296
- shape_tensor = concat((one_tensor, shape_tensor))
297
- elif len_multi < rank:
298
- one_tensor = ones((rank - len_multi,), mstype.int64)
299
- multiples = concat((one_tensor, multiples))
300
- tile_shape = stack_op((multiples, shape_tensor))
265
+ shapex = dyn_shape_op(x)
266
+ shapey = create_tensor_by_element(multiples)
267
+ rankx = dyn_rank(x)
268
+ ranky = dyn_shape_op(shapey)[0]
269
+ offset = F.expand_dims(ranky - rankx + 1, 0)
270
+ shape_x = concat((dyn_ones(offset, mstype.int64), shapex))
271
+ shape_x = shape_x[1:]
272
+ shapey = concat((P.Ones()((1,), mstype.int64), shapey))
273
+ shapey = shapey[1:]
274
+ tile_shape = P.Stack(1)((shapey, shape_x))
301
275
  r_shape = P.Reshape()(tile_shape, (-1,))
276
+ axis = get_reduce_axis(r_shape)
302
277
 
303
- # 0 represents the start index, and 2 represents the step
304
- axis = F.make_range(0, len(r_shape), 2)
305
278
  dout_reshaped = P.Reshape()(dout, r_shape)
306
279
  dout_origin_dtype = dout_reshaped.dtype
307
280
  # Currently, for Ascend and GPU, the reduce_sum's input does not support int16, int32 and int64.
@@ -325,6 +298,8 @@ def get_bprop_embedding_lookup(self):
325
298
 
326
299
  def bprop_sparse(x, indices, offset, out, dout):
327
300
  x_shp = shape_op(x)
301
+ if F.is_sequence_value_unknown(x_shp):
302
+ raise RuntimeError("Now, EmbeddingLookup op's grad don't support Dynamic Sense!")
328
303
  new_indices = sub_op(indices, offset)
329
304
  indices_size = size_op(new_indices)
330
305
  if indices_size > 0:
@@ -337,61 +312,43 @@ def get_bprop_embedding_lookup(self):
337
312
  actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
338
313
  # Reshape the 'actual_dout' on device
339
314
  actual_dout = reshape_op(dout, actual_dout_shape_changed)
340
- return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
315
+ return RowTensorInner(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
341
316
 
342
317
  return bprop_sparse
343
318
 
344
319
 
345
- @constexpr
320
+ @_primexpr
346
321
  def make_begin(shp):
347
322
  """Creates a tuple with zero according to the shape."""
348
323
  begin = tuple([0 for _ in shp])
349
324
  return begin
350
325
 
351
326
 
327
+ def make_dynamic_begin(shp):
328
+ """Creates a tuple with zero according to the shape."""
329
+ begin = zeros_like(shp)
330
+ return begin
331
+
332
+
352
333
  @bprop_getters.register(P.Padding)
353
334
  def get_bprop_padding(self):
354
335
  """Grad definition for `Padding` operation."""
355
336
 
356
337
  def bprop(x, out, dout):
357
338
  shp = shape_op(x)
358
- begin = make_begin(shp)
339
+ begin = ()
340
+ if F.is_sequence_value_unknown(shp):
341
+ shp = dyn_shape_op(x)
342
+ begin = make_dynamic_begin(shp)
343
+ else:
344
+ begin = make_begin(shp)
359
345
  dx = P.Slice()(dout, begin, shp)
360
346
  return (dx,)
361
347
 
362
348
  return bprop
363
349
 
364
350
 
365
- @constexpr
366
- def _transpose_perm_positive(perm):
367
- res = []
368
- for value in perm:
369
- value = value if (value >= 0) else (value + len(perm))
370
- res.append(value)
371
- return tuple(res)
372
-
373
-
374
- def _dyn_transpose_perm_positive(perm):
375
- return (perm + dyn_size(perm)) % (dyn_size(perm))
376
-
377
-
378
- @bprop_getters.register(P.Transpose)
379
- def get_bprop_transpose(self):
380
- """Generate bprop for Transpose"""
381
-
382
- def bprop(x, perm, out, dout):
383
- is_mutable, perm = convert_to_tensor(perm)
384
- if is_mutable:
385
- perm = _dyn_transpose_perm_positive(perm)
386
- return transpose(dout, dyn_invert_permutation(perm)), zeros_like(perm)
387
-
388
- perm = _transpose_perm_positive(perm)
389
- return transpose(dout, invert_permutation(perm)), zeros_like(perm)
390
-
391
- return bprop
392
-
393
-
394
- @constexpr
351
+ @_primexpr
395
352
  def _concat_grad_uniform(input_shapes, input_nums):
396
353
  """Helper function for bprop of Concat"""
397
354
  is_uniform = True
@@ -444,12 +401,6 @@ def get_bprop_concat(self):
444
401
  return bprop
445
402
 
446
403
 
447
- @constexpr
448
- def _slice_grad_pad(begins, sizes, shapes):
449
- pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
450
- return pads
451
-
452
-
453
404
  @bprop_getters.register(P.Slice)
454
405
  def get_bprop_slice(self):
455
406
  """Generate bprop for Slice"""
@@ -461,17 +412,17 @@ def get_bprop_slice(self):
461
412
  return bprop
462
413
 
463
414
 
464
- @constexpr
465
- def _generate_inverse_index(x_shape, axis):
415
+ @_primexpr
416
+ def _generate_inverse_index(x_shape, axis, batch_dims=0):
466
417
  x_rank = len(x_shape)
467
418
  index = tuple(range(x_rank))
468
419
  if axis < 0:
469
420
  axis += x_rank
470
- perm = index[1:1 + axis] + (0,) + index[1 + axis:]
421
+ perm = index[:batch_dims] + index[batch_dims + 1:1 + axis] + (index[batch_dims],) + index[1 + axis:]
471
422
  return perm
472
423
 
473
424
 
474
- @constexpr
425
+ @_primexpr
475
426
  def _regenerate_output_shape(x_shp, ind_shp, axis):
476
427
  rank = len(x_shp)
477
428
  if axis < 0:
@@ -480,12 +431,99 @@ def _regenerate_output_shape(x_shp, ind_shp, axis):
480
431
  return out_shape
481
432
 
482
433
 
434
+ def _dyn_regenerate_output_shape(x_shp, ind_shp, axis):
435
+ """Get reshape new_shape"""
436
+ rank = dyn_shape_op(x_shp)[0]
437
+ if axis < 0:
438
+ axis += rank
439
+ out_shape = P.Concat(0)((x_shp[:axis], ind_shp, x_shp[axis + 1:]))
440
+ return out_shape
441
+
442
+
443
+ def _dyn_generate_shape_index(out_shape, indices_shape, axis, batch_dims=0):
444
+ """Get tranpose order"""
445
+ out_rank = F.reshape(dyn_shape_op(out_shape), ())
446
+ ind_rank = F.reshape(dyn_shape_op(indices_shape), ())
447
+ if axis < 0:
448
+ axis += out_rank - ind_rank + 1
449
+ perm_part1 = P.Range()(F.cast(0, mstype.int32), F.cast(20, mstype.int32), F.cast(1, mstype.int32))
450
+ ind_end = axis + ind_rank - batch_dims
451
+ perm_part1 = perm_part1[axis: ind_end]
452
+ index = P.Range()(F.cast(0, mstype.int32), F.cast(out_rank, mstype.int32), F.cast(1, mstype.int32))
453
+ perm = F.hstack((index[:batch_dims], perm_part1, index[batch_dims:axis], index[ind_end:]))
454
+ return perm
455
+
456
+
457
+ def _dyn_generate_inverse_index(x_shp, axis, batch_dims=0):
458
+ """Get tranpose order"""
459
+ x_rank = F.reshape(dyn_shape_op(x_shp), ())
460
+ index = P.Range()(F.cast(0, mstype.int32), F.cast(x_rank, mstype.int32), F.cast(1, mstype.int32))
461
+ if axis < 0:
462
+ axis += x_rank
463
+ perm = F.hstack((index[:batch_dims], index[batch_dims + 1:1 + axis], index[batch_dims], index[1 + axis:]))
464
+ return perm
465
+
466
+
467
+ def calculate_batch_gather(values, indices, x_shape, axis, batch_dims):
468
+ """Calculate gather grad with batch_dims"""
469
+ values_shape = dyn_shape_op(values)
470
+ batch_size = F.prod(x_shape[:batch_dims])
471
+ batch_size = F.cast(batch_size, mstype.int32)
472
+ axis_dim = F.cast(x_shape[axis], mstype.int32)
473
+
474
+ # Move batch dimension to first non-batch dimension
475
+ values = values.reshape((-1,) + values.shape[batch_dims:])
476
+ indices = indices.reshape((-1,) + indices.shape[batch_dims:])
477
+ offset = P.Range()(F.cast(0, mstype.int32), batch_size * axis_dim, axis_dim)
478
+ offset_shape = F.hstack([batch_size] + [Tensor(1, dtype=mstype.int32) for _ in range(len(indices.shape) - 1)])
479
+ offset = reshape(offset, offset_shape)
480
+ indices = indices + offset
481
+ num_segments = batch_size * axis_dim
482
+ params_grad = unsorted_segment_sum(values, indices, num_segments)
483
+ grad_shape = dyn_shape_op(params_grad)
484
+ ret_shape = F.hstack([values_shape[:batch_dims], F.cast(axis_dim, mstype.int64), grad_shape[1:]])
485
+ params_grad = reshape(params_grad, ret_shape)
486
+ return params_grad
487
+
488
+
483
489
  @bprop_getters.register(P.Gather)
484
490
  @bprop_getters.register(P.GatherV2)
485
491
  def get_bprop_gather_v2(self):
486
492
  """Generate bprop for GatherV2"""
487
493
 
494
+ def _dyn_bprop_gather_v2(x, indices, axis, dout):
495
+ """dyn shape bprop for GatherV2"""
496
+ orig_indices = indices
497
+ x_shp = dyn_shape_op(x)
498
+ ind_shp = dyn_shape_op(indices)
499
+ out_shp = dyn_shape_op(dout)
500
+ batch_dims = self.batch_dims
501
+ if batch_dims < 0:
502
+ batch_dims += F.reshape(dyn_shape_op(ind_shp), ())
503
+
504
+ if F.rank(dout) == 0:
505
+ dout = P.ExpandDims()(dout, -1)
506
+ if F.rank(indices) == 0:
507
+ indices = P.ExpandDims()(indices, -1)
508
+ out_shp = _dyn_regenerate_output_shape(x_shp, ind_shp, axis)
509
+ dout = reshape(dout, out_shp)
510
+
511
+ # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
512
+ perm_1 = _dyn_generate_shape_index(out_shp, ind_shp, axis, batch_dims)
513
+ values_transpose = transpose(dout, perm_1)
514
+ if batch_dims > 0:
515
+ params_grad = calculate_batch_gather(values_transpose, indices, x_shp, axis, batch_dims)
516
+ else:
517
+ params_grad = unsorted_segment_sum(values_transpose, indices, x_shp[axis])
518
+ perm_2 = _dyn_generate_inverse_index(x_shp, axis, batch_dims)
519
+ params_grad = transpose(params_grad, perm_2)
520
+ return params_grad, zeros_like(orig_indices), zeros_like(axis)
521
+
488
522
  def bprop(x, indices, axis, out, dout):
523
+ is_mutable, axis = convert_to_tensor(axis)
524
+ if (F.is_sequence_value_unknown(shape_op(x)) or F.is_sequence_value_unknown(shape_op(indices)) or \
525
+ F.is_sequence_value_unknown(shape_op(dout))) and is_mutable:
526
+ return _dyn_bprop_gather_v2(x, indices, axis, dout)
489
527
  orig_indices = indices
490
528
  if F.rank(dout) == 0:
491
529
  dout = P.ExpandDims()(dout, -1)
@@ -499,15 +537,19 @@ def get_bprop_gather_v2(self):
499
537
  x_shp = shape_op(x)
500
538
  out_shp = shape_op(dout)
501
539
  ind_shp = shape_op(indices)
540
+ batch_dims = self.batch_dims
541
+ if batch_dims < 0:
542
+ batch_dims += len(ind_shp)
502
543
  # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
503
- perm_1 = generate_shape_index(out_shp, ind_shp, axis)
544
+ perm_1 = generate_shape_index(out_shp, ind_shp, axis, batch_dims)
504
545
  values_transpose = transpose(dout, perm_1)
505
- if is_shape_unknown(shape_op(x)):
506
- params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis])
546
+ dyn_x_sape = dyn_shape_op(x)
547
+ if batch_dims > 0:
548
+ params_grad = calculate_batch_gather(values_transpose, indices, dyn_x_sape, axis, batch_dims)
507
549
  else:
508
- params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
550
+ params_grad = unsorted_segment_sum(values_transpose, indices, dyn_x_sape[axis])
509
551
  # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
510
- perm_2 = _generate_inverse_index(x_shp, axis)
552
+ perm_2 = _generate_inverse_index(x_shp, axis, batch_dims)
511
553
  params_grad = transpose(params_grad, perm_2)
512
554
  return params_grad, zeros_like(orig_indices), zeros_like(axis)
513
555
 
@@ -519,17 +561,19 @@ def get_bprop_gather_d(self):
519
561
  """Generate bprop for GatherD"""
520
562
 
521
563
  def bprop(x, dim, index, out, dout):
522
- dx = G.GatherDGradV2(dim)(x, index, dout)
564
+ dx = G.GatherDGradV2()(x, dim, index, dout)
523
565
  return dx, zeros_like(dim), zeros_like(index)
524
566
 
525
567
  return bprop
526
568
 
569
+
527
570
  @bprop_getters.register(G.GatherDGrad)
528
571
  def get_bprop_gather_d_grad(self):
529
572
  """Generate bprop for GatherDGrad"""
530
573
  op = P.Gather()
531
574
  dim = self.dim
532
575
  x_shp = self.out_shape
576
+
533
577
  def bprop(index, x, out, dout):
534
578
  index_shp = shape_op(index)
535
579
  dim_before_axis = 1
@@ -538,7 +582,7 @@ def get_bprop_gather_d_grad(self):
538
582
  dim_at_axis_index = index_shp[dim]
539
583
  dim_at_axis_output = x_shp[dim]
540
584
  dim_after_axis = 1
541
- for i in range(dim+1, len(x_shp)):
585
+ for i in range(dim + 1, len(x_shp)):
542
586
  dim_after_axis *= x_shp[i]
543
587
  element = dim_before_axis * dim_at_axis_index * dim_after_axis
544
588
  id_ = range_op(0, element, 1, index.dtype)
@@ -547,7 +591,7 @@ def get_bprop_gather_d_grad(self):
547
591
  j = P.Cast()(index < 0, index.dtype)
548
592
  j_read = dim_at_axis_index * j + index
549
593
  j_read = P.Reshape()(j_read, (-1,))
550
- read_id = i*dim_at_axis_output*dim_after_axis + j_read * dim_after_axis + k
594
+ read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k
551
595
  dout = P.Reshape()(dout, (-1,))
552
596
  dx = op(dout, read_id, 0)
553
597
  dx = P.Reshape()(dx, shape_op(x))
@@ -561,6 +605,7 @@ def get_bprop_gather_d_grad_v2(self):
561
605
  """Generate bprop for GatherDGradV2"""
562
606
  op = P.Gather()
563
607
  dim = self.dim
608
+
564
609
  def bprop(index, x, out, dout):
565
610
  index_shp = shape_op(index)
566
611
  dim_before_axis = 1
@@ -570,7 +615,7 @@ def get_bprop_gather_d_grad_v2(self):
570
615
  dim_at_axis_index = index_shp[dim]
571
616
  dim_at_axis_output = x_shp[dim]
572
617
  dim_after_axis = 1
573
- for i in range(dim+1, len(x_shp)):
618
+ for i in range(dim + 1, len(x_shp)):
574
619
  dim_after_axis *= x_shp[i]
575
620
  element = dim_before_axis * dim_at_axis_index * dim_after_axis
576
621
  id_ = range_op(0, element, 1, index.dtype)
@@ -579,7 +624,7 @@ def get_bprop_gather_d_grad_v2(self):
579
624
  j = P.Cast()(index < 0, index.dtype)
580
625
  j_read = dim_at_axis_index * j + index
581
626
  j_read = P.Reshape()(j_read, (-1,))
582
- read_id = i*dim_at_axis_output*dim_after_axis + j_read * dim_after_axis + k
627
+ read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k
583
628
  dout = P.Reshape()(dout, (-1,))
584
629
  dx = op(dout, read_id, 0)
585
630
  dx = P.Reshape()(dx, shape_op(x))
@@ -587,6 +632,7 @@ def get_bprop_gather_d_grad_v2(self):
587
632
 
588
633
  return bprop
589
634
 
635
+
590
636
  @bprop_getters.register(P.SparseGatherV2)
591
637
  def get_bprop_sparse_gather_v2(self):
592
638
  """Generate bprop for SparseGatherV2"""
@@ -602,7 +648,7 @@ def get_bprop_sparse_gather_v2(self):
602
648
  values_shape = indices_size + x_tail_shp
603
649
  values = reshape(dout, values_shape)
604
650
  indices_new = reshape(indices, indices_size)
605
- return RowTensor(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis)
651
+ return RowTensorInner(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis)
606
652
  if F.rank(dout) == 0:
607
653
  dout = P.ExpandDims()(dout, -1)
608
654
  if F.rank(indices) == 0:
@@ -715,7 +761,7 @@ def get_bprop_stack(self):
715
761
  axis = self.axis
716
762
 
717
763
  def bprop(x, out, dout):
718
- stack_grad = P.Unstack(axis)
764
+ stack_grad = P.Unstack(num=len(x), axis=axis)
719
765
  out = stack_grad(dout)
720
766
  if is_sub_class(F.typeof(x), ms.list_):
721
767
  ret = []
@@ -764,7 +810,7 @@ def get_bprop_strided_slice(self):
764
810
 
765
811
  def bprop(x, begin, end, strides, out, dout):
766
812
  x_shape = shape_op(x)
767
- if is_shape_unknown(x_shape):
813
+ if F.is_sequence_value_unknown(x_shape):
768
814
  x_shape = dyn_shape_op(x)
769
815
  dx = input_grad(dout, x_shape, begin, end, strides)
770
816
  return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
@@ -780,9 +826,10 @@ def get_bprop_strided_slice_grad(self):
780
826
  ellipsis_mask=self.ellipsis_mask,
781
827
  new_axis_mask=self.new_axis_mask,
782
828
  shrink_axis_mask=self.shrink_axis_mask)
829
+
783
830
  def bprop(dy, shapex, begin, end, strides, out, dout):
784
831
  return strided_slice(dout, begin, end, strides), zeros_like(shapex), zeros_like(begin), zeros_like(end), \
785
- zeros_like(strides)
832
+ zeros_like(strides)
786
833
 
787
834
  return bprop
788
835
 
@@ -835,7 +882,7 @@ def get_bprop_resize_nearest_neighbor(self):
835
882
  tensor_shape = P.TensorShape()
836
883
 
837
884
  def bprop(inputs, out, dout):
838
- if (-1 in shape_op(inputs)) or (-2 in shape_op(inputs)):
885
+ if F.is_sequence_value_unknown(shape_op(inputs)) or F.is_sequence_shape_unknown(shape_op(inputs)):
839
886
  shp = tensor_shape(inputs)
840
887
  else:
841
888
  shp = shape_op(inputs)
@@ -853,7 +900,7 @@ def get_bprop_gather_nd(self):
853
900
 
854
901
  def bprop(x, indices, out, dout):
855
902
  shp = shape_op(x)
856
- if is_shape_unknown(shp):
903
+ if F.is_sequence_value_unknown(shp):
857
904
  shp = dyn_shape_op(x)
858
905
  return op(indices, dout, shp), zeros_like(indices)
859
906
 
@@ -1029,18 +1076,14 @@ def _gather_drop_negatives(params,
1029
1076
  select = P.Select()
1030
1077
 
1031
1078
  if zero_clipped_indices is None:
1032
- if is_shape_unknown(shape_op(ids)):
1033
- zero_ids = dyn_fill(ids.dtype, dyn_shape_op(ids), 0)
1034
- else:
1035
- zero_ids = zeros_like(ids)
1036
- zero_clipped_indices = maximum(ids, zero_ids)
1079
+ zero_clipped_indices = maximum(ids, zeros_like(ids))
1037
1080
  gathered = gather(params, zero_clipped_indices, 0)
1038
1081
  zero_slice = zeros_like(gathered)
1039
1082
  if is_positive is None:
1040
1083
  is_positive = greater_equal(ids, 0)
1041
1084
  is_positive_shape = shape_op(is_positive)
1042
1085
  gathered_shape = shape_op(gathered)
1043
- if is_shape_unknown(gathered_shape) or is_shape_unknown(is_positive_shape):
1086
+ if F.is_sequence_value_unknown(gathered_shape) or F.is_sequence_value_unknown(is_positive_shape):
1044
1087
  gathered_shape = dyn_shape_op(gathered)
1045
1088
  rank_gathered = dyn_rank(gathered)
1046
1089
  fill_gathered = dyn_fill(mstype.int64, gathered_shape, 1)
@@ -1052,7 +1095,6 @@ def _gather_drop_negatives(params,
1052
1095
  is_positive_shape = P.Concat(-1)((is_positive_shape, padded_shape))
1053
1096
  is_positive = reshape(is_positive, is_positive_shape)
1054
1097
  is_positive = logical_and(is_positive, F.cast(fill_gathered, mstype.bool_))
1055
- zero_slice = dyn_fill(gathered.dtype, gathered_shape, 0)
1056
1098
  else:
1057
1099
  broadcastable_shape = is_positive_shape
1058
1100
  for _ in range(rank(gathered) - rank(is_positive)):
@@ -1087,13 +1129,7 @@ def get_bprop_unsorted_segment_sum(self):
1087
1129
  """Generate bprop for UnsortedSegmentSum"""
1088
1130
 
1089
1131
  def bprop(x, segment_ids, num_segments, out, dout):
1090
- segment_shape = shape_op(segment_ids)
1091
- if is_shape_unknown(segment_shape):
1092
- segment_shape = dyn_shape_op(segment_ids)
1093
- zeros_segment = dyn_fill(segment_ids.dtype, segment_shape, 0)
1094
- else:
1095
- zeros_segment = zeros_like(segment_ids)
1096
- return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_segment, \
1132
+ return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
1097
1133
  zeros_like(num_segments)
1098
1134
 
1099
1135
  return bprop
@@ -1132,6 +1168,8 @@ def get_bprop_unsorted_segment_prod(self):
1132
1168
  unsorted_segment_prod = P.UnsortedSegmentProd()
1133
1169
 
1134
1170
  def bprop(x, segment_ids, num_segments, out, dout):
1171
+ if x.dtype == mstype.complex64 or x.dtype == mstype.complex128:
1172
+ raise TypeError("For 'UnsortedSegmentProd', complex number is not supported for gradient currently.")
1135
1173
  if x.dtype == mstype.complex64 or x.dtype == mstype.complex128:
1136
1174
  is_zero = equal(x, F.scalar_to_tensor(0).astype(x.dtype))
1137
1175
  else:
@@ -1221,12 +1259,31 @@ def get_bprop_broadcast_to(self):
1221
1259
  x_shape = shape_op(x)
1222
1260
  dout_shape = shape_op(dout)
1223
1261
  broadcast_shape = shape_op(out)
1224
-
1225
- if x_shape == dout_shape:
1262
+ dynamic = F.is_sequence_value_unknown(x_shape) or F.is_sequence_value_unknown(dout_shape)
1263
+ if not dynamic and x_shape == dout_shape:
1226
1264
  return (dout,)
1227
- _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
1228
- reduced_grad = reduce_keep_dim(dout, reduction_axes)
1229
- dx = reshape(reduced_grad, x_shape)
1265
+ dynamic = dynamic or F.is_sequence_value_unknown(broadcast_shape)
1266
+ out_type = dout.dtype
1267
+ if not dynamic:
1268
+ _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
1269
+ if out_type in (ms.int16, ms.int32, ms.int64):
1270
+ dout = P.Cast()(dout, ms.float32)
1271
+ reduced_grad = reduce_keep_dim(dout, reduction_axes)
1272
+ reduced_grad = P.Cast()(reduced_grad, out_type)
1273
+ else:
1274
+ reduced_grad = reduce_keep_dim(dout, reduction_axes)
1275
+ dx = reshape(reduced_grad, x_shape)
1276
+ else:
1277
+ x_shape = dyn_shape_op(x)
1278
+ broadcast_shape = dyn_shape_op(out)
1279
+ _, reduction_axes = DynamicBroadcastGradientArgs()(broadcast_shape, x_shape)
1280
+ if out_type in (ms.int16, ms.int32, ms.int64):
1281
+ dout = P.Cast()(dout, ms.float32)
1282
+ reduced_grad = sum_grad_reduce_axis(dout, reduction_axes, keep_dims=True)
1283
+ reduced_grad = P.Cast()(reduced_grad, out_type)
1284
+ else:
1285
+ reduced_grad = sum_grad_reduce_axis(dout, reduction_axes, keep_dims=True)
1286
+ dx = reshape(reduced_grad, x_shape)
1230
1287
  return (dx,)
1231
1288
 
1232
1289
  return bprop