mindspore 1.10.0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (944) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
  3. mindspore/__init__.py +9 -4
  4. mindspore/_akg/akg/composite/build_module.py +11 -0
  5. mindspore/_akg/akg/config/repository_cuda.json +11 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
  7. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  8. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  9. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  10. mindspore/_check_jit_forbidden_api.py +102 -0
  11. mindspore/_checkparam.py +1066 -1001
  12. mindspore/_extends/builtin_operations.py +32 -4
  13. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  21. mindspore/_extends/parse/__init__.py +5 -3
  22. mindspore/_extends/parse/namespace.py +17 -2
  23. mindspore/_extends/parse/parser.py +193 -34
  24. mindspore/_extends/parse/resources.py +7 -8
  25. mindspore/_extends/parse/standard_method.py +1780 -435
  26. mindspore/_extends/parse/trope.py +3 -1
  27. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  28. mindspore/amp.py +53 -58
  29. mindspore/bin/cache_admin +0 -0
  30. mindspore/bin/cache_server +0 -0
  31. mindspore/boost/adasum.py +3 -2
  32. mindspore/boost/boost.py +2 -2
  33. mindspore/boost/boost_cell_wrapper.py +46 -26
  34. mindspore/boost/dim_reduce.py +6 -5
  35. mindspore/boost/grad_accumulation.py +2 -1
  36. mindspore/boost/group_loss_scale_manager.py +1 -1
  37. mindspore/common/__init__.py +11 -10
  38. mindspore/common/_decorator.py +2 -0
  39. mindspore/common/_register_for_adapter.py +55 -0
  40. mindspore/common/_stub_tensor.py +201 -0
  41. mindspore/common/_utils.py +57 -0
  42. mindspore/common/api.py +582 -297
  43. mindspore/common/dtype.py +66 -18
  44. mindspore/common/dump.py +2 -2
  45. mindspore/common/initializer.py +38 -1
  46. mindspore/common/jit_config.py +25 -13
  47. mindspore/common/mutable.py +53 -24
  48. mindspore/common/parameter.py +60 -37
  49. mindspore/common/seed.py +8 -24
  50. mindspore/common/sparse_tensor.py +927 -0
  51. mindspore/common/tensor.py +1627 -3900
  52. mindspore/communication/__init__.py +10 -5
  53. mindspore/communication/_comm_helper.py +78 -214
  54. mindspore/communication/_hccl_management.py +2 -1
  55. mindspore/communication/management.py +136 -47
  56. mindspore/config/op_info.config +501 -1008
  57. mindspore/config/super_bar_config.json +512 -0
  58. mindspore/context.py +291 -56
  59. mindspore/dataset/__init__.py +12 -8
  60. mindspore/dataset/audio/__init__.py +9 -9
  61. mindspore/dataset/audio/transforms.py +1090 -228
  62. mindspore/dataset/audio/utils.py +87 -39
  63. mindspore/dataset/audio/validators.py +223 -1
  64. mindspore/dataset/callback/ds_callback.py +17 -15
  65. mindspore/dataset/core/config.py +246 -17
  66. mindspore/dataset/core/py_util_helpers.py +4 -3
  67. mindspore/dataset/core/validator_helpers.py +10 -10
  68. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  69. mindspore/dataset/debug/debug_hook.py +65 -0
  70. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  71. mindspore/dataset/engine/__init__.py +7 -3
  72. mindspore/dataset/engine/cache_client.py +9 -9
  73. mindspore/dataset/engine/datasets.py +648 -477
  74. mindspore/dataset/engine/datasets_audio.py +165 -167
  75. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  76. mindspore/dataset/engine/datasets_text.py +492 -342
  77. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  78. mindspore/dataset/engine/datasets_vision.py +1224 -699
  79. mindspore/dataset/engine/graphdata.py +134 -69
  80. mindspore/dataset/engine/iterators.py +50 -9
  81. mindspore/dataset/engine/offload.py +52 -31
  82. mindspore/dataset/engine/samplers.py +27 -24
  83. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  84. mindspore/dataset/engine/validators.py +213 -52
  85. mindspore/dataset/text/__init__.py +10 -8
  86. mindspore/dataset/text/transforms.py +152 -57
  87. mindspore/dataset/text/utils.py +98 -49
  88. mindspore/dataset/text/validators.py +25 -0
  89. mindspore/dataset/transforms/__init__.py +4 -2
  90. mindspore/dataset/transforms/c_transforms.py +11 -13
  91. mindspore/dataset/transforms/py_transforms.py +2 -2
  92. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  93. mindspore/dataset/transforms/transforms.py +13 -15
  94. mindspore/dataset/transforms/validators.py +7 -7
  95. mindspore/dataset/utils/__init__.py +2 -1
  96. mindspore/dataset/utils/browse_dataset.py +13 -13
  97. mindspore/dataset/utils/line_reader.py +121 -0
  98. mindspore/dataset/vision/__init__.py +8 -7
  99. mindspore/dataset/vision/c_transforms.py +125 -126
  100. mindspore/dataset/vision/py_transforms.py +37 -37
  101. mindspore/dataset/vision/py_transforms_util.py +23 -20
  102. mindspore/dataset/vision/transforms.py +316 -315
  103. mindspore/dataset/vision/utils.py +313 -17
  104. mindspore/dataset/vision/validators.py +6 -6
  105. mindspore/default_config.py +0 -1
  106. mindspore/{compression → experimental}/__init__.py +6 -5
  107. mindspore/experimental/map_parameter.py +275 -0
  108. mindspore/include/OWNERS +0 -1
  109. mindspore/include/api/callback/callback.h +9 -13
  110. mindspore/include/api/callback/ckpt_saver.h +2 -2
  111. mindspore/include/api/callback/loss_monitor.h +2 -2
  112. mindspore/include/api/callback/lr_scheduler.h +5 -5
  113. mindspore/include/api/callback/time_monitor.h +2 -2
  114. mindspore/include/api/callback/train_accuracy.h +4 -6
  115. mindspore/include/api/cfg.h +19 -6
  116. mindspore/include/api/context.h +70 -9
  117. mindspore/include/api/delegate.h +8 -1
  118. mindspore/include/api/dual_abi_helper.h +8 -24
  119. mindspore/include/api/metrics/accuracy.h +2 -2
  120. mindspore/include/api/metrics/metrics.h +4 -3
  121. mindspore/include/api/model.h +9 -4
  122. mindspore/include/api/model_group.h +68 -0
  123. mindspore/include/api/model_parallel_runner.h +17 -17
  124. mindspore/include/api/net.h +12 -11
  125. mindspore/include/api/serialization.h +20 -4
  126. mindspore/include/api/status.h +7 -1
  127. mindspore/include/api/types.h +25 -21
  128. mindspore/include/api/visible.h +4 -0
  129. mindspore/include/c_api/model_c.h +5 -0
  130. mindspore/include/c_api/status_c.h +1 -1
  131. mindspore/include/dataset/config.h +1 -1
  132. mindspore/include/dataset/constants.h +14 -0
  133. mindspore/include/dataset/text.h +59 -0
  134. mindspore/include/dataset/vision.h +56 -117
  135. mindspore/include/dataset/vision_lite.h +102 -0
  136. mindspore/include/mindapi/base/type_id.h +42 -3
  137. mindspore/lib/libdnnl.so.2 +0 -0
  138. mindspore/lib/libicudata.so.69 +0 -0
  139. mindspore/lib/libicui18n.so.69 +0 -0
  140. mindspore/lib/libicuuc.so.69 +0 -0
  141. mindspore/lib/libmindspore.so +0 -0
  142. mindspore/lib/libmindspore_backend.so +0 -0
  143. mindspore/lib/libmindspore_common.so +0 -0
  144. mindspore/lib/libmindspore_core.so +0 -0
  145. mindspore/lib/libmindspore_glog.so.0 +0 -0
  146. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  147. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  148. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  149. mindspore/lib/libmindspore_shared_lib.so +0 -0
  150. mindspore/lib/libmpi_adapter.so +0 -0
  151. mindspore/lib/libmpi_collective.so +0 -0
  152. mindspore/lib/libnnacl.so +0 -0
  153. mindspore/lib/libopencv_core.so.4.5 +0 -0
  154. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  155. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  156. mindspore/lib/libps_cache.so +0 -0
  157. mindspore/lib/plugin/ascend/libakg.so +0 -0
  158. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  159. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  160. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  161. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  163. mindspore/lib/{libakg.so → plugin/cpu/libakg.so} +0 -0
  164. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  165. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  166. mindspore/log.py +28 -28
  167. mindspore/mindrecord/common/exceptions.py +2 -4
  168. mindspore/mindrecord/filereader.py +19 -1
  169. mindspore/mindrecord/filewriter.py +250 -88
  170. mindspore/mindrecord/mindpage.py +13 -13
  171. mindspore/mindrecord/shardheader.py +15 -15
  172. mindspore/mindrecord/shardreader.py +9 -0
  173. mindspore/mindrecord/shardwriter.py +29 -29
  174. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  175. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  176. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  177. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  178. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  179. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  180. mindspore/nn/__init__.py +1 -5
  181. mindspore/nn/cell.py +297 -234
  182. mindspore/nn/dynamic_lr.py +1 -1
  183. mindspore/nn/grad/cell_grad.py +17 -42
  184. mindspore/nn/layer/__init__.py +7 -4
  185. mindspore/nn/layer/activation.py +131 -88
  186. mindspore/nn/layer/basic.py +313 -613
  187. mindspore/nn/layer/channel_shuffle.py +103 -0
  188. mindspore/nn/layer/combined.py +1 -1
  189. mindspore/nn/layer/container.py +52 -6
  190. mindspore/nn/layer/conv.py +112 -43
  191. mindspore/nn/layer/dense.py +10 -9
  192. mindspore/nn/layer/embedding.py +36 -34
  193. mindspore/nn/layer/image.py +123 -27
  194. mindspore/nn/layer/math.py +108 -107
  195. mindspore/nn/layer/normalization.py +212 -366
  196. mindspore/nn/layer/padding.py +370 -42
  197. mindspore/nn/layer/pooling.py +1443 -219
  198. mindspore/nn/layer/rnn_cells.py +11 -16
  199. mindspore/nn/layer/rnns.py +38 -39
  200. mindspore/nn/layer/thor_layer.py +24 -25
  201. mindspore/nn/layer/timedistributed.py +5 -5
  202. mindspore/nn/layer/transformer.py +701 -0
  203. mindspore/nn/learning_rate_schedule.py +8 -8
  204. mindspore/nn/loss/__init__.py +9 -6
  205. mindspore/nn/loss/loss.py +678 -142
  206. mindspore/nn/metrics.py +53 -0
  207. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  208. mindspore/nn/optim/ada_grad.py +8 -8
  209. mindspore/nn/optim/adadelta.py +2 -3
  210. mindspore/nn/optim/adafactor.py +18 -14
  211. mindspore/nn/optim/adam.py +429 -87
  212. mindspore/nn/optim/adamax.py +5 -6
  213. mindspore/nn/optim/adasum.py +10 -8
  214. mindspore/nn/optim/asgd.py +7 -7
  215. mindspore/nn/optim/ftrl.py +81 -11
  216. mindspore/nn/optim/lamb.py +7 -8
  217. mindspore/nn/optim/lars.py +4 -4
  218. mindspore/nn/optim/lazyadam.py +82 -7
  219. mindspore/nn/optim/momentum.py +8 -7
  220. mindspore/nn/optim/optimizer.py +19 -10
  221. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  222. mindspore/nn/optim/rmsprop.py +3 -3
  223. mindspore/nn/optim/rprop.py +20 -16
  224. mindspore/nn/optim/sgd.py +21 -15
  225. mindspore/nn/optim/thor.py +23 -21
  226. mindspore/nn/probability/__init__.py +0 -2
  227. mindspore/nn/probability/bijector/bijector.py +7 -6
  228. mindspore/nn/probability/bijector/invert.py +4 -2
  229. mindspore/nn/probability/bijector/softplus.py +2 -2
  230. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  231. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  232. mindspore/nn/probability/distribution/__init__.py +6 -0
  233. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  234. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  235. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  236. mindspore/nn/probability/distribution/beta.py +1 -1
  237. mindspore/nn/probability/distribution/categorical.py +9 -9
  238. mindspore/nn/probability/distribution/cauchy.py +8 -8
  239. mindspore/nn/probability/distribution/distribution.py +12 -6
  240. mindspore/nn/probability/distribution/exponential.py +5 -5
  241. mindspore/nn/probability/distribution/gamma.py +3 -3
  242. mindspore/nn/probability/distribution/geometric.py +6 -5
  243. mindspore/nn/probability/distribution/gumbel.py +5 -5
  244. mindspore/nn/probability/distribution/half_normal.py +133 -0
  245. mindspore/nn/probability/distribution/laplace.py +128 -0
  246. mindspore/nn/probability/distribution/log_normal.py +0 -1
  247. mindspore/nn/probability/distribution/logistic.py +4 -5
  248. mindspore/nn/probability/distribution/normal.py +11 -15
  249. mindspore/nn/probability/distribution/poisson.py +6 -2
  250. mindspore/nn/probability/distribution/student_t.py +150 -0
  251. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  252. mindspore/nn/probability/distribution/uniform.py +5 -5
  253. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  254. mindspore/nn/reinforcement/tensor_array.py +2 -2
  255. mindspore/nn/sparse/sparse.py +8 -1
  256. mindspore/nn/wrap/cell_wrapper.py +55 -27
  257. mindspore/nn/wrap/grad_reducer.py +20 -11
  258. mindspore/nn/wrap/loss_scale.py +47 -30
  259. mindspore/numpy/array_creations.py +33 -22
  260. mindspore/numpy/array_ops.py +46 -42
  261. mindspore/numpy/logic_ops.py +6 -27
  262. mindspore/numpy/math_ops.py +26 -19
  263. mindspore/numpy/utils.py +1 -8
  264. mindspore/numpy/utils_const.py +112 -62
  265. mindspore/ops/__init__.py +6 -3
  266. mindspore/ops/_constants.py +0 -6
  267. mindspore/ops/_grad/__init__.py +2 -1
  268. mindspore/ops/_grad/grad_array_ops.py +209 -152
  269. mindspore/ops/_grad/grad_base.py +55 -17
  270. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  271. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  272. mindspore/ops/_grad/grad_implementations.py +21 -61
  273. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  274. mindspore/ops/_grad/grad_math_ops.py +306 -161
  275. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  276. mindspore/ops/_grad/grad_other_ops.py +1 -1
  277. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  278. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  279. mindspore/ops/_grad/grad_sparse.py +15 -9
  280. mindspore/ops/_grad_experimental/__init__.py +1 -0
  281. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  282. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  283. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  284. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  285. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  286. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  287. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  288. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  289. mindspore/ops/_op_impl/__init__.py +3 -3
  290. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  291. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  292. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  293. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  294. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  295. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  296. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  297. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  298. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  299. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  300. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  302. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  303. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  304. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  305. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  307. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  314. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  315. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  317. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  318. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  320. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  321. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  322. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  324. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  326. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  327. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  328. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  329. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  330. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  331. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  332. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  335. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  336. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  337. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  338. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  339. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  340. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  341. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  342. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  343. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  344. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  345. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  346. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  347. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  349. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  350. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  351. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  352. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  353. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  354. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  355. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  356. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  357. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  358. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  359. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  360. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  361. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  362. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  363. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  365. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  366. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  367. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  368. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  369. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  370. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  371. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  372. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  373. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  374. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  375. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  376. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  377. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  378. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  379. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  380. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  381. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  382. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  383. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  384. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  385. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  386. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  387. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  388. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  389. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  390. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  391. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  392. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  393. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  394. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  395. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  396. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  397. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  398. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  399. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  400. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  401. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  402. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  403. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  404. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  405. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  407. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  408. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  409. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  410. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  411. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  412. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  413. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  414. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  415. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  416. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  417. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  418. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  419. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  420. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  421. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  422. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  423. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  424. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  425. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  426. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  427. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  428. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  429. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  430. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  431. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  432. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  433. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  434. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  435. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  436. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  437. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  441. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  442. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  443. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  444. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  445. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  446. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  447. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  448. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  449. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  450. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  451. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  452. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  453. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  454. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  455. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  456. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  457. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  458. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  459. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  460. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  461. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  462. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  463. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  464. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  465. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  466. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  467. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  468. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  469. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  470. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  471. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  472. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  473. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  474. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  475. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  476. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  477. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  478. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  479. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  480. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  481. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  482. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  483. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  484. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  485. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  486. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  487. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  488. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  489. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  490. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  491. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  492. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  493. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  494. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  495. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  496. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  497. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  498. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  499. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  500. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  501. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  502. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  503. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  504. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  505. mindspore/ops/_primitive_cache.py +3 -2
  506. mindspore/ops/_register_for_op.py +11 -0
  507. mindspore/ops/_utils/__init__.py +1 -1
  508. mindspore/ops/_utils/utils.py +20 -41
  509. mindspore/ops/_vmap/__init__.py +2 -2
  510. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  511. mindspore/ops/_vmap/vmap_base.py +24 -10
  512. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  513. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  514. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  515. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  516. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  517. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  518. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  519. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  520. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  521. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  522. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  524. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  525. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  526. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  527. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  528. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  529. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  530. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  532. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  534. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  536. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  537. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  538. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  541. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  542. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  543. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  544. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  545. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  546. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  547. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  548. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  549. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  550. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  551. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  552. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  554. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  555. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  556. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  560. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  561. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  563. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  564. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  565. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  566. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  567. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  568. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  569. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  571. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  572. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  573. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  574. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  575. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  577. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  578. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  579. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  580. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  581. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  582. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  583. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  585. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  586. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  587. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  588. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  589. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  590. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  591. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  592. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  593. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  594. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  595. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  596. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  597. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  598. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  599. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  600. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  601. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  604. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  605. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  606. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  607. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  609. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  610. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  611. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  612. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  613. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  614. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  615. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  616. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  617. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  618. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  619. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  620. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  621. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  622. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  623. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  624. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  625. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  626. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  627. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  628. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  629. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  630. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  632. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  633. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  634. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  635. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  637. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  638. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  639. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  640. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  643. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  644. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  645. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  646. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  647. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  648. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  649. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  650. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  651. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  652. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  653. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  654. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  655. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  656. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  657. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  658. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  660. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  661. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  662. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  663. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  664. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  665. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  666. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  667. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  668. mindspore/ops/bprop_mindir/__init__.py +1 -4
  669. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  670. mindspore/ops/composite/__init__.py +12 -13
  671. mindspore/ops/composite/base.py +261 -254
  672. mindspore/ops/composite/env_ops.py +41 -0
  673. mindspore/ops/composite/math_ops.py +197 -156
  674. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  675. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  676. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  677. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  678. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  679. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  680. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  681. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  682. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  683. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  684. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  685. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  686. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  687. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  688. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  689. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  690. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  691. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  692. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  693. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  694. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  695. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  696. mindspore/ops/function/__init__.py +323 -8
  697. mindspore/ops/function/array_func.py +3511 -780
  698. mindspore/ops/function/clip_func.py +329 -0
  699. mindspore/ops/function/debug_func.py +6 -6
  700. mindspore/ops/function/grad/__init__.py +5 -1
  701. mindspore/ops/function/grad/grad_func.py +736 -65
  702. mindspore/ops/function/image_func.py +270 -0
  703. mindspore/ops/function/linalg_func.py +268 -8
  704. mindspore/ops/function/math_func.py +8032 -3164
  705. mindspore/ops/function/nn_func.py +5619 -1855
  706. mindspore/ops/function/other_func.py +115 -0
  707. mindspore/ops/function/parameter_func.py +11 -10
  708. mindspore/ops/function/random_func.py +939 -77
  709. mindspore/ops/function/sparse_func.py +249 -84
  710. mindspore/ops/function/sparse_unary_func.py +2303 -0
  711. mindspore/ops/function/spectral_func.py +146 -0
  712. mindspore/ops/function/vmap_func.py +114 -0
  713. mindspore/ops/functional.py +182 -254
  714. mindspore/ops/op_info_register.py +79 -34
  715. mindspore/ops/operations/__init__.py +210 -118
  716. mindspore/ops/operations/_csr_ops.py +7 -7
  717. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  718. mindspore/ops/operations/_grad_ops.py +447 -322
  719. mindspore/ops/operations/_inner_ops.py +547 -176
  720. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  721. mindspore/ops/operations/_ms_kernel.py +29 -27
  722. mindspore/ops/operations/_ocr_ops.py +11 -11
  723. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  724. mindspore/ops/operations/_quant_ops.py +186 -101
  725. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  726. mindspore/ops/operations/_scalar_ops.py +466 -0
  727. mindspore/ops/operations/_sequence_ops.py +1047 -0
  728. mindspore/ops/operations/_tensor_array.py +10 -11
  729. mindspore/ops/operations/_thor_ops.py +4 -4
  730. mindspore/ops/operations/array_ops.py +1428 -1226
  731. mindspore/ops/operations/comm_ops.py +180 -117
  732. mindspore/ops/operations/control_ops.py +4 -2
  733. mindspore/ops/operations/custom_ops.py +185 -98
  734. mindspore/ops/operations/debug_ops.py +92 -54
  735. mindspore/ops/operations/image_ops.py +406 -211
  736. mindspore/ops/operations/inner_ops.py +42 -53
  737. mindspore/ops/operations/linalg_ops.py +32 -29
  738. mindspore/ops/operations/math_ops.py +2076 -897
  739. mindspore/ops/operations/nn_ops.py +1282 -1252
  740. mindspore/ops/operations/other_ops.py +124 -278
  741. mindspore/ops/operations/random_ops.py +345 -178
  742. mindspore/ops/operations/rl_ops.py +8 -9
  743. mindspore/ops/operations/sparse_ops.py +502 -157
  744. mindspore/ops/operations/spectral_ops.py +107 -0
  745. mindspore/ops/primitive.py +192 -15
  746. mindspore/ops/vm_impl_registry.py +23 -2
  747. mindspore/parallel/__init__.py +6 -1
  748. mindspore/parallel/_auto_parallel_context.py +199 -92
  749. mindspore/parallel/_cell_wrapper.py +4 -2
  750. mindspore/parallel/_cost_model_context.py +3 -0
  751. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  752. mindspore/parallel/_offload_context.py +185 -0
  753. mindspore/parallel/_parallel_serialization.py +167 -28
  754. mindspore/parallel/_ps_context.py +9 -5
  755. mindspore/parallel/_recovery_context.py +1 -1
  756. mindspore/parallel/_tensor.py +9 -1
  757. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  758. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  759. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  760. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  761. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  762. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  763. mindspore/parallel/_utils.py +47 -7
  764. mindspore/parallel/algo_parameter_config.py +5 -1
  765. mindspore/parallel/checkpoint_transform.py +329 -0
  766. mindspore/parallel/shard.py +229 -0
  767. mindspore/profiler/__init__.py +2 -1
  768. mindspore/profiler/common/util.py +4 -3
  769. mindspore/profiler/common/validator/validate_path.py +2 -2
  770. mindspore/profiler/envprofiling.py +249 -0
  771. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  772. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  773. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  774. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  775. mindspore/profiler/parser/framework_parser.py +42 -16
  776. mindspore/profiler/parser/hccl_parser.py +158 -158
  777. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  778. mindspore/profiler/parser/integrator.py +18 -1579
  779. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  780. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  781. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  782. mindspore/profiler/parser/optime_parser.py +17 -18
  783. mindspore/profiler/parser/profiler_info.py +108 -0
  784. mindspore/profiler/parser/step_trace_parser.py +1 -1
  785. mindspore/profiler/profiling.py +396 -194
  786. mindspore/rewrite/__init__.py +6 -2
  787. mindspore/rewrite/api/node.py +51 -110
  788. mindspore/rewrite/api/node_type.py +10 -6
  789. mindspore/rewrite/api/pattern_engine.py +51 -7
  790. mindspore/rewrite/api/scoped_value.py +64 -53
  791. mindspore/rewrite/api/symbol_tree.py +108 -61
  792. mindspore/rewrite/api/tree_node_helper.py +2 -3
  793. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  794. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  795. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  796. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  797. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  798. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  799. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  800. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  801. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  802. mindspore/rewrite/common/__init__.py +2 -0
  803. mindspore/rewrite/common/event.py +1 -1
  804. mindspore/rewrite/common/observable.py +1 -1
  805. mindspore/rewrite/common/observer.py +1 -1
  806. mindspore/rewrite/common/rewrite_elog.py +35 -0
  807. mindspore/rewrite/namer.py +2 -2
  808. mindspore/rewrite/namespace.py +14 -4
  809. mindspore/rewrite/node.py +161 -13
  810. mindspore/rewrite/parser.py +0 -1
  811. mindspore/rewrite/parser_register.py +0 -1
  812. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  813. mindspore/rewrite/parsers/assign_parser.py +267 -67
  814. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  815. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  816. mindspore/rewrite/parsers/constant_parser.py +101 -0
  817. mindspore/rewrite/parsers/container_parser.py +88 -0
  818. mindspore/rewrite/parsers/for_parser.py +28 -15
  819. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  820. mindspore/rewrite/parsers/if_parser.py +11 -28
  821. mindspore/rewrite/parsers/module_parser.py +9 -6
  822. mindspore/rewrite/parsers/return_parser.py +3 -2
  823. mindspore/rewrite/sparsify/__init__.py +0 -0
  824. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  825. mindspore/rewrite/sparsify/sparsify.py +109 -0
  826. mindspore/rewrite/sparsify/utils.py +173 -0
  827. mindspore/rewrite/symbol_tree.py +322 -109
  828. mindspore/rewrite/symbol_tree_builder.py +45 -8
  829. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  830. mindspore/rewrite/topological_manager.py +1 -2
  831. mindspore/run_check/_check_version.py +209 -112
  832. mindspore/run_check/run_check.py +2 -1
  833. mindspore/scipy/linalg.py +13 -117
  834. mindspore/scipy/ops.py +5 -71
  835. mindspore/scipy/ops_grad.py +1 -25
  836. mindspore/scipy/ops_wrapper.py +1 -1
  837. mindspore/scipy/optimize/_bfgs.py +1 -1
  838. mindspore/scipy/optimize/_lagrange.py +200 -0
  839. mindspore/scipy/optimize/line_search.py +3 -2
  840. mindspore/scipy/optimize/minimize.py +43 -6
  841. mindspore/scipy/sparse/__init__.py +2 -2
  842. mindspore/scipy/sparse/linalg.py +5 -465
  843. mindspore/scipy/utils.py +2 -1
  844. mindspore/scipy/utils_const.py +7 -1
  845. mindspore/train/__init__.py +6 -4
  846. mindspore/train/_utils.py +28 -5
  847. mindspore/train/amp.py +321 -50
  848. mindspore/train/callback/__init__.py +3 -1
  849. mindspore/train/callback/_backup_and_restore.py +120 -0
  850. mindspore/train/callback/_callback.py +8 -8
  851. mindspore/train/callback/_checkpoint.py +12 -9
  852. mindspore/train/callback/_early_stop.py +13 -7
  853. mindspore/train/callback/_history.py +8 -8
  854. mindspore/train/callback/_lambda_callback.py +6 -6
  855. mindspore/train/callback/_landscape.py +36 -38
  856. mindspore/train/callback/_loss_monitor.py +12 -6
  857. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  858. mindspore/train/callback/_on_request_exit.py +212 -0
  859. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  860. mindspore/train/callback/_summary_collector.py +27 -19
  861. mindspore/train/callback/_time_monitor.py +13 -7
  862. mindspore/train/checkpoint_pb2.py +68 -8
  863. mindspore/train/data_sink.py +122 -33
  864. mindspore/train/dataset_helper.py +28 -87
  865. mindspore/train/loss_scale_manager.py +4 -7
  866. mindspore/{nn → train}/metrics/__init__.py +20 -20
  867. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  868. mindspore/{nn → train}/metrics/auc.py +4 -4
  869. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  870. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  871. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  872. mindspore/{nn → train}/metrics/dice.py +6 -5
  873. mindspore/{nn → train}/metrics/error.py +7 -5
  874. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  875. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  876. mindspore/{nn → train}/metrics/loss.py +4 -3
  877. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  878. mindspore/{nn → train}/metrics/metric.py +6 -5
  879. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  880. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  881. mindspore/{nn → train}/metrics/precision.py +5 -4
  882. mindspore/{nn → train}/metrics/recall.py +5 -4
  883. mindspore/{nn → train}/metrics/roc.py +7 -6
  884. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  885. mindspore/{nn → train}/metrics/topk.py +7 -5
  886. mindspore/train/mind_ir_pb2.py +339 -32
  887. mindspore/train/model.py +113 -84
  888. mindspore/train/serialization.py +547 -167
  889. mindspore/train/summary/_summary_adapter.py +1 -1
  890. mindspore/train/summary/summary_record.py +43 -12
  891. mindspore/train/train_thor/convert_utils.py +7 -1
  892. mindspore/train/train_thor/dataset_helper.py +3 -3
  893. mindspore/train/train_thor/model_thor.py +0 -4
  894. mindspore/version.py +1 -1
  895. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  896. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +899 -675
  897. mindspore/compression/common/constant.py +0 -124
  898. mindspore/compression/export/__init__.py +0 -19
  899. mindspore/compression/export/quant_export.py +0 -514
  900. mindspore/compression/quant/qat.py +0 -636
  901. mindspore/compression/quant/quant_utils.py +0 -462
  902. mindspore/compression/quant/quantizer.py +0 -68
  903. mindspore/nn/layer/quant.py +0 -1868
  904. mindspore/nn/layer/rnn_utils.py +0 -90
  905. mindspore/nn/probability/dpn/__init__.py +0 -22
  906. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  907. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  908. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  909. mindspore/nn/probability/infer/__init__.py +0 -22
  910. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  911. mindspore/nn/probability/infer/variational/svi.py +0 -84
  912. mindspore/nn/probability/toolbox/__init__.py +0 -22
  913. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  914. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  915. mindspore/nn/probability/transforms/__init__.py +0 -22
  916. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  917. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  918. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  919. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  920. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  921. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  922. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  923. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  924. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  925. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  926. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  927. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  928. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  929. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  930. mindspore/ops/composite/array_ops.py +0 -210
  931. mindspore/ops/composite/clip_ops.py +0 -238
  932. mindspore/ops/composite/random_ops.py +0 -426
  933. mindspore/ops/composite/vmap_ops.py +0 -38
  934. mindspore/ops/operations/sponge_ops.py +0 -3531
  935. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  936. mindspore/parallel/nn/__init__.py +0 -42
  937. mindspore/parallel/nn/loss.py +0 -22
  938. mindspore/parallel/nn/moe.py +0 -21
  939. mindspore/parallel/nn/op_parallel_config.py +0 -22
  940. mindspore/parallel/nn/transformer.py +0 -31
  941. mindspore/run_check/_check_deps_version.py +0 -84
  942. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  943. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  944. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,15 @@
14
14
  # ============================================================================
15
15
 
16
16
  """array_ops"""
17
+ from __future__ import absolute_import
18
+
17
19
  from mindspore import Tensor
18
20
  from mindspore.ops.primitive import constexpr
19
21
  from mindspore.common import dtype as mstype
20
22
  from mindspore.numpy.array_ops import where
21
23
  from mindspore.ops._grad.grad_math_ops import binop_grad_common
22
- from mindspore.ops._grad.grad_base import bprop_getters
24
+ from mindspore.ops._grad.grad_base import bprop_getters, dyn_rank, dyn_fill, dyn_ones, create_tensor_by_element
25
+ from mindspore.ops._grad.grad_base import convert_to_tensor
23
26
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
24
27
  from mindspore.ops.operations.array_ops import Tril
25
28
  from mindspore.ops.operations.array_ops import MatrixDiagV3
@@ -30,6 +33,7 @@ from mindspore.ops.operations.array_ops import Mvlgamma
30
33
  from mindspore.ops.operations.array_ops import Triu
31
34
  from mindspore.ops.operations.array_ops import IdentityN
32
35
  from mindspore.ops.operations.array_ops import IndexFill
36
+ from mindspore.ops.operations.array_ops import IndexPut
33
37
  from mindspore.ops.operations.array_ops import CheckNumerics
34
38
  from mindspore.ops.operations.array_ops import ConjugateTranspose
35
39
  from mindspore.ops.operations.array_ops import SegmentMax
@@ -42,10 +46,72 @@ from mindspore.ops.operations.array_ops import SegmentMean
42
46
  from mindspore.ops.operations.array_ops import AffineGrid
43
47
  from mindspore.ops.operations.array_ops import Im2Col
44
48
  from mindspore.ops.operations.array_ops import Col2Im
49
+ from mindspore.ops.operations.array_ops import StridedSliceV2
50
+ from mindspore.ops.operations.array_ops import MaskedScatter
51
+ from mindspore.ops.operations.array_ops import MaskedSelect
52
+ from mindspore.ops.operations.array_ops import CountNonZero
53
+ from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
54
+ from mindspore.ops.operations.random_ops import LogNormalReverse
55
+ from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal
56
+ from mindspore.ops.operations import _inner_ops as inner
45
57
  from mindspore.ops import functional as F
46
58
  from mindspore.ops import operations as P
47
- from mindspore.ops._utils.utils import is_shape_unknown
48
59
  from mindspore.ops.operations import _grad_ops as G
60
+ from mindspore import context
61
+
62
+
63
+ @constexpr
64
+ def _raise_value_error(*info):
65
+ info_str = ""
66
+ for obj in info:
67
+ info_str = info_str + f"{obj}"
68
+ raise ValueError(info_str)
69
+
70
+
71
+ @bprop_getters.register(P.FillV2)
72
+ def get_bprop_fill_v2(self):
73
+ """Generate bprop for FillV2"""
74
+ sum_op = P.ReduceSum()
75
+ cast_op = P.Cast()
76
+ shape_op = P.TensorShape()
77
+
78
+ def bprop(shape, value, out, dout):
79
+ dout_type = F.dtype(dout)
80
+ type_list = [
81
+ mstype.int8, mstype.int16, mstype.int32, mstype.int64,
82
+ mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
83
+ mstype.float16, mstype.float64
84
+ ]
85
+ if dout_type in type_list:
86
+ dout = cast_op(dout, mstype.float32)
87
+ dout_shape = shape_op(dout)
88
+ axis = tuple([i for i in range(len(dout_shape))])
89
+ dvalue = sum_op(dout, axis)
90
+ return zeros_like(shape), cast_op(dvalue, dout_type)
91
+
92
+ return bprop
93
+
94
+
95
+ @bprop_getters.register(StridedSliceV2)
96
+ def get_bprop_strided_slice_v2(self):
97
+ """Generate bprop for StridedSliceV2"""
98
+ shape_op = P.Shape()
99
+ dyn_shape_op = P.TensorShape()
100
+ input_grad = StridedSliceV2Grad(self.begin_mask,
101
+ self.end_mask,
102
+ self.ellipsis_mask,
103
+ self.new_axis_mask,
104
+ self.shrink_axis_mask)
105
+
106
+ def bprop(x, begin, end, strides, out, dout):
107
+ x_shape = shape_op(x)
108
+ if F.is_sequence_value_unknown(x_shape):
109
+ x_shape = dyn_shape_op(x)
110
+ dx = input_grad(x_shape, begin, end, strides, dout)
111
+ dx_all = (dx, zeros_like(begin), zeros_like(end), zeros_like(strides))
112
+ return dx_all
113
+
114
+ return bprop
49
115
 
50
116
 
51
117
  @constexpr
@@ -77,14 +143,19 @@ def get_bprop_masked_select(self):
77
143
  """Generate bprop for MaskedFill"""
78
144
  mul_op = P.Mul()
79
145
  sum_op = P.ReduceSum()
80
- is_instance_op = P.IsInstance()
146
+ is_instance_op = inner.IsInstance()
81
147
 
82
148
  def bprop(input_data, mask, value, out, dout):
83
149
  mask = F.cast(mask, mstype.float32)
84
150
  dinput = mul_op(dout, (1 - mask))
85
151
  dvalue = mul_op(dout, mask)
86
152
  dinput, dvalue = binop_grad_common(input_data, mask, dinput, dvalue)
87
- dvalue = sum_op(dvalue)
153
+ # for dynamic rank, reduce axis should be calc
154
+ if F.is_sequence_shape_unknown(P.Shape()(dvalue)):
155
+ axis = P.Range()(Tensor(0), dyn_rank(dvalue), Tensor(1))
156
+ dvalue = sum_op(dvalue, axis)
157
+ else:
158
+ dvalue = sum_op(dvalue)
88
159
  dinput = F.cast(dinput, F.dtype(input_data))
89
160
  if is_instance_op(value, mstype.number):
90
161
  dvalue = 0
@@ -95,6 +166,54 @@ def get_bprop_masked_select(self):
95
166
  return bprop
96
167
 
97
168
 
169
+ @bprop_getters.register(MaskedScatter)
170
+ def get_bprop_masked_scatter(self):
171
+ """Generate bprop for MaskedScatter"""
172
+ sort_ = P.Sort(descending=True)
173
+ masked_scatter = MaskedScatter()
174
+ masked_fill = P.MaskedFill()
175
+ masked_select = P.MaskedSelect()
176
+ size = P.Size()
177
+ zeros = P.Zeros()
178
+ concat = P.Concat(axis=0)
179
+ reshape = P.Reshape()
180
+ shape = P.Shape()
181
+
182
+ def bprop(x, mask, updates, out, dout):
183
+ dx = masked_fill(F.cast(dout, mstype.float32), mask, 0.0)
184
+ mask_selected = masked_select(F.cast(dout, mstype.float32), mask)
185
+ mask_broad = mask
186
+ if shape(mask) != shape(x):
187
+ broad_cast = P.BroadcastTo(shape(x))
188
+ mask_broad = broad_cast(mask)
189
+ mask_broad_vec = mask_broad.reshape(-1)
190
+ mask_sorted = F.cast(sort_(F.cast(mask_broad_vec, mstype.float32))[0], F.dtype(mask))
191
+ diff_num = size(updates) - size(mask_broad)
192
+ if diff_num > 0:
193
+ zeros_pad = zeros(diff_num, F.dtype(mask))
194
+ mask_sorted = concat((mask_sorted, zeros_pad))
195
+ zeros_tensor = zeros(size(updates), mstype.float32)
196
+ dupdates = masked_scatter(zeros_tensor, mask_sorted, mask_selected)
197
+ if shape(updates) != ():
198
+ dupdates = reshape(dupdates, shape(updates))
199
+ else:
200
+ zeros_tensor = zeros(shape(updates), mstype.float32)
201
+ dupdates = masked_scatter(zeros_tensor, mask, mask_selected)
202
+ return F.cast(dx, F.dtype(x)), zeros_like(mask), F.cast(dupdates, F.dtype(updates))
203
+
204
+ return bprop
205
+
206
+
207
+ @bprop_getters.register(CountNonZero)
208
+ def get_bprop_countnonzero(self):
209
+ """Grad definition for CountNonZero"""
210
+
211
+ def bprop(x, out, dout):
212
+ return (zeros_like(x),)
213
+
214
+ return bprop
215
+
216
+
98
217
  @bprop_getters.register(Mvlgamma)
99
218
  def get_bprop_mvlgamma(self):
100
219
  """Grad definition for Mvlgamma"""
@@ -142,16 +261,59 @@ def get_bprop_index_fill(self):
142
261
  def bprop(x, dim, indices, value, out, dout):
143
262
  zero_value = zeros_like(value)
144
263
  x_grad = index_fill(dout, dim, indices, zero_value)
145
- if shape(x) == ():
146
- value_grad = dout
264
+ if F.is_sequence_value_unknown(shape(x)):
265
+ if dyn_rank(x) == 0:
266
+ value_grad = dout
267
+ else:
268
+ value_grad = gather(dout, indices, dim).sum()
147
269
  else:
148
- value_grad = gather(dout, indices, dim).sum()
270
+ if shape(x) == ():
271
+ value_grad = dout
272
+ else:
273
+ value_grad = gather(dout, indices, dim).sum()
149
274
  result = (x_grad, zeros_like(dim), zeros_like(indices), value_grad)
150
275
  return result
151
276
 
152
277
  return bprop
153
278
 
154
279
 
280
+ @bprop_getters.register(IndexPut)
281
+ def get_bprop_index_put(self):
282
+ """Generate bprop for IndexPut"""
283
+ gather_nd = P.GatherNd()
284
+ stack = P.Stack()
285
+ tile = P.Tile()
286
+ masked_select = MaskedSelect()
287
+ masked_scatter = MaskedScatter()
288
+ accumulate_grad = self.accumulate
289
+ index_put = IndexPut(accumulate=accumulate_grad)
290
+ is_ascend = context.get_context("device_target") == 'Ascend'
291
+
292
+ # Negative value are not supported for GatherNd indices when Ascend, so convert it to positive value.
293
+ def convert_idx_positive(indices_i, x_shape_i):
294
+ mask = indices_i < 0
295
+ idx_pos = masked_select(indices_i + x_shape_i, mask)
296
+ idx = masked_scatter(indices_i, mask, idx_pos)
297
+ return idx
298
+
299
+ def bprop(x1, x2, indices, out, dout):
300
+ maxsize = max(x.shape[0] for x in indices)
301
+ indices_ms = [tile(x, (maxsize,)) if x.shape[0] == 1 else x for x in indices]
302
+ if is_ascend:
303
+ indices_ms = [convert_idx_positive(indices_ms[i], x1.shape[i]) for i in range(len(indices_ms))]
304
+ indices_grad = stack(indices_ms).T
305
+ values_grad = gather_nd(dout, indices_grad)
306
+ if x2.shape[0] == 1:
307
+ values_grad = values_grad.sum().reshape(1)
308
+ if values_grad.shape != x2.shape and len(indices) < len(x1.shape):
309
+ _, values_grad = binop_grad_common(x1, x2, dout, values_grad)
310
+ if accumulate_grad == 0:
311
+ dout = index_put(dout, zeros_like(x2), indices)
312
+ return dout, values_grad, [zeros_like(item) for item in indices]
313
+
314
+ return bprop
315
+
316
+
155
317
  @bprop_getters.register(P.TensorScatterSub)
156
318
  def get_bprop_tensor_scatter_sub(self):
157
319
  """Generate bprop for TensorScatterSub"""
@@ -206,7 +368,7 @@ def get_bprop_matrix_diag_part_v3(self):
206
368
 
207
369
  def bprop(x, k, padding_value, out, dout):
208
370
  shape_this = P.Shape()(x)[-2:]
209
- if not is_shape_unknown(shape_this):
371
+ if not F.is_sequence_value_unknown(shape_this):
210
372
  row = shape_this[0]
211
373
  col = shape_this[1]
212
374
  result = (matrix_diag_v3(dout, k, Tensor(row, dtype=mstype.int32), Tensor(col, dtype=mstype.int32),
@@ -224,36 +386,17 @@ def get_bprop_matrix_set_diag_v3(self):
224
386
  align = self.align
225
387
  matrix_diag_part_v3 = MatrixDiagPartV3(align=align)
226
388
  matrix_set_diag_v3 = MatrixSetDiagV3(align=align)
227
- resha = P.Reshape()
228
389
  zeros = P.Zeros()
229
- minimum = P.Minimum()
230
- concat = P.Concat()
231
390
 
232
391
  def bprop(x, diagonal, k, out, dout):
233
392
  diagonal_cal = matrix_diag_part_v3(dout, k, zeros((), dout.dtype))
234
393
 
235
394
  diagonal_shape = P.Shape()(diagonal)
236
- if is_shape_unknown(diagonal_shape):
237
- shape_dout = P.Shape()(dout)
238
- pre_shape = shape_dout[:-2]
239
- back_shape = shape_dout[-2:]
240
-
241
- site_dia = resha(k, (-1))
242
- index_min = -1 * site_dia[0]
243
- index_max = site_dia[-1]
244
- col = 0
245
- if index_max < 0:
246
- col = index_max
247
- row = 0
248
- if index_min < 0:
249
- row = index_min
250
- max_diag_len = minimum(back_shape[0] + col, back_shape[1] + row)
251
-
252
- back = [max_diag_len]
253
- if index_max != index_min:
254
- back = [index_max - index_min + 1, max_diag_len]
255
- diagonal_shape = concat([pre_shape, back])
256
- x_cal = matrix_set_diag_v3(dout, zeros(diagonal_shape, dout.dtype), k)
395
+ if F.is_sequence_value_unknown(diagonal_shape):
396
+ diagonal = F.cast(diagonal, dout.dtype)
397
+ x_cal = matrix_set_diag_v3(dout, zeros_like(diagonal), k)
398
+ else:
399
+ x_cal = matrix_set_diag_v3(dout, zeros(diagonal_shape, dout.dtype), k)
257
400
 
258
401
  return x_cal, diagonal_cal, zeros_like(k)
259
402
 
@@ -266,11 +409,16 @@ def tensor_scatter_possible_replacement(x, indices, updates, out, dout):
266
409
  scatter_nd = P.ScatterNd()
267
410
  equal = P.Equal()
268
411
  shape = P.Shape()
412
+ dyn_shape_op = P.TensorShape()
269
413
 
270
414
  x_indicators = F.cast(equal(x, out), mstype.int32)
271
415
  possibly_updated = gather_nd(out, indices)
272
416
  out_indicators = F.cast(equal(updates, possibly_updated), mstype.int32)
273
- scattered_out_indicators = scatter_nd(indices, out_indicators, shape(x))
417
+ input_shape = shape(x)
418
+ if F.is_sequence_value_unknown(input_shape):
419
+ input_shape = dyn_shape_op(x)
420
+
421
+ scattered_out_indicators = scatter_nd(indices, out_indicators, input_shape)
274
422
  indicators = x_indicators + scattered_out_indicators
275
423
  dx = dout * F.cast(x_indicators, F.dtype(dout)) / F.cast(indicators, F.dtype(dout))
276
424
  dupdates = gather_nd(dout / F.cast(indicators, F.dtype(dout)), indices) * F.cast(out_indicators, F.dtype(dout))
@@ -278,6 +426,24 @@ def tensor_scatter_possible_replacement(x, indices, updates, out, dout):
278
426
  return F.cast(dx, F.dtype(x)), zeros_like(indices), F.cast(dupdates, F.dtype(updates))
279
427
 
280
428
 
429
+ @bprop_getters.register(LogNormalReverse)
430
+ def get_bprop_log_normal_reverse(self):
431
+ """Grad definition for `LogNormalReverse` operation."""
432
+ def bprop(input_data, out, dout):
433
+ return (zeros_like(input_data),)
434
+
435
+ return bprop
436
+
437
+
438
+ @bprop_getters.register(ParameterizedTruncatedNormal)
439
+ def get_bprop_parameterized_truncated_normal(self):
440
+ """Grad definition for `ParameterizedTruncatedNormal` operation."""
441
+ def bprop(shape, mean, stdevs, min_val, max_val, out, dout):
442
+ return (zeros_like(shape), zeros_like(mean), zeros_like(stdevs), zeros_like(min_val), zeros_like(max_val))
443
+
444
+ return bprop
445
+
446
+
281
447
  @bprop_getters.register(P.TensorScatterMax)
282
448
  def get_bprop_tensor_scatter_max(self):
283
449
  """Generate bprop for TensorScatterMax"""
@@ -377,9 +543,16 @@ def get_bprop_resize_nearest_neighbor_v2(self):
377
543
 
378
544
  def bprop(x, size, output, dout):
379
545
  x_shape = P.Shape()(x)
546
+ if F.is_sequence_value_unknown(x_shape):
547
+ x_shape = P.TensorShape()(x)
380
548
  grad_in_size = x_shape[1:3]
381
549
  if data_format == 'NCHW':
382
550
  grad_in_size = x_shape[2:4]
551
+
552
+ if F.is_sequence_value_unknown(P.Shape()(x)):
553
+ dx = grad_op(dout, grad_in_size)
554
+ return dx, zeros_like(grad_in_size)
555
+
383
556
  dx = grad_op(dout, _create_tensor(grad_in_size, mstype.int32))
384
557
  return dx, zeros_like(grad_in_size)
385
558
 
@@ -393,7 +566,7 @@ def get_bprop_col2im(self):
393
566
  dilations = self.dilation
394
567
  strides = self.stride
395
568
  pads = self.padding
396
- im2col = Im2Col(ksizes=ksizes, dilations=dilations, strides=strides, padding_mode="CALCULATED", pads=pads)
569
+ im2col = Im2Col(ksizes=ksizes, dilations=dilations, strides=strides, pads=pads)
397
570
 
398
571
  def bprop(x, output_size, out, dout):
399
572
  dx = im2col(dout)
@@ -402,6 +575,36 @@ def get_bprop_col2im(self):
402
575
  return bprop
403
576
 
404
577
 
578
+ @bprop_getters.register(Im2Col)
579
+ def get_bprop_im2col(self):
580
+ """
581
+ Generate bprop for Im2Col
582
+
583
+ Im2Col, corresponding to torch's UnFold operator.
584
+ The Unfold operator has no `padding_mode` attribute,
585
+ and it's implementation corresponds to the mindspore
586
+ implementation with `padding_mode=CALCULATED` .
587
+ So, currently the bprop function of Im2Col only supports
588
+ the CALCULATED mode.
589
+ """
590
+ kernel_size = self.ksizes
591
+ dilation = self.dilations
592
+ stride = self.strides
593
+ padding = (self.pads[0], self.pads[-1])
594
+ shape_op = P.TensorShape()
595
+ col2im = Col2Im(kernel_size=kernel_size,
596
+ dilation=dilation,
597
+ stride=stride,
598
+ padding=padding)
599
+
600
+ def bprop(x, out, dout):
601
+ x_shape = shape_op(x)[2:]
602
+ dx = col2im(dout, x_shape)
603
+ return (dx,)
604
+
605
+ return bprop
606
+
607
+
405
608
  @bprop_getters.register(P.ExtractVolumePatches)
406
609
  def get_bprop_extract_volume_patches(self):
407
610
  """Generate bprop for ExtractVolumePatches"""
@@ -416,9 +619,54 @@ def get_bprop_extract_volume_patches(self):
416
619
  cast = P.Cast()
417
620
  matmul = P.MatMul()
418
621
  _, _, ksize_d, ksize_h, ksize_w = self.kernel_size
622
+ range_ = P.Range()
623
+ dyn_shape_op = P.TensorShape()
624
+ ones_like = P.OnesLike()
625
+
626
+ def _dyn_extract_volume_patches(x, out, dout):
627
+ x_shape = dyn_shape_op(x)
628
+ out_shape = dyn_shape_op(out)
629
+ x_n, x_c, x_d, x_h, x_w = x_shape[0], x_shape[1], x_shape[2], x_shape[3], x_shape[4]
630
+ x_indices_num = 1 + x_d * x_h * x_w
631
+ x_idx = range_(cast(1, mstype.float32), cast(x_indices_num, mstype.float32), cast(1, mstype.float32))
632
+ x_idx = cast(x_idx, mstype.float16)
633
+ x_idx = P.Reshape()(x_idx, create_tensor_by_element((1, 1, x_d, x_h, x_w)))
634
+ x_idx_patched = extract_volume_patches(x_idx)
635
+ x_idx_patched = P.Transpose()(x_idx_patched, (0, 2, 3, 4, 1))
636
+ x_idx_patched = cast(x_idx_patched, mstype.int32)
637
+
638
+ out_d, out_h, out_w = out_shape[2], out_shape[3], out_shape[4]
639
+ out_indices_num = out_d * out_h * out_w * ksize_d * ksize_h * ksize_w
640
+ out_idx_ori = range_(cast(0, mstype.int32), cast(out_indices_num, mstype.int32), cast(1, mstype.int32))
641
+ out_idx = P.Reshape()(out_idx_ori,
642
+ create_tensor_by_element((1, out_d, out_h, out_w, ksize_d * ksize_h * ksize_w)))
643
+
644
+ idx_tensor = concat((expend_dims(x_idx_patched, -1), expend_dims(out_idx, -1)))
645
+ idx_map = P.Reshape()(idx_tensor, (-1, 2))
646
+ sp_shape = create_tensor_by_element((x_indices_num, out_indices_num))
647
+ update = cast(ones_like(out_idx_ori), dtype(dout))
648
+ sp_mat_full = scatter_nd(idx_map, update, sp_shape)
649
+ begin = create_tensor_by_element((1, 0))
650
+ size = create_tensor_by_element((x_indices_num - 1, out_indices_num))
651
+ sp_tensor = slice_op(sp_mat_full, begin, size)
652
+
653
+ grad = P.Transpose()(dout, (0, 2, 3, 4, 1))
654
+ grad = P.Reshape()(grad, create_tensor_by_element((x_n, out_d, out_h, out_w, ksize_d,
655
+ ksize_h, ksize_w, x_c)))
656
+ grad_expended = P.Transpose()(grad, (1, 2, 3, 4, 5, 6, 0, 7))
657
+ grad_flat = P.Reshape()(grad_expended,
658
+ create_tensor_by_element((out_d * out_h * out_w * ksize_d * ksize_h * ksize_w,
659
+ x_n * x_c)))
660
+ jac = matmul(sp_tensor, grad_flat)
661
+ dx = P.Reshape()(jac, create_tensor_by_element((x_d, x_h, x_w, x_n, x_c)))
662
+ dx = P.Transpose()(dx, (3, 4, 0, 1, 2))
663
+ return (dx,)
419
664
 
420
665
  def bprop(x, out, dout):
421
666
  x_shape = P.Shape()(x)
667
+ out_shape = P.Shape()(out)
668
+ if F.is_sequence_value_unknown(x_shape) or F.is_sequence_value_unknown(out_shape):
669
+ return _dyn_extract_volume_patches(x, out, dout)
422
670
  x_n, x_c, x_d, x_h, x_w = x_shape
423
671
  x_indices_num = 1 + x_d * x_h * x_w
424
672
  x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float16)
@@ -427,7 +675,6 @@ def get_bprop_extract_volume_patches(self):
427
675
  x_idx_patched = P.Transpose()(x_idx_patched, (0, 2, 3, 4, 1))
428
676
  x_idx_patched = cast(x_idx_patched, mstype.int32)
429
677
 
430
- out_shape = P.Shape()(out)
431
678
  _, _, out_d, out_h, out_w = out_shape
432
679
  out_indices_num = out_d * out_h * out_w * ksize_d * ksize_h * ksize_w
433
680
  out_idx = F.tuple_to_array(range(0, out_indices_num))
@@ -489,16 +736,126 @@ def get_bprop_affinegrid(self):
489
736
  """Generate bprop for AffineGrid"""
490
737
 
491
738
  align_corners = self.align_corners
739
+ input_grad = G.AffineGridGrad(align_corners)
492
740
  ones = P.Ones()
493
741
  transpose = P.Transpose()
494
742
  concat = P.Concat(1)
743
+ concat0 = P.Concat(0)
495
744
  tile = P.Tile()
745
+ div = P.Div()
496
746
  reshape = P.Reshape()
497
747
  linspace = P.LinSpace()
498
748
  batmatmul = P.BatchMatMul()
499
749
  expend_dims = P.ExpandDims()
750
+ dyn_shape = P.TensorShape()
751
+ reducesum = P.ReduceSum(keep_dims=False)
752
+
753
+ def get_linspace(num):
754
+ start = Tensor(-1, mstype.float32)
755
+ stop = Tensor(1, mstype.float32)
756
+ lins_tensor = Tensor([0], dtype=mstype.float32)
757
+ if num != 1:
758
+ lins_tensor = linspace(start, stop, num)
759
+ return lins_tensor
760
+
761
+ def dyn_bprop_five(theta, output_size, out, dout, len_output_size):
762
+ perm1 = (1, 0)
763
+ perm2 = (0, 2, 1)
764
+ one_tensor = create_tensor_by_element((1,), mstype.int32)
765
+ n_value = reducesum(output_size[0])
766
+ d_value = reducesum(output_size[2])
767
+ h_value = reducesum(output_size[3])
768
+ w_value = reducesum(output_size[len_output_size - 1])
769
+ vecx = get_linspace(w_value.astype("int64"))
770
+ vecy = get_linspace(h_value.astype("int64"))
771
+ vecz = get_linspace(d_value.astype("int64"))
772
+ if align_corners is False:
773
+ vecx = div(vecx * (w_value - 1), w_value)
774
+ vecy = div(vecy * (h_value - 1), h_value)
775
+ vecz = div(vecz * (d_value - 1), d_value)
776
+ out = vecx
777
+ if h_value * d_value != 1:
778
+ multiples = concat0((expend_dims(h_value * d_value, -1), one_tensor))
779
+ out = tile(vecx, multiples)
780
+ hwd_value = h_value * w_value * d_value
781
+ hwd_shape = concat0((expend_dims(hwd_value, -1), one_tensor))
782
+ one = reshape(out, hwd_shape)
783
+ if w_value == 1:
784
+ out = expend_dims(vecy, 0)
785
+ elif w_value != 1:
786
+ multiples = concat0((expend_dims(w_value, -1), one_tensor))
787
+ out = tile(vecy, multiples)
788
+ out = transpose(out, perm1)
789
+ if d_value != 1:
790
+ multiples = concat0((expend_dims(d_value, -1), one_tensor))
791
+ out = tile(out, multiples)
792
+ two = reshape(out, hwd_shape)
793
+ out = expend_dims(vecz, 0)
794
+ if w_value * h_value != 1:
795
+ multiples = concat0((expend_dims(w_value * h_value, -1), one_tensor))
796
+ out = tile(vecz, multiples)
797
+ out = transpose(out, perm1)
798
+ four = dyn_ones(hwd_shape, mstype.float32)
799
+ output = concat((one, two, reshape(out, hwd_shape), four))
800
+ output = transpose(output, perm1)
801
+ if n_value != 1:
802
+ multiples = concat0((expend_dims(n_value, -1), one_tensor))
803
+ output = tile(output, multiples)
804
+ three_tensor = create_tensor_by_element((3,), mstype.int32)
805
+ four_tensor = create_tensor_by_element((4,), mstype.int32)
806
+ output_shape = concat0((expend_dims(n_value, -1), four_tensor, expend_dims(hwd_value, -1)))
807
+ dout_shape = concat0((expend_dims(n_value, -1), expend_dims(hwd_value, -1), three_tensor))
808
+ dtheta = batmatmul(reshape(output, output_shape), reshape(dout, dout_shape).astype("float32"))
809
+ return transpose(dtheta, perm2), four
810
+
811
+ def dyn_bprop_four(theta, output_size, out, dout):
812
+ perm1 = (1, 0)
813
+ perm2 = (0, 2, 1)
814
+ one_tensor = create_tensor_by_element((1,), mstype.int32)
815
+ n_value = reducesum(output_size[0])
816
+ h_value = reducesum(output_size[2])
817
+ w_value = reducesum(output_size[3])
818
+ vecx = get_linspace(w_value.astype("int64"))
819
+ vecy = get_linspace(h_value.astype("int64"))
820
+ if align_corners is False:
821
+ vecx = div(vecx * (w_value - 1), w_value)
822
+ vecy = div(vecy * (h_value - 1), h_value)
823
+ out = vecx
824
+ if h_value != 1:
825
+ multiples = concat0((expend_dims(h_value, -1), one_tensor))
826
+ out = tile(vecx, multiples)
827
+ hw_shape = concat0((expend_dims(h_value * w_value, -1), one_tensor))
828
+ one = reshape(out, hw_shape)
829
+ if w_value == 1:
830
+ out = expend_dims(vecy, 0)
831
+ elif w_value != 1:
832
+ multiples = concat0((expend_dims(w_value, -1), one_tensor))
833
+ out = tile(vecy, multiples)
834
+ out = transpose(out, perm1)
835
+ two = reshape(out, hw_shape)
836
+ tre = dyn_ones(hw_shape, mstype.float32)
837
+ output = concat((one, two, tre))
838
+ multiples = concat0((expend_dims(n_value, -1), one_tensor))
839
+ output = transpose(output, perm1)
840
+ output = tile(output, multiples)
841
+ two_tensor = create_tensor_by_element((2,), mstype.int32)
842
+ three_tensor = create_tensor_by_element((3,), mstype.int32)
843
+ output_shape = concat0((expend_dims(n_value, -1), three_tensor, expend_dims(h_value * w_value, -1)))
844
+ dout_shape = concat0((expend_dims(n_value, -1), expend_dims(h_value * w_value, -1), two_tensor))
845
+ dtheta = batmatmul(reshape(output, output_shape), reshape(dout, dout_shape).astype("float32"))
846
+ return transpose(dtheta, perm2), tre
847
+
848
+ def dyn_bprop(theta, output_size, out, dout):
849
+ len_output_size = reducesum(dyn_shape(output_size))
850
+ dtheta = dyn_ones(Tensor([1, 3, 2], mstype.int32), mstype.float32)
851
+ ret = dyn_ones(Tensor([1, 6], mstype.int32), mstype.float32)
852
+ if len_output_size == 5:
853
+ dtheta, ret = dyn_bprop_five(theta, output_size, out, dout, len_output_size)
854
+ elif len_output_size == 4:
855
+ dtheta, ret = dyn_bprop_four(theta, output_size, out, dout)
856
+ return dtheta, ret
500
857
 
501
- def bprop(theta, output_size, out, dout):
858
+ def static_bprop(theta, output_size, out, dout):
502
859
  x_shape = P.Shape()(dout)
503
860
  n_value = x_shape[0]
504
861
  h_value = x_shape[1]
@@ -530,10 +887,10 @@ def get_bprop_affinegrid(self):
530
887
  vecy = vecy * (h_value - 1) / h_value
531
888
  vecz = vecz * (d_value - 1) / d_value
532
889
  out = vecx
533
- if h_value*d_value != 1:
534
- multiples = (h_value*d_value, 1)
890
+ if h_value * d_value != 1:
891
+ multiples = (h_value * d_value, 1)
535
892
  out = tile(vecx, multiples)
536
- one = reshape(out, (h_value*w_value*d_value, 1))
893
+ one = reshape(out, (h_value * w_value * d_value, 1))
537
894
  if w_value == 1:
538
895
  out = expend_dims(vecy, 0)
539
896
  elif w_value != 1:
@@ -543,21 +900,21 @@ def get_bprop_affinegrid(self):
543
900
  if d_value != 1:
544
901
  multiples = (d_value, 1)
545
902
  out = tile(out, multiples)
546
- two = reshape(out, (h_value*w_value*d_value, 1))
903
+ two = reshape(out, (h_value * w_value * d_value, 1))
547
904
  out = expend_dims(vecz, 0)
548
- if w_value*h_value != 1:
549
- multiples = (w_value*h_value, 1)
905
+ if w_value * h_value != 1:
906
+ multiples = (w_value * h_value, 1)
550
907
  out = tile(vecz, multiples)
551
908
  out = transpose(out, perm1)
552
- tre = reshape(out, (h_value*w_value*d_value, 1))
553
- fou = ones((h_value*w_value*d_value, 1), mstype.float32)
909
+ tre = reshape(out, (h_value * w_value * d_value, 1))
910
+ fou = ones((h_value * w_value * d_value, 1), mstype.float32)
554
911
  output = concat((one, two, tre, fou))
555
912
  output = transpose(output, perm1)
556
913
  if n_value != 1:
557
914
  multiples = (n_value, 1)
558
915
  output = tile(output, multiples)
559
- output = output.view(n_value, 4, h_value*w_value*d_value)
560
- dout_ = dout.view(n_value, d_value*h_value*w_value, 3).astype("float32")
916
+ output = output.view(n_value, 4, h_value * w_value * d_value)
917
+ dout_ = dout.view(n_value, d_value * h_value * w_value, 3).astype("float32")
561
918
  dtheta = batmatmul(output, dout_)
562
919
  dtheta = transpose(dtheta, perm2)
563
920
  elif len_output_size == 4:
@@ -576,25 +933,38 @@ def get_bprop_affinegrid(self):
576
933
  if h_value != 1:
577
934
  multiples = (h_value, 1)
578
935
  out = tile(vecx, multiples)
579
- one = reshape(out, (h_value*w_value, 1))
936
+ one = reshape(out, (h_value * w_value, 1))
580
937
  if w_value == 1:
581
938
  out = expend_dims(vecy, 0)
582
939
  elif w_value != 1:
583
940
  multiples = (w_value, 1)
584
941
  out = tile(vecy, multiples)
585
942
  out = transpose(out, perm1)
586
- two = reshape(out, (h_value*w_value, 1))
587
- tre = ones((h_value*w_value, 1), mstype.float32)
943
+ two = reshape(out, (h_value * w_value, 1))
944
+ tre = ones((h_value * w_value, 1), mstype.float32)
588
945
  output = concat((one, two, tre))
589
946
  multiples = (n_value, 1)
590
947
  output = transpose(output, perm1)
591
948
  output = tile(output, multiples)
592
- output = output.view(n_value, 3, h_value*w_value)
593
- dout_ = dout.view(n_value, h_value*w_value, 2).astype("float32")
949
+ output = output.view(n_value, 3, h_value * w_value)
950
+ dout_ = dout.view(n_value, h_value * w_value, 2).astype("float32")
594
951
  dtheta = batmatmul(output, dout_)
595
952
  dtheta = transpose(dtheta, perm2)
596
953
  return dtheta, tre
597
954
 
955
+ def bprop_gpu(theta, output_size, out, dout):
956
+ is_tensor, _ = convert_to_tensor(output_size)
957
+ if is_tensor:
958
+ return dyn_bprop(theta, output_size, out, dout)
959
+ return static_bprop(theta, output_size, out, dout)
960
+
961
+ def bprop(theta, output_size, out, dout):
962
+ dx = input_grad(dout, output_size)
963
+ return dx, zeros_like(output_size)
964
+
965
+ if context.get_context('device_target') == "GPU":
966
+ return bprop_gpu
967
+
598
968
  return bprop
599
969
 
600
970
 
@@ -672,7 +1042,7 @@ def get_bprop_expand(self):
672
1042
 
673
1043
  def bprop(x, shape, out, dout):
674
1044
  reduce_dims = []
675
- dshape = zeroslike(dout)
1045
+ dshape = zeroslike(shape)
676
1046
  dx_shape = dout.shape
677
1047
  if dx_shape is None:
678
1048
  return dout.sum(), dshape
@@ -696,21 +1066,37 @@ def get_bprop_segment_mean(self):
696
1066
  """Generate bprop for SegmentMean"""
697
1067
  rank = P.Rank()
698
1068
  shape = P.Shape()
1069
+ dyn_shape = P.TensorShape()
699
1070
  fill = P.Fill()
700
1071
  divide = P.Div()
701
1072
  segment_sum = SegmentSum()
702
1073
  gather = P.Gather()
703
1074
  cast = P.Cast()
1075
+ concat = P.Concat()
1076
+ expand_dims = P.ExpandDims()
704
1077
 
705
1078
  def bprop(input_x, segment_ids, output, dout):
706
1079
  input_x_type = F.dtype(input_x)
707
1080
  input_x = cast(input_x, mstype.float32)
708
1081
  dout = cast(dout, mstype.float32)
709
1082
  dout_type = F.dtype(dout)
710
- input_rank = rank(input_x)
1083
+
711
1084
  ones_shape = shape(segment_ids)
712
- ones_shape = ones_shape + (1,) * (input_rank - 1)
713
- ones = fill(dout_type, ones_shape, 1)
1085
+ if F.is_sequence_value_unknown(ones_shape):
1086
+ ones_shape = dyn_shape(segment_ids)
1087
+
1088
+ ones = ()
1089
+ inputx_shape = shape(input_x)
1090
+ if F.is_sequence_value_unknown(inputx_shape):
1091
+ input_rank = dyn_rank(input_x)
1092
+ if input_rank > cast(1, mstype.float32):
1093
+ ones_shape = concat([ones_shape, dyn_ones(expand_dims(input_rank - 1, 0), mstype.int64)])
1094
+ ones = dyn_fill(dout_type, ones_shape, 1)
1095
+ else:
1096
+ input_rank = rank(input_x)
1097
+ ones_shape = ones_shape + (1,) * (input_rank - 1)
1098
+ ones = fill(dout_type, ones_shape, 1)
1099
+
714
1100
  scaled_grad = divide(dout, segment_sum(ones, segment_ids))
715
1101
  return cast(gather(scaled_grad, segment_ids, 0), input_x_type), zeros_like(segment_ids)
716
1102