mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (287) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/initializer.py +51 -15
  26. mindspore/common/mindir_util.py +2 -2
  27. mindspore/common/parameter.py +62 -15
  28. mindspore/common/recompute.py +39 -9
  29. mindspore/common/sparse_tensor.py +7 -3
  30. mindspore/common/tensor.py +183 -37
  31. mindspore/communication/__init__.py +1 -1
  32. mindspore/communication/_comm_helper.py +38 -3
  33. mindspore/communication/comm_func.py +315 -60
  34. mindspore/communication/management.py +14 -14
  35. mindspore/context.py +132 -22
  36. mindspore/dataset/__init__.py +1 -1
  37. mindspore/dataset/audio/__init__.py +1 -1
  38. mindspore/dataset/core/config.py +7 -0
  39. mindspore/dataset/core/validator_helpers.py +7 -0
  40. mindspore/dataset/engine/cache_client.py +1 -1
  41. mindspore/dataset/engine/datasets.py +72 -44
  42. mindspore/dataset/engine/datasets_audio.py +7 -7
  43. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  44. mindspore/dataset/engine/datasets_text.py +20 -20
  45. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  46. mindspore/dataset/engine/datasets_vision.py +33 -33
  47. mindspore/dataset/engine/iterators.py +29 -0
  48. mindspore/dataset/engine/obs/util.py +7 -0
  49. mindspore/dataset/engine/queue.py +114 -60
  50. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  51. mindspore/dataset/engine/validators.py +34 -14
  52. mindspore/dataset/text/__init__.py +1 -4
  53. mindspore/dataset/transforms/__init__.py +0 -3
  54. mindspore/dataset/utils/line_reader.py +2 -0
  55. mindspore/dataset/vision/__init__.py +1 -4
  56. mindspore/dataset/vision/utils.py +1 -1
  57. mindspore/dataset/vision/validators.py +2 -1
  58. mindspore/dnnl.dll +0 -0
  59. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  60. mindspore/experimental/es/embedding_service.py +883 -0
  61. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  62. mindspore/experimental/llm_boost/__init__.py +21 -0
  63. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  64. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  65. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  66. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  67. mindspore/experimental/llm_boost/register.py +129 -0
  68. mindspore/experimental/llm_boost/utils.py +31 -0
  69. mindspore/experimental/optim/adamw.py +85 -0
  70. mindspore/experimental/optim/optimizer.py +3 -0
  71. mindspore/hal/__init__.py +3 -3
  72. mindspore/hal/contiguous_tensors_handle.py +175 -0
  73. mindspore/hal/stream.py +18 -0
  74. mindspore/include/api/model_group.h +13 -1
  75. mindspore/include/api/types.h +10 -10
  76. mindspore/include/dataset/config.h +2 -2
  77. mindspore/include/dataset/constants.h +2 -2
  78. mindspore/include/dataset/execute.h +2 -2
  79. mindspore/include/dataset/vision.h +4 -0
  80. mindspore/jpeg62.dll +0 -0
  81. mindspore/log.py +1 -1
  82. mindspore/mindrecord/filewriter.py +68 -51
  83. mindspore/mindspore_backend.dll +0 -0
  84. mindspore/mindspore_common.dll +0 -0
  85. mindspore/mindspore_core.dll +0 -0
  86. mindspore/mindspore_glog.dll +0 -0
  87. mindspore/mindspore_np_dtype.dll +0 -0
  88. mindspore/mindspore_ops.dll +0 -0
  89. mindspore/mint/__init__.py +983 -46
  90. mindspore/mint/distributed/__init__.py +31 -0
  91. mindspore/mint/distributed/distributed.py +254 -0
  92. mindspore/mint/nn/__init__.py +268 -23
  93. mindspore/mint/nn/functional.py +125 -19
  94. mindspore/mint/nn/layer/__init__.py +39 -0
  95. mindspore/mint/nn/layer/activation.py +133 -0
  96. mindspore/mint/nn/layer/normalization.py +477 -0
  97. mindspore/mint/nn/layer/pooling.py +110 -0
  98. mindspore/mint/optim/adamw.py +26 -13
  99. mindspore/mint/special/__init__.py +63 -0
  100. mindspore/multiprocessing/__init__.py +2 -1
  101. mindspore/nn/__init__.py +0 -1
  102. mindspore/nn/cell.py +276 -96
  103. mindspore/nn/layer/activation.py +211 -44
  104. mindspore/nn/layer/basic.py +137 -10
  105. mindspore/nn/layer/embedding.py +137 -2
  106. mindspore/nn/layer/normalization.py +101 -5
  107. mindspore/nn/layer/padding.py +34 -48
  108. mindspore/nn/layer/pooling.py +161 -7
  109. mindspore/nn/layer/transformer.py +3 -3
  110. mindspore/nn/loss/__init__.py +2 -2
  111. mindspore/nn/loss/loss.py +84 -6
  112. mindspore/nn/optim/__init__.py +2 -1
  113. mindspore/nn/optim/adadelta.py +1 -1
  114. mindspore/nn/optim/adam.py +1 -1
  115. mindspore/nn/optim/lamb.py +1 -1
  116. mindspore/nn/optim/tft_wrapper.py +124 -0
  117. mindspore/nn/wrap/cell_wrapper.py +12 -23
  118. mindspore/nn/wrap/grad_reducer.py +5 -5
  119. mindspore/nn/wrap/loss_scale.py +17 -3
  120. mindspore/numpy/__init__.py +1 -1
  121. mindspore/numpy/array_creations.py +65 -68
  122. mindspore/numpy/array_ops.py +64 -60
  123. mindspore/numpy/fft.py +610 -75
  124. mindspore/numpy/logic_ops.py +11 -10
  125. mindspore/numpy/math_ops.py +85 -84
  126. mindspore/numpy/utils_const.py +4 -4
  127. mindspore/opencv_core452.dll +0 -0
  128. mindspore/opencv_imgcodecs452.dll +0 -0
  129. mindspore/opencv_imgproc452.dll +0 -0
  130. mindspore/ops/__init__.py +6 -4
  131. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  132. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  133. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  134. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  135. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  136. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  139. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  140. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  141. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  142. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  143. mindspore/ops/composite/base.py +85 -48
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  145. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  146. mindspore/ops/function/__init__.py +22 -0
  147. mindspore/ops/function/array_func.py +492 -153
  148. mindspore/ops/function/debug_func.py +113 -1
  149. mindspore/ops/function/fft_func.py +15 -2
  150. mindspore/ops/function/grad/grad_func.py +3 -2
  151. mindspore/ops/function/math_func.py +564 -207
  152. mindspore/ops/function/nn_func.py +817 -383
  153. mindspore/ops/function/other_func.py +3 -2
  154. mindspore/ops/function/random_func.py +402 -12
  155. mindspore/ops/function/reshard_func.py +13 -11
  156. mindspore/ops/function/sparse_unary_func.py +1 -1
  157. mindspore/ops/function/vmap_func.py +3 -2
  158. mindspore/ops/functional.py +24 -14
  159. mindspore/ops/op_info_register.py +3 -3
  160. mindspore/ops/operations/__init__.py +7 -2
  161. mindspore/ops/operations/_grad_ops.py +2 -76
  162. mindspore/ops/operations/_infer_ops.py +1 -1
  163. mindspore/ops/operations/_inner_ops.py +71 -94
  164. mindspore/ops/operations/array_ops.py +14 -146
  165. mindspore/ops/operations/comm_ops.py +63 -53
  166. mindspore/ops/operations/custom_ops.py +83 -19
  167. mindspore/ops/operations/debug_ops.py +42 -10
  168. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  169. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  170. mindspore/ops/operations/math_ops.py +12 -223
  171. mindspore/ops/operations/nn_ops.py +20 -114
  172. mindspore/ops/operations/other_ops.py +7 -4
  173. mindspore/ops/operations/random_ops.py +46 -1
  174. mindspore/ops/primitive.py +18 -6
  175. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  176. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  177. mindspore/ops_generate/gen_constants.py +36 -0
  178. mindspore/ops_generate/gen_ops.py +67 -52
  179. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  180. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  181. mindspore/ops_generate/op_proto.py +10 -3
  182. mindspore/ops_generate/pyboost_utils.py +14 -1
  183. mindspore/ops_generate/template.py +43 -21
  184. mindspore/parallel/__init__.py +3 -1
  185. mindspore/parallel/_auto_parallel_context.py +31 -9
  186. mindspore/parallel/_cell_wrapper.py +85 -0
  187. mindspore/parallel/_parallel_serialization.py +47 -19
  188. mindspore/parallel/_tensor.py +127 -13
  189. mindspore/parallel/_utils.py +53 -22
  190. mindspore/parallel/algo_parameter_config.py +5 -5
  191. mindspore/parallel/checkpoint_transform.py +46 -39
  192. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  193. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  194. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  195. mindspore/parallel/parameter_broadcast.py +3 -4
  196. mindspore/parallel/shard.py +162 -31
  197. mindspore/parallel/transform_safetensors.py +1146 -0
  198. mindspore/profiler/__init__.py +2 -1
  199. mindspore/profiler/common/constant.py +29 -0
  200. mindspore/profiler/common/registry.py +47 -0
  201. mindspore/profiler/common/util.py +28 -0
  202. mindspore/profiler/dynamic_profiler.py +694 -0
  203. mindspore/profiler/envprofiling.py +17 -19
  204. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  205. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  206. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  207. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  208. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  209. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  210. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  211. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  212. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  213. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  214. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  215. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  216. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  217. mindspore/profiler/parser/framework_parser.py +1 -391
  218. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  219. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  220. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  221. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  222. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  223. mindspore/profiler/parser/profiler_info.py +78 -6
  224. mindspore/profiler/profiler.py +153 -0
  225. mindspore/profiler/profiling.py +285 -413
  226. mindspore/rewrite/__init__.py +1 -2
  227. mindspore/rewrite/common/namespace.py +4 -4
  228. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  229. mindspore/run_check/_check_version.py +39 -104
  230. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  231. mindspore/swresample-4.dll +0 -0
  232. mindspore/swscale-6.dll +0 -0
  233. mindspore/tinyxml2.dll +0 -0
  234. mindspore/train/__init__.py +4 -3
  235. mindspore/train/_utils.py +105 -19
  236. mindspore/train/amp.py +171 -53
  237. mindspore/train/callback/__init__.py +2 -2
  238. mindspore/train/callback/_callback.py +4 -4
  239. mindspore/train/callback/_checkpoint.py +97 -31
  240. mindspore/train/callback/_cluster_monitor.py +1 -1
  241. mindspore/train/callback/_flops_collector.py +1 -0
  242. mindspore/train/callback/_loss_monitor.py +3 -3
  243. mindspore/train/callback/_on_request_exit.py +145 -31
  244. mindspore/train/callback/_summary_collector.py +5 -5
  245. mindspore/train/callback/_tft_register.py +375 -0
  246. mindspore/train/dataset_helper.py +15 -3
  247. mindspore/train/metrics/metric.py +3 -3
  248. mindspore/train/metrics/roc.py +4 -4
  249. mindspore/train/mind_ir_pb2.py +44 -39
  250. mindspore/train/model.py +154 -58
  251. mindspore/train/serialization.py +342 -128
  252. mindspore/turbojpeg.dll +0 -0
  253. mindspore/utils/__init__.py +21 -0
  254. mindspore/utils/utils.py +60 -0
  255. mindspore/version.py +1 -1
  256. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  257. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
  258. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
  259. mindspore/include/c_api/ms/abstract.h +0 -67
  260. mindspore/include/c_api/ms/attribute.h +0 -197
  261. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  262. mindspore/include/c_api/ms/base/macros.h +0 -32
  263. mindspore/include/c_api/ms/base/status.h +0 -33
  264. mindspore/include/c_api/ms/base/types.h +0 -283
  265. mindspore/include/c_api/ms/context.h +0 -102
  266. mindspore/include/c_api/ms/graph.h +0 -160
  267. mindspore/include/c_api/ms/node.h +0 -606
  268. mindspore/include/c_api/ms/tensor.h +0 -161
  269. mindspore/include/c_api/ms/value.h +0 -84
  270. mindspore/mindspore_shared_lib.dll +0 -0
  271. mindspore/nn/extend/basic.py +0 -140
  272. mindspore/nn/extend/embedding.py +0 -143
  273. mindspore/nn/extend/layer/normalization.py +0 -109
  274. mindspore/nn/extend/pooling.py +0 -117
  275. mindspore/nn/layer/embedding_service.py +0 -531
  276. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  277. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  278. mindspore/ops/extend/__init__.py +0 -53
  279. mindspore/ops/extend/array_func.py +0 -218
  280. mindspore/ops/extend/math_func.py +0 -76
  281. mindspore/ops/extend/nn_func.py +0 -308
  282. mindspore/ops/silent_check.py +0 -162
  283. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  284. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  285. mindspore/train/callback/_mindio_ttp.py +0 -443
  286. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  287. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -32,14 +32,13 @@ from mindspore.ops.operations._inner_ops import DynamicBroadcastTo
32
32
  from mindspore.ops.operations._sequence_ops import TupleToTensor
33
33
  from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
34
34
  from mindspore.ops.operations._sequence_ops import TensorToList
35
- from mindspore.ops.auto_generate import OnesLikeExt, ZerosLikeExt, FillScalar, FillTensor, Arange, Chunk, UniqueDim,\
36
- Unique2, SortExt, NonZero, NonZeroExt
35
+ from mindspore.ops.auto_generate import OnesLikeExt, ZerosLikeExt, FillScalar, FillTensor, Arange, Chunk, UniqueDim, \
36
+ Unique2, SortExt, NonZero, NonZeroExt, Scatter, ScatterValue
37
37
  from mindspore.ops.auto_generate.gen_ops_prim import SplitTensor
38
38
  from mindspore.ops.auto_generate.gen_ops_prim import SplitWithSize, RepeatInterleaveInt, RepeatInterleaveTensor
39
-
39
+ from mindspore.ops.auto_generate.pyboost_inner_prim import _PyboostSearchSortedPrim
40
40
  from mindspore.ops.operations.array_ops import (
41
41
  UniqueConsecutive,
42
- SearchSorted,
43
42
  MatrixDiagV3,
44
43
  MatrixDiagPartV3,
45
44
  MatrixSetDiagV3,
@@ -58,7 +57,6 @@ from mindspore.ops.operations.array_ops import (
58
57
  ArgMaxWithValue,
59
58
  ArgMinWithValue
60
59
  )
61
- from mindspore.ops.operations.array_ops import TensorScatterElements
62
60
  from mindspore.common import Tensor
63
61
  from mindspore.ops._primitive_cache import _get_cache_prim
64
62
  from mindspore import _checkparam as validator
@@ -66,10 +64,12 @@ from mindspore._c_expression import Tensor as Tensor_
66
64
  from mindspore.ops._utils.utils import ms_arrange
67
65
 
68
66
  from mindspore.ops.auto_generate import cat, range, scatter_nd, deepcopy, masked_fill, diagonal, expand_dims, \
69
- flip, transpose, triu, unsorted_segment_sum, diag, gather, gather_d, gather_nd, reshape, \
70
- broadcast_to, strided_slice, ones, zeros, max_, min_, select
71
- from mindspore.ops.auto_generate.gen_ops_prim import scatter_add_ext_op, slice_ext_op
67
+ flip, transpose, triu, unsorted_segment_sum, diag, gather, gather_d, gather_nd, reshape, masked_select, \
68
+ broadcast_to, strided_slice, ones, zeros, max_, min_, select, zero_
69
+ from mindspore.ops.auto_generate import tensor_scatter_elements as tensor_scatter_elements_ext
70
+ from mindspore.ops.auto_generate.gen_ops_prim import scatter_add_ext_op, slice_ext_op, gather_d_op
72
71
  from mindspore.ops.operations.manually_defined import tile, rank, scalar_cast
72
+ from mindspore.ops.auto_generate.pyboost_inner_prim import _PyboostOneHotExtPrim, tril_ext_impl
73
73
 
74
74
  arg_max_with_value_ = ArgMaxWithValue()
75
75
  arg_min_with_value_ = ArgMinWithValue()
@@ -87,7 +87,6 @@ gather_nd_ = P.GatherNd()
87
87
  ger_ = P.Ger()
88
88
  index_fill_ = IndexFill()
89
89
  lstsq_ = Lstsq()
90
- masked_select_ = P.MaskedSelect()
91
90
  matrix_band_part_ = P.array_ops.MatrixBandPart()
92
91
  ones_ = P.Ones()
93
92
  population_count_ = P.PopulationCount()
@@ -104,6 +103,7 @@ scatter_min_ = P.ScatterMin()
104
103
  scatter_mul_ = P.ScatterMul()
105
104
  scatter_nd_ = P.ScatterNd()
106
105
  scatter_update_ = P.ScatterUpdate()
106
+ search_sorted_ = _PyboostSearchSortedPrim()
107
107
  shape_ = P.Shape()
108
108
  split_tensor = SplitTensor()
109
109
  split_with_size = SplitWithSize()
@@ -122,18 +122,20 @@ transpose_ = P.Transpose()
122
122
  tuple_to_array_ = P.TupleToArray()
123
123
  tuple_to_tensor_ = TupleToTensor()
124
124
  unique_ = P.Unique()
125
- unique_with_pad_ = P.UniqueWithPad()
126
125
  unsorted_segment_max_ = P.UnsortedSegmentMax()
127
126
  unsorted_segment_min_ = P.UnsortedSegmentMin()
128
127
  unsorted_segment_prod_ = P.UnsortedSegmentProd()
129
128
  unsorted_segment_sum_ = P.UnsortedSegmentSum()
130
129
  ones_like_ = P.OnesLike()
130
+ one_hot_ext_impl = _PyboostOneHotExtPrim()
131
131
  zeros_like_ = P.ZerosLike()
132
132
  ones_like_ext_ = OnesLikeExt()
133
133
  zeros_like_ext_ = ZerosLikeExt()
134
134
  fill_scalar_ = FillScalar()
135
135
  fill_tensor_ = FillTensor()
136
136
  sort_ext_ = SortExt()
137
+ scatter_ = Scatter()
138
+ scatter_value_ = ScatterValue()
137
139
  arange_ = Arange()
138
140
  chunk_ = Chunk()
139
141
  repeat_interleave_int_ = RepeatInterleaveInt()
@@ -199,7 +201,8 @@ def _get_max_type(start, end, step):
199
201
 
200
202
  type_map = {'Float64': '3', 'Float32': '2', "<class 'float'>": '2', 'Int64': '1', "<class 'int'>": '1',
201
203
  'Int32': '0'}
202
- type_map_reverse = {'3': mstype.float64, '2': mstype.float32, '1': mstype.int64, '0': mstype.int32}
204
+ type_map_reverse = {'3': mstype.float64,
205
+ '2': mstype.float32, '1': mstype.int64, '0': mstype.int32}
203
206
  type_level = [type_map.get(i) for i in arg_type_map]
204
207
  max_level = builtins.max(type_level)
205
208
  return type_map_reverse.get(max_level)
@@ -329,7 +332,7 @@ def arange_ext(start=0, end=None, step=1, *, dtype=None):
329
332
  [7 5 3]
330
333
  >>> print(output.dtype)
331
334
  Int64
332
- >>> output = ops.arange_ext(12, 2, -1, dtype=ms.bfloat16))
335
+ >>> output = ops.arange_ext(12, 2, -1, dtype=ms.bfloat16)
333
336
  >>> print(output)
334
337
  [12. 11. 10. 9. 8. 7. 6. 5. 4. 3.]
335
338
  >>> print(output.dtype)
@@ -347,9 +350,9 @@ def concat(tensors, axis=0):
347
350
  Tutorial Examples:
348
351
  - `Tensor - Tensor Operation <https://mindspore.cn/tutorials/en/master/beginner/tensor.html#tensor-operation>`_
349
352
  - `Vision Transformer Image Classification - Building ViT as a whole
350
- <https://mindspore.cn/tutorials/application/en/master/cv/vit.html#building-vit-as-a-whole>`_
353
+ <https://mindspore.cn/tutorials/en/master/cv/vit.html#building-vit-as-a-whole>`_
351
354
  - `Sentiment Classification Implemented by RNN - Dense
352
- <https://mindspore.cn/tutorials/application/en/master/nlp/sentiment_analysis.html#dense>`_
355
+ <https://mindspore.cn/tutorials/en/master/nlp/sentiment_analysis.html#dense>`_
353
356
  """
354
357
  return cat(tensors, axis)
355
358
 
@@ -451,20 +454,25 @@ def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype
451
454
  [0.08 0.39785218 0.91214782 0.91214782 0.39785218 0.08]
452
455
  """
453
456
  if not isinstance(window_length, int):
454
- raise TypeError(f"For array function 'hamming_window', 'window_length' must be int, but got" \
457
+ raise TypeError(f"For array function 'hamming_window', 'window_length' must be int, but got"
455
458
  f" {type(window_length)}.")
456
459
  if window_length < 0:
457
- raise ValueError(f"For array function 'hamming_window', 'window_length' must be non negative number.")
460
+ raise ValueError(
461
+ f"For array function 'hamming_window', 'window_length' must be non negative number.")
458
462
  if not isinstance(periodic, bool):
459
- raise TypeError(f"For array function 'hamming_window', 'periodic' must be bool, but got {type(periodic)}.")
463
+ raise TypeError(
464
+ f"For array function 'hamming_window', 'periodic' must be bool, but got {type(periodic)}.")
460
465
  if not isinstance(alpha, float):
461
- raise TypeError(f"For array function 'hamming_window', 'alpha' must be float, but got {type(alpha)}.")
466
+ raise TypeError(
467
+ f"For array function 'hamming_window', 'alpha' must be float, but got {type(alpha)}.")
462
468
  if not isinstance(beta, float):
463
- raise TypeError(f"For array function 'hamming_window', 'beta' must be float, but got {type(beta)}.")
469
+ raise TypeError(
470
+ f"For array function 'hamming_window', 'beta' must be float, but got {type(beta)}.")
464
471
  if window_length <= 1:
465
472
  return Tensor(np.ones(window_length))
466
473
  if dtype is not None and dtype not in mstype.float_type:
467
- raise TypeError(f"For array function 'hamming_window', 'dtype' must be floating point dtypes, but got {dtype}.")
474
+ raise TypeError(
475
+ f"For array function 'hamming_window', 'dtype' must be floating point dtypes, but got {dtype}.")
468
476
 
469
477
  dtype = mstype.float32 if dtype is None else dtype
470
478
  op = _get_cache_prim(P.HammingWindow)(periodic, alpha, beta, dtype)
@@ -641,7 +649,8 @@ def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True, ops_n
641
649
  if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
642
650
  for ax in axis:
643
651
  if not isinstance(ax, int):
644
- raise TypeError(f"For {ops_name}, each axis must be integer, but got {type(ax)} in {axis}.")
652
+ raise TypeError(
653
+ f"For {ops_name}, each axis must be integer, but got {type(ax)} in {axis}.")
645
654
  return True
646
655
 
647
656
  type_str = ""
@@ -651,7 +660,8 @@ def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True, ops_n
651
660
  type_str += "tuple, "
652
661
  if type_list:
653
662
  type_str += "list, "
654
- raise TypeError(f"For {ops_name}, the axis should be {type_str}, but got {type(axis)}.")
663
+ raise TypeError(
664
+ f"For {ops_name}, the axis should be {type_str}, but got {type(axis)}.")
655
665
 
656
666
 
657
667
  def one_hot(indices, depth, on_value=1, off_value=0, axis=-1):
@@ -720,8 +730,8 @@ def fill(type, shape, value): # pylint: disable=redefined-outer-name
720
730
 
721
731
  Args:
722
732
  type (mindspore.dtype): The specified type of output tensor. The data type only supports
723
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ and
724
- `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
733
+ `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ and
734
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ .
725
735
  shape (Union(Tensor, tuple[int])): The specified shape of output tensor.
726
736
  value (Union(Tensor, number.Number, bool)): Value to fill the returned tensor.
727
737
 
@@ -786,11 +796,13 @@ def full(size, fill_value, *, dtype=None): # pylint: disable=redefined-outer-na
786
796
  [0. 0. 0.]]
787
797
  """
788
798
  if not isinstance(size, (list, tuple)):
789
- raise TypeError(f"For 'ops.full', 'size' must be a tuple or list of ints, but got {type(size)}.")
799
+ raise TypeError(
800
+ f"For 'ops.full', 'size' must be a tuple or list of ints, but got {type(size)}.")
790
801
  if dtype is None:
791
802
  dtype = mstype.int64
792
803
  if dtype not in mstype.all_types:
793
- raise TypeError(f"For 'ops.full', 'dtype' must be mindspore.type, but got {dtype}.")
804
+ raise TypeError(
805
+ f"For 'ops.full', 'dtype' must be mindspore.type, but got {dtype}.")
794
806
  if isinstance(size, list):
795
807
  size = tuple(size)
796
808
  return ops.fill(dtype, size, fill_value)
@@ -802,7 +814,8 @@ def full_ext(size, fill_value, *, dtype=None): # pylint: disable=redefined-oute
802
814
 
803
815
  Args:
804
816
  size (Union(tuple[int], list[int])): The specified shape of output tensor.
805
- fill_value (number.Number): Value to fill the returned tensor. Complex numbers are not supported for now.
817
+ fill_value (Union(number.Number, Tensor)): Value to fill the returned tensor. It can be a Scalar number, a 0-D
818
+ Tensor, or a 1-D Tensor with only one element.
806
819
 
807
820
  Keyword Args:
808
821
  dtype (mindspore.dtype): The specified type of output tensor. `bool_` and `number` are supported, for details,
@@ -820,18 +833,16 @@ def full_ext(size, fill_value, *, dtype=None): # pylint: disable=redefined-oute
820
833
 
821
834
  Examples:
822
835
  >>> from mindspore import ops
823
- >>> output = ops.full((2, 2), 1)
836
+ >>> output = ops.full_ext((2, 2), 1)
824
837
  >>> print(output)
825
838
  [[1. 1.]
826
839
  [1. 1.]]
827
- >>> output = ops.full((3, 3), 0)
840
+ >>> output = ops.full_ext((3, 3), 0)
828
841
  >>> print(output)
829
842
  [[0. 0. 0.]
830
843
  [0. 0. 0.]
831
844
  [0. 0. 0.]]
832
845
  """
833
- if isinstance(fill_value, Tensor):
834
- return fill_tensor_(size, fill_value, dtype)
835
846
  return fill_scalar_(size, fill_value, dtype)
836
847
 
837
848
 
@@ -872,7 +883,8 @@ def full_like(input, fill_value, *, dtype=None):
872
883
  [0. 0. 0.]]
873
884
  """
874
885
  if not isinstance(input, Tensor):
875
- raise TypeError(f"For ops.full_like, the argument 'x' must be tensor, but got {type(input)}")
886
+ raise TypeError(
887
+ f"For ops.full_like, the argument 'x' must be tensor, but got {type(input)}")
876
888
  if dtype is None:
877
889
  dtype = input.dtype
878
890
  return full(input.shape, fill_value, dtype=dtype)
@@ -914,19 +926,24 @@ def chunk(input, chunks, axis=0):
914
926
  Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
915
927
  """
916
928
  if not isinstance(input, Tensor):
917
- raise TypeError(f'For ops.chunk parameter `input` must be Tensor, but got {type(input)}')
929
+ raise TypeError(
930
+ f'For ops.chunk parameter `input` must be Tensor, but got {type(input)}')
918
931
  _check_axis_type(axis, True, False, False, "ops.chunk")
919
932
  arr_axis = _canonicalize_axis(axis, input.ndim)
920
933
 
921
934
  if not isinstance(chunks, int):
922
- raise TypeError(f"For ops.chunk type of argument `chunks` should be integer, but got {type(chunks)}")
935
+ raise TypeError(
936
+ f"For ops.chunk type of argument `chunks` should be integer, but got {type(chunks)}")
923
937
  if chunks <= 0:
924
- raise ValueError(f"For ops.chunk parameter 'chunks' must be greater than 0, but got {chunks}")
938
+ raise ValueError(
939
+ f"For ops.chunk parameter 'chunks' must be greater than 0, but got {chunks}")
925
940
 
926
941
  arr_shape = input.shape
927
942
  length_along_dim = arr_shape[arr_axis]
928
943
 
929
- if chunks > length_along_dim:
944
+ if length_along_dim == 0:
945
+ res = _get_cache_prim(P.Split)(arr_axis)(input)
946
+ elif chunks > length_along_dim:
930
947
  res = _get_cache_prim(P.Split)(arr_axis, length_along_dim)(input)
931
948
  elif length_along_dim % chunks == 0:
932
949
  res = _get_cache_prim(P.Split)(arr_axis, chunks)(input)
@@ -939,9 +956,11 @@ def chunk(input, chunks, axis=0):
939
956
  size1 = _tuple_setitem(arr_shape, arr_axis, length1)
940
957
  start2 = _tuple_setitem(start1, arr_axis, length1)
941
958
  size2 = _tuple_setitem(arr_shape, arr_axis, length2)
942
- res = _get_cache_prim(P.Split)(arr_axis, true_chunks)(tensor_slice(input, start1, size1))
959
+ res = _get_cache_prim(P.Split)(arr_axis, true_chunks)(
960
+ tensor_slice(input, start1, size1))
943
961
  if length2:
944
- res += _get_cache_prim(P.Split)(arr_axis, 1)(tensor_slice(input, start2, size2))
962
+ res += _get_cache_prim(P.Split)(arr_axis,
963
+ 1)(tensor_slice(input, start2, size2))
945
964
  return res
946
965
 
947
966
 
@@ -952,6 +971,9 @@ def chunk_ext(input, chunks, dim=0):
952
971
  Note:
953
972
  This function may return less than the specified number of chunks!
954
973
 
974
+ .. warning::
975
+ This is an experimental API that is subject to change or deletion.
976
+
955
977
  Args:
956
978
  input (Tensor): A Tensor to be cut.
957
979
  chunks (int): Number of sub-tensors to cut.
@@ -1260,11 +1282,14 @@ def unique_ext(input, sorted=True, return_inverse=False, return_counts=False, di
1260
1282
  [0 1 2 1]
1261
1283
  """
1262
1284
  if not F.isconstant(return_inverse) or not F.isconstant(return_counts):
1263
- raise ValueError(f"For 'unique_ext', 'return_inverse' and 'return_counts' cannot be mutable")
1285
+ raise ValueError(
1286
+ f"For 'unique_ext', 'return_inverse' and 'return_counts' cannot be mutable")
1264
1287
  if dim is None:
1265
- y, inverse, counts = unique2_(input, sorted, return_inverse, return_counts)
1288
+ y, inverse, counts = unique2_(
1289
+ input, sorted, return_inverse, return_counts)
1266
1290
  else:
1267
- validator.check_value_type("return_counts", return_counts, [bool], "unique_ext")
1291
+ validator.check_value_type(
1292
+ "return_counts", return_counts, [bool], "unique_ext")
1268
1293
  y, inverse, counts = unique_dim_(input, sorted, return_inverse, dim)
1269
1294
  if return_inverse and return_counts:
1270
1295
  return y, inverse, counts
@@ -1285,6 +1310,11 @@ def unique_with_pad(x, pad_num):
1285
1310
  the UniqueWithPad operator will fill the `y` Tensor with the `pad_num` specified by the user
1286
1311
  to make it have the same shape as the Tensor `idx`.
1287
1312
 
1313
+ .. warning::
1314
+ :func:`mindspore.ops.unique_with_pad` is deprecated from version 2.4 and will be removed in a future version.
1315
+ Please use the :func:`mindspore.ops.unique` combined with :func:`mindspore.ops.pad` to realize
1316
+ the same function.
1317
+
1288
1318
  Args:
1289
1319
  x (Tensor): The tensor need to be unique. Must be 1-D vector with types: int32, int64.
1290
1320
  pad_num (int): Pad num. The data type is an int.
@@ -1297,10 +1327,10 @@ def unique_with_pad(x, pad_num):
1297
1327
 
1298
1328
  Raises:
1299
1329
  TypeError: If dtype of `x` is neither int32 nor int64.
1300
- ValueError: If length of shape of `x` is not equal to 1.
1330
+ ValueError: If `x` is not a 1-D Tensor.
1301
1331
 
1302
1332
  Supported Platforms:
1303
- ``Ascend`` ``GPU`` ``CPU``
1333
+ Deprecated
1304
1334
 
1305
1335
  Examples:
1306
1336
  >>> import mindspore
@@ -1319,7 +1349,7 @@ def unique_with_pad(x, pad_num):
1319
1349
  >>> print(idx)
1320
1350
  [0 1 1 2 3 3]
1321
1351
  """
1322
- return unique_with_pad_(x, pad_num)
1352
+ return _get_cache_prim(P.UniqueWithPad)()(x, pad_num)
1323
1353
 
1324
1354
 
1325
1355
  def unique_consecutive(input, return_idx=False, return_counts=False, axis=None):
@@ -1369,7 +1399,8 @@ def unique_consecutive(input, return_idx=False, return_counts=False, axis=None):
1369
1399
 
1370
1400
  if not isinstance(input, (Tensor, Tensor_)):
1371
1401
  raise TypeError("For 'unique_consecutive', 'input' must be Tensor.")
1372
- unique_consecutive_op = _get_cache_prim(UniqueConsecutive)(return_idx, return_counts, axis)
1402
+ unique_consecutive_op = _get_cache_prim(
1403
+ UniqueConsecutive)(return_idx, return_counts, axis)
1373
1404
  output, idx, counts = unique_consecutive_op(input)
1374
1405
  if return_idx and return_counts:
1375
1406
  return output, idx, counts
@@ -1400,7 +1431,7 @@ def searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=
1400
1431
  set to ``"left"`` while `right` is ``True``. Default: ``None`` .
1401
1432
  sorter(Tensor, optional): if provided, a tensor matching the shape of the unsorted sorted_sequence
1402
1433
  containing a sequence of indices that sort it in the ascending order on the innermost
1403
- dimension and type must be int64. Default: ``None`` .
1434
+ dimension and type must be int64. Default: ``None`` . CPU and GPU can only use default values
1404
1435
 
1405
1436
  Returns:
1406
1437
  Tensor containing the indices from the innermost dimension of `sorted_sequence` such that,
@@ -1437,8 +1468,7 @@ def searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=
1437
1468
  f"got side of left while right was True.")
1438
1469
  if side == "right":
1439
1470
  right = True
1440
- search_sorted_ = SearchSorted(dtype, right)
1441
- return search_sorted_(sorted_sequence, values, sorter)
1471
+ return search_sorted_(sorted_sequence, values, sorter, dtype, right)
1442
1472
 
1443
1473
 
1444
1474
  def ger(input, vec2):
@@ -1488,7 +1518,7 @@ def size(input_x):
1488
1518
 
1489
1519
  Args:
1490
1520
  input_x (Tensor): Input parameters, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is
1491
- `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
1521
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
1492
1522
 
1493
1523
  Returns:
1494
1524
  int. A scalar representing the elements' size of `input_x`, tensor is the number of elements
@@ -1681,7 +1711,8 @@ def flatten(input, order='C', *, start_dim=1, end_dim=-1):
1681
1711
 
1682
1712
  def check_dim_valid(start_dim, end_dim):
1683
1713
  if start_dim > end_dim:
1684
- raise ValueError("For 'flatten', 'start_dim' cannot come after 'end_dim'.")
1714
+ raise ValueError(
1715
+ "For 'flatten', 'start_dim' cannot come after 'end_dim'.")
1685
1716
 
1686
1717
  def canonicalize_axis(axis, x_rank):
1687
1718
  ndim = x_rank if x_rank != 0 else 1
@@ -1693,7 +1724,8 @@ def flatten(input, order='C', *, start_dim=1, end_dim=-1):
1693
1724
  raise TypeError(f"For 'flatten', argument 'input' must be Tensor.")
1694
1725
  if not isinstance(start_dim, int) or not isinstance(end_dim, int) or \
1695
1726
  isinstance(start_dim, bool) or isinstance(end_dim, bool):
1696
- raise TypeError(f"For 'flatten', both 'start_dim' and 'end_dim' must be int.")
1727
+ raise TypeError(
1728
+ f"For 'flatten', both 'start_dim' and 'end_dim' must be int.")
1697
1729
  check_flatten_order_const(order)
1698
1730
  if order == 'F':
1699
1731
  x_rank = rank_(input)
@@ -3269,13 +3301,13 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
3269
3301
  Note:
3270
3302
  If some values of the `indices` exceed the upper or lower bounds of the index of `input_x`, instead of raising
3271
3303
  an index error, the corresponding `updates` will not be updated to `input_x`.
3304
+ The backward is supported only for the case `updates.shape == indices.shape`.
3272
3305
 
3273
3306
  Args:
3274
3307
  input_x (Tensor): The target tensor. The rank must be at least 1.
3275
3308
  indices (Tensor): The index of `input_x` to do scatter operation whose data type must be mindspore.int32 or
3276
3309
  mindspore.int64. Same rank as `input_x`. And accepted range is [-s, s) where s is the size along axis.
3277
- updates (Tensor): The tensor doing the scatter operation with `input_x`, has the same type as `input_x` and
3278
- the same shape as `indices`.
3310
+ updates (Tensor): The tensor doing the scatter operation with `input_x`.
3279
3311
  axis (int): Which axis to scatter. Accepted range is [-r, r) where r = rank(input_x). Default: ``0``.
3280
3312
  reduction (str): Which reduction operation to scatter, supports ``"none"`` , ``"add"`` . Default: ``"none"``.
3281
3313
  When `reduction` is set to ``"none"``, `updates` will be assigned to `input_x` according to `indices`.
@@ -3287,7 +3319,6 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
3287
3319
  Raises:
3288
3320
  TypeError: If `indices` is neither int32 nor int64.
3289
3321
  ValueError: If anyone of the rank among `input_x`, `indices` and `updates` less than 1.
3290
- ValueError: If the shape of `updates` is not equal to the shape of `indices`.
3291
3322
  ValueError: If the rank of `updates` is not equal to the rank of `input_x`.
3292
3323
  RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
3293
3324
  is required when data type conversion of Parameter is not supported.
@@ -3319,8 +3350,7 @@ def tensor_scatter_elements(input_x, indices, updates, axis=0, reduction="none")
3319
3350
  [ 5 5 14]
3320
3351
  [ 7 15 11]]
3321
3352
  """
3322
- _tensor_scatter_elements = _get_cache_prim(TensorScatterElements)(axis, reduction)
3323
- return _tensor_scatter_elements(input_x, indices, updates)
3353
+ return tensor_scatter_elements_ext(input_x, indices, updates, axis, reduction)
3324
3354
 
3325
3355
 
3326
3356
  def scatter(input, axis, index, src):
@@ -3328,24 +3358,26 @@ def scatter(input, axis, index, src):
3328
3358
  Update the value in `src` to `input` according to the specified index.
3329
3359
  Refer to :func:`mindspore.ops.tensor_scatter_elements` for more details.
3330
3360
 
3361
+ .. note::
3362
+ The backward is supported only for the case `src.shape == index.shape`.
3363
+
3331
3364
  Args:
3332
3365
  input (Tensor): The target tensor. The rank of `input` must be at least 1.
3333
3366
  axis (int): Which axis to scatter. Accepted range is [-r, r) where r = rank(input).
3334
- index (Tensor): The index to do update operation whose data type must be mindspore.int32 or
3335
- mindspore.int64. Same rank as `input` . And accepted range is [-s, s) where s is the size along axis.
3336
- src (Tensor): The tensor doing the update operation with `input` , has the same type as `input` ,
3337
- and the shape of `src` should be equal to the shape of `index` .
3367
+ index (Tensor): The index to do update operation whose data must be positive number with type of mindspore.int32
3368
+ or mindspore.int64. Same rank as `input` . And accepted range is [-s, s) where s is the size along axis.
3369
+ src (Tensor, float): The data doing the update operation with `input`. Can be a tensor with the same data type
3370
+ as `input` or a float number to scatter.
3338
3371
 
3339
3372
  Returns:
3340
- Tensor, has the same shape and type as `input` .
3373
+ The backward is supported only for the case `src.shape == index.shape` when `src` is a tensor.
3341
3374
 
3342
3375
  Raises:
3343
3376
  TypeError: If `index` is neither int32 nor int64.
3344
- ValueError: If anyone of the rank among `input` , `index` and `src` less than 1.
3345
- ValueError: If the shape of `src` is not equal to the shape of `index` .
3377
+ ValueError: If rank of any of `input` , `index` and `src` less than 1.
3346
3378
  ValueError: If the rank of `src` is not equal to the rank of `input` .
3347
- RuntimeError: If the data type of `input` and `src` conversion of Parameter
3348
- is required when data type conversion of Parameter is not supported.
3379
+ TypeError: If the data type of `input` and `src` have different dtypes.
3380
+ RuntimeError: If `index` has negative elements.
3349
3381
 
3350
3382
  Supported Platforms:
3351
3383
  ``Ascend`` ``GPU`` ``CPU``
@@ -3381,7 +3413,9 @@ def scatter(input, axis, index, src):
3381
3413
  [0. 0. 0. 0. 0.]
3382
3414
  [0. 0. 0. 0. 0.]]
3383
3415
  """
3384
- return ops.tensor_scatter_elements(input_x=input, indices=index, updates=src, axis=axis)
3416
+ if isinstance(src, Tensor):
3417
+ return scatter_(input, axis, index, src)
3418
+ return scatter_value_(input, axis, index, src)
3385
3419
 
3386
3420
 
3387
3421
  def scatter_add_ext(input, dim, index, src):
@@ -3516,7 +3550,8 @@ def slice_scatter(input, src, axis=0, start=None, end=None, step=1):
3516
3550
  _check_is_tensor("input", input, "slice_scatter")
3517
3551
  _check_is_tensor("src", src, "slice_scatter")
3518
3552
  input_shape = input.shape
3519
- input_rank, index, axis = _get_slice_scatter_const(input_shape, axis, start, end, step)
3553
+ input_rank, index, axis = _get_slice_scatter_const(
3554
+ input_shape, axis, start, end, step)
3520
3555
 
3521
3556
  src_shape = src.shape
3522
3557
  index_shape = input_shape[:axis] + (len(index),) + input_shape[axis + 1:]
@@ -3638,7 +3673,8 @@ def space_to_batch_nd(input_x, block_size, paddings):
3638
3673
  [[[3.]]]
3639
3674
  [[[4.]]]]
3640
3675
  """
3641
- _space_to_batch_nd = _get_cache_prim(P.SpaceToBatchND)(block_size, paddings)
3676
+ _space_to_batch_nd = _get_cache_prim(
3677
+ P.SpaceToBatchND)(block_size, paddings)
3642
3678
  return _space_to_batch_nd(input_x)
3643
3679
 
3644
3680
 
@@ -4330,9 +4366,11 @@ def index_select(input, axis, index):
4330
4366
  [[ 8. 9. 10. 11.]]]
4331
4367
  """
4332
4368
  if not (isinstance(input, Tensor) and isinstance(index, Tensor)):
4333
- raise TypeError(f"For 'index_select', `input` and `index` must be all tensors.")
4369
+ raise TypeError(
4370
+ f"For 'index_select', `input` and `index` must be all tensors.")
4334
4371
  if index.ndim != 1:
4335
- raise ValueError(f"For 'index_select', the dimension of `index` must be 1, but got {index.ndim}")
4372
+ raise ValueError(
4373
+ f"For 'index_select', the dimension of `index` must be 1, but got {index.ndim}")
4336
4374
  axis = _check_check_axis_in_range(axis, input.ndim)
4337
4375
  return gather_(input, index, axis)
4338
4376
 
@@ -4425,9 +4463,11 @@ def is_nonzero(input):
4425
4463
  True
4426
4464
  """
4427
4465
  if not isinstance(input, Tensor):
4428
- raise TypeError(f'For is_nonzero, the input must be a Tensor, but got {type(input)}.')
4466
+ raise TypeError(
4467
+ f'For is_nonzero, the input must be a Tensor, but got {type(input)}.')
4429
4468
  if input.numel() != 1:
4430
- raise ValueError(f"For is_nonzero, the numel of input must be 1, but got {input.numel()}.")
4469
+ raise ValueError(
4470
+ f"For is_nonzero, the numel of input must be 1, but got {input.numel()}.")
4431
4471
  out = ops.squeeze(input)
4432
4472
  return bool(out)
4433
4473
 
@@ -4622,38 +4662,6 @@ def tuple_to_array(input_x):
4622
4662
  return tuple_to_tensor_(input_x, dtype)
4623
4663
 
4624
4664
 
4625
- def masked_select(input, mask):
4626
- """
4627
- Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
4628
- The shapes of the `mask` tensor and the `x` tensor don't need to match, but they must be broadcastable.
4629
-
4630
- Args:
4631
- input (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
4632
- mask (Tensor[bool]): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
4633
-
4634
- Returns:
4635
- A 1-D Tensor, with the same type as `input`.
4636
-
4637
- Raises:
4638
- TypeError: If `input` or `mask` is not a Tensor.
4639
- TypeError: If dtype of `mask` is not bool.
4640
-
4641
- Supported Platforms:
4642
- ``Ascend`` ``GPU`` ``CPU``
4643
-
4644
- Examples:
4645
- >>> import numpy as np
4646
- >>> import mindspore
4647
- >>> from mindspore import Tensor, ops
4648
- >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
4649
- >>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_)
4650
- >>> output = ops.masked_select(x, mask)
4651
- >>> print(output)
4652
- [1 3]
4653
- """
4654
- return masked_select_(input, mask)
4655
-
4656
-
4657
4665
  def diagflat(input, offset=0):
4658
4666
  r"""
4659
4667
  Create a 2-D Tensor which diagonal is the flattened `input` .
@@ -4687,9 +4695,11 @@ def diagflat(input, offset=0):
4687
4695
  [0. 0. 0.]]
4688
4696
  """
4689
4697
  if not isinstance(input, Tensor):
4690
- raise TypeError(f"For diagflat, the input x must be tensor, but got {type(input)}")
4698
+ raise TypeError(
4699
+ f"For diagflat, the input x must be tensor, but got {type(input)}")
4691
4700
  if not isinstance(offset, int):
4692
- raise TypeError(f"For diagflat, the offset must be int, but got {type(offset)}")
4701
+ raise TypeError(
4702
+ f"For diagflat, the offset must be int, but got {type(offset)}")
4693
4703
  offset_abs = abs(offset)
4694
4704
  if input.size == 0:
4695
4705
  return zeros((offset_abs, offset_abs), input.dtype)
@@ -4759,7 +4769,9 @@ def _split_int(x, split_size_or_sections, axis):
4759
4769
  """
4760
4770
  arr_shape = x.shape
4761
4771
  length_along_dim = arr_shape[axis]
4762
- if split_size_or_sections > length_along_dim:
4772
+ if length_along_dim == 0:
4773
+ res = _get_cache_prim(P.Split)(axis)(x)
4774
+ elif split_size_or_sections > length_along_dim:
4763
4775
  res = _get_cache_prim(P.Split)(axis, 1)(x)
4764
4776
  elif length_along_dim % split_size_or_sections == 0:
4765
4777
  sections = length_along_dim // split_size_or_sections
@@ -4773,7 +4785,7 @@ def _split_int(x, split_size_or_sections, axis):
4773
4785
  start2 = _tuple_setitem(start1, axis, length1)
4774
4786
  size2 = _tuple_setitem(arr_shape, axis, length2)
4775
4787
  res = _get_cache_prim(P.Split)(axis, num_sections)(tensor_slice(x, start1, size1)) + \
4776
- _get_cache_prim(P.Split)(axis, 1)(tensor_slice(x, start2, size2))
4788
+ _get_cache_prim(P.Split)(axis, 1)(tensor_slice(x, start2, size2))
4777
4789
  return res
4778
4790
 
4779
4791
 
@@ -4798,6 +4810,7 @@ def _split_sub_tensors(x, split_size_or_sections, axis):
4798
4810
  sub_tensors.append(sliced_tensor)
4799
4811
  return sub_tensors
4800
4812
 
4813
+
4801
4814
  def split(tensor, split_size_or_sections, axis=0):
4802
4815
  """
4803
4816
  Splits the Tensor into chunks along the given axis.
@@ -4839,7 +4852,8 @@ def split(tensor, split_size_or_sections, axis=0):
4839
4852
  if not isinstance(tensor, Tensor):
4840
4853
  raise TypeError(f'expect `tensor` is a Tensor, but got {type(tensor)}')
4841
4854
  if type(axis) is not int:
4842
- raise TypeError(f"Type of Argument `axis` should be integer but got {type(axis)}")
4855
+ raise TypeError(
4856
+ f"Type of Argument `axis` should be integer but got {type(axis)}")
4843
4857
  arr_axis = _canonicalize_axis(axis, tensor.ndim)
4844
4858
 
4845
4859
  if type(split_size_or_sections) is int:
@@ -4851,7 +4865,8 @@ def split(tensor, split_size_or_sections, axis=0):
4851
4865
  elif isinstance(split_size_or_sections, (list, tuple)):
4852
4866
  for item in split_size_or_sections:
4853
4867
  if type(item) is not int:
4854
- raise TypeError(f"Each element in 'split_size_or_sections' should be integer, but got {type(item)}.")
4868
+ raise TypeError(
4869
+ f"Each element in 'split_size_or_sections' should be integer, but got {type(item)}.")
4855
4870
  if item < 0:
4856
4871
  raise TypeError(f"Each element in 'split_size_or_sections' should be non-negative, "
4857
4872
  f"but got {split_size_or_sections}.")
@@ -4861,10 +4876,11 @@ def split(tensor, split_size_or_sections, axis=0):
4861
4876
  f"but got {sum(split_size_or_sections)}.")
4862
4877
  res = _split_sub_tensors(tensor, split_size_or_sections, arr_axis)
4863
4878
  else:
4864
- raise TypeError(f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), " \
4879
+ raise TypeError(f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), "
4865
4880
  f"but got {type(split_size_or_sections)}")
4866
4881
  return tuple(res)
4867
4882
 
4883
+
4868
4884
  def split_ext(tensor, split_size_or_sections, axis=0):
4869
4885
  """
4870
4886
  Splits the Tensor into chunks along the given axis.
@@ -4908,14 +4924,14 @@ def split_ext(tensor, split_size_or_sections, axis=0):
4908
4924
  elif isinstance(split_size_or_sections, (list, tuple)):
4909
4925
  res = split_with_size(tensor, split_size_or_sections, axis)
4910
4926
  else:
4911
- raise TypeError(f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), " \
4927
+ raise TypeError(f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), "
4912
4928
  f"but got {type(split_size_or_sections)}")
4913
4929
  return res
4914
4930
 
4915
4931
 
4916
4932
  def tril(input, diagonal=0): # pylint: disable=redefined-outer-name
4917
4933
  """
4918
- Returns the lower triangle part of 'input' (elements that contain the diagonal and below),
4934
+ Returns the lower triangle part of `input` (elements that contain the diagonal and below),
4919
4935
  and set the other elements to zeros.
4920
4936
 
4921
4937
  Args:
@@ -4925,13 +4941,13 @@ def tril(input, diagonal=0): # pylint: disable=redefined-outer-name
4925
4941
  indicating the main diagonal.
4926
4942
 
4927
4943
  Returns:
4928
- Tensor, the same shape and data type as the input `x`.
4944
+ Tensor, the same shape and data type as the `input`.
4929
4945
 
4930
4946
  Raises:
4931
- TypeError: If `x` is not a Tensor.
4947
+ TypeError: If `input` is not a Tensor.
4932
4948
  TypeError: If `diagonal` is not an int.
4933
- TypeError: If the type of `x` is neither number nor bool.
4934
- ValueError: If the rank of `x` is less than 2.
4949
+ TypeError: If the type of `input` is neither number nor bool.
4950
+ ValueError: If the rank of `input` is less than 2.
4935
4951
 
4936
4952
  Supported Platforms:
4937
4953
  ``Ascend`` ``GPU`` ``CPU``
@@ -4970,10 +4986,70 @@ def tril(input, diagonal=0): # pylint: disable=redefined-outer-name
4970
4986
  [10 11 0 0]
4971
4987
  [14 15 16 0]]
4972
4988
  """
4973
- tril_ = Tril(diagonal)
4989
+ tril_ = _get_cache_prim(Tril)(diagonal)
4974
4990
  return tril_(input)
4975
4991
 
4976
4992
 
4993
+ def tril_ext(input, diagonal=0):
4994
+ """
4995
+ Returns the lower triangle part of `input` (elements that contain the diagonal and below),
4996
+ and set the other elements to zeros.
4997
+
4998
+ Args:
4999
+ input (Tensor): A Tensor with shape :math:`(x_1, x_2, ..., x_R)`. The rank must be at least 2.
5000
+ Supporting all number types including bool.
5001
+ diagonal (int, optional): An optional attribute indicates the diagonal to consider, default: 0,
5002
+ indicating the main diagonal.
5003
+
5004
+ Returns:
5005
+ Tensor, the same shape and data type as the `input`.
5006
+
5007
+ Raises:
5008
+ TypeError: If `input` is not a Tensor.
5009
+ TypeError: If `diagonal` is not an int.
5010
+ TypeError: If the type of `input` is neither number nor bool.
5011
+ ValueError: If the rank of `input` is less than 2.
5012
+
5013
+ Supported Platforms:
5014
+ ``Ascend``
5015
+
5016
+ Examples:
5017
+ >>> import numpy as np
5018
+ >>> from mindspore import Tensor, ops
5019
+ >>> x = Tensor(np.array([[ 1, 2, 3, 4],
5020
+ ... [ 5, 6, 7, 8],
5021
+ ... [10, 11, 12, 13],
5022
+ ... [14, 15, 16, 17]]))
5023
+ >>> result = ops.function.array_func.tril_ext(x)
5024
+ >>> print(result)
5025
+ [[ 1 0 0 0]
5026
+ [ 5 6 0 0]
5027
+ [10 11 12 0]
5028
+ [14 15 16 17]]
5029
+ >>> x = Tensor(np.array([[ 1, 2, 3, 4],
5030
+ ... [ 5, 6, 7, 8],
5031
+ ... [10, 11, 12, 13],
5032
+ ... [14, 15, 16, 17]]))
5033
+ >>> result = ops.function.array_func.tril_ext(x, diagonal=1)
5034
+ >>> print(result)
5035
+ [[ 1 2 0 0]
5036
+ [ 5 6 7 0]
5037
+ [10 11 12 13]
5038
+ [14 15 16 17]]
5039
+ >>> x = Tensor(np.array([[ 1, 2, 3, 4],
5040
+ ... [ 5, 6, 7, 8],
5041
+ ... [10, 11, 12, 13],
5042
+ ... [14, 15, 16, 17]]))
5043
+ >>> result = ops.function.array_func.tril_ext(x, diagonal=-1)
5044
+ >>> print(result)
5045
+ [[ 0 0 0 0]
5046
+ [ 5 0 0 0]
5047
+ [10 11 0 0]
5048
+ [14 15 16 0]]
5049
+ """
5050
+ return tril_ext_impl(input, diagonal)
5051
+
5052
+
4977
5053
  @_primexpr
4978
5054
  def _canonicalize_axis(axis, ndim):
4979
5055
  """
@@ -4992,7 +5068,8 @@ def _canonicalize_axis(axis, ndim):
4992
5068
  if not isinstance(ax, int):
4993
5069
  raise TypeError(f'axis should be integers, not {type(ax)}')
4994
5070
  if not -ndim <= ax < ndim:
4995
- raise ValueError(f'axis {ax} is out of bounds for array of dimension {ndim}')
5071
+ raise ValueError(
5072
+ f'axis {ax} is out of bounds for array of dimension {ndim}')
4996
5073
 
4997
5074
  def canonicalizer(ax):
4998
5075
  return ax + ndim if ax < 0 else ax
@@ -5072,7 +5149,9 @@ def _tensor_split_sub_int(x, indices_or_sections, axis):
5072
5149
  """
5073
5150
  arr_shape = x.shape
5074
5151
  length_along_dim = arr_shape[axis]
5075
- if indices_or_sections > length_along_dim:
5152
+ if length_along_dim == 0:
5153
+ res = _get_cache_prim(P.Split)(axis)(x)
5154
+ elif indices_or_sections > length_along_dim:
5076
5155
  res = _get_cache_prim(P.Split)(axis, length_along_dim)(x)
5077
5156
  indices_or_sections_n = [length_along_dim, length_along_dim + 1]
5078
5157
  res2 = _tensor_split_sub_tensors(x, indices_or_sections_n, axis)
@@ -5083,14 +5162,16 @@ def _tensor_split_sub_int(x, indices_or_sections, axis):
5083
5162
  else:
5084
5163
  num_long_tensor = length_along_dim % indices_or_sections
5085
5164
  num_short_tensor = indices_or_sections - num_long_tensor
5086
- length1 = num_long_tensor * (length_along_dim // indices_or_sections + 1)
5165
+ length1 = num_long_tensor * \
5166
+ (length_along_dim // indices_or_sections + 1)
5087
5167
  length2 = length_along_dim - length1
5088
5168
  start1 = _list_comprehensions(rank_(x), 0, True)
5089
5169
  size1 = _tuple_setitem(arr_shape, axis, length1)
5090
5170
  start2 = _tuple_setitem(start1, axis, length1)
5091
5171
  size2 = _tuple_setitem(arr_shape, axis, length2)
5092
5172
  res = _get_cache_prim(P.Split)(axis, num_long_tensor)(tensor_slice(x, start1, size1)) + \
5093
- _get_cache_prim(P.Split)(axis, num_short_tensor)(tensor_slice(x, start2, size2))
5173
+ _get_cache_prim(P.Split)(axis, num_short_tensor)(
5174
+ tensor_slice(x, start2, size2))
5094
5175
  return res
5095
5176
 
5096
5177
 
@@ -5143,21 +5224,25 @@ def tensor_split(input, indices_or_sections, axis=0):
5143
5224
  raise TypeError(f'expect `x` is a Tensor, but got {type(input)}')
5144
5225
 
5145
5226
  if type(axis) is not int:
5146
- raise TypeError(f"Type of Argument `axis` should be integer but got {type(axis)}")
5227
+ raise TypeError(
5228
+ f"Type of Argument `axis` should be integer but got {type(axis)}")
5147
5229
  handle_axis = _canonicalize_axis(axis, input.ndim)
5148
5230
  if type(indices_or_sections) is int:
5149
5231
  if indices_or_sections > 0:
5150
- res = _tensor_split_sub_int(input, indices_or_sections, handle_axis)
5232
+ res = _tensor_split_sub_int(
5233
+ input, indices_or_sections, handle_axis)
5151
5234
  else:
5152
5235
  raise ValueError(f"For tensor_split, the value of 'indices_or_sections' must be more than zero "
5153
5236
  f"but got {indices_or_sections}")
5154
5237
  elif isinstance(indices_or_sections, (list, tuple)):
5155
5238
  for item in indices_or_sections:
5156
5239
  if type(item) is not int:
5157
- raise TypeError(f"Each element in 'indices_or_sections' should be integer, but got {type(item)}.")
5158
- res = _tensor_split_sub_tensors(input, indices_or_sections, handle_axis)
5240
+ raise TypeError(
5241
+ f"Each element in 'indices_or_sections' should be integer, but got {type(item)}.")
5242
+ res = _tensor_split_sub_tensors(
5243
+ input, indices_or_sections, handle_axis)
5159
5244
  else:
5160
- raise TypeError(f"Type of Argument `indices_or_sections` should be integer, tuple(int) or list(int), " \
5245
+ raise TypeError(f"Type of Argument `indices_or_sections` should be integer, tuple(int) or list(int), "
5161
5246
  f"but got {type(indices_or_sections)}")
5162
5247
 
5163
5248
  return res
@@ -5193,7 +5278,8 @@ def vsplit(input, indices_or_sections):
5193
5278
  if not isinstance(input, Tensor):
5194
5279
  raise TypeError(f'expect `x` is a Tensor, but got {type(input)}')
5195
5280
  if input.ndim < 1:
5196
- raise ValueError(f'vsplit expect `x` is a Tensor with at least 1 dimension, but got {input.ndim}')
5281
+ raise ValueError(
5282
+ f'vsplit expect `x` is a Tensor with at least 1 dimension, but got {input.ndim}')
5197
5283
  return tensor_split(input, indices_or_sections, 0)
5198
5284
 
5199
5285
 
@@ -5229,7 +5315,8 @@ def hsplit(input, indices_or_sections):
5229
5315
  if not isinstance(input, Tensor):
5230
5316
  raise TypeError(f'expect `x` is a Tensor, but got {type(input)}')
5231
5317
  if input.ndim < 2:
5232
- raise ValueError(f'hsplit expect `x` is a Tensor with at least 2 dimension, but got {input.ndim}')
5318
+ raise ValueError(
5319
+ f'hsplit expect `x` is a Tensor with at least 2 dimension, but got {input.ndim}')
5233
5320
 
5234
5321
  return tensor_split(input, indices_or_sections, 1)
5235
5322
 
@@ -5262,7 +5349,8 @@ def dsplit(input, indices_or_sections):
5262
5349
  if not isinstance(input, Tensor):
5263
5350
  raise TypeError(f'expect `x` is a Tensor, but got {type(input)}')
5264
5351
  if input.ndim < 3:
5265
- raise ValueError(f'dsplit expect `x` is a Tensor with at least 3 dimension, but got {input.ndim}')
5352
+ raise ValueError(
5353
+ f'dsplit expect `x` is a Tensor with at least 3 dimension, but got {input.ndim}')
5266
5354
 
5267
5355
  return tensor_split(input, indices_or_sections, 2)
5268
5356
 
@@ -5354,7 +5442,8 @@ def max(input, axis=None, keepdims=False, *, initial=None, where=None): # pylin
5354
5442
  if axis is None:
5355
5443
  return (max_(input), Tensor(0, dtype=mstype.int64))
5356
5444
  if initial is not None and not isinstance(initial, numbers.Number):
5357
- raise TypeError(f"For 'max', 'initial' must be a scalar, but got {type(initial)}")
5445
+ raise TypeError(
5446
+ f"For 'max', 'initial' must be a scalar, but got {type(initial)}")
5358
5447
  if axis is not None and not isinstance(axis, int):
5359
5448
  raise TypeError(f"For 'max', 'axis' must be int, but got {type(axis)}")
5360
5449
  input = _init_and_select_elem(input, initial, where, ops.maximum)
@@ -5408,7 +5497,6 @@ def argmax(input, dim=None, keepdim=False):
5408
5497
  return out
5409
5498
 
5410
5499
 
5411
-
5412
5500
  def min(input, axis=None, keepdims=False, *, initial=None, where=None): # pylint: disable=redefined-outer-name
5413
5501
  """
5414
5502
  Calculates the minimum value along with the given axis for the input tensor. It returns the minimum values and
@@ -5471,7 +5559,8 @@ def min(input, axis=None, keepdims=False, *, initial=None, where=None): # pylin
5471
5559
  if axis is None:
5472
5560
  return (min_(input), Tensor(0, dtype=mstype.int64))
5473
5561
  if initial is not None and not isinstance(initial, numbers.Number):
5474
- raise TypeError(f"For 'min', 'initial' must be a scalar, but got {type(initial)}")
5562
+ raise TypeError(
5563
+ f"For 'min', 'initial' must be a scalar, but got {type(initial)}")
5475
5564
  if axis is not None and not isinstance(axis, int):
5476
5565
  raise TypeError(f"For 'min', 'axis' must be int, but got {type(axis)}")
5477
5566
  input = _init_and_select_elem(input, initial, where, ops.minimum)
@@ -5584,7 +5673,8 @@ def narrow(input, axis, start, length):
5584
5673
  validator.check_value_type("input", input, Tensor, "narrow")
5585
5674
  validator.check_axis_in_range(axis, input.ndim)
5586
5675
  validator.check_int_range(start, 0, input.shape[axis], validator.INC_LEFT)
5587
- validator.check_int_range(length, 1, input.shape[axis] - start, validator.INC_BOTH)
5676
+ validator.check_int_range(
5677
+ length, 1, input.shape[axis] - start, validator.INC_BOTH)
5588
5678
 
5589
5679
  begins = [0] * input.ndim
5590
5680
  begins[axis] = start
@@ -5631,7 +5721,7 @@ def narrow_ext(input, dim, start, length):
5631
5721
  [ 8 9]]
5632
5722
  """
5633
5723
  validator.check_value_type("input", input, Tensor, "narrow")
5634
- return slice_ext_op(input, dim, start, start+length, 1)
5724
+ return slice_ext_op(input, dim, start, start + length, 1)
5635
5725
 
5636
5726
 
5637
5727
  def topk(input, k, dim=None, largest=True, sorted=True):
@@ -5825,7 +5915,8 @@ def _check_unfold_params(param, param_name, param_size):
5825
5915
  """Check the parameters of unfold op."""
5826
5916
  validator.check_value_type(param_name, param, [int, tuple, list], 'unfold')
5827
5917
  param = (param, param) if isinstance(param, int) else param
5828
- validator.check(param_name + " size", len(param), "", param_size, validator.IN, 'unfold')
5918
+ validator.check(param_name + " size", len(param), "",
5919
+ param_size, validator.IN, 'unfold')
5829
5920
  if param_name == "padding":
5830
5921
  validator.check_non_negative_int_sequence(param, param_name, 'unfold')
5831
5922
  else:
@@ -5928,7 +6019,8 @@ def _check_diagonal_axes(dim1, dim2, x_ndim):
5928
6019
  def _check_is_tensor(param_name, input, cls_name):
5929
6020
  """Returns True if input is Tensor."""
5930
6021
  if not isinstance(input, Tensor):
5931
- raise TypeError(f"For {cls_name}, {param_name} must be a Tensor, but got {type(input)}.")
6022
+ raise TypeError(
6023
+ f"For {cls_name}, {param_name} must be a Tensor, but got {type(input)}.")
5932
6024
 
5933
6025
 
5934
6026
  @_primexpr
@@ -6241,19 +6333,22 @@ def column_stack(tensors):
6241
6333
  [1 2]]
6242
6334
  """
6243
6335
  if not isinstance(tensors, (list, tuple)):
6244
- raise TypeError(f"For column_stack, the input must be list or tuple of tensors, but got {type(tensors)}.")
6336
+ raise TypeError(
6337
+ f"For column_stack, the input must be list or tuple of tensors, but got {type(tensors)}.")
6245
6338
 
6246
6339
  trans_x = ()
6247
6340
  for tensor in tensors:
6248
6341
  if not isinstance(tensor, Tensor):
6249
- raise TypeError(f"For column_stack, the input element must be tensor, but got {type(tensor)}.")
6342
+ raise TypeError(
6343
+ f"For column_stack, the input element must be tensor, but got {type(tensor)}.")
6250
6344
  if tensor.ndim < 1:
6251
6345
  tensor = expand_dims(tensor, 0)
6252
6346
  if tensor.ndim == 1:
6253
6347
  tensor = expand_dims(tensor, 1)
6254
6348
  trans_x += (tensor,)
6255
6349
  if not trans_x:
6256
- raise ValueError(f"For column_stack, the input must have at least 1 tensor, but got 0.")
6350
+ raise ValueError(
6351
+ f"For column_stack, the input must have at least 1 tensor, but got 0.")
6257
6352
  _concat = _get_cache_prim(P.Concat)(1)
6258
6353
  return _concat(trans_x)
6259
6354
 
@@ -6289,17 +6384,20 @@ def hstack(tensors):
6289
6384
  [1. 1. 1. 2. 2. 2.]
6290
6385
  """
6291
6386
  if not isinstance(tensors, (list, tuple)):
6292
- raise TypeError(f"For hstack, the input must be list or tuple, but got {type(tensors)}.")
6387
+ raise TypeError(
6388
+ f"For hstack, the input must be list or tuple, but got {type(tensors)}.")
6293
6389
 
6294
6390
  tuple_of_tensor = ()
6295
6391
  for tensor in tensors:
6296
6392
  if not isinstance(tensor, Tensor):
6297
- raise TypeError(f"For hstack, the input element must be tensor, but got {type(tensor)}.")
6393
+ raise TypeError(
6394
+ f"For hstack, the input element must be tensor, but got {type(tensor)}.")
6298
6395
  if tensor.ndim < 1:
6299
6396
  tensor = expand_dims(tensor, 0)
6300
6397
  tuple_of_tensor += (tensor,)
6301
6398
  if not tuple_of_tensor:
6302
- raise ValueError("For hstack, the input must have at least 1 tensor, but got 0.")
6399
+ raise ValueError(
6400
+ "For hstack, the input must have at least 1 tensor, but got 0.")
6303
6401
  if tuple_of_tensor[0].ndim <= 1:
6304
6402
  _concat = _get_cache_prim(P.Concat)(0)
6305
6403
  return _concat(tuple_of_tensor)
@@ -6328,7 +6426,8 @@ def _get_moved_perm(ndim, source, destination):
6328
6426
  Helper function for movedim, returns permutation after moving axis
6329
6427
  from source to destination.
6330
6428
  """
6331
- dest_sorted_idx = [i for i, _ in sorted(enumerate(destination), key=operator.itemgetter(1))]
6429
+ dest_sorted_idx = [i for i, _ in sorted(
6430
+ enumerate(destination), key=operator.itemgetter(1))]
6332
6431
  axis_orig = [i for i in builtins.range(0, ndim) if i not in source]
6333
6432
 
6334
6433
  k = 0
@@ -6455,7 +6554,8 @@ def swapaxes(input, axis0, axis1):
6455
6554
  (4, 3, 2)
6456
6555
  '''
6457
6556
  if not isinstance(input, Tensor):
6458
- raise TypeError(f'For ops.swapaxes, parameter `input` must be Tensor, but got {type(input)}')
6557
+ raise TypeError(
6558
+ f'For ops.swapaxes, parameter `input` must be Tensor, but got {type(input)}')
6459
6559
 
6460
6560
  axis0, axis1 = _check_swapaxes_axis((axis0, axis1), input.ndim)
6461
6561
  if axis0 == axis1:
@@ -6466,10 +6566,10 @@ def swapaxes(input, axis0, axis1):
6466
6566
  perm = ops.make_range(0, input.ndim)
6467
6567
  if axis1 + 1 < input.ndim:
6468
6568
  new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \
6469
- perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1] + perm[axis1 + 1:]
6569
+ perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1] + perm[axis1 + 1:]
6470
6570
  else:
6471
6571
  new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \
6472
- perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1]
6572
+ perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1]
6473
6573
 
6474
6574
  return transpose_(input, new_perm)
6475
6575
 
@@ -6515,13 +6615,15 @@ def _check_is_int(arg_value, arg_name, op_name):
6515
6615
 
6516
6616
  @_primexpr
6517
6617
  def _check_positive_int(arg_value, arg_name, op_name):
6518
- arg_value = validator.check_int_range(arg_value, 0, 2147483647, validator.INC_RIGHT, arg_name, op_name)
6618
+ arg_value = validator.check_int_range(
6619
+ arg_value, 0, 2147483647, validator.INC_RIGHT, arg_name, op_name)
6519
6620
  return arg_value
6520
6621
 
6521
6622
 
6522
6623
  @constexpr
6523
6624
  def _check_axis_range(arg_value, limit, arg_name, op_name):
6524
- arg_value = validator.check_int_range(arg_value, -limit, limit, validator.INC_LEFT, arg_name, op_name)
6625
+ arg_value = validator.check_int_range(
6626
+ arg_value, -limit, limit, validator.INC_LEFT, arg_name, op_name)
6525
6627
  return arg_value
6526
6628
 
6527
6629
 
@@ -6539,6 +6641,14 @@ def _cal_reshape(x_shape, rep, axis):
6539
6641
  return tuple(x_reshape)
6540
6642
 
6541
6643
 
6644
+ @_primexpr
6645
+ def _check_rank_range(x_rank, limit, arg_name, op_name):
6646
+ if x_rank > limit:
6647
+ raise ValueError(
6648
+ f"For {op_name}, the rank of {arg_name} should be less than or equal to {limit}, but got {x_rank}.")
6649
+ return x_rank
6650
+
6651
+
6542
6652
  def repeat_interleave(input, repeats, axis=None):
6543
6653
  """
6544
6654
  Repeat elements of a tensor along an axis, like `numpy.repeat`.
@@ -6583,6 +6693,9 @@ def repeat_interleave_ext(input, repeats, dim=None, output_size=None):
6583
6693
  r"""
6584
6694
  Repeat elements of a tensor along an axis, like `numpy.repeat`.
6585
6695
 
6696
+ .. warning::
6697
+ Only support on Atlas A2 training series.
6698
+
6586
6699
  Args:
6587
6700
  input (Tensor): The tensor to repeat values for. Must be of type: float16,
6588
6701
  float32, int8, uint8, int16, int32, or int64.
@@ -6621,9 +6734,13 @@ def repeat_elements(x, rep, axis=0):
6621
6734
  """
6622
6735
  Repeat elements of a tensor along an axis, like `numpy.repeat` .
6623
6736
 
6737
+ Note:
6738
+ It is recommended to use :func:'mindspore.mint.repeat_interleave', the dimension of input 'x' can support
6739
+ a maximum of 8, and get better performance.
6740
+
6624
6741
  Args:
6625
- x (Tensor): The tensor to repeat values for. Must be of type: float16,
6626
- float32, int8, uint8, int16, int32, or int64.
6742
+ x (Tensor): The tensor to repeat values for. Must be of type: float16, float32, int8, uint8, int16, int32,
6743
+ or int64. The rank of `x` must be less than or equal to 7.
6627
6744
  rep (int): The number of times to repeat, must be positive.
6628
6745
  axis (int): The axis along which to repeat. Default: 0.
6629
6746
 
@@ -6632,6 +6749,9 @@ def repeat_elements(x, rep, axis=0):
6632
6749
  :math:`(s1, s2, ..., sn)` and axis is i, the output will have shape :math:`(s1, s2, ..., si * rep, ..., sn)`.
6633
6750
  The output type will be the same as the type of `x`.
6634
6751
 
6752
+ Raises:
6753
+ ValueError: If the rank of `x` is greater than 7.
6754
+
6635
6755
  Supported Platforms:
6636
6756
  ``Ascend`` ``GPU`` ``CPU``
6637
6757
 
@@ -6658,6 +6778,7 @@ def repeat_elements(x, rep, axis=0):
6658
6778
  rep = _check_positive_int(rep, "rep", "repeat_elements")
6659
6779
  axis = _check_is_int(axis, "axis", "repeat_elements")
6660
6780
  x_rank = rank_(x)
6781
+ x_rank = _check_rank_range(x_rank, 7, "x", "repeat_elements")
6661
6782
  axis = _check_axis_range(axis, x_rank, "axis", "repeat_elements")
6662
6783
  axis = axis + x.ndim if axis < 0 else axis
6663
6784
  expand_axis = axis + 1
@@ -6722,7 +6843,8 @@ def sequence_mask(lengths, maxlen=None):
6722
6843
  [[ True True False False ]
6723
6844
  [ True True True True ]]]
6724
6845
  """
6725
- const_utils.check_type_valid(ops.dtype(lengths), [mstype.int64, mstype.int32], 'lengths')
6846
+ const_utils.check_type_valid(
6847
+ ops.dtype(lengths), [mstype.int64, mstype.int32], 'lengths')
6726
6848
 
6727
6849
  if maxlen is None:
6728
6850
  flatten_data = reshape_(lengths, (-1,))
@@ -6733,7 +6855,8 @@ def sequence_mask(lengths, maxlen=None):
6733
6855
  maxlen = _check_positive_int(maxlen, "maxlen", "sequence_mask")
6734
6856
  maxlen = scalar_to_tensor_(maxlen, mstype.int32)
6735
6857
 
6736
- range_vector = range_(scalar_to_tensor_(0, mstype.int32), maxlen, scalar_to_tensor_(1, mstype.int32))
6858
+ range_vector = range_(scalar_to_tensor_(0, mstype.int32),
6859
+ maxlen, scalar_to_tensor_(1, mstype.int32))
6737
6860
  mask = expand_dims(lengths, -1)
6738
6861
  result = range_vector < mask
6739
6862
  return result
@@ -6747,6 +6870,221 @@ def top_k(input_x, k, sorted=True):
6747
6870
  return top_k_(input_x, k)
6748
6871
 
6749
6872
 
6873
+ def gather_ext(input, dim, index):
6874
+ r"""
6875
+ Gather data from a tensor by indices.
6876
+
6877
+ .. math::
6878
+ output[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)] =
6879
+ input[(i_0, i_1, ..., index[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)], i_{dim+1}, ..., i_n)]
6880
+
6881
+ .. warning::
6882
+ On Ascend, the behavior is unpredictable in the following cases:
6883
+
6884
+ - the value of `index` is not in the range `[-input.shape[dim], input.shape[dim])` in forward;
6885
+ - the value of `index` is not in the range `[0, input.shape[dim])` in backward.
6886
+
6887
+ Args:
6888
+ input (Tensor): The target tensor to gather values.
6889
+ dim (int): the axis to index along, must be in range `[-input.rank, input.rank)`.
6890
+ index (Tensor): The index tensor, with int32 or int64 data type. An valid `index` should be:
6891
+
6892
+ - `index.rank == input.rank`;
6893
+ - for `axis != dim`, `index.shape[axis] <= input.shape[axis]`;
6894
+ - the value of `index` is in range `[-input.shape[dim], input.shape[dim])`.
6895
+
6896
+ Returns:
6897
+ Tensor, has the same type as `input` and the same shape as `index`.
6898
+
6899
+ Raises:
6900
+ ValueError: If the shape of `index` is illegal.
6901
+ ValueError: If `dim` is not in `[-input.rank, input.rank)`.
6902
+ ValueError: If the value of `index` is out of the valid range.
6903
+ TypeError: If the type of `index` is illegal.
6904
+
6905
+ Supported Platforms:
6906
+ ``Ascend`` ``GPU`` ``CPU``
6907
+
6908
+ Examples:
6909
+ >>> import mindspore
6910
+ >>> import numpy as np
6911
+ >>> from mindspore import Tensor, ops
6912
+ >>> from mindspore.ops.function.array_func import gather_ext
6913
+ >>> input = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
6914
+ >>> index = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
6915
+ >>> output = gather_ext(input, 1, index)
6916
+ >>> print(output)
6917
+ [[-0.1 -0.1]
6918
+ [0.5 0.5]]
6919
+ """
6920
+ return gather_d_op(input, dim, index)
6921
+
6922
+
6923
+ def max_ext(input, dim=None, keepdim=False):
6924
+ """
6925
+ Calculates the maximum value along with the given dimension for the input tensor.
6926
+
6927
+ Args:
6928
+ input (Tensor): The input tensor, can be any dimension. Complex tensor is not supported for now.
6929
+ dim (int, optional): The dimension to reduce. Default: ``None`` .
6930
+ keepdim (bool, optional): Whether to reduce dimension, if true, the output will keep same dimension
6931
+ with the input, the output will reduce dimension if false. Default: ``False`` .
6932
+
6933
+ Returns:
6934
+ Tensor if `dim` is the default value ``None`` , the maximum value of input tensor, with the shape :math:`()` ,
6935
+ and same dtype as `input`.
6936
+
6937
+ tuple (Tensor) if `dim` is not the default value ``None`` , tuple of 2 tensors, containing the maximum
6938
+ value of the input tensor along the given dimension `dim` and the corresponding index.
6939
+
6940
+ - **values (Tensor)** - The maximum value of input tensor along the given dimension `dim`, with same dtype as
6941
+ `input`. If `keepdim` is ``True`` , the shape of output tensors is :math:`(input_1, input_2, ...,
6942
+ input_{axis-1}, 1, input_{axis+1}, ..., input_N)` . Otherwise, the shape is :math:`(input_1, input_2, ...,
6943
+ input_{axis-1}, input_{axis+1}, ..., input_N)` .
6944
+ - **index (Tensor)** - The index for the maximum value of the input tensor along the given dimension `dim`, with
6945
+ the same shape as `values`.
6946
+
6947
+ Raises:
6948
+ ValueError: If `dim` is the default value ``None`` and `keepdim` is not ``False`` .
6949
+
6950
+ Supported Platforms:
6951
+ ``Ascend`` ``GPU`` ``CPU``
6952
+
6953
+ Examples:
6954
+ >>> import mindspore
6955
+ >>> import numpy as np
6956
+ >>> from mindspore import Tensor, ops
6957
+ >>> from mindspore.ops.function.array_func import max_ext
6958
+ >>> y = Tensor(np.array([[0.0, 0.3, 0.4, 0.5, 0.1],
6959
+ ... [3.2, 0.4, 0.1, 2.9, 4.0]]), mindspore.float32)
6960
+ >>> output, index = max_ext(y, 0, True)
6961
+ >>> print(output, index)
6962
+ [[3.2 0.4 0.4 2.9 4. ]] [[1 1 0 1 1]]
6963
+ """
6964
+ if dim is None:
6965
+ if keepdim is not False:
6966
+ raise ValueError(
6967
+ f"For 'max', the `keepdim` must be False when the `dim` is None, but got {keepdim}")
6968
+ return max_(input)
6969
+ argmax_with_value_op = _get_cache_prim(ArgMaxWithValue)(dim, keepdim)
6970
+ indices, values = argmax_with_value_op(input)
6971
+ return values, indices
6972
+
6973
+
6974
+ def min_ext(input, dim=None, keepdim=False):
6975
+ """
6976
+ Calculates the minimum value along with the given dimension for the input tensor.
6977
+
6978
+ Args:
6979
+ input (Tensor): The input tensor, can be any dimension. Complex tensor is not supported for now.
6980
+ dim (int, optional): The dimension to reduce. Default: ``None`` .
6981
+ keepdim (bool, optional): Whether to reduce dimension, if true, the output will keep same dimension
6982
+ with the input, the output will reduce dimension if false. Default: ``False`` .
6983
+
6984
+ Returns:
6985
+ Tensor if `dim` is the default value ``None`` , the minimum value of input tensor, with the shape :math:`()` ,
6986
+ and same dtype as `input`.
6987
+
6988
+ tuple (Tensor) if `dim` is not the default value ``None`` , tuple of 2 tensors, containing the minimum value
6989
+ of the input tensor along the given dimension `dim` and the corresponding index.
6990
+
6991
+ - **values (Tensor)** - The minimum value of input tensor along the given dimension `dim`, with same dtype as
6992
+ `input`. If `keepdim` is ``True`` , the shape of output tensors is :math:`(input_1, input_2, ...,
6993
+ input_{axis-1}, 1, input_{axis+1}, ..., input_N)` . Otherwise, the shape is :math:`(input_1, input_2, ...,
6994
+ input_{axis-1}, input_{axis+1}, ..., input_N)` .
6995
+ - **index (Tensor)** - The index for the minimum value of the input tensor along the given dimension `dim`,
6996
+ with the same shape as `values`.
6997
+
6998
+ Raises:
6999
+ ValueError: If `dim` is the default value ``None`` and `keepdim` is not ``False`` .
7000
+
7001
+ Supported Platforms:
7002
+ ``Ascend`` ``GPU`` ``CPU``
7003
+
7004
+ Examples:
7005
+ >>> import mindspore
7006
+ >>> import numpy as np
7007
+ >>> from mindspore import Tensor, ops
7008
+ >>> from mindspore.ops.function.array_func import min_ext
7009
+ >>> x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)
7010
+ >>> output, index = min_ext(x, 0, keepdim=True)
7011
+ >>> print(output, index)
7012
+ [0.0] [0]
7013
+ """
7014
+ if dim is None:
7015
+ if keepdim is not False:
7016
+ raise ValueError(
7017
+ f"For 'min', the `keepdim` must be False when the `dim` is None, but got {keepdim}")
7018
+ return min_(input)
7019
+ argmin_with_value_op = _get_cache_prim(ArgMinWithValue)(dim, keepdim)
7020
+ indices, values = argmin_with_value_op(input)
7021
+ return values, indices
7022
+
7023
+
7024
+ def one_hot_ext(tensor, num_classes):
7025
+ r"""
7026
+ Computes a one-hot tensor.
7027
+
7028
+ The locations represented by tensor in `tensor` take value `1`, while all
7029
+ other locations take value `0`.
7030
+
7031
+ Args:
7032
+ tensor (Tensor): A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
7033
+ Data type must be int32 or int64.
7034
+ num_classes (int): A scalar defining the depth of the one-hot dimension.
7035
+
7036
+ Returns:
7037
+ Tensor, one-hot tensor.
7038
+
7039
+ Raises:
7040
+ TypeError: If `num_classes` is not an int.
7041
+ TypeError: If dtype of `tensor` is not int32 or int64.
7042
+ ValueError: If `num_classes` is less than 0.
7043
+
7044
+ Supported Platforms:
7045
+ ``Ascend`` ``GPU`` ``CPU``
7046
+
7047
+ Examples:
7048
+ >>> import mindspore
7049
+ >>> import numpy as np
7050
+ >>> from mindspore import ops
7051
+ >>> from mindspore import Tensor
7052
+ >>> from mindspore.ops.function.array_func import one_hot_ext
7053
+ >>> tensor = Tensor(np.array([0, 1, 2]), mindspore.int32)
7054
+ >>> num_classes = 3
7055
+ >>> output = one_hot_ext(tensor, num_classes)
7056
+ >>> print(output)
7057
+ [[1. 0. 0.]
7058
+ [0. 1. 0.]
7059
+ [0. 0. 1.]]
7060
+ """
7061
+ on_value = Tensor(1, dtype=tensor.dtype)
7062
+ off_value = Tensor(0, dtype=tensor.dtype)
7063
+ return one_hot_ext_impl(tensor, num_classes, on_value, off_value, -1)
7064
+
7065
+
7066
+ def from_numpy(array):
7067
+ r"""
7068
+ Convert numpy array to Tensor.
7069
+ If the data is not C contiguous, the data will be copied to C contiguous to construct the tensor.
7070
+ Otherwise, the tensor will be constructed using this numpy array without copy.
7071
+
7072
+ Args:
7073
+ array (numpy.array): The input array.
7074
+
7075
+ Returns:
7076
+ Tensor, has the same data type as input array.
7077
+
7078
+ Examples:
7079
+ >>> import numpy as np
7080
+ >>> import mindspore as ms
7081
+ >>> x = np.array([1, 2])
7082
+ >>> output = ms.from_numpy(x)
7083
+ >>> print(output)
7084
+ [1 2]
7085
+ """
7086
+ return Tensor.from_numpy(array)
7087
+
6750
7088
  __all__ = [
6751
7089
  'unique',
6752
7090
  'unique_with_pad',
@@ -6763,6 +7101,7 @@ __all__ = [
6763
7101
  'ones_like',
6764
7102
  'zeros',
6765
7103
  'zeros_like',
7104
+ 'zero_',
6766
7105
  'shape',
6767
7106
  'shape_',
6768
7107
  'reverse',