mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,8 @@ from mindspore.ops import operations as P
24
24
  from mindspore.ops.composite import base
25
25
  from mindspore.ops._primitive_cache import _get_cache_prim
26
26
  from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
27
- TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo
27
+ TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
28
+ SelectView, CopyWithSlice
28
29
  from mindspore.common import dtype as mstype
29
30
  from mindspore.common._register_for_tensor import tensor_operator_registry
30
31
  from mindspore.common.initializer import Zero
@@ -32,6 +33,8 @@ from mindspore.common import Tensor, CSRTensor, COOTensor
32
33
  from mindspore.common import mutable
33
34
  from mindspore import ops
34
35
  from mindspore.ops.primitive import _primexpr
36
+ from mindspore import _checkparam as validator
37
+ from mindspore.common._stub_tensor import _convert_stub
35
38
 
36
39
  slice_get_item = SliceGetItem()
37
40
  hyper_map = base.HyperMap()
@@ -42,6 +45,8 @@ is_parameter = IsParameter()
42
45
  getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
43
46
  setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
44
47
 
48
+ selevt_view = SelectView()
49
+ copy_with_slice = CopyWithSlice()
45
50
 
46
51
  def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
47
52
  new_axis_mask=0, shrink_axis_mask=0):
@@ -65,16 +70,23 @@ class ValueTransferType(IntEnum):
65
70
  kGatherND = 9
66
71
  kScatterNdUpdate = 10
67
72
  kReshape = 11
68
- kScatterND = 12
69
- kNumberToTensor = 13
70
- kHandleSequenceValue = 14
71
- kByPass = 15
72
- kReSetItemByIndex = 16
73
- kCopySlice = 17
74
- kSetItemByBool = 18
75
- kEmptyTensor = 19
76
- kSetItemByEllipsis = 20
77
- kRaiseIndexError = 21
73
+ kSelectView = 12
74
+ kUnsqueeze = 13
75
+ kCopyView = 14
76
+ kScatterND = 15
77
+ kNumberToTensor = 16
78
+ kHandleSequenceValue = 17
79
+ kByPass = 18
80
+ kReSetItemByIndex = 19
81
+ kCopySlice = 20
82
+ kSetItemByBool = 21
83
+ kEmptyTensor = 22
84
+ kSetItemByEllipsis = 23
85
+ kFormatIndexTensor = 24
86
+ kGetitemByBoolTensor = 25
87
+ kSetitemByBoolTensor = 26
88
+ kJustReturn = 27
89
+ kRaiseIndexError = 28
78
90
 
79
91
 
80
92
  def data_update(transfer_types, args, data, new_index, value=None):
@@ -82,11 +94,14 @@ def data_update(transfer_types, args, data, new_index, value=None):
82
94
  We finally generate a new tensor when handling tensor getitem/setitem
83
95
  by transfer data and value with index.
84
96
  """
97
+ origin_data = data
85
98
  for transfer_type, arg in zip(transfer_types, args):
86
99
  if transfer_type == ValueTransferType.kUnknown:
87
100
  raise IndexError(f"Inlvaid transfer type {transfer_type}.")
88
101
  if transfer_type <= ValueTransferType.kScatterND:
89
- data = data_update_by_ops(transfer_type, arg, data, new_index, value)
102
+ data = data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value)
103
+ if transfer_type == ValueTransferType.kJustReturn:
104
+ return _convert_stub(arg)
90
105
  if transfer_type == ValueTransferType.kSetItemByBool:
91
106
  return tensor_setitem_by_bool(data, new_index, value)
92
107
  if transfer_type == ValueTransferType.kCopySlice:
@@ -98,13 +113,19 @@ def data_update(transfer_types, args, data, new_index, value=None):
98
113
  return data
99
114
  if transfer_type == ValueTransferType.kEmptyTensor:
100
115
  return handle_empty_tensor(arg, data)
116
+ if transfer_type == ValueTransferType.kFormatIndexTensor:
117
+ new_index = format_index_tensor(new_index, arg)
118
+ if transfer_type == ValueTransferType.kGetitemByBoolTensor:
119
+ return F.gather_nd(data, new_index.nonzero())
120
+ if transfer_type == ValueTransferType.kSetitemByBoolTensor:
121
+ return handle_setitem_by_bool_tensor(data, new_index, value)
101
122
  if transfer_type == ValueTransferType.kRaiseIndexError:
102
123
  raise IndexError(
103
124
  f'index {arg[0]} is out of bounds for dimension with size {arg[1]}')
104
125
  return data
105
126
 
106
127
 
107
- def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
128
+ def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=None):
108
129
  """
109
130
  Generate a new tensor when handling tensor getitem/setitem
110
131
  by ops.
@@ -125,14 +146,22 @@ def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
125
146
  F.scatter_nd_update(data, new_index, value)
126
147
  elif transfer_type == ValueTransferType.kSelect:
127
148
  data = F.select(Tensor(new_index), value, data)
149
+ elif transfer_type == ValueTransferType.kSelectView:
150
+ data = selevt_view(data, arg[0], arg[1])
151
+ elif transfer_type == ValueTransferType.kCopyView:
152
+ value = _broadcast(F.shape(data), F.cast(value, F.dtype(data)))
153
+ data = copy_with_slice(data, value)
154
+ return origin_data
128
155
  elif transfer_type == ValueTransferType.kReshape:
129
156
  data = F.reshape(data, arg)
130
157
  elif transfer_type == ValueTransferType.kGather:
131
158
  data = F.gather(data, new_index, 0)
132
159
  elif transfer_type == ValueTransferType.kExpandDims:
133
160
  data = F.expand_dims(data, 0)
161
+ elif transfer_type == ValueTransferType.kUnsqueeze:
162
+ data = F.unsqueeze(data, arg)
134
163
  elif transfer_type == ValueTransferType.kStrideSlice:
135
- data = F.strided_slice(data, arg[0], arg[1], arg[2])
164
+ data = strided_slice(data, arg[0], arg[1], arg[2])
136
165
  else:
137
166
  raise IndexError(f"Inlvaid transfer type {transfer_type}.")
138
167
  return data
@@ -144,7 +173,7 @@ def value_update(transfer_types, args, data, value):
144
173
  if transfer_type == ValueTransferType.kByPass:
145
174
  continue
146
175
  if transfer_type == ValueTransferType.kNumberToTensor:
147
- value = F.fill(F.dtype(data), (), value)
176
+ value = F.cast(value, F.dtype(data))
148
177
  elif transfer_type == ValueTransferType.kHandleSequenceValue:
149
178
  op_type, index = arg
150
179
  if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
@@ -182,7 +211,10 @@ def _tensor_setitem(self, index, value):
182
211
  data_update_types = setitem_info[3]
183
212
  data_update_args = setitem_info[4]
184
213
  value = value_update(v_transfer_types, v_transfer_args, self, value)
185
- return data_update(data_update_types, data_update_args, self, new_index, value)
214
+ output = data_update(data_update_types, data_update_args, self, new_index, value)
215
+ if new_index == "view":
216
+ return (self,)
217
+ return output
186
218
 
187
219
 
188
220
  tensor_operator_registry.register("__getitem__", _tensor_getitem)
@@ -273,17 +305,27 @@ def _scalar_to_tensor(input_x):
273
305
  return ops.add(input_x, mutable(Tensor(0)))
274
306
 
275
307
 
308
+ @_primexpr
309
+ def _check_scalar_tensor_args(args):
310
+ """For the item, check that the index of the scalar tensor is set."""
311
+ if args not in ((None,), ()):
312
+ const_utils.raise_value_error("For item, the index of scalar Tensor should not be set.")
313
+
314
+
276
315
  def tensor_item(data, *args):
277
316
  """Tensor getitem by index whose dtype is int or tuple with int."""
278
317
  # transform a.item(tuple(int)) -> a.item(int1,int2...intN)
318
+ if data.ndim == 0:
319
+ _check_scalar_tensor_args(args)
320
+ return data.asnumpy().item()
279
321
  if len(args) == 1 and isinstance(args[0], tuple):
280
322
  args = args[0]
281
323
 
282
324
  args_types = hyper_map(F.typeof, args)
283
325
  if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
284
326
  if data.shape == (1,):
285
- return data[0]
286
- const_utils.raise_value_error("Can only convert an array of size 1 to a Tensor scalar")
327
+ return data.asnumpy().item()
328
+ const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
287
329
 
288
330
  if not const_utils.judge_indexes_types(args_types, mstype.int64):
289
331
  const_utils.raise_type_error("The index object cannot be interpreted as an integer")
@@ -342,7 +384,8 @@ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
342
384
  exp_msg = const_utils.gen_exception_msg(
343
385
  "Tuple index len({}) is not same to tensor dimension({})", len(tuple_index), data.ndim)
344
386
  const_utils.raise_index_error(exp_msg)
345
- return tensor_setitem_by_tuple_with_number(data, tuple_index, nubmer_value)
387
+ nubmer_value = F.cast(nubmer_value, F.dtype(data))
388
+ return tensor_itemset_by_tuple_with_tensor(data, tuple_index, nubmer_value)
346
389
 
347
390
 
348
391
  def _broadcast(broadcast_shape, x):
@@ -429,12 +472,39 @@ def handle_multi_dim_index_tensor(new_index, arg):
429
472
  return new_index
430
473
 
431
474
 
475
+ def format_index_tensor(index, arg):
476
+ """Format index tensor when tensor less than 0"""
477
+ format_indices, format_dims = arg
478
+ if isinstance(index, list):
479
+ for format_idx, format_dim in zip(format_indices, format_dims):
480
+ index_tensor = index[format_idx]
481
+ index[format_idx] = F.select(index_tensor < 0, index_tensor + format_dim, index_tensor)
482
+ return index
483
+ index = Tensor(index)
484
+ return F.select(index < 0, index + format_dims, index)
485
+
486
+
487
+ def handle_setitem_by_bool_tensor(data, index, value):
488
+ """Set a tensor item by a bool tensor with a tensor."""
489
+ value = F.cast(value, F.dtype(data))
490
+ indices = index.nonzero()
491
+ if indices.shape[0] == 0:
492
+ return data
493
+ value_shape = (indices.shape[0],) + data.shape[index.ndim:]
494
+ value = _broadcast(value_shape, value)
495
+ value = F.scatter_nd(indices, value, data.shape)
496
+ index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
497
+ index = _broadcast(data.shape, index)
498
+ result = F.select(index, value, data)
499
+ return result
500
+
501
+
432
502
  def _expand_data_dims(data, tuple_index):
433
503
  """expand the data's dim with 'None' and 'Boolean' in tuple_index"""
434
504
  indexes_types = hyper_map(toptypeof, tuple_index)
435
505
  expand_positions, tuple_index_new = (), ()
436
506
  for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
437
- if isinstance(index_type, mstype.none_type):
507
+ if isinstance(index_type, mstype.NoneType):
438
508
  tuple_index_new += (const_utils.make_empty_slice(),)
439
509
  expand_positions += (i,)
440
510
  elif isinstance(index_type, mstype.Bool):
@@ -471,29 +541,27 @@ def convert_variable_to_tensor_slice(slice_index):
471
541
  return slice_index
472
542
 
473
543
 
544
+ class _TensorIndexGetitem(base.TensorIndexGetitem_):
545
+ """
546
+ Getting item of Tensor.
547
+
548
+ Args:
549
+ data (Tensor): A tuple to be sliced.
550
+ index: Index of tensor.
551
+
552
+ Returns:
553
+ Type is the same as the element type of data.
554
+ """
555
+
556
+ def __call__(self, *args):
557
+ pass
558
+
559
+ _tensor_index_getitem = _TensorIndexGetitem('tensor_index_getitem')
560
+
561
+
474
562
  def tensor_index_by_slice(data, slice_index):
475
563
  """Tensor getitem by a slice."""
476
- min_data_dim, max_data_dim = 1, 8
477
- const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
478
- data_shape = F.shape(data)
479
- slice_index = convert_variable_to_tensor_slice(slice_index)
480
-
481
- is_dynamic = (F.is_sequence_value_unknown(data_shape)
482
- or isinstance(slice_get_item(slice_index, "start"), Tensor)
483
- or isinstance(slice_get_item(slice_index, "stop"), Tensor)
484
- or isinstance(slice_get_item(slice_index, "step"), Tensor))
485
- if is_dynamic:
486
- begin_strides, end_strides, step_strides = get_stride_info_from_slice(data, slice_index)
487
- else:
488
- begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(data_shape, slice_index)
489
- begin_mask = 1 if slice_get_item(slice_index, "start") is None else 0
490
- end_mask = 1 if slice_get_item(slice_index, "stop") is None else 0
491
- for i in range(1, len(data_shape)):
492
- begin_mask += 2 ** i
493
- end_mask += 2 ** i
494
- if begin_mask or end_mask:
495
- return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, 0)
496
- return F.strided_slice(data, begin_strides, end_strides, step_strides)
564
+ return _tensor_index_getitem(data, slice_index)
497
565
 
498
566
 
499
567
  def get_stride_info_from_slice(data, slice_index):
@@ -531,9 +599,12 @@ def _tensor_index_by_bool(data, bool_value):
531
599
  """Tensor getitem by a single bool value"""
532
600
  min_data_dim, max_data_dim = 0, 7
533
601
  const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
602
+ output = data
534
603
  if bool_value:
535
- return F.expand_dims(data, 0)
536
- return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
604
+ output = F.expand_dims(data, 0)
605
+ elif not F.is_sequence_value_unknown(F.shape(data)):
606
+ return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
607
+ return output
537
608
 
538
609
 
539
610
  def get_stride_info_from_integer(tensor_int):
@@ -550,15 +621,14 @@ def get_stride_info_from_integer(tensor_int):
550
621
  def _tensor_index_by_integer(data, int_index):
551
622
  """Tensor getitem by a single integer number"""
552
623
  data_shape = F.shape(data)
553
- if not data_shape:
554
- const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
555
- if data.ndim < 1 or data.ndim > 8:
556
- const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
557
-
558
624
  if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
559
625
  tensor_index = _scalar_to_tensor(int_index)
560
626
  begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
561
627
  else:
628
+ if not data_shape:
629
+ const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
630
+ if data.ndim < 1 or data.ndim > 8:
631
+ const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
562
632
  transformed_number = const_utils.check_range(int_index, data_shape[0])
563
633
  begin_strides, end_strides, step_strides = \
564
634
  const_utils.get_stride_info_from_integer(data_shape, transformed_number)
@@ -570,7 +640,6 @@ def _tensor_index_by_integer(data, int_index):
570
640
  end_mask += 2 ** i
571
641
  return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
572
642
 
573
-
574
643
  def _check_dim_shape_valid(data, tensor_index):
575
644
  """check dim and shape of tensor_index for tensor(bool) indexing"""
576
645
  if data.ndim < tensor_index.ndim:
@@ -583,7 +652,8 @@ def _check_dim_shape_valid(data, tensor_index):
583
652
 
584
653
  def tensor_index_by_bool_tensor(data, tensor_index):
585
654
  """Tensor getitem by a bool tensor"""
586
- _check_dim_shape_valid(data, tensor_index)
655
+ if not F.is_sequence_value_unknown(F.shape(data)):
656
+ _check_dim_shape_valid(data, tensor_index)
587
657
  tensor_index = tensor_index.nonzero()
588
658
  return F.gather_nd(data, tensor_index)
589
659
 
@@ -591,7 +661,8 @@ def tensor_index_by_bool_tensor(data, tensor_index):
591
661
  def tensor_index_by_tensor(data, tensor_index):
592
662
  """Tensor getitem by a single tensor"""
593
663
  min_data_dim, max_data_dim = 0, 7
594
- const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
664
+ if not F.is_sequence_value_unknown(F.shape(data)):
665
+ const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
595
666
  if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
596
667
  return F.gather(data, tensor_index, 0)
597
668
  if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
@@ -609,16 +680,22 @@ def tensor_index_by_list(data, list_index):
609
680
 
610
681
  data_shape = F.shape(data)
611
682
  indexes_types = hyper_map(toptypeof, list_index)
612
- if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)):
683
+ if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)) \
684
+ and not F.is_sequence_value_unknown(list_index):
613
685
  if not F.isconstant(data_shape[0]):
614
686
  if all(isinstance(i, bool) for i in list_index):
615
- const_utils.raise_unimplemented_error(
616
- "Not supported to the dynamic shape tensor slice by using list of Boolean type")
687
+ if F.dyn_shape(data)[0] != len(list_index):
688
+ raise IndexError(
689
+ f'dimension is {F.dyn_shape(data)[0]} but corresponding boolean dimension is {len(list_index)}')
690
+ tensor_index = Tensor(list_index).nonzero()
691
+ return F.gather_nd(data, tensor_index)
617
692
  tensor_index = const_utils.sequence_to_index(list_index, None)
618
693
  else:
619
- tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
694
+ tensor_index = const_utils.sequence_to_index(
695
+ list_index, data_shape[0])
620
696
  if tensor_index is False:
621
- const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
697
+ const_utils.raise_index_error(
698
+ "When tensor is indexed by list, the list can't be empty.")
622
699
  return F.gather(data, tensor_index, 0)
623
700
 
624
701
  tuple_index_new = ()
@@ -637,23 +714,92 @@ def convert_tupleslice_to_tensor(tuple_index):
637
714
  return tuple(new_tuple_index)
638
715
 
639
716
 
640
- def tensor_index_by_tuple(data, tuple_index):
641
- """Tensor getitem by tuple of various types with None"""
642
- if not tuple_index:
643
- return data
717
+ def judge_tuple_index_dim_check_error(index_dim, data_dim):
718
+ """raise IndexError when tuple_index's dim is invalid"""
719
+ if index_dim > data_dim:
720
+ raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
721
+ f"dim of index:{index_dim}, dim of data:{data_dim}")
644
722
 
645
- tuple_index = convert_tupleslice_to_tensor(tuple_index)
646
- op_name = const_utils.TENSOR_GETITEM
647
- tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
648
- data, tuple_index = _expand_data_dims(data, tuple_index)
649
723
 
650
- min_data_dim, max_data_dim = 1, 8
651
- const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
724
+ class _HandleEmptySlice(base.HandleEmptySlice_):
725
+ """
726
+ Getting item of Tensor.
727
+
728
+ Args:
729
+ data (Tensor): A tuple to be sliced.
730
+ index: Index of tensor.
731
+
732
+ Returns:
733
+ Type is the same as the element type of data.
734
+ """
735
+
736
+ def __init__(self, name):
737
+ """Initialize _HandleEmptySlice."""
738
+ base.HandleEmptySlice_.__init__(self, name)
739
+
740
+ def __call__(self, *args):
741
+ pass
742
+
743
+
744
+ _handle_empty_slice = _HandleEmptySlice('handle_zero_tuple_index')
745
+
746
+
747
+ def judge_tuple_index_dim(data, tuple_index):
748
+ """Judge whether tuple_index's dim is valid"""
749
+ data_dim = data.ndim
750
+ index_dim = 0
751
+ for index in tuple_index:
752
+ if isinstance(toptypeof(index), mstype.TensorType) and index.dtype == mstype.bool_:
753
+ index_dim += index.ndim
754
+ elif not isinstance(toptypeof(index), (mstype.NoneType, mstype.Ellipsis_, mstype.Bool)):
755
+ index_dim += 1
756
+ judge_tuple_index_dim_check_error(index_dim, data_dim)
757
+
758
+
759
+ def judge_simple_tuple_index(data, tuple_index):
760
+ """Judge whether tuple_index is simple index, which not rollback to cpu ops."""
761
+ op_name = const_utils.TENSOR_GETITEM
652
762
  indexes_types = hyper_map(toptypeof, tuple_index)
653
763
  contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
654
- if contain_type == const_utils.ALL_BASIC:
764
+ return F.isconstant(tuple_index) and contain_type == const_utils.ALL_BASIC \
765
+ and F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(F.rank(data))
766
+
767
+
768
+ def tensor_index_by_tuple(data, tuple_index):
769
+ """Tensor getitem by tuple of various types with None"""
770
+ if not tuple_index:
771
+ return data
772
+ if judge_simple_tuple_index(data, tuple_index):
773
+ tuple_index = convert_tupleslice_to_tensor(tuple_index)
774
+ op_name = const_utils.TENSOR_GETITEM
775
+ tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
776
+ min_data_dim, max_data_dim = 1, 8
777
+ const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
655
778
  return _tensor_getitem_by_tuple_slice(data, tuple_index)
656
- return _tensor_getitem_by_tuple(data, tuple_index, op_name)
779
+
780
+ if not F.is_sequence_value_unknown(F.shape(data)):
781
+ judge_tuple_index_dim(data, tuple_index)
782
+ tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
783
+ for non_zero_shape in non_zero_shapes:
784
+ if F.reduce_min(non_zero_shape) == 0:
785
+ tuple_index = zero_index
786
+ break
787
+ if not F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(tuple_index):
788
+ _, stub_zero_dim_tensor = _handle_empty_slice(data, tuple_index)
789
+ if 0 in stub_zero_dim_tensor.shape:
790
+ return F.fill(data.dtype, stub_zero_dim_tensor.shape, 0)
791
+ has_tensor_index = False
792
+ for i in tuple_index:
793
+ if isinstance(i, Tensor):
794
+ has_tensor_index = True
795
+ break
796
+ empty_broadcast_data_shape = False
797
+ _broadcast_data_shape = _handle_scalar_tensor_index(data, tuple_index)
798
+ if has_tensor_index and isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
799
+ empty_broadcast_data_shape = True
800
+ if has_tensor_index and isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
801
+ empty_broadcast_data_shape = True
802
+ return _tensor_index_getitem(data, tuple_index, empty_broadcast_data_shape)
657
803
 
658
804
 
659
805
  def get_slice_stride(slice_index, dim_size):
@@ -895,6 +1041,15 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
895
1041
  return indices
896
1042
 
897
1043
 
1044
+ def parse_check_slice_index(index_out, dim_size):
1045
+ """ Parse and check slice index """
1046
+ has_false = False
1047
+ start, stop, step = const_utils.normalize_slice(index_out, dim_size)
1048
+ if F.isconstant(start) and F.isconstant(stop) and F.isconstant(step):
1049
+ has_false = const_utils.check_slice_empty(start, stop, step)
1050
+ return has_false
1051
+
1052
+
898
1053
  def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
899
1054
  """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
900
1055
  data_shape = F.shape(data)
@@ -925,8 +1080,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
925
1080
  tuple_index_new += (tensor_index,)
926
1081
  tensor_indexes.append(tensor_index)
927
1082
  elif i in slice_positions:
928
- start, stop, step = const_utils.normalize_slice(index, dim_size)
929
- if const_utils.check_slice_empty(start, stop, step):
1083
+ if parse_check_slice_index(index, dim_size):
930
1084
  return False
931
1085
  slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
932
1086
  slice_shapes += (len(slice_ele_list_index),)
@@ -962,7 +1116,7 @@ def sequence_to_tensor(value, dtype):
962
1116
 
963
1117
  if value_elements_type == const_utils.ALL_TENSOR:
964
1118
  value = F.stack(value).astype(dtype)
965
- elif value_elements_type == const_utils.NO_TENSOR:
1119
+ elif value_elements_type == const_utils.NO_TENSOR and not F.is_sequence_value_unknown(value):
966
1120
  value = const_utils.make_tensor(value, dtype)
967
1121
  else:
968
1122
  new_value = ()
@@ -984,7 +1138,7 @@ def _generate_updates_from_sequence(data, index, value, op_type):
984
1138
  def _generate_updates_from_tensor(data, index, value, op_type):
985
1139
  """Generate an updates tensor from a tensor."""
986
1140
  value = value.astype(data.dtype)
987
- if F.is_sequence_value_unknown(F.shape(data)):
1141
+ if F.is_sequence_value_unknown(F.shape(data)) or F.is_sequence_value_unknown(F.shape(index)):
988
1142
  data_shape = F.dyn_shape(data)
989
1143
  index_shape = F.dyn_shape(index)
990
1144
  updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
@@ -1025,13 +1179,49 @@ def tensor_setitem_by_number(self, index, value):
1025
1179
  return tensor_setitem_by_number_with_sequence(self, index, value)
1026
1180
 
1027
1181
 
1182
+ def _tuple_index_transfer(broadcast_shape, final_shape, new_shape, x, all_empty_tensor):
1183
+ """Transform tuple index tensor to the required."""
1184
+ if isinstance(broadcast_shape, Tensor):
1185
+ if not all_empty_tensor:
1186
+ x = F.broadcast_to(x, broadcast_shape)
1187
+ x = F.reshape(x, new_shape)
1188
+ x = F.broadcast_to(x, final_shape)
1189
+ return x
1190
+ item = _broadcast(broadcast_shape, x)
1191
+ return _broadcast(final_shape, F.reshape(item, new_shape))
1192
+
1193
+
1194
+ class _TensorIndexSetitem(base.TensorIndexSetitem_):
1195
+ """
1196
+ Getting item of Tensor.
1197
+
1198
+ Args:
1199
+ data (Tensor): A tuple to be sliced.
1200
+ index: Index of tensor.
1201
+
1202
+ Returns:
1203
+ Type is the same as the element type of data.
1204
+ """
1205
+
1206
+ def __call__(self, *args):
1207
+ pass
1208
+
1209
+
1210
+ _tensor_index_setitem = _TensorIndexSetitem('tensor_index_setitem')
1211
+
1212
+
1028
1213
  def tensor_setitem_by_slice(self, index, value):
1029
- index = convert_variable_to_tensor_slice(index)
1030
- if isinstance(value, (int, float, bool)):
1031
- return tensor_setitem_by_slice_with_number(self, index, value)
1032
- if isinstance(value, Tensor):
1033
- return tensor_setitem_by_slice_with_tensor(self, index, value)
1034
- return tensor_setitem_by_slice_with_sequence(self, index, value)
1214
+ """Set a tensor item by slice."""
1215
+ indices, value_shape, start, stop, step, value = _tensor_index_setitem(
1216
+ self, index, value)
1217
+ if start == stop:
1218
+ return self
1219
+ value = F.broadcast_to(value, value_shape)
1220
+ if not const_utils.is_ascend() and step == 1:
1221
+ if isinstance(step, Tensor):
1222
+ return copy_slice(self, value, start, stop, step)
1223
+ return copy_slice(self, value, (start,), (stop,), (step,))
1224
+ return F.tensor_scatter_update(self, indices, value)
1035
1225
 
1036
1226
 
1037
1227
  def tensor_setitem_by_ellipsis(self, index, value):
@@ -1049,8 +1239,6 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
1049
1239
  updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
1050
1240
  data_shape = F.shape(data)
1051
1241
  first_val = data_shape[0]
1052
- if not F.isconstant(first_val):
1053
- first_val = -1
1054
1242
  index = F.select(index < 0, index + first_val, index)
1055
1243
  index = F.expand_dims(index, -1)
1056
1244
  if F.rank(index) < 2:
@@ -1081,13 +1269,12 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
1081
1269
  return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
1082
1270
 
1083
1271
  if F.is_sequence_value_unknown(F.shape(data)):
1084
- const_utils.raise_unimplemented_error(
1085
- "Not supported to the dynamic shape tensor slice by using tensor of Boolean type")
1272
+ return tensor_setitem_by_tuple_with_tensor(data, (index,), value_tensor.astype(data.dtype))
1086
1273
  return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
1087
1274
 
1088
1275
 
1089
1276
  def tensor_setitem_by_tensor_with_number(data, index, value):
1090
- value = F.fill(F.dtype(data), (), value)
1277
+ value = F.cast(value, F.dtype(data))
1091
1278
  return tensor_setitem_by_tensor_with_tensor(data, index, value)
1092
1279
 
1093
1280
 
@@ -1118,13 +1305,13 @@ def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
1118
1305
 
1119
1306
  def tensor_setitem_by_slice_with_number(data, input_slice, value):
1120
1307
  """Givens a scalar assign to tensor by slice"""
1121
- value = F.fill(F.dtype(data), (), value)
1308
+ value = F.cast(value, F.dtype(data))
1122
1309
  return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
1123
1310
 
1124
1311
 
1125
1312
  def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
1126
1313
  """Assigns the tensor by tuple with number value."""
1127
- value = F.fill(F.dtype(data), (), value)
1314
+ value = F.cast(value, F.dtype(data))
1128
1315
  return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
1129
1316
 
1130
1317
 
@@ -1202,7 +1389,123 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
1202
1389
  return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
1203
1390
 
1204
1391
 
1392
+ class _PreSetitemByTuple(base.PreSetitemByTuple_):
1393
+ """
1394
+ Getting item of Tensor.
1395
+
1396
+ Args:
1397
+ data (Tensor): A tuple to be sliced.
1398
+ index: Index of tensor.
1399
+
1400
+ Returns:
1401
+ Type is the same as the element type of data.
1402
+ """
1403
+
1404
+ def __init__(self, name):
1405
+ """Initialize _PreSetitemByTuple."""
1406
+ base.PreSetitemByTuple_.__init__(self, name)
1407
+
1408
+ def __call__(self, *args):
1409
+ pass
1410
+
1411
+
1412
+ _pre_setitem_by_tuple = _PreSetitemByTuple('pre_setitem_by_tuple')
1413
+
1414
+
1415
+ class _HandleBoolTensor(base.HandleBoolTensor_):
1416
+ """
1417
+ Getting item of Tensor.
1418
+
1419
+ Args:
1420
+ data (Tensor): A tuple to be sliced.
1421
+ index: Index of tensor.
1422
+
1423
+ Returns:
1424
+ Type is the same as the element type of data.
1425
+ """
1426
+
1427
+ def __init__(self, name):
1428
+ """Initialize _HandleBoolTensor."""
1429
+ base.HandleBoolTensor_.__init__(self, name)
1430
+
1431
+ def __call__(self, *args):
1432
+ pass
1433
+
1434
+
1435
+ _handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
1436
+
1437
+
1438
+ class _HandleScalarTensorIndex(base.HandleScalarTensorIndex_):
1439
+ """
1440
+ Getting item of Tensor.
1441
+
1442
+ Args:
1443
+ data (Tensor): A tuple to be sliced.
1444
+ index: Index of tensor.
1445
+
1446
+ Returns:
1447
+ Type is the same as the element type of data.
1448
+ """
1449
+
1450
+ def __init__(self, name):
1451
+ """Initialize _HandleBoolTensor."""
1452
+ base.HandleScalarTensorIndex_.__init__(self, name)
1453
+
1454
+ def __call__(self, *args):
1455
+ pass
1456
+
1457
+
1458
+ _handle_scalar_tensor_index = _HandleScalarTensorIndex('handle_scalar_tensor_index')
1459
+
1460
+
1205
1461
  def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
1462
+ """Assigns the tensor by tuple with tensor value."""
1463
+ if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
1464
+ if F.is_sequence_value_unknown(F.shape(data)):
1465
+ return tensor_copy_slice_from_tuple(data, tuple_index, value)
1466
+ dim1_start, dim1_stop, _ = const_utils.normalize_slice(
1467
+ tuple_index[1], data.shape[1])
1468
+ if dim1_stop - dim1_start <= 0:
1469
+ return data
1470
+ dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
1471
+ start = (dim0_start, dim1_start)
1472
+ stop = (dim0_start + 1, dim1_stop)
1473
+ step = (1, 1)
1474
+ value_shape = (dim1_stop - dim1_start,) + \
1475
+ const_utils.tuple_slice(data.shape, 2, None)
1476
+ value = _broadcast(value_shape, value)
1477
+ return copy_slice(data, value.astype(data.dtype), start, stop, step)
1478
+ tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
1479
+
1480
+ for non_zero_shape in non_zero_shapes:
1481
+ if F.reduce_min(non_zero_shape) == 0:
1482
+ return data
1483
+ value = value.astype(data.dtype)
1484
+ special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
1485
+ = _pre_setitem_by_tuple(data, tuple_index, value)
1486
+ if special_index == 0:
1487
+ return data
1488
+ value = F.reshape(value, new_value_shape)
1489
+ if not tuple_index or special_index == 1:
1490
+ data[True] = value
1491
+ return data
1492
+
1493
+ empty_broadcast_data_shape = False
1494
+ if isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
1495
+ empty_broadcast_data_shape = True
1496
+ if isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
1497
+ empty_broadcast_data_shape = True
1498
+ indices = _tensor_index_setitem(
1499
+ data, tuple_index, value, idx_advanced, empty_broadcast_data_shape)
1500
+
1501
+ updates = _generate_updates_from_tensor(
1502
+ data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
1503
+ if is_parameter(data):
1504
+ F.scatter_nd_update(data, indices, updates)
1505
+ return data
1506
+ return F.tensor_scatter_update(data, indices, updates)
1507
+
1508
+ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
1206
1509
  """Assigns the tensor by tuple with tensor value."""
1207
1510
  op_name = const_utils.TENSOR_SETITEM
1208
1511
  tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
@@ -1220,7 +1523,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
1220
1523
  value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
1221
1524
  value = _broadcast(value_shape, value)
1222
1525
  return copy_slice(data, value.astype(data.dtype), start, stop, step)
1223
-
1224
1526
  tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
1225
1527
 
1226
1528
  if tuple_index is False:
@@ -1248,7 +1550,7 @@ def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
1248
1550
 
1249
1551
  def tensor_setitem_by_number_with_number(data, index, value):
1250
1552
  """Assigns the tensor by number with number value."""
1251
- value = F.fill(F.dtype(data), (), value)
1553
+ value = F.cast(value, F.dtype(data))
1252
1554
  return tensor_setitem_by_number_with_tensor(data, index, value)
1253
1555
 
1254
1556
 
@@ -1283,7 +1585,7 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
1283
1585
  data_shape = F.shape(data)
1284
1586
  data_dtype = F.dtype(data)
1285
1587
  if F.is_sequence_value_unknown(data_shape):
1286
- value = F.fill(F.dtype(data), (), value)
1588
+ value = F.cast(value, F.dtype(data))
1287
1589
  return tensor_setitem_by_ellipsis_with_tensor(data, value)
1288
1590
  return F.fill(data_dtype, data_shape, value)
1289
1591
 
@@ -1315,6 +1617,7 @@ def tensor_setitem_by_ellipsis_with_sequence(data, value):
1315
1617
  def tensor_setitem_by_bool(data, index, value):
1316
1618
  """Assigns a value to the tensor by boolean."""
1317
1619
  data_shape = F.shape(data)
1620
+ data_dtype = F.dtype(data)
1318
1621
  if not index:
1319
1622
  data_shape = (0,) + data_shape
1320
1623
  if isinstance(value, (list, tuple)):
@@ -1326,6 +1629,7 @@ def tensor_setitem_by_bool(data, index, value):
1326
1629
 
1327
1630
  if F.is_sequence_value_unknown(data_shape) and index:
1328
1631
  data_shape = F.dyn_shape(data)
1632
+ value = value.astype(data_dtype)
1329
1633
  data = ops.broadcast_to(value, data_shape)
1330
1634
  return data
1331
1635
  value_shape = F.shape(value)
@@ -1333,7 +1637,7 @@ def tensor_setitem_by_bool(data, index, value):
1333
1637
  if index:
1334
1638
  value = F.reshape(value, source_shape)
1335
1639
  value = _broadcast(data_shape, value)
1336
- data = value
1640
+ data = F.cast(value, data_dtype)
1337
1641
  return data
1338
1642
 
1339
1643
 
@@ -1417,8 +1721,8 @@ def remove_expanded_dims(tuple_index, data_shape, value):
1417
1721
  elif const_utils.is_slice(index_out):
1418
1722
  indices_out += (index_out,)
1419
1723
  not_expanded_dim += (True,)
1420
- start, stop, step = const_utils.normalize_slice(index_out, data_shape[cur_dim])
1421
- has_false = has_false or const_utils.check_slice_empty(start, stop, step)
1724
+ has_false = has_false or parse_check_slice_index(
1725
+ index_out, data_shape[cur_dim])
1422
1726
  cur_dim += 1
1423
1727
  elif isinstance(index_out, (Tensor, bool)): # advanced index
1424
1728
  if idx_advanced == -1:
@@ -1490,7 +1794,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
1490
1794
  ndim = F.rank(a)
1491
1795
  if dtype is None:
1492
1796
  dtype = F.dtype(a)
1493
- axes = const_utils.check_axis_valid_const(axis, ndim)
1797
+ axes = validator.check_axis_valid(axis, ndim)
1494
1798
  if initial is not None:
1495
1799
  if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or
1496
1800
  not isinstance(initial, (int, float, bool, Tensor))):
@@ -1505,18 +1809,20 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
1505
1809
  initial = F.fill(dtype, shape, initial)
1506
1810
  a = cmp_fn(a, initial)
1507
1811
 
1508
- if isinstance(where, Tensor):
1812
+ if where is not None and not isinstance(where, Tensor):
1813
+ where = Tensor(where, dtype=mstype.bool_)
1814
+
1815
+ if where is not None and (where.shape or not where):
1509
1816
  if initial is None:
1510
1817
  const_utils.raise_value_error('initial value must be provided for where masks')
1511
1818
  ndim_orig = F.rank(a)
1512
1819
  # broadcasts input tensors
1513
1820
  shape_out = const_utils.infer_out_shape(F.shape(where), F.shape(a), F.shape(initial))
1514
- broadcast_to = P.BroadcastTo(shape_out)
1515
1821
  where = where.astype(mstype.float32)
1516
- where = broadcast_to(where)
1822
+ where = F.broadcast_to(where, shape_out)
1517
1823
  where = where.astype(mstype.bool_)
1518
- a = broadcast_to(a)
1519
- initial = broadcast_to(initial)
1824
+ a = F.broadcast_to(a, shape_out)
1825
+ initial = F.broadcast_to(initial, shape_out)
1520
1826
  a = F.select(where, a, initial)
1521
1827
  axes = const_utils.real_axes(ndim_orig, F.rank(a), axes)
1522
1828