mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1217) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -5
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +124 -25
  9. mindspore/_extends/builtin_operations.py +2 -1
  10. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  11. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  12. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  13. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  14. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  15. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  16. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  17. mindspore/_extends/parse/__init__.py +18 -14
  18. mindspore/_extends/parse/compile_config.py +299 -0
  19. mindspore/_extends/parse/namespace.py +2 -2
  20. mindspore/_extends/parse/parser.py +182 -68
  21. mindspore/_extends/parse/resources.py +45 -14
  22. mindspore/_extends/parse/standard_method.py +192 -252
  23. mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
  24. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  25. mindspore/_extends/remote/kernel_build_server.py +2 -0
  26. mindspore/_profiler.py +30 -0
  27. mindspore/amp.py +67 -26
  28. mindspore/atlprov.dll +0 -0
  29. mindspore/avcodec-59.dll +0 -0
  30. mindspore/avdevice-59.dll +0 -0
  31. mindspore/avfilter-8.dll +0 -0
  32. mindspore/avformat-59.dll +0 -0
  33. mindspore/avutil-57.dll +0 -0
  34. mindspore/boost/adasum.py +1 -1
  35. mindspore/boost/base.py +1 -1
  36. mindspore/boost/boost_cell_wrapper.py +2 -2
  37. mindspore/boost/grad_freeze.py +2 -2
  38. mindspore/boost/group_loss_scale_manager.py +1 -1
  39. mindspore/boost/less_batch_normalization.py +9 -6
  40. mindspore/c1.dll +0 -0
  41. mindspore/c1xx.dll +0 -0
  42. mindspore/c2.dll +0 -0
  43. mindspore/common/__init__.py +20 -7
  44. mindspore/common/_jit_fallback_utils.py +2 -3
  45. mindspore/common/_pijit_context.py +190 -0
  46. mindspore/common/_register_for_adapter.py +7 -0
  47. mindspore/common/_register_for_recompute.py +48 -0
  48. mindspore/common/_register_for_tensor.py +10 -10
  49. mindspore/common/_stub_tensor.py +7 -1
  50. mindspore/common/_tensor_overload.py +139 -0
  51. mindspore/common/_utils.py +5 -17
  52. mindspore/common/api.py +449 -129
  53. mindspore/common/auto_dynamic_shape.py +27 -14
  54. mindspore/common/dtype.py +17 -10
  55. mindspore/common/dump.py +8 -11
  56. mindspore/common/file_system.py +48 -0
  57. mindspore/common/generator.py +254 -0
  58. mindspore/common/hook_handle.py +65 -30
  59. mindspore/common/initializer.py +1 -1
  60. mindspore/common/jit_config.py +34 -14
  61. mindspore/common/lazy_inline.py +72 -19
  62. mindspore/common/mindir_util.py +12 -2
  63. mindspore/common/mutable.py +79 -14
  64. mindspore/common/no_inline.py +54 -0
  65. mindspore/common/np_dtype.py +25 -0
  66. mindspore/common/parameter.py +73 -21
  67. mindspore/common/recompute.py +292 -0
  68. mindspore/common/seed.py +9 -9
  69. mindspore/common/sparse_tensor.py +276 -24
  70. mindspore/common/symbol.py +122 -0
  71. mindspore/common/tensor.py +668 -514
  72. mindspore/communication/__init__.py +6 -11
  73. mindspore/communication/_comm_helper.py +43 -3
  74. mindspore/communication/comm_func.py +1395 -0
  75. mindspore/communication/management.py +117 -104
  76. mindspore/config/op_info.config +22 -54
  77. mindspore/context.py +455 -71
  78. mindspore/dataset/__init__.py +5 -5
  79. mindspore/dataset/audio/__init__.py +6 -6
  80. mindspore/dataset/audio/transforms.py +711 -158
  81. mindspore/dataset/callback/ds_callback.py +2 -2
  82. mindspore/dataset/core/config.py +7 -0
  83. mindspore/dataset/core/validator_helpers.py +7 -0
  84. mindspore/dataset/engine/cache_client.py +2 -2
  85. mindspore/dataset/engine/datasets.py +201 -116
  86. mindspore/dataset/engine/datasets_audio.py +14 -14
  87. mindspore/dataset/engine/datasets_standard_format.py +83 -3
  88. mindspore/dataset/engine/datasets_text.py +39 -39
  89. mindspore/dataset/engine/datasets_user_defined.py +230 -141
  90. mindspore/dataset/engine/datasets_vision.py +78 -74
  91. mindspore/dataset/engine/iterators.py +29 -0
  92. mindspore/dataset/engine/obs/util.py +7 -0
  93. mindspore/dataset/engine/offload.py +5 -7
  94. mindspore/dataset/engine/queue.py +138 -66
  95. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  96. mindspore/dataset/engine/validators.py +41 -15
  97. mindspore/dataset/text/__init__.py +2 -5
  98. mindspore/dataset/text/transforms.py +408 -121
  99. mindspore/dataset/text/utils.py +9 -9
  100. mindspore/dataset/transforms/__init__.py +0 -3
  101. mindspore/dataset/transforms/transforms.py +261 -76
  102. mindspore/dataset/utils/browse_dataset.py +9 -9
  103. mindspore/dataset/utils/line_reader.py +2 -0
  104. mindspore/dataset/vision/__init__.py +7 -10
  105. mindspore/dataset/vision/c_transforms.py +10 -10
  106. mindspore/dataset/vision/py_transforms_util.py +1 -1
  107. mindspore/dataset/vision/transforms.py +2844 -549
  108. mindspore/dataset/vision/utils.py +161 -10
  109. mindspore/dataset/vision/validators.py +16 -3
  110. mindspore/dnnl.dll +0 -0
  111. mindspore/dpcmi.dll +0 -0
  112. mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
  113. mindspore/experimental/es/embedding_service.py +883 -0
  114. mindspore/experimental/es/embedding_service_layer.py +581 -0
  115. mindspore/experimental/llm_boost/__init__.py +21 -0
  116. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  119. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  120. mindspore/experimental/llm_boost/register.py +129 -0
  121. mindspore/experimental/llm_boost/utils.py +31 -0
  122. mindspore/experimental/optim/__init__.py +12 -2
  123. mindspore/experimental/optim/adadelta.py +161 -0
  124. mindspore/experimental/optim/adagrad.py +168 -0
  125. mindspore/experimental/optim/adam.py +35 -34
  126. mindspore/experimental/optim/adamax.py +170 -0
  127. mindspore/experimental/optim/adamw.py +124 -15
  128. mindspore/experimental/optim/asgd.py +153 -0
  129. mindspore/experimental/optim/lr_scheduler.py +66 -121
  130. mindspore/experimental/optim/nadam.py +157 -0
  131. mindspore/experimental/optim/optimizer.py +18 -8
  132. mindspore/experimental/optim/radam.py +194 -0
  133. mindspore/experimental/optim/rmsprop.py +154 -0
  134. mindspore/experimental/optim/rprop.py +164 -0
  135. mindspore/experimental/optim/sgd.py +28 -19
  136. mindspore/hal/__init__.py +40 -0
  137. mindspore/hal/_ascend.py +57 -0
  138. mindspore/hal/_base.py +57 -0
  139. mindspore/hal/_cpu.py +56 -0
  140. mindspore/hal/_gpu.py +57 -0
  141. mindspore/hal/contiguous_tensors_handle.py +175 -0
  142. mindspore/hal/device.py +356 -0
  143. mindspore/hal/event.py +179 -0
  144. mindspore/hal/memory.py +326 -0
  145. mindspore/hal/stream.py +357 -0
  146. mindspore/include/api/data_type.h +2 -2
  147. mindspore/include/api/dual_abi_helper.h +16 -3
  148. mindspore/include/api/model.h +4 -3
  149. mindspore/include/api/model_group.h +13 -1
  150. mindspore/include/api/status.h +14 -0
  151. mindspore/include/api/types.h +10 -10
  152. mindspore/include/c_api/model_c.h +173 -0
  153. mindspore/include/c_api/types_c.h +19 -0
  154. mindspore/include/dataset/config.h +2 -2
  155. mindspore/include/dataset/constants.h +2 -2
  156. mindspore/include/dataset/execute.h +3 -5
  157. mindspore/include/dataset/vision.h +58 -2
  158. mindspore/jpeg62.dll +0 -0
  159. mindspore/log.py +3 -3
  160. mindspore/mindrecord/__init__.py +5 -1
  161. mindspore/mindrecord/config.py +809 -0
  162. mindspore/mindrecord/filereader.py +25 -0
  163. mindspore/mindrecord/filewriter.py +138 -103
  164. mindspore/mindrecord/mindpage.py +40 -6
  165. mindspore/mindrecord/shardutils.py +3 -2
  166. mindspore/mindrecord/shardwriter.py +7 -0
  167. mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
  168. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
  169. mindspore/mindrecord/tools/csv_to_mr.py +4 -9
  170. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  171. mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
  172. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
  173. mindspore/mindspore_backend.dll +0 -0
  174. mindspore/mindspore_common.dll +0 -0
  175. mindspore/mindspore_core.dll +0 -0
  176. mindspore/mindspore_glog.dll +0 -0
  177. mindspore/mindspore_np_dtype.dll +0 -0
  178. mindspore/mindspore_ops.dll +0 -0
  179. mindspore/mint/__init__.py +1586 -0
  180. mindspore/mint/distributed/__init__.py +31 -0
  181. mindspore/mint/distributed/distributed.py +254 -0
  182. mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
  183. mindspore/mint/nn/__init__.py +757 -0
  184. mindspore/mint/nn/functional.py +679 -0
  185. mindspore/mint/nn/layer/__init__.py +39 -0
  186. mindspore/mint/nn/layer/activation.py +133 -0
  187. mindspore/mint/nn/layer/normalization.py +477 -0
  188. mindspore/mint/nn/layer/pooling.py +110 -0
  189. mindspore/mint/optim/__init__.py +24 -0
  190. mindspore/mint/optim/adamw.py +206 -0
  191. mindspore/mint/special/__init__.py +63 -0
  192. mindspore/msobj140.dll +0 -0
  193. mindspore/mspdb140.dll +0 -0
  194. mindspore/mspdbcore.dll +0 -0
  195. mindspore/mspdbst.dll +0 -0
  196. mindspore/mspft140.dll +0 -0
  197. mindspore/msvcdis140.dll +0 -0
  198. mindspore/msvcp140_1.dll +0 -0
  199. mindspore/msvcp140_2.dll +0 -0
  200. mindspore/msvcp140_atomic_wait.dll +0 -0
  201. mindspore/msvcp140_codecvt_ids.dll +0 -0
  202. mindspore/multiprocessing/__init__.py +73 -0
  203. mindspore/nn/cell.py +461 -323
  204. mindspore/nn/dynamic_lr.py +2 -2
  205. mindspore/nn/layer/activation.py +292 -135
  206. mindspore/nn/layer/basic.py +288 -83
  207. mindspore/nn/layer/channel_shuffle.py +3 -16
  208. mindspore/nn/layer/container.py +3 -3
  209. mindspore/nn/layer/conv.py +75 -66
  210. mindspore/nn/layer/embedding.py +221 -45
  211. mindspore/nn/layer/image.py +4 -7
  212. mindspore/nn/layer/math.py +1 -1
  213. mindspore/nn/layer/normalization.py +150 -68
  214. mindspore/nn/layer/padding.py +64 -87
  215. mindspore/nn/layer/pooling.py +175 -12
  216. mindspore/nn/layer/rnn_cells.py +6 -16
  217. mindspore/nn/layer/rnns.py +6 -5
  218. mindspore/nn/layer/thor_layer.py +1 -2
  219. mindspore/nn/layer/timedistributed.py +1 -1
  220. mindspore/nn/layer/transformer.py +55 -53
  221. mindspore/nn/learning_rate_schedule.py +6 -5
  222. mindspore/nn/loss/__init__.py +2 -2
  223. mindspore/nn/loss/loss.py +145 -88
  224. mindspore/nn/optim/__init__.py +2 -1
  225. mindspore/nn/optim/ada_grad.py +4 -2
  226. mindspore/nn/optim/adadelta.py +4 -2
  227. mindspore/nn/optim/adafactor.py +1 -1
  228. mindspore/nn/optim/adam.py +102 -181
  229. mindspore/nn/optim/adamax.py +4 -2
  230. mindspore/nn/optim/adasum.py +3 -3
  231. mindspore/nn/optim/asgd.py +4 -2
  232. mindspore/nn/optim/ftrl.py +31 -61
  233. mindspore/nn/optim/lamb.py +5 -3
  234. mindspore/nn/optim/lars.py +2 -2
  235. mindspore/nn/optim/lazyadam.py +6 -4
  236. mindspore/nn/optim/momentum.py +13 -25
  237. mindspore/nn/optim/optimizer.py +6 -3
  238. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  239. mindspore/nn/optim/rmsprop.py +9 -3
  240. mindspore/nn/optim/rprop.py +4 -2
  241. mindspore/nn/optim/sgd.py +5 -3
  242. mindspore/nn/optim/tft_wrapper.py +127 -0
  243. mindspore/nn/optim/thor.py +2 -2
  244. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  245. mindspore/nn/probability/distribution/beta.py +2 -2
  246. mindspore/nn/probability/distribution/categorical.py +4 -6
  247. mindspore/nn/probability/distribution/cauchy.py +2 -2
  248. mindspore/nn/probability/distribution/exponential.py +2 -2
  249. mindspore/nn/probability/distribution/geometric.py +1 -1
  250. mindspore/nn/probability/distribution/gumbel.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +1 -1
  252. mindspore/nn/probability/distribution/poisson.py +2 -2
  253. mindspore/nn/probability/distribution/uniform.py +2 -2
  254. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  255. mindspore/nn/wrap/__init__.py +2 -1
  256. mindspore/nn/wrap/cell_wrapper.py +46 -12
  257. mindspore/nn/wrap/grad_reducer.py +148 -8
  258. mindspore/nn/wrap/loss_scale.py +44 -7
  259. mindspore/numpy/__init__.py +2 -0
  260. mindspore/numpy/array_creations.py +67 -68
  261. mindspore/numpy/array_ops.py +70 -66
  262. mindspore/numpy/dtypes.py +3 -3
  263. mindspore/numpy/fft.py +966 -0
  264. mindspore/numpy/logic_ops.py +11 -10
  265. mindspore/numpy/math_ops.py +147 -152
  266. mindspore/numpy/utils.py +3 -0
  267. mindspore/numpy/utils_const.py +4 -4
  268. mindspore/opencv_core452.dll +0 -0
  269. mindspore/opencv_imgcodecs452.dll +0 -0
  270. mindspore/opencv_imgproc452.dll +0 -0
  271. mindspore/ops/__init__.py +9 -6
  272. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
  273. mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
  274. mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
  275. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  276. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  277. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  278. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  279. mindspore/ops/_op_impl/__init__.py +0 -1
  280. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  281. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
  282. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  283. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  284. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  285. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  286. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  287. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  288. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  289. mindspore/ops/_vmap/vmap_array_ops.py +162 -101
  290. mindspore/ops/_vmap/vmap_base.py +8 -1
  291. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  292. mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
  293. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  294. mindspore/ops/_vmap/vmap_math_ops.py +147 -59
  295. mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
  296. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  297. mindspore/ops/auto_generate/__init__.py +31 -0
  298. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  299. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  300. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  301. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  302. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  303. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  304. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  305. mindspore/ops/composite/__init__.py +5 -2
  306. mindspore/ops/composite/base.py +201 -66
  307. mindspore/ops/composite/math_ops.py +10 -49
  308. mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
  309. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
  310. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  311. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  312. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  313. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  314. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  315. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  316. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  317. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  319. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  320. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  321. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  323. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  324. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  325. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  326. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  327. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  328. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  330. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  331. mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  333. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  334. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  335. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  336. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  337. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  338. mindspore/ops/deprecated.py +14 -3
  339. mindspore/ops/function/__init__.py +53 -11
  340. mindspore/ops/function/array_func.py +1269 -1821
  341. mindspore/ops/function/clip_func.py +19 -31
  342. mindspore/ops/function/debug_func.py +114 -5
  343. mindspore/ops/function/fft_func.py +44 -0
  344. mindspore/ops/function/grad/grad_func.py +30 -22
  345. mindspore/ops/function/image_func.py +27 -21
  346. mindspore/ops/function/linalg_func.py +35 -68
  347. mindspore/ops/function/math_func.py +1170 -2697
  348. mindspore/ops/function/nn_func.py +2116 -1128
  349. mindspore/ops/function/other_func.py +8 -8
  350. mindspore/ops/function/parameter_func.py +5 -93
  351. mindspore/ops/function/random_func.py +435 -113
  352. mindspore/ops/function/reshard_func.py +104 -0
  353. mindspore/ops/function/sparse_func.py +4 -4
  354. mindspore/ops/function/sparse_unary_func.py +9 -16
  355. mindspore/ops/function/spectral_func.py +1 -1
  356. mindspore/ops/function/vmap_func.py +16 -15
  357. mindspore/ops/functional.py +355 -346
  358. mindspore/ops/op_info_register.py +18 -45
  359. mindspore/ops/operations/__init__.py +38 -24
  360. mindspore/ops/operations/_grad_ops.py +21 -927
  361. mindspore/ops/operations/_infer_ops.py +19 -0
  362. mindspore/ops/operations/_inner_ops.py +173 -607
  363. mindspore/ops/operations/_rl_inner_ops.py +2 -2
  364. mindspore/ops/operations/_scalar_ops.py +5 -480
  365. mindspore/ops/operations/_sequence_ops.py +6 -36
  366. mindspore/ops/operations/_tensor_array.py +8 -8
  367. mindspore/ops/operations/array_ops.py +106 -2837
  368. mindspore/ops/operations/comm_ops.py +799 -127
  369. mindspore/ops/operations/custom_ops.py +124 -119
  370. mindspore/ops/operations/debug_ops.py +142 -41
  371. mindspore/ops/operations/image_ops.py +1 -217
  372. mindspore/ops/operations/inner_ops.py +5 -40
  373. mindspore/ops/operations/linalg_ops.py +1 -49
  374. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  375. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  376. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  377. mindspore/ops/operations/math_ops.py +666 -4972
  378. mindspore/ops/operations/nn_ops.py +205 -2213
  379. mindspore/ops/operations/other_ops.py +60 -49
  380. mindspore/ops/operations/random_ops.py +50 -54
  381. mindspore/ops/operations/reshard_ops.py +53 -0
  382. mindspore/ops/operations/sparse_ops.py +4 -4
  383. mindspore/ops/primitive.py +216 -103
  384. mindspore/ops_generate/__init__.py +27 -0
  385. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  386. mindspore/ops_generate/arg_handler.py +197 -0
  387. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  388. mindspore/ops_generate/gen_constants.py +36 -0
  389. mindspore/ops_generate/gen_ops.py +1099 -0
  390. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  391. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  392. mindspore/ops_generate/gen_utils.py +209 -0
  393. mindspore/ops_generate/op_proto.py +145 -0
  394. mindspore/ops_generate/pyboost_utils.py +367 -0
  395. mindspore/ops_generate/template.py +261 -0
  396. mindspore/parallel/__init__.py +8 -4
  397. mindspore/parallel/_auto_parallel_context.py +100 -10
  398. mindspore/parallel/_cell_wrapper.py +99 -9
  399. mindspore/parallel/_cost_model_context.py +1 -1
  400. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  401. mindspore/parallel/_parallel_serialization.py +67 -23
  402. mindspore/parallel/_ps_context.py +1 -1
  403. mindspore/parallel/_recovery_context.py +1 -1
  404. mindspore/parallel/_tensor.py +99 -22
  405. mindspore/parallel/_transformer/__init__.py +1 -1
  406. mindspore/parallel/_transformer/layers.py +1 -1
  407. mindspore/parallel/_transformer/loss.py +1 -1
  408. mindspore/parallel/_transformer/moe.py +1 -1
  409. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  410. mindspore/parallel/_transformer/transformer.py +2 -2
  411. mindspore/parallel/_utils.py +173 -6
  412. mindspore/parallel/algo_parameter_config.py +8 -10
  413. mindspore/parallel/checkpoint_transform.py +204 -38
  414. mindspore/parallel/cluster/__init__.py +15 -0
  415. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  416. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  417. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  418. mindspore/parallel/cluster/run.py +136 -0
  419. mindspore/parallel/mpi/__init__.py +1 -1
  420. mindspore/parallel/mpi/_mpi_config.py +1 -1
  421. mindspore/parallel/parameter_broadcast.py +151 -0
  422. mindspore/parallel/shard.py +279 -37
  423. mindspore/parallel/transform_safetensors.py +993 -0
  424. mindspore/pgodb140.dll +0 -0
  425. mindspore/pgort140.dll +0 -0
  426. mindspore/profiler/__init__.py +4 -2
  427. mindspore/profiler/common/constant.py +29 -0
  428. mindspore/profiler/common/process_pool.py +41 -0
  429. mindspore/profiler/common/registry.py +47 -0
  430. mindspore/profiler/common/singleton.py +28 -0
  431. mindspore/profiler/common/util.py +153 -0
  432. mindspore/profiler/dynamic_profiler.py +694 -0
  433. mindspore/profiler/envprofiling.py +18 -20
  434. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  435. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  436. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  437. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  438. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  439. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  440. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  441. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  442. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  443. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  444. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  445. mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
  446. mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
  447. mindspore/profiler/parser/ascend_flops_generator.py +20 -4
  448. mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
  449. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  450. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  451. mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
  452. mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
  453. mindspore/profiler/parser/ascend_op_generator.py +92 -42
  454. mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
  455. mindspore/profiler/parser/base_timeline_generator.py +25 -25
  456. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  457. mindspore/profiler/parser/framework_parser.py +4 -393
  458. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  459. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  460. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  461. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  462. mindspore/profiler/parser/integrator.py +3 -1
  463. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  464. mindspore/profiler/parser/minddata_parser.py +72 -3
  465. mindspore/profiler/parser/profiler_info.py +94 -7
  466. mindspore/profiler/profiler.py +153 -0
  467. mindspore/profiler/profiling.py +631 -508
  468. mindspore/rewrite/__init__.py +2 -14
  469. mindspore/rewrite/api/node.py +122 -36
  470. mindspore/rewrite/api/pattern_engine.py +2 -3
  471. mindspore/rewrite/api/scoped_value.py +16 -15
  472. mindspore/rewrite/api/symbol_tree.py +45 -29
  473. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  474. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  475. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  476. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  477. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  478. mindspore/rewrite/common/__init__.py +1 -2
  479. mindspore/rewrite/common/config.py +24 -0
  480. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  481. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  482. mindspore/rewrite/common/namespace.py +118 -0
  483. mindspore/rewrite/node/__init__.py +5 -5
  484. mindspore/rewrite/node/call_function.py +23 -7
  485. mindspore/rewrite/node/cell_container.py +7 -3
  486. mindspore/rewrite/node/control_flow.py +53 -28
  487. mindspore/rewrite/node/node.py +212 -196
  488. mindspore/rewrite/node/node_manager.py +51 -22
  489. mindspore/rewrite/node/node_topological_manager.py +3 -23
  490. mindspore/rewrite/parsers/__init__.py +12 -0
  491. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  492. mindspore/rewrite/parsers/assign_parser.py +637 -413
  493. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  494. mindspore/rewrite/parsers/class_def_parser.py +115 -148
  495. mindspore/rewrite/parsers/constant_parser.py +5 -5
  496. mindspore/rewrite/parsers/container_parser.py +4 -6
  497. mindspore/rewrite/parsers/expr_parser.py +55 -0
  498. mindspore/rewrite/parsers/for_parser.py +31 -98
  499. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  500. mindspore/rewrite/parsers/if_parser.py +28 -10
  501. mindspore/rewrite/parsers/module_parser.py +8 -182
  502. mindspore/rewrite/parsers/parser.py +1 -5
  503. mindspore/rewrite/parsers/parser_register.py +1 -1
  504. mindspore/rewrite/parsers/return_parser.py +5 -10
  505. mindspore/rewrite/parsers/while_parser.py +59 -0
  506. mindspore/rewrite/sparsify/utils.py +1 -1
  507. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  508. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
  509. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  510. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  511. mindspore/run_check/_check_version.py +40 -115
  512. mindspore/run_check/run_check.py +1 -1
  513. mindspore/safeguard/rewrite_obfuscation.py +597 -263
  514. mindspore/swresample-4.dll +0 -0
  515. mindspore/swscale-6.dll +0 -0
  516. mindspore/tbbmalloc.dll +0 -0
  517. mindspore/tinyxml2.dll +0 -0
  518. mindspore/train/__init__.py +7 -5
  519. mindspore/train/_utils.py +204 -4
  520. mindspore/train/amp.py +335 -295
  521. mindspore/train/anf_ir_pb2.py +14 -2
  522. mindspore/train/callback/__init__.py +5 -2
  523. mindspore/train/callback/_backup_and_restore.py +5 -5
  524. mindspore/train/callback/_callback.py +4 -4
  525. mindspore/train/callback/_checkpoint.py +220 -43
  526. mindspore/train/callback/_cluster_monitor.py +201 -0
  527. mindspore/train/callback/_early_stop.py +2 -2
  528. mindspore/train/callback/_flops_collector.py +239 -0
  529. mindspore/train/callback/_landscape.py +15 -9
  530. mindspore/train/callback/_loss_monitor.py +5 -5
  531. mindspore/train/callback/_on_request_exit.py +136 -33
  532. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  533. mindspore/train/callback/_summary_collector.py +12 -12
  534. mindspore/train/callback/_tft_register.py +352 -0
  535. mindspore/train/callback/_time_monitor.py +3 -3
  536. mindspore/train/data_sink.py +6 -5
  537. mindspore/train/dataset_helper.py +66 -23
  538. mindspore/train/loss_scale_manager.py +2 -2
  539. mindspore/train/metrics/accuracy.py +7 -7
  540. mindspore/train/metrics/confusion_matrix.py +8 -6
  541. mindspore/train/metrics/cosine_similarity.py +6 -4
  542. mindspore/train/metrics/error.py +2 -2
  543. mindspore/train/metrics/metric.py +3 -3
  544. mindspore/train/metrics/perplexity.py +2 -1
  545. mindspore/train/metrics/roc.py +4 -4
  546. mindspore/train/metrics/topk.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +116 -37
  548. mindspore/train/model.py +382 -76
  549. mindspore/train/serialization.py +787 -288
  550. mindspore/train/summary/_summary_adapter.py +1 -1
  551. mindspore/train/summary/summary_record.py +51 -28
  552. mindspore/train/train_thor/convert_utils.py +3 -3
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +21 -0
  555. mindspore/utils/utils.py +60 -0
  556. mindspore/vcmeta.dll +0 -0
  557. mindspore/vcruntime140.dll +0 -0
  558. mindspore/vcruntime140_1.dll +0 -0
  559. mindspore/version.py +1 -1
  560. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
  561. mindspore-2.4.0.dist-info/RECORD +1406 -0
  562. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
  563. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  564. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  565. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  566. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  567. mindspore/gen_ops.py +0 -273
  568. mindspore/include/c_api/ms/abstract.h +0 -67
  569. mindspore/include/c_api/ms/attribute.h +0 -197
  570. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  571. mindspore/include/c_api/ms/base/macros.h +0 -32
  572. mindspore/include/c_api/ms/base/status.h +0 -33
  573. mindspore/include/c_api/ms/base/types.h +0 -282
  574. mindspore/include/c_api/ms/context.h +0 -102
  575. mindspore/include/c_api/ms/graph.h +0 -160
  576. mindspore/include/c_api/ms/node.h +0 -606
  577. mindspore/include/c_api/ms/tensor.h +0 -161
  578. mindspore/include/c_api/ms/value.h +0 -84
  579. mindspore/mindspore_shared_lib.dll +0 -0
  580. mindspore/nn/layer/flash_attention.py +0 -189
  581. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  582. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  583. mindspore/ops/_op_impl/cpu/concat.py +0 -39
  584. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  585. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  586. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  587. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  588. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  589. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  590. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  591. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  592. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  593. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  594. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  595. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  596. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  597. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  598. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  599. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  600. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  601. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  602. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  603. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  604. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  605. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  606. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  607. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  608. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  609. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  610. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  611. mindspore/ops/_op_impl/tbe/add.py +0 -42
  612. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  613. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  614. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  615. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  616. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  617. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  618. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  619. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  620. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  621. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  622. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  623. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  624. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  625. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  626. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  627. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  628. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  629. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  630. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  631. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  632. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  633. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  634. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  635. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  636. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  637. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  638. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  639. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  640. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  641. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  642. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  643. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  644. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  645. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  646. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  647. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  648. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  649. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  650. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  651. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  652. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  653. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  654. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  655. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  656. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  657. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  658. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  659. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  660. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  661. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  662. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  663. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  664. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  665. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  666. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  667. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  668. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  669. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  670. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  671. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  672. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  673. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  674. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  675. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  676. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  677. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  678. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  679. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  680. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  681. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  682. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  683. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  684. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  685. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  686. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  687. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  688. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  689. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  690. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  691. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  692. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  693. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  694. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  695. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  696. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  697. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  698. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  699. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  700. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  701. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  702. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  703. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  704. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  705. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  706. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  707. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  708. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  709. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  710. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  711. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  712. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  713. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  714. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  715. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  716. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  717. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  718. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  719. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  720. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  721. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  722. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  723. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  724. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  725. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  726. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  727. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  728. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  729. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  730. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  731. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  732. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  733. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  734. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  735. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  736. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  737. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  738. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  739. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  740. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  741. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  742. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  743. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  744. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  745. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  746. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  747. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  748. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  749. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  750. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  751. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  752. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  753. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  754. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  755. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  756. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  757. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  758. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  759. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  760. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  761. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  762. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  763. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  764. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  765. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  766. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  767. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  768. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  769. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  770. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  771. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  772. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  773. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  774. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  775. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  776. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  777. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  778. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  779. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  780. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  781. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  782. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  783. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  784. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  785. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  786. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  787. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  788. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  789. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  790. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  791. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  792. mindspore/ops/_op_impl/tbe/div.py +0 -41
  793. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  794. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  795. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  796. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  797. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  798. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  799. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  800. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  801. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  802. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  803. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  804. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  805. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  806. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  807. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  808. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  809. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  810. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  811. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  812. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  813. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  814. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  815. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  816. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  817. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  818. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  819. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  820. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  821. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  822. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  823. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  824. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  825. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  826. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  827. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  828. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  829. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  830. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  831. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  832. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  833. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  834. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  835. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  836. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  837. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  838. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  839. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  840. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  841. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  842. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  843. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  844. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  845. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  846. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  847. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  848. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  849. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  850. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  851. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  852. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  853. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  854. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  855. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  856. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  857. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  858. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  859. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  860. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  861. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  862. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  863. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  864. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  865. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  866. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  867. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  868. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  869. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  870. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  871. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  872. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  873. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  874. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  875. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  876. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  877. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  878. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  879. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  880. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  881. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  882. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  883. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  884. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  885. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  886. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  887. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  888. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  889. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  890. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  891. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  892. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  893. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  894. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  895. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  896. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  897. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  898. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  899. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  900. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  901. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  902. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  903. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  904. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  905. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  906. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  907. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  908. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  909. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  910. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  911. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  912. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  913. mindspore/ops/_op_impl/tbe/less.py +0 -41
  914. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  915. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  916. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  917. mindspore/ops/_op_impl/tbe/log.py +0 -40
  918. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  919. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  920. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  921. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  922. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  923. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  924. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  925. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  926. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  927. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  928. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  929. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  930. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  931. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  932. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  933. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  934. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  935. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  936. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  937. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  938. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  939. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  940. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  941. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  942. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  943. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  944. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  945. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  946. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  947. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  948. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  949. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  950. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  951. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  952. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  953. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  954. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  955. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  956. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  957. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  958. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  959. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  960. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  961. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  962. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  963. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  964. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  965. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  966. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  967. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  968. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  969. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  970. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  971. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  972. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  973. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  974. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  975. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  976. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  977. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  978. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  979. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  980. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  981. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  982. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  983. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  984. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  985. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  986. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  987. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  988. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  989. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  990. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  991. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  992. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  993. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  994. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  995. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  996. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  997. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  998. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  999. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  1000. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  1001. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  1002. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  1003. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  1004. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  1005. mindspore/ops/_op_impl/tbe/range.py +0 -39
  1006. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  1007. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  1008. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  1009. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  1010. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  1011. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  1012. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  1013. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  1014. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  1015. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  1016. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  1017. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  1018. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  1019. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  1020. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  1021. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  1022. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  1023. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  1024. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  1025. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  1026. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  1027. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  1028. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  1029. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  1030. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  1031. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  1032. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  1033. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  1034. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  1035. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  1036. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  1037. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  1038. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  1039. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  1040. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  1041. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  1042. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  1043. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  1044. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  1045. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  1046. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  1047. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  1048. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  1049. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  1050. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  1051. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  1052. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  1053. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  1054. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  1055. mindspore/ops/_op_impl/tbe/round.py +0 -38
  1056. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  1057. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  1058. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  1059. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  1060. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  1061. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  1062. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1063. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1064. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1065. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1066. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1067. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1068. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1069. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1070. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1071. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1072. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1073. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1074. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1075. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1076. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1077. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1078. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1079. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1080. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1081. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1082. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1083. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1084. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1085. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1086. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1087. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1088. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1089. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1090. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1091. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1092. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1093. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1094. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1095. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1096. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1097. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1098. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1099. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1100. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1101. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1102. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1103. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1104. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1105. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1106. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1107. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1108. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1109. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1110. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1111. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1112. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1113. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1114. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1115. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1116. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1117. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1118. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1119. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1120. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1121. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1122. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1123. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1124. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1125. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1126. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1127. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1128. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1129. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1130. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1131. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1132. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1133. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1134. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1135. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1136. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1137. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1138. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1139. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1140. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1141. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1142. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1143. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1144. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1145. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1146. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1147. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1148. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1149. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1150. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1151. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1152. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1153. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1154. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1155. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1156. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1157. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1158. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1159. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1160. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1161. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1162. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1163. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1164. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1165. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1166. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1167. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1168. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1169. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1170. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1171. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1172. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1173. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1174. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1175. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1176. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1177. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1178. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1179. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1180. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1181. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1182. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1183. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1184. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1185. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1186. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1187. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1188. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1189. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1190. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1191. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1192. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1193. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1194. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1195. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1196. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1197. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1198. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1199. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1200. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1201. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1202. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1203. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1204. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1205. mindspore/ops/_tracefunc.py +0 -241
  1206. mindspore/ops/arg_dtype_cast.py +0 -54
  1207. mindspore/ops/silent_check.py +0 -162
  1208. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  1209. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  1210. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1211. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1212. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1213. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1214. mindspore/rewrite/namespace.py +0 -53
  1215. mindspore-2.2.14.dist-info/RECORD +0 -1924
  1216. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  1217. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2022-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import numpy as np
24
24
  import mindspore as ms
25
25
  import mindspore.common.dtype as mstype
26
26
  from mindspore.ops import operations as P
27
+ from mindspore.ops import functional as F
27
28
  from mindspore.ops.primitive import constexpr
28
29
  from mindspore.ops.primitive import _primexpr
29
30
  import mindspore.ops as ops
@@ -31,18 +32,18 @@ from mindspore.ops.operations._inner_ops import DynamicBroadcastTo
31
32
  from mindspore.ops.operations._sequence_ops import TupleToTensor
32
33
  from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
33
34
  from mindspore.ops.operations._sequence_ops import TensorToList
34
-
35
+ from mindspore.ops.auto_generate import OnesLikeExt, ZerosLikeExt, FillScalar, FillTensor, Arange, Chunk, UniqueDim, \
36
+ Unique2, SortExt, NonZero, NonZeroExt, Scatter, ScatterValue
37
+ from mindspore.ops.auto_generate.gen_ops_prim import SplitTensor
38
+ from mindspore.ops.auto_generate.gen_ops_prim import SplitWithSize, RepeatInterleaveInt, RepeatInterleaveTensor
39
+ from mindspore.ops.auto_generate.pyboost_inner_prim import _PyboostSearchSortedPrim
35
40
  from mindspore.ops.operations.array_ops import (
36
41
  UniqueConsecutive,
37
- SearchSorted,
38
- NonZero,
39
42
  MatrixDiagV3,
40
43
  MatrixDiagPartV3,
41
44
  MatrixSetDiagV3,
42
45
  Fills,
43
46
  Col2Im,
44
- ArgMaxWithValue,
45
- ArgMinWithValue,
46
47
  ScatterNdMax,
47
48
  ScatterNdMul,
48
49
  IndexFill,
@@ -52,62 +53,97 @@ from mindspore.ops.operations.array_ops import (
52
53
  Lstsq,
53
54
  Mvlgamma,
54
55
  Tril,
55
- Argmax
56
+ Argmax,
57
+ ArgMaxWithValue,
58
+ ArgMinWithValue
56
59
  )
57
- from mindspore.ops.operations.array_ops import TensorScatterElements
58
60
  from mindspore.common import Tensor
59
61
  from mindspore.ops._primitive_cache import _get_cache_prim
60
62
  from mindspore import _checkparam as validator
61
63
  from mindspore._c_expression import Tensor as Tensor_
62
64
  from mindspore.ops._utils.utils import ms_arrange
63
65
 
64
- tuple_to_tensor_ = TupleToTensor()
66
+ from mindspore.ops.auto_generate import cat, range, scatter_nd, deepcopy, masked_fill, diagonal, expand_dims, \
67
+ flip, transpose, triu, unsorted_segment_sum, diag, gather, gather_d, gather_nd, reshape, masked_select, \
68
+ broadcast_to, strided_slice, ones, zeros, max_, min_, select, zero_
69
+ from mindspore.ops.auto_generate import tensor_scatter_elements as tensor_scatter_elements_ext
70
+ from mindspore.ops.auto_generate.gen_ops_prim import scatter_add_ext_op, slice_ext_op, gather_d_op
71
+ from mindspore.ops.operations.manually_defined import tile, rank, scalar_cast
72
+ from mindspore.ops.auto_generate.pyboost_inner_prim import _PyboostOneHotExtPrim, tril_ext_impl
73
+
74
+ arg_max_with_value_ = ArgMaxWithValue()
75
+ arg_min_with_value_ = ArgMinWithValue()
76
+ batch_to_space_nd_v2_ = P.BatchToSpaceNDV2()
77
+ cast_ = P.Cast()
78
+ diag_ = P.Diag()
79
+ dynamic_broadcast_to_ = DynamicBroadcastTo()
65
80
  eye_ = P.Eye()
66
81
  fills_ = Fills()
82
+ fillv2_ = P.FillV2()
83
+ flatten_ = P.Flatten()
84
+ gather_ = P.Gather()
85
+ gather_d_ = P.GatherD()
86
+ gather_nd_ = P.GatherNd()
87
+ ger_ = P.Ger()
88
+ index_fill_ = IndexFill()
89
+ lstsq_ = Lstsq()
90
+ matrix_band_part_ = P.array_ops.MatrixBandPart()
67
91
  ones_ = P.Ones()
68
- ones_like_ = P.OnesLike()
69
- tile_ = P.Tile()
70
- unique_with_pad_ = P.UniqueWithPad()
71
- size_ = P.Size()
72
- shape_ = P.Shape()
92
+ population_count_ = P.PopulationCount()
93
+ range_ = P.Range()
73
94
  rank_ = P.Rank()
74
- tensor_shape_ = P.TensorShape()
95
+ reduce_max_ = P.ReduceMax()
96
+ reduce_min_ = P.ReduceMin()
75
97
  reshape_ = P.Reshape()
76
- tensor_slice = P.Slice()
77
- expand_dims_ = P.ExpandDims()
78
- transpose_ = P.Transpose()
98
+ scalar_to_tensor_ = P.ScalarToTensor()
79
99
  scatter_add_ = P.ScatterAdd()
100
+ scatter_div_ = P.ScatterDiv()
80
101
  scatter_max_ = P.ScatterMax()
81
102
  scatter_min_ = P.ScatterMin()
82
103
  scatter_mul_ = P.ScatterMul()
83
- scatter_div_ = P.ScatterDiv()
84
104
  scatter_nd_ = P.ScatterNd()
85
- gather_ = P.Gather()
86
- gather_d_ = P.GatherD()
87
- gather_nd_ = P.GatherNd()
88
- nonzero_ = NonZero()
89
- scalar_cast_ = P.ScalarCast()
105
+ scatter_update_ = P.ScatterUpdate()
106
+ search_sorted_ = _PyboostSearchSortedPrim()
107
+ shape_ = P.Shape()
108
+ split_tensor = SplitTensor()
109
+ split_with_size = SplitWithSize()
110
+ size_ = P.Size()
90
111
  tensor_scatter_add_ = P.TensorScatterAdd()
91
- tensor_scatter_sub_ = P.TensorScatterSub()
92
- tensor_scatter_mul_ = P.TensorScatterMul()
93
112
  tensor_scatter_div_ = P.TensorScatterDiv()
94
- tensor_scatter_min_ = P.TensorScatterMin()
95
113
  tensor_scatter_max_ = P.TensorScatterMax()
96
- scalar_to_tensor_ = P.ScalarToTensor()
97
- tuple_to_array_ = P.TupleToArray()
98
- masked_select_ = P.MaskedSelect()
99
- matrix_band_part_ = P.array_ops.MatrixBandPart()
100
- ger_ = P.Ger()
101
- diag_ = P.Diag()
102
- range_ = P.Range()
103
- zeros_like_ = P.ZerosLike()
104
- cast_ = P.Cast()
114
+ tensor_scatter_min_ = P.TensorScatterMin()
115
+ tensor_scatter_mul_ = P.TensorScatterMul()
116
+ tensor_scatter_sub_ = P.TensorScatterSub()
105
117
  tensor_select_ = P.Select()
106
- index_fill_ = IndexFill()
118
+ tensor_shape_ = P.TensorShape()
119
+ tensor_slice = P.Slice()
120
+ tile_ = P.Tile()
121
+ transpose_ = P.Transpose()
122
+ tuple_to_array_ = P.TupleToArray()
123
+ tuple_to_tensor_ = TupleToTensor()
124
+ unique_ = P.Unique()
125
+ unsorted_segment_max_ = P.UnsortedSegmentMax()
126
+ unsorted_segment_min_ = P.UnsortedSegmentMin()
127
+ unsorted_segment_prod_ = P.UnsortedSegmentProd()
107
128
  unsorted_segment_sum_ = P.UnsortedSegmentSum()
108
- population_count_ = P.PopulationCount()
109
- reduce_max = P.ReduceMax()
110
- reduce_min = P.ReduceMin()
129
+ ones_like_ = P.OnesLike()
130
+ one_hot_ext_impl = _PyboostOneHotExtPrim()
131
+ zeros_like_ = P.ZerosLike()
132
+ ones_like_ext_ = OnesLikeExt()
133
+ zeros_like_ext_ = ZerosLikeExt()
134
+ fill_scalar_ = FillScalar()
135
+ fill_tensor_ = FillTensor()
136
+ sort_ext_ = SortExt()
137
+ scatter_ = Scatter()
138
+ scatter_value_ = ScatterValue()
139
+ arange_ = Arange()
140
+ chunk_ = Chunk()
141
+ repeat_interleave_int_ = RepeatInterleaveInt()
142
+ repeat_interleave_tensor_ = RepeatInterleaveTensor()
143
+ unique_dim_ = UniqueDim()
144
+ unique2_ = Unique2()
145
+ non_zero_ = NonZero()
146
+ non_zero_ext_ = NonZeroExt()
111
147
 
112
148
 
113
149
  @_primexpr
@@ -165,7 +201,8 @@ def _get_max_type(start, end, step):
165
201
 
166
202
  type_map = {'Float64': '3', 'Float32': '2', "<class 'float'>": '2', 'Int64': '1', "<class 'int'>": '1',
167
203
  'Int32': '0'}
168
- type_map_reverse = {'3': mstype.float64, '2': mstype.float32, '1': mstype.int64, '0': mstype.int32}
204
+ type_map_reverse = {'3': mstype.float64,
205
+ '2': mstype.float32, '1': mstype.int64, '0': mstype.int32}
169
206
  type_level = [type_map.get(i) for i in arg_type_map]
170
207
  max_level = builtins.max(type_level)
171
208
  return type_map_reverse.get(max_level)
@@ -187,8 +224,11 @@ def arange(start=0, end=None, step=1, *, dtype=None):
187
224
 
188
225
  Keyword Args:
189
226
  dtype (mindspore.dtype, optional): The required data type of returned Tensor. Default: ``None`` .
190
- If the value is not specified or is ``None`` , the type with the highest precision in the
191
- `start`, `end`, and `step` parameters is inferred.
227
+ When `dtype` is not specified or ``None``:
228
+
229
+ If `start`, `end`, and `step` are all integers, the dtype of output is int64,
230
+
231
+ If `start`, `end`, and `step` contain at least one floating-point number, the dtype of output is float32.
192
232
 
193
233
  Returns:
194
234
  A 1-D Tensor, with the same type as the inputs.
@@ -225,7 +265,7 @@ def arange(start=0, end=None, step=1, *, dtype=None):
225
265
  >>> print(output)
226
266
  [12. 11. 10. 9. 8. 7. 6. 5. 4. 3.]
227
267
  >>> print(output.dtype)
228
- Float64
268
+ Float32
229
269
  """
230
270
  if end is None:
231
271
  start, end = 0, start
@@ -237,67 +277,84 @@ def arange(start=0, end=None, step=1, *, dtype=None):
237
277
  if start.shape != () or end.shape != () or step.shape != ():
238
278
  raise ValueError(f"For arange, the input args must be a TensorScalar,"
239
279
  f" but got start shape:{start.shape}, end shape:{end.shape}, step shape:{step.shape}")
240
- range_op = _get_cache_prim(P.Range)()
241
- data = range_op(start, end, step)
280
+ data = range_(start, end, step)
242
281
  if dtype is not None:
243
282
  data = cast_(data, dtype)
244
283
  return data
245
284
 
246
285
 
247
- def cat(tensors, axis=0):
286
+ def arange_ext(start=0, end=None, step=1, *, dtype=None):
248
287
  r"""
249
- Connect input tensors along with the given axis.
288
+ Creates a sequence of numbers that begins at `start` and extends by increments of
289
+ `step` up to but not including `end`.
250
290
 
251
- The input data is a tuple or a list of tensors. These tensors have the same rank :math:`R`.
252
- Set the given axis as :math:`m`, and :math:`0 \le m < R`. Set the number of input tensors as :math:`N`.
253
- For the :math:`i`-th tensor :math:`t_i`, it has the shape of :math:`(x_1, x_2, ..., x_{mi}, ..., x_R)`.
254
- :math:`x_{mi}` is the :math:`m`-th dimension of the :math:`t_i`. Then, the shape of the output tensor is
291
+ Args:
292
+ start (Union[float, int], optional): The start of the interval. Default: ``0`` .
293
+ end (Union[float, int], optional): The end of the interval, exclusive.
294
+ Default: ``None`` . If ``None`` , it defaults to the value of `start`, and 0 is used as the starting value.
295
+ step (Union[float, int], optional): The step size with which the array element increments. Default: ``1`` .
255
296
 
256
- .. math::
297
+ Keyword Args:
298
+ dtype (mindspore.dtype, optional): The required data type of returned Tensor. Default: ``None`` .
299
+ When `dtype` is not specified or ``None``:
257
300
 
258
- (x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)
301
+ If `start`, `end`, and `step` are all integers, the dtype of output is int64,
259
302
 
260
- Args:
261
- tensors (Union[tuple, list]): A tuple or a list of input tensors.
262
- Suppose there are two tensors in this tuple or list, namely t1 and t2.
263
- To perform `concat` in the axis 0 direction, except for the :math:`0`-th axis,
264
- all other dimensions should be equal, that is,
265
- :math:`t1.shape[1] = t2.shape[1], t1.shape[2] = t2.shape[2], ..., t1.shape[R-1] = t2.shape[R-1]`,
266
- where :math:`R` represents the rank of tensor.
267
- axis (int): The specified axis, whose value is in range :math:`[-R, R)`. Default: ``0`` .
303
+ If `start`, `end`, and `step` contain at least one floating-point number, the dtype of output is float32.
268
304
 
269
305
  Returns:
270
- Tensor, the shape is :math:`(x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)`.
271
- The data type is the same with `tensors`.
306
+ A 1-D Tensor, cast to `dtype` if provided, may potentially lose precision due to casting.
272
307
 
273
308
  Raises:
274
- TypeError: If `axis` is not an int.
275
- ValueError: If `tensors` have different dimension of tensor.
276
- ValueError: If `axis` not in range :math:`[-R, R)`.
277
- RuntimeError: If tensor's shape in `tensors` except for `axis` are different.
309
+ TypeError: If `start`, `end` or `step` are not of type int or float.
310
+ ValueError: If `step` = 0.
311
+ ValueError: If `start` >= `end` when `step` > 0.
312
+ ValueError: If `start` <= `end` when `step` < 0.
278
313
 
279
314
  Supported Platforms:
280
- ``Ascend`` ``GPU`` ``CPU``
315
+ ``Ascend``
281
316
 
282
317
  Examples:
283
- >>> import mindspore
284
- >>> import numpy as np
318
+ >>> import mindspore as ms
285
319
  >>> from mindspore import Tensor, ops
286
- >>> input_x1 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
287
- >>> input_x2 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
288
- >>> output = ops.cat((input_x1, input_x2))
320
+ >>> output = ops.arange_ext(1, 6)
289
321
  >>> print(output)
290
- [[0. 1.]
291
- [2. 1.]
292
- [0. 1.]
293
- [2. 1.]]
294
- >>> output = ops.cat((input_x1, input_x2), 1)
322
+ [1 2 3 4 5]
323
+ >>> print(output.dtype)
324
+ Int64
325
+ >>> output = ops.arange_ext(0, 3, 1.2)
326
+ >>> print(output)
327
+ [0. 1.2 2.4]
328
+ >>> print(output.dtype)
329
+ Float32
330
+ >>> output = ops.arange_ext(7, 1, -2)
295
331
  >>> print(output)
296
- [[0. 1. 0. 1.]
297
- [2. 1. 2. 1.]]
332
+ [7 5 3]
333
+ >>> print(output.dtype)
334
+ Int64
335
+ >>> output = ops.arange_ext(12, 2, -1, dtype=ms.bfloat16)
336
+ >>> print(output)
337
+ [12. 11. 10. 9. 8. 7. 6. 5. 4. 3.]
338
+ >>> print(output.dtype)
339
+ BFloat16
340
+ """
341
+ if end is None:
342
+ start, end = 0, start
343
+ return arange_(start, end, step, dtype)
344
+
345
+
346
+ def concat(tensors, axis=0):
298
347
  """
299
- _concat = _get_cache_prim(P.Concat)(axis)
300
- return _concat(tensors)
348
+ Alias for :func:`mindspore.ops.cat()`.
349
+
350
+ Tutorial Examples:
351
+ - `Tensor - Tensor Operation <https://mindspore.cn/tutorials/en/master/beginner/tensor.html#tensor-operation>`_
352
+ - `Vision Transformer Image Classification - Building ViT as a whole
353
+ <https://mindspore.cn/tutorials/en/master/cv/vit.html#building-vit-as-a-whole>`_
354
+ - `Sentiment Classification Implemented by RNN - Dense
355
+ <https://mindspore.cn/tutorials/en/master/nlp/sentiment_analysis.html#dense>`_
356
+ """
357
+ return cat(tensors, axis)
301
358
 
302
359
 
303
360
  def eye(n, m=None, dtype=None):
@@ -305,14 +362,14 @@ def eye(n, m=None, dtype=None):
305
362
  Creates a tensor with ones on the diagonal and zeros in the rest.
306
363
 
307
364
  Note:
308
- Combines ReverseV2 operator to get an anti-diagonal Tensor,
309
- but ReverseV2 only supports Ascend and GPU platforms currently.
365
+ The data type of returned tensor can be float16, float32, int8, int16, int32, int64, uint8
366
+ or bool on Ascend platforms.
310
367
 
311
368
  Args:
312
369
  n (int): The number of rows of returned tensor. Constant value only.
313
- m (int): The number of columns of returned tensor. Constant value only.
370
+ m (int, optional): The number of columns of returned tensor. Constant value only.
314
371
  Default: ``None`` , if ``None`` , the number of columns is as the same as n.
315
- dtype (mindspore.dtype): MindSpore's dtype, the data type of the returned tensor.
372
+ dtype (mindspore.dtype, optional): MindSpore's dtype, the data type of the returned tensor.
316
373
  The data type can be bool or Number.
317
374
  Default: ``None`` , the data type of the returned tensor is mindspore.float32.
318
375
 
@@ -336,11 +393,11 @@ def eye(n, m=None, dtype=None):
336
393
  [0 1]]
337
394
  >>> print(output.dtype)
338
395
  Int32
339
- >>> output = ops.eye(1, 2, mindspore.float64)
396
+ >>> output = ops.eye(1, 2, mindspore.float32)
340
397
  >>> print(output)
341
398
  [[1. 0.]]
342
399
  >>> print(output.dtype)
343
- Float64
400
+ Float32
344
401
  >>> output = ops.eye(2, dtype=mindspore.int32)
345
402
  >>> print(output)
346
403
  [[1 0]
@@ -397,20 +454,25 @@ def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype
397
454
  [0.08 0.39785218 0.91214782 0.91214782 0.39785218 0.08]
398
455
  """
399
456
  if not isinstance(window_length, int):
400
- raise TypeError(f"For array function 'hamming_window', 'window_length' must be int, but got" \
457
+ raise TypeError(f"For array function 'hamming_window', 'window_length' must be int, but got"
401
458
  f" {type(window_length)}.")
402
459
  if window_length < 0:
403
- raise ValueError(f"For array function 'hamming_window', 'window_length' must be non negative number.")
460
+ raise ValueError(
461
+ f"For array function 'hamming_window', 'window_length' must be non negative number.")
404
462
  if not isinstance(periodic, bool):
405
- raise TypeError(f"For array function 'hamming_window', 'periodic' must be bool, but got {type(periodic)}.")
463
+ raise TypeError(
464
+ f"For array function 'hamming_window', 'periodic' must be bool, but got {type(periodic)}.")
406
465
  if not isinstance(alpha, float):
407
- raise TypeError(f"For array function 'hamming_window', 'alpha' must be float, but got {type(alpha)}.")
466
+ raise TypeError(
467
+ f"For array function 'hamming_window', 'alpha' must be float, but got {type(alpha)}.")
408
468
  if not isinstance(beta, float):
409
- raise TypeError(f"For array function 'hamming_window', 'beta' must be float, but got {type(beta)}.")
469
+ raise TypeError(
470
+ f"For array function 'hamming_window', 'beta' must be float, but got {type(beta)}.")
410
471
  if window_length <= 1:
411
472
  return Tensor(np.ones(window_length))
412
473
  if dtype is not None and dtype not in mstype.float_type:
413
- raise TypeError(f"For array function 'hamming_window', 'dtype' must be floating point dtypes, but got {dtype}.")
474
+ raise TypeError(
475
+ f"For array function 'hamming_window', 'dtype' must be floating point dtypes, but got {dtype}.")
414
476
 
415
477
  dtype = mstype.float32 if dtype is None else dtype
416
478
  op = _get_cache_prim(P.HammingWindow)(periodic, alpha, beta, dtype)
@@ -419,25 +481,25 @@ def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype
419
481
  return out
420
482
 
421
483
 
422
- def where(condition, x, y):
484
+ def where(condition, input, other):
423
485
  r"""
424
- Selects elements from `x` or `y` based on `condition` and returns a tensor.
486
+ Selects elements from `input` or `other` based on `condition` and returns a tensor.
425
487
 
426
488
  .. math::
427
- output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases}
489
+ output_i = \begin{cases} input_i,\quad &if\ condition_i \\ other_i,\quad &otherwise \end{cases}
428
490
 
429
491
  Args:
430
- condition (Tensor[bool]): If True, yield `x`, otherwise yield `y`.
431
- x (Union[Tensor, Scalar]): When `condition` is True, values to select from.
432
- y (Union[Tensor, Scalar]): When `condition` is False, values to select from.
492
+ condition (Tensor[bool]): If True, yield `input`, otherwise yield `other`.
493
+ input (Union[Tensor, Scalar]): When `condition` is True, values to select from.
494
+ other (Union[Tensor, Scalar]): When `condition` is False, values to select from.
433
495
 
434
496
  Returns:
435
- Tensor, elements are selected from `x` and `y`.
497
+ Tensor, elements are selected from `input` and `other`.
436
498
 
437
499
  Raises:
438
500
  TypeError: If `condition` is not a Tensor.
439
- TypeError: If both `x` and `y` are scalars.
440
- ValueError: If `condition`, `x` and `y` can not broadcast to each other.
501
+ TypeError: If both `input` and `other` are scalars.
502
+ ValueError: If `condition`, `input` and `other` can not broadcast to each other.
441
503
 
442
504
  Supported Platforms:
443
505
  ``Ascend`` ``GPU`` ``CPU``
@@ -454,66 +516,15 @@ def where(condition, x, y):
454
516
  [[0. 1.]
455
517
  [2. 1.]]
456
518
  """
457
- if not isinstance(condition, Tensor):
458
- raise TypeError(f"For 'where', 'condition' must be a Tensor, but got {type(condition)}.")
459
- if isinstance(x, (int, float)):
460
- if not isinstance(y, Tensor):
461
- raise TypeError(
462
- f"For 'where', at least one of 'x' and 'y' should be Tensor, but got x:{type(x)}, y:{type(y)}."
463
- )
464
- x = cast_(x, y.dtype)
465
- elif isinstance(y, (int, float)):
466
- if not isinstance(x, Tensor):
467
- raise TypeError(
468
- f"For 'where', at least one of 'x' and 'y' should be Tensor, but got x:{type(x)}, y:{type(y)}."
469
- )
470
- y = cast_(y, x.dtype)
471
- output_shape = _calc_broadcast_shape(x.shape, y.shape, condition.shape)
472
- condition = broadcast_to(condition, output_shape)
473
- x = broadcast_to(x, output_shape)
474
- y = broadcast_to(y, output_shape)
475
- _select = P.Select()
476
- return _select(condition, x, y)
519
+ return tensor_select_(condition, input, other)
477
520
 
478
521
 
479
522
  def reverse(x, axis):
480
523
  """
481
- Reverses specific dimensions of a tensor.
482
-
483
- .. warning::
484
- The value range of "axis" is [-dims, dims - 1]. "dims" is the dimension length of "input_x".
485
-
486
- Args:
487
- x (Tensor): The target tensor.
488
- The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
489
- axis (Union[tuple(int), list(int)]): The indices of the dimensions to reverse.
490
-
491
- Outputs:
492
- Tensor, has the same shape and type as `x`.
493
-
494
- Raises:
495
- TypeError: If `axis` is neither list nor tuple.
496
- TypeError: If element of `axis` is not an int.
497
-
498
- Supported Platforms:
499
- ``Ascend`` ``GPU`` ``CPU``
500
-
501
- Examples:
502
- >>> import mindspore
503
- >>> import numpy as np
504
- >>> from mindspore import Tensor, ops
505
- >>> input_x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.int32)
506
- >>> output = ops.reverse(input_x, axis=[1])
507
- >>> print(output)
508
- [[4 3 2 1]
509
- [8 7 6 5]]
510
- >>> input_x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.int32)
511
- >>> output = ops.reverse(input_x, axis=[1, 0])
512
- >>> print(output)
513
- [[8 7 6 5]
514
- [4 3 2 1]]
524
+ :func:`mindspore.ops.reverse` will be deprecated in the future.
525
+ Please use :func:`mindspore.ops.flip` instead.
515
526
  """
516
- return P.ReverseV2(axis)(x)
527
+ return flip(x, axis)
517
528
 
518
529
 
519
530
  def ravel(input):
@@ -638,7 +649,8 @@ def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True, ops_n
638
649
  if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
639
650
  for ax in axis:
640
651
  if not isinstance(ax, int):
641
- raise TypeError(f"For {ops_name}, each axis must be integer, but got {type(ax)} in {axis}.")
652
+ raise TypeError(
653
+ f"For {ops_name}, each axis must be integer, but got {type(ax)} in {axis}.")
642
654
  return True
643
655
 
644
656
  type_str = ""
@@ -648,7 +660,8 @@ def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True, ops_n
648
660
  type_str += "tuple, "
649
661
  if type_list:
650
662
  type_str += "list, "
651
- raise TypeError(f"For {ops_name}, the axis should be {type_str}, but got {type(axis)}.")
663
+ raise TypeError(
664
+ f"For {ops_name}, the axis should be {type_str}, but got {type(axis)}.")
652
665
 
653
666
 
654
667
  def one_hot(indices, depth, on_value=1, off_value=0, axis=-1):
@@ -659,8 +672,9 @@ def one_hot(indices, depth, on_value=1, off_value=0, axis=-1):
659
672
  other locations take value `off_value`.
660
673
 
661
674
  Note:
662
- If the input indices is rank `N`, the output will have rank `N+1`. The new axis is created at dimension `axis`.
663
- On Ascend, if `on_value` is int64 dtype, `indices` must be int64 dtype.
675
+ If the input `indices` has rank `N`, the output will have rank `N+1`.
676
+ The new axis is created at dimension `axis`. On Ascend, if `on_value` is int64 dtype, `indices` must be
677
+ int64 dtype, and the value for `on_value` and `off_value` can only be 1 and 0.
664
678
 
665
679
  Args:
666
680
  indices(Tensor): A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
@@ -682,6 +696,7 @@ def one_hot(indices, depth, on_value=1, off_value=0, axis=-1):
682
696
  Raises:
683
697
  TypeError: If `axis` or `depth` is not an int.
684
698
  TypeError: If dtype of `indices` is not int32 or int64.
699
+ TypeError: If dtype of `on_value` is not int32, int64, float16 or float32.
685
700
  TypeError: If `indices`, `on_value` or `off_value` is not a Tensor.
686
701
  ValueError: If `axis` is not in range [-1, ndim].
687
702
  ValueError: If `depth` is less than 0.
@@ -715,8 +730,8 @@ def fill(type, shape, value): # pylint: disable=redefined-outer-name
715
730
 
716
731
  Args:
717
732
  type (mindspore.dtype): The specified type of output tensor. The data type only supports
718
- `bool_ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.html#mindspore.dtype>`_ and
719
- `number <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.html#mindspore.dtype>`_ .
733
+ `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ and
734
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ .
720
735
  shape (Union(Tensor, tuple[int])): The specified shape of output tensor.
721
736
  value (Union(Tensor, number.Number, bool)): Value to fill the returned tensor.
722
737
 
@@ -743,7 +758,7 @@ def fill(type, shape, value): # pylint: disable=redefined-outer-name
743
758
  [0. 0. 0.]]
744
759
  """
745
760
  value = cast_(value, type)
746
- return _get_cache_prim(P.FillV2)()(shape, value)
761
+ return fillv2_(shape, value)
747
762
 
748
763
 
749
764
  def full(size, fill_value, *, dtype=None): # pylint: disable=redefined-outer-name
@@ -781,16 +796,56 @@ def full(size, fill_value, *, dtype=None): # pylint: disable=redefined-outer-na
781
796
  [0. 0. 0.]]
782
797
  """
783
798
  if not isinstance(size, (list, tuple)):
784
- raise TypeError(f"For 'ops.full', 'size' must be a tuple or list of ints, but got {type(size)}.")
799
+ raise TypeError(
800
+ f"For 'ops.full', 'size' must be a tuple or list of ints, but got {type(size)}.")
785
801
  if dtype is None:
786
802
  dtype = mstype.int64
787
803
  if dtype not in mstype.all_types:
788
- raise TypeError(f"For 'ops.full', 'dtype' must be mindspore.type, but got {dtype}.")
804
+ raise TypeError(
805
+ f"For 'ops.full', 'dtype' must be mindspore.type, but got {dtype}.")
789
806
  if isinstance(size, list):
790
807
  size = tuple(size)
791
808
  return ops.fill(dtype, size, fill_value)
792
809
 
793
810
 
811
+ def full_ext(size, fill_value, *, dtype=None): # pylint: disable=redefined-outer-name
812
+ """
813
+ Create a Tensor of the specified shape and fill it with the specified value.
814
+
815
+ Args:
816
+ size (Union(tuple[int], list[int])): The specified shape of output tensor.
817
+ fill_value (Union(number.Number, Tensor)): Value to fill the returned tensor. It can be a Scalar number, a 0-D
818
+ Tensor, or a 1-D Tensor with only one element.
819
+
820
+ Keyword Args:
821
+ dtype (mindspore.dtype): The specified type of output tensor. `bool_` and `number` are supported, for details,
822
+ please refer to :class:`mindspore.dtype` . Default: ``None`` .
823
+
824
+ Returns:
825
+ Tensor.
826
+
827
+ Raises:
828
+ TypeError: If `size` is not a tuple or list.
829
+ ValueError: The element in `size` is less than 0.
830
+
831
+ Supported Platforms:
832
+ ``Ascend`` ``GPU`` ``CPU``
833
+
834
+ Examples:
835
+ >>> from mindspore import ops
836
+ >>> output = ops.full_ext((2, 2), 1)
837
+ >>> print(output)
838
+ [[1. 1.]
839
+ [1. 1.]]
840
+ >>> output = ops.full_ext((3, 3), 0)
841
+ >>> print(output)
842
+ [[0. 0. 0.]
843
+ [0. 0. 0.]
844
+ [0. 0. 0.]]
845
+ """
846
+ return fill_scalar_(size, fill_value, dtype)
847
+
848
+
794
849
  def full_like(input, fill_value, *, dtype=None):
795
850
  """
796
851
  Return a Tensor of the same shape as `input` and filled with `fill_value`.
@@ -828,7 +883,8 @@ def full_like(input, fill_value, *, dtype=None):
828
883
  [0. 0. 0.]]
829
884
  """
830
885
  if not isinstance(input, Tensor):
831
- raise TypeError(f"For ops.full_like, the argument 'x' must be tensor, but got {type(input)}")
886
+ raise TypeError(
887
+ f"For ops.full_like, the argument 'x' must be tensor, but got {type(input)}")
832
888
  if dtype is None:
833
889
  dtype = input.dtype
834
890
  return full(input.shape, fill_value, dtype=dtype)
@@ -870,37 +926,86 @@ def chunk(input, chunks, axis=0):
870
926
  Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
871
927
  """
872
928
  if not isinstance(input, Tensor):
873
- raise TypeError(f'For ops.chunk parameter `input` must be Tensor, but got {type(input)}')
929
+ raise TypeError(
930
+ f'For ops.chunk parameter `input` must be Tensor, but got {type(input)}')
874
931
  _check_axis_type(axis, True, False, False, "ops.chunk")
875
932
  arr_axis = _canonicalize_axis(axis, input.ndim)
876
933
 
877
934
  if not isinstance(chunks, int):
878
- raise TypeError(f"For ops.chunk type of argument `chunks` should be integer, but got {type(chunks)}")
935
+ raise TypeError(
936
+ f"For ops.chunk type of argument `chunks` should be integer, but got {type(chunks)}")
879
937
  if chunks <= 0:
880
- raise ValueError(f"For ops.chunk parameter 'chunks' must be greater than 0, but got {chunks}")
938
+ raise ValueError(
939
+ f"For ops.chunk parameter 'chunks' must be greater than 0, but got {chunks}")
881
940
 
882
941
  arr_shape = input.shape
883
942
  length_along_dim = arr_shape[arr_axis]
884
943
 
885
- if chunks > length_along_dim:
886
- res = P.Split(arr_axis, length_along_dim)(input)
944
+ if length_along_dim == 0:
945
+ res = _get_cache_prim(P.Split)(arr_axis)(input)
946
+ elif chunks > length_along_dim:
947
+ res = _get_cache_prim(P.Split)(arr_axis, length_along_dim)(input)
887
948
  elif length_along_dim % chunks == 0:
888
- res = P.Split(arr_axis, chunks)(input)
949
+ res = _get_cache_prim(P.Split)(arr_axis, chunks)(input)
889
950
  else:
890
951
  block_size = int(np.ceil(length_along_dim / chunks))
891
952
  true_chunks = int(length_along_dim // block_size)
892
953
  length1 = true_chunks * block_size
893
954
  length2 = length_along_dim - length1
894
- start1 = _list_comprehensions(rank(input), 0, True)
955
+ start1 = _list_comprehensions(rank_(input), 0, True)
895
956
  size1 = _tuple_setitem(arr_shape, arr_axis, length1)
896
957
  start2 = _tuple_setitem(start1, arr_axis, length1)
897
958
  size2 = _tuple_setitem(arr_shape, arr_axis, length2)
898
- res = P.Split(arr_axis, true_chunks)(tensor_slice(input, start1, size1))
959
+ res = _get_cache_prim(P.Split)(arr_axis, true_chunks)(
960
+ tensor_slice(input, start1, size1))
899
961
  if length2:
900
- res += P.Split(arr_axis, 1)(tensor_slice(input, start2, size2))
962
+ res += _get_cache_prim(P.Split)(arr_axis,
963
+ 1)(tensor_slice(input, start2, size2))
901
964
  return res
902
965
 
903
966
 
967
+ def chunk_ext(input, chunks, dim=0):
968
+ """
969
+ Cut the input Tensor into `chunks` sub-tensors along the specified axis.
970
+
971
+ Note:
972
+ This function may return less than the specified number of chunks!
973
+
974
+ .. warning::
975
+ This is an experimental API that is subject to change or deletion.
976
+
977
+ Args:
978
+ input (Tensor): A Tensor to be cut.
979
+ chunks (int): Number of sub-tensors to cut.
980
+ dim (int, optional): Specify the dimensions that you want to split. Default: ``0`` .
981
+
982
+ Returns:
983
+ A tuple of sub-tensors.
984
+
985
+ Raises:
986
+ TypeError: If argument `input` is not Tensor.
987
+ TypeError: The sum of `chunks` is not int.
988
+ TypeError: If argument `dim` is not int.
989
+ ValueError: If argument `dim` is out of range of :math:`[-input.ndim, input.ndim)` .
990
+ ValueError: If argument `chunks` is not positive number.
991
+
992
+ Supported Platforms:
993
+ ``Ascend``
994
+
995
+ Examples:
996
+ >>> import numpy as np
997
+ >>> import mindspore
998
+ >>> from mindspore import Tensor
999
+ >>> input_x = np.arange(9).astype("float32")
1000
+ >>> output = mindspore.mint.chunk(Tensor(input_x), 3)
1001
+ >>> print(output)
1002
+ (Tensor(shape=[3], dtype=Float32, value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
1003
+ Tensor(shape=[3], dtype=Float32, value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
1004
+ Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
1005
+ """
1006
+ return chunk_(input, chunks, dim)
1007
+
1008
+
904
1009
  def fills(x, value):
905
1010
  """
906
1011
  `fills` is deprecated, please use `ops.fill` instead.
@@ -920,50 +1025,6 @@ def fills(x, value):
920
1025
  return fills_(x, value_)
921
1026
 
922
1027
 
923
- def ones(shape, dtype=None): # pylint: disable=redefined-outer-name
924
- r"""
925
- Creates a tensor filled with value ones.
926
-
927
- Creates a tensor with shape described by the first argument and fills it with value ones in type of the second
928
- argument.
929
-
930
- Args:
931
- shape (Union[tuple[int], int, Tensor]): The specified shape of output tensor. Only positive integer or
932
- tuple or Tensor containing positive integers are allowed. If it is a Tensor,
933
- it must be a 0-D or 1-D Tensor with int32 or int64 dtypes.
934
- dtype (:class:`mindspore.dtype`): The specified type of output tensor. If `dtype` is ``None`` ,
935
- `mindspore.float32` will be used. Default: ``None`` .
936
-
937
- Returns:
938
- Tensor, has the same type and shape as input shape value.
939
-
940
- Raises:
941
- TypeError: If `shape` is not tuple, int or Tensor.
942
-
943
- Supported Platforms:
944
- ``Ascend`` ``GPU`` ``CPU``
945
-
946
- Examples:
947
- >>> import mindspore
948
- >>> from mindspore import ops
949
- >>> output = ops.ones((2, 2), mindspore.float32)
950
- >>> print(output)
951
- [[1. 1.]
952
- [1. 1.]]
953
- """
954
- _dtype = mstype.float32 if dtype is None else dtype
955
- ones_op = _get_cache_prim(P.FillV2)()
956
- value = Tensor(1, _dtype)
957
- if isinstance(shape, int):
958
- shape = tuple([shape])
959
- elif isinstance(shape, list):
960
- shape = Tensor(shape, dtype=mstype.int64)
961
- elif isinstance(shape, Tensor) and shape.ndim == 0 and shape.size == 1:
962
- shape = shape.reshape(1)
963
- output = ones_op(shape, value)
964
- return output
965
-
966
-
967
1028
  def ones_like(input, *, dtype=None):
968
1029
  """
969
1030
  Returns a Tensor with a value of 1 and its shape is the same as the input.
@@ -993,57 +1054,15 @@ def ones_like(input, *, dtype=None):
993
1054
  [[1 1]
994
1055
  [1 1]]
995
1056
  """
996
- ones_like_op = _get_cache_prim(P.OnesLike)()
997
- output = ones_like_op(input)
1057
+ output = ones_like_(input)
998
1058
  _dtype = input.dtype if dtype is None else dtype
999
1059
  output = cast_(output, _dtype)
1000
1060
  return output
1001
1061
 
1002
1062
 
1003
- def zeros(size, dtype=None): # pylint: disable=redefined-outer-name
1004
- r"""
1005
- Creates a tensor filled with 0 with shape described by `shape` and fills it with value 0 in type of `dtype`.
1006
-
1007
- Args:
1008
- size (Union[tuple[int], int, Tensor]): The specified shape of output tensor. Only positive integer or
1009
- tuple or Tensor containing positive integers are allowed. If it is a Tensor,
1010
- it must be a 0-D or 1-D Tensor with int32 or int64 dtypes.
1011
- dtype (:class:`mindspore.dtype`, optional): The specified type of output tensor. If `dtype` is ``None`` ,
1012
- mindspore.float32 will be used. Default: ``None`` .
1013
-
1014
- Returns:
1015
- Tensor, has the same dtype and size as input.
1016
-
1017
- Raises:
1018
- TypeError: If `size` is not tuple, int or Tensor.
1019
-
1020
- Supported Platforms:
1021
- ``Ascend`` ``GPU`` ``CPU``
1022
-
1023
- Examples:
1024
- >>> import mindspore
1025
- >>> from mindspore import ops
1026
- >>> output = ops.zeros((2, 2), mindspore.float32)
1027
- >>> print(output)
1028
- [[0. 0.]
1029
- [0. 0.]]
1030
- """
1031
- zero_op = _get_cache_prim(P.FillV2)()
1032
- _dtype = mstype.float32 if dtype is None else dtype
1033
- value = Tensor(0, _dtype)
1034
- if isinstance(size, int):
1035
- size = tuple([size])
1036
- elif isinstance(size, list):
1037
- size = Tensor(size, dtype=mstype.int64)
1038
- elif isinstance(size, Tensor) and size.ndim == 0 and size.size == 1:
1039
- size = size.reshape(1)
1040
- output = zero_op(size, value)
1041
- return output
1042
-
1043
-
1044
1063
  def zeros_like(input, *, dtype=None):
1045
1064
  r"""
1046
- Creates a tensor filled with 0, with the same size as x, and the given dtype.
1065
+ Creates a tensor filled with 0, with the same size as input, and the given dtype.
1047
1066
 
1048
1067
  If `dtype = None`, the tensor will have the same dtype as input `input`.
1049
1068
 
@@ -1074,125 +1093,78 @@ def zeros_like(input, *, dtype=None):
1074
1093
  [0. 0.]]
1075
1094
  """
1076
1095
  _dtype = input.dtype if dtype is None else dtype
1077
- _zeros_like = _get_cache_prim(P.ZerosLike)()
1078
- _cast = _get_cache_prim(P.Cast)()
1079
- output = _zeros_like(input)
1080
- output = _cast(output, _dtype)
1096
+ output = zeros_like_(input)
1097
+ output = cast_(output, _dtype)
1081
1098
  return output
1082
1099
 
1083
1100
 
1084
- def tile(input, multiples):
1085
- r"""
1086
- Replicates an input tensor with given multiples times.
1087
-
1088
- Creates a new tensor by replicating `input` `multiples` times. The i'th dimension of
1089
- output tensor has `input.shape[i] * multiples[i]` elements, and the values of `input`
1090
- are replicated `multiples[i]` times along the i'th dimension.
1101
+ def ones_like_ext(input, *, dtype=None):
1102
+ """
1103
+ Creates a tensor filled with 1, with the same shape as input, and its data type is determined by the given dtype.
1091
1104
 
1092
- Note:
1093
- The length of `multiples` must be greater or equal to the length of dimension in `input`.
1105
+ If `dtype = None`, the tensor will have the same dtype as input `input`.
1094
1106
 
1095
1107
  Args:
1096
- input (Tensor): 1-D or higher dimensional Tensor. Set the shape of input tensor as
1097
- :math:`(x_1, x_2, ..., x_S)` .
1108
+ input (Tensor): Tensor of any dimension.
1098
1109
 
1099
- multiples (tuple[int]): The parameter that specifies the number of replications,
1100
- the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
1101
- The length of `multiples` cannot be smaller than the length of the shape of `input`.
1102
- Only constant value is allowed.
1110
+ Keyword Args:
1111
+ dtype (:class:`mindspore.dtype`, optional): The specified dtype of the output tensor. If `dtype` is ``None`` ,
1112
+ the dtype of the input tensor will be used. Default: ``None`` .
1103
1113
 
1104
1114
  Returns:
1105
- Tensor, has the same data type as the `input`. Suppose the length of `multiples` is `d`,
1106
- the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
1107
-
1108
- - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
1109
- the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
1110
- - If `input.dim < d`, fill in multiple 1 in the length of the shape of `input` until their
1111
- lengths are consistent. Such as set the shape of `input` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
1112
- then the shape of their corresponding positions can be multiplied, and the shape of Outputs is
1113
- :math:`(1*y_1, ..., x_R*y_R, x_S*y_S)`.
1115
+ Tensor, has the same shape as `input` but filled with ones.
1114
1116
 
1115
1117
  Raises:
1116
- TypeError: If `multiples` is not a tuple or its elements are not all int.
1117
- ValueError: If the elements of `multiples` are not all greater than 0.
1118
- ValueError: If the length of `multiples` are smaller than the length of dimension in `input`.
1118
+ TypeError: If `input` is not a Tensor.
1119
1119
 
1120
1120
  Supported Platforms:
1121
1121
  ``Ascend`` ``GPU`` ``CPU``
1122
1122
 
1123
1123
  Examples:
1124
- >>> import mindspore
1125
1124
  >>> import numpy as np
1126
1125
  >>> from mindspore import Tensor, ops
1127
- >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
1128
- >>> multiples = (2, 3)
1129
- >>> output = ops.tile(input, multiples)
1130
- >>> print(output)
1131
- [[1. 2. 1. 2. 1. 2.]
1132
- [3. 4. 3. 4. 3. 4.]
1133
- [1. 2. 1. 2. 1. 2.]
1134
- [3. 4. 3. 4. 3. 4.]]
1135
- >>> multiples = (2, 3, 2)
1136
- >>> output = ops.tile(input, multiples)
1137
- >>> print(output)
1138
- [[[1. 2. 1. 2.]
1139
- [3. 4. 3. 4.]
1140
- [1. 2. 1. 2.]
1141
- [3. 4. 3. 4.]
1142
- [1. 2. 1. 2.]
1143
- [3. 4. 3. 4.]]
1144
- [[1. 2. 1. 2.]
1145
- [3. 4. 3. 4.]
1146
- [1. 2. 1. 2.]
1147
- [3. 4. 3. 4.]
1148
- [1. 2. 1. 2.]
1149
- [3. 4. 3. 4.]]]
1150
- """
1151
- tile_op = _get_cache_prim(P.Tile)()
1152
- return tile_op(input, multiples)
1153
-
1154
-
1155
- def range(start, end, step):
1126
+ >>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
1127
+ >>> output = ops.function.array_func.ones_like_ext(x)
1128
+ >>> print(output)
1129
+ [[1 1]
1130
+ [1 1]]
1131
+ """
1132
+ return ones_like_ext_(input, dtype)
1133
+
1134
+
1135
+ def zeros_like_ext(input, *, dtype=None):
1156
1136
  r"""
1157
- Creates a sequence of numbers that begins at `start` and extends by increments of
1158
- `limit` up to but not including `end`.
1137
+ Creates a tensor filled with 0, with the same size as input. Its data type is determined by the given dtype.
1159
1138
 
1160
- The types of all 3 inputs must be the same. The type of the resulting tensor is
1161
- the same as the type of the inputs.
1139
+ If `dtype = None`, the tensor will have the same dtype as input `input`.
1162
1140
 
1163
1141
  Args:
1164
- start (Tensor): A scalar Tensor. The first number in the sequence. Must have
1165
- type: int32 ,int64, float32 or float64.
1166
- end (Tensor): A scalar Tensor. Upper limit of the sequence, exclusive. Must
1167
- have type: int32 ,int64, float32 or float64.
1168
- step (Tensor): A scalar Tensor. Number that increments `start`. Must have
1169
- type: int32 ,int64, float32 or float64.
1142
+ input (Tensor): Tensor of any dimension.
1143
+
1144
+ Keyword Args:
1145
+ dtype (:class:`mindspore.dtype`, optional): The specified dtype of the output tensor. If `dtype` is ``None`` ,
1146
+ the dtype of the input tensor will be used. Default: ``None`` .
1170
1147
 
1171
1148
  Returns:
1172
- A 1-D Tensor, with the same type as the inputs.
1149
+ Tensor, filled with 0.
1173
1150
 
1174
1151
  Raises:
1175
- TypeError: If `start`, `end` or `step` is not scalar Tensor.
1176
- TypeError: If datatype of `start`, `end` or `step` is not same.
1177
- TypeError: If datatype of `start`, `end` or `step` is not supported.
1178
- ValueError: If `step` = 0.
1179
- ValueError: If `start` >= `end` when `step` > 0.
1180
- ValueError: If `start` <= `end` when `step` < 0.
1152
+ TypeError: If dtype is not a MindSpore dtype.
1181
1153
 
1182
1154
  Supported Platforms:
1183
- ``GPU`` ``CPU``
1155
+ ``Ascend`` ``GPU`` ``CPU``
1184
1156
 
1185
1157
  Examples:
1158
+ >>> import mindspore
1159
+ >>> import numpy as np
1186
1160
  >>> from mindspore import Tensor, ops
1187
- >>> from mindspore import dtype as mstype
1188
- >>> start = Tensor(0, mstype.int32)
1189
- >>> end = Tensor(10, mstype.int32)
1190
- >>> step = Tensor(4, mstype.int32)
1191
- >>> output = ops.range(start, end, step)
1161
+ >>> x = Tensor(np.arange(4).reshape(2, 2))
1162
+ >>> output = ops.function.array_func.zeros_like_ext(x, dtype=mindspore.float32)
1192
1163
  >>> print(output)
1193
- [0 4 8]
1164
+ [[0. 0.]
1165
+ [0. 0.]]
1194
1166
  """
1195
- return range_(start, end, step)
1167
+ return zeros_like_ext_(input, dtype)
1196
1168
 
1197
1169
 
1198
1170
  ##############################
@@ -1246,18 +1218,88 @@ def unique(input):
1246
1218
  >>> print(idx)
1247
1219
  [0 1 2 1]
1248
1220
  """
1249
-
1250
- unique_op = _get_cache_prim(P.Unique)()
1251
- reshape_op = _get_cache_prim(P.Reshape)()
1252
-
1253
1221
  shape_x = input.shape
1254
1222
  length_x = get_x_shape(shape_x)
1255
- input = reshape_op(input, length_x)
1256
- y, idx = unique_op(input)
1257
- idx = reshape_op(idx, shape_x)
1223
+ input = reshape_(input, length_x)
1224
+ y, idx = unique_(input)
1225
+ idx = reshape_(idx, shape_x)
1258
1226
  return y, idx
1259
1227
 
1260
1228
 
1229
+ def unique_ext(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
1230
+ """
1231
+ Returns the unique elements of input tensor.
1232
+
1233
+ when `return_inverse=True`, also return a tensor containing the index of each value of input
1234
+ tensor corresponding to the output unique tensor.
1235
+ when `return_counts=True`, also return a tensor containing the number of occurrences for each
1236
+ unique value or tensor
1237
+
1238
+ Args:
1239
+ input (Tensor): The input tensor.
1240
+ sorted(bool): Whether to sort the unique elements in ascending order before returning as output.
1241
+ Default: ``True`` .
1242
+ return_inverse(bool): Whether to also return the indices for where elements in the original input ended up in
1243
+ the returned unique list. Default: ``False`` .
1244
+ return_counts(bool): Whether to also return the counts for each unique element. Default: ``False`` .
1245
+ dim(int): the dimension to operate upon. If ``None``, the unique of the flattened input is returned.
1246
+ Otherwise, each of the tensors indexed by the given dimension is treated as one of the elements to apply the
1247
+ unique operation upon. Default: ``None`` .
1248
+
1249
+
1250
+ Returns:
1251
+ A tensor or a tuple of tensors containing some of tensor objects (`output`, `inverse_indices`, `counts`).
1252
+
1253
+ - output(Tensor) - The output tensor including the unique elements of input tensor, it has same dtype as input.
1254
+ - inverse_indices(Tensor) - Return when ``return_inverse`` is True. It represents the indices for where
1255
+ elements in the original input map to in the output. When ``dim`` is ``None``, it has same shape as input,
1256
+ otherwise, the shape is input.shape[dim].
1257
+ - counts(Tensor) - Return when ``return_counts`` is True. It represents the number of occurrences for each
1258
+ unique value or tensor. When ``dim`` is ``None``, it has same shape as output, otherwise, the shape is
1259
+ output.shape(dim).
1260
+
1261
+
1262
+ Raises:
1263
+ TypeError: If `input` is not a Tensor.
1264
+
1265
+ Supported Platforms:
1266
+ ``Ascend``
1267
+
1268
+ Examples:
1269
+ >>> import mindspore
1270
+ >>> import numpy as np
1271
+ >>> from mindspore import Tensor, nn
1272
+ >>> from mindspore import ops
1273
+ >>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
1274
+ >>> output = ops.unique_ext(x, return_inverse=True)
1275
+ >>> print(output)
1276
+ (Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int64, value= [0, 1, 2, 1]))
1277
+ >>> y = output[0]
1278
+ >>> print(y)
1279
+ [1 2 5]
1280
+ >>> idx = output[1]
1281
+ >>> print(idx)
1282
+ [0 1 2 1]
1283
+ """
1284
+ if not F.isconstant(return_inverse) or not F.isconstant(return_counts):
1285
+ raise ValueError(
1286
+ f"For 'unique_ext', 'return_inverse' and 'return_counts' cannot be mutable")
1287
+ if dim is None:
1288
+ y, inverse, counts = unique2_(
1289
+ input, sorted, return_inverse, return_counts)
1290
+ else:
1291
+ validator.check_value_type(
1292
+ "return_counts", return_counts, [bool], "unique_ext")
1293
+ y, inverse, counts = unique_dim_(input, sorted, return_inverse, dim)
1294
+ if return_inverse and return_counts:
1295
+ return y, inverse, counts
1296
+ if return_inverse:
1297
+ return y, inverse
1298
+ if return_counts:
1299
+ return y, counts
1300
+ return y
1301
+
1302
+
1261
1303
  def unique_with_pad(x, pad_num):
1262
1304
  """
1263
1305
  Returns unique elements and relative indexes in 1-D tensor, filled with padding num.
@@ -1268,6 +1310,9 @@ def unique_with_pad(x, pad_num):
1268
1310
  the UniqueWithPad operator will fill the `y` Tensor with the `pad_num` specified by the user
1269
1311
  to make it have the same shape as the Tensor `idx`.
1270
1312
 
1313
+ .. warning::
1314
+ :func:`mindspore.ops.unique_with_pad` is deprecated from version 2.4 and will be removed in a future version.
1315
+
1271
1316
  Args:
1272
1317
  x (Tensor): The tensor need to be unique. Must be 1-D vector with types: int32, int64.
1273
1318
  pad_num (int): Pad num. The data type is an int.
@@ -1280,10 +1325,10 @@ def unique_with_pad(x, pad_num):
1280
1325
 
1281
1326
  Raises:
1282
1327
  TypeError: If dtype of `x` is neither int32 nor int64.
1283
- ValueError: If length of shape of `x` is not equal to 1.
1328
+ ValueError: If `x` is not a 1-D Tensor.
1284
1329
 
1285
1330
  Supported Platforms:
1286
- ``Ascend`` ``GPU`` ``CPU``
1331
+ Deprecated
1287
1332
 
1288
1333
  Examples:
1289
1334
  >>> import mindspore
@@ -1302,7 +1347,7 @@ def unique_with_pad(x, pad_num):
1302
1347
  >>> print(idx)
1303
1348
  [0 1 1 2 3 3]
1304
1349
  """
1305
- return unique_with_pad_(x, pad_num)
1350
+ return _get_cache_prim(P.UniqueWithPad)()(x, pad_num)
1306
1351
 
1307
1352
 
1308
1353
  def unique_consecutive(input, return_idx=False, return_counts=False, axis=None):
@@ -1352,7 +1397,8 @@ def unique_consecutive(input, return_idx=False, return_counts=False, axis=None):
1352
1397
 
1353
1398
  if not isinstance(input, (Tensor, Tensor_)):
1354
1399
  raise TypeError("For 'unique_consecutive', 'input' must be Tensor.")
1355
- unique_consecutive_op = _get_cache_prim(UniqueConsecutive)(return_idx, return_counts, axis)
1400
+ unique_consecutive_op = _get_cache_prim(
1401
+ UniqueConsecutive)(return_idx, return_counts, axis)
1356
1402
  output, idx, counts = unique_consecutive_op(input)
1357
1403
  if return_idx and return_counts:
1358
1404
  return output, idx, counts
@@ -1363,7 +1409,7 @@ def unique_consecutive(input, return_idx=False, return_counts=False, axis=None):
1363
1409
  return output
1364
1410
 
1365
1411
 
1366
- def searchsorted(sorted_sequence, values, *, out_int32=False, right=False):
1412
+ def searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, sorter=None):
1367
1413
  """
1368
1414
  Return the position indices such that after inserting the values into the `sorted_sequence`, the order of innermost
1369
1415
  dimension of the `sorted_sequence` remains unchanged.
@@ -1378,16 +1424,24 @@ def searchsorted(sorted_sequence, values, *, out_int32=False, right=False):
1378
1424
  if ``False`` , the output datatype will be int64. Default: ``False`` .
1379
1425
  right (bool, optional): Search Strategy. If ``True`` , return the last suitable index found;
1380
1426
  if ``False`` , return the first such index. Default: ``False`` .
1427
+ side (str, optional): the same as right but preferred. ``"left"`` corresponds to ``False`` for `right`
1428
+ and ``"right"`` corresponds to ``True`` for `right`. An error will be reported if this parameter is
1429
+ set to ``"left"`` while `right` is ``True``. Default: ``None`` .
1430
+ sorter(Tensor, optional): if provided, a tensor matching the shape of the unsorted sorted_sequence
1431
+ containing a sequence of indices that sort it in the ascending order on the innermost
1432
+ dimension and type must be int64. Default: ``None`` . CPU and GPU can only use default values
1381
1433
 
1382
1434
  Returns:
1383
1435
  Tensor containing the indices from the innermost dimension of `sorted_sequence` such that,
1384
- if insert the corresponding value in the `values` tensor, the order of `sorted_sequence` would be preserved,
1436
+ if insert the corresponding value in the `values` Tensor, the order of `sorted_sequence` would be preserved,
1385
1437
  whose datatype is int32 if out_int32 is ``True`` , otherwise int64, and shape is the same as the shape of
1386
1438
  `values`.
1387
1439
 
1388
1440
  Raises:
1389
1441
  ValueError: If the dimension of `sorted_sequence` isn't 1 and all dimensions except the last dimension of
1390
1442
  `sorted_sequence` and `values` are different.
1443
+ ValueError: If `sorted_sequence` value is a scalar.
1444
+ ValueError: If `values` is a scalar when `sorted_sequence` dimension is not 1.
1391
1445
 
1392
1446
  Supported Platforms:
1393
1447
  ``Ascend`` ``GPU`` ``CPU``
@@ -1404,10 +1458,15 @@ def searchsorted(sorted_sequence, values, *, out_int32=False, right=False):
1404
1458
  [1 2 4]]
1405
1459
  """
1406
1460
 
1407
- _check_attr_dtype("out_int32", out_int32, [bool], "search_sorted")
1408
- dtype = mstype.int64 if not out_int32 else mstype.int32
1409
- search_sorted_ = SearchSorted(dtype, right)
1410
- return search_sorted_(sorted_sequence, values)
1461
+ validator.check_value_type("out_int32", out_int32, [bool], "search_sorted")
1462
+ validator.check_value_type("right", right, [bool], "search_sorted")
1463
+ dtype = mstype.int32 if bool(out_int32) else mstype.int64
1464
+ if (side == "left" and right is True):
1465
+ raise ValueError(f"For 'searchsorted', side and right can't be set to opposites,"
1466
+ f"got side of left while right was True.")
1467
+ if side == "right":
1468
+ right = True
1469
+ return search_sorted_(sorted_sequence, values, sorter, dtype, right)
1411
1470
 
1412
1471
 
1413
1472
  def ger(input, vec2):
@@ -1457,7 +1516,7 @@ def size(input_x):
1457
1516
 
1458
1517
  Args:
1459
1518
  input_x (Tensor): Input parameters, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is
1460
- `number <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.html#mindspore.dtype>`_.
1519
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
1461
1520
 
1462
1521
  Returns:
1463
1522
  int. A scalar representing the elements' size of `input_x`, tensor is the number of elements
@@ -1538,76 +1597,6 @@ def dyn_shape(input_x):
1538
1597
  return tensor_shape_(input_x)
1539
1598
 
1540
1599
 
1541
- def rank(input_x):
1542
- """
1543
- Returns the rank of a tensor.
1544
-
1545
- Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
1546
- is the number of indices required to uniquely select each element of the tensor.
1547
-
1548
- Args:
1549
- input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number.
1550
-
1551
- Returns:
1552
- Tensor. 0-D int32 Tensor representing the rank of input, i.e., :math:`R`. The data type is an int.
1553
-
1554
- Raises:
1555
- TypeError: If `input_x` is not a Tensor.
1556
-
1557
- Supported Platforms:
1558
- ``Ascend`` ``GPU`` ``CPU``
1559
-
1560
- Examples:
1561
- >>> import mindspore
1562
- >>> import numpy as np
1563
- >>> from mindspore import Tensor, ops
1564
- >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
1565
- >>> output = ops.rank(input_tensor)
1566
- >>> print(output)
1567
- 2
1568
- >>> print(type(output))
1569
- <class 'int'>
1570
- """
1571
- return rank_(input_x)
1572
-
1573
-
1574
- def reshape(input, shape):
1575
- """
1576
- Rearranges the input Tensor based on the given shape.
1577
-
1578
- The 'shape' can only have one -1 at most, in which case it's inferred from the remaining dimensions and
1579
- the number of elements in the input.
1580
-
1581
- Args:
1582
- input (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1583
- shape (Union[tuple[int], Tensor[int]]): Constructed by multiple
1584
- integers, i.e., :math:`(y_1, y_2, ..., y_S)`. Only constant value is allowed.
1585
-
1586
- Returns:
1587
- Tensor, the shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
1588
-
1589
- Raises:
1590
- ValueError: Given a shape tuple, if it has several -1; or if the product
1591
- of its elements is less than or equal to 0 or cannot be divided by the product
1592
- of the input tensor shape; or if it does not match the input's array size.
1593
-
1594
- Supported Platforms:
1595
- ``Ascend`` ``GPU`` ``CPU``
1596
-
1597
- Examples:
1598
- >>> import mindspore
1599
- >>> import numpy as np
1600
- >>> from mindspore import Tensor, ops
1601
- >>> input = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
1602
- >>> output = ops.reshape(input, (3, 2))
1603
- >>> print(output)
1604
- [[-0.1 0.3]
1605
- [ 3.6 0.4]
1606
- [ 0.5 -3.2]]
1607
- """
1608
- return reshape_(input, shape)
1609
-
1610
-
1611
1600
  def reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0):
1612
1601
  r"""
1613
1602
  Reverses variable length slices.
@@ -1672,7 +1661,7 @@ def reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0):
1672
1661
  [[4. 3. 2. 1.]
1673
1662
  [8. 7. 6. 5.]]
1674
1663
  """
1675
- return P.ReverseSequence(seq_dim=seq_dim, batch_dim=batch_dim)(x, seq_lengths)
1664
+ return _get_cache_prim(P.ReverseSequence)(seq_dim=seq_dim, batch_dim=batch_dim)(x, seq_lengths)
1676
1665
 
1677
1666
 
1678
1667
  def flatten(input, order='C', *, start_dim=1, end_dim=-1):
@@ -1696,7 +1685,7 @@ def flatten(input, order='C', *, start_dim=1, end_dim=-1):
1696
1685
  Raises:
1697
1686
  TypeError: If `input` is not a Tensor.
1698
1687
  TypeError: If `order` is not string type.
1699
- ValueError: If `order` is string type, but not 'C' or 'F'.
1688
+ ValueError: If `order` is string type, but not ``'C'`` or ``'F'``.
1700
1689
  TypeError: If `start_dim` or `end_dim` is not int.
1701
1690
  ValueError: If `start_dim` is greater than `end_dim` after canonicalized.
1702
1691
  ValueError: If `start_dim` or `end_dim` is not in range of [-input.dim, input.dim-1].
@@ -1718,392 +1707,59 @@ def flatten(input, order='C', *, start_dim=1, end_dim=-1):
1718
1707
  if axis < -ndim or axis >= ndim:
1719
1708
  raise ValueError("'start_dim' or 'end_dim' out of range.")
1720
1709
 
1721
- def check_dim_valid(start_dim, end_dim):
1722
- if start_dim > end_dim:
1723
- raise ValueError("For 'flatten', 'start_dim' cannot come after 'end_dim'.")
1724
-
1725
- def canonicalize_axis(axis, x_rank):
1726
- ndim = x_rank if x_rank != 0 else 1
1727
- check_axis_valid(axis, ndim)
1728
- return axis if axis >= 0 else axis + ndim
1729
-
1730
- # Check the types of arguments.
1731
- if not isinstance(input, Tensor):
1732
- raise TypeError(f"For 'flatten', argument 'input' must be Tensor.")
1733
- if not isinstance(start_dim, int) or not isinstance(end_dim, int) or \
1734
- isinstance(start_dim, bool) or isinstance(end_dim, bool):
1735
- raise TypeError(f"For 'flatten', both 'start_dim' and 'end_dim' must be int.")
1736
- check_flatten_order_const(order)
1737
- if order == 'F':
1738
- x_rank = rank_(input)
1739
- # If input is a 0-dimensional Tensor, a 1-dimensional Tensor will be returned.
1740
- if x_rank in (0, 1):
1741
- return reshape_(input, (-1,))
1742
- perm = ops.make_range(0, x_rank)
1743
- new_order = ops.tuple_reversed(perm)
1744
- input = _get_cache_prim(P.Transpose)()(input, new_order)
1745
-
1746
- # Handle the default case.
1747
- x_shape = shape_(input)
1748
- x_rank = rank_(input)
1749
- if start_dim == 1 and end_dim == -1:
1750
- if x_rank in (0, 1):
1751
- return reshape_(input, (-1,))
1752
- return _get_cache_prim(P.Flatten)()(input)
1753
-
1754
- # Check axis.
1755
- start_dim = canonicalize_axis(start_dim, x_rank)
1756
- end_dim = canonicalize_axis(end_dim, x_rank)
1757
- check_dim_valid(start_dim, end_dim)
1758
- # If input is a 0-dimensional Tensor, a 1-dimensional Tensor will be returned.
1759
- if x_rank in (0, 1):
1760
- return reshape_(input, (-1,))
1761
- # If no dimensions to flatten, return the original object.
1762
- if start_dim == end_dim:
1763
- return input
1764
- # Flatten elements along specified dimensions.
1765
- dim_length = 1
1766
- idx = start_dim
1767
- while idx <= end_dim:
1768
- dim_length *= x_shape[idx]
1769
- idx += 1
1770
- new_shape = x_shape[:start_dim] + (dim_length,) + x_shape[end_dim + 1:]
1771
- return reshape_(input, new_shape)
1772
-
1773
-
1774
- @constexpr
1775
- def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name):
1776
- if isinstance(scalar, int) and tensor_type != mstype.int32:
1777
- raise TypeError(f"For functional operator[select], the input[{scalar_name}] is int, "
1778
- f"then the input[{tensor_name}] must be a Tensor of int32.")
1779
- if isinstance(scalar, float) and tensor_type != mstype.float32:
1780
- raise TypeError(f"For functional operator[select], the input[{scalar_name}] is float, "
1781
- f"then the input[{tensor_name}] must be a Tensor of float32.")
1782
-
1783
-
1784
- @_primexpr
1785
- def _check_select_shape_match(input_shape, cond_shape, tensor_name):
1786
- if input_shape != cond_shape:
1787
- raise ValueError(f"For functional operator[select], the cond shape must be same as {tensor_name} shape.")
1788
-
1789
-
1790
- @constexpr
1791
- def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor):
1792
- if not is_cond_tensor:
1793
- raise TypeError(f"For functional operator[select], the input[cond] must be a Tensor.")
1794
- if is_x_scalar and not is_y_tensor:
1795
- raise TypeError(f"For functional operator[select], the input[x] is int or float, "
1796
- f"then the input[y] must be a Tensor.")
1797
- if is_y_scalar and not is_x_tensor:
1798
- raise TypeError(f"For functional operator[select], the input[y] is int or float, "
1799
- f"then the input[x] must be a Tensor.")
1800
-
1801
-
1802
- @constexpr
1803
- def _check_select_shape_same(cond_shape, x_shape, y_shape):
1804
- """Check if input of select has same shape."""
1805
- return cond_shape == x_shape and x_shape == y_shape and cond_shape == y_shape
1806
-
1807
-
1808
- @constexpr
1809
- def get_max_value(x, y, z):
1810
- """Get the maximum value of x, y and z."""
1811
- if x >= y and x >= z:
1812
- return x
1813
- if y >= x and y >= z:
1814
- return y
1815
- return z
1816
-
1817
-
1818
- @constexpr
1819
- def _calc_broadcast_shape(cond_shape, x_shape, y_shape):
1820
- """Calculate broadcast shape for select"""
1821
- converted_shape = []
1822
- cond_reverse = cond_shape[::-1]
1823
- x_reverse = x_shape[::-1]
1824
- y_reverse = y_shape[::-1]
1825
- max_len = get_max_value(len(cond_reverse), len(x_reverse), len(y_reverse))
1826
- i = 0
1827
- while i < max_len:
1828
- cond_element = 1 if i >= len(cond_reverse) else cond_reverse[i]
1829
- x_element = 1 if i >= len(x_reverse) else x_reverse[i]
1830
- y_element = 1 if i >= len(y_reverse) else y_reverse[i]
1831
- broadcast_element = get_max_value(cond_element, x_element, y_element)
1832
- if cond_element not in (1, broadcast_element):
1833
- raise ValueError(f"For select, condition input can not broadcast at index {i}")
1834
- if x_element not in (1, broadcast_element):
1835
- raise ValueError(f"For select, x input can not broadcast at index {i}")
1836
- if y_element not in (1, broadcast_element):
1837
- raise ValueError(f"For select, y input can not broadcast at index {i}")
1838
- converted_shape.append(broadcast_element)
1839
- i = i + 1
1840
- converted_shape.reverse()
1841
- return tuple(converted_shape)
1842
-
1843
-
1844
- def select(cond, x, y):
1845
- r"""
1846
- The conditional tensor determines whether the corresponding element in the output must be
1847
- selected from `x` (if true) or `y` (if false) based on the value of each element.
1848
-
1849
- It can be defined as:
1850
-
1851
- .. math::
1852
- out_i = \begin{cases}
1853
- x_i, & \text{if } cond_i \\
1854
- y_i, & \text{otherwise}
1855
- \end{cases}
1856
-
1857
- Args:
1858
- cond (Tensor[bool]): The condition tensor, decides which element is chosen.
1859
- The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
1860
- x (Union[Tensor, int, float]): The first Tensor or number to be selected.
1861
- If x is a Tensor, the shape is or can be broadcadt to :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
1862
- If x is an int or a float, it will be cast to the type of int32 or float32,
1863
- and broadcast to the same shape as y. One of x and y must be a Tensor.
1864
- y (Union[Tensor, int, float]): The second Tensor or number to be selected.
1865
- If y is a Tensor, The shape is or can be broadcadt to :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
1866
- If y is an int or a float, it will be cast to the type of int32 or float32,
1867
- and broadcast to the same shape as x. One of x and y must be a Tensor.
1868
-
1869
- Returns:
1870
- Tensor, has the same shape as `cond`.
1871
-
1872
- Raises:
1873
- TypeError: If `x` or `y` is not a Tensor, int or float.
1874
- ValueError: The shapes of inputs can not be broadcast.
1875
-
1876
- Supported Platforms:
1877
- ``Ascend`` ``GPU`` ``CPU``
1878
-
1879
- Examples:
1880
- >>> import mindspore
1881
- >>> from mindspore import Tensor, ops
1882
- >>> # 1) Both inputs are Tensor
1883
- >>>
1884
- >>> cond = Tensor([True, False])
1885
- >>> x = Tensor([2,3], mindspore.float32)
1886
- >>> y = Tensor([1,2], mindspore.float32)
1887
- >>> output = ops.select(cond, x, y)
1888
- >>> print(output)
1889
- [2. 2.]
1890
- >>> # 2) y is a float
1891
- >>> cond = Tensor([True, False])
1892
- >>> x = Tensor([2,3], mindspore.float32)
1893
- >>> y = 2.0
1894
- >>> output = ops.select(cond, x, y)
1895
- >>> print(output)
1896
- [2. 2.]
1897
- """
1898
- is_x_scalar = isinstance(x, (int, float))
1899
- is_y_scalar = isinstance(y, (int, float))
1900
- is_x_tensor = isinstance(x, Tensor)
1901
- is_y_tensor = isinstance(y, Tensor)
1902
- is_cond_tensor = isinstance(cond, Tensor)
1903
- _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor)
1904
- input_x = x
1905
- input_y = y
1906
- if is_x_scalar:
1907
- _check_select_shape_match(y.shape, cond.shape, "y")
1908
- _check_select_type_match(x, y.dtype, "x", "y")
1909
- input_x = zeros_like_(y) + x
1910
- if isinstance(x, int):
1911
- input_x = cast_(input_x, mstype.int32)
1912
- else:
1913
- input_x = cast_(input_x, mstype.float32)
1914
-
1915
- if is_y_scalar:
1916
- _check_select_shape_match(x.shape, cond.shape, "x")
1917
- _check_select_type_match(y, x.dtype, "y", "x")
1918
- input_y = zeros_like_(x) + y
1919
- if isinstance(y, int):
1920
- input_y = cast_(input_y, mstype.int32)
1921
- else:
1922
- input_y = cast_(input_y, mstype.float32)
1923
-
1924
- if is_x_tensor and is_y_tensor and is_cond_tensor:
1925
- x_shape = ops.shape(x)
1926
- y_shape = ops.shape(y)
1927
- cond_shape = ops.shape(cond)
1928
- all_constant = ops.isconstant(cond_shape) and ops.isconstant(x_shape) and ops.isconstant(y_shape)
1929
- if all_constant and not _check_select_shape_same(cond_shape, x_shape, y_shape):
1930
- broadcast_shape = _calc_broadcast_shape(cond_shape, x_shape, y_shape)
1931
- new_cond = ops.broadcast_to(cond, broadcast_shape)
1932
- new_x = ops.broadcast_to(x, broadcast_shape)
1933
- new_y = ops.broadcast_to(y, broadcast_shape)
1934
- return tensor_select_(new_cond, new_x, new_y)
1935
-
1936
- return tensor_select_(cond, input_x, input_y)
1937
-
1938
-
1939
- def strided_slice(input_x,
1940
- begin,
1941
- end,
1942
- strides,
1943
- begin_mask=0,
1944
- end_mask=0,
1945
- ellipsis_mask=0,
1946
- new_axis_mask=0,
1947
- shrink_axis_mask=0):
1948
- r"""
1949
- Extracts a strided slice of a Tensor based on `begin/end` index and `strides`.
1950
-
1951
- This operation extracts a fragment of size (end-begin)/strides from the given 'input_tensor'.
1952
- Starting from the beginning position, the fragment continues adding strides to the index until
1953
- all dimensions are not less than the ending position.
1954
-
1955
- Note:
1956
- - `begin` , `end` and `strides` must have the same shape.
1957
- - `begin` , `end` and `strides` are all 1-D Tensor, and their shape size
1958
- must not greater than the dim of `input_x`.
1959
-
1960
- During the slicing process, the fragment (end-begin)/strides are extracted from each dimension.
1961
-
1962
- Example: For Tensor `input_x` with shape :math:`(5, 6, 7)`,
1963
- set `begin`, `end` and `strides` to (1, 3, 2), (3, 5, 6),
1964
- (1, 1, 2) respectively, then elements from index 1 to 3 are extrected for dim 0, index 3 to 5
1965
- are extrected for dim 1 and index 2 to 6 with a `stirded` of 2 are extrected for dim 2, this
1966
- process is equivalent to a pythonic slice `input_x[1:3, 3:5, 2:6:2]`.
1967
-
1968
- If the length of `begin` 、 `end` and `strides` is smaller than the dim of `input_x`,
1969
- then all elements are extracted from the missing dims, it behaves like all the
1970
- missing dims are filled with zeros, size of that missing dim and ones.
1971
-
1972
- Example: For Tensor `input_x` with shape :math:`(5, 6, 7)`,
1973
- set `begin`, `end` and `strides` to (1, 3),
1974
- (3, 5), (1, 1) respectively, then elements from index 1 to 3 are extrected
1975
- for dim 0, index 3 to 5 are extrected for dim 1 and index 3 to 5 are extrected
1976
- for dim 2, this process is equivalent to a pythonic slice `input_x[1:3, 3:5, 0:7]`.
1977
-
1978
- Here's how a mask works:
1979
- For each specific mask, it will be converted to a binary representation internally, and then
1980
- reverse the result to start the calculation. For Tensor `input_x` with
1981
- shape :math:`(5, 6, 7)`. Given mask value of 3 which
1982
- can be represented as 0b011. Reverse that we get 0b110, which implies the first and second dim of the
1983
- original Tensor will be effected by this mask. See examples below, for simplicity all mask mentioned
1984
- below are all in their reverted binary form:
1985
-
1986
- - `begin_mask` and `end_mask`
1987
-
1988
- If the ith bit of `begin_mask` is 1, `begin[i]` is ignored and the fullest
1989
- possible range in that dimension is used instead. `end_mask` is analogous,
1990
- except with the end range. For Tensor `input_x` with shape :math:`(5, 6, 7, 8)`, if `begin_mask`
1991
- is 0b110, `end_mask` is 0b011, the slice `input_x[0:3, 0:6, 2:7:2]` is produced.
1992
-
1993
- - `ellipsis_mask`
1994
-
1995
- If the ith bit of `ellipsis_mask` is 1, as many unspecified dimensions as needed
1996
- will be inserted between other dimensions. Only one non-zero bit is allowed
1997
- in `ellipsis_mask`. For Tensor `input_x` with shape :math:`(5, 6, 7, 8)`, `input_x[2:,...,:6]`
1998
- is equivalent to `input_x[2:5,:,:,0:6]` , `input_x[2:,...]` is equivalent
1999
- to `input_x[2:5,:,:,:]`.
2000
-
2001
- - `new_axis_mask`
2002
-
2003
- If the ith bit of `new_axis_mask` is 1, `begin`, `end` and `strides` are
2004
- ignored and a new length 1 dimension is added at the specified position
2005
- in the output Tensor. For Tensor `input_x` with shape :math:`(5, 6, 7)`, if `new_axis_mask`
2006
- is 0b110, a new dim is added to the second dim, which will produce
2007
- a Tensor with shape :math:`(5, 1, 6, 7)`.
2008
-
2009
- - `shrink_axis_mask`
2010
-
2011
- If the ith bit of `shrink_axis_mask` is 1, `begin`, `end` and `strides`
2012
- are ignored and dimension i will be shrunk to 0.
2013
- For Tensor `input_x` with shape :math:`(5, 6, 7)`,
2014
- if `shrink_axis_mask` is 0b010, it is equivalent to slice `x[:, 5, :]`
2015
- and results in an output shape of :math:`(5, 7)`.
2016
-
2017
- Note:
2018
- `new_axis_mask` and `shrink_axis_mask` are not recommended to
2019
- use at the same time, it might incur unexpected result.
2020
-
2021
- Args:
2022
- input_x (Tensor): The input Tensor to be extracted from.
2023
- begin (tuple[int]): A tuple which represents the location where to start.
2024
- end (tuple[int]): A tuple or which represents the maximum location where to end.
2025
- strides (tuple[int]): A tuple which represents the strides is continuously added
2026
- before reaching the maximum location. Only int is allowed, it can be negative
2027
- which results in reversed slicing.
2028
- begin_mask (int, optional): Starting index of the slice. Default: ``0`` .
2029
- end_mask (int, optional): Ending index of the slice. Default: ``0`` .
2030
- ellipsis_mask (int, optional): An int mask, ignore slicing operation when set to 1. Default: ``0`` .
2031
- new_axis_mask (int, optional): An int mask for adding new dims. Default: ``0`` .
2032
- shrink_axis_mask (int, optional): An int mask for shrinking dims. Default: ``0`` .
1710
+ def check_dim_valid(start_dim, end_dim):
1711
+ if start_dim > end_dim:
1712
+ raise ValueError(
1713
+ "For 'flatten', 'start_dim' cannot come after 'end_dim'.")
2033
1714
 
2034
- Returns:
2035
- Tensor, return the extracts a strided slice of a Tensor based on `begin/end` index and `strides`.
1715
+ def canonicalize_axis(axis, x_rank):
1716
+ ndim = x_rank if x_rank != 0 else 1
1717
+ check_axis_valid(axis, ndim)
1718
+ return axis if axis >= 0 else axis + ndim
2036
1719
 
2037
- Raises:
2038
- TypeError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or
2039
- `shrink_axis_mask` is not an int.
2040
- TypeError: If `begin`, `end` or `strides` is not tuple[int].
2041
- ValueError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or
2042
- `shrink_axis_mask` is less than 0.
2043
- ValueError: If `begin`, `end` and `strides` have different shapes.
1720
+ # Check the types of arguments.
1721
+ if not isinstance(input, Tensor):
1722
+ raise TypeError(f"For 'flatten', argument 'input' must be Tensor.")
1723
+ if not isinstance(start_dim, int) or not isinstance(end_dim, int) or \
1724
+ isinstance(start_dim, bool) or isinstance(end_dim, bool):
1725
+ raise TypeError(
1726
+ f"For 'flatten', both 'start_dim' and 'end_dim' must be int.")
1727
+ check_flatten_order_const(order)
1728
+ if order == 'F':
1729
+ x_rank = rank_(input)
1730
+ # If input is a 0-dimensional Tensor, a 1-dimensional Tensor will be returned.
1731
+ if x_rank in (0, 1):
1732
+ return reshape_(input, (-1,))
1733
+ perm = ops.make_range(0, x_rank)
1734
+ new_order = ops.tuple_reversed(perm)
1735
+ input = transpose_(input, new_order)
2044
1736
 
2045
- Supported Platforms:
2046
- ``Ascend`` ``GPU`` ``CPU``
1737
+ # Handle the default case.
1738
+ x_shape = shape_(input)
1739
+ x_rank = rank_(input)
1740
+ if start_dim == 1 and end_dim == -1:
1741
+ if x_rank in (0, 1):
1742
+ return reshape_(input, (-1,))
1743
+ return flatten_(input)
2047
1744
 
2048
- Examples:
2049
- >>> import mindspore
2050
- >>> from mindspore import Tensor, ops
2051
- >>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
2052
- ... [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
2053
- >>> output = ops.strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
2054
- >>> # Take this " output = strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1)) " as an example,
2055
- >>> # start = [1, 0, 2] , end = [3, 1, 3], strides = [1, 1, 1], Find a segment of (start, end),
2056
- >>> # note that end is an open interval
2057
- >>> # To facilitate understanding, this operator can be divided into three steps:
2058
- >>> # Step 1: Calculation of the first dimension:
2059
- >>> # start = 1, end = 3, strides = 1, So can take 1st, 2nd rows, and then gets the final output at this time.
2060
- >>> # output_1th =
2061
- >>> # [
2062
- >>> # [
2063
- >>> # [3,3,3]
2064
- >>> # [4,4,4]
2065
- >>> # ]
2066
- >>> # [
2067
- >>> # [5,5,5]
2068
- >>> # [6,6,6]
2069
- >>> # ]
2070
- >>> # ]
2071
- >>> # Step 2: Calculation of the second dimension
2072
- >>> # 2nd dimension, start = 0, end = 1, strides = 1. So only 0th rows
2073
- >>> # can be taken, and the output at this time.
2074
- >>> # output_2nd =
2075
- >>> # [
2076
- >>> # [
2077
- >>> # [3,3,3]
2078
- >>> # ]
2079
- >>> # [
2080
- >>> # [5,5,5]
2081
- >>> # ]
2082
- >>> # ]
2083
- >>> # Step 3: Calculation of the third dimension
2084
- >>> # 3nd dimension,start = 2, end = 3, strides = 1, So can take 2th cols,
2085
- >>> # and you get the final output at this time.
2086
- >>> # output_3ed =
2087
- >>> # [
2088
- >>> # [
2089
- >>> # [3]
2090
- >>> # ]
2091
- >>> # [
2092
- >>> # [5]
2093
- >>> # ]
2094
- >>> # ]
2095
- >>> # The final output after finishing is:
2096
- >>> print(output)
2097
- [[[3.]]
2098
- [[5.]]]
2099
- >>> # another example like :
2100
- >>> output = strided_slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
2101
- >>> print(output)
2102
- [[[3. 3. 3.]]]
2103
- """
2104
- strided_slice_ = _get_cache_prim(P.StridedSlice)(
2105
- begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
2106
- return strided_slice_(input_x, begin, end, strides)
1745
+ # Check axis.
1746
+ start_dim = canonicalize_axis(start_dim, x_rank)
1747
+ end_dim = canonicalize_axis(end_dim, x_rank)
1748
+ check_dim_valid(start_dim, end_dim)
1749
+ # If input is a 0-dimensional Tensor, a 1-dimensional Tensor will be returned.
1750
+ if x_rank in (0, 1):
1751
+ return reshape_(input, (-1,))
1752
+ # If no dimensions to flatten, return the original object.
1753
+ if start_dim == end_dim:
1754
+ return input
1755
+ # Flatten elements along specified dimensions.
1756
+ dim_length = 1
1757
+ idx = start_dim
1758
+ while idx <= end_dim:
1759
+ dim_length *= x_shape[idx]
1760
+ idx += 1
1761
+ new_shape = x_shape[:start_dim] + (dim_length,) + x_shape[end_dim + 1:]
1762
+ return reshape_(input, new_shape)
2107
1763
 
2108
1764
 
2109
1765
  def slice(input_x, begin, size):
@@ -2160,20 +1816,6 @@ def slice(input_x, begin, size):
2160
1816
  return tensor_slice(input_x, begin, size)
2161
1817
 
2162
1818
 
2163
- def concat(tensors, axis=0):
2164
- """
2165
- Alias for :func:`mindspore.ops.cat()`.
2166
-
2167
- Tutorial Examples:
2168
- - `Tensor - Tensor Operation <https://mindspore.cn/tutorials/en/r2.2/beginner/tensor.html#tensor-operation>`_
2169
- - `Vision Transformer Image Classification - Building ViT as a whole
2170
- <https://mindspore.cn/tutorials/application/en/r2.2/cv/vit.html#building-vit-as-a-whole>`_
2171
- - `Sentiment Classification Implemented by RNN - Dense
2172
- <https://mindspore.cn/tutorials/application/en/r2.2/nlp/sentiment_analysis.html#dense>`_
2173
- """
2174
- return cat(tensors, axis)
2175
-
2176
-
2177
1819
  def stack(tensors, axis=0):
2178
1820
  r"""
2179
1821
  Stacks a list of tensors in specified axis.
@@ -2284,45 +1926,6 @@ def unbind(input, dim=0):
2284
1926
  return _unstack(input)
2285
1927
 
2286
1928
 
2287
- def expand_dims(input_x, axis):
2288
- """
2289
- Adds an additional dimension to `input_x` at the given axis, the dimension
2290
- of `input_x` should be greater than or equal to 1.
2291
-
2292
- Note:
2293
- If the specified axis is a negative number, the index is counted
2294
- backward from the end and starts at 1.
2295
-
2296
- Args:
2297
- input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
2298
- axis (int): Specifies the dimension index at which to expand
2299
- the shape of `input_x`. The value of axis must be in the range
2300
- `[-input_x.ndim-1, input_x.ndim]`. Only constant value is allowed.
2301
-
2302
- Returns:
2303
- Tensor, the shape of tensor is :math:`(1, x_1, x_2, ..., x_R)` if the
2304
- value of `axis` is 0. It has the same data type as `input_x`.
2305
-
2306
- Raises:
2307
- TypeError: If `axis` is not an int.
2308
- ValueError: If `axis` is not in the valid range :math:`[-a.ndim-1, a.ndim]`.
2309
-
2310
- Supported Platforms:
2311
- ``Ascend`` ``GPU`` ``CPU``
2312
-
2313
- Examples:
2314
- >>> import mindspore
2315
- >>> import numpy as np
2316
- >>> from mindspore import Tensor, ops
2317
- >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
2318
- >>> output = ops.expand_dims(input_tensor, 0)
2319
- >>> print(output)
2320
- [[[2. 2.]
2321
- [2. 2.]]]
2322
- """
2323
- return expand_dims_(input_x, axis)
2324
-
2325
-
2326
1929
  def unsqueeze(input, dim):
2327
1930
  """
2328
1931
  Adds an additional dimension to `input` at the given dim.
@@ -2354,7 +1957,7 @@ def unsqueeze(input, dim):
2354
1957
  [[[2. 2.]
2355
1958
  [2. 2.]]]
2356
1959
  """
2357
- return expand_dims_(input, dim)
1960
+ return expand_dims(input, dim)
2358
1961
 
2359
1962
 
2360
1963
  def squeeze(input, axis=None):
@@ -2411,57 +2014,6 @@ def squeeze(input, axis=None):
2411
2014
  return squeeze_(input)
2412
2015
 
2413
2016
 
2414
- def transpose(input, input_perm):
2415
- """
2416
- Permutes the dimensions of the input tensor according to input permutation.
2417
-
2418
- For a 1-D array this has no effect, as a transposed vector is simply the same vector.
2419
- To convert a 1-D array into a 2D column vector please refer the class: mindspore.ops.ExpandDims.
2420
- For a 2-D array, this is a standard matrix transpose. For an n-D array, if axes are given,
2421
- their order indicates how the axes are permuted (see Examples).
2422
- If axes are not provided and a.shape is :math:`(i[0], i[1], ... i[n-2], i[n-1])`,
2423
- then a.transpose().shape is :math:`(i[n-1], i[n-2], ... i[1], i[0])`.
2424
-
2425
- Note:
2426
- On GPU and CPU, if the value of `input_perm` is negative, its actual value is `input_perm[i] + rank(input)`.
2427
- Negative value of `input_perm` is not supported on Ascend.
2428
-
2429
- Args:
2430
- input (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
2431
- input_perm (tuple[int]): The permutation to be converted. The elements in `input_perm` are composed of
2432
- the indexes of each dimension of `input`. The length of `input_perm` and the shape of `input` must be
2433
- the same. Only constant value is allowed. Must be in the range [-rank(input), rank(input)).
2434
-
2435
- Returns:
2436
- Tensor, the type of output tensor is the same as `input` and the shape of output tensor is decided by the
2437
- shape of `input` and the value of `input_perm`.
2438
-
2439
- Raises:
2440
- TypeError: If `input_perm` is not a tuple.
2441
- ValueError: If length of shape of `input` is not equal to length of shape of `input_perm`.
2442
- ValueError: If the same element exists in `input_perm`.
2443
-
2444
- Supported Platforms:
2445
- ``Ascend`` ``GPU`` ``CPU``
2446
-
2447
- Examples:
2448
- >>> import mindspore
2449
- >>> import numpy as np
2450
- >>> from mindspore import Tensor, ops
2451
- >>> input = Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), mindspore.float32)
2452
- >>> input_perm = (0, 2, 1)
2453
- >>> output = ops.transpose(input, input_perm)
2454
- >>> print(output)
2455
- [[[ 1. 4.]
2456
- [ 2. 5.]
2457
- [ 3. 6.]]
2458
- [[ 7. 10.]
2459
- [ 8. 11.]
2460
- [ 9. 12.]]]
2461
- """
2462
- return transpose_(input, input_perm)
2463
-
2464
-
2465
2017
  def scatter_mul(input_x, indices, updates):
2466
2018
  r"""
2467
2019
  Using given values to update tensor value through the mul operation, along with the input indices.
@@ -2792,111 +2344,6 @@ def scatter_div(input_x, indices, updates):
2792
2344
  return scatter_div_(input_x, indices, updates)
2793
2345
 
2794
2346
 
2795
- def scatter_nd(indices, updates, shape):
2796
- r"""
2797
- Scatters a tensor into a new tensor depending on the specified indices.
2798
-
2799
- Creates an empty tensor with the given `shape`, and set values by scattering the update tensor
2800
- depending on indices. The empty tensor has rank :math:`P` and `indices` has rank :math:`Q`.
2801
-
2802
- The `shape` is :math:`(s_0, s_1, ..., s_{P-1})`, where :math:`P \ge 1`.
2803
-
2804
- `indices` has shape :math:`(i_0, i_1, ..., i_{Q-2}, N)`, where :math:`Q \ge 2` and :math:`N \le P`.
2805
-
2806
- The last dimension of `indices` (with length :math:`N` ) indicates slices along the :math:`N` th dimension of the
2807
- empty tensor.
2808
-
2809
- `updates` is a tensor of rank :math:`Q-1+P-N`, and
2810
- its shape is :math:`(i_0, i_1, ..., i_{Q-2}, s_N, s_{N+1}, ..., s_{P-1})`.
2811
-
2812
- If `indices` contains duplicates, the duplicate `updates` are summed.
2813
-
2814
- The following figure shows the calculation process of inserting two new value matrices into the first dimension
2815
- with rank-3:
2816
-
2817
- .. image:: ScatterNd.png
2818
-
2819
- Args:
2820
- indices (Tensor): Define the index of scattering in the new tensor with int32 or int64 data type.
2821
- The rank of `indices` must be at least 2 and `indices.shape[-1] <= len(shape)`.
2822
- updates (Tensor): Define the source Tensor to be updated.
2823
- It has shape `indices.shape[:-1] + shape[indices.shape[-1]:]`.
2824
- shape (tuple[int]): Define the shape of the output tensor, has the same data type as indices.
2825
- `shape` can not be empty, and the elements in `shape` must be greater than or equal to 1.
2826
-
2827
- Returns:
2828
- Tensor, the new tensor, has the same type as `update` and the same shape as `shape`.
2829
-
2830
- Raises:
2831
- TypeError: If `shape` is not a tuple.
2832
- ValueError: If any element of `shape` is less than 1.
2833
-
2834
- Supported Platforms:
2835
- ``Ascend`` ``GPU`` ``CPU``
2836
-
2837
- Examples:
2838
- >>> import mindspore
2839
- >>> import numpy as np
2840
- >>> from mindspore import Tensor, ops
2841
- >>> indices = Tensor(np.array([[0], [2]]), mindspore.int32)
2842
- >>> updates = Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2],
2843
- ... [3, 3, 3, 3], [4, 4, 4, 4]],
2844
- ... [[1, 1, 1, 1], [2, 2, 2, 2],
2845
- ... [3, 3, 3, 3], [4, 4, 4, 4]]]), mindspore.float32)
2846
- >>> shape = (4, 4, 4)
2847
- >>> output = ops.scatter_nd(indices, updates, shape)
2848
- >>> print(output)
2849
- [[[1. 1. 1. 1.]
2850
- [2. 2. 2. 2.]
2851
- [3. 3. 3. 3.]
2852
- [4. 4. 4. 4.]]
2853
- [[0. 0. 0. 0.]
2854
- [0. 0. 0. 0.]
2855
- [0. 0. 0. 0.]
2856
- [0. 0. 0. 0.]]
2857
- [[1. 1. 1. 1.]
2858
- [2. 2. 2. 2.]
2859
- [3. 3. 3. 3.]
2860
- [4. 4. 4. 4.]]
2861
- [[0. 0. 0. 0.]
2862
- [0. 0. 0. 0.]
2863
- [0. 0. 0. 0.]
2864
- [0. 0. 0. 0.]]]
2865
- >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
2866
- >>> updates = Tensor(np.array([3.2, 1.1]), mindspore.float32)
2867
- >>> shape = (3, 3)
2868
- >>> output = ops.scatter_nd(indices, updates, shape)
2869
- >>> # In order to facilitate understanding, explain the operator pseudo-operation process step by step:
2870
- >>> # Step 1: Generate an empty Tensor of the specified shape according to the shape
2871
- >>> # [
2872
- >>> # [0. 0. 0.]
2873
- >>> # [0. 0. 0.]
2874
- >>> # [0. 0. 0.]
2875
- >>> # ]
2876
- >>> # Step 2: Modify the data at the specified location according to the indicators
2877
- >>> # 0th row of indices is [0, 1], 0th row of updates is 3.2.
2878
- >>> # means that the empty tensor in the 0th row and 1st col set to 3.2
2879
- >>> # [
2880
- >>> # [0. 3.2. 0.]
2881
- >>> # [0. 0. 0.]
2882
- >>> # [0. 0. 0.]
2883
- >>> # ]
2884
- >>> # 1th row of indices is [1, 1], 1th row of updates is 1.1.
2885
- >>> # means that the empty tensor in the 1th row and 1st col set to 1.1
2886
- >>> # [
2887
- >>> # [0. 3.2. 0.]
2888
- >>> # [0. 1.1 0.]
2889
- >>> # [0. 0. 0.]
2890
- >>> # ]
2891
- >>> # The final result is as follows:
2892
- >>> print(output)
2893
- [[0. 3.2 0.]
2894
- [0. 1.1 0.]
2895
- [0. 0. 0.]]
2896
- """
2897
- return scatter_nd_(indices, updates, shape)
2898
-
2899
-
2900
2347
  def scatter_update(input_x, indices, updates):
2901
2348
  r"""
2902
2349
  Updates tensor values by using input indices and value.
@@ -2946,8 +2393,7 @@ def scatter_update(input_x, indices, updates):
2946
2393
  [[2. 1.2 1.]
2947
2394
  [3. 1.2 1.]]
2948
2395
  """
2949
- scatter_update_inner = _get_cache_prim(P.ScatterUpdate)()
2950
- return scatter_update_inner(input_x, indices, updates)
2396
+ return scatter_update_(input_x, indices, updates)
2951
2397
 
2952
2398
 
2953
2399
  def scatter_nd_add(input_x, indices, updates, use_locking=False):
@@ -3414,8 +2860,8 @@ def sort(input_x, axis=-1, descending=False):
3414
2860
  are sorted in descending order, or else sorted in ascending order. Default: ``False`` .
3415
2861
 
3416
2862
  .. warning::
3417
- Currently, the data types of Float16, UInt8, Int8, Int16, Int32, Int64 are well supported.
3418
- If use Float32, it may cause loss of accuracy.
2863
+ Currently, the data types of float16, uint8, int8, int16, int32, int64 are well supported.
2864
+ If use float32, it may cause loss of accuracy.
3419
2865
 
3420
2866
  Returns:
3421
2867
 
@@ -3452,129 +2898,72 @@ def sort(input_x, axis=-1, descending=False):
3452
2898
  return _sort(input_x)
3453
2899
 
3454
2900
 
3455
- def argsort(input, axis=-1, descending=False):
2901
+ def sort_ext(input, *, dim=-1, descending=False, stable=False):
3456
2902
  r"""
3457
- Sorts the input tensor along the given dimension in specified order and return the sorted indices.
2903
+ Sorts the elements of the input tensor along the given dimension in the specified order.
2904
+
2905
+ .. warning::
2906
+ Currently, the data types of float16, uint8, int8, int16, int32, int64 are well supported.
2907
+ If use float32, it may cause loss of accuracy.
3458
2908
 
3459
2909
  Args:
3460
2910
  input(Tensor): The input tensor to sort.
3461
- axis (int): The axis to sort along. Default: ``-1`` , means the last dimension.
3462
- The Ascend backend only supports sorting the last dimension.
3463
- descending (bool): The sort order. If `descending` is True then the elements
3464
- are sorted in descending order by value. Otherwise sort in ascending order. Default: ``False`` .
2911
+ The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
2912
+
2913
+ Keyword Args:
2914
+ dim (int, optional): The dimension to sort along. Default: ``-1``, means the last dimension.
2915
+ descending (bool, optional): Controls the sort order. If `descending` is True, the elements
2916
+ are sorted in descending order, or else sorted in ascending order. Default: ``False`` .
2917
+ stable (bool, optional): Controls the sort order. If stable is True then the sorting routine
2918
+ becomes stable, preserving the order of equivalent elements. Default: ``False`` .
3465
2919
 
3466
2920
  Returns:
3467
- Tensor, the indices of sorted input tensor. Data type is int32.
2921
+ - y1, a tensor whose values are the sorted values, with the same shape and data type as input.
2922
+ - y2, a tensor that consists of the indices of the elements in the original input tensor.
2923
+ Data type is int64.
2924
+
2925
+ Raises:
2926
+ TypeError: If `dim` is not an int.
2927
+ TypeError: If `descending` is not a bool.
2928
+ TypeError: If `input` not in float16, float32, uint8, int8, int16, int32, int64, bfloat16
2929
+ TypeError: If `stable` is not a bool.
2930
+ ValueError: If `dim` is not in range of [-len(input_x.shape), len(input_x.shape)).
3468
2931
 
3469
2932
  Supported Platforms:
3470
- ``Ascend`` ``GPU`` ``CPU``
2933
+ ``Ascend``
3471
2934
 
3472
2935
  Examples:
3473
2936
  >>> import mindspore
3474
2937
  >>> import numpy as np
3475
2938
  >>> from mindspore import Tensor, ops
3476
2939
  >>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
3477
- >>> sort = ops.argsort(x)
3478
- >>> print(sort)
3479
- [[2 1 0]
3480
- [2 0 1]
3481
- [0 1 2]]
2940
+ >>> output = ops.function.array_func.sort_ext(x)
2941
+ >>> # The output below is based on the Ascend platform.
2942
+ >>> print(output)
2943
+ (Tensor(shape=[3, 3], dtype=Float16, value=
2944
+ [[ 1.0000e+00, 2.0000e+00, 8.0000e+00],
2945
+ [ 3.0000e+00, 5.0000e+00, 9.0000e+00],
2946
+ [ 4.0000e+00, 6.0000e+00, 7.0000e+00]]), Tensor(shape=[3, 3], dtype=Int64, value=
2947
+ [[2, 1, 0],
2948
+ [2, 0, 1],
2949
+ [0, 1, 2]]))
3482
2950
  """
3483
- _sort = _get_cache_prim(P.Sort)(axis, descending)
3484
- _, arg_sort = _sort(input)
3485
- return arg_sort
2951
+ return sort_ext_(input, dim, descending, stable)
3486
2952
 
3487
2953
 
3488
- def gather(input_params, input_indices, axis, batch_dims=0):
2954
+ def argsort(input, axis=-1, descending=False):
3489
2955
  r"""
3490
- Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
3491
-
3492
- The following figure shows the calculation process of Gather commonly:
3493
-
3494
- .. image:: Gather.png
3495
-
3496
- where params represents the input `input_params`, and indices represents the index to be sliced `input_indices`.
3497
-
3498
- .. note::
3499
- 1. The value of input_indices must be in the range of `[0, input_param.shape[axis])`.
3500
- On CPU and GPU, an error is raised if an out of bound indice is found. On Ascend, the results may be
3501
- undefined.
3502
-
3503
- 2. The data type of input_params cannot be
3504
- `bool_ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.html#mindspore.dtype>`_ on Ascend
3505
- platform currently.
2956
+ Sorts the input tensor along the given dimension in specified order and return the sorted indices.
3506
2957
 
3507
2958
  Args:
3508
- input_params (Tensor): The original Tensor. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
3509
- input_indices (Tensor): Index tensor to be sliced, the shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
3510
- Specifies the indices of elements of the original Tensor. The data type can be int32 or int64.
3511
- axis (Union(int, Tensor[int])): Specifies the dimension index to gather indices.
3512
- It must be greater than or equal to `batch_dims`.
3513
- When `axis` is a Tensor, the size must be 1.
3514
- batch_dims (int): Specifies the number of batch dimensions. It must be less than or euqal to the rank
3515
- of `input_indices`. Default: ``0`` .
2959
+ input(Tensor): The input tensor to sort.
2960
+ axis (int): The axis to sort along. Default: ``-1`` , means the last dimension.
2961
+ The Ascend backend only supports sorting the last dimension.
2962
+ descending (bool): The sort order. If `descending` is True then the elements
2963
+ are sorted in descending order by value. Otherwise sort in ascending order. Default: ``False`` .
3516
2964
 
3517
2965
  Returns:
3518
- Tensor, the shape of tensor is
3519
- :math:`input\_params.shape[:axis] + input\_indices.shape[batch\_dims:] + input\_params.shape[axis + 1:]`.
3520
-
3521
- Raises:
3522
- TypeError: If `axis` is not an int or Tensor.
3523
- ValueError: If `axis` is a Tensor and its size is not 1.
3524
- TypeError: If `input_params` is not a tensor.
3525
- TypeError: If `input_indices` is not a tensor of type int.
3526
- RuntimeError: If `input_indices` is out of range `[0, input_param.shape[axis])` on CPU or GPU.
3527
-
3528
- Supported Platforms:
3529
- ``Ascend`` ``GPU`` ``CPU``
3530
-
3531
- Examples:
3532
- >>> import mindspore
3533
- >>> import numpy as np
3534
- >>> from mindspore import Tensor, ops
3535
- >>> # case1: input_indices is a Tensor with shape (5, ).
3536
- >>> input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32)
3537
- >>> input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32)
3538
- >>> axis = 0
3539
- >>> output = ops.gather(input_params, input_indices, axis)
3540
- >>> print(output)
3541
- [1. 3. 5. 3. 7.]
3542
- >>> # case2: input_indices is a Tensor with shape (2, 2). When the input_params has one dimension,
3543
- >>> # the output shape is equal to the input_indices shape.
3544
- >>> input_indices = Tensor(np.array([[0, 2], [2, 6]]), mindspore.int32)
3545
- >>> axis = 0
3546
- >>> output = ops.gather(input_params, input_indices, axis)
3547
- >>> print(output)
3548
- [[1. 3.]
3549
- [3. 7.]]
3550
- >>> # case3: input_indices is a Tensor with shape (2, ) and
3551
- >>> # input_params is a Tensor with shape (3, 4) and axis is 0.
3552
- >>> input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32)
3553
- >>> input_indices = Tensor(np.array([0, 2]), mindspore.int32)
3554
- >>> axis = 0
3555
- >>> output = ops.gather(input_params, input_indices, axis)
3556
- >>> print(output)
3557
- [[ 1. 2. 3. 4.]
3558
- [ 9. 10. 11. 12.]]
3559
- >>> # case4: input_indices is a Tensor with shape (2, ) and
3560
- >>> # input_params is a Tensor with shape (3, 4) and axis is 1, batch_dims is 1.
3561
- >>> input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32)
3562
- >>> input_indices = Tensor(np.array([0, 2, 1]), mindspore.int32)
3563
- >>> axis = 1
3564
- >>> batch_dims = 1
3565
- >>> output = ops.gather(input_params, input_indices, axis, batch_dims)
3566
- >>> print(output)
3567
- [ 1. 7. 10.]
3568
- """
3569
- _gather = _get_cache_prim(P.Gather)(batch_dims)
3570
- return _gather(input_params, input_indices, axis)
3571
-
3572
-
3573
- def gather_d(x, dim, index):
3574
- """
3575
- Gathers elements along an axis specified by dim.
3576
-
3577
- Refer to :func:`mindspore.ops.gather_elements` for more detail.
2966
+ Tensor, the indices of sorted input tensor. Data type is int32.
3578
2967
 
3579
2968
  Supported Platforms:
3580
2969
  ``Ascend`` ``GPU`` ``CPU``
@@ -3583,15 +2972,16 @@ def gather_d(x, dim, index):
3583
2972
  >>> import mindspore
3584
2973
  >>> import numpy as np
3585
2974
  >>> from mindspore import Tensor, ops
3586
- >>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
3587
- >>> index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
3588
- >>> dim = 1
3589
- >>> output = ops.gather_d(x, dim, index)
3590
- >>> print(output)
3591
- [[1 1]
3592
- [4 3]]
2975
+ >>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
2976
+ >>> sort = ops.argsort(x)
2977
+ >>> print(sort)
2978
+ [[2 1 0]
2979
+ [2 0 1]
2980
+ [0 1 2]]
3593
2981
  """
3594
- return gather_d_(x, dim, index)
2982
+ _sort = _get_cache_prim(P.Sort)(axis, descending)
2983
+ _, arg_sort = _sort(input)
2984
+ return arg_sort
3595
2985
 
3596
2986
 
3597
2987
  def gather_elements(input, dim, index):
@@ -3608,26 +2998,29 @@ def gather_elements(input, dim, index):
3608
2998
 
3609
2999
  output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
3610
3000
 
3611
- `input` and `index` have the same length of dimensions, and all dimensions except `dim` have the same size.
3612
- If `dim` = i, `input` is an n-D tensor with shape :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})`,
3613
- the `index` must be an n-D tensor with shape :math:`(z_0, z_1, ..., y, ..., z_{n-1})`
3614
- where `y`>=1 and the output will have the same shape with `index`.
3001
+ `input` and `index` have the same length of dimensions, and `index.shape[axis] <= input.shape[axis]`
3002
+ where axis goes through all dimensions of `input` except `dim`.
3003
+
3004
+ .. warning::
3005
+ On Ascend, the behavior is unpredictable in the following cases:
3006
+
3007
+ - the value of `index` is not in the range `[-input.shape[dim], input.shape[dim])` in forward;
3008
+ - the value of `index` is not in the range `[0, input.shape[dim])` in backward.
3615
3009
 
3616
3010
  Args:
3617
3011
  input (Tensor): The input tensor.
3618
- dim (int): The axis along which to index. It must be int32 or int64. The value range is [-input.ndim,
3619
- input.ndim).
3012
+ dim (int): The axis along which to index. It must be int32 or int64. The value range is `[-input.ndim,
3013
+ input.ndim)`.
3620
3014
  index (Tensor): The indices of elements to gather. It can be one of the following data types:
3621
- int32, int64. The value range of each index element is [-input.shape(dim), input.shape(dim)).
3015
+ int32, int64. The value range of each index element is `[-input.shape(dim), input.shape(dim))`.
3622
3016
 
3623
3017
  Returns:
3624
- Tensor, has the same shape as index tensor, the shape of tensor is :math:`(z_0, z_1, ..., y, ..., z_{n-1})`,
3625
- and has the same data type with `input`.
3018
+ Tensor, has the same shape as `index` and has the same data type with `input`.
3626
3019
 
3627
3020
  Raises:
3628
3021
  TypeError: If dtype of `dim` or `index` is neither int32 nor int64.
3629
3022
  ValueError: If length of shape of `input` is not equal to length of shape of `index`.
3630
- ValueError: If the size of the dimension except `dim` is not equal between `input` and `index`.
3023
+ ValueError: If the size of the dimension except `dim` in `input` is less than size in `index`.
3631
3024
  ValueError: If the value of `dim` is not in the expected range.
3632
3025
 
3633
3026
  Supported Platforms:
@@ -3645,49 +3038,7 @@ def gather_elements(input, dim, index):
3645
3038
  [[1 1]
3646
3039
  [4 3]]
3647
3040
  """
3648
- return gather_d_(input, dim, index)
3649
-
3650
-
3651
- def gather_nd(input_x, indices):
3652
- r"""
3653
- Gathers slices from a tensor by indices.
3654
-
3655
- Using given indices to gather slices from a tensor with a specified shape.
3656
-
3657
- `indices` is an K-dimensional integer tensor. Supposes it as a (K-1)-dimensional tensor and each element of it
3658
- defines a slice of `input_x`:
3659
-
3660
- .. math::
3661
- output[(i_0, ..., i_{K-2})] = input\_x[indices[(i_0, ..., i_{K-2})]]
3662
-
3663
- The last dimension of `indices` can not more than the rank of `input_x`:
3664
- :math:`indices.shape[-1] <= input\_x.rank`.
3665
-
3666
- Args:
3667
- input_x (Tensor): The target tensor to gather values.
3668
- indices (Tensor): The index tensor, with int32 or int64 data type.
3669
-
3670
- Returns:
3671
- Tensor, has the same type as `input_x` and the shape is
3672
- :math:`indices\_shape[:-1] + input\_x\_shape[indices\_shape[-1]:]`.
3673
-
3674
- Raises:
3675
- ValueError: If length of shape of `input_x` is less than the last dimension of `indices`.
3676
-
3677
- Supported Platforms:
3678
- ``Ascend`` ``GPU`` ``CPU``
3679
-
3680
- Examples:
3681
- >>> import mindspore
3682
- >>> import numpy as np
3683
- >>> from mindspore import Tensor, ops
3684
- >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
3685
- >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
3686
- >>> output = ops.gather_nd(input_x, indices)
3687
- >>> print(output)
3688
- [-0.1 0.5]
3689
- """
3690
- return gather_nd_(input_x, indices)
3041
+ return gather_d_(input, dim, index)
3691
3042
 
3692
3043
 
3693
3044
  def tensor_scatter_add(input_x, indices, updates):
@@ -3700,7 +3051,7 @@ def tensor_scatter_add(input_x, indices, updates):
3700
3051
 
3701
3052
  The last axis of `indices` is the depth of each index vectors. For each index vector,
3702
3053
  there must be a corresponding value in `updates`. The shape of `updates` should be
3703
- equal to the shape of `input_x[indices]`. For more details, see use cases.
3054
+ equal to the shape of `input_x[indices]`. For more details, see Examples.
3704
3055
 
3705
3056
  .. math::
3706
3057
  output\left [indices \right ] = input\_x + update
@@ -3758,7 +3109,7 @@ def tensor_scatter_sub(input_x, indices, updates):
3758
3109
 
3759
3110
  The last axis of `indices` is the depth of each index vectors. For each index vector,
3760
3111
  there must be a corresponding value in `updates`. The shape of `updates` should be
3761
- equal to the shape of `input_x[indices]`. For more details, see use cases.
3112
+ equal to the shape of `input_x[indices]`. For more details, see Examples.
3762
3113
 
3763
3114
  .. math::
3764
3115
  output[indices] = input\_x - update
@@ -3943,20 +3294,18 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
3943
3294
  nondeterministic.
3944
3295
  - On Ascend, the reduction only support set to "none" for now.
3945
3296
  - On Ascend, the data type of `input_x` must be float16 or float32.
3297
+ - This is an experimental API that is subject to change or deletion.
3946
3298
 
3947
3299
  Note:
3948
3300
  If some values of the `indices` exceed the upper or lower bounds of the index of `input_x`, instead of raising
3949
3301
  an index error, the corresponding `updates` will not be updated to `input_x`.
3950
-
3951
- .. warning::
3952
- This is an experimental API that is subject to change or deletion.
3302
+ The backward is supported only for the case `updates.shape == indices.shape`.
3953
3303
 
3954
3304
  Args:
3955
3305
  input_x (Tensor): The target tensor. The rank must be at least 1.
3956
3306
  indices (Tensor): The index of `input_x` to do scatter operation whose data type must be mindspore.int32 or
3957
3307
  mindspore.int64. Same rank as `input_x`. And accepted range is [-s, s) where s is the size along axis.
3958
- updates (Tensor): The tensor doing the scatter operation with `input_x`, has the same type as `input_x` and
3959
- the same shape as `indices`.
3308
+ updates (Tensor): The tensor doing the scatter operation with `input_x`.
3960
3309
  axis (int): Which axis to scatter. Accepted range is [-r, r) where r = rank(input_x). Default: ``0``.
3961
3310
  reduction (str): Which reduction operation to scatter, supports ``"none"`` , ``"add"`` . Default: ``"none"``.
3962
3311
  When `reduction` is set to ``"none"``, `updates` will be assigned to `input_x` according to `indices`.
@@ -3968,7 +3317,6 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
3968
3317
  Raises:
3969
3318
  TypeError: If `indices` is neither int32 nor int64.
3970
3319
  ValueError: If anyone of the rank among `input_x`, `indices` and `updates` less than 1.
3971
- ValueError: If the shape of `updates` is not equal to the shape of `indices`.
3972
3320
  ValueError: If the rank of `updates` is not equal to the rank of `input_x`.
3973
3321
  RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
3974
3322
  is required when data type conversion of Parameter is not supported.
@@ -4000,8 +3348,7 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
4000
3348
  [ 5 5 14]
4001
3349
  [ 7 15 11]]
4002
3350
  """
4003
- _tensor_scatter_elements = _get_cache_prim(TensorScatterElements)(axis, reduction)
4004
- return _tensor_scatter_elements(input_x, indices, updates)
3351
+ return tensor_scatter_elements_ext(input_x, indices, updates, axis, reduction)
4005
3352
 
4006
3353
 
4007
3354
  def scatter(input, axis, index, src):
@@ -4009,24 +3356,26 @@ def scatter(input, axis, index, src):
4009
3356
  Update the value in `src` to `input` according to the specified index.
4010
3357
  Refer to :func:`mindspore.ops.tensor_scatter_elements` for more details.
4011
3358
 
3359
+ .. note::
3360
+ The backward is supported only for the case `src.shape == index.shape`.
3361
+
4012
3362
  Args:
4013
3363
  input (Tensor): The target tensor. The rank of `input` must be at least 1.
4014
3364
  axis (int): Which axis to scatter. Accepted range is [-r, r) where r = rank(input).
4015
- index (Tensor): The index to do update operation whose data type must be mindspore.int32 or
4016
- mindspore.int64. Same rank as `input` . And accepted range is [-s, s) where s is the size along axis.
4017
- src (Tensor): The tensor doing the update operation with `input` , has the same type as `input` ,
4018
- and the shape of `src` should be equal to the shape of `index` .
3365
+ index (Tensor): The index to do update operation whose data must be positive number with type of mindspore.int32
3366
+ or mindspore.int64. Same rank as `input` . And accepted range is [-s, s) where s is the size along axis.
3367
+ src (Tensor, float): The data doing the update operation with `input`. Can be a tensor with the same data type
3368
+ as `input` or a float number to scatter.
4019
3369
 
4020
3370
  Returns:
4021
- Tensor, has the same shape and type as `input` .
3371
+ The backward is supported only for the case `src.shape == index.shape` when `src` is a tensor.
4022
3372
 
4023
3373
  Raises:
4024
3374
  TypeError: If `index` is neither int32 nor int64.
4025
- ValueError: If anyone of the rank among `input` , `index` and `src` less than 1.
4026
- ValueError: If the shape of `src` is not equal to the shape of `index` .
3375
+ ValueError: If rank of any of `input` , `index` and `src` less than 1.
4027
3376
  ValueError: If the rank of `src` is not equal to the rank of `input` .
4028
- RuntimeError: If the data type of `input` and `src` conversion of Parameter
4029
- is required when data type conversion of Parameter is not supported.
3377
+ TypeError: If the data type of `input` and `src` have different dtypes.
3378
+ RuntimeError: If `index` has negative elements.
4030
3379
 
4031
3380
  Supported Platforms:
4032
3381
  ``Ascend`` ``GPU`` ``CPU``
@@ -4062,7 +3411,82 @@ def scatter(input, axis, index, src):
4062
3411
  [0. 0. 0. 0. 0.]
4063
3412
  [0. 0. 0. 0. 0.]]
4064
3413
  """
4065
- return ops.tensor_scatter_elements(input_x=input, indices=index, updates=src, axis=axis)
3414
+ if isinstance(src, Tensor):
3415
+ return scatter_(input, axis, index, src)
3416
+ return scatter_value_(input, axis, index, src)
3417
+
3418
+
3419
+ def scatter_add_ext(input, dim, index, src):
3420
+ """
3421
+ Add all elements in `src` to the index specified by `index` to `input` along dimension specified by `dim`.
3422
+ It takes three inputs `input`, `src` and `index` of the same rank r >= 1.
3423
+
3424
+ For a 3-D tensor, the operation updates input as follows:
3425
+
3426
+ .. code-block::
3427
+
3428
+ input[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
3429
+
3430
+ input[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
3431
+
3432
+ input[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
3433
+
3434
+ Args:
3435
+ input (Tensor): The target tensor. The rank must be at least 1.
3436
+ dim (int): Which dim to scatter. Accepted range is [-r, r) where r = rank(`input`). Default: ``0``.
3437
+ index (Tensor): The index of `input` to do scatter operation whose data type must be mindspore.int32 or
3438
+ mindspore.int64. Same rank as `input`. Except for the dimension specified by `dim`,
3439
+ the size of each dimension of `index` must be less than or equal to the size of
3440
+ the corresponding dimension of `input`.
3441
+ src (Tensor): The tensor doing the scatter operation with `input`, has the same type as `input` and
3442
+ the size of each dimension must be greater than or equal to that of `index`.
3443
+
3444
+ Returns:
3445
+ Tensor, has the same shape and type as `input`.
3446
+
3447
+ Raises:
3448
+ TypeError: If `index` is neither int32 nor int64.
3449
+ ValueError: If anyone of the rank among `input`, `index` and `src` less than 1.
3450
+ ValueError: If the rank of `input`, `index` and `src` is not the same.
3451
+ ValueError: If, outside dimension `dim`, the size of any dimension of `index` is greater than the size of
3452
+ the corresponding dimension of `input` .
3453
+ ValueError: If the size of any dimension of `src` is less than that of `index`.
3454
+
3455
+ Supported Platforms:
3456
+ ``Ascend``
3457
+
3458
+ Examples:
3459
+ >>> import numpy as np
3460
+ >>> import mindspore as ms
3461
+ >>> from mindspore import Tensor, ops
3462
+ >>> input = Tensor(np.array([[1, 2, 3, 4, 5]]), dtype=ms.float32)
3463
+ >>> src = Tensor(np.array([[8, 8]]), dtype=ms.float32)
3464
+ >>> index = Tensor(np.array([[2, 4]]), dtype=ms.int64)
3465
+ >>> out = ops.function.array_func.scatter_add_ext(input=input, dim=1, index=index, src=src)
3466
+ >>> print(out)
3467
+ [[1. 2. 11. 4. 13.]]
3468
+ >>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32)
3469
+ >>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
3470
+ >>> index = Tensor(np.array([[0, 0, 0], [2, 2, 2], [4, 4, 4]]), dtype=ms.int64)
3471
+ >>> out = ops.function.array_func.scatter_add_ext(input=input, dim=0, index=index, src=src)
3472
+ >>> print(out)
3473
+ [[1. 2. 3. 0. 0.]
3474
+ [0. 0. 0. 0. 0.]
3475
+ [4. 5. 6. 0. 0.]
3476
+ [0. 0. 0. 0. 0.]
3477
+ [7. 8. 9. 0. 0.]]
3478
+ >>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32)
3479
+ >>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
3480
+ >>> index = Tensor(np.array([[0, 2, 4], [0, 2, 4], [0, 2, 4]]), dtype=ms.int64)
3481
+ >>> out = ops.function.array_func.scatter_add_ext(input=input, dim=1, index=index, src=src)
3482
+ >>> print(out)
3483
+ [[1. 0. 2. 0. 3.]
3484
+ [4. 0. 5. 0. 6.]
3485
+ [7. 0. 8. 0. 9.]
3486
+ [0. 0. 0. 0. 0.]
3487
+ [0. 0. 0. 0. 0.]]
3488
+ """
3489
+ return scatter_add_ext_op(input, dim, index, src)
4066
3490
 
4067
3491
 
4068
3492
  def _get_slice_scatter_const(x_shape, axis, start, end, step):
@@ -4074,7 +3498,7 @@ def _get_slice_scatter_const(x_shape, axis, start, end, step):
4074
3498
  start = start if start is not None else 0
4075
3499
  start = start if start >= 0 else start + x_rank
4076
3500
  end = end if end is not None else x_shape[axis]
4077
- end = end if end >= 0 else end + x_rank
3501
+ end = end if end >= 0 else end + x_shape[axis]
4078
3502
  end = end if end < x_shape[axis] else x_shape[axis]
4079
3503
  index = list(builtins.range(start, end, step))
4080
3504
  return x_rank, index, axis
@@ -4121,8 +3545,11 @@ def slice_scatter(input, src, axis=0, start=None, end=None, step=1):
4121
3545
  [1. 0. 1. 0. 1. 0.]
4122
3546
  [1. 0. 1. 0. 1. 0.]]
4123
3547
  """
3548
+ _check_is_tensor("input", input, "slice_scatter")
3549
+ _check_is_tensor("src", src, "slice_scatter")
4124
3550
  input_shape = input.shape
4125
- input_rank, index, axis = _get_slice_scatter_const(input_shape, axis, start, end, step)
3551
+ input_rank, index, axis = _get_slice_scatter_const(
3552
+ input_shape, axis, start, end, step)
4126
3553
 
4127
3554
  src_shape = src.shape
4128
3555
  index_shape = input_shape[:axis] + (len(index),) + input_shape[axis + 1:]
@@ -4136,6 +3563,8 @@ def slice_scatter(input, src, axis=0, start=None, end=None, step=1):
4136
3563
  for _ in builtins.range(input_rank - axis - 1):
4137
3564
  index_tensor = index_tensor.expand_dims(-1)
4138
3565
  index_tensor = index_tensor.broadcast_to(src.shape)
3566
+ if index_tensor.dtype not in mstype.int_type:
3567
+ index_tensor = index_tensor.astype(mstype.int64)
4139
3568
  return tensor_scatter_elements(input, axis=axis, indices=index_tensor, updates=src)
4140
3569
 
4141
3570
 
@@ -4174,10 +3603,12 @@ def select_scatter(input, src, axis, index):
4174
3603
  [1. 1. 1.]
4175
3604
  [0. 0. 0.]]]
4176
3605
  """
3606
+ _check_is_tensor("input", input, "select_scatter")
3607
+ _check_is_tensor("src", src, "select_scatter")
4177
3608
  src = src.expand_dims(axis=axis)
4178
3609
  x_rank = input.ndim
4179
3610
  axis = axis if axis >= 0 else axis + x_rank
4180
- index = index if index >= 0 else index + x_rank
3611
+ index = index if index >= 0 else index + input.shape[axis]
4181
3612
  return slice_scatter(input, src, axis, start=index, end=index + 1)
4182
3613
 
4183
3614
 
@@ -4240,7 +3671,8 @@ def space_to_batch_nd(input_x, block_size, paddings):
4240
3671
  [[[3.]]]
4241
3672
  [[[4.]]]]
4242
3673
  """
4243
- _space_to_batch_nd = _get_cache_prim(P.SpaceToBatchND)(block_size, paddings)
3674
+ _space_to_batch_nd = _get_cache_prim(
3675
+ P.SpaceToBatchND)(block_size, paddings)
4244
3676
  return _space_to_batch_nd(input_x)
4245
3677
 
4246
3678
 
@@ -4303,49 +3735,11 @@ def batch_to_space_nd(input_x, block_shape, crops):
4303
3735
  [3. 4.]]]]
4304
3736
  """
4305
3737
  if isinstance(block_shape, Tensor):
4306
- _batch_to_space_ndv2 = _get_cache_prim(P.BatchToSpaceNDV2)()
4307
- return _batch_to_space_ndv2(input_x, block_shape, crops)
3738
+ return batch_to_space_nd_v2_(input_x, block_shape, crops)
4308
3739
  _batch_to_space_nd = _get_cache_prim(P.BatchToSpaceND)(block_shape, crops)
4309
3740
  return _batch_to_space_nd(input_x)
4310
3741
 
4311
3742
 
4312
- def nonzero(input):
4313
- """
4314
- Return a Tensor of the positions of all non-zero values.
4315
-
4316
- Args:
4317
- input (Tensor): The input Tensor, its rank should be greater than or eaqual to 1.
4318
-
4319
- Returns:
4320
- Tensor, a 2-D Tensor whose data type is int64, containing the positions of all non-zero values of the input.
4321
-
4322
- Raises:
4323
- TypeError: If `input` is not Tensor.
4324
- ValueError: If dim of `x` equals to 0.
4325
-
4326
- Supported Platforms:
4327
- ``Ascend`` ``GPU`` ``CPU``
4328
-
4329
- Examples:
4330
- >>> import mindspore
4331
- >>> import numpy as np
4332
- >>> from mindspore import Tensor
4333
- >>> import mindspore.ops as ops
4334
- >>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32)
4335
- >>> output = ops.nonzero(x)
4336
- >>> print(output)
4337
- [[0 0 0]
4338
- [0 1 0]]
4339
- >>> x = Tensor(np.array([1, 0, 2, 0, 3]), mindspore.int32)
4340
- >>> output = ops.nonzero(x)
4341
- >>> print(output)
4342
- [[0]
4343
- [2]
4344
- [4]]
4345
- """
4346
- return nonzero_(input)
4347
-
4348
-
4349
3743
  def matrix_diag(x, k=0, num_rows=-1, num_cols=-1, padding_value=0, align="RIGHT_LEFT"):
4350
3744
  r"""
4351
3745
  Returns a Tensor with the contents in `x` as k[0]-th to k[1]-th diagonals of a matrix, with everything else padded
@@ -4605,18 +3999,19 @@ def meshgrid(*inputs, indexing='xy'):
4605
3999
 
4606
4000
  Keyword Args:
4607
4001
  indexing (str, optional): Cartesian ('xy', default) or
4608
- matrix ('ij') indexing of output. Valid options: xy' or 'ij'. In the 2-D case with
4002
+ matrix ('ij') indexing of output. Valid options: xy' or ``'ij'``. In the 2-D case with
4609
4003
  inputs of length `M` and `N`, the outputs are of shape :math:`(N, M)`
4610
- for 'xy' indexing and :math:`(M, N)` for 'ij' indexing. In the 3-D
4004
+ for ``'xy'`` indexing and :math:`(M, N)` for ``'ij'`` indexing. In the 3-D
4611
4005
  case with inputs of length `M`, `N` and `P`, outputs are of shape
4612
- :math:`(N, M, P)` for 'xy' indexing and :math:`(M, N, P)` for 'ij' indexing. Default: ``'xy'`` .
4006
+ :math:`(N, M, P)` for ``'xy'`` indexing and :math:`(M, N, P)` for ``'ij'`` indexing.
4007
+ Default: ``'xy'`` .
4613
4008
 
4614
4009
  Returns:
4615
4010
  Tensors, a Tuple of N N-D Tensor objects. The data type is the same with the Inputs.
4616
4011
 
4617
4012
  Raises:
4618
4013
  TypeError: If `indexing` is not a str or `inputs` is not a tuple.
4619
- ValueError: If `indexing` is neither 'xy' nor 'ij'.
4014
+ ValueError: If `indexing` is neither ``'xy'`` nor ``'ij'``.
4620
4015
 
4621
4016
  Supported Platforms:
4622
4017
  ``Ascend`` ``GPU`` ``CPU``
@@ -4624,7 +4019,7 @@ def meshgrid(*inputs, indexing='xy'):
4624
4019
  Examples:
4625
4020
  >>> import numpy as np
4626
4021
  >>> from mindspore import Tensor
4627
- >>> import mindspore.ops as ops
4022
+ >>> from mindspore import ops
4628
4023
  >>> x = Tensor(np.array([1, 2, 3, 4]).astype(np.int32))
4629
4024
  >>> y = Tensor(np.array([5, 6, 7]).astype(np.int32))
4630
4025
  >>> z = Tensor(np.array([8, 9, 0, 1, 2]).astype(np.int32))
@@ -4707,7 +4102,7 @@ def affine_grid(theta, size, align_corners=False):
4707
4102
  Examples:
4708
4103
  >>> import mindspore
4709
4104
  >>> from mindspore import Tensor
4710
- >>> import mindspore.ops as ops
4105
+ >>> from mindspore import ops
4711
4106
  >>> theta = Tensor([[[0.8, 0.5, 0],[-0.5, 0.8, 0]]], mindspore.float32)
4712
4107
  >>> out_size = (1, 3, 2, 3)
4713
4108
  >>> output = ops.affine_grid(theta, out_size, False)
@@ -4723,87 +4118,6 @@ def affine_grid(theta, size, align_corners=False):
4723
4118
  return affine_grid_op(theta, size)
4724
4119
 
4725
4120
 
4726
- def broadcast_to(input, shape): # pylint: disable=redefined-outer-name
4727
- """
4728
- Broadcasts input tensor to a given shape. The dim of input shape must be smaller
4729
- than or equal to that of target shape. Suppose input shape is :math:`(x_1, x_2, ..., x_m)`,
4730
- target shape is :math:`(*, y_1, y_2, ..., y_m)`, where :math:`*` means any additional dimension.
4731
- The broadcast rules are as follows:
4732
-
4733
- Compare the value of :math:`x_m` and :math:`y_m`, :math:`x_{m-1}` and :math:`y_{m-1}`, ...,
4734
- :math:`x_1` and :math:`y_1` consecutively and
4735
- decide whether these shapes are broadcastable and what the broadcast result is.
4736
-
4737
- If the value pairs at a specific dim are equal, then that value goes right into that dim of output shape.
4738
- With an input shape :math:`(2, 3)`, target shape :math:`(2, 3)` , the inferred output shape is :math:`(2, 3)`.
4739
-
4740
- If the value pairs are unequal, there are three cases:
4741
-
4742
- Case 1: If the value of the target shape in the dimension is -1, the value of the
4743
- output shape in the dimension is the value of the corresponding input shape in the dimension.
4744
- With an input shape :math:`(3, 3)`, target
4745
- shape :math:`(-1, 3)`, the output shape is :math:`(3, 3)`.
4746
-
4747
- Case 2: If the value of target shape in the dimension is not -1, but the corresponding
4748
- value in the input shape is 1, then the corresponding value of the output shape
4749
- is that of the target shape. With an input shape :math:`(1, 3)`, target
4750
- shape :math:`(8, 3)`, the output shape is :math:`(8, 3)`.
4751
-
4752
- Case 3: If the corresponding values of the two shapes do not satisfy the above cases,
4753
- it means that broadcasting from the input shape to the target shape is not supported.
4754
-
4755
- So far we got the last m dims of the outshape, now focus on the first :math:`*` dims, there are
4756
- two cases:
4757
-
4758
- If the first :math:`*` dims of output shape does not have -1 in it, then fill the input
4759
- shape with ones until their length are the same, and then refer to
4760
- Case 2 mentioned above to calculate the output shape. With target shape :math:`(3, 1, 4, 1, 5, 9)`,
4761
- input shape :math:`(1, 5, 9)`, the filled input shape will be :math:`(1, 1, 1, 1, 5, 9)` and thus the
4762
- output shape is :math:`(3, 1, 4, 1, 5, 9)`.
4763
-
4764
- If the first :math:`*` dims of output shape have -1 in it, it implies this -1 is corresponding to
4765
- a non-existing dim so they're not broadcastable. With target shape :math:`(3, -1, 4, 1, 5, 9)`,
4766
- input shape :math:`(1, 5, 9)`, instead of operating the dim-filling process first, it raises errors directly.
4767
-
4768
- Args:
4769
- input (Tensor): The input Tensor.
4770
- shape (tuple): The target shape to broadcast. Can be fully specified, or have -1 in one position
4771
- where it will be substituted by the input tensor's shape in that position, see example.
4772
-
4773
- Returns:
4774
- Tensor, with the given `shape` and the same data type as `input`.
4775
-
4776
- Raises:
4777
- TypeError: If `shape` is not a tuple.
4778
- ValueError: If the target and input shapes are incompatible, or if a - 1 in the target shape is in an invalid
4779
- location.
4780
-
4781
- Supported Platforms:
4782
- ``Ascend`` ``GPU`` ``CPU``
4783
-
4784
- Examples:
4785
- >>> import numpy as np
4786
- >>> from mindspore import Tensor, ops
4787
- >>> shape = (2, 3)
4788
- >>> x = Tensor(np.array([1, 2, 3]).astype(np.float32))
4789
- >>> output = ops.broadcast_to(x, shape)
4790
- >>> print(output)
4791
- [[1. 2. 3.]
4792
- [1. 2. 3.]]
4793
- >>> shape = (-1, 2)
4794
- >>> x = Tensor(np.array([[1], [2]]).astype(np.float32))
4795
- >>> output = ops.broadcast_to(x, shape)
4796
- >>> print(output)
4797
- [[1. 1.]
4798
- [2. 2.]]
4799
- """
4800
- if isinstance(shape, Tensor) or ops.is_sequence_value_unknown(shape):
4801
- _dyn_broadcast_to = _get_cache_prim(DynamicBroadcastTo)()
4802
- return _dyn_broadcast_to(input, shape)
4803
- _broadcast_to = _get_cache_prim(P.BroadcastTo)(shape)
4804
- return _broadcast_to(input)
4805
-
4806
-
4807
4121
  def unsorted_segment_min(x, segment_ids, num_segments):
4808
4122
  r"""
4809
4123
  Computes the minimum of a tensor along segments.
@@ -4827,14 +4141,13 @@ def unsorted_segment_min(x, segment_ids, num_segments):
4827
4141
  x (Tensor): The shape is :math:`(x_1, x_2, ..., x_R)`. With float16, float32 or int32 data type.
4828
4142
  segment_ids (Tensor): TThe label indicates the segment to which each element belongs.
4829
4143
  Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
4830
- num_segments (int): The value specifies the number of distinct `segment_ids`.
4144
+ num_segments (Union[int, Tensor], optional): Set :math:`z` as num_segments, it can be an int or 0-D Tensor.
4831
4145
 
4832
4146
  Returns:
4833
- Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
4147
+ Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
4834
4148
 
4835
4149
  Raises:
4836
4150
  TypeError: If `num_segments` is not an int.
4837
- ValueError: If length of shape of `segment_ids` is not equal to 1.
4838
4151
 
4839
4152
  Supported Platforms:
4840
4153
  ``Ascend`` ``GPU`` ``CPU``
@@ -4851,7 +4164,6 @@ def unsorted_segment_min(x, segment_ids, num_segments):
4851
4164
  [[1. 2. 3.]
4852
4165
  [4. 2. 1.]]
4853
4166
  """
4854
- unsorted_segment_min_ = P.UnsortedSegmentMin()
4855
4167
  return unsorted_segment_min_(x, segment_ids, num_segments)
4856
4168
 
4857
4169
 
@@ -4878,14 +4190,13 @@ def unsorted_segment_max(x, segment_ids, num_segments):
4878
4190
  x (Tensor): The shape is :math:`(x_1, x_2, ..., x_R)`. With float16, float32 or int32 data type.
4879
4191
  segment_ids (Tensor): TThe label indicates the segment to which each element belongs.
4880
4192
  Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
4881
- num_segments (int): The value specifies the number of distinct `segment_ids`.
4193
+ num_segments (Union[int, Tensor], optional): Set :math:`z` as num_segments, it can be an int or 0-D Tensor.
4882
4194
 
4883
4195
  Returns:
4884
- Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
4196
+ Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
4885
4197
 
4886
4198
  Raises:
4887
4199
  TypeError: If `num_segments` is not an int.
4888
- ValueError: If length of shape of `segment_ids` is not equal to 1.
4889
4200
 
4890
4201
  Supported Platforms:
4891
4202
  ``Ascend`` ``GPU`` ``CPU``
@@ -4902,7 +4213,6 @@ def unsorted_segment_max(x, segment_ids, num_segments):
4902
4213
  [[1. 2. 3.]
4903
4214
  [4. 5. 6.]]
4904
4215
  """
4905
- unsorted_segment_max_ = P.UnsortedSegmentMax()
4906
4216
  return unsorted_segment_max_(x, segment_ids, num_segments)
4907
4217
 
4908
4218
 
@@ -4920,16 +4230,15 @@ def unsorted_segment_prod(x, segment_ids, num_segments):
4920
4230
 
4921
4231
  Args:
4922
4232
  x (Tensor): The shape is :math:`(x_1, x_2, ..., x_R)`. With float16, float32 or int32 data type.
4923
- segment_ids (Tensor): A `1-D` tensor whose shape is :math:`(x_1)`,
4924
- the value must be non-negative tensor. The data type must be int32.
4925
- num_segments (int): The value specifies the number of distinct `segment_ids`.
4233
+ segment_ids (Tensor): TThe label indicates the segment to which each element belongs.
4234
+ Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R. The data type must be int32.
4235
+ num_segments (Union[int, Tensor], optional): Set :math:`z` as num_segments, it can be an int or 0-D Tensor.
4926
4236
 
4927
4237
  Returns:
4928
- Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
4238
+ Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
4929
4239
 
4930
4240
  Raises:
4931
4241
  TypeError: If `num_segments` is not an int.
4932
- ValueError: If length of shape of `segment_ids` is not equal to 1.
4933
4242
 
4934
4243
  Supported Platforms:
4935
4244
  ``Ascend`` ``GPU`` ``CPU``
@@ -4946,7 +4255,6 @@ def unsorted_segment_prod(x, segment_ids, num_segments):
4946
4255
  [[4. 4. 3.]
4947
4256
  [4. 5. 6.]]
4948
4257
  """
4949
- unsorted_segment_prod_ = P.UnsortedSegmentProd()
4950
4258
  return unsorted_segment_prod_(x, segment_ids, num_segments)
4951
4259
 
4952
4260
 
@@ -4987,7 +4295,7 @@ def index_fill(x, axis, index, value):
4987
4295
  Examples:
4988
4296
  >>> import mindspore
4989
4297
  >>> import numpy as np
4990
- >>> import mindspore.ops as ops
4298
+ >>> from mindspore import ops
4991
4299
  >>> from mindspore import Tensor
4992
4300
  >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32))
4993
4301
  >>> index = Tensor([0, 2], mindspore.int32)
@@ -5056,9 +4364,11 @@ def index_select(input, axis, index):
5056
4364
  [[ 8. 9. 10. 11.]]]
5057
4365
  """
5058
4366
  if not (isinstance(input, Tensor) and isinstance(index, Tensor)):
5059
- raise TypeError(f"For 'index_select', `input` and `index` must be all tensors.")
4367
+ raise TypeError(
4368
+ f"For 'index_select', `input` and `index` must be all tensors.")
5060
4369
  if index.ndim != 1:
5061
- raise ValueError(f"For 'index_select', the dimension of `index` must be 1, but got {index.ndim}")
4370
+ raise ValueError(
4371
+ f"For 'index_select', the dimension of `index` must be 1, but got {index.ndim}")
5062
4372
  axis = _check_check_axis_in_range(axis, input.ndim)
5063
4373
  return gather_(input, index, axis)
5064
4374
 
@@ -5151,40 +4461,15 @@ def is_nonzero(input):
5151
4461
  True
5152
4462
  """
5153
4463
  if not isinstance(input, Tensor):
5154
- raise TypeError(f'For is_nonzero, the input must be a Tensor, but got {type(input)}.')
4464
+ raise TypeError(
4465
+ f'For is_nonzero, the input must be a Tensor, but got {type(input)}.')
5155
4466
  if input.numel() != 1:
5156
- raise ValueError(f"For is_nonzero, the numel of input must be 1, but got {input.numel()}.")
4467
+ raise ValueError(
4468
+ f"For is_nonzero, the numel of input must be 1, but got {input.numel()}.")
5157
4469
  out = ops.squeeze(input)
5158
4470
  return bool(out)
5159
4471
 
5160
4472
 
5161
- def scalar_cast(input_x, input_y):
5162
- """
5163
- Casts the input scalar to another type.
5164
-
5165
- Args:
5166
- input_x (scalar): The input scalar. Only constant value is allowed.
5167
- input_y (mindspore.dtype): The type to be cast. Only constant value is allowed.
5168
-
5169
- Returns:
5170
- Scalar. The type is the same as the python type corresponding to `input_y`.
5171
-
5172
- Raises:
5173
- TypeError: If neither `input_x` nor `input_y` is a constant value.
5174
-
5175
- Supported Platforms:
5176
- ``Ascend`` ``GPU`` ``CPU``
5177
-
5178
- Examples:
5179
- >>> import mindspore
5180
- >>> from mindspore import ops
5181
- >>> output = ops.scalar_cast(255.0, mindspore.int32)
5182
- >>> print(output)
5183
- 255
5184
- """
5185
- return scalar_cast_(input_x, input_y)
5186
-
5187
-
5188
4473
  def tensor_scatter_mul(input_x, indices, updates):
5189
4474
  r"""
5190
4475
  Creates a new tensor by multiplying the values from the positions in `input_x` indicated by
@@ -5194,10 +4479,10 @@ def tensor_scatter_mul(input_x, indices, updates):
5194
4479
 
5195
4480
  The last axis of `indices` is the depth of each index vectors. For each index vector,
5196
4481
  there must be a corresponding value in `updates`. The shape of `updates` should be
5197
- equal to the shape of `input_x[indices]`. For more details, see use cases.
4482
+ equal to the shape of `input_x[indices]`. For more details, see Examples.
5198
4483
 
5199
4484
  .. math::
5200
- output[indices] = input\_x \times update
4485
+ output\left [indices \right ] = input\_x\times update
5201
4486
 
5202
4487
  Note:
5203
4488
  - If some values of the `indices` are out of bound, instead of raising an index error,
@@ -5254,7 +4539,7 @@ def tensor_scatter_div(input_x, indices, updates):
5254
4539
 
5255
4540
  The last axis of `indices` is the depth of each index vectors. For each index vector,
5256
4541
  there must be a corresponding value in `updates`. The shape of `updates` should be
5257
- equal to the shape of `input_x[indices]`. For more details, see use cases.
4542
+ equal to the shape of `input_x[indices]`. For more details, see Examples.
5258
4543
 
5259
4544
  .. math::
5260
4545
  output\left [indices \right ] = input\_x \div update
@@ -5375,115 +4660,6 @@ def tuple_to_array(input_x):
5375
4660
  return tuple_to_tensor_(input_x, dtype)
5376
4661
 
5377
4662
 
5378
- def masked_select(input, mask):
5379
- """
5380
- Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
5381
- The shapes of the `mask` tensor and the `x` tensor don't need to match, but they must be broadcastable.
5382
-
5383
- Args:
5384
- input (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
5385
- mask (Tensor[bool]): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
5386
-
5387
- Returns:
5388
- A 1-D Tensor, with the same type as `input`.
5389
-
5390
- Raises:
5391
- TypeError: If `input` or `mask` is not a Tensor.
5392
- TypeError: If dtype of `mask` is not bool.
5393
-
5394
- Supported Platforms:
5395
- ``Ascend`` ``GPU`` ``CPU``
5396
-
5397
- Examples:
5398
- >>> import numpy as np
5399
- >>> import mindspore
5400
- >>> from mindspore import Tensor, ops
5401
- >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
5402
- >>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_)
5403
- >>> output = ops.masked_select(x, mask)
5404
- >>> print(output)
5405
- [1 3]
5406
- """
5407
- return masked_select_(input, mask)
5408
-
5409
-
5410
- def masked_fill(input_x, mask, value):
5411
- """
5412
- Fills elements of Tensor with value where mask is True.
5413
- The shapes of `input_x` and `mask` need to be the same or broadcastable.
5414
-
5415
- Args:
5416
- input_x (Tensor): The source Tensor whose data type is one of bool, uint8, int8, int16, int32,
5417
- int64, float16, float32, float64, complex64, complex128.
5418
- mask (Tensor[bool]): The boolean mask.
5419
- value (Union[float, Tensor]): The value to fill in with, which dtype is the same as `input_x`.
5420
-
5421
- Returns:
5422
- Tensor, has the same type and shape as `input_x`.
5423
-
5424
- Raises:
5425
- TypeError: If dtype of `mask` is not bool.
5426
- TypeError: If `input_x` or `mask` is not a Tensor.
5427
- ValueError: If the shapes of `input_x` and `mask` could not be broadcast.
5428
- TypeError: If dtype of `input_x` or `value` is not one of bool, uint8, int8, int16, int32,
5429
- int64, float16, float32, float64, complex64, complex128.
5430
- TypeError: If dtype of `value` is different from that of `input_x`.
5431
- TypeError: If `value` is neither float number nor Tensor.
5432
-
5433
- Supported Platforms:
5434
- ``Ascend`` ``GPU`` ``CPU``
5435
-
5436
- Examples:
5437
- >>> import mindspore
5438
- >>> import numpy as np
5439
- >>> from mindspore import Tensor, ops
5440
- >>> input_x = Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
5441
- >>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
5442
- >>> output = ops.masked_fill(input_x, mask, 0.5)
5443
- >>> print(output)
5444
- [0.5 0.5 3. 0.5]
5445
- """
5446
- if isinstance(value, (float, int)) and isinstance(input_x, Tensor):
5447
- value = scalar_to_tensor_(value, input_x.dtype)
5448
- masked_fill_ = _get_cache_prim(P.MaskedFill)()
5449
- return masked_fill_(input_x, mask, value)
5450
-
5451
-
5452
- def diag(input):
5453
- r"""
5454
- Constructs a diagonal tensor with a given diagonal values.
5455
-
5456
- Assume `input` has dimensions :math:`(D_1,... D_k)` , the output is a tensor of
5457
- rank 2k with dimensions :math:`(D_1,..., D_k, D_1,..., D_k)` where:
5458
- :math:`output[i_1,..., i_k, i_1,..., i_k] = input[i_1,..., i_k]` and 0 everywhere else.
5459
-
5460
- Args:
5461
- input (Tensor): The input tensor.
5462
-
5463
- Returns:
5464
- Tensor, has the same dtype as the `input`.
5465
-
5466
- Raises:
5467
- TypeError: If `input` is not a Tensor.
5468
- ValueError: If rank of `input` is less than 1.
5469
-
5470
- Supported Platforms:
5471
- ``Ascend`` ``GPU`` ``CPU``
5472
-
5473
- Examples:
5474
- >>> from mindspore import Tensor
5475
- >>> import mindspore.ops as ops
5476
- >>> input_x = Tensor([1, 2, 3, 4]).astype('int32')
5477
- >>> output = ops.diag(input_x)
5478
- >>> print(output)
5479
- [[1 0 0 0]
5480
- [0 2 0 0]
5481
- [0 0 3 0]
5482
- [0 0 0 4]]
5483
- """
5484
- return diag_(input)
5485
-
5486
-
5487
4663
  def diagflat(input, offset=0):
5488
4664
  r"""
5489
4665
  Create a 2-D Tensor which diagonal is the flattened `input` .
@@ -5517,9 +4693,11 @@ def diagflat(input, offset=0):
5517
4693
  [0. 0. 0.]]
5518
4694
  """
5519
4695
  if not isinstance(input, Tensor):
5520
- raise TypeError(f"For diagflat, the input x must be tensor, but got {type(input)}")
4696
+ raise TypeError(
4697
+ f"For diagflat, the input x must be tensor, but got {type(input)}")
5521
4698
  if not isinstance(offset, int):
5522
- raise TypeError(f"For diagflat, the offset must be int, but got {type(offset)}")
4699
+ raise TypeError(
4700
+ f"For diagflat, the offset must be int, but got {type(offset)}")
5523
4701
  offset_abs = abs(offset)
5524
4702
  if input.size == 0:
5525
4703
  return zeros((offset_abs, offset_abs), input.dtype)
@@ -5542,7 +4720,7 @@ def col2im(input_x, output_size, kernel_size, dilation, padding_value, stride):
5542
4720
  Combines an array of sliding local blocks into a large containing tensor.
5543
4721
 
5544
4722
  Args:
5545
- input_x (Tensor): 4D tensor with data type float16 or float.
4723
+ input_x (Tensor): 4D tensor with data type float16 or float32.
5546
4724
  output_size (Tensor): 1D tensor with 2 elements of data type int.
5547
4725
  kernel_size (Union[int, tuple[int], list[int]]): The size of the kernel, should be two int
5548
4726
  for height and width. If type is int, it means that height equal with width. Must be specified.
@@ -5589,7 +4767,9 @@ def _split_int(x, split_size_or_sections, axis):
5589
4767
  """
5590
4768
  arr_shape = x.shape
5591
4769
  length_along_dim = arr_shape[axis]
5592
- if split_size_or_sections > length_along_dim:
4770
+ if length_along_dim == 0:
4771
+ res = _get_cache_prim(P.Split)(axis)(x)
4772
+ elif split_size_or_sections > length_along_dim:
5593
4773
  res = _get_cache_prim(P.Split)(axis, 1)(x)
5594
4774
  elif length_along_dim % split_size_or_sections == 0:
5595
4775
  sections = length_along_dim // split_size_or_sections
@@ -5598,12 +4778,12 @@ def _split_int(x, split_size_or_sections, axis):
5598
4778
  num_sections = length_along_dim // split_size_or_sections
5599
4779
  length1 = num_sections * split_size_or_sections
5600
4780
  length2 = length_along_dim - length1
5601
- start1 = _list_comprehensions(rank(x), 0, True)
4781
+ start1 = _list_comprehensions(rank_(x), 0, True)
5602
4782
  size1 = _tuple_setitem(arr_shape, axis, length1)
5603
4783
  start2 = _tuple_setitem(start1, axis, length1)
5604
4784
  size2 = _tuple_setitem(arr_shape, axis, length2)
5605
4785
  res = _get_cache_prim(P.Split)(axis, num_sections)(tensor_slice(x, start1, size1)) + \
5606
- _get_cache_prim(P.Split)(axis, 1)(tensor_slice(x, start2, size2))
4786
+ _get_cache_prim(P.Split)(axis, 1)(tensor_slice(x, start2, size2))
5607
4787
  return res
5608
4788
 
5609
4789
 
@@ -5650,9 +4830,9 @@ def split(tensor, split_size_or_sections, axis=0):
5650
4830
  TypeError: If argument `tensor` is not Tensor.
5651
4831
  TypeError: If argument `axis` is not Tensor.
5652
4832
  ValueError: If argument `axis` is out of range of :math:`[-tensor.ndim, tensor.ndim)` .
5653
- TypeError: If each element in 'split_size_or_sections' is not integer.
5654
- TypeError: If argument `indices_or_sections` is not int, tuple(int) or list(int).
5655
- ValueError: The sum of 'split_size_or_sections' is not equal to x.shape[axis].
4833
+ TypeError: If each element in `split_size_or_sections` is not integer.
4834
+ TypeError: If argument `split_size_or_sections` is not int, tuple(int) or list(int).
4835
+ ValueError: The sum of `split_size_or_sections` is not equal to x.shape[axis].
5656
4836
 
5657
4837
  Supported Platforms:
5658
4838
  ``Ascend`` ``GPU`` ``CPU``
@@ -5670,7 +4850,8 @@ def split(tensor, split_size_or_sections, axis=0):
5670
4850
  if not isinstance(tensor, Tensor):
5671
4851
  raise TypeError(f'expect `tensor` is a Tensor, but got {type(tensor)}')
5672
4852
  if type(axis) is not int:
5673
- raise TypeError(f"Type of Argument `axis` should be integer but got {type(axis)}")
4853
+ raise TypeError(
4854
+ f"Type of Argument `axis` should be integer but got {type(axis)}")
5674
4855
  arr_axis = _canonicalize_axis(axis, tensor.ndim)
5675
4856
 
5676
4857
  if type(split_size_or_sections) is int:
@@ -5682,7 +4863,8 @@ def split(tensor, split_size_or_sections, axis=0):
5682
4863
  elif isinstance(split_size_or_sections, (list, tuple)):
5683
4864
  for item in split_size_or_sections:
5684
4865
  if type(item) is not int:
5685
- raise TypeError(f"Each element in 'split_size_or_sections' should be integer, but got {type(item)}.")
4866
+ raise TypeError(
4867
+ f"Each element in 'split_size_or_sections' should be integer, but got {type(item)}.")
5686
4868
  if item < 0:
5687
4869
  raise TypeError(f"Each element in 'split_size_or_sections' should be non-negative, "
5688
4870
  f"but got {split_size_or_sections}.")
@@ -5692,14 +4874,62 @@ def split(tensor, split_size_or_sections, axis=0):
5692
4874
  f"but got {sum(split_size_or_sections)}.")
5693
4875
  res = _split_sub_tensors(tensor, split_size_or_sections, arr_axis)
5694
4876
  else:
5695
- raise TypeError(f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), " \
4877
+ raise TypeError(f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), "
5696
4878
  f"but got {type(split_size_or_sections)}")
5697
4879
  return tuple(res)
5698
4880
 
5699
4881
 
4882
+ def split_ext(tensor, split_size_or_sections, axis=0):
4883
+ """
4884
+ Splits the Tensor into chunks along the given axis.
4885
+
4886
+ Args:
4887
+ tensor (Tensor): A Tensor to be divided.
4888
+ split_size_or_sections (Union[int, tuple(int), list(int)]):
4889
+ If `split_size_or_sections` is an int type, `tensor` will be split into equally sized chunks,
4890
+ each chunk with size `split_size_or_sections`. Last chunk will be smaller than `split_size_or_sections`
4891
+ if `tensor.shape[axis]` is not divisible by `split_size_or_sections`.
4892
+ If `split_size_or_sections` is a list type, then `tensor` will be split into len(split_size_or_sections)
4893
+ chunks with sizes `split_size_or_sections` along the given `axis`.
4894
+ axis (int): The axis along which to split. Default: ``0`` .
4895
+
4896
+ Returns:
4897
+ A tuple of sub-tensors.
4898
+
4899
+ Raises:
4900
+ TypeError: If argument `tensor` is not Tensor.
4901
+ TypeError: If argument `axis` is not int.
4902
+ ValueError: If argument `axis` is out of range of :[-tensor.ndim, tensor.ndim).
4903
+ TypeError: If each element in `split_size_or_sections` is not integer.
4904
+ TypeError: If argument `split_size_or_sections` is not int, tuple(int) or list(int).
4905
+ ValueError: The sum of `split_size_or_sections` is not equal to x.shape[axis].
4906
+
4907
+ Supported Platforms:
4908
+ ``Ascend``
4909
+
4910
+ Examples:
4911
+ >>> import numpy as np
4912
+ >>> from mindspore import ops, Tensor
4913
+ >>> input_x = np.arange(9).astype("float32")
4914
+ >>> output = ops.split_ext(Tensor(input_x), 3)
4915
+ >>> print(output)
4916
+ (Tensor(shape=[3], dtype=Float32, value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
4917
+ Tensor(shape=[3], dtype=Float32, value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
4918
+ Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
4919
+ """
4920
+ if isinstance(split_size_or_sections, int):
4921
+ res = split_tensor(tensor, split_size_or_sections, axis)
4922
+ elif isinstance(split_size_or_sections, (list, tuple)):
4923
+ res = split_with_size(tensor, split_size_or_sections, axis)
4924
+ else:
4925
+ raise TypeError(f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), "
4926
+ f"but got {type(split_size_or_sections)}")
4927
+ return res
4928
+
4929
+
5700
4930
  def tril(input, diagonal=0): # pylint: disable=redefined-outer-name
5701
4931
  """
5702
- Returns the lower triangle part of 'input' (elements that contain the diagonal and below),
4932
+ Returns the lower triangle part of `input` (elements that contain the diagonal and below),
5703
4933
  and set the other elements to zeros.
5704
4934
 
5705
4935
  Args:
@@ -5709,13 +4939,13 @@ def tril(input, diagonal=0): # pylint: disable=redefined-outer-name
5709
4939
  indicating the main diagonal.
5710
4940
 
5711
4941
  Returns:
5712
- Tensor, the same shape and data type as the input `x`.
4942
+ Tensor, the same shape and data type as the `input`.
5713
4943
 
5714
4944
  Raises:
5715
- TypeError: If `x` is not a Tensor.
4945
+ TypeError: If `input` is not a Tensor.
5716
4946
  TypeError: If `diagonal` is not an int.
5717
- TypeError: If the type of `x` is neither number nor bool.
5718
- ValueError: If the rank of `x` is less than 2.
4947
+ TypeError: If the type of `input` is neither number nor bool.
4948
+ ValueError: If the rank of `input` is less than 2.
5719
4949
 
5720
4950
  Supported Platforms:
5721
4951
  ``Ascend`` ``GPU`` ``CPU``
@@ -5754,33 +4984,32 @@ def tril(input, diagonal=0): # pylint: disable=redefined-outer-name
5754
4984
  [10 11 0 0]
5755
4985
  [14 15 16 0]]
5756
4986
  """
5757
- tril_ = Tril(diagonal)
4987
+ tril_ = _get_cache_prim(Tril)(diagonal)
5758
4988
  return tril_(input)
5759
4989
 
5760
4990
 
5761
- def triu(input, diagonal=0): # pylint: disable=redefined-outer-name
5762
- r"""
5763
- Returns the upper triangle part of 'input' (elements that contain the diagonal and below),
4991
+ def tril_ext(input, diagonal=0):
4992
+ """
4993
+ Returns the lower triangle part of `input` (elements that contain the diagonal and below),
5764
4994
  and set the other elements to zeros.
5765
4995
 
5766
- .. warning::
5767
- This is an experimental API that is subject to change or deletion.
5768
-
5769
4996
  Args:
5770
- input (Tensor): The input tensor with shape :math:`(M, N, *)` where * means any number of additional dimensions.
4997
+ input (Tensor): A Tensor with shape :math:`(x_1, x_2, ..., x_R)`. The rank must be at least 2.
4998
+ Supporting all number types including bool.
5771
4999
  diagonal (int, optional): An optional attribute indicates the diagonal to consider, default: 0,
5772
5000
  indicating the main diagonal.
5773
5001
 
5774
5002
  Returns:
5775
- Tensor, a tensor has the same shape and data type as input.
5003
+ Tensor, the same shape and data type as the `input`.
5776
5004
 
5777
5005
  Raises:
5778
- TypeError: If `diagonal` is not an int.
5779
5006
  TypeError: If `input` is not a Tensor.
5780
- ValueError: If the dimension of `input` is less than 2.
5007
+ TypeError: If `diagonal` is not an int.
5008
+ TypeError: If the type of `input` is neither number nor bool.
5009
+ ValueError: If the rank of `input` is less than 2.
5781
5010
 
5782
5011
  Supported Platforms:
5783
- ``Ascend`` ``GPU`` ``CPU``
5012
+ ``Ascend``
5784
5013
 
5785
5014
  Examples:
5786
5015
  >>> import numpy as np
@@ -5789,34 +5018,34 @@ def triu(input, diagonal=0): # pylint: disable=redefined-outer-name
5789
5018
  ... [ 5, 6, 7, 8],
5790
5019
  ... [10, 11, 12, 13],
5791
5020
  ... [14, 15, 16, 17]]))
5792
- >>> result = ops.triu(x)
5021
+ >>> result = ops.function.array_func.tril_ext(x)
5793
5022
  >>> print(result)
5794
- [[ 1 2 3 4]
5795
- [ 0 6 7 8]
5796
- [ 0 0 12 13]
5797
- [ 0 0 0 17]]
5023
+ [[ 1 0 0 0]
5024
+ [ 5 6 0 0]
5025
+ [10 11 12 0]
5026
+ [14 15 16 17]]
5798
5027
  >>> x = Tensor(np.array([[ 1, 2, 3, 4],
5799
5028
  ... [ 5, 6, 7, 8],
5800
5029
  ... [10, 11, 12, 13],
5801
5030
  ... [14, 15, 16, 17]]))
5802
- >>> result = ops.triu(x, diagonal=1)
5031
+ >>> result = ops.function.array_func.tril_ext(x, diagonal=1)
5803
5032
  >>> print(result)
5804
- [[ 0 2 3 4]
5805
- [ 0 0 7 8]
5806
- [ 0 0 0 13]
5807
- [ 0 0 0 0]]
5033
+ [[ 1 2 0 0]
5034
+ [ 5 6 7 0]
5035
+ [10 11 12 13]
5036
+ [14 15 16 17]]
5808
5037
  >>> x = Tensor(np.array([[ 1, 2, 3, 4],
5809
5038
  ... [ 5, 6, 7, 8],
5810
5039
  ... [10, 11, 12, 13],
5811
5040
  ... [14, 15, 16, 17]]))
5812
- >>> result = ops.triu(x, diagonal=-1)
5041
+ >>> result = ops.function.array_func.tril_ext(x, diagonal=-1)
5813
5042
  >>> print(result)
5814
- [[ 1 2 3 4]
5815
- [ 5 6 7 8]
5816
- [ 0 11 12 13]
5817
- [ 0 0 16 17]]
5043
+ [[ 0 0 0 0]
5044
+ [ 5 0 0 0]
5045
+ [10 11 0 0]
5046
+ [14 15 16 0]]
5818
5047
  """
5819
- return _get_cache_prim(P.Triu)(diagonal)(input)
5048
+ return tril_ext_impl(input, diagonal)
5820
5049
 
5821
5050
 
5822
5051
  @_primexpr
@@ -5837,7 +5066,8 @@ def _canonicalize_axis(axis, ndim):
5837
5066
  if not isinstance(ax, int):
5838
5067
  raise TypeError(f'axis should be integers, not {type(ax)}')
5839
5068
  if not -ndim <= ax < ndim:
5840
- raise ValueError(f'axis {ax} is out of bounds for array of dimension {ndim}')
5069
+ raise ValueError(
5070
+ f'axis {ax} is out of bounds for array of dimension {ndim}')
5841
5071
 
5842
5072
  def canonicalizer(ax):
5843
5073
  return ax + ndim if ax < 0 else ax
@@ -5917,25 +5147,29 @@ def _tensor_split_sub_int(x, indices_or_sections, axis):
5917
5147
  """
5918
5148
  arr_shape = x.shape
5919
5149
  length_along_dim = arr_shape[axis]
5920
- if indices_or_sections > length_along_dim:
5921
- res = P.Split(axis, length_along_dim)(x)
5150
+ if length_along_dim == 0:
5151
+ res = _get_cache_prim(P.Split)(axis)(x)
5152
+ elif indices_or_sections > length_along_dim:
5153
+ res = _get_cache_prim(P.Split)(axis, length_along_dim)(x)
5922
5154
  indices_or_sections_n = [length_along_dim, length_along_dim + 1]
5923
5155
  res2 = _tensor_split_sub_tensors(x, indices_or_sections_n, axis)
5924
5156
  for _ in np.arange(length_along_dim, indices_or_sections):
5925
5157
  res += tuple(res2)[1:]
5926
5158
  elif length_along_dim % indices_or_sections == 0:
5927
- res = P.Split(axis, indices_or_sections)(x)
5159
+ res = _get_cache_prim(P.Split)(axis, indices_or_sections)(x)
5928
5160
  else:
5929
5161
  num_long_tensor = length_along_dim % indices_or_sections
5930
5162
  num_short_tensor = indices_or_sections - num_long_tensor
5931
- length1 = num_long_tensor * (length_along_dim // indices_or_sections + 1)
5163
+ length1 = num_long_tensor * \
5164
+ (length_along_dim // indices_or_sections + 1)
5932
5165
  length2 = length_along_dim - length1
5933
- start1 = _list_comprehensions(rank(x), 0, True)
5166
+ start1 = _list_comprehensions(rank_(x), 0, True)
5934
5167
  size1 = _tuple_setitem(arr_shape, axis, length1)
5935
5168
  start2 = _tuple_setitem(start1, axis, length1)
5936
5169
  size2 = _tuple_setitem(arr_shape, axis, length2)
5937
- res = P.Split(axis, num_long_tensor)(tensor_slice(x, start1, size1)) + \
5938
- P.Split(axis, num_short_tensor)(tensor_slice(x, start2, size2))
5170
+ res = _get_cache_prim(P.Split)(axis, num_long_tensor)(tensor_slice(x, start1, size1)) + \
5171
+ _get_cache_prim(P.Split)(axis, num_short_tensor)(
5172
+ tensor_slice(x, start2, size2))
5939
5173
  return res
5940
5174
 
5941
5175
 
@@ -5949,11 +5183,11 @@ def tensor_split(input, indices_or_sections, axis=0):
5949
5183
 
5950
5184
  - If `indices_or_sections` is an integer n, input tensor will be split into n sections.
5951
5185
 
5952
- - If :math:`input.shape(axis)` can be divisible by n, sub-sections will have equal size
5953
- :math:`input.shape(axis) / n` .
5954
- - If :math:`input.shape(axis)` is not divisible by n, the first :math:`input.shape(axis) % n` sections
5955
- will have size :math:`input.shape(axis) // n + 1` , and the rest will have
5956
- size :math:`input.shape(axis) // n` .
5186
+ - If :math:`input.shape[axis]` can be divisible by n, sub-sections will have equal size
5187
+ :math:`input.shape[axis] / n` .
5188
+ - If :math:`input.shape[axis]` is not divisible by n, the first :math:`input.shape[axis] \bmod n` sections
5189
+ will have size :math:`input.shape[axis] // n + 1` , and the rest will have
5190
+ size :math:`input.shape[axis] // n` .
5957
5191
  - If `indices_or_sections` is of type tuple(int) or list(int), the input tensor will be split at the
5958
5192
  indices in the list or tuple. For example, given parameters :math:`indices\_or\_sections=[1, 4]`
5959
5193
  and :math:`axis=0` , the input tensor will be split into sections :math:`input[:1]` ,
@@ -5988,21 +5222,25 @@ def tensor_split(input, indices_or_sections, axis=0):
5988
5222
  raise TypeError(f'expect `x` is a Tensor, but got {type(input)}')
5989
5223
 
5990
5224
  if type(axis) is not int:
5991
- raise TypeError(f"Type of Argument `axis` should be integer but got {type(axis)}")
5225
+ raise TypeError(
5226
+ f"Type of Argument `axis` should be integer but got {type(axis)}")
5992
5227
  handle_axis = _canonicalize_axis(axis, input.ndim)
5993
5228
  if type(indices_or_sections) is int:
5994
5229
  if indices_or_sections > 0:
5995
- res = _tensor_split_sub_int(input, indices_or_sections, handle_axis)
5230
+ res = _tensor_split_sub_int(
5231
+ input, indices_or_sections, handle_axis)
5996
5232
  else:
5997
5233
  raise ValueError(f"For tensor_split, the value of 'indices_or_sections' must be more than zero "
5998
5234
  f"but got {indices_or_sections}")
5999
5235
  elif isinstance(indices_or_sections, (list, tuple)):
6000
5236
  for item in indices_or_sections:
6001
5237
  if type(item) is not int:
6002
- raise TypeError(f"Each element in 'indices_or_sections' should be integer, but got {type(item)}.")
6003
- res = _tensor_split_sub_tensors(input, indices_or_sections, handle_axis)
5238
+ raise TypeError(
5239
+ f"Each element in 'indices_or_sections' should be integer, but got {type(item)}.")
5240
+ res = _tensor_split_sub_tensors(
5241
+ input, indices_or_sections, handle_axis)
6004
5242
  else:
6005
- raise TypeError(f"Type of Argument `indices_or_sections` should be integer, tuple(int) or list(int), " \
5243
+ raise TypeError(f"Type of Argument `indices_or_sections` should be integer, tuple(int) or list(int), "
6006
5244
  f"but got {type(indices_or_sections)}")
6007
5245
 
6008
5246
  return res
@@ -6038,7 +5276,8 @@ def vsplit(input, indices_or_sections):
6038
5276
  if not isinstance(input, Tensor):
6039
5277
  raise TypeError(f'expect `x` is a Tensor, but got {type(input)}')
6040
5278
  if input.ndim < 1:
6041
- raise ValueError(f'vsplit expect `x` is a Tensor with at least 1 dimension, but got {input.ndim}')
5279
+ raise ValueError(
5280
+ f'vsplit expect `x` is a Tensor with at least 1 dimension, but got {input.ndim}')
6042
5281
  return tensor_split(input, indices_or_sections, 0)
6043
5282
 
6044
5283
 
@@ -6074,7 +5313,8 @@ def hsplit(input, indices_or_sections):
6074
5313
  if not isinstance(input, Tensor):
6075
5314
  raise TypeError(f'expect `x` is a Tensor, but got {type(input)}')
6076
5315
  if input.ndim < 2:
6077
- raise ValueError(f'hsplit expect `x` is a Tensor with at least 2 dimension, but got {input.ndim}')
5316
+ raise ValueError(
5317
+ f'hsplit expect `x` is a Tensor with at least 2 dimension, but got {input.ndim}')
6078
5318
 
6079
5319
  return tensor_split(input, indices_or_sections, 1)
6080
5320
 
@@ -6107,7 +5347,8 @@ def dsplit(input, indices_or_sections):
6107
5347
  if not isinstance(input, Tensor):
6108
5348
  raise TypeError(f'expect `x` is a Tensor, but got {type(input)}')
6109
5349
  if input.ndim < 3:
6110
- raise ValueError(f'dsplit expect `x` is a Tensor with at least 3 dimension, but got {input.ndim}')
5350
+ raise ValueError(
5351
+ f'dsplit expect `x` is a Tensor with at least 3 dimension, but got {input.ndim}')
6111
5352
 
6112
5353
  return tensor_split(input, indices_or_sections, 2)
6113
5354
 
@@ -6166,7 +5407,7 @@ def max(input, axis=None, keepdims=False, *, initial=None, where=None): # pylin
6166
5407
  tensor.
6167
5408
 
6168
5409
  - values (Tensor) - The maximum value of input tensor, with the same shape as index, and same dtype as x.
6169
- - index (Tensor) - The index for the maximum value of the input tensor, with dtype int32. If `keepdims`
5410
+ - index (Tensor) - The index for the maximum value of the input tensor, with dtype int64. If `keepdims`
6170
5411
  is true, the shape of output tensors is :math:`(input_1, input_2, ..., input_{axis-1}, 1, input_{axis+1},
6171
5412
  ..., input_N)` . Otherwise, the shape is :math:`(input_1, input_2, ..., input_{axis-1}, input_{axis+1},
6172
5413
  ..., input_N)` .
@@ -6195,16 +5436,16 @@ def max(input, axis=None, keepdims=False, *, initial=None, where=None): # pylin
6195
5436
  [[3.2 0.4 0.4 2.9 4. ]] [[1 1 0 1 1]]
6196
5437
  """
6197
5438
  if not input.shape:
6198
- return (input, Tensor(0, dtype=mstype.int32))
5439
+ return (input, Tensor(0, dtype=mstype.int64))
6199
5440
  if axis is None:
6200
- reduce_max_op = _get_cache_prim(P.ReduceMax)()
6201
- return (reduce_max_op(input), Tensor(0, dtype=mstype.int32))
5441
+ return (max_(input), Tensor(0, dtype=mstype.int64))
6202
5442
  if initial is not None and not isinstance(initial, numbers.Number):
6203
- raise TypeError(f"For 'max', 'initial' must be a scalar, but got {type(initial)}")
5443
+ raise TypeError(
5444
+ f"For 'max', 'initial' must be a scalar, but got {type(initial)}")
6204
5445
  if axis is not None and not isinstance(axis, int):
6205
5446
  raise TypeError(f"For 'max', 'axis' must be int, but got {type(axis)}")
6206
5447
  input = _init_and_select_elem(input, initial, where, ops.maximum)
6207
- argmax_with_value_op = ArgMaxWithValue(axis, keepdims)
5448
+ argmax_with_value_op = _get_cache_prim(ArgMaxWithValue)(axis, keepdims)
6208
5449
  indices, values = argmax_with_value_op(input)
6209
5450
  return values, indices
6210
5451
 
@@ -6250,7 +5491,7 @@ def argmax(input, dim=None, keepdim=False):
6250
5491
  is_dim_none = True
6251
5492
  out = _get_cache_prim(Argmax)(dim, mstype.int64)(input)
6252
5493
  if keepdim and not is_dim_none:
6253
- out = expand_dims_(out, dim)
5494
+ out = expand_dims(out, dim)
6254
5495
  return out
6255
5496
 
6256
5497
 
@@ -6312,16 +5553,17 @@ def min(input, axis=None, keepdims=False, *, initial=None, where=None): # pylin
6312
5553
  0.0 0
6313
5554
  """
6314
5555
  if not input.shape:
6315
- return (input, Tensor(0, dtype=mstype.int32))
5556
+ return (input, Tensor(0, dtype=mstype.int64))
6316
5557
  if axis is None:
6317
- return (reduce_min(input), Tensor(0, dtype=mstype.int32))
5558
+ return (min_(input), Tensor(0, dtype=mstype.int64))
6318
5559
  if initial is not None and not isinstance(initial, numbers.Number):
6319
- raise TypeError(f"For 'min', 'initial' must be a scalar, but got {type(initial)}")
5560
+ raise TypeError(
5561
+ f"For 'min', 'initial' must be a scalar, but got {type(initial)}")
6320
5562
  if axis is not None and not isinstance(axis, int):
6321
5563
  raise TypeError(f"For 'min', 'axis' must be int, but got {type(axis)}")
6322
5564
  input = _init_and_select_elem(input, initial, where, ops.minimum)
6323
- argmin_with_value_ = ArgMinWithValue(axis=axis, keep_dims=keepdims)
6324
- indices, values = argmin_with_value_(input)
5565
+ argmin_with_value_op = _get_cache_prim(ArgMinWithValue)(axis, keepdims)
5566
+ indices, values = argmin_with_value_op(input)
6325
5567
  return values, indices
6326
5568
 
6327
5569
 
@@ -6379,8 +5621,8 @@ def aminmax(input, *, axis=0, keepdims=False):
6379
5621
  output0 = ops.reshape(output0, [1] * input.ndim)
6380
5622
  output1 = ops.reshape(output1, [1] * input.ndim)
6381
5623
  return output0, output1
6382
- argmin_with_value_op = P.ArgMinWithValue(axis, keepdims)
6383
- argmax_with_value_op = P.ArgMaxWithValue(axis, keepdims)
5624
+ argmin_with_value_op = _get_cache_prim(ArgMinWithValue)(axis, keepdims)
5625
+ argmax_with_value_op = _get_cache_prim(ArgMaxWithValue)(axis, keepdims)
6384
5626
  _, output0 = argmin_with_value_op(input)
6385
5627
  _, output1 = argmax_with_value_op(input)
6386
5628
  if keepdims is True and input.ndim == 0:
@@ -6429,72 +5671,55 @@ def narrow(input, axis, start, length):
6429
5671
  validator.check_value_type("input", input, Tensor, "narrow")
6430
5672
  validator.check_axis_in_range(axis, input.ndim)
6431
5673
  validator.check_int_range(start, 0, input.shape[axis], validator.INC_LEFT)
6432
- validator.check_int_range(length, 1, input.shape[axis] - start, validator.INC_BOTH)
5674
+ validator.check_int_range(
5675
+ length, 1, input.shape[axis] - start, validator.INC_BOTH)
6433
5676
 
6434
5677
  begins = [0] * input.ndim
6435
5678
  begins[axis] = start
6436
5679
  sizes = list(input.shape)
6437
5680
  sizes[axis] = length
6438
- return P.Slice()(input, begins, sizes)
6439
-
5681
+ return tensor_slice(input, begins, sizes)
6440
5682
 
6441
- def unsorted_segment_sum(input_x, segment_ids, num_segments):
6442
- r"""
6443
- Computes the sum of a tensor along segments.
6444
-
6445
- Calculates a tensor such that :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]`, where
6446
- :math:`j,...` is a tuple describing the index of element in data.
6447
- `segment_ids` selects which elements in data to sum
6448
- up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value
6449
- range.
6450
-
6451
- The following figure shows the calculation process of unsorted_segment_sum:
6452
-
6453
- .. image:: UnsortedSegmentSum.png
6454
5683
 
6455
- Note:
6456
- - If the segment_id i is absent in the segment_ids, then output[i] will be filled with 0.
6457
- - On Ascend, if the value of segment_id is less than 0 or greater than the length of the input data shape, an
6458
- execution error will occur.
6459
-
6460
- If the sum of the given segment_ids :math:`i` is empty, then :math:`\text{output}[i] = 0`. If the given segment_ids
6461
- is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.
5684
+ def narrow_ext(input, dim, start, length):
5685
+ """
5686
+ Returns a narrowed tensor from input tensor, and
5687
+ the dimension axis is input from start to start + length.
6462
5688
 
6463
5689
  Args:
6464
- input_x (Tensor): Input Tensor contains the data to be summed.
6465
- The shape is :math:`(x_1, x_2, ..., x_R)`.
6466
- segment_ids (Tensor): TThe label indicates the segment to which each element belongs.
6467
- Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
6468
- num_segments (Union[int, Tensor], optional): Set :math:`z` as num_segments, it can be an int or 0-D Tensor.
5690
+ input (Tensor): the tensor to narrow.
5691
+ dim (int): dimension along which to narrow.
5692
+ start (int): the starting dimension.
5693
+ length (int): the distance to the ending dimension.
6469
5694
 
6470
5695
  Returns:
6471
- Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
5696
+ Tensor.
6472
5697
 
6473
5698
  Raises:
6474
- TypeError: If `num_segments` is not an int or 0-D Tensor.
6475
- ValueError: If length of shape of `segment_ids` is less than 1.
5699
+ ValueError: If dim is out of range [-input.ndim, input.ndim).
5700
+ ValueError: If start is out of range [-input.shape[dim], input.shape[dim]].
5701
+ ValueError: It length is out of range [0, input.shape[dim]-start].
6476
5702
 
6477
5703
  Supported Platforms:
6478
- ``Ascend`` ``GPU`` ``CPU``
5704
+ ``Ascend``
6479
5705
 
6480
5706
  Examples:
6481
- >>> from mindspore import Tensor
6482
- >>> from mindspore import ops
6483
5707
  >>> import mindspore
6484
- >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
6485
- >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
6486
- >>> num_segments = 4
6487
- >>> output = ops.unsorted_segment_sum(input_x, segment_ids, num_segments)
5708
+ >>> from mindspore import ops
5709
+ >>> from mindspore import Tensor
5710
+ >>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mindspore.int32)
5711
+ >>> output = ops.narrow(x, 0, 0, 2)
6488
5712
  >>> print(output)
6489
- [3. 3. 4. 0.]
6490
- >>> input_x = Tensor([1, 2, 3, 4, 2, 5], mindspore.float32)
6491
- >>> segment_ids = Tensor([0, 0, 1, 2, 3, 4], mindspore.int32)
6492
- >>> num_segments = 6
6493
- >>> output = ops.unsorted_segment_sum(input_x, segment_ids, num_segments)
5713
+ [[ 1 2 3]
5714
+ [ 4 5 6]]
5715
+ >>> output = ops.narrow(x, 1, 1, 2)
6494
5716
  >>> print(output)
6495
- [3. 3. 4. 2. 5. 0.]
5717
+ [[ 2 3]
5718
+ [ 5 6]
5719
+ [ 8 9]]
6496
5720
  """
6497
- return unsorted_segment_sum_(input_x, segment_ids, num_segments)
5721
+ validator.check_value_type("input", input, Tensor, "narrow")
5722
+ return slice_ext_op(input, dim, start, start + length, 1)
6498
5723
 
6499
5724
 
6500
5725
  def topk(input, k, dim=None, largest=True, sorted=True):
@@ -6651,8 +5876,8 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
6651
5876
  A Tensor, with same type as `input` . And its shape is as described above.
6652
5877
 
6653
5878
  Raises:
6654
- TypeError: If `kernel_size`, `dilation`, `padding`, `stride` data type is not int, tuple or list.
6655
- ValueError: If `kernel_size`, `dilation`, `stride` value is not
5879
+ TypeError: If `output_size`, `kernel_size`, `stride`, `dilation`, `padding` data type is not int, tuple or list.
5880
+ ValueError: If `output_size`, `kernel_size`, `dilation`, `stride` value is not
6656
5881
  greater than zero or elements number more than `2`.
6657
5882
  ValueError: If `padding` value is less than zero or elements number more than `2`.
6658
5883
  ValueError: If `input.shape[1] != kernel_size[0] * kernel_size[1]`
@@ -6688,7 +5913,8 @@ def _check_unfold_params(param, param_name, param_size):
6688
5913
  """Check the parameters of unfold op."""
6689
5914
  validator.check_value_type(param_name, param, [int, tuple, list], 'unfold')
6690
5915
  param = (param, param) if isinstance(param, int) else param
6691
- validator.check(param_name + " size", len(param), "", param_size, validator.IN, 'unfold')
5916
+ validator.check(param_name + " size", len(param), "",
5917
+ param_size, validator.IN, 'unfold')
6692
5918
  if param_name == "padding":
6693
5919
  validator.check_non_negative_int_sequence(param, param_name, 'unfold')
6694
5920
  else:
@@ -6728,9 +5954,7 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
6728
5954
  .. warning::
6729
5955
  - The output is a 3-dimensional Tensor whose shape is
6730
5956
  :math:`(N, C \times \prod(\text{kernel_size}), L)` .
6731
-
6732
- .. warning::
6733
- This is an experimental API that is subject to change or deletion.
5957
+ - This is an experimental API that is subject to change or deletion.
6734
5958
 
6735
5959
  Args:
6736
5960
  input (Tensor): 4-D Tensor, supported dtypes: float16, float32, float64, complex64 and complex128.
@@ -6739,10 +5963,11 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
6739
5963
  dilation (Union[int, tuple[int], list[int]], optional): The dilation of the window, should be two int
6740
5964
  for height and width. If type is int, it means that height equal with width. Default: ``1`` .
6741
5965
  padding (Union[int, tuple[int], list[int]], optional): The pad of the window, that must be
6742
- a tuple/list of one or two `int` for height and width.
6743
- If one int, pad_height = pad_width.
6744
- If two int, pad_height = padding[0], pad_width = padding[1].
6745
- Default: ``0`` .
5966
+ a tuple/list of one or two `int` for height and width. Default: ``0`` .
5967
+
5968
+ - If one int, pad_height = pad_width.
5969
+ - If two int, pad_height = padding[0], pad_width = padding[1].
5970
+
6746
5971
  stride (Union[int, tuple[int], list[int]], optional): The stride of the window, should be two int
6747
5972
  for height and width. If type is int, it means that height equal with width. Default: ``1`` .
6748
5973
 
@@ -6789,102 +6014,11 @@ def _check_diagonal_axes(dim1, dim2, x_ndim):
6789
6014
  return axes
6790
6015
 
6791
6016
 
6792
- def diagonal(input, offset=0, dim1=0, dim2=1):
6793
- """
6794
- Returns specified diagonals of `input`.
6795
-
6796
- If `input` is 2-D, returns the diagonal of `input` with the given offset.
6797
- If `input` has more than two
6798
- dimensions, then the axes specified by `dim1` and `dim2` are used to determine
6799
- the 2-D sub-array whose diagonal is returned. In this case, remove the `dim1` and `dim2` dimensions of `input`
6800
- and insert the last dimension of `input` by the diagonal elements determined by `dim1` and `dim2`.
6801
-
6802
- Args:
6803
- input (Tensor): Array from which the diagonals are taken.
6804
- offset (int, optional): Offset of the diagonal from the main diagonal.
6805
- Can be positive or negative. Default: ``0`` .
6806
- dim1 (int, optional): Axis to be used as the first axis of the 2-D
6807
- sub-arrays from which the diagonals should be taken. Defaults to
6808
- first axis (0). Default: ``0`` .
6809
- dim2 (int, optional): Axis to be used as the second axis of the 2-D
6810
- sub-arrays from which the diagonals should be taken. Defaults to
6811
- second axis (1). Default: ``1`` .
6812
-
6813
- Returns:
6814
- Tensor, if `input` is 2-D, then `input` 1-D array containing the diagonal. If
6815
- ``input.ndim > 2``, then the dimensions specified by `dim1` and `dim2` are removed,
6816
- and a new axis inserted at the end corresponding to the diagonal.
6817
-
6818
- Raises:
6819
- TypeError: if `dim1` or `dim2` are not an int.
6820
- ValueError: if the input tensor has less than two dimensions.
6821
-
6822
- Supported Platforms:
6823
- ``Ascend`` ``GPU`` ``CPU``
6824
-
6825
- Examples:
6826
- >>> from mindspore import Tensor, ops
6827
- >>> from mindspore import dtype as mstype
6828
- >>> x = Tensor([[0, 1], [2, 3]], mstype.float32)
6829
- >>> output = ops.diagonal(x)
6830
- >>> print(output)
6831
- [0 3]
6832
- """
6833
- x_ndim = input.ndim
6834
- if x_ndim < 2:
6835
- raise ValueError(f"For 'ops.diagonal', the original tensor requires at least two dimensions, but got {x_ndim}")
6836
- _check_attr_dtype("dim1", dim1, [int], "diagonal")
6837
- _check_attr_dtype("dim2", dim2, [int], "diagonal")
6838
- dtype = input.dtype
6839
-
6840
- axes = _check_diagonal_axes(dim1, dim2, x_ndim)
6841
- perm = ()
6842
- for i in ms_arrange(x_ndim):
6843
- if i not in axes:
6844
- perm += (i,)
6845
- perm += axes
6846
- input = input.transpose(perm)
6847
-
6848
- x_shape = input.shape
6849
- n, m = x_shape[-2:]
6850
-
6851
- e = ops.eye(n, m, dtype)
6852
- if offset >= m or offset <= -n:
6853
- zero_shape = x_shape[:-2] + (0,)
6854
- return ops.zeros(zero_shape, dtype)
6855
- if offset != 0:
6856
- e = e.astype(mstype.float32)
6857
- if offset > 0:
6858
- e_left = ops.fill(mstype.float32, (n, offset), 0)
6859
- e_right = e[..., 0:m - offset:1]
6860
- e = ops.cat((e_left, e_right), 1).astype(dtype)
6861
- elif offset < 0:
6862
- e_upper = ops.fill(mstype.float32, (-offset, m), 0)
6863
- e_lower = e[0:n + offset:1, ...]
6864
- e = ops.cat((e_upper, e_lower), 0).astype(dtype)
6865
- e = ops.broadcast_to(e, x_shape)
6866
-
6867
- prod_val = ops.mul(input, e)
6868
- res = ops.ReduceSum()(prod_val.astype(mstype.float32), -1)
6869
-
6870
- begin = ()
6871
- for _ in ms_arrange(x_ndim - 2):
6872
- begin += (0,)
6873
- last_dim_begin = builtins.max(0, -offset)
6874
- begin += (last_dim_begin,)
6875
- res_size = res.shape[:-1]
6876
- last_dim_end = builtins.min(x_shape[-2], builtins.max(0, x_shape[-1] - offset)) - last_dim_begin
6877
- if last_dim_end <= 0:
6878
- return Tensor([])
6879
- res_size += (last_dim_end,)
6880
- res = ops.slice(res, begin, res_size)
6881
- return res.astype(dtype)
6882
-
6883
-
6884
6017
  def _check_is_tensor(param_name, input, cls_name):
6885
6018
  """Returns True if input is Tensor."""
6886
6019
  if not isinstance(input, Tensor):
6887
- raise TypeError(f"For {cls_name}, {param_name} must be a Tensor, but got {type(input)}.")
6020
+ raise TypeError(
6021
+ f"For {cls_name}, {param_name} must be a Tensor, but got {type(input)}.")
6888
6022
 
6889
6023
 
6890
6024
  @_primexpr
@@ -6900,6 +6034,9 @@ def diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
6900
6034
  the elements in these two dimensions will be treated as elements of a matrix,
6901
6035
  and `src` is embedded on the diagonal of the matrix.
6902
6036
 
6037
+ Note:
6038
+ Currently, ``inf`` value of elements in `input` or `src` is not supported.
6039
+
6903
6040
  Args:
6904
6041
  input (Tensor): Input Tensor, whose dimension is larger than 1.
6905
6042
  src (Tensor): The source Tensor to embed.
@@ -6936,16 +6073,39 @@ def diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
6936
6073
  """
6937
6074
  _check_is_tensor("input", input, "diagonal_scatter")
6938
6075
  _check_is_tensor("src", src, "diagonal_scatter")
6939
- _check_is_int(offset, "offset", "diagonal_scatter")
6940
- _check_is_int(dim1, "dim1", "diagonal_scatter")
6941
- _check_is_int(dim2, "dim2", "diagonal_scatter")
6942
6076
  input_diag = input.diagonal(offset, dim1, dim2)
6943
6077
  _check_diagonal_scatter_shape(input_diag.shape, src.shape)
6944
- embed = ones_like(src)
6945
- embed = ops.diag_embed(embed, offset, dim1, dim2)
6946
- embed = input * embed
6078
+ input_shape = input.shape
6079
+ zeros_shape = list(input_shape)
6080
+ m, n = input_shape[dim1], input_shape[dim2]
6081
+ if m == n:
6082
+ src = src - input_diag
6083
+ src = ops.diag_embed(src, offset, dim1, dim2)
6084
+ return input + src
6085
+ if m > n:
6086
+ axis = dim2
6087
+ zeros_shape[axis] = m - n
6088
+ else:
6089
+ axis = dim1
6090
+ zeros_shape[axis] = n - m
6091
+ zeros_tensor = zeros(zeros_shape, dtype=input.dtype)
6092
+ input = concat((input, zeros_tensor), axis)
6093
+ input_diag = input.diagonal(offset, dim1, dim2)
6094
+ if src.shape != input_diag.shape:
6095
+ zeros_shape = []
6096
+ for i, ax in enumerate(src.shape):
6097
+ if ax == input_diag.shape[i]:
6098
+ zeros_shape.append(ax)
6099
+ else:
6100
+ axis = i
6101
+ zeros_shape.append(input_diag.shape[i] - ax)
6102
+ zeros_tensor = zeros(zeros_shape, dtype=src.dtype)
6103
+ src = concat((src, zeros_tensor), axis)
6104
+ src = src - input_diag
6947
6105
  src = ops.diag_embed(src, offset, dim1, dim2)
6948
- return input + src - embed
6106
+ input = input + src
6107
+ begin = (0,) * input.ndim
6108
+ return slice(input, begin, input_shape)
6949
6109
 
6950
6110
 
6951
6111
  def lstsq(input, A):
@@ -7004,8 +6164,7 @@ def lstsq(input, A):
7004
6164
  [-6.5000005 -4.500001 ]
7005
6165
  [-3.500002 -2.5000017]]
7006
6166
  """
7007
- lstsq_op = _get_cache_prim(Lstsq)()
7008
- return lstsq_op(input, A)
6167
+ return lstsq_(input, A)
7009
6168
 
7010
6169
 
7011
6170
  def mvlgamma(input, p):
@@ -7053,6 +6212,64 @@ def mvlgamma(input, p):
7053
6212
  return mvlgamma_op(input)
7054
6213
 
7055
6214
 
6215
+ def nonzero(input, as_tuple=False):
6216
+ r"""
6217
+ Return the positions of all non-zero values.
6218
+
6219
+ Args:
6220
+ input (Tensor): The input Tensor, its rank should be greater than or equal to 1.
6221
+ as_tuple (bool, optional): Whether the output is tuple.
6222
+ If ``False`` , return Tensor. Default: ``False`` .
6223
+ If ``True`` , return Tuple of Tensor, only support ``Ascend`` .
6224
+
6225
+
6226
+ Returns:
6227
+ - If `as_tuple` is ``False``, return the Tensor, a 2-D Tensor whose data type is int64,
6228
+ containing the positions of all non-zero values of the input.
6229
+ - If `as_tuple` is ``True``, return the Tuple of Tensor and data type is int64.
6230
+ The Tuple length is the dimension of the input tensor,
6231
+ and each element is the 1D tensor of the subscript of all non-zero elements of
6232
+ the input tensor in that dimension.
6233
+
6234
+ Raises:
6235
+ TypeError: If `input` is not Tensor.
6236
+ TypeError: If `as_tuple` is not bool.
6237
+ ValueError: If dim of `input` equals to 0.
6238
+
6239
+ Supported Platforms:
6240
+ ``Ascend`` ``GPU`` ``CPU``
6241
+
6242
+ Examples:
6243
+ >>> import mindspore
6244
+ >>> import numpy as np
6245
+ >>> from mindspore import Tensor, ops
6246
+ >>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32)
6247
+ >>> output = ops.nonzero(x)
6248
+ >>> print(output)
6249
+ [[0 0 0]
6250
+ [0 1 0]]
6251
+ >>> x = Tensor(np.array([1, 0, 2, 0, 3]), mindspore.int32)
6252
+ >>> output = ops.nonzero(x, False)
6253
+ >>> print(output)
6254
+ [[0]
6255
+ [2]
6256
+ [4]]
6257
+ >>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32)
6258
+ >>> output = ops.nonzero(x, True)
6259
+ >>> print(output)
6260
+ (Tensor(shape=[2], dtype=Int64, value=[0, 0]),
6261
+ Tensor(shape=[2], dtype=Int64, value=[0, 1]),
6262
+ Tensor(shape=[2], dtype=Int64, value=[0, 0]))
6263
+ >>> x = Tensor(np.array([1, 0, 2, 0, 3]), mindspore.int32)
6264
+ >>> output = ops.nonzero(x, True)
6265
+ >>> print(output)
6266
+ (Tensor(shape=[3], dtype=Int64, value=[0, 2, 4]), )
6267
+ """
6268
+ if as_tuple:
6269
+ return non_zero_ext_(input)
6270
+ return non_zero_(input)
6271
+
6272
+
7056
6273
  def argwhere(input):
7057
6274
  """
7058
6275
  Return a Tensor of the positions of all non-zero values.
@@ -7080,7 +6297,7 @@ def argwhere(input):
7080
6297
  [[0 0 0]
7081
6298
  [0 1 0]]
7082
6299
  """
7083
- return nonzero_(input)
6300
+ return nonzero(input)
7084
6301
 
7085
6302
 
7086
6303
  def column_stack(tensors):
@@ -7114,20 +6331,22 @@ def column_stack(tensors):
7114
6331
  [1 2]]
7115
6332
  """
7116
6333
  if not isinstance(tensors, (list, tuple)):
7117
- raise TypeError(f"For column_stack, the input must be list or tuple of tensors, but got {type(tensors)}.")
6334
+ raise TypeError(
6335
+ f"For column_stack, the input must be list or tuple of tensors, but got {type(tensors)}.")
7118
6336
 
7119
6337
  trans_x = ()
7120
- _expand_dims = _get_cache_prim(P.ExpandDims)()
7121
6338
  for tensor in tensors:
7122
6339
  if not isinstance(tensor, Tensor):
7123
- raise TypeError(f"For column_stack, the input element must be tensor, but got {type(tensor)}.")
6340
+ raise TypeError(
6341
+ f"For column_stack, the input element must be tensor, but got {type(tensor)}.")
7124
6342
  if tensor.ndim < 1:
7125
- tensor = _expand_dims(tensor, 0)
6343
+ tensor = expand_dims(tensor, 0)
7126
6344
  if tensor.ndim == 1:
7127
- tensor = _expand_dims(tensor, 1)
6345
+ tensor = expand_dims(tensor, 1)
7128
6346
  trans_x += (tensor,)
7129
6347
  if not trans_x:
7130
- raise ValueError(f"For column_stack, the input must have at least 1 tensor, but got 0.")
6348
+ raise ValueError(
6349
+ f"For column_stack, the input must have at least 1 tensor, but got 0.")
7131
6350
  _concat = _get_cache_prim(P.Concat)(1)
7132
6351
  return _concat(trans_x)
7133
6352
 
@@ -7163,17 +6382,20 @@ def hstack(tensors):
7163
6382
  [1. 1. 1. 2. 2. 2.]
7164
6383
  """
7165
6384
  if not isinstance(tensors, (list, tuple)):
7166
- raise TypeError(f"For hstack, the input must be list or tuple, but got {type(tensors)}.")
6385
+ raise TypeError(
6386
+ f"For hstack, the input must be list or tuple, but got {type(tensors)}.")
7167
6387
 
7168
6388
  tuple_of_tensor = ()
7169
6389
  for tensor in tensors:
7170
6390
  if not isinstance(tensor, Tensor):
7171
- raise TypeError(f"For hstack, the input element must be tensor, but got {type(tensor)}.")
6391
+ raise TypeError(
6392
+ f"For hstack, the input element must be tensor, but got {type(tensor)}.")
7172
6393
  if tensor.ndim < 1:
7173
- tensor = expand_dims_(tensor, 0)
6394
+ tensor = expand_dims(tensor, 0)
7174
6395
  tuple_of_tensor += (tensor,)
7175
6396
  if not tuple_of_tensor:
7176
- raise ValueError("For hstack, the input must have at least 1 tensor, but got 0.")
6397
+ raise ValueError(
6398
+ "For hstack, the input must have at least 1 tensor, but got 0.")
7177
6399
  if tuple_of_tensor[0].ndim <= 1:
7178
6400
  _concat = _get_cache_prim(P.Concat)(0)
7179
6401
  return _concat(tuple_of_tensor)
@@ -7202,7 +6424,8 @@ def _get_moved_perm(ndim, source, destination):
7202
6424
  Helper function for movedim, returns permutation after moving axis
7203
6425
  from source to destination.
7204
6426
  """
7205
- dest_sorted_idx = [i for i, _ in sorted(enumerate(destination), key=operator.itemgetter(1))]
6427
+ dest_sorted_idx = [i for i, _ in sorted(
6428
+ enumerate(destination), key=operator.itemgetter(1))]
7206
6429
  axis_orig = [i for i in builtins.range(0, ndim) if i not in source]
7207
6430
 
7208
6431
  k = 0
@@ -7270,7 +6493,7 @@ def movedim(x, source, destination):
7270
6493
  f"For `source` and `destination` arguments, the number of elements must be the same, but got 'source':"
7271
6494
  f" {len(source)} and 'destination': {len(destination)}.")
7272
6495
  perm = _get_moved_perm(ndim, source, destination)
7273
- return _get_cache_prim(P.Transpose)()(x, perm)
6496
+ return transpose_(x, perm)
7274
6497
 
7275
6498
 
7276
6499
  def moveaxis(x, source, destination):
@@ -7321,7 +6544,7 @@ def swapaxes(input, axis0, axis1):
7321
6544
 
7322
6545
  Examples:
7323
6546
  >>> import numpy as np
7324
- >>> import mindspore.ops as ops
6547
+ >>> from mindspore import ops
7325
6548
  >>> from mindspore import Tensor
7326
6549
  >>> input = Tensor(np.ones((2,3,4), dtype=np.float32))
7327
6550
  >>> output = ops.swapaxes(input, 0, 2)
@@ -7329,7 +6552,8 @@ def swapaxes(input, axis0, axis1):
7329
6552
  (4, 3, 2)
7330
6553
  '''
7331
6554
  if not isinstance(input, Tensor):
7332
- raise TypeError(f'For ops.swapaxes, parameter `input` must be Tensor, but got {type(input)}')
6555
+ raise TypeError(
6556
+ f'For ops.swapaxes, parameter `input` must be Tensor, but got {type(input)}')
7333
6557
 
7334
6558
  axis0, axis1 = _check_swapaxes_axis((axis0, axis1), input.ndim)
7335
6559
  if axis0 == axis1:
@@ -7340,12 +6564,12 @@ def swapaxes(input, axis0, axis1):
7340
6564
  perm = ops.make_range(0, input.ndim)
7341
6565
  if axis1 + 1 < input.ndim:
7342
6566
  new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \
7343
- perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1] + perm[axis1 + 1:]
6567
+ perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1] + perm[axis1 + 1:]
7344
6568
  else:
7345
6569
  new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \
7346
- perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1]
6570
+ perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1]
7347
6571
 
7348
- return _get_cache_prim(P.Transpose)()(input, new_perm)
6572
+ return transpose_(input, new_perm)
7349
6573
 
7350
6574
 
7351
6575
  def swapdims(input, dim0, dim1):
@@ -7371,7 +6595,7 @@ def swapdims(input, dim0, dim1):
7371
6595
 
7372
6596
  Examples:
7373
6597
  >>> import numpy as np
7374
- >>> import mindspore.ops as ops
6598
+ >>> from mindspore import ops
7375
6599
  >>> from mindspore import Tensor
7376
6600
  >>> input = Tensor(np.ones((2,3,4), dtype=np.float32))
7377
6601
  >>> output = ops.swapdims(input, 0, 2)
@@ -7389,13 +6613,15 @@ def _check_is_int(arg_value, arg_name, op_name):
7389
6613
 
7390
6614
  @_primexpr
7391
6615
  def _check_positive_int(arg_value, arg_name, op_name):
7392
- arg_value = validator.check_int_range(arg_value, 0, 2147483647, validator.INC_RIGHT, arg_name, op_name)
6616
+ arg_value = validator.check_int_range(
6617
+ arg_value, 0, 2147483647, validator.INC_RIGHT, arg_name, op_name)
7393
6618
  return arg_value
7394
6619
 
7395
6620
 
7396
6621
  @constexpr
7397
6622
  def _check_axis_range(arg_value, limit, arg_name, op_name):
7398
- arg_value = validator.check_int_range(arg_value, -limit, limit, validator.INC_LEFT, arg_name, op_name)
6623
+ arg_value = validator.check_int_range(
6624
+ arg_value, -limit, limit, validator.INC_LEFT, arg_name, op_name)
7399
6625
  return arg_value
7400
6626
 
7401
6627
 
@@ -7413,6 +6639,14 @@ def _cal_reshape(x_shape, rep, axis):
7413
6639
  return tuple(x_reshape)
7414
6640
 
7415
6641
 
6642
+ @_primexpr
6643
+ def _check_rank_range(x_rank, limit, arg_name, op_name):
6644
+ if x_rank > limit:
6645
+ raise ValueError(
6646
+ f"For {op_name}, the rank of {arg_name} should be less than or equal to {limit}, but got {x_rank}.")
6647
+ return x_rank
6648
+
6649
+
7416
6650
  def repeat_interleave(input, repeats, axis=None):
7417
6651
  """
7418
6652
  Repeat elements of a tensor along an axis, like `numpy.repeat`.
@@ -7453,13 +6687,58 @@ def repeat_interleave(input, repeats, axis=None):
7453
6687
  return output
7454
6688
 
7455
6689
 
6690
+ def repeat_interleave_ext(input, repeats, dim=None, output_size=None):
6691
+ r"""
6692
+ Repeat elements of a tensor along an axis, like `numpy.repeat`.
6693
+
6694
+ .. warning::
6695
+ Only support on Atlas A2 training series.
6696
+
6697
+ Args:
6698
+ input (Tensor): The tensor to repeat values for. Must be of type: float16,
6699
+ float32, int8, uint8, int16, int32, or int64.
6700
+ repeats (Union[int, tuple, list, Tensor]): The number of times to repeat, must be positive.
6701
+ dim (int, optional): The dim along which to repeat, Default: ``None``. if dims is None,
6702
+ the input Tensor will be flattened and the output will alse be flattened.
6703
+ output_size (int, optional): Total output size for the given axis (e.g. sum of repeats),
6704
+ Default: ``None``.
6705
+
6706
+ Returns:
6707
+ One tensor with values repeated along the specified dim. If input has shape
6708
+ :math:`(s1, s2, ..., sn)` and dim is i, the output will have shape :math:`(s1, s2, ...,
6709
+ si * repeats, ..., sn)`. The output type will be the same as the type of `input`.
6710
+
6711
+ Supported Platforms:
6712
+ ``Ascend``
6713
+
6714
+ Examples:
6715
+ >>> import mindspore
6716
+ >>> import numpy as np
6717
+ >>> from mindspore import Tensor, ops
6718
+ >>> input = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
6719
+ >>> output = ops.function.array_func.repeat_interleave_ext(input, repeats=2, dim=0)
6720
+ >>> print(output)
6721
+ [[0 1 2]
6722
+ [0 1 2]
6723
+ [3 4 5]
6724
+ [3 4 5]]
6725
+ """
6726
+ if isinstance(repeats, int):
6727
+ return repeat_interleave_int_(input, repeats, dim, output_size)
6728
+ return repeat_interleave_tensor_(input, repeats, dim, output_size)
6729
+
6730
+
7456
6731
  def repeat_elements(x, rep, axis=0):
7457
6732
  """
7458
- Repeat elements of a tensor along an axis, like `np.repeat` .
6733
+ Repeat elements of a tensor along an axis, like `numpy.repeat` .
6734
+
6735
+ Note:
6736
+ It is recommended to use :func:'mindspore.mint.repeat_interleave', the dimension of input 'x' can support
6737
+ a maximum of 8, and get better performance.
7459
6738
 
7460
6739
  Args:
7461
- x (Tensor): The tensor to repeat values for. Must be of type: float16,
7462
- float32, int8, uint8, int16, int32, or int64.
6740
+ x (Tensor): The tensor to repeat values for. Must be of type: float16, float32, int8, uint8, int16, int32,
6741
+ or int64. The rank of `x` must be less than or equal to 7.
7463
6742
  rep (int): The number of times to repeat, must be positive.
7464
6743
  axis (int): The axis along which to repeat. Default: 0.
7465
6744
 
@@ -7468,6 +6747,9 @@ def repeat_elements(x, rep, axis=0):
7468
6747
  :math:`(s1, s2, ..., sn)` and axis is i, the output will have shape :math:`(s1, s2, ..., si * rep, ..., sn)`.
7469
6748
  The output type will be the same as the type of `x`.
7470
6749
 
6750
+ Raises:
6751
+ ValueError: If the rank of `x` is greater than 7.
6752
+
7471
6753
  Supported Platforms:
7472
6754
  ``Ascend`` ``GPU`` ``CPU``
7473
6755
 
@@ -7493,34 +6775,20 @@ def repeat_elements(x, rep, axis=0):
7493
6775
  const_utils.check_type_valid(ops.dtype(x), mstype.number_type, 'input x')
7494
6776
  rep = _check_positive_int(rep, "rep", "repeat_elements")
7495
6777
  axis = _check_is_int(axis, "axis", "repeat_elements")
7496
- shape_op = P.Shape()
7497
- rank_op = P.Rank()
7498
- tile_op = P.Tile()
7499
- expand_dims_op = P.ExpandDims()
7500
- reshape_op = P.Reshape()
7501
- x_rank = rank_op(x)
6778
+ x_rank = rank_(x)
6779
+ x_rank = _check_rank_range(x_rank, 7, "x", "repeat_elements")
7502
6780
  axis = _check_axis_range(axis, x_rank, "axis", "repeat_elements")
6781
+ axis = axis + x.ndim if axis < 0 else axis
7503
6782
  expand_axis = axis + 1
7504
- x_expand = expand_dims_op(x, expand_axis)
6783
+ x_expand = expand_dims(x, expand_axis)
7505
6784
  rep_dims = _cal_repeat_dims(x_rank, rep, expand_axis)
7506
- x_expand = tile_op(x_expand, rep_dims)
7507
- x_shape = shape_op(x)
6785
+ x_expand = tile_(x_expand, rep_dims)
6786
+ x_shape = shape_(x)
7508
6787
  x_reshape = _cal_reshape(x_shape, rep, axis)
7509
- x_rep = reshape_op(x_expand, x_reshape)
6788
+ x_rep = reshape_(x_expand, x_reshape)
7510
6789
  return x_rep
7511
6790
 
7512
6791
 
7513
- @_primexpr
7514
- def _check_sequence_mask_input_len(input_shape, prim_name=None):
7515
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
7516
- if not input_shape:
7517
- raise ValueError(f"{msg_prefix} input_shape must be greater than 0, but got {input_shape}.")
7518
- # broadcast only supports 7d shape
7519
- shape_size = len(input_shape)
7520
- if shape_size >= 7:
7521
- raise ValueError(f"{msg_prefix} dimension of input_shape must be less than 7, but got {shape_size}d.")
7522
-
7523
-
7524
6792
  def sequence_mask(lengths, maxlen=None):
7525
6793
  """
7526
6794
  Returns a mask tensor representing the first N positions of each cell.
@@ -7573,29 +6841,21 @@ def sequence_mask(lengths, maxlen=None):
7573
6841
  [[ True True False False ]
7574
6842
  [ True True True True ]]]
7575
6843
  """
7576
-
7577
- argmax_op = P.ArgMaxWithValue()
7578
- reshape_op = P.Reshape()
7579
- range_op = P.Range()
7580
- expand_op = P.ExpandDims()
7581
- cast_op = P.Cast()
7582
- to_tensor_op = P.ScalarToTensor()
7583
- shape_op = P.Shape()
7584
-
7585
- const_utils.check_type_valid(ops.dtype(lengths), [mstype.int64, mstype.int32], 'lengths')
7586
- _check_sequence_mask_input_len(shape_op(lengths), "sequence_mask")
6844
+ const_utils.check_type_valid(
6845
+ ops.dtype(lengths), [mstype.int64, mstype.int32], 'lengths')
7587
6846
 
7588
6847
  if maxlen is None:
7589
- flatten_data = reshape_op(lengths, (-1,))
7590
- flatten_data = cast_op(flatten_data, mstype.float32)
7591
- _, value = argmax_op(flatten_data)
7592
- maxlen = cast_op(value, mstype.int32)
6848
+ flatten_data = reshape_(lengths, (-1,))
6849
+ flatten_data = cast_(flatten_data, mstype.float32)
6850
+ _, value = arg_max_with_value_(flatten_data)
6851
+ maxlen = cast_(value, mstype.int32)
7593
6852
  else:
7594
6853
  maxlen = _check_positive_int(maxlen, "maxlen", "sequence_mask")
7595
- maxlen = to_tensor_op(maxlen, mstype.int32)
6854
+ maxlen = scalar_to_tensor_(maxlen, mstype.int32)
7596
6855
 
7597
- range_vector = range_op(to_tensor_op(0, mstype.int32), maxlen, to_tensor_op(1, mstype.int32))
7598
- mask = expand_op(lengths, -1)
6856
+ range_vector = range_(scalar_to_tensor_(0, mstype.int32),
6857
+ maxlen, scalar_to_tensor_(1, mstype.int32))
6858
+ mask = expand_dims(lengths, -1)
7599
6859
  result = range_vector < mask
7600
6860
  return result
7601
6861
 
@@ -7608,34 +6868,220 @@ def top_k(input_x, k, sorted=True):
7608
6868
  return top_k_(input_x, k)
7609
6869
 
7610
6870
 
7611
- def deepcopy(input_x):
6871
+ def gather_ext(input, dim, index):
6872
+ r"""
6873
+ Gather data from a tensor by indices.
6874
+
6875
+ .. math::
6876
+ output[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)] =
6877
+ input[(i_0, i_1, ..., index[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)], i_{dim+1}, ..., i_n)]
6878
+
6879
+ .. warning::
6880
+ On Ascend, the behavior is unpredictable in the following cases:
6881
+
6882
+ - the value of `index` is not in the range `[-input.shape[dim], input.shape[dim])` in forward;
6883
+ - the value of `index` is not in the range `[0, input.shape[dim])` in backward.
6884
+
6885
+ Args:
6886
+ input (Tensor): The target tensor to gather values.
6887
+ dim (int): the axis to index along, must be in range `[-input.rank, input.rank)`.
6888
+ index (Tensor): The index tensor, with int32 or int64 data type. An valid `index` should be:
6889
+
6890
+ - `index.rank == input.rank`;
6891
+ - for `axis != dim`, `index.shape[axis] <= input.shape[axis]`;
6892
+ - the value of `index` is in range `[-input.shape[dim], input.shape[dim])`.
6893
+
6894
+ Returns:
6895
+ Tensor, has the same type as `input` and the same shape as `index`.
6896
+
6897
+ Raises:
6898
+ ValueError: If the shape of `index` is illegal.
6899
+ ValueError: If `dim` is not in `[-input.rank, input.rank)`.
6900
+ ValueError: If the value of `index` is out of the valid range.
6901
+ TypeError: If the type of `index` is illegal.
6902
+
6903
+ Supported Platforms:
6904
+ ``Ascend`` ``GPU`` ``CPU``
6905
+
6906
+ Examples:
6907
+ >>> import mindspore
6908
+ >>> import numpy as np
6909
+ >>> from mindspore import Tensor, ops
6910
+ >>> from mindspore.ops.function.array_func import gather_ext
6911
+ >>> input = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
6912
+ >>> index = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
6913
+ >>> output = gather_ext(input, 1, index)
6914
+ >>> print(output)
6915
+ [[-0.1 -0.1]
6916
+ [0.5 0.5]]
6917
+ """
6918
+ return gather_d_op(input, dim, index)
6919
+
6920
+
6921
+ def max_ext(input, dim=None, keepdim=False):
7612
6922
  """
7613
- Returns a deepcopy of input tensor.
6923
+ Calculates the maximum value along with the given dimension for the input tensor.
7614
6924
 
7615
6925
  Args:
7616
- input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
6926
+ input (Tensor): The input tensor, can be any dimension. Complex tensor is not supported for now.
6927
+ dim (int, optional): The dimension to reduce. Default: ``None`` .
6928
+ keepdim (bool, optional): Whether to reduce dimension, if true, the output will keep same dimension
6929
+ with the input, the output will reduce dimension if false. Default: ``False`` .
7617
6930
 
7618
6931
  Returns:
7619
- Tensor, a deepcopy of `input_x`.
6932
+ Tensor if `dim` is the default value ``None`` , the maximum value of input tensor, with the shape :math:`()` ,
6933
+ and same dtype as `input`.
6934
+
6935
+ tuple (Tensor) if `dim` is not the default value ``None`` , tuple of 2 tensors, containing the maximum
6936
+ value of the input tensor along the given dimension `dim` and the corresponding index.
6937
+
6938
+ - **values (Tensor)** - The maximum value of input tensor along the given dimension `dim`, with same dtype as
6939
+ `input`. If `keepdim` is ``True`` , the shape of output tensors is :math:`(input_1, input_2, ...,
6940
+ input_{axis-1}, 1, input_{axis+1}, ..., input_N)` . Otherwise, the shape is :math:`(input_1, input_2, ...,
6941
+ input_{axis-1}, input_{axis+1}, ..., input_N)` .
6942
+ - **index (Tensor)** - The index for the maximum value of the input tensor along the given dimension `dim`, with
6943
+ the same shape as `values`.
7620
6944
 
7621
6945
  Raises:
7622
- TypeError: If `input_x` is not a Tensor.
6946
+ ValueError: If `dim` is the default value ``None`` and `keepdim` is not ``False`` .
7623
6947
 
7624
6948
  Supported Platforms:
7625
6949
  ``Ascend`` ``GPU`` ``CPU``
7626
6950
 
7627
6951
  Examples:
7628
6952
  >>> import mindspore
6953
+ >>> import numpy as np
7629
6954
  >>> from mindspore import Tensor, ops
7630
- >>> input = Tensor([[0, 1], [2, 1]], dtype=mindspore.int32)
7631
- >>> output = ops.deepcopy(input)
6955
+ >>> from mindspore.ops.function.array_func import max_ext
6956
+ >>> y = Tensor(np.array([[0.0, 0.3, 0.4, 0.5, 0.1],
6957
+ ... [3.2, 0.4, 0.1, 2.9, 4.0]]), mindspore.float32)
6958
+ >>> output, index = max_ext(y, 0, True)
6959
+ >>> print(output, index)
6960
+ [[3.2 0.4 0.4 2.9 4. ]] [[1 1 0 1 1]]
6961
+ """
6962
+ if dim is None:
6963
+ if keepdim is not False:
6964
+ raise ValueError(
6965
+ f"For 'max', the `keepdim` must be False when the `dim` is None, but got {keepdim}")
6966
+ return max_(input)
6967
+ argmax_with_value_op = _get_cache_prim(ArgMaxWithValue)(dim, keepdim)
6968
+ indices, values = argmax_with_value_op(input)
6969
+ return values, indices
6970
+
6971
+
6972
+ def min_ext(input, dim=None, keepdim=False):
6973
+ """
6974
+ Calculates the minimum value along with the given dimension for the input tensor.
6975
+
6976
+ Args:
6977
+ input (Tensor): The input tensor, can be any dimension. Complex tensor is not supported for now.
6978
+ dim (int, optional): The dimension to reduce. Default: ``None`` .
6979
+ keepdim (bool, optional): Whether to reduce dimension, if true, the output will keep same dimension
6980
+ with the input, the output will reduce dimension if false. Default: ``False`` .
6981
+
6982
+ Returns:
6983
+ Tensor if `dim` is the default value ``None`` , the minimum value of input tensor, with the shape :math:`()` ,
6984
+ and same dtype as `input`.
6985
+
6986
+ tuple (Tensor) if `dim` is not the default value ``None`` , tuple of 2 tensors, containing the minimum value
6987
+ of the input tensor along the given dimension `dim` and the corresponding index.
6988
+
6989
+ - **values (Tensor)** - The minimum value of input tensor along the given dimension `dim`, with same dtype as
6990
+ `input`. If `keepdim` is ``True`` , the shape of output tensors is :math:`(input_1, input_2, ...,
6991
+ input_{axis-1}, 1, input_{axis+1}, ..., input_N)` . Otherwise, the shape is :math:`(input_1, input_2, ...,
6992
+ input_{axis-1}, input_{axis+1}, ..., input_N)` .
6993
+ - **index (Tensor)** - The index for the minimum value of the input tensor along the given dimension `dim`,
6994
+ with the same shape as `values`.
6995
+
6996
+ Raises:
6997
+ ValueError: If `dim` is the default value ``None`` and `keepdim` is not ``False`` .
6998
+
6999
+ Supported Platforms:
7000
+ ``Ascend`` ``GPU`` ``CPU``
7001
+
7002
+ Examples:
7003
+ >>> import mindspore
7004
+ >>> import numpy as np
7005
+ >>> from mindspore import Tensor, ops
7006
+ >>> from mindspore.ops.function.array_func import min_ext
7007
+ >>> x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)
7008
+ >>> output, index = min_ext(x, 0, keepdim=True)
7009
+ >>> print(output, index)
7010
+ [0.0] [0]
7011
+ """
7012
+ if dim is None:
7013
+ if keepdim is not False:
7014
+ raise ValueError(
7015
+ f"For 'min', the `keepdim` must be False when the `dim` is None, but got {keepdim}")
7016
+ return min_(input)
7017
+ argmin_with_value_op = _get_cache_prim(ArgMinWithValue)(dim, keepdim)
7018
+ indices, values = argmin_with_value_op(input)
7019
+ return values, indices
7020
+
7021
+
7022
+ def one_hot_ext(tensor, num_classes):
7023
+ r"""
7024
+ Computes a one-hot tensor.
7025
+
7026
+ The locations represented by tensor in `tensor` take value `1`, while all
7027
+ other locations take value `0`.
7028
+
7029
+ Args:
7030
+ tensor (Tensor): A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
7031
+ Data type must be int32 or int64.
7032
+ num_classes (int): A scalar defining the depth of the one-hot dimension.
7033
+
7034
+ Returns:
7035
+ Tensor, one-hot tensor.
7036
+
7037
+ Raises:
7038
+ TypeError: If `num_classes` is not an int.
7039
+ TypeError: If dtype of `tensor` is not int32 or int64.
7040
+ ValueError: If `num_classes` is less than 0.
7041
+
7042
+ Supported Platforms:
7043
+ ``Ascend`` ``GPU`` ``CPU``
7044
+
7045
+ Examples:
7046
+ >>> import mindspore
7047
+ >>> import numpy as np
7048
+ >>> from mindspore import ops
7049
+ >>> from mindspore import Tensor
7050
+ >>> from mindspore.ops.function.array_func import one_hot_ext
7051
+ >>> tensor = Tensor(np.array([0, 1, 2]), mindspore.int32)
7052
+ >>> num_classes = 3
7053
+ >>> output = one_hot_ext(tensor, num_classes)
7632
7054
  >>> print(output)
7633
- [[0 1]
7634
- [2 1]]
7055
+ [[1. 0. 0.]
7056
+ [0. 1. 0.]
7057
+ [0. 0. 1.]]
7635
7058
  """
7636
- _deepcopy = _get_cache_prim(P.Identity)()
7637
- return _deepcopy(input_x)
7059
+ on_value = Tensor(1, dtype=tensor.dtype)
7060
+ off_value = Tensor(0, dtype=tensor.dtype)
7061
+ return one_hot_ext_impl(tensor, num_classes, on_value, off_value, -1)
7062
+
7063
+
7064
+ def from_numpy(array):
7065
+ r"""
7066
+ Convert numpy array to Tensor.
7067
+ If the data is not C contiguous, the data will be copied to C contiguous to construct the tensor.
7068
+ Otherwise, the tensor will be constructed using this numpy array without copy.
7069
+
7070
+ Args:
7071
+ array (numpy.array): The input array.
7072
+
7073
+ Returns:
7074
+ Tensor, has the same data type as input array.
7638
7075
 
7076
+ Examples:
7077
+ >>> import numpy as np
7078
+ >>> import mindspore as ms
7079
+ >>> x = np.array([1, 2])
7080
+ >>> output = ms.from_numpy(x)
7081
+ >>> print(output)
7082
+ [1 2]
7083
+ """
7084
+ return Tensor.from_numpy(array)
7639
7085
 
7640
7086
  __all__ = [
7641
7087
  'unique',
@@ -7653,6 +7099,7 @@ __all__ = [
7653
7099
  'ones_like',
7654
7100
  'zeros',
7655
7101
  'zeros_like',
7102
+ 'zero_',
7656
7103
  'shape',
7657
7104
  'shape_',
7658
7105
  'reverse',
@@ -7663,8 +7110,8 @@ __all__ = [
7663
7110
  'full_like',
7664
7111
  'dyn_shape',
7665
7112
  'rank',
7666
- 'range',
7667
7113
  'arange',
7114
+ 'range',
7668
7115
  'reshape',
7669
7116
  'reshape_',
7670
7117
  'flatten',
@@ -7773,6 +7220,7 @@ __all__ = [
7773
7220
  'aminmax',
7774
7221
  'sort',
7775
7222
  'top_k',
7776
- 'deepcopy'
7223
+ 'deepcopy',
7224
+ 'flip',
7777
7225
  ]
7778
7226
  __all__.sort()