mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.3.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1166) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -5
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +76 -18
  9. mindspore/_extends/builtin_operations.py +2 -1
  10. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  11. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  12. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  13. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  14. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  15. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  16. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  17. mindspore/_extends/parse/__init__.py +18 -14
  18. mindspore/_extends/parse/compile_config.py +258 -0
  19. mindspore/_extends/parse/namespace.py +2 -2
  20. mindspore/_extends/parse/parser.py +174 -62
  21. mindspore/_extends/parse/resources.py +45 -14
  22. mindspore/_extends/parse/standard_method.py +142 -240
  23. mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
  24. mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
  25. mindspore/_extends/remote/kernel_build_server.py +2 -0
  26. mindspore/_profiler.py +30 -0
  27. mindspore/amp.py +51 -24
  28. mindspore/atlprov.dll +0 -0
  29. mindspore/avcodec-59.dll +0 -0
  30. mindspore/avdevice-59.dll +0 -0
  31. mindspore/avfilter-8.dll +0 -0
  32. mindspore/avformat-59.dll +0 -0
  33. mindspore/avutil-57.dll +0 -0
  34. mindspore/boost/adasum.py +1 -1
  35. mindspore/boost/base.py +1 -1
  36. mindspore/boost/boost_cell_wrapper.py +2 -2
  37. mindspore/boost/grad_freeze.py +2 -2
  38. mindspore/boost/group_loss_scale_manager.py +1 -1
  39. mindspore/boost/less_batch_normalization.py +9 -6
  40. mindspore/c1.dll +0 -0
  41. mindspore/c1xx.dll +0 -0
  42. mindspore/c2.dll +0 -0
  43. mindspore/common/__init__.py +15 -4
  44. mindspore/common/_jit_fallback_utils.py +2 -3
  45. mindspore/common/_register_for_adapter.py +7 -0
  46. mindspore/common/_register_for_recompute.py +48 -0
  47. mindspore/common/_register_for_tensor.py +8 -9
  48. mindspore/common/_stub_tensor.py +7 -1
  49. mindspore/common/_utils.py +5 -17
  50. mindspore/common/api.py +411 -106
  51. mindspore/common/auto_dynamic_shape.py +27 -14
  52. mindspore/common/dtype.py +17 -10
  53. mindspore/common/dump.py +6 -8
  54. mindspore/common/file_system.py +48 -0
  55. mindspore/common/generator.py +260 -0
  56. mindspore/common/hook_handle.py +51 -4
  57. mindspore/common/initializer.py +1 -1
  58. mindspore/common/jit_config.py +34 -14
  59. mindspore/common/lazy_inline.py +72 -19
  60. mindspore/common/mindir_util.py +12 -2
  61. mindspore/common/mutable.py +79 -14
  62. mindspore/common/no_inline.py +54 -0
  63. mindspore/common/np_dtype.py +25 -0
  64. mindspore/common/parameter.py +30 -11
  65. mindspore/common/recompute.py +262 -0
  66. mindspore/common/seed.py +9 -9
  67. mindspore/common/sparse_tensor.py +272 -24
  68. mindspore/common/symbol.py +122 -0
  69. mindspore/common/tensor.py +468 -494
  70. mindspore/communication/__init__.py +6 -11
  71. mindspore/communication/_comm_helper.py +5 -0
  72. mindspore/communication/comm_func.py +1140 -0
  73. mindspore/communication/management.py +115 -102
  74. mindspore/config/op_info.config +22 -54
  75. mindspore/context.py +346 -63
  76. mindspore/dataset/__init__.py +5 -5
  77. mindspore/dataset/audio/__init__.py +6 -6
  78. mindspore/dataset/audio/transforms.py +711 -158
  79. mindspore/dataset/callback/ds_callback.py +2 -2
  80. mindspore/dataset/engine/cache_client.py +2 -2
  81. mindspore/dataset/engine/datasets.py +140 -83
  82. mindspore/dataset/engine/datasets_audio.py +14 -14
  83. mindspore/dataset/engine/datasets_standard_format.py +33 -3
  84. mindspore/dataset/engine/datasets_text.py +38 -38
  85. mindspore/dataset/engine/datasets_user_defined.py +78 -59
  86. mindspore/dataset/engine/datasets_vision.py +77 -73
  87. mindspore/dataset/engine/offload.py +5 -7
  88. mindspore/dataset/engine/queue.py +56 -38
  89. mindspore/dataset/engine/validators.py +11 -5
  90. mindspore/dataset/text/__init__.py +3 -3
  91. mindspore/dataset/text/transforms.py +408 -121
  92. mindspore/dataset/text/utils.py +9 -9
  93. mindspore/dataset/transforms/__init__.py +1 -1
  94. mindspore/dataset/transforms/transforms.py +261 -76
  95. mindspore/dataset/utils/browse_dataset.py +9 -9
  96. mindspore/dataset/vision/__init__.py +8 -8
  97. mindspore/dataset/vision/c_transforms.py +10 -10
  98. mindspore/dataset/vision/py_transforms_util.py +1 -1
  99. mindspore/dataset/vision/transforms.py +2844 -549
  100. mindspore/dataset/vision/utils.py +161 -10
  101. mindspore/dataset/vision/validators.py +14 -2
  102. mindspore/dnnl.dll +0 -0
  103. mindspore/dpcmi.dll +0 -0
  104. mindspore/experimental/optim/__init__.py +12 -2
  105. mindspore/experimental/optim/adadelta.py +161 -0
  106. mindspore/experimental/optim/adagrad.py +168 -0
  107. mindspore/experimental/optim/adam.py +35 -34
  108. mindspore/experimental/optim/adamax.py +170 -0
  109. mindspore/experimental/optim/adamw.py +40 -16
  110. mindspore/experimental/optim/asgd.py +153 -0
  111. mindspore/experimental/optim/lr_scheduler.py +66 -121
  112. mindspore/experimental/optim/nadam.py +157 -0
  113. mindspore/experimental/optim/optimizer.py +15 -8
  114. mindspore/experimental/optim/radam.py +194 -0
  115. mindspore/experimental/optim/rmsprop.py +154 -0
  116. mindspore/experimental/optim/rprop.py +164 -0
  117. mindspore/experimental/optim/sgd.py +28 -19
  118. mindspore/hal/__init__.py +40 -0
  119. mindspore/hal/_ascend.py +57 -0
  120. mindspore/hal/_base.py +57 -0
  121. mindspore/hal/_cpu.py +56 -0
  122. mindspore/hal/_gpu.py +57 -0
  123. mindspore/hal/device.py +356 -0
  124. mindspore/hal/event.py +179 -0
  125. mindspore/hal/memory.py +326 -0
  126. mindspore/hal/stream.py +339 -0
  127. mindspore/include/api/data_type.h +2 -2
  128. mindspore/include/api/dual_abi_helper.h +16 -3
  129. mindspore/include/api/model.h +4 -3
  130. mindspore/include/api/status.h +14 -0
  131. mindspore/include/c_api/model_c.h +173 -0
  132. mindspore/include/c_api/ms/base/types.h +1 -0
  133. mindspore/include/c_api/types_c.h +19 -0
  134. mindspore/include/dataset/execute.h +1 -3
  135. mindspore/include/dataset/vision.h +54 -2
  136. mindspore/jpeg62.dll +0 -0
  137. mindspore/log.py +2 -2
  138. mindspore/mindrecord/__init__.py +5 -1
  139. mindspore/mindrecord/config.py +809 -0
  140. mindspore/mindrecord/filereader.py +25 -0
  141. mindspore/mindrecord/filewriter.py +76 -58
  142. mindspore/mindrecord/mindpage.py +40 -6
  143. mindspore/mindrecord/shardutils.py +3 -2
  144. mindspore/mindrecord/shardwriter.py +7 -0
  145. mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
  146. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
  147. mindspore/mindrecord/tools/csv_to_mr.py +4 -9
  148. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  149. mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
  150. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
  151. mindspore/mindspore_backend.dll +0 -0
  152. mindspore/mindspore_common.dll +0 -0
  153. mindspore/mindspore_core.dll +0 -0
  154. mindspore/mindspore_glog.dll +0 -0
  155. mindspore/mindspore_np_dtype.dll +0 -0
  156. mindspore/mindspore_shared_lib.dll +0 -0
  157. mindspore/mint/__init__.py +1137 -0
  158. mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
  159. mindspore/mint/nn/__init__.py +512 -0
  160. mindspore/mint/nn/functional.py +573 -0
  161. mindspore/mint/optim/__init__.py +24 -0
  162. mindspore/mint/optim/adamw.py +185 -0
  163. mindspore/msobj140.dll +0 -0
  164. mindspore/mspdb140.dll +0 -0
  165. mindspore/mspdbcore.dll +0 -0
  166. mindspore/mspdbst.dll +0 -0
  167. mindspore/mspft140.dll +0 -0
  168. mindspore/msvcdis140.dll +0 -0
  169. mindspore/msvcp140_1.dll +0 -0
  170. mindspore/msvcp140_2.dll +0 -0
  171. mindspore/msvcp140_atomic_wait.dll +0 -0
  172. mindspore/msvcp140_codecvt_ids.dll +0 -0
  173. mindspore/multiprocessing/__init__.py +72 -0
  174. mindspore/nn/__init__.py +1 -0
  175. mindspore/nn/cell.py +213 -257
  176. mindspore/nn/dynamic_lr.py +2 -2
  177. mindspore/nn/extend/__init__.py +29 -0
  178. mindspore/nn/extend/basic.py +140 -0
  179. mindspore/nn/extend/embedding.py +143 -0
  180. mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
  181. mindspore/nn/extend/layer/normalization.py +109 -0
  182. mindspore/nn/extend/pooling.py +117 -0
  183. mindspore/nn/layer/activation.py +83 -93
  184. mindspore/nn/layer/basic.py +177 -82
  185. mindspore/nn/layer/channel_shuffle.py +3 -16
  186. mindspore/nn/layer/container.py +3 -3
  187. mindspore/nn/layer/conv.py +75 -66
  188. mindspore/nn/layer/embedding.py +101 -43
  189. mindspore/nn/layer/embedding_service.py +531 -0
  190. mindspore/nn/layer/embedding_service_layer.py +393 -0
  191. mindspore/nn/layer/image.py +4 -7
  192. mindspore/nn/layer/math.py +1 -1
  193. mindspore/nn/layer/normalization.py +52 -66
  194. mindspore/nn/layer/padding.py +30 -39
  195. mindspore/nn/layer/pooling.py +18 -9
  196. mindspore/nn/layer/rnn_cells.py +6 -16
  197. mindspore/nn/layer/rnns.py +6 -5
  198. mindspore/nn/layer/thor_layer.py +1 -2
  199. mindspore/nn/layer/timedistributed.py +1 -1
  200. mindspore/nn/layer/transformer.py +52 -50
  201. mindspore/nn/learning_rate_schedule.py +6 -5
  202. mindspore/nn/loss/loss.py +62 -83
  203. mindspore/nn/optim/ada_grad.py +4 -2
  204. mindspore/nn/optim/adadelta.py +3 -1
  205. mindspore/nn/optim/adafactor.py +1 -1
  206. mindspore/nn/optim/adam.py +102 -181
  207. mindspore/nn/optim/adamax.py +4 -2
  208. mindspore/nn/optim/adasum.py +3 -3
  209. mindspore/nn/optim/asgd.py +4 -2
  210. mindspore/nn/optim/ftrl.py +31 -61
  211. mindspore/nn/optim/lamb.py +5 -3
  212. mindspore/nn/optim/lars.py +2 -2
  213. mindspore/nn/optim/lazyadam.py +6 -4
  214. mindspore/nn/optim/momentum.py +13 -25
  215. mindspore/nn/optim/optimizer.py +6 -3
  216. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  217. mindspore/nn/optim/rmsprop.py +9 -3
  218. mindspore/nn/optim/rprop.py +4 -2
  219. mindspore/nn/optim/sgd.py +5 -3
  220. mindspore/nn/optim/thor.py +2 -2
  221. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  222. mindspore/nn/probability/distribution/beta.py +2 -2
  223. mindspore/nn/probability/distribution/categorical.py +4 -6
  224. mindspore/nn/probability/distribution/cauchy.py +2 -2
  225. mindspore/nn/probability/distribution/exponential.py +2 -2
  226. mindspore/nn/probability/distribution/geometric.py +1 -1
  227. mindspore/nn/probability/distribution/gumbel.py +2 -2
  228. mindspore/nn/probability/distribution/logistic.py +1 -1
  229. mindspore/nn/probability/distribution/poisson.py +2 -2
  230. mindspore/nn/probability/distribution/uniform.py +2 -2
  231. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  232. mindspore/nn/wrap/__init__.py +2 -1
  233. mindspore/nn/wrap/cell_wrapper.py +58 -13
  234. mindspore/nn/wrap/grad_reducer.py +148 -8
  235. mindspore/nn/wrap/loss_scale.py +32 -9
  236. mindspore/numpy/__init__.py +2 -0
  237. mindspore/numpy/array_creations.py +2 -0
  238. mindspore/numpy/array_ops.py +6 -6
  239. mindspore/numpy/dtypes.py +3 -3
  240. mindspore/numpy/fft.py +431 -0
  241. mindspore/numpy/math_ops.py +62 -68
  242. mindspore/numpy/utils.py +3 -0
  243. mindspore/opencv_core452.dll +0 -0
  244. mindspore/opencv_imgcodecs452.dll +0 -0
  245. mindspore/opencv_imgproc452.dll +0 -0
  246. mindspore/ops/__init__.py +6 -5
  247. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
  248. mindspore/ops/_grad_experimental/grad_comm_ops.py +89 -34
  249. mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
  250. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  251. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  252. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  253. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  254. mindspore/ops/_op_impl/__init__.py +0 -1
  255. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  256. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
  257. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  258. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  259. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  260. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  261. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  262. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  263. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  264. mindspore/ops/_vmap/vmap_array_ops.py +164 -101
  265. mindspore/ops/_vmap/vmap_base.py +8 -1
  266. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  267. mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
  268. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  269. mindspore/ops/_vmap/vmap_math_ops.py +130 -58
  270. mindspore/ops/_vmap/vmap_nn_ops.py +249 -115
  271. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  272. mindspore/ops/auto_generate/__init__.py +31 -0
  273. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
  274. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
  275. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  276. mindspore/ops/auto_generate/gen_extend_func.py +980 -0
  277. mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
  278. mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
  279. mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
  280. mindspore/ops/composite/__init__.py +5 -2
  281. mindspore/ops/composite/base.py +121 -23
  282. mindspore/ops/composite/math_ops.py +10 -49
  283. mindspore/ops/composite/multitype_ops/_compile_utils.py +191 -618
  284. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
  285. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  286. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  287. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  288. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  289. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  290. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  291. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  292. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  293. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  294. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  296. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  297. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  298. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  299. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  300. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  301. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  302. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  303. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  304. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  305. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
  307. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  308. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  309. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  310. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  311. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  312. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  313. mindspore/ops/deprecated.py +14 -3
  314. mindspore/ops/extend/__init__.py +53 -0
  315. mindspore/ops/extend/array_func.py +218 -0
  316. mindspore/ops/extend/math_func.py +76 -0
  317. mindspore/ops/extend/nn_func.py +308 -0
  318. mindspore/ops/function/__init__.py +31 -11
  319. mindspore/ops/function/array_func.py +846 -1735
  320. mindspore/ops/function/clip_func.py +19 -31
  321. mindspore/ops/function/debug_func.py +1 -4
  322. mindspore/ops/function/fft_func.py +31 -0
  323. mindspore/ops/function/grad/grad_func.py +27 -20
  324. mindspore/ops/function/image_func.py +27 -21
  325. mindspore/ops/function/linalg_func.py +35 -68
  326. mindspore/ops/function/math_func.py +913 -2791
  327. mindspore/ops/function/nn_func.py +1439 -885
  328. mindspore/ops/function/other_func.py +6 -7
  329. mindspore/ops/function/parameter_func.py +5 -93
  330. mindspore/ops/function/random_func.py +254 -108
  331. mindspore/ops/function/reshard_func.py +102 -0
  332. mindspore/ops/function/sparse_func.py +4 -4
  333. mindspore/ops/function/sparse_unary_func.py +9 -16
  334. mindspore/ops/function/spectral_func.py +1 -1
  335. mindspore/ops/function/vmap_func.py +14 -14
  336. mindspore/ops/functional.py +342 -343
  337. mindspore/ops/op_info_register.py +16 -43
  338. mindspore/ops/operations/__init__.py +32 -23
  339. mindspore/ops/operations/_grad_ops.py +21 -853
  340. mindspore/ops/operations/_infer_ops.py +19 -0
  341. mindspore/ops/operations/_inner_ops.py +107 -518
  342. mindspore/ops/operations/_rl_inner_ops.py +2 -2
  343. mindspore/ops/operations/_scalar_ops.py +5 -480
  344. mindspore/ops/operations/_sequence_ops.py +6 -36
  345. mindspore/ops/operations/_tensor_array.py +8 -8
  346. mindspore/ops/operations/array_ops.py +108 -2705
  347. mindspore/ops/operations/comm_ops.py +801 -118
  348. mindspore/ops/operations/custom_ops.py +61 -120
  349. mindspore/ops/operations/debug_ops.py +104 -35
  350. mindspore/ops/operations/image_ops.py +1 -217
  351. mindspore/ops/operations/inner_ops.py +5 -40
  352. mindspore/ops/operations/linalg_ops.py +1 -49
  353. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  354. mindspore/ops/operations/manually_defined/_inner.py +61 -0
  355. mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
  356. mindspore/ops/operations/math_ops.py +572 -4667
  357. mindspore/ops/operations/nn_ops.py +248 -2162
  358. mindspore/ops/operations/other_ops.py +53 -45
  359. mindspore/ops/operations/random_ops.py +4 -53
  360. mindspore/ops/operations/reshard_ops.py +53 -0
  361. mindspore/ops/operations/sparse_ops.py +4 -4
  362. mindspore/ops/primitive.py +204 -103
  363. mindspore/ops/silent_check.py +5 -5
  364. mindspore/ops_generate/__init__.py +27 -0
  365. mindspore/ops_generate/arg_dtype_cast.py +250 -0
  366. mindspore/ops_generate/arg_handler.py +197 -0
  367. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  368. mindspore/ops_generate/gen_ops.py +1084 -0
  369. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  370. mindspore/ops_generate/gen_pyboost_func.py +968 -0
  371. mindspore/ops_generate/gen_utils.py +209 -0
  372. mindspore/ops_generate/op_proto.py +138 -0
  373. mindspore/ops_generate/pyboost_utils.py +354 -0
  374. mindspore/ops_generate/template.py +239 -0
  375. mindspore/parallel/__init__.py +6 -4
  376. mindspore/parallel/_auto_parallel_context.py +73 -3
  377. mindspore/parallel/_cell_wrapper.py +16 -9
  378. mindspore/parallel/_cost_model_context.py +1 -1
  379. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  380. mindspore/parallel/_parallel_serialization.py +29 -13
  381. mindspore/parallel/_ps_context.py +1 -1
  382. mindspore/parallel/_recovery_context.py +1 -1
  383. mindspore/parallel/_tensor.py +18 -11
  384. mindspore/parallel/_transformer/__init__.py +1 -1
  385. mindspore/parallel/_transformer/layers.py +1 -1
  386. mindspore/parallel/_transformer/loss.py +1 -1
  387. mindspore/parallel/_transformer/moe.py +1 -1
  388. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  389. mindspore/parallel/_transformer/transformer.py +2 -2
  390. mindspore/parallel/_utils.py +161 -6
  391. mindspore/parallel/algo_parameter_config.py +6 -8
  392. mindspore/parallel/checkpoint_transform.py +191 -32
  393. mindspore/parallel/cluster/__init__.py +15 -0
  394. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  395. mindspore/parallel/cluster/process_entity/_api.py +344 -0
  396. mindspore/parallel/cluster/process_entity/_utils.py +126 -0
  397. mindspore/parallel/cluster/run.py +136 -0
  398. mindspore/parallel/mpi/__init__.py +1 -1
  399. mindspore/parallel/mpi/_mpi_config.py +1 -1
  400. mindspore/parallel/parameter_broadcast.py +152 -0
  401. mindspore/parallel/shard.py +128 -17
  402. mindspore/pgodb140.dll +0 -0
  403. mindspore/pgort140.dll +0 -0
  404. mindspore/profiler/__init__.py +3 -2
  405. mindspore/profiler/common/process_pool.py +41 -0
  406. mindspore/profiler/common/singleton.py +28 -0
  407. mindspore/profiler/common/util.py +125 -0
  408. mindspore/profiler/envprofiling.py +2 -2
  409. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  410. mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
  411. mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
  412. mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
  413. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
  414. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
  415. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
  416. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
  417. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  418. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
  419. mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
  420. mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
  421. mindspore/profiler/parser/ascend_flops_generator.py +20 -4
  422. mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
  423. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  424. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  425. mindspore/profiler/parser/ascend_msprof_exporter.py +147 -146
  426. mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
  427. mindspore/profiler/parser/ascend_op_generator.py +92 -42
  428. mindspore/profiler/parser/ascend_timeline_generator.py +296 -133
  429. mindspore/profiler/parser/base_timeline_generator.py +6 -0
  430. mindspore/profiler/parser/framework_parser.py +3 -2
  431. mindspore/profiler/parser/integrator.py +3 -1
  432. mindspore/profiler/parser/minddata_parser.py +72 -3
  433. mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
  434. mindspore/profiler/parser/msadvisor_parser.py +1 -1
  435. mindspore/profiler/parser/profiler_info.py +16 -1
  436. mindspore/profiler/profiling.py +445 -190
  437. mindspore/rewrite/__init__.py +2 -13
  438. mindspore/rewrite/api/node.py +122 -36
  439. mindspore/rewrite/api/pattern_engine.py +2 -3
  440. mindspore/rewrite/api/scoped_value.py +16 -15
  441. mindspore/rewrite/api/symbol_tree.py +45 -29
  442. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  443. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  444. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  445. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  446. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  447. mindspore/rewrite/common/__init__.py +1 -2
  448. mindspore/rewrite/common/config.py +24 -0
  449. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  450. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  451. mindspore/rewrite/common/namespace.py +118 -0
  452. mindspore/rewrite/node/__init__.py +5 -5
  453. mindspore/rewrite/node/call_function.py +23 -7
  454. mindspore/rewrite/node/cell_container.py +7 -3
  455. mindspore/rewrite/node/control_flow.py +53 -28
  456. mindspore/rewrite/node/node.py +212 -196
  457. mindspore/rewrite/node/node_manager.py +51 -22
  458. mindspore/rewrite/node/node_topological_manager.py +3 -23
  459. mindspore/rewrite/parsers/__init__.py +12 -0
  460. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  461. mindspore/rewrite/parsers/assign_parser.py +637 -413
  462. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  463. mindspore/rewrite/parsers/class_def_parser.py +115 -148
  464. mindspore/rewrite/parsers/constant_parser.py +5 -5
  465. mindspore/rewrite/parsers/container_parser.py +4 -6
  466. mindspore/rewrite/parsers/expr_parser.py +55 -0
  467. mindspore/rewrite/parsers/for_parser.py +31 -98
  468. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  469. mindspore/rewrite/parsers/if_parser.py +28 -10
  470. mindspore/rewrite/parsers/module_parser.py +8 -182
  471. mindspore/rewrite/parsers/parser.py +1 -5
  472. mindspore/rewrite/parsers/parser_register.py +1 -1
  473. mindspore/rewrite/parsers/return_parser.py +5 -10
  474. mindspore/rewrite/parsers/while_parser.py +59 -0
  475. mindspore/rewrite/sparsify/utils.py +1 -1
  476. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  477. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
  478. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  479. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  480. mindspore/run_check/_check_version.py +6 -14
  481. mindspore/run_check/run_check.py +1 -1
  482. mindspore/safeguard/rewrite_obfuscation.py +9 -19
  483. mindspore/swresample-4.dll +0 -0
  484. mindspore/swscale-6.dll +0 -0
  485. mindspore/tbbmalloc.dll +0 -0
  486. mindspore/tinyxml2.dll +0 -0
  487. mindspore/train/__init__.py +6 -5
  488. mindspore/train/_utils.py +178 -4
  489. mindspore/train/amp.py +167 -245
  490. mindspore/train/anf_ir_pb2.py +14 -2
  491. mindspore/train/callback/__init__.py +5 -2
  492. mindspore/train/callback/_backup_and_restore.py +5 -5
  493. mindspore/train/callback/_callback.py +4 -4
  494. mindspore/train/callback/_checkpoint.py +143 -29
  495. mindspore/train/callback/_cluster_monitor.py +201 -0
  496. mindspore/train/callback/_early_stop.py +2 -2
  497. mindspore/train/callback/_flops_collector.py +238 -0
  498. mindspore/train/callback/_landscape.py +15 -9
  499. mindspore/train/callback/_loss_monitor.py +2 -2
  500. mindspore/train/callback/_mindio_ttp.py +443 -0
  501. mindspore/train/callback/_on_request_exit.py +2 -2
  502. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  503. mindspore/train/callback/_summary_collector.py +7 -7
  504. mindspore/train/callback/_time_monitor.py +3 -3
  505. mindspore/train/data_sink.py +6 -5
  506. mindspore/train/dataset_helper.py +60 -21
  507. mindspore/train/loss_scale_manager.py +2 -2
  508. mindspore/train/metrics/accuracy.py +7 -7
  509. mindspore/train/metrics/confusion_matrix.py +8 -6
  510. mindspore/train/metrics/cosine_similarity.py +6 -4
  511. mindspore/train/metrics/error.py +2 -2
  512. mindspore/train/metrics/metric.py +3 -3
  513. mindspore/train/metrics/perplexity.py +2 -1
  514. mindspore/train/metrics/topk.py +2 -2
  515. mindspore/train/mind_ir_pb2.py +89 -15
  516. mindspore/train/model.py +290 -60
  517. mindspore/train/serialization.py +495 -220
  518. mindspore/train/summary/_summary_adapter.py +1 -1
  519. mindspore/train/summary/summary_record.py +51 -28
  520. mindspore/train/train_thor/convert_utils.py +3 -3
  521. mindspore/turbojpeg.dll +0 -0
  522. mindspore/vcmeta.dll +0 -0
  523. mindspore/vcruntime140.dll +0 -0
  524. mindspore/vcruntime140_1.dll +0 -0
  525. mindspore/version.py +1 -1
  526. {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/METADATA +3 -3
  527. mindspore-2.3.0.dist-info/RECORD +1400 -0
  528. {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/entry_points.txt +1 -0
  529. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  530. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  531. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  532. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  533. mindspore/gen_ops.py +0 -273
  534. mindspore/nn/layer/flash_attention.py +0 -189
  535. mindspore/ops/_op_impl/cpu/concat.py +0 -39
  536. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  537. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  538. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  539. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  540. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  541. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  542. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  543. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  544. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  545. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  546. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  547. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  548. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  549. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  550. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  551. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  552. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  553. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  554. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  555. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  556. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  557. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  558. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  559. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  560. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  561. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  562. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  563. mindspore/ops/_op_impl/tbe/add.py +0 -42
  564. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  565. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  566. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  567. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  568. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  569. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  570. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  571. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  572. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  573. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  574. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  575. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  576. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  577. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  578. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  579. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  580. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  581. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  582. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  583. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  584. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  585. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  586. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  587. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  588. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  589. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  590. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  591. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  592. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  593. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  594. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  595. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  596. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  597. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  598. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  599. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  600. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  601. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  602. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  603. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  604. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  605. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  606. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  607. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  608. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  609. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  610. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  611. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  612. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  613. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  614. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  615. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  616. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  617. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  618. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  619. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  620. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  621. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  622. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  623. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  624. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  625. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  626. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  627. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  628. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  629. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  630. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  631. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  632. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  633. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  634. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  635. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  636. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  637. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  638. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  639. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  640. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  641. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  642. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  643. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  644. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  645. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  646. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  647. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  648. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  649. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  650. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  651. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  652. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  653. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  654. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  655. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  656. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  657. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  658. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  659. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  660. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  661. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  662. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  663. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  664. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  665. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  666. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  667. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  668. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  669. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  670. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  671. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  672. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  673. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  674. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  675. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  676. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  677. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  678. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  679. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  680. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  681. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  682. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  683. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  684. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  685. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  686. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  687. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  688. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  689. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  690. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  691. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  692. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  693. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  694. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  695. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  696. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  697. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  698. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  699. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  700. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  701. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  702. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  703. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  704. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  705. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  706. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  707. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  708. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  709. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  710. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  711. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  712. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  713. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  714. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  715. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  716. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  717. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  718. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  719. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  720. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  721. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  722. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  723. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  724. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  725. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  726. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  727. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  728. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  729. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  730. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  731. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  732. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  733. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  734. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  735. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  736. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  737. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  738. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  739. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  740. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  741. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  742. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  743. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  744. mindspore/ops/_op_impl/tbe/div.py +0 -41
  745. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  746. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  747. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  748. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  749. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  750. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  751. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  752. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  753. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  754. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  755. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  756. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  757. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  758. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  759. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  760. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  761. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  762. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  763. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  764. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  765. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  766. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  767. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  768. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  769. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  770. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  771. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  772. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  773. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  774. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  775. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  776. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  777. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  778. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  779. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  780. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  781. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  782. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  783. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  784. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  785. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  786. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  787. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  788. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  789. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  790. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  791. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  792. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  793. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  794. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  795. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  796. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  797. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  798. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  799. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  800. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  801. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  802. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  803. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  804. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  805. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  806. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  807. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  808. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  809. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  810. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  811. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  812. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  813. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  814. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  815. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  816. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  817. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  818. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  819. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  820. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  821. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  822. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  823. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  824. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  825. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  826. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  827. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  828. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  829. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  830. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  831. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  832. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  833. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  834. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  835. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  836. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  837. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  838. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  839. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  840. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  841. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  842. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  843. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  844. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  845. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  846. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  847. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  848. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  849. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  850. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  851. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  852. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  853. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  854. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  855. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  856. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  857. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  858. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  859. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  860. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  861. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  862. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  863. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  864. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  865. mindspore/ops/_op_impl/tbe/less.py +0 -41
  866. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  867. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  868. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  869. mindspore/ops/_op_impl/tbe/log.py +0 -40
  870. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  871. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  872. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  873. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  874. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  875. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  876. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  877. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  878. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  879. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  880. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  881. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  882. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  883. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  884. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  885. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  886. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  887. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  888. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  889. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  890. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  891. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  892. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  893. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  894. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  895. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  896. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  897. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  898. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  899. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  900. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  901. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  902. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  903. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  904. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  905. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  906. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  907. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  908. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  909. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  910. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  911. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  912. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  913. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  914. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  915. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  916. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  917. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  918. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  919. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  920. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  921. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  922. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  923. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  924. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  925. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  926. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  927. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  928. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  929. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  930. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  931. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  932. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  933. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  934. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  935. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  936. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  937. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  938. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  939. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  940. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  941. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  942. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  943. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  944. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  945. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  946. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  947. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  948. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  949. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  950. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  951. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  952. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  953. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  954. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  955. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  956. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  957. mindspore/ops/_op_impl/tbe/range.py +0 -39
  958. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  959. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  960. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  961. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  962. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  963. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  964. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  965. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  966. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  967. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  968. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  969. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  970. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  971. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  972. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  973. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  974. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  975. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  976. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  977. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  978. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  979. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  980. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  981. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  982. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  983. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  984. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  985. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  986. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  987. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  988. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  989. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  990. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  991. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  992. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  993. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  994. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  995. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  996. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  997. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  998. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  999. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  1000. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  1001. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  1002. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  1003. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  1004. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  1005. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  1006. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  1007. mindspore/ops/_op_impl/tbe/round.py +0 -38
  1008. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  1009. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  1010. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  1011. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  1012. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  1013. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  1014. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1015. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1016. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1017. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1018. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1019. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1020. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1021. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1022. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1023. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1024. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1025. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1026. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1027. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1028. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1029. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1030. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1031. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1032. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1033. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1034. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1035. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1036. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1037. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1038. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1039. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1040. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1041. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1042. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1043. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1044. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1045. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1046. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1047. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1048. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1049. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1050. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1051. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1052. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1053. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1054. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1055. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1056. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1057. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1058. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1059. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1060. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1061. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1062. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1063. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1064. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1065. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1066. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1067. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1068. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1069. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1070. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1071. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1072. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1073. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1074. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1075. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1076. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1077. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1078. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1079. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1080. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1081. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1082. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1083. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1084. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1085. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1086. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1087. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1088. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1089. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1090. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1091. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1092. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1093. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1094. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1095. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1096. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1097. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1098. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1099. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1100. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1101. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1102. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1103. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1104. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1105. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1106. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1107. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1108. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1109. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1110. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1111. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1112. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1113. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1114. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1115. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1116. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1117. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1118. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1119. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1120. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1121. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1122. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1123. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1124. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1125. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1126. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1127. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1128. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1129. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1130. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1131. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1132. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1133. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1134. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1135. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1136. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1137. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1138. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1139. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1140. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1141. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1142. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1143. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1144. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1145. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1146. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1147. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1148. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1149. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1150. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1151. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1152. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1153. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1154. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1155. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1156. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1157. mindspore/ops/_tracefunc.py +0 -241
  1158. mindspore/ops/arg_dtype_cast.py +0 -54
  1159. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1160. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1161. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1162. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1163. mindspore/rewrite/namespace.py +0 -53
  1164. mindspore-2.2.14.dist-info/RECORD +0 -1924
  1165. {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/WHEEL +0 -0
  1166. {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/top_level.txt +0 -0
@@ -24,13 +24,13 @@ from mindspore.ops import operations as P
24
24
  from mindspore.ops.composite import base
25
25
  from mindspore.ops._primitive_cache import _get_cache_prim
26
26
  from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
27
- TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
27
+ TopTypeof, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
28
28
  SelectView, CopyWithSlice
29
+ from mindspore.ops.operations._sequence_ops import TensorToTuple, TensorToScalar, TupleToTensor
29
30
  from mindspore.common import dtype as mstype
30
31
  from mindspore.common._register_for_tensor import tensor_operator_registry
31
32
  from mindspore.common.initializer import Zero
32
- from mindspore.common import Tensor, CSRTensor, COOTensor
33
- from mindspore.common import mutable
33
+ from mindspore.common import Tensor, CSRTensor, COOTensor, mutable
34
34
  from mindspore import ops
35
35
  from mindspore.ops.primitive import _primexpr
36
36
  from mindspore import _checkparam as validator
@@ -137,7 +137,8 @@ def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=N
137
137
  elif transfer_type == ValueTransferType.kGatherND:
138
138
  if isinstance(new_index, list):
139
139
  new_index = handle_multi_dim_index_tensor(new_index, arg)
140
- data = F.gather_nd(data, Tensor(new_index))
140
+ new_index = format_index_tensor(new_index, (None, F.shape(data)[:F.shape(new_index)[-1]]))
141
+ data = F.gather_nd(data, new_index)
141
142
  elif transfer_type == ValueTransferType.kTensorScatterUpdate:
142
143
  if isinstance(new_index, list):
143
144
  new_index = handle_multi_dim_index_tensor(new_index, arg)
@@ -217,8 +218,8 @@ def _tensor_setitem(self, index, value):
217
218
  return output
218
219
 
219
220
 
220
- tensor_operator_registry.register("__getitem__", _tensor_getitem)
221
- tensor_operator_registry.register("__setitem__", _tensor_setitem)
221
+ setattr(tensor_operator_registry, "__getitem__", _tensor_getitem)
222
+ setattr(tensor_operator_registry, "__setitem__", _tensor_setitem)
222
223
 
223
224
 
224
225
  def _tensor_add(self, other):
@@ -287,15 +288,15 @@ def _tensor_floordiv(self, other):
287
288
  return F.floordiv(self, other)
288
289
 
289
290
 
290
- tensor_operator_registry.register('__add__', _tensor_add)
291
- tensor_operator_registry.register('__sub__', _tensor_sub)
292
- tensor_operator_registry.register('__mul__', _tensor_mul)
293
- tensor_operator_registry.register('__matmul__', _tensor_matmul)
294
- tensor_operator_registry.register('__truediv__', _tensor_div)
295
- tensor_operator_registry.register('__mod__', _tensor_mod)
296
- tensor_operator_registry.register('__pow__', _tensor_pow)
297
- tensor_operator_registry.register('__rpow__', _tensor_rpow)
298
- tensor_operator_registry.register('__floordiv__', _tensor_floordiv)
291
+ setattr(tensor_operator_registry, '__add__', _tensor_add)
292
+ setattr(tensor_operator_registry, '__sub__', _tensor_sub)
293
+ setattr(tensor_operator_registry, '__mul__', _tensor_mul)
294
+ setattr(tensor_operator_registry, '__matmul__', _tensor_matmul)
295
+ setattr(tensor_operator_registry, '__truediv__', _tensor_div)
296
+ setattr(tensor_operator_registry, '__mod__', _tensor_mod)
297
+ setattr(tensor_operator_registry, '__pow__', _tensor_pow)
298
+ setattr(tensor_operator_registry, '__rpow__', _tensor_rpow)
299
+ setattr(tensor_operator_registry, '__floordiv__', _tensor_floordiv)
299
300
 
300
301
 
301
302
  def _scalar_to_tensor(input_x):
@@ -317,24 +318,25 @@ def tensor_item(data, *args):
317
318
  # transform a.item(tuple(int)) -> a.item(int1,int2...intN)
318
319
  if data.ndim == 0:
319
320
  _check_scalar_tensor_args(args)
320
- return data.asnumpy().item()
321
+ return TensorToScalar()(data)
321
322
  if len(args) == 1 and isinstance(args[0], tuple):
322
323
  args = args[0]
323
324
 
324
325
  args_types = hyper_map(F.typeof, args)
325
326
  if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
326
327
  if data.shape == (1,):
327
- return data.asnumpy().item()
328
+ return TensorToScalar()(data[0])
328
329
  const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
329
330
 
330
331
  if not const_utils.judge_indexes_types(args_types, mstype.int64):
331
332
  const_utils.raise_type_error("The index object cannot be interpreted as an integer")
332
333
 
333
334
  if len(args) == data.ndim:
334
- return _tensor_getitem_by_tuple_slice(data, args)
335
+ return tensor_index_by_tuple(data, args)
335
336
  if len(args) > 1:
336
337
  const_utils.raise_value_error("Incorrect number of indices for array")
337
- return _tensor_index_by_integer(F.reshape(data, (-1,)), args[0])
338
+ output = _tensor_index_by_integer(F.reshape(data, (-1,)), args[0])
339
+ return TensorToScalar()(output)
338
340
 
339
341
 
340
342
  def tensor_itemset(data, *args):
@@ -354,8 +356,8 @@ def tensor_itemset(data, *args):
354
356
  return tensor_itemset_with_number(data, args[0])
355
357
 
356
358
 
357
- tensor_operator_registry.register("item", tensor_item)
358
- tensor_operator_registry.register("itemset", tensor_itemset)
359
+ setattr(tensor_operator_registry, "item", tensor_item)
360
+ setattr(tensor_operator_registry, "itemset", tensor_itemset)
359
361
 
360
362
 
361
363
  def tensor_itemset_with_number(data, number_value):
@@ -521,24 +523,45 @@ def _expand_data_dims(data, tuple_index):
521
523
  return data, tuple_index_new
522
524
 
523
525
 
524
- def convert_variable_to_tensor_slice(slice_index):
525
- """convert mutable scalar to tensor"""
526
- start = slice_get_item(slice_index, "start")
527
- stop = slice_get_item(slice_index, "stop")
528
- step = slice_get_item(slice_index, "step")
529
- find_mutable_scalar = False
530
- if isinstance(start, int) and not F.isconstant(start):
531
- start = ops.Cast()(start, mstype.int64)
532
- find_mutable_scalar = True
533
- if isinstance(stop, int) and not F.isconstant(stop):
534
- stop = ops.Cast()(stop, mstype.int64)
535
- find_mutable_scalar = True
536
- if isinstance(step, int) and not F.isconstant(step):
537
- step = ops.Cast()(step, mstype.int64)
538
- find_mutable_scalar = True
539
- if find_mutable_scalar:
540
- return F.make_slice(start, stop, step)
541
- return slice_index
526
+ def _convert_list_index_to_tensor(list_index):
527
+ """convert list to tensor"""
528
+ has_bool = False
529
+ has_int = False
530
+ has_no_bool_int = False
531
+ for idx in list_index:
532
+ if isinstance(idx, bool):
533
+ has_bool = True
534
+ elif isinstance(idx, int):
535
+ has_int = True
536
+ else:
537
+ has_no_bool_int = True
538
+
539
+ all_bool = has_bool and not has_int and not has_no_bool_int
540
+ all_int = has_int and not has_bool and not has_no_bool_int
541
+ all_bool_or_int = not has_no_bool_int
542
+
543
+ if all_int:
544
+ index_tensor = TupleToTensor()(tuple(list_index), mstype.int64)
545
+ return index_tensor
546
+
547
+
548
+ if all_bool:
549
+ index_tensor = TupleToTensor()(tuple(list_index), mstype.bool_)
550
+ return index_tensor
551
+
552
+ # convert bool to int if index is mixture of (bool, int)
553
+ if all_bool_or_int:
554
+ new_index = []
555
+ for idx in list_index:
556
+ if isinstance(idx, bool):
557
+ new_idx = int(idx)
558
+ new_index.append(new_idx)
559
+ else:
560
+ new_index.append(idx)
561
+ index_tensor = TupleToTensor()(tuple(new_index), mstype.int64)
562
+ return index_tensor
563
+
564
+ return None
542
565
 
543
566
 
544
567
  class _TensorIndexGetitem(base.TensorIndexGetitem_):
@@ -564,26 +587,6 @@ def tensor_index_by_slice(data, slice_index):
564
587
  return _tensor_index_getitem(data, slice_index)
565
588
 
566
589
 
567
- def get_stride_info_from_slice(data, slice_index):
568
- """get the stride info from slice index"""
569
- data_shape = F.dyn_shape(data)
570
- begin_strides, end_strides, step_strides = [], [], []
571
- start, stop, step = get_slice_stride(slice_index, data_shape[0])
572
- if start.ndim > 0:
573
- start = start.item()
574
- if stop.ndim > 0:
575
- stop = stop.item()
576
- if step.ndim > 0:
577
- step = step.item()
578
- begin_strides.append(start)
579
- end_strides.append(stop)
580
- step_strides.append(step)
581
- begin_tensor = stack(begin_strides)
582
- end_tensor = stack(end_strides)
583
- step_tensor = stack(step_strides)
584
- return begin_tensor, end_tensor, step_tensor
585
-
586
-
587
590
  def tensor_index_by_number(data, number_index):
588
591
  """Tensor getitem by a Number which may be integer/float/bool value"""
589
592
  if isinstance(number_index, bool):
@@ -607,31 +610,18 @@ def _tensor_index_by_bool(data, bool_value):
607
610
  return output
608
611
 
609
612
 
610
- def get_stride_info_from_integer(tensor_int):
613
+ def get_stride_info_from_integer(int_index):
611
614
  """Convert integer to slice"""
612
- begin_strides = [tensor_int]
613
- end_strides = [tensor_int + 1]
614
- step_strides = [const_utils.make_tensor(1)]
615
- begin_tensor = stack(begin_strides)
616
- end_tensor = stack(end_strides)
617
- step_tensor = stack(step_strides)
618
- return begin_tensor, end_tensor, step_tensor
615
+ begin_strides = (int_index,)
616
+ end_strides = (int_index + 1,)
617
+ step_strides = (1,)
618
+ return begin_strides, end_strides, step_strides
619
619
 
620
620
 
621
621
  def _tensor_index_by_integer(data, int_index):
622
622
  """Tensor getitem by a single integer number"""
623
- data_shape = F.shape(data)
624
- if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
625
- tensor_index = _scalar_to_tensor(int_index)
626
- begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
627
- else:
628
- if not data_shape:
629
- const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
630
- if data.ndim < 1 or data.ndim > 8:
631
- const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
632
- transformed_number = const_utils.check_range(int_index, data_shape[0])
633
- begin_strides, end_strides, step_strides = \
634
- const_utils.get_stride_info_from_integer(data_shape, transformed_number)
623
+ begin_strides, end_strides, step_strides = get_stride_info_from_integer(int_index)
624
+
635
625
  shrink_axis_mask = 1
636
626
  begin_mask = 0
637
627
  end_mask = 0
@@ -664,6 +654,7 @@ def tensor_index_by_tensor(data, tensor_index):
664
654
  if not F.is_sequence_value_unknown(F.shape(data)):
665
655
  const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
666
656
  if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
657
+ tensor_index = F.select(tensor_index < 0, tensor_index + F.shape(data)[0], tensor_index)
667
658
  return F.gather(data, tensor_index, 0)
668
659
  if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
669
660
  return tensor_index_by_bool_tensor(data, tensor_index)
@@ -676,27 +667,23 @@ def tensor_index_by_tensor(data, tensor_index):
676
667
  def tensor_index_by_list(data, list_index):
677
668
  """Tensor getitem by list of int and bool"""
678
669
  min_data_dim, max_data_dim = 1, 8
679
- const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
670
+ if F.isconstant(data.ndim):
671
+ const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
680
672
 
681
673
  data_shape = F.shape(data)
682
- indexes_types = hyper_map(toptypeof, list_index)
683
- if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)) \
684
- and not F.is_sequence_value_unknown(list_index):
685
- if not F.isconstant(data_shape[0]):
686
- if all(isinstance(i, bool) for i in list_index):
687
- if F.dyn_shape(data)[0] != len(list_index):
688
- raise IndexError(
689
- f'dimension is {F.dyn_shape(data)[0]} but corresponding boolean dimension is {len(list_index)}')
690
- tensor_index = Tensor(list_index).nonzero()
691
- return F.gather_nd(data, tensor_index)
692
- tensor_index = const_utils.sequence_to_index(list_index, None)
693
- else:
694
- tensor_index = const_utils.sequence_to_index(
695
- list_index, data_shape[0])
696
- if tensor_index is False:
697
- const_utils.raise_index_error(
698
- "When tensor is indexed by list, the list can't be empty.")
699
- return F.gather(data, tensor_index, 0)
674
+ if F.isconstant(data_shape[0]) and all(isinstance(i, bool) for i in list_index):
675
+ if data_shape[0] != len(list_index):
676
+ raise IndexError(
677
+ f'dimension is {data_shape[0]} but corresponding boolean dimension is {len(list_index)}')
678
+ tensor_index = Tensor(list_index).nonzero()
679
+ return F.gather_nd(data, tensor_index)
680
+
681
+ if not list_index:
682
+ const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
683
+
684
+ index_tensor = _convert_list_index_to_tensor(list_index)
685
+ if index_tensor is not None:
686
+ return tensor_index_by_tensor(data, index_tensor)
700
687
 
701
688
  tuple_index_new = ()
702
689
  for index in list_index:
@@ -704,16 +691,6 @@ def tensor_index_by_list(data, list_index):
704
691
  return tensor_index_by_tuple(data, tuple_index_new)
705
692
 
706
693
 
707
- def convert_tupleslice_to_tensor(tuple_index):
708
- """convert mutable scalar in slice to tensor"""
709
- new_tuple_index = []
710
- for item in tuple_index:
711
- if isinstance(item, slice):
712
- item = convert_variable_to_tensor_slice(item)
713
- new_tuple_index.append(item)
714
- return tuple(new_tuple_index)
715
-
716
-
717
694
  def judge_tuple_index_dim_check_error(index_dim, data_dim):
718
695
  """raise IndexError when tuple_index's dim is invalid"""
719
696
  if index_dim > data_dim:
@@ -721,29 +698,6 @@ def judge_tuple_index_dim_check_error(index_dim, data_dim):
721
698
  f"dim of index:{index_dim}, dim of data:{data_dim}")
722
699
 
723
700
 
724
- class _HandleEmptySlice(base.HandleEmptySlice_):
725
- """
726
- Getting item of Tensor.
727
-
728
- Args:
729
- data (Tensor): A tuple to be sliced.
730
- index: Index of tensor.
731
-
732
- Returns:
733
- Type is the same as the element type of data.
734
- """
735
-
736
- def __init__(self, name):
737
- """Initialize _HandleEmptySlice."""
738
- base.HandleEmptySlice_.__init__(self, name)
739
-
740
- def __call__(self, *args):
741
- pass
742
-
743
-
744
- _handle_empty_slice = _HandleEmptySlice('handle_zero_tuple_index')
745
-
746
-
747
701
  def judge_tuple_index_dim(data, tuple_index):
748
702
  """Judge whether tuple_index's dim is valid"""
749
703
  data_dim = data.ndim
@@ -756,50 +710,20 @@ def judge_tuple_index_dim(data, tuple_index):
756
710
  judge_tuple_index_dim_check_error(index_dim, data_dim)
757
711
 
758
712
 
759
- def judge_simple_tuple_index(data, tuple_index):
760
- """Judge whether tuple_index is simple index, which not rollback to cpu ops."""
761
- op_name = const_utils.TENSOR_GETITEM
762
- indexes_types = hyper_map(toptypeof, tuple_index)
763
- contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
764
- return F.isconstant(tuple_index) and contain_type == const_utils.ALL_BASIC \
765
- and F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(F.rank(data))
766
-
767
-
768
713
  def tensor_index_by_tuple(data, tuple_index):
769
714
  """Tensor getitem by tuple of various types with None"""
770
715
  if not tuple_index:
771
716
  return data
772
- if judge_simple_tuple_index(data, tuple_index):
773
- tuple_index = convert_tupleslice_to_tensor(tuple_index)
774
- op_name = const_utils.TENSOR_GETITEM
775
- tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
776
- min_data_dim, max_data_dim = 1, 8
777
- const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
778
- return _tensor_getitem_by_tuple_slice(data, tuple_index)
779
717
 
780
718
  if not F.is_sequence_value_unknown(F.shape(data)):
781
719
  judge_tuple_index_dim(data, tuple_index)
782
720
  tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
783
721
  for non_zero_shape in non_zero_shapes:
784
- if F.reduce_min(non_zero_shape) == 0:
722
+ if 0 in non_zero_shape:
785
723
  tuple_index = zero_index
786
724
  break
787
- if not F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(tuple_index):
788
- _, stub_zero_dim_tensor = _handle_empty_slice(data, tuple_index)
789
- if 0 in stub_zero_dim_tensor.shape:
790
- return F.fill(data.dtype, stub_zero_dim_tensor.shape, 0)
791
- has_tensor_index = False
792
- for i in tuple_index:
793
- if isinstance(i, Tensor):
794
- has_tensor_index = True
795
- break
796
- empty_broadcast_data_shape = False
797
- _broadcast_data_shape = _handle_scalar_tensor_index(data, tuple_index)
798
- if has_tensor_index and isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
799
- empty_broadcast_data_shape = True
800
- if has_tensor_index and isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
801
- empty_broadcast_data_shape = True
802
- return _tensor_index_getitem(data, tuple_index, empty_broadcast_data_shape)
725
+
726
+ return _tensor_index_getitem(data, tuple_index)
803
727
 
804
728
 
805
729
  def get_slice_stride(slice_index, dim_size):
@@ -809,20 +733,20 @@ def get_slice_stride(slice_index, dim_size):
809
733
  step = slice_get_item(slice_index, "step")
810
734
 
811
735
  if start is None:
812
- start = const_utils.make_tensor(0)
736
+ start = 0
813
737
  if stop is None:
814
738
  stop = dim_size
815
739
  if step is None:
816
- step = const_utils.make_tensor(1)
740
+ step = 1
817
741
 
818
- if issubclass_(F.typeof(start), mstype.number):
819
- start = const_utils.make_tensor(start)
742
+ if isinstance(start, Tensor):
743
+ start = int(start)
820
744
 
821
- if issubclass_(F.typeof(stop), mstype.number):
822
- stop = const_utils.make_tensor(stop)
745
+ if isinstance(stop, Tensor):
746
+ stop = int(stop)
823
747
 
824
- if issubclass_(F.typeof(step), mstype.number):
825
- step = const_utils.make_tensor(step)
748
+ if isinstance(step, Tensor):
749
+ step = int(step)
826
750
 
827
751
  return start, stop, step
828
752
 
@@ -841,190 +765,6 @@ def cal_tuple_slice_mask(data_shape, tuple_index):
841
765
  return begin_mask, end_mask
842
766
 
843
767
 
844
- def _get_stride_info_from_tuple(data, tuple_index):
845
- """get the stride info from tuple"""
846
- data_shape = F.dyn_shape(data)
847
- begin_strides, end_strides, step_strides = [], [], []
848
- tuple_index_len = len(tuple_index)
849
- data_dim = data.ndim
850
- shrink_axis, index_count, ellipsis_count = 0, 0, 0
851
- for item in range(data_dim):
852
- if item >= tuple_index_len or item >= data_dim:
853
- break
854
- index = tuple_index[item]
855
- dim_size = data_shape[item]
856
- if isinstance(index, slice):
857
- start, stop, step = get_slice_stride(index, dim_size)
858
- begin_strides.append(start)
859
- end_strides.append(stop)
860
- step_strides.append(step)
861
- index_count = index_count + 1
862
- elif isinstance(index, int):
863
- int_tensor = _scalar_to_tensor(index)
864
- begin_strides.append(int_tensor)
865
- end_strides.append(int_tensor + const_utils.make_tensor(1))
866
- step_strides.append(const_utils.make_tensor(1))
867
- shrink_axis = shrink_axis + (2 ** index_count)
868
- index_count = index_count + 1
869
- elif index is ...:
870
- ellipsis_count = ellipsis_count + 1
871
- if ellipsis_count > 1:
872
- const_utils.raise_value_error("An index can have only one ellipsis (...)")
873
- ellipsis_range_size = data_dim - tuple_index_len + 1
874
- begin_strides.extend([const_utils.make_tensor(0)] * ellipsis_range_size)
875
- end_strides.extend(
876
- [shape for shape in data_shape[index_count: index_count + ellipsis_range_size]])
877
- step_strides.extend([const_utils.make_tensor(1)] * ellipsis_range_size)
878
- index_count = index_count + ellipsis_range_size
879
- else:
880
- exp_msg = const_utils.gen_exception_msg("Not supported index data type, got {}, type is {}", index,
881
- type(index))
882
- const_utils.raise_index_error(exp_msg)
883
- begin_tensor = stack(begin_strides)
884
- end_tensor = stack(end_strides)
885
- step_tensor = stack(step_strides)
886
- strides_v = {
887
- 'begin': begin_tensor,
888
- 'end': end_tensor,
889
- 'step': step_tensor
890
- }
891
- return strides_v, shrink_axis
892
-
893
-
894
- def _tensor_getitem_by_tuple_slice(data, tuple_index):
895
- """Tensor getitem by a tuple of slice"""
896
- data_shape = F.shape(data)
897
- is_dynamic = F.is_sequence_value_unknown(data_shape)
898
- for item in tuple_index:
899
- if isinstance(item, slice):
900
- is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
901
- or isinstance(slice_get_item(item, "stop"), Tensor) \
902
- or isinstance(slice_get_item(item, "step"), Tensor)
903
-
904
- strides_v = {}
905
- shrink_axis_mask = 0
906
- if not is_dynamic:
907
- strides_v, shrink_axis_mask = const_utils.get_stride_info_from_tuple(
908
- data_shape, tuple_index)
909
- else:
910
- strides_v, shrink_axis_mask = _get_stride_info_from_tuple(
911
- data, tuple_index)
912
- begin_mask, end_mask = cal_tuple_slice_mask(data_shape, tuple_index)
913
- begin_v = strides_v['begin']
914
- end_v = strides_v['end']
915
- step_v = strides_v['step']
916
- return strided_slice(data, begin_v, end_v, step_v, begin_mask, end_mask, 0, 0, shrink_axis_mask)
917
-
918
-
919
- @_primexpr
920
- def _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
921
- tensor_positions_new):
922
- """ parse index of bool tensor type """
923
- indices = index.nonzero()
924
- if indices.shape[0] == 0:
925
- return None, tensor_indexes, tensor_positions_new
926
- indices = F.cast(indices, mstype.int64)
927
- indices = indices.T
928
- for sub_index in indices:
929
- tensor_positions_new.append(len(tuple_index_new))
930
- tuple_index_new += (sub_index,)
931
- tensor_indexes.append(sub_index)
932
- return tuple_index_new, tensor_indexes, tensor_positions_new
933
-
934
-
935
- def _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new, tensor_indexes, tensor_positions_new):
936
- """ parse index of tensor type """
937
- if F.dtype(index) in mstype.int_type:
938
- tensor_index = F.cast(index, mstype.int64)
939
- tensor_positions_new.append(len(tuple_index_new))
940
- tuple_index_new += (tensor_index,)
941
- tensor_indexes.append(tensor_index)
942
- elif F.dtype(index) == mstype.bool_:
943
- return _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
944
- tensor_positions_new)
945
- else:
946
- exp_msg = const_utils.gen_exception_msg(
947
- "The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
948
- const_utils.raise_index_error(exp_msg)
949
- return tuple_index_new, tensor_indexes, tensor_positions_new
950
-
951
-
952
- def _tensor_getitem_by_tuple(data, tuple_index, op_name):
953
- """Tensor getitem by a tuple of mixed tensor."""
954
- slice_is_tensor = False
955
- for item in tuple_index:
956
- if isinstance(item, slice):
957
- slice_is_tensor = isinstance(slice_get_item(item, "start"), Tensor) \
958
- or isinstance(slice_get_item(item, "stop"), Tensor) \
959
- or isinstance(slice_get_item(item, "step"), Tensor)
960
- if slice_is_tensor:
961
- const_utils.raise_index_error("Not supported when slice has tensor")
962
-
963
- indexes_types = hyper_map(toptypeof, tuple_index)
964
- slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
965
- const_utils.get_pos_of_indexes_types(indexes_types, op_name)
966
- data_shape = F.shape(data)
967
- tensor_indexes, slice_indexes = [], []
968
- tuple_index_new, slice_shapes = (), ()
969
- slice_positions_new, tensor_positions_new = [], []
970
- for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
971
- if i in int_positions:
972
- int_index = const_utils.check_range(index, dim_size)
973
- tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
974
- if F.is_sequence_value_unknown(data_shape):
975
- tensor_index = _scalar_to_tensor(int_index)
976
- tensor_index = F.cast(tensor_index, mstype.int64)
977
- tensor_positions_new.append(len(tuple_index_new))
978
- tuple_index_new += (tensor_index,)
979
- tensor_indexes.append(tensor_index)
980
- elif i in sequence_positions:
981
- tensor_index = const_utils.sequence_to_index(index, dim_size)
982
- if tensor_index is False:
983
- const_utils.raise_index_error("The sequence element(tuple/list) in tuple index can't be empty.")
984
- tensor_positions_new.append(len(tuple_index_new))
985
- tuple_index_new += (tensor_index,)
986
- tensor_indexes.append(tensor_index)
987
- elif i in tensor_positions:
988
- tuple_index_new, tensor_indexes, tensor_positions_new = \
989
- _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new,
990
- tensor_indexes, tensor_positions_new)
991
- if tuple_index_new is None:
992
- return Tensor([])
993
- elif i in slice_positions:
994
- slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
995
- slice_shapes += (len(slice_ele_list_index),)
996
- slice_positions_new.append(len(tuple_index_new))
997
- tuple_index_new += (slice_ele_list_index,)
998
- slice_indexes.append(slice_ele_list_index)
999
- tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
1000
- broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
1001
- const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions_new, tensor_indexes_shapes,
1002
- slice_shapes, op_name)
1003
-
1004
- tuple_index_len = len(tuple_index)
1005
- if 0 in final_shape + data_shape:
1006
- if tuple_index_len < len(data_shape):
1007
- final_shape = final_shape + data_shape[tuple_index_len:]
1008
- return const_utils.make_tensor([], data.dtype, final_shape)
1009
-
1010
- final_index_tensors = []
1011
- slice_cnt = 0
1012
- for i, index in enumerate(tuple_index_new):
1013
- if i in tensor_positions_new:
1014
- transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
1015
- index)
1016
- final_index_tensors.append(transform_tensor)
1017
- elif i in slice_positions_new:
1018
- slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
1019
- slice_shapes, fancy_position)
1020
- final_index_tensors.append(slice_index_tensor)
1021
- slice_cnt += 1
1022
-
1023
- indices = stack(final_index_tensors)
1024
- result = F.gather_nd(data, indices)
1025
- return result
1026
-
1027
-
1028
768
  def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
1029
769
  """Generate an indices tensor from a tuple of tensor."""
1030
770
  indexes_types = hyper_map(F.dtype, tuple_index)
@@ -1116,8 +856,15 @@ def sequence_to_tensor(value, dtype):
1116
856
 
1117
857
  if value_elements_type == const_utils.ALL_TENSOR:
1118
858
  value = F.stack(value).astype(dtype)
1119
- elif value_elements_type == const_utils.NO_TENSOR and not F.is_sequence_value_unknown(value):
1120
- value = const_utils.make_tensor(value, dtype)
859
+ elif value_elements_type == const_utils.NO_TENSOR:
860
+ if isinstance(value, list):
861
+ value = tuple(value)
862
+
863
+ if dtype == mstype.float16:
864
+ value = TupleToTensor()(value, mstype.float32)
865
+ value = F.cast(value, dtype)
866
+ else:
867
+ value = TupleToTensor()(value, dtype)
1121
868
  else:
1122
869
  new_value = ()
1123
870
  for ele in value:
@@ -1138,57 +885,31 @@ def _generate_updates_from_sequence(data, index, value, op_type):
1138
885
  def _generate_updates_from_tensor(data, index, value, op_type):
1139
886
  """Generate an updates tensor from a tensor."""
1140
887
  value = value.astype(data.dtype)
1141
- if F.is_sequence_value_unknown(F.shape(data)) or F.is_sequence_value_unknown(F.shape(index)):
1142
- data_shape = F.dyn_shape(data)
1143
- index_shape = F.dyn_shape(index)
1144
- updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
1145
- updates = ops.broadcast_to(value, updates_shape)
1146
- return updates
1147
- updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type, False)
1148
- need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
1149
- if need_broadcast:
1150
- return _broadcast(updates_shape, value)
1151
- return value
888
+ updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type)
889
+ updates = ops.broadcast_to(value, updates_shape)
890
+ return updates
1152
891
 
1153
892
 
1154
893
  # Tensor getitem implementations are above this line, setitem implementations below.
1155
894
 
1156
- def tensor_setitem_by_tensor(self, index, value):
1157
- if isinstance(value, (int, float, bool)):
1158
- return tensor_setitem_by_tensor_with_number(self, index, value)
1159
- if isinstance(value, Tensor):
1160
- return tensor_setitem_by_tensor_with_tensor(self, index, value)
1161
- return tensor_setitem_by_tensor_with_sequence(self, index, value)
1162
-
1163
-
1164
- def tensor_setitem_by_tuple(self, index, value):
1165
- index = convert_tupleslice_to_tensor(index)
1166
- if isinstance(value, (int, float, bool)):
1167
- index = format_tuple_indices(index)
1168
- return tensor_setitem_by_tuple_with_number(self, index, value)
1169
- if isinstance(value, Tensor):
1170
- return tensor_setitem_by_tuple_with_tensor(self, index, value)
1171
- return tensor_setitem_by_tuple_with_sequence(self, index, value)
895
+ def _tensor_index_transfer(index, broadcast_shape, final_shape, new_shape):
896
+ """Transform tuple index tensor to the required."""
897
+ if 0 in final_shape:
898
+ return F.fill(index.dtype, final_shape, 0)
1172
899
 
900
+ if broadcast_shape == ():
901
+ # broadcast_to () is not support on Ascend
902
+ item = index
903
+ else:
904
+ item = F.broadcast_to(index, broadcast_shape)
905
+ item = F.reshape(item, new_shape)
906
+ return F.broadcast_to(item, final_shape)
1173
907
 
1174
- def tensor_setitem_by_number(self, index, value):
1175
- if isinstance(value, (int, float, bool)):
1176
- return tensor_setitem_by_number_with_number(self, index, value)
1177
- if isinstance(value, Tensor):
1178
- return tensor_setitem_by_number_with_tensor(self, index, value)
1179
- return tensor_setitem_by_number_with_sequence(self, index, value)
1180
908
 
1181
-
1182
- def _tuple_index_transfer(broadcast_shape, final_shape, new_shape, x, all_empty_tensor):
1183
- """Transform tuple index tensor to the required."""
1184
- if isinstance(broadcast_shape, Tensor):
1185
- if not all_empty_tensor:
1186
- x = F.broadcast_to(x, broadcast_shape)
1187
- x = F.reshape(x, new_shape)
1188
- x = F.broadcast_to(x, final_shape)
1189
- return x
1190
- item = _broadcast(broadcast_shape, x)
1191
- return _broadcast(final_shape, F.reshape(item, new_shape))
909
+ def reshape_with_check(x, new_shape):
910
+ if isinstance(new_shape, Tensor):
911
+ new_shape = TensorToTuple()(new_shape)
912
+ return F.reshape(x, new_shape)
1192
913
 
1193
914
 
1194
915
  class _TensorIndexSetitem(base.TensorIndexSetitem_):
@@ -1218,9 +939,10 @@ def tensor_setitem_by_slice(self, index, value):
1218
939
  return self
1219
940
  value = F.broadcast_to(value, value_shape)
1220
941
  if not const_utils.is_ascend() and step == 1:
1221
- if isinstance(step, Tensor):
1222
- return copy_slice(self, value, start, stop, step)
1223
- return copy_slice(self, value, (start,), (stop,), (step,))
942
+ start = (start,)
943
+ stop = (stop,)
944
+ step = (step,)
945
+ return copy_slice(self, value, start, stop, step)
1224
946
  return F.tensor_scatter_update(self, indices, value)
1225
947
 
1226
948
 
@@ -1236,14 +958,14 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
1236
958
  """Set a tensor item by an int tensor with a tensor."""
1237
959
  if F.rank(index) == 0:
1238
960
  index = F.expand_dims(index, -1)
1239
- updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
961
+
1240
962
  data_shape = F.shape(data)
963
+ updates_shape = index.shape + data_shape[1:]
964
+ value = F.cast(value, F.dtype(data))
965
+ updates = ops.broadcast_to(value, updates_shape)
1241
966
  first_val = data_shape[0]
1242
967
  index = F.select(index < 0, index + first_val, index)
1243
968
  index = F.expand_dims(index, -1)
1244
- if F.rank(index) < 2:
1245
- index = F.expand_dims(index, 0)
1246
- updates = F.expand_dims(updates, 0)
1247
969
  if is_parameter(data):
1248
970
  F.scatter_nd_update(data, index, updates)
1249
971
  return data
@@ -1255,8 +977,7 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
1255
977
  index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
1256
978
  index = F.broadcast_to(index, data.shape)
1257
979
  value = F.cast(value, F.dtype(data))
1258
- while value.ndim < data.ndim:
1259
- value = value.unsqueeze(-1)
980
+ value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
1260
981
  value = F.broadcast_to(value, data.shape)
1261
982
  result = F.select(index, value, data)
1262
983
  return result
@@ -1269,8 +990,6 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
1269
990
  if tensor_dtype == const_utils.INT_:
1270
991
  return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
1271
992
 
1272
- if F.is_sequence_value_unknown(F.shape(data)):
1273
- return tensor_setitem_by_tuple_with_tensor(data, (index,), value_tensor.astype(data.dtype))
1274
993
  return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
1275
994
 
1276
995
 
@@ -1281,33 +1000,8 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
1281
1000
 
1282
1001
  def tensor_setitem_by_tensor_with_sequence(data, index, value):
1283
1002
  """Assigns the tensor by tensor with tuple value."""
1284
- index_dtype = F.dtype(index)
1285
- if index_dtype in (mstype.int32, mstype.int64):
1286
- return _tensor_setitem_by_tensor_with_sequence(data, index, value)
1287
- if index_dtype == mstype.bool_:
1288
- return _tensor_setitem_by_bool_tensor_with_sequence(data, index, value)
1289
- exp_msg = const_utils.gen_exception_msg("The tensor index must be int or bool type, but got {}.", index_dtype)
1290
- const_utils.raise_index_error(exp_msg)
1291
- return None
1292
-
1293
-
1294
- def _tensor_setitem_by_tensor_with_sequence(data, index, value):
1295
- """Set a tensor item by a tensor with a tuple."""
1296
- updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
1297
- index = F.expand_dims(index, -1)
1298
- return F.tensor_scatter_update(data, index, updates)
1299
-
1300
-
1301
- def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
1302
- """Set a tensor item by a bool tensor with a tuple."""
1303
1003
  value = sequence_to_tensor(value, F.dtype(data))
1304
- return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value)
1305
-
1306
-
1307
- def tensor_setitem_by_slice_with_number(data, input_slice, value):
1308
- """Givens a scalar assign to tensor by slice"""
1309
- value = F.cast(value, F.dtype(data))
1310
- return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
1004
+ return tensor_setitem_by_tensor_with_tensor(data, index, value)
1311
1005
 
1312
1006
 
1313
1007
  def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
@@ -1316,78 +1010,14 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
1316
1010
  return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
1317
1011
 
1318
1012
 
1319
- def tensor_copy_slice_from_slice(data, input_slice, value):
1320
- """using TensorCopySlices by slice."""
1321
- data_shape = F.dyn_shape(data)
1322
- start, stop, step = get_slice_stride(input_slice, data_shape[0])
1323
- start_tensor = stack((start,))
1324
- stop_tensor = stack((stop,))
1325
- step_tensor = stack((step,))
1326
- dim0_size = stop_tensor - start_tensor
1327
- if dim0_size <= 0:
1328
- return data
1329
- if dim0_size >= data_shape[0]:
1330
- dim0_size = data_shape[0:1]
1331
- value_shape = P.Concat(-1)((dim0_size, data_shape[1:]))
1332
- value = ops.broadcast_to(value, value_shape)
1333
- return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
1334
-
1335
-
1336
- def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
1337
- """Assigns a tensor value to the tensor by slice."""
1338
- result = None
1339
- check_result = const_utils.check_tensor_setitem_index(input_slice)
1340
- if check_result:
1341
- data_shape = F.shape(data)
1342
- step = const_utils.get_step_from_slice(input_slice)
1343
- if step == 1 and not const_utils.is_ascend():
1344
- if F.is_sequence_value_unknown(data_shape):
1345
- return tensor_copy_slice_from_slice(data, input_slice, value)
1346
- start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
1347
- dim0_size = stop - start
1348
- if dim0_size <= 0:
1349
- return data
1350
- value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
1351
- value = _broadcast(value_shape, value)
1352
- return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
1353
- if F.is_sequence_value_unknown(data_shape):
1354
- const_utils.raise_unimplemented_error(
1355
- "Not supported to take the subscript of dynamic shape tensor slice setitem")
1356
- indices = const_utils.slice2indices(input_slice, data_shape)
1357
- if indices is False:
1358
- return data
1359
- value_shape = const_utils.tuple_slice(F.shape(indices), None, -1)
1360
- value = _broadcast(value_shape, value)
1361
- result = F.tensor_scatter_update(data, indices, value.astype(F.dtype(data)))
1362
- return result
1363
-
1364
-
1365
- def tensor_setitem_by_slice_with_sequence(data, input_slice, value):
1366
- """Assigns a list/tuple value to the tensor by slice."""
1367
- value = _generate_updates_from_sequence(data, input_slice, value, const_utils.SET_ITEM_BY_NON_TENSOR)
1368
- return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
1013
+ def tensor_setitem_by_list(data, index, value):
1014
+ """list indices will be converted to tuple or tensor based on its contents."""
1015
+ index_tensor = _convert_list_index_to_tensor(index)
1016
+ if index_tensor is not None:
1017
+ return tensor_setitem_by_tensor_with_tensor(data, index_tensor, value)
1369
1018
 
1019
+ return tensor_setitem_by_tuple_with_tensor(data, tuple(index), value)
1370
1020
 
1371
- def tensor_copy_slice_from_tuple(data, tuple_index, value):
1372
- """using TensorCopySlices by fixed model tuple."""
1373
- data_shape = F.dyn_shape(data)
1374
- dim1_start, dim1_stop, _ = get_slice_stride(tuple_index[1], data_shape[1])
1375
- if dim1_stop - dim1_start <= 0:
1376
- return data
1377
- dim0_start = _scalar_to_tensor(tuple_index[0])
1378
- dim0_stop = dim0_start + const_utils.make_tensor(1)
1379
- start = (dim0_start, dim1_start)
1380
- stop = (dim0_stop, dim1_stop)
1381
- step = (const_utils.make_tensor(1), const_utils.make_tensor(1))
1382
- start_tensor = stack(start)
1383
- stop_tensor = stack(stop)
1384
- step_tensor = stack(step)
1385
- dim1_size = stack((dim1_stop - dim1_start,))
1386
- if dim1_size > data_shape[1]:
1387
- dim1_size = data_shape[1:2]
1388
- value_shape = P.Concat(-1)((dim1_size, data_shape[2:]))
1389
- value = ops.broadcast_to(value, value_shape)
1390
- return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
1391
1021
 
1392
1022
 
1393
1023
  class _PreSetitemByTuple(base.PreSetitemByTuple_):
@@ -1436,50 +1066,28 @@ class _HandleBoolTensor(base.HandleBoolTensor_):
1436
1066
  _handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
1437
1067
 
1438
1068
 
1439
- class _HandleScalarTensorIndex(base.HandleScalarTensorIndex_):
1440
- """
1441
- Getting item of Tensor.
1442
-
1443
- Args:
1444
- data (Tensor): A tuple to be sliced.
1445
- index: Index of tensor.
1446
-
1447
- Returns:
1448
- Type is the same as the element type of data.
1449
- """
1450
-
1451
- def __init__(self, name):
1452
- """Initialize _HandleBoolTensor."""
1453
- base.HandleScalarTensorIndex_.__init__(self, name)
1454
-
1455
- def __call__(self, *args):
1456
- pass
1457
-
1458
-
1459
- _handle_scalar_tensor_index = _HandleScalarTensorIndex('handle_scalar_tensor_index')
1460
-
1461
-
1462
1069
  def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
1463
1070
  """Assigns the tensor by tuple with tensor value."""
1464
1071
  if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
1465
- if F.is_sequence_value_unknown(F.shape(data)):
1466
- return tensor_copy_slice_from_tuple(data, tuple_index, value)
1467
1072
  dim1_start, dim1_stop, _ = const_utils.normalize_slice(
1468
1073
  tuple_index[1], data.shape[1])
1074
+ if isinstance(dim1_start, Tensor):
1075
+ dim1_start = int(dim1_start)
1076
+ if isinstance(dim1_stop, Tensor):
1077
+ dim1_stop = int(dim1_stop)
1469
1078
  if dim1_stop - dim1_start <= 0:
1470
1079
  return data
1471
1080
  dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
1472
1081
  start = (dim0_start, dim1_start)
1473
1082
  stop = (dim0_start + 1, dim1_stop)
1474
1083
  step = (1, 1)
1475
- value_shape = (dim1_stop - dim1_start,) + \
1476
- const_utils.tuple_slice(data.shape, 2, None)
1477
- value = _broadcast(value_shape, value)
1084
+ value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
1085
+ value = F.broadcast_to(value, value_shape)
1478
1086
  return copy_slice(data, value.astype(data.dtype), start, stop, step)
1479
1087
  tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
1480
1088
 
1481
1089
  for non_zero_shape in non_zero_shapes:
1482
- if F.reduce_min(non_zero_shape) == 0:
1090
+ if 0 in non_zero_shape:
1483
1091
  return data
1484
1092
  value = value.astype(data.dtype)
1485
1093
  special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
@@ -1512,17 +1120,19 @@ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
1512
1120
  tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
1513
1121
 
1514
1122
  if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
1515
- if F.is_sequence_value_unknown(F.shape(data)):
1516
- return tensor_copy_slice_from_tuple(data, tuple_index, value)
1517
1123
  dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
1124
+ if isinstance(dim1_start, Tensor):
1125
+ dim1_start = int(dim1_start)
1126
+ if isinstance(dim1_stop, Tensor):
1127
+ dim1_stop = int(dim1_stop)
1518
1128
  if dim1_stop - dim1_start <= 0:
1519
1129
  return data
1520
1130
  dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
1521
1131
  start = (dim0_start, dim1_start)
1522
1132
  stop = (dim0_start + 1, dim1_stop)
1523
1133
  step = (1, 1)
1524
- value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
1525
- value = _broadcast(value_shape, value)
1134
+ value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
1135
+ value = F.broadcast_to(value, value_shape)
1526
1136
  return copy_slice(data, value.astype(data.dtype), start, stop, step)
1527
1137
  tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
1528
1138
 
@@ -1545,49 +1155,45 @@ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
1545
1155
 
1546
1156
 
1547
1157
  def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
1548
- value = _generate_updates_from_sequence(data, tuple_index, value, const_utils.SET_ITEM_BY_NON_TENSOR)
1158
+ value = sequence_to_tensor(value, F.dtype(data))
1549
1159
  return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
1550
1160
 
1551
1161
 
1552
1162
  def tensor_setitem_by_number_with_number(data, index, value):
1553
1163
  """Assigns the tensor by number with number value."""
1554
- value = F.cast(value, F.dtype(data))
1555
- return tensor_setitem_by_number_with_tensor(data, index, value)
1164
+ data_shape = F.shape(data)
1165
+ dim_size = data_shape[0]
1166
+ if index < 0:
1167
+ index += dim_size
1168
+ if index < -dim_size or index >= dim_size:
1169
+ raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
1170
+ index = F.cast(index, mstype.int64)
1171
+ index = F.reshape(index, (1, 1))
1172
+
1173
+ updates = F.cast(value, data.dtype)
1174
+ updates_shape = (1,) + data_shape[1:]
1175
+ updates = ops.broadcast_to(updates, updates_shape)
1176
+
1177
+ if is_parameter(data):
1178
+ F.scatter_nd_update(data, index, updates)
1179
+ return data
1180
+ return F.tensor_scatter_update(data, index, updates)
1556
1181
 
1557
1182
 
1558
1183
  def tensor_setitem_by_number_with_sequence(data, index, value):
1559
1184
  """Assigns a list/tuple value to the tensor by slice."""
1560
- value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR)
1185
+ value = sequence_to_tensor(value, F.dtype(data))
1561
1186
  return tensor_setitem_by_number_with_tensor(data, index, value)
1562
1187
 
1563
1188
 
1564
1189
  def tensor_setitem_by_number_with_tensor(data, index, value):
1565
- """Assigns the tensor by number with tensor value."""
1566
- data_shape = F.shape(data)
1567
- if F.is_sequence_value_unknown(data_shape):
1568
- index = _scalar_to_tensor(index)
1569
- index = F.expand_dims(index, -1)
1570
- return _tensor_setitem_by_int_tensor_with_tensor(data, index, value)
1571
-
1572
- dim_size = data_shape[0]
1573
- if index < -dim_size or index >= dim_size:
1574
- raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
1575
- index = const_utils.int_to_index(index, data_shape)
1576
- value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
1577
- value = _broadcast(value_shape, value.astype(F.dtype(data)))
1578
- if is_parameter(data):
1579
- F.scatter_nd_update(data, index, value)
1580
- return data
1581
- return F.tensor_scatter_update(data, index, value)
1190
+ return tensor_setitem_by_number_with_number(data, index, value)
1582
1191
 
1583
1192
 
1584
1193
  def tensor_setitem_by_ellipsis_with_number(data, value):
1585
1194
  """Assigns the tensor by ellipsis with number value."""
1586
1195
  data_shape = F.shape(data)
1587
1196
  data_dtype = F.dtype(data)
1588
- if F.is_sequence_value_unknown(data_shape):
1589
- value = F.cast(value, F.dtype(data))
1590
- return tensor_setitem_by_ellipsis_with_tensor(data, value)
1591
1197
  return F.fill(data_dtype, data_shape, value)
1592
1198
 
1593
1199
 
@@ -1597,21 +1203,20 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
1597
1203
  data_dtype = F.dtype(data)
1598
1204
  value = value.astype(data_dtype)
1599
1205
 
1600
- if F.is_sequence_value_unknown(data_shape):
1601
- data_shape = F.dyn_shape(data)
1602
- data = ops.broadcast_to(value, data_shape)
1603
- return data
1604
1206
  value_shape = F.shape(value)
1605
- source_shape = const_utils.get_source_shape(data_shape, value_shape)
1207
+
1208
+ if len(value_shape) > len(data_shape):
1209
+ source_shape = data_shape
1210
+ else:
1211
+ source_shape = value_shape
1606
1212
  value = F.reshape(value, source_shape)
1607
- value = _broadcast(data_shape, value)
1608
- data = F.cast(value, data_dtype)
1213
+ data = F.broadcast_to(value, data_shape)
1609
1214
  return data
1610
1215
 
1611
1216
 
1612
1217
  def tensor_setitem_by_ellipsis_with_sequence(data, value):
1613
1218
  """Assigns a list/tuple value to the tensor by ellipsis."""
1614
- value = _generate_updates_from_sequence(data, None, value, const_utils.SET_ITEM_BY_NON_TENSOR)
1219
+ value = sequence_to_tensor(value, F.dtype(data))
1615
1220
  return tensor_setitem_by_ellipsis_with_tensor(data, value)
1616
1221
 
1617
1222
 
@@ -1622,23 +1227,18 @@ def tensor_setitem_by_bool(data, index, value):
1622
1227
  if not index:
1623
1228
  data_shape = (0,) + data_shape
1624
1229
  if isinstance(value, (list, tuple)):
1625
- value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR)
1626
- elif isinstance(value, (int, bool)):
1627
- value = const_utils.make_tensor(value, mstype.int32)
1628
- elif isinstance(value, float):
1629
- value = const_utils.make_tensor(value, mstype.float32)
1630
-
1631
- if F.is_sequence_value_unknown(data_shape) and index:
1632
- data_shape = F.dyn_shape(data)
1633
- value = value.astype(data_dtype)
1634
- data = ops.broadcast_to(value, data_shape)
1635
- return data
1636
- value_shape = F.shape(value)
1637
- source_shape = const_utils.get_source_shape(data_shape, value_shape)
1230
+ value = sequence_to_tensor(value, data_dtype)
1231
+ else:
1232
+ value = F.cast(value, data_dtype)
1233
+
1638
1234
  if index:
1235
+ value_shape = F.shape(value)
1236
+ if len(value_shape) > len(data_shape):
1237
+ source_shape = data_shape
1238
+ else:
1239
+ source_shape = value_shape
1639
1240
  value = F.reshape(value, source_shape)
1640
- value = _broadcast(data_shape, value)
1641
- data = F.cast(value, data_dtype)
1241
+ data = F.broadcast_to(value, data_shape)
1642
1242
  return data
1643
1243
 
1644
1244
 
@@ -1651,33 +1251,6 @@ def tensor_in_sequence(x, y):
1651
1251
  return result
1652
1252
 
1653
1253
 
1654
- def format_list_indices(list_indices, length):
1655
- """Convert list indices to tensor or tuple indices based on its contents."""
1656
- indices_types = hyper_map(F.typeof, list_indices)
1657
- # If eyery element in list is bool, it's treated as 1-D bool tensor.
1658
- # If every element in list is int(not all bool), it's treated as int tensor.
1659
- if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
1660
- if not F.isconstant(length):
1661
- return const_utils.sequence_to_index(list_indices, None)
1662
- return const_utils.sequence_to_index(list_indices, length)
1663
- # If list contains other types(.../list/tuple/None), it's treated as a tuple
1664
- return const_utils.deep_tuple(list_indices)
1665
-
1666
-
1667
- def format_tuple_indices(tuple_indices):
1668
- """
1669
- Format tuple indices by unpacking high-dimension tuple and removing expand
1670
- dimension signs(Bool and None).
1671
- """
1672
- res = ()
1673
- for i in tuple_indices:
1674
- if isinstance(i, (list, tuple)):
1675
- res += (const_utils.unpack(i),)
1676
- else:
1677
- res += (i,)
1678
- return res
1679
-
1680
-
1681
1254
  @_primexpr
1682
1255
  def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
1683
1256
  """ Parse bool tensor index """
@@ -1830,7 +1403,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
1830
1403
  return reduce_fn(a, axes).astype(dtype)
1831
1404
 
1832
1405
 
1833
- tensor_operator_registry.register("reduce", reduce_)
1406
+ setattr(tensor_operator_registry, "reduce", reduce_)
1834
1407
 
1835
1408
 
1836
1409
  def check_indices(dims, indices, mode, allow_negative_index=True):
@@ -1857,7 +1430,7 @@ def check_indices(dims, indices, mode, allow_negative_index=True):
1857
1430
  return clipped
1858
1431
 
1859
1432
 
1860
- tensor_operator_registry.register('check_indices', check_indices)
1433
+ setattr(tensor_operator_registry, 'check_indices', check_indices)
1861
1434
 
1862
1435
 
1863
1436
  def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):