mindspore 2.1.0__cp38-none-any.whl → 2.2.11__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (578) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +61 -95
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/__init__.py +4 -2
  257. mindspore/nn/wrap/cell_wrapper.py +87 -34
  258. mindspore/nn/wrap/grad_reducer.py +8 -5
  259. mindspore/nn/wrap/loss_scale.py +105 -42
  260. mindspore/numpy/array_creations.py +1 -2
  261. mindspore/numpy/array_ops.py +3 -2
  262. mindspore/numpy/utils_const.py +5 -5
  263. mindspore/offline_debug/convert_async.py +2 -2
  264. mindspore/ops/_grad_experimental/__init__.py +0 -5
  265. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  266. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  267. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  268. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  269. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  270. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  271. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  272. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  273. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  274. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  275. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  276. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  277. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  278. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  279. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  280. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  281. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  282. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  283. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  284. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  285. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  286. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  287. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  288. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  289. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  290. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  291. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  293. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  294. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  295. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  296. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  298. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  299. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  300. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  301. mindspore/ops/_primitive_cache.py +1 -1
  302. mindspore/ops/_tracefunc.py +45 -13
  303. mindspore/ops/_utils/utils.py +6 -1
  304. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  305. mindspore/ops/_vmap/vmap_base.py +3 -3
  306. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  307. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  308. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  309. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  310. mindspore/ops/arg_dtype_cast.py +54 -0
  311. mindspore/ops/composite/base.py +37 -10
  312. mindspore/ops/composite/math_ops.py +5 -4
  313. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  314. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  315. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  316. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  317. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  320. mindspore/ops/deprecated.py +304 -0
  321. mindspore/ops/function/__init__.py +4 -1
  322. mindspore/ops/function/array_func.py +174 -193
  323. mindspore/ops/function/clip_func.py +81 -13
  324. mindspore/ops/function/debug_func.py +1 -1
  325. mindspore/ops/function/grad/grad_func.py +18 -9
  326. mindspore/ops/function/image_func.py +10 -4
  327. mindspore/ops/function/linalg_func.py +5 -5
  328. mindspore/ops/function/math_func.py +575 -386
  329. mindspore/ops/function/nn_func.py +568 -260
  330. mindspore/ops/function/random_func.py +88 -57
  331. mindspore/ops/function/sparse_func.py +1 -1
  332. mindspore/ops/function/sparse_unary_func.py +14 -12
  333. mindspore/ops/function/vmap_func.py +6 -5
  334. mindspore/ops/functional.py +15 -10
  335. mindspore/ops/op_info_register.py +244 -25
  336. mindspore/ops/operations/__init__.py +31 -19
  337. mindspore/ops/operations/_grad_ops.py +71 -7
  338. mindspore/ops/operations/_inner_ops.py +350 -17
  339. mindspore/ops/operations/_quant_ops.py +4 -8
  340. mindspore/ops/operations/_sequence_ops.py +42 -0
  341. mindspore/ops/operations/array_ops.py +68 -282
  342. mindspore/ops/operations/comm_ops.py +107 -59
  343. mindspore/ops/operations/custom_ops.py +94 -70
  344. mindspore/ops/operations/debug_ops.py +8 -4
  345. mindspore/ops/operations/image_ops.py +18 -12
  346. mindspore/ops/operations/inner_ops.py +26 -3
  347. mindspore/ops/operations/math_ops.py +192 -144
  348. mindspore/ops/operations/nn_ops.py +857 -489
  349. mindspore/ops/operations/other_ops.py +0 -22
  350. mindspore/ops/operations/random_ops.py +53 -111
  351. mindspore/ops/operations/sparse_ops.py +3 -1
  352. mindspore/ops/primitive.py +24 -18
  353. mindspore/parallel/_auto_parallel_context.py +68 -8
  354. mindspore/parallel/_cost_model_context.py +2 -2
  355. mindspore/parallel/_offload_context.py +17 -3
  356. mindspore/parallel/_parallel_serialization.py +12 -5
  357. mindspore/parallel/_ps_context.py +12 -0
  358. mindspore/parallel/_tensor.py +18 -13
  359. mindspore/parallel/_transformer/layers.py +5 -3
  360. mindspore/parallel/_transformer/loss.py +1 -0
  361. mindspore/parallel/_transformer/moe.py +2 -2
  362. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  363. mindspore/parallel/_transformer/transformer.py +23 -3
  364. mindspore/parallel/_utils.py +11 -7
  365. mindspore/parallel/algo_parameter_config.py +85 -5
  366. mindspore/parallel/checkpoint_transform.py +19 -12
  367. mindspore/parallel/shard.py +21 -14
  368. mindspore/profiler/common/struct_type.py +3 -3
  369. mindspore/profiler/common/util.py +4 -2
  370. mindspore/profiler/envprofiling.py +1 -1
  371. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  372. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  373. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  374. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  375. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  376. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  377. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  378. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  379. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  380. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  381. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  382. mindspore/profiler/parser/flops_parser.py +15 -11
  383. mindspore/profiler/parser/framework_parser.py +38 -22
  384. mindspore/profiler/parser/hccl_parser.py +16 -12
  385. mindspore/profiler/parser/integrator.py +22 -11
  386. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  387. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  388. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  389. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  390. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  391. mindspore/profiler/parser/optime_parser.py +1 -1
  392. mindspore/profiler/parser/profiler_info.py +21 -2
  393. mindspore/profiler/parser/step_trace_parser.py +11 -14
  394. mindspore/profiler/profiling.py +179 -89
  395. mindspore/rewrite/api/node.py +102 -19
  396. mindspore/rewrite/api/node_type.py +5 -1
  397. mindspore/rewrite/api/pattern_engine.py +1 -1
  398. mindspore/rewrite/api/scoped_value.py +9 -17
  399. mindspore/rewrite/api/symbol_tree.py +131 -47
  400. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  401. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  402. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  403. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  404. mindspore/rewrite/common/rewrite_elog.py +5 -1
  405. mindspore/rewrite/namer.py +33 -24
  406. mindspore/rewrite/namespace.py +14 -5
  407. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  408. mindspore/rewrite/node/call_function.py +79 -0
  409. mindspore/rewrite/node/cell_container.py +135 -0
  410. mindspore/rewrite/node/control_flow.py +88 -0
  411. mindspore/rewrite/{node.py → node/node.py} +273 -234
  412. mindspore/rewrite/node/node_manager.py +254 -0
  413. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  414. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  415. mindspore/rewrite/parsers/assign_parser.py +216 -221
  416. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  417. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  418. mindspore/rewrite/parsers/constant_parser.py +9 -6
  419. mindspore/rewrite/parsers/container_parser.py +9 -7
  420. mindspore/rewrite/parsers/for_parser.py +42 -21
  421. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  422. mindspore/rewrite/parsers/if_parser.py +28 -24
  423. mindspore/rewrite/parsers/module_parser.py +196 -25
  424. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  425. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  426. mindspore/rewrite/parsers/return_parser.py +6 -6
  427. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  428. mindspore/rewrite/sparsify/utils.py +1 -1
  429. mindspore/rewrite/symbol_tree.py +523 -578
  430. mindspore/rewrite/symbol_tree_builder.py +9 -193
  431. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  432. mindspore/run_check/_check_version.py +6 -4
  433. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  434. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  435. mindspore/scipy/linalg.py +1 -1
  436. mindspore/scipy/ops.py +55 -5
  437. mindspore/scipy/optimize/__init__.py +3 -2
  438. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  439. mindspore/scipy/optimize/minimize.py +7 -3
  440. mindspore/train/_utils.py +7 -3
  441. mindspore/train/amp.py +323 -123
  442. mindspore/train/anf_ir_pb2.py +14 -2
  443. mindspore/train/callback/_backup_and_restore.py +2 -12
  444. mindspore/train/callback/_callback.py +29 -4
  445. mindspore/train/callback/_checkpoint.py +23 -8
  446. mindspore/train/callback/_early_stop.py +2 -2
  447. mindspore/train/callback/_landscape.py +4 -4
  448. mindspore/train/callback/_loss_monitor.py +2 -2
  449. mindspore/train/callback/_on_request_exit.py +2 -2
  450. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  451. mindspore/train/callback/_summary_collector.py +15 -8
  452. mindspore/train/callback/_time_monitor.py +58 -5
  453. mindspore/train/data_sink.py +5 -11
  454. mindspore/train/dataset_helper.py +84 -57
  455. mindspore/train/loss_scale_manager.py +2 -2
  456. mindspore/train/metrics/__init__.py +3 -3
  457. mindspore/train/metrics/cosine_similarity.py +1 -1
  458. mindspore/train/metrics/hausdorff_distance.py +3 -2
  459. mindspore/train/metrics/mean_surface_distance.py +3 -2
  460. mindspore/train/metrics/metric.py +39 -19
  461. mindspore/train/metrics/roc.py +2 -2
  462. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  463. mindspore/train/mind_ir_pb2.py +85 -36
  464. mindspore/train/model.py +187 -47
  465. mindspore/train/serialization.py +487 -161
  466. mindspore/train/summary/_summary_adapter.py +1 -1
  467. mindspore/train/summary/_writer_pool.py +3 -2
  468. mindspore/train/summary/summary_record.py +37 -17
  469. mindspore/train/train_thor/convert_utils.py +3 -3
  470. mindspore/train/train_thor/dataset_helper.py +1 -1
  471. mindspore/version.py +1 -1
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +477 -528
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  475. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  478. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  479. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  480. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  481. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  482. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  483. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  486. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  487. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  488. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  489. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  490. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  491. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  492. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  493. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  494. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  495. mindspore/_extends/graph_kernel/expander.py +0 -80
  496. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  497. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  498. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  499. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  500. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  501. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  502. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  503. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  504. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  505. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  506. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  507. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  508. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  509. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  510. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  511. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  512. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  513. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  514. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  515. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  516. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  517. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  518. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  519. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  520. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  523. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  524. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  525. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  526. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  527. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  528. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  532. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  533. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  534. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  535. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  536. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  537. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  538. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  539. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  541. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  542. mindspore/dataset/datapreprocess/__init__.py +0 -20
  543. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  544. mindspore/include/api/net.h +0 -142
  545. mindspore/nn/lr_scheduler.py +0 -262
  546. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  547. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  548. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  549. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  550. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  559. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  560. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  564. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  575. mindspore/rewrite/node_visitor.py +0 -44
  576. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  578. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,8 @@ from mindspore.ops import operations as P
24
24
  from mindspore.ops.composite import base
25
25
  from mindspore.ops._primitive_cache import _get_cache_prim
26
26
  from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
27
- TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo
27
+ TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
28
+ SelectView, CopyWithSlice
28
29
  from mindspore.common import dtype as mstype
29
30
  from mindspore.common._register_for_tensor import tensor_operator_registry
30
31
  from mindspore.common.initializer import Zero
@@ -33,6 +34,7 @@ from mindspore.common import mutable
33
34
  from mindspore import ops
34
35
  from mindspore.ops.primitive import _primexpr
35
36
  from mindspore import _checkparam as validator
37
+ from mindspore.common._stub_tensor import _convert_stub
36
38
 
37
39
  slice_get_item = SliceGetItem()
38
40
  hyper_map = base.HyperMap()
@@ -43,6 +45,8 @@ is_parameter = IsParameter()
43
45
  getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
44
46
  setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
45
47
 
48
+ selevt_view = SelectView()
49
+ copy_with_slice = CopyWithSlice()
46
50
 
47
51
  def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
48
52
  new_axis_mask=0, shrink_axis_mask=0):
@@ -66,19 +70,23 @@ class ValueTransferType(IntEnum):
66
70
  kGatherND = 9
67
71
  kScatterNdUpdate = 10
68
72
  kReshape = 11
69
- kScatterND = 12
70
- kNumberToTensor = 13
71
- kHandleSequenceValue = 14
72
- kByPass = 15
73
- kReSetItemByIndex = 16
74
- kCopySlice = 17
75
- kSetItemByBool = 18
76
- kEmptyTensor = 19
77
- kSetItemByEllipsis = 20
78
- kFormatIndexTensor = 21
79
- kGetitemByBoolTensor = 22
80
- kSetitemByBoolTensor = 23
81
- kRaiseIndexError = 24
73
+ kSelectView = 12
74
+ kUnsqueeze = 13
75
+ kCopyView = 14
76
+ kScatterND = 15
77
+ kNumberToTensor = 16
78
+ kHandleSequenceValue = 17
79
+ kByPass = 18
80
+ kReSetItemByIndex = 19
81
+ kCopySlice = 20
82
+ kSetItemByBool = 21
83
+ kEmptyTensor = 22
84
+ kSetItemByEllipsis = 23
85
+ kFormatIndexTensor = 24
86
+ kGetitemByBoolTensor = 25
87
+ kSetitemByBoolTensor = 26
88
+ kJustReturn = 27
89
+ kRaiseIndexError = 28
82
90
 
83
91
 
84
92
  def data_update(transfer_types, args, data, new_index, value=None):
@@ -86,11 +94,14 @@ def data_update(transfer_types, args, data, new_index, value=None):
86
94
  We finally generate a new tensor when handling tensor getitem/setitem
87
95
  by transfer data and value with index.
88
96
  """
97
+ origin_data = data
89
98
  for transfer_type, arg in zip(transfer_types, args):
90
99
  if transfer_type == ValueTransferType.kUnknown:
91
100
  raise IndexError(f"Inlvaid transfer type {transfer_type}.")
92
101
  if transfer_type <= ValueTransferType.kScatterND:
93
- data = data_update_by_ops(transfer_type, arg, data, new_index, value)
102
+ data = data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value)
103
+ if transfer_type == ValueTransferType.kJustReturn:
104
+ return _convert_stub(arg)
94
105
  if transfer_type == ValueTransferType.kSetItemByBool:
95
106
  return tensor_setitem_by_bool(data, new_index, value)
96
107
  if transfer_type == ValueTransferType.kCopySlice:
@@ -114,7 +125,7 @@ def data_update(transfer_types, args, data, new_index, value=None):
114
125
  return data
115
126
 
116
127
 
117
- def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
128
+ def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=None):
118
129
  """
119
130
  Generate a new tensor when handling tensor getitem/setitem
120
131
  by ops.
@@ -135,14 +146,22 @@ def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
135
146
  F.scatter_nd_update(data, new_index, value)
136
147
  elif transfer_type == ValueTransferType.kSelect:
137
148
  data = F.select(Tensor(new_index), value, data)
149
+ elif transfer_type == ValueTransferType.kSelectView:
150
+ data = selevt_view(data, arg[0], arg[1])
151
+ elif transfer_type == ValueTransferType.kCopyView:
152
+ value = _broadcast(F.shape(data), F.cast(value, F.dtype(data)))
153
+ data = copy_with_slice(data, value)
154
+ return origin_data
138
155
  elif transfer_type == ValueTransferType.kReshape:
139
156
  data = F.reshape(data, arg)
140
157
  elif transfer_type == ValueTransferType.kGather:
141
158
  data = F.gather(data, new_index, 0)
142
159
  elif transfer_type == ValueTransferType.kExpandDims:
143
160
  data = F.expand_dims(data, 0)
161
+ elif transfer_type == ValueTransferType.kUnsqueeze:
162
+ data = F.unsqueeze(data, arg)
144
163
  elif transfer_type == ValueTransferType.kStrideSlice:
145
- data = F.strided_slice(data, arg[0], arg[1], arg[2])
164
+ data = strided_slice(data, arg[0], arg[1], arg[2])
146
165
  else:
147
166
  raise IndexError(f"Inlvaid transfer type {transfer_type}.")
148
167
  return data
@@ -154,7 +173,7 @@ def value_update(transfer_types, args, data, value):
154
173
  if transfer_type == ValueTransferType.kByPass:
155
174
  continue
156
175
  if transfer_type == ValueTransferType.kNumberToTensor:
157
- value = F.fill(F.dtype(data), (), value)
176
+ value = F.cast(value, F.dtype(data))
158
177
  elif transfer_type == ValueTransferType.kHandleSequenceValue:
159
178
  op_type, index = arg
160
179
  if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
@@ -192,7 +211,10 @@ def _tensor_setitem(self, index, value):
192
211
  data_update_types = setitem_info[3]
193
212
  data_update_args = setitem_info[4]
194
213
  value = value_update(v_transfer_types, v_transfer_args, self, value)
195
- return data_update(data_update_types, data_update_args, self, new_index, value)
214
+ output = data_update(data_update_types, data_update_args, self, new_index, value)
215
+ if new_index == "view":
216
+ return (self,)
217
+ return output
196
218
 
197
219
 
198
220
  tensor_operator_registry.register("__getitem__", _tensor_getitem)
@@ -286,7 +308,7 @@ def _scalar_to_tensor(input_x):
286
308
  @_primexpr
287
309
  def _check_scalar_tensor_args(args):
288
310
  """For the item, check that the index of the scalar tensor is set."""
289
- if args != (None,) and args != ():
311
+ if args not in ((None,), ()):
290
312
  const_utils.raise_value_error("For item, the index of scalar Tensor should not be set.")
291
313
 
292
314
 
@@ -295,15 +317,15 @@ def tensor_item(data, *args):
295
317
  # transform a.item(tuple(int)) -> a.item(int1,int2...intN)
296
318
  if data.ndim == 0:
297
319
  _check_scalar_tensor_args(args)
298
- return data
320
+ return data.asnumpy().item()
299
321
  if len(args) == 1 and isinstance(args[0], tuple):
300
322
  args = args[0]
301
323
 
302
324
  args_types = hyper_map(F.typeof, args)
303
325
  if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
304
326
  if data.shape == (1,):
305
- return data[0]
306
- const_utils.raise_value_error("Can only convert an array of size 1 to a Tensor scalar")
327
+ return data.asnumpy().item()
328
+ const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
307
329
 
308
330
  if not const_utils.judge_indexes_types(args_types, mstype.int64):
309
331
  const_utils.raise_type_error("The index object cannot be interpreted as an integer")
@@ -362,7 +384,8 @@ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
362
384
  exp_msg = const_utils.gen_exception_msg(
363
385
  "Tuple index len({}) is not same to tensor dimension({})", len(tuple_index), data.ndim)
364
386
  const_utils.raise_index_error(exp_msg)
365
- return tensor_setitem_by_tuple_with_number(data, tuple_index, nubmer_value)
387
+ nubmer_value = F.cast(nubmer_value, F.dtype(data))
388
+ return tensor_itemset_by_tuple_with_tensor(data, tuple_index, nubmer_value)
366
389
 
367
390
 
368
391
  def _broadcast(broadcast_shape, x):
@@ -530,10 +553,6 @@ class _TensorIndexGetitem(base.TensorIndexGetitem_):
530
553
  Type is the same as the element type of data.
531
554
  """
532
555
 
533
- def __init__(self, name):
534
- """Initialize _TensorIndexGetitem."""
535
- base.TensorIndexGetitem_.__init__(self, name)
536
-
537
556
  def __call__(self, *args):
538
557
  pass
539
558
 
@@ -580,9 +599,12 @@ def _tensor_index_by_bool(data, bool_value):
580
599
  """Tensor getitem by a single bool value"""
581
600
  min_data_dim, max_data_dim = 0, 7
582
601
  const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
602
+ output = data
583
603
  if bool_value:
584
- return F.expand_dims(data, 0)
585
- return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
604
+ output = F.expand_dims(data, 0)
605
+ elif not F.is_sequence_value_unknown(F.shape(data)):
606
+ return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
607
+ return output
586
608
 
587
609
 
588
610
  def get_stride_info_from_integer(tensor_int):
@@ -599,15 +621,14 @@ def get_stride_info_from_integer(tensor_int):
599
621
  def _tensor_index_by_integer(data, int_index):
600
622
  """Tensor getitem by a single integer number"""
601
623
  data_shape = F.shape(data)
602
- if not data_shape:
603
- const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
604
- if data.ndim < 1 or data.ndim > 8:
605
- const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
606
-
607
624
  if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
608
625
  tensor_index = _scalar_to_tensor(int_index)
609
626
  begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
610
627
  else:
628
+ if not data_shape:
629
+ const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
630
+ if data.ndim < 1 or data.ndim > 8:
631
+ const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
611
632
  transformed_number = const_utils.check_range(int_index, data_shape[0])
612
633
  begin_strides, end_strides, step_strides = \
613
634
  const_utils.get_stride_info_from_integer(data_shape, transformed_number)
@@ -619,7 +640,6 @@ def _tensor_index_by_integer(data, int_index):
619
640
  end_mask += 2 ** i
620
641
  return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
621
642
 
622
-
623
643
  def _check_dim_shape_valid(data, tensor_index):
624
644
  """check dim and shape of tensor_index for tensor(bool) indexing"""
625
645
  if data.ndim < tensor_index.ndim:
@@ -632,7 +652,8 @@ def _check_dim_shape_valid(data, tensor_index):
632
652
 
633
653
  def tensor_index_by_bool_tensor(data, tensor_index):
634
654
  """Tensor getitem by a bool tensor"""
635
- _check_dim_shape_valid(data, tensor_index)
655
+ if not F.is_sequence_value_unknown(F.shape(data)):
656
+ _check_dim_shape_valid(data, tensor_index)
636
657
  tensor_index = tensor_index.nonzero()
637
658
  return F.gather_nd(data, tensor_index)
638
659
 
@@ -640,7 +661,8 @@ def tensor_index_by_bool_tensor(data, tensor_index):
640
661
  def tensor_index_by_tensor(data, tensor_index):
641
662
  """Tensor getitem by a single tensor"""
642
663
  min_data_dim, max_data_dim = 0, 7
643
- const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
664
+ if not F.is_sequence_value_unknown(F.shape(data)):
665
+ const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
644
666
  if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
645
667
  return F.gather(data, tensor_index, 0)
646
668
  if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
@@ -658,16 +680,22 @@ def tensor_index_by_list(data, list_index):
658
680
 
659
681
  data_shape = F.shape(data)
660
682
  indexes_types = hyper_map(toptypeof, list_index)
661
- if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)):
683
+ if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)) \
684
+ and not F.is_sequence_value_unknown(list_index):
662
685
  if not F.isconstant(data_shape[0]):
663
686
  if all(isinstance(i, bool) for i in list_index):
664
- const_utils.raise_unimplemented_error(
665
- "Not supported to the dynamic shape tensor slice by using list of Boolean type")
687
+ if F.dyn_shape(data)[0] != len(list_index):
688
+ raise IndexError(
689
+ f'dimension is {F.dyn_shape(data)[0]} but corresponding boolean dimension is {len(list_index)}')
690
+ tensor_index = Tensor(list_index).nonzero()
691
+ return F.gather_nd(data, tensor_index)
666
692
  tensor_index = const_utils.sequence_to_index(list_index, None)
667
693
  else:
668
- tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
694
+ tensor_index = const_utils.sequence_to_index(
695
+ list_index, data_shape[0])
669
696
  if tensor_index is False:
670
- const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
697
+ const_utils.raise_index_error(
698
+ "When tensor is indexed by list, the list can't be empty.")
671
699
  return F.gather(data, tensor_index, 0)
672
700
 
673
701
  tuple_index_new = ()
@@ -693,6 +721,29 @@ def judge_tuple_index_dim_check_error(index_dim, data_dim):
693
721
  f"dim of index:{index_dim}, dim of data:{data_dim}")
694
722
 
695
723
 
724
+ class _HandleEmptySlice(base.HandleEmptySlice_):
725
+ """
726
+ Getting item of Tensor.
727
+
728
+ Args:
729
+ data (Tensor): A tuple to be sliced.
730
+ index: Index of tensor.
731
+
732
+ Returns:
733
+ Type is the same as the element type of data.
734
+ """
735
+
736
+ def __init__(self, name):
737
+ """Initialize _HandleEmptySlice."""
738
+ base.HandleEmptySlice_.__init__(self, name)
739
+
740
+ def __call__(self, *args):
741
+ pass
742
+
743
+
744
+ _handle_empty_slice = _HandleEmptySlice('handle_zero_tuple_index')
745
+
746
+
696
747
  def judge_tuple_index_dim(data, tuple_index):
697
748
  """Judge whether tuple_index's dim is valid"""
698
749
  data_dim = data.ndim
@@ -700,29 +751,55 @@ def judge_tuple_index_dim(data, tuple_index):
700
751
  for index in tuple_index:
701
752
  if isinstance(toptypeof(index), mstype.TensorType) and index.dtype == mstype.bool_:
702
753
  index_dim += index.ndim
703
- else:
754
+ elif not isinstance(toptypeof(index), (mstype.NoneType, mstype.Ellipsis_, mstype.Bool)):
704
755
  index_dim += 1
705
756
  judge_tuple_index_dim_check_error(index_dim, data_dim)
706
757
 
707
758
 
759
+ def judge_simple_tuple_index(data, tuple_index):
760
+ """Judge whether tuple_index is simple index, which not rollback to cpu ops."""
761
+ op_name = const_utils.TENSOR_GETITEM
762
+ indexes_types = hyper_map(toptypeof, tuple_index)
763
+ contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
764
+ return F.isconstant(tuple_index) and contain_type == const_utils.ALL_BASIC \
765
+ and F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(F.rank(data))
766
+
767
+
708
768
  def tensor_index_by_tuple(data, tuple_index):
709
769
  """Tensor getitem by tuple of various types with None"""
710
770
  if not tuple_index:
711
771
  return data
712
-
713
- tuple_index = convert_tupleslice_to_tensor(tuple_index)
714
- op_name = const_utils.TENSOR_GETITEM
715
- tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
716
- data, tuple_index = _expand_data_dims(data, tuple_index)
717
-
718
- min_data_dim, max_data_dim = 1, 8
719
- const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
720
- judge_tuple_index_dim(data, tuple_index)
721
- indexes_types = hyper_map(toptypeof, tuple_index)
722
- contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
723
- if contain_type == const_utils.ALL_BASIC:
772
+ if judge_simple_tuple_index(data, tuple_index):
773
+ tuple_index = convert_tupleslice_to_tensor(tuple_index)
774
+ op_name = const_utils.TENSOR_GETITEM
775
+ tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
776
+ min_data_dim, max_data_dim = 1, 8
777
+ const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
724
778
  return _tensor_getitem_by_tuple_slice(data, tuple_index)
725
- return _tensor_getitem_by_tuple(data, tuple_index, op_name)
779
+
780
+ if not F.is_sequence_value_unknown(F.shape(data)):
781
+ judge_tuple_index_dim(data, tuple_index)
782
+ tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
783
+ for non_zero_shape in non_zero_shapes:
784
+ if F.reduce_min(non_zero_shape) == 0:
785
+ tuple_index = zero_index
786
+ break
787
+ if not F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(tuple_index):
788
+ _, stub_zero_dim_tensor = _handle_empty_slice(data, tuple_index)
789
+ if 0 in stub_zero_dim_tensor.shape:
790
+ return F.fill(data.dtype, stub_zero_dim_tensor.shape, 0)
791
+ has_tensor_index = False
792
+ for i in tuple_index:
793
+ if isinstance(i, Tensor):
794
+ has_tensor_index = True
795
+ break
796
+ empty_broadcast_data_shape = False
797
+ _broadcast_data_shape = _handle_scalar_tensor_index(data, tuple_index)
798
+ if has_tensor_index and isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
799
+ empty_broadcast_data_shape = True
800
+ if has_tensor_index and isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
801
+ empty_broadcast_data_shape = True
802
+ return _tensor_index_getitem(data, tuple_index, empty_broadcast_data_shape)
726
803
 
727
804
 
728
805
  def get_slice_stride(slice_index, dim_size):
@@ -1039,7 +1116,7 @@ def sequence_to_tensor(value, dtype):
1039
1116
 
1040
1117
  if value_elements_type == const_utils.ALL_TENSOR:
1041
1118
  value = F.stack(value).astype(dtype)
1042
- elif value_elements_type == const_utils.NO_TENSOR:
1119
+ elif value_elements_type == const_utils.NO_TENSOR and not F.is_sequence_value_unknown(value):
1043
1120
  value = const_utils.make_tensor(value, dtype)
1044
1121
  else:
1045
1122
  new_value = ()
@@ -1061,7 +1138,7 @@ def _generate_updates_from_sequence(data, index, value, op_type):
1061
1138
  def _generate_updates_from_tensor(data, index, value, op_type):
1062
1139
  """Generate an updates tensor from a tensor."""
1063
1140
  value = value.astype(data.dtype)
1064
- if F.is_sequence_value_unknown(F.shape(data)):
1141
+ if F.is_sequence_value_unknown(F.shape(data)) or F.is_sequence_value_unknown(F.shape(index)):
1065
1142
  data_shape = F.dyn_shape(data)
1066
1143
  index_shape = F.dyn_shape(index)
1067
1144
  updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
@@ -1102,6 +1179,18 @@ def tensor_setitem_by_number(self, index, value):
1102
1179
  return tensor_setitem_by_number_with_sequence(self, index, value)
1103
1180
 
1104
1181
 
1182
+ def _tuple_index_transfer(broadcast_shape, final_shape, new_shape, x, all_empty_tensor):
1183
+ """Transform tuple index tensor to the required."""
1184
+ if isinstance(broadcast_shape, Tensor):
1185
+ if not all_empty_tensor:
1186
+ x = F.broadcast_to(x, broadcast_shape)
1187
+ x = F.reshape(x, new_shape)
1188
+ x = F.broadcast_to(x, final_shape)
1189
+ return x
1190
+ item = _broadcast(broadcast_shape, x)
1191
+ return _broadcast(final_shape, F.reshape(item, new_shape))
1192
+
1193
+
1105
1194
  class _TensorIndexSetitem(base.TensorIndexSetitem_):
1106
1195
  """
1107
1196
  Getting item of Tensor.
@@ -1114,10 +1203,6 @@ class _TensorIndexSetitem(base.TensorIndexSetitem_):
1114
1203
  Type is the same as the element type of data.
1115
1204
  """
1116
1205
 
1117
- def __init__(self, name):
1118
- """Initialize _TensorIndexGetitem."""
1119
- base.TensorIndexSetitem_.__init__(self, name)
1120
-
1121
1206
  def __call__(self, *args):
1122
1207
  pass
1123
1208
 
@@ -1170,7 +1255,8 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
1170
1255
  index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
1171
1256
  index = F.broadcast_to(index, data.shape)
1172
1257
  value = F.cast(value, F.dtype(data))
1173
- value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
1258
+ while value.ndim < data.ndim:
1259
+ value = value.unsqueeze(-1)
1174
1260
  value = F.broadcast_to(value, data.shape)
1175
1261
  result = F.select(index, value, data)
1176
1262
  return result
@@ -1184,13 +1270,12 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
1184
1270
  return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
1185
1271
 
1186
1272
  if F.is_sequence_value_unknown(F.shape(data)):
1187
- const_utils.raise_unimplemented_error(
1188
- "Not supported to the dynamic shape tensor slice by using tensor of Boolean type")
1273
+ return tensor_setitem_by_tuple_with_tensor(data, (index,), value_tensor.astype(data.dtype))
1189
1274
  return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
1190
1275
 
1191
1276
 
1192
1277
  def tensor_setitem_by_tensor_with_number(data, index, value):
1193
- value = F.fill(F.dtype(data), (), value)
1278
+ value = F.cast(value, F.dtype(data))
1194
1279
  return tensor_setitem_by_tensor_with_tensor(data, index, value)
1195
1280
 
1196
1281
 
@@ -1221,13 +1306,13 @@ def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
1221
1306
 
1222
1307
  def tensor_setitem_by_slice_with_number(data, input_slice, value):
1223
1308
  """Givens a scalar assign to tensor by slice"""
1224
- value = F.fill(F.dtype(data), (), value)
1309
+ value = F.cast(value, F.dtype(data))
1225
1310
  return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
1226
1311
 
1227
1312
 
1228
1313
  def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
1229
1314
  """Assigns the tensor by tuple with number value."""
1230
- value = F.fill(F.dtype(data), (), value)
1315
+ value = F.cast(value, F.dtype(data))
1231
1316
  return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
1232
1317
 
1233
1318
 
@@ -1305,7 +1390,123 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
1305
1390
  return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
1306
1391
 
1307
1392
 
1393
+ class _PreSetitemByTuple(base.PreSetitemByTuple_):
1394
+ """
1395
+ Getting item of Tensor.
1396
+
1397
+ Args:
1398
+ data (Tensor): A tuple to be sliced.
1399
+ index: Index of tensor.
1400
+
1401
+ Returns:
1402
+ Type is the same as the element type of data.
1403
+ """
1404
+
1405
+ def __init__(self, name):
1406
+ """Initialize _PreSetitemByTuple."""
1407
+ base.PreSetitemByTuple_.__init__(self, name)
1408
+
1409
+ def __call__(self, *args):
1410
+ pass
1411
+
1412
+
1413
+ _pre_setitem_by_tuple = _PreSetitemByTuple('pre_setitem_by_tuple')
1414
+
1415
+
1416
+ class _HandleBoolTensor(base.HandleBoolTensor_):
1417
+ """
1418
+ Getting item of Tensor.
1419
+
1420
+ Args:
1421
+ data (Tensor): A tuple to be sliced.
1422
+ index: Index of tensor.
1423
+
1424
+ Returns:
1425
+ Type is the same as the element type of data.
1426
+ """
1427
+
1428
+ def __init__(self, name):
1429
+ """Initialize _HandleBoolTensor."""
1430
+ base.HandleBoolTensor_.__init__(self, name)
1431
+
1432
+ def __call__(self, *args):
1433
+ pass
1434
+
1435
+
1436
+ _handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
1437
+
1438
+
1439
+ class _HandleScalarTensorIndex(base.HandleScalarTensorIndex_):
1440
+ """
1441
+ Getting item of Tensor.
1442
+
1443
+ Args:
1444
+ data (Tensor): A tuple to be sliced.
1445
+ index: Index of tensor.
1446
+
1447
+ Returns:
1448
+ Type is the same as the element type of data.
1449
+ """
1450
+
1451
+ def __init__(self, name):
1452
+ """Initialize _HandleBoolTensor."""
1453
+ base.HandleScalarTensorIndex_.__init__(self, name)
1454
+
1455
+ def __call__(self, *args):
1456
+ pass
1457
+
1458
+
1459
+ _handle_scalar_tensor_index = _HandleScalarTensorIndex('handle_scalar_tensor_index')
1460
+
1461
+
1308
1462
  def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
1463
+ """Assigns the tensor by tuple with tensor value."""
1464
+ if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
1465
+ if F.is_sequence_value_unknown(F.shape(data)):
1466
+ return tensor_copy_slice_from_tuple(data, tuple_index, value)
1467
+ dim1_start, dim1_stop, _ = const_utils.normalize_slice(
1468
+ tuple_index[1], data.shape[1])
1469
+ if dim1_stop - dim1_start <= 0:
1470
+ return data
1471
+ dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
1472
+ start = (dim0_start, dim1_start)
1473
+ stop = (dim0_start + 1, dim1_stop)
1474
+ step = (1, 1)
1475
+ value_shape = (dim1_stop - dim1_start,) + \
1476
+ const_utils.tuple_slice(data.shape, 2, None)
1477
+ value = _broadcast(value_shape, value)
1478
+ return copy_slice(data, value.astype(data.dtype), start, stop, step)
1479
+ tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
1480
+
1481
+ for non_zero_shape in non_zero_shapes:
1482
+ if F.reduce_min(non_zero_shape) == 0:
1483
+ return data
1484
+ value = value.astype(data.dtype)
1485
+ special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
1486
+ = _pre_setitem_by_tuple(data, tuple_index, value)
1487
+ if special_index == 0:
1488
+ return data
1489
+ value = F.reshape(value, new_value_shape)
1490
+ if not tuple_index or special_index == 1:
1491
+ data[True] = value
1492
+ return data
1493
+
1494
+ empty_broadcast_data_shape = False
1495
+ if isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
1496
+ empty_broadcast_data_shape = True
1497
+ if isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
1498
+ empty_broadcast_data_shape = True
1499
+ indices = _tensor_index_setitem(
1500
+ data, tuple_index, value, idx_advanced, empty_broadcast_data_shape)
1501
+
1502
+ updates = _generate_updates_from_tensor(
1503
+ data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
1504
+ if is_parameter(data):
1505
+ F.scatter_nd_update(data, indices, updates)
1506
+ return data
1507
+ return F.tensor_scatter_update(data, indices, updates)
1508
+
1509
+ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
1309
1510
  """Assigns the tensor by tuple with tensor value."""
1310
1511
  op_name = const_utils.TENSOR_SETITEM
1311
1512
  tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
@@ -1323,7 +1524,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
1323
1524
  value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
1324
1525
  value = _broadcast(value_shape, value)
1325
1526
  return copy_slice(data, value.astype(data.dtype), start, stop, step)
1326
-
1327
1527
  tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
1328
1528
 
1329
1529
  if tuple_index is False:
@@ -1351,7 +1551,7 @@ def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
1351
1551
 
1352
1552
  def tensor_setitem_by_number_with_number(data, index, value):
1353
1553
  """Assigns the tensor by number with number value."""
1354
- value = F.fill(F.dtype(data), (), value)
1554
+ value = F.cast(value, F.dtype(data))
1355
1555
  return tensor_setitem_by_number_with_tensor(data, index, value)
1356
1556
 
1357
1557
 
@@ -1386,7 +1586,7 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
1386
1586
  data_shape = F.shape(data)
1387
1587
  data_dtype = F.dtype(data)
1388
1588
  if F.is_sequence_value_unknown(data_shape):
1389
- value = F.fill(F.dtype(data), (), value)
1589
+ value = F.cast(value, F.dtype(data))
1390
1590
  return tensor_setitem_by_ellipsis_with_tensor(data, value)
1391
1591
  return F.fill(data_dtype, data_shape, value)
1392
1592
 
@@ -1418,6 +1618,7 @@ def tensor_setitem_by_ellipsis_with_sequence(data, value):
1418
1618
  def tensor_setitem_by_bool(data, index, value):
1419
1619
  """Assigns a value to the tensor by boolean."""
1420
1620
  data_shape = F.shape(data)
1621
+ data_dtype = F.dtype(data)
1421
1622
  if not index:
1422
1623
  data_shape = (0,) + data_shape
1423
1624
  if isinstance(value, (list, tuple)):
@@ -1429,6 +1630,7 @@ def tensor_setitem_by_bool(data, index, value):
1429
1630
 
1430
1631
  if F.is_sequence_value_unknown(data_shape) and index:
1431
1632
  data_shape = F.dyn_shape(data)
1633
+ value = value.astype(data_dtype)
1432
1634
  data = ops.broadcast_to(value, data_shape)
1433
1635
  return data
1434
1636
  value_shape = F.shape(value)
@@ -1436,7 +1638,7 @@ def tensor_setitem_by_bool(data, index, value):
1436
1638
  if index:
1437
1639
  value = F.reshape(value, source_shape)
1438
1640
  value = _broadcast(data_shape, value)
1439
- data = value
1641
+ data = F.cast(value, data_dtype)
1440
1642
  return data
1441
1643
 
1442
1644