mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (287) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/initializer.py +51 -15
  26. mindspore/common/mindir_util.py +2 -2
  27. mindspore/common/parameter.py +62 -15
  28. mindspore/common/recompute.py +39 -9
  29. mindspore/common/sparse_tensor.py +7 -3
  30. mindspore/common/tensor.py +183 -37
  31. mindspore/communication/__init__.py +1 -1
  32. mindspore/communication/_comm_helper.py +38 -3
  33. mindspore/communication/comm_func.py +315 -60
  34. mindspore/communication/management.py +14 -14
  35. mindspore/context.py +132 -22
  36. mindspore/dataset/__init__.py +1 -1
  37. mindspore/dataset/audio/__init__.py +1 -1
  38. mindspore/dataset/core/config.py +7 -0
  39. mindspore/dataset/core/validator_helpers.py +7 -0
  40. mindspore/dataset/engine/cache_client.py +1 -1
  41. mindspore/dataset/engine/datasets.py +72 -44
  42. mindspore/dataset/engine/datasets_audio.py +7 -7
  43. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  44. mindspore/dataset/engine/datasets_text.py +20 -20
  45. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  46. mindspore/dataset/engine/datasets_vision.py +33 -33
  47. mindspore/dataset/engine/iterators.py +29 -0
  48. mindspore/dataset/engine/obs/util.py +7 -0
  49. mindspore/dataset/engine/queue.py +114 -60
  50. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  51. mindspore/dataset/engine/validators.py +34 -14
  52. mindspore/dataset/text/__init__.py +1 -4
  53. mindspore/dataset/transforms/__init__.py +0 -3
  54. mindspore/dataset/utils/line_reader.py +2 -0
  55. mindspore/dataset/vision/__init__.py +1 -4
  56. mindspore/dataset/vision/utils.py +1 -1
  57. mindspore/dataset/vision/validators.py +2 -1
  58. mindspore/dnnl.dll +0 -0
  59. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  60. mindspore/experimental/es/embedding_service.py +883 -0
  61. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  62. mindspore/experimental/llm_boost/__init__.py +21 -0
  63. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  64. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  65. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  66. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  67. mindspore/experimental/llm_boost/register.py +129 -0
  68. mindspore/experimental/llm_boost/utils.py +31 -0
  69. mindspore/experimental/optim/adamw.py +85 -0
  70. mindspore/experimental/optim/optimizer.py +3 -0
  71. mindspore/hal/__init__.py +3 -3
  72. mindspore/hal/contiguous_tensors_handle.py +175 -0
  73. mindspore/hal/stream.py +18 -0
  74. mindspore/include/api/model_group.h +13 -1
  75. mindspore/include/api/types.h +10 -10
  76. mindspore/include/dataset/config.h +2 -2
  77. mindspore/include/dataset/constants.h +2 -2
  78. mindspore/include/dataset/execute.h +2 -2
  79. mindspore/include/dataset/vision.h +4 -0
  80. mindspore/jpeg62.dll +0 -0
  81. mindspore/log.py +1 -1
  82. mindspore/mindrecord/filewriter.py +68 -51
  83. mindspore/mindspore_backend.dll +0 -0
  84. mindspore/mindspore_common.dll +0 -0
  85. mindspore/mindspore_core.dll +0 -0
  86. mindspore/mindspore_glog.dll +0 -0
  87. mindspore/mindspore_np_dtype.dll +0 -0
  88. mindspore/mindspore_ops.dll +0 -0
  89. mindspore/mint/__init__.py +983 -46
  90. mindspore/mint/distributed/__init__.py +31 -0
  91. mindspore/mint/distributed/distributed.py +254 -0
  92. mindspore/mint/nn/__init__.py +268 -23
  93. mindspore/mint/nn/functional.py +125 -19
  94. mindspore/mint/nn/layer/__init__.py +39 -0
  95. mindspore/mint/nn/layer/activation.py +133 -0
  96. mindspore/mint/nn/layer/normalization.py +477 -0
  97. mindspore/mint/nn/layer/pooling.py +110 -0
  98. mindspore/mint/optim/adamw.py +26 -13
  99. mindspore/mint/special/__init__.py +63 -0
  100. mindspore/multiprocessing/__init__.py +2 -1
  101. mindspore/nn/__init__.py +0 -1
  102. mindspore/nn/cell.py +276 -96
  103. mindspore/nn/layer/activation.py +211 -44
  104. mindspore/nn/layer/basic.py +137 -10
  105. mindspore/nn/layer/embedding.py +137 -2
  106. mindspore/nn/layer/normalization.py +101 -5
  107. mindspore/nn/layer/padding.py +34 -48
  108. mindspore/nn/layer/pooling.py +161 -7
  109. mindspore/nn/layer/transformer.py +3 -3
  110. mindspore/nn/loss/__init__.py +2 -2
  111. mindspore/nn/loss/loss.py +84 -6
  112. mindspore/nn/optim/__init__.py +2 -1
  113. mindspore/nn/optim/adadelta.py +1 -1
  114. mindspore/nn/optim/adam.py +1 -1
  115. mindspore/nn/optim/lamb.py +1 -1
  116. mindspore/nn/optim/tft_wrapper.py +124 -0
  117. mindspore/nn/wrap/cell_wrapper.py +12 -23
  118. mindspore/nn/wrap/grad_reducer.py +5 -5
  119. mindspore/nn/wrap/loss_scale.py +17 -3
  120. mindspore/numpy/__init__.py +1 -1
  121. mindspore/numpy/array_creations.py +65 -68
  122. mindspore/numpy/array_ops.py +64 -60
  123. mindspore/numpy/fft.py +610 -75
  124. mindspore/numpy/logic_ops.py +11 -10
  125. mindspore/numpy/math_ops.py +85 -84
  126. mindspore/numpy/utils_const.py +4 -4
  127. mindspore/opencv_core452.dll +0 -0
  128. mindspore/opencv_imgcodecs452.dll +0 -0
  129. mindspore/opencv_imgproc452.dll +0 -0
  130. mindspore/ops/__init__.py +6 -4
  131. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  132. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  133. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  134. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  135. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  136. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  139. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  140. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  141. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  142. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  143. mindspore/ops/composite/base.py +85 -48
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  145. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  146. mindspore/ops/function/__init__.py +22 -0
  147. mindspore/ops/function/array_func.py +492 -153
  148. mindspore/ops/function/debug_func.py +113 -1
  149. mindspore/ops/function/fft_func.py +15 -2
  150. mindspore/ops/function/grad/grad_func.py +3 -2
  151. mindspore/ops/function/math_func.py +564 -207
  152. mindspore/ops/function/nn_func.py +817 -383
  153. mindspore/ops/function/other_func.py +3 -2
  154. mindspore/ops/function/random_func.py +402 -12
  155. mindspore/ops/function/reshard_func.py +13 -11
  156. mindspore/ops/function/sparse_unary_func.py +1 -1
  157. mindspore/ops/function/vmap_func.py +3 -2
  158. mindspore/ops/functional.py +24 -14
  159. mindspore/ops/op_info_register.py +3 -3
  160. mindspore/ops/operations/__init__.py +7 -2
  161. mindspore/ops/operations/_grad_ops.py +2 -76
  162. mindspore/ops/operations/_infer_ops.py +1 -1
  163. mindspore/ops/operations/_inner_ops.py +71 -94
  164. mindspore/ops/operations/array_ops.py +14 -146
  165. mindspore/ops/operations/comm_ops.py +63 -53
  166. mindspore/ops/operations/custom_ops.py +83 -19
  167. mindspore/ops/operations/debug_ops.py +42 -10
  168. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  169. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  170. mindspore/ops/operations/math_ops.py +12 -223
  171. mindspore/ops/operations/nn_ops.py +20 -114
  172. mindspore/ops/operations/other_ops.py +7 -4
  173. mindspore/ops/operations/random_ops.py +46 -1
  174. mindspore/ops/primitive.py +18 -6
  175. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  176. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  177. mindspore/ops_generate/gen_constants.py +36 -0
  178. mindspore/ops_generate/gen_ops.py +67 -52
  179. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  180. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  181. mindspore/ops_generate/op_proto.py +10 -3
  182. mindspore/ops_generate/pyboost_utils.py +14 -1
  183. mindspore/ops_generate/template.py +43 -21
  184. mindspore/parallel/__init__.py +3 -1
  185. mindspore/parallel/_auto_parallel_context.py +31 -9
  186. mindspore/parallel/_cell_wrapper.py +85 -0
  187. mindspore/parallel/_parallel_serialization.py +47 -19
  188. mindspore/parallel/_tensor.py +127 -13
  189. mindspore/parallel/_utils.py +53 -22
  190. mindspore/parallel/algo_parameter_config.py +5 -5
  191. mindspore/parallel/checkpoint_transform.py +46 -39
  192. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  193. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  194. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  195. mindspore/parallel/parameter_broadcast.py +3 -4
  196. mindspore/parallel/shard.py +162 -31
  197. mindspore/parallel/transform_safetensors.py +1146 -0
  198. mindspore/profiler/__init__.py +2 -1
  199. mindspore/profiler/common/constant.py +29 -0
  200. mindspore/profiler/common/registry.py +47 -0
  201. mindspore/profiler/common/util.py +28 -0
  202. mindspore/profiler/dynamic_profiler.py +694 -0
  203. mindspore/profiler/envprofiling.py +17 -19
  204. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  205. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  206. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  207. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  208. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  209. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  210. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  211. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  212. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  213. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  214. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  215. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  216. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  217. mindspore/profiler/parser/framework_parser.py +1 -391
  218. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  219. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  220. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  221. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  222. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  223. mindspore/profiler/parser/profiler_info.py +78 -6
  224. mindspore/profiler/profiler.py +153 -0
  225. mindspore/profiler/profiling.py +285 -413
  226. mindspore/rewrite/__init__.py +1 -2
  227. mindspore/rewrite/common/namespace.py +4 -4
  228. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  229. mindspore/run_check/_check_version.py +39 -104
  230. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  231. mindspore/swresample-4.dll +0 -0
  232. mindspore/swscale-6.dll +0 -0
  233. mindspore/tinyxml2.dll +0 -0
  234. mindspore/train/__init__.py +4 -3
  235. mindspore/train/_utils.py +105 -19
  236. mindspore/train/amp.py +171 -53
  237. mindspore/train/callback/__init__.py +2 -2
  238. mindspore/train/callback/_callback.py +4 -4
  239. mindspore/train/callback/_checkpoint.py +97 -31
  240. mindspore/train/callback/_cluster_monitor.py +1 -1
  241. mindspore/train/callback/_flops_collector.py +1 -0
  242. mindspore/train/callback/_loss_monitor.py +3 -3
  243. mindspore/train/callback/_on_request_exit.py +145 -31
  244. mindspore/train/callback/_summary_collector.py +5 -5
  245. mindspore/train/callback/_tft_register.py +375 -0
  246. mindspore/train/dataset_helper.py +15 -3
  247. mindspore/train/metrics/metric.py +3 -3
  248. mindspore/train/metrics/roc.py +4 -4
  249. mindspore/train/mind_ir_pb2.py +44 -39
  250. mindspore/train/model.py +154 -58
  251. mindspore/train/serialization.py +342 -128
  252. mindspore/turbojpeg.dll +0 -0
  253. mindspore/utils/__init__.py +21 -0
  254. mindspore/utils/utils.py +60 -0
  255. mindspore/version.py +1 -1
  256. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  257. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
  258. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
  259. mindspore/include/c_api/ms/abstract.h +0 -67
  260. mindspore/include/c_api/ms/attribute.h +0 -197
  261. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  262. mindspore/include/c_api/ms/base/macros.h +0 -32
  263. mindspore/include/c_api/ms/base/status.h +0 -33
  264. mindspore/include/c_api/ms/base/types.h +0 -283
  265. mindspore/include/c_api/ms/context.h +0 -102
  266. mindspore/include/c_api/ms/graph.h +0 -160
  267. mindspore/include/c_api/ms/node.h +0 -606
  268. mindspore/include/c_api/ms/tensor.h +0 -161
  269. mindspore/include/c_api/ms/value.h +0 -84
  270. mindspore/mindspore_shared_lib.dll +0 -0
  271. mindspore/nn/extend/basic.py +0 -140
  272. mindspore/nn/extend/embedding.py +0 -143
  273. mindspore/nn/extend/layer/normalization.py +0 -109
  274. mindspore/nn/extend/pooling.py +0 -117
  275. mindspore/nn/layer/embedding_service.py +0 -531
  276. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  277. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  278. mindspore/ops/extend/__init__.py +0 -53
  279. mindspore/ops/extend/array_func.py +0 -218
  280. mindspore/ops/extend/math_func.py +0 -76
  281. mindspore/ops/extend/nn_func.py +0 -308
  282. mindspore/ops/silent_check.py +0 -162
  283. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  284. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  285. mindspore/train/callback/_mindio_ttp.py +0 -443
  286. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  287. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -15,45 +15,57 @@
15
15
  """mint module."""
16
16
  from __future__ import absolute_import
17
17
  import mindspore.ops as ops
18
- from mindspore.ops.extend import gather, conv2d, max, min
19
- from mindspore.ops.extend import array_func, math_func, nn_func
20
- from mindspore.mint.nn.functional import *
18
+ from mindspore.ops.primitive import constexpr
19
+ from mindspore.common._register_for_tensor import tensor_operator_registry_for_mint
20
+ from mindspore.common.tensor import Tensor
21
+ from mindspore.ops.function.array_func import gather_ext as gather, max_ext as max, min_ext as min
22
+ from mindspore.ops.function.nn_func import conv2d_ext as conv2d
23
+ from mindspore.mint.nn.functional import sigmoid
21
24
  from mindspore.mint.nn import functional
22
25
  from mindspore.mint import linalg
23
- from mindspore.ops import erf, where, triu
26
+ from mindspore.mint import special
27
+ from mindspore.mint import distributed
28
+ from mindspore.ops import erf, where
24
29
  from mindspore.ops.function.math_func import linspace_ext as linspace
25
- from mindspore.ops.function.array_func import full_ext as full
30
+ from mindspore.ops.function.math_func import median_ext as median
26
31
  from mindspore.ops.function.array_func import ones_like_ext as ones_like
32
+ from mindspore.ops.function.array_func import full_ext as full
27
33
  from mindspore.ops.function.array_func import zeros_like_ext as zeros_like
28
34
  from mindspore.ops.function.array_func import unique_ext as unique
35
+ from mindspore.ops.function.array_func import chunk_ext as chunk
29
36
  from mindspore.ops.function.math_func import isclose
30
37
  from mindspore.ops.auto_generate import abs
31
38
  # 1
32
39
  from mindspore.ops.function.math_func import divide, div
33
40
  from mindspore.ops.auto_generate import topk_ext as topk
41
+ from mindspore.ops.function.math_func import roll
34
42
  # 2
35
43
  from mindspore.ops.function.math_func import sin
36
44
  # 3
37
45
  from mindspore.ops.function.clip_func import clamp
38
46
  # 4
39
-
47
+ from mindspore.ops.auto_generate import sinc
48
+ from mindspore.ops.auto_generate import sinh
49
+ from mindspore.ops.auto_generate import cosh
50
+ from mindspore.ops.function.math_func import xlogy_ext as xlogy
40
51
  # 5
41
52
  from mindspore.ops.auto_generate import cumsum_ext as cumsum
42
53
  # 6
43
54
  from mindspore.ops.auto_generate import stack_ext as stack
44
55
 
45
56
  # 7
46
-
57
+ from mindspore.ops.function.array_func import unsqueeze
47
58
  # 8
48
-
59
+ from mindspore.ops.auto_generate import transpose_ext as transpose
49
60
  # 9
50
-
61
+ from mindspore.ops.auto_generate import masked_select
62
+ from mindspore.ops.function.math_func import cross
51
63
  # 10
52
64
  from mindspore.ops.function.math_func import ne
53
65
  # 11
54
66
 
55
67
  # 12
56
-
68
+ from mindspore.ops.function.array_func import repeat_interleave_ext as repeat_interleave
57
69
  # 13
58
70
  from mindspore.ops.functional import flip
59
71
  # 14
@@ -122,13 +134,13 @@ from mindspore.ops.functional import cos
122
134
  # 45
123
135
 
124
136
  # 46
125
-
137
+ from mindspore.ops.function.math_func import bitwise_and_ext as bitwise_and
126
138
  # 47
127
-
139
+ from mindspore.ops.function.math_func import bitwise_or_ext as bitwise_or
128
140
  # 48
129
-
141
+ from mindspore.ops.function.math_func import bitwise_xor_ext as bitwise_xor
130
142
  # 49
131
-
143
+ from mindspore.ops.function.math_func import baddbmm_ext as baddbmm
132
144
  # 50
133
145
  from mindspore.ops.functional import tile
134
146
  # 51
@@ -142,7 +154,7 @@ from mindspore.ops.function.random_func import normal_ext as normal
142
154
  # 55
143
155
 
144
156
  # 56
145
-
157
+ from mindspore.ops.function.math_func import norm_ext as norm
146
158
  # 57
147
159
  from mindspore.ops.functional import broadcast_to
148
160
  # 58
@@ -166,7 +178,7 @@ from mindspore.ops.functional import logical_not
166
178
  # 67
167
179
  from mindspore.ops.functional import logical_or
168
180
  # 68
169
-
181
+ from mindspore.ops.functional import logical_xor
170
182
  # 69
171
183
  from mindspore.ops.functional import less_equal, le
172
184
  # 70
@@ -194,7 +206,7 @@ from mindspore.ops.function import arange_ext as arange
194
206
  # 81
195
207
  from mindspore.ops.auto_generate import index_select_ext as index_select
196
208
  # 82
197
-
209
+ from mindspore.ops.auto_generate import cummin_ext as cummin
198
210
  # 83
199
211
  from mindspore.ops.function.array_func import narrow_ext as narrow
200
212
  # 84
@@ -204,9 +216,9 @@ from mindspore.mint import nn, optim
204
216
  # 86
205
217
 
206
218
  # 87
207
-
219
+ from mindspore.ops.auto_generate import trunc
208
220
  # 88
209
- from mindspore.ops.function.array_func import chunk_ext as chunk
221
+
210
222
  # 89
211
223
 
212
224
  # 90
@@ -231,24 +243,135 @@ from mindspore.ops.function.math_func import tanh
231
243
 
232
244
  # 100
233
245
 
246
+ # 101
247
+
248
+ # 102
249
+
250
+ # 103
251
+
252
+ # 104
253
+
254
+ # 105
255
+
256
+ # 106
257
+
258
+ # 107
259
+
260
+ # 108
261
+
262
+ # 109
263
+ from mindspore.ops.auto_generate import argmin_ext as argmin
264
+ # 110
265
+
266
+ # 111
267
+
268
+ # 112
269
+
270
+ # 113
271
+
272
+ # 114
273
+
274
+ # 115
275
+
276
+ # 116
277
+
278
+ # 117
279
+
280
+ # 118
281
+
282
+ # 119
283
+
284
+ # 120
285
+
286
+ # 121
287
+
234
288
  # 122
235
289
 
290
+ # 151
291
+ from mindspore.ops.function.math_func import acos_ext as acos
292
+ from mindspore.ops.function.math_func import arccos_ext as arccos
293
+ # 152
294
+ from mindspore.ops.function.math_func import acosh_ext as acosh
295
+ from mindspore.ops.function.math_func import arccosh_ext as arccosh
296
+ # 172
297
+ from mindspore.ops.function.math_func import asin_ext as asin
298
+ from mindspore.ops.function.math_func import arcsin_ext as arcsin
299
+ # 173
300
+ from mindspore.ops.function.math_func import asinh_ext as asinh
301
+ from mindspore.ops.function.math_func import arcsinh_ext as arcsinh
302
+ # 174
303
+ from mindspore.ops.function.math_func import atan_ext as atan
304
+ from mindspore.ops.function.math_func import arctan_ext as arctan
305
+ # 175
306
+ from mindspore.ops.function.math_func import atanh
307
+ from mindspore.ops.function.math_func import arctanh
236
308
  # 176
237
309
  from mindspore.ops.function.math_func import atan2_ext as atan2
238
310
  from mindspore.ops.function.math_func import arctan2_ext as arctan2
239
311
 
312
+ # 177
313
+ from mindspore.ops.function.math_func import round
240
314
 
315
+ # 182
316
+ from mindspore.ops.function.math_func import bernoulli_ext as bernoulli
317
+
318
+ # 204
319
+ from mindspore.ops.auto_generate import erfc
320
+ # 207
321
+ from mindspore.ops.auto_generate import expm1
241
322
  # 208
242
323
  from mindspore.ops.function.array_func import eye
324
+ from mindspore.ops.function.random_func import randperm_ext as randperm
243
325
  from mindspore.ops.function.random_func import rand_ext as rand
244
326
  from mindspore.ops.function.random_func import rand_like_ext as rand_like
327
+ from mindspore.ops.function.random_func import randn_ext as randn
328
+ from mindspore.ops.function.random_func import randn_like_ext as randn_like
329
+ from mindspore.ops.function.random_func import randint_ext as randint
330
+ from mindspore.ops.function.random_func import randint_like_ext as randint_like
245
331
  # 210
246
332
  from mindspore.ops.auto_generate import floor
247
333
  # 231
248
334
  from mindspore.ops.function.math_func import inverse_ext as inverse
249
-
335
+ # 244
336
+ from mindspore.ops.auto_generate import log1p
337
+ # 261
338
+ from mindspore.ops.function.random_func import multinomial_ext as multinomial
339
+ # 275
340
+ from mindspore.ops.function.math_func import remainder_ext as remainder
250
341
  # 285
251
342
  from mindspore.ops.function.array_func import scatter_add_ext as scatter_add
343
+ # 289
344
+ from mindspore.ops.auto_generate import sign
345
+
346
+ from mindspore.ops.auto_generate import select_ext as select
347
+
348
+ # 301
349
+ from mindspore.ops.function.math_func import tan
350
+
351
+ # 303
352
+ from mindspore.ops.auto_generate import trace_ext as trace
353
+
354
+ from mindspore.ops.function.array_func import reshape
355
+
356
+ from mindspore.ops.auto_generate import outer_ext as outer
357
+
358
+ # 304
359
+ from mindspore.ops.function.array_func import tril_ext as tril
360
+
361
+ # 305
362
+ from mindspore.ops import triu
363
+
364
+ # 538
365
+ from mindspore.ops.function.math_func import histc_ext as histc
366
+
367
+ # 553
368
+ from mindspore.ops.auto_generate import logaddexp_ext as logaddexp
369
+
370
+ # 610
371
+ from mindspore.ops.function.math_func import nan_to_num
372
+
373
+ # 695
374
+ from mindspore.ops.auto_generate import count_nonzero
252
375
 
253
376
 
254
377
  def add(input, other, *, alpha=1):
@@ -268,12 +391,12 @@ def add(input, other, *, alpha=1):
268
391
  Args:
269
392
  input (Union[Tensor, number.Number, bool]): The first input is a number.Number or
270
393
  a bool or a tensor whose data type is
271
- `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
272
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
394
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
395
+ `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
273
396
  other (Union[Tensor, number.Number, bool]): The second input, is a number.Number or
274
397
  a bool or a tensor whose data type is
275
- `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
276
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
398
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
399
+ `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
277
400
 
278
401
  Keyword Args:
279
402
  alpha (number.Number): A scaling factor applied to `other`, default 1.
@@ -429,7 +552,6 @@ def all(input, dim=None, keepdim=False):
429
552
  return ops.function.math_func.all(input, dim, keepdim)
430
553
 
431
554
 
432
-
433
555
  def cat(tensors, dim=0):
434
556
  r"""
435
557
  Connect input tensors along with the given dimension.
@@ -487,6 +609,573 @@ def cat(tensors, dim=0):
487
609
  return ops.auto_generate.cat(tensors, dim)
488
610
 
489
611
 
612
+ def concat(tensors, dim=0):
613
+ r"""
614
+ .. warning::
615
+ This is an experimental API that is subject to change or deletion.
616
+
617
+ Alias of mint.cat().
618
+ """
619
+ return cat(tensors, dim)
620
+
621
+
622
+ def cummax(input, dim):
623
+ r"""
624
+ Returns a tuple (values, indices) where `values` is the cumulative maximum value of input Tensor `input`
625
+ along the dimension `dim`, and `indices` is the index location of each maximum value.
626
+
627
+ .. math::
628
+ \begin{array}{ll} \\
629
+ y_{i} = \max(x_{1}, x_{2}, ... , x_{i})
630
+ \end{array}
631
+
632
+ Args:
633
+ input (Tensor): The input Tensor. Rank of `input` must be greater than 0.
634
+ dim (int): The dimension to do the operation over. The value of `dim` must be in the range
635
+ `[-input.ndim, input.ndim - 1]`.
636
+
637
+ Returns:
638
+ tuple [Tensor], tuple of 2 Tensors, containing the cumulative maximum of elements and the index.
639
+ The shape of each output tensor is the same as that of input `input`.
640
+
641
+ Raises:
642
+ TypeError: If `input` is not a Tensor.
643
+ TypeError: If `dim` is not an int.
644
+ ValueError: If `dim` is out the range of `[-input.ndim, input.ndim - 1]`.
645
+
646
+ .. note::
647
+ O2 mode is not supported in Ascend.
648
+
649
+ Supported Platforms:
650
+ ``Ascend``
651
+
652
+ Examples:
653
+ >>> import mindspore
654
+ >>> import numpy as np
655
+ >>> from mindspore import Tensor
656
+ >>> from mindspore import ops
657
+ >>> x = Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float32))
658
+ >>> output = mint.cummax(x, dim=0)
659
+ >>> print(output[0])
660
+ [[ 3. 4. 6. 10.]
661
+ [ 3. 6. 7. 10.]
662
+ [ 4. 6. 8. 10.]
663
+ [ 4. 6. 8. 10.]]
664
+ >>> print(output[1])
665
+ [[0 0 0 0]
666
+ [0 1 1 0]
667
+ [2 1 2 0]
668
+ [2 1 2 0]]
669
+ """
670
+ return ops.auto_generate.cummax(input, dim)
671
+
672
+
673
+ def _einsum_convert_sublist_to_label(num, ell_num=False):
674
+ """Convert sublist to label."""
675
+ if num == Ellipsis or ell_num and num == 52:
676
+ return '...'
677
+ if 0 <= num < 26:
678
+ return chr(num + ord('A'))
679
+ if 26 <= num < 52:
680
+ return chr(num + ord('a') - 26)
681
+ raise ValueError(f'For einsum, the number in sublist must be in range [0, 52), but got {num}')
682
+
683
+
684
+ def _einsum_convert_label_to_index(label):
685
+ """Convert label to index."""
686
+ label_num = ord(label)
687
+ if ord('A') <= label_num <= ord('Z'):
688
+ return label_num - ord('A')
689
+ if ord('a') <= label_num <= ord('z'):
690
+ return label_num - ord('a') + 26
691
+ if label_num == ord('.'):
692
+ return 52
693
+ raise ValueError(f'For einsum, the label in equation must be in [a-zA-Z] or ., but got {label}')
694
+
695
+
696
+ def _einsum_convert_sublist(equation, *operands):
697
+ """Convert the sublist to an equation operand if the received input is a sublist format."""
698
+ if isinstance(equation, Tensor):
699
+ equation_tmp = ''
700
+ for i, lst in enumerate(operands):
701
+ if i % 2 == 0:
702
+ for _, num in enumerate(lst):
703
+ equation_tmp += _einsum_convert_sublist_to_label(num)
704
+ if i in (len(operands) - 1, len(operands) - 2):
705
+ continue
706
+ equation_tmp += ','
707
+ if len(operands) % 2 == 0:
708
+ equation_tmp += '->'
709
+ for _, num in enumerate(operands[-1]):
710
+ equation_tmp += _einsum_convert_sublist_to_label(num)
711
+ operands_tmp = list([equation]) + list(operands[1:-1:2])
712
+ else:
713
+ operands_tmp = list([equation]) + list(operands[1::2])
714
+ equation = equation_tmp
715
+ operands = tuple(operands_tmp)
716
+ if len(operands) == 0: # pylint: disable=len-as-condition
717
+ raise ValueError("For einsum, the 'operands' must have at least one operand.")
718
+ return equation, operands
719
+
720
+
721
+ def _einsum_check_inputargs(equation, operands):
722
+ """Check equation and operands."""
723
+ if not isinstance(equation, str):
724
+ raise TypeError(f"For einsum, 'equation' must be a str, but got {type(equation)}.")
725
+ for operand in operands:
726
+ if not isinstance(operand, Tensor):
727
+ raise TypeError(f"For einsum, members of 'operands' must be Tensor, but got {type(operand)}.")
728
+
729
+
730
+ @constexpr
731
+ def _einsum_parse_equation(equation):
732
+ """Parse equation."""
733
+ l_equation = ''
734
+ r_equation = ''
735
+ equation = equation.replace(' ', '')
736
+
737
+ if '->' in equation:
738
+ l_equation, r_equation = equation.split('->', 1)
739
+ if l_equation == '':
740
+ raise ValueError('For einsum, equation must contain characters to the left fo the arrow.')
741
+ else:
742
+ l_equation = equation
743
+
744
+ if ',' in l_equation:
745
+ l_equationlst = l_equation.split(",")
746
+ else:
747
+ l_equationlst = [l_equation]
748
+
749
+ l_equationlst = []
750
+
751
+ for subequation in l_equation.split(','):
752
+ if '.' in subequation and ('...' not in subequation or subequation.count('.') != 3):
753
+ raise ValueError(f"For einsum, an ellipsis in the equation must include three continuous \'.\', "
754
+ f"and can only be found once.")
755
+ subequation_lst = [_einsum_convert_label_to_index(label) for label in subequation.replace('...', '.')]
756
+ l_equationlst.append(subequation_lst)
757
+
758
+ if "." in r_equation and ('...' not in r_equation or r_equation.count('.') != 3):
759
+ raise ValueError(f"For einsum, an ellipsis in the equation must include three continuous \'.\', "
760
+ f"and can only be found once.")
761
+ r_equationlst = [_einsum_convert_label_to_index(label) for label in r_equation.replace('...', '.')]
762
+
763
+ return l_equationlst, r_equationlst, ('->' in equation)
764
+
765
+
766
+ def _einsum_parse_labels(l_equationlst, operands):
767
+ """Parse left script of equation."""
768
+ align_rank = 0
769
+ max_labels = 53
770
+ labels_count = [0] * max_labels
771
+ labels2dimlst = [None] * max_labels
772
+
773
+ if len(operands) != len(l_equationlst):
774
+ raise ValueError(f"For einsum, 'operands' is not equal to specified in the 'equation', "
775
+ f"but got {len(operands)} and {len(l_equationlst)}.")
776
+
777
+ for idx, sub_equ in enumerate(l_equationlst):
778
+ start_dim = 0
779
+ label_num = 0
780
+ operand_shape = list(operands[idx].shape)
781
+ for label in sub_equ:
782
+ label_num += 1
783
+ end_dim = start_dim + 1
784
+
785
+ # Label is ellipsis
786
+ if label == 52:
787
+ end_dim = len(operand_shape) - len(sub_equ) + label_num
788
+ if labels2dimlst[label] is None:
789
+ labels2dimlst[label] = operand_shape[start_dim:end_dim]
790
+ align_rank += (end_dim - start_dim)
791
+ else:
792
+ if labels2dimlst[label] != operand_shape[start_dim:end_dim]:
793
+ raise ValueError(f"For einsum, one label in 'equation' can only represent the same dimension "
794
+ f"in 'operands', but '{_einsum_convert_sublist_to_label(label, True)}' "
795
+ f"represented different dimensions.")
796
+ labels_count[label] += 1
797
+ start_dim = end_dim
798
+ if label_num != len(sub_equ) or start_dim != len(operand_shape):
799
+ raise ValueError(f"For einsum, the numbers of labels specified in the 'equation' does not match "
800
+ f"'operands[{idx}]'.")
801
+ return labels2dimlst, labels_count, align_rank
802
+
803
+
804
+ def _einsum_infer_output(r_equationlst, arrow_exist, labels2dimlst, labels_count):
805
+ """Parse right script of equation and infer output shape."""
806
+ idx = 0
807
+ idle_idx = -1
808
+ output_shape = []
809
+ labels_perm_idx = [idle_idx] * 53
810
+
811
+ if arrow_exist:
812
+ for label in r_equationlst:
813
+ if labels_count[label] != 0:
814
+ output_shape += labels2dimlst[label]
815
+ if labels_perm_idx[label] != idle_idx:
816
+ raise ValueError(f"For einsum, '{_einsum_convert_sublist_to_label(label, True)}' or {label} in "
817
+ f"sublist format has appears more than once in output subscript.")
818
+ labels_perm_idx[label] = idx
819
+ idx += len(labels2dimlst[label])
820
+ else:
821
+ raise ValueError(f"For einsum, the label to the right of arrow in the 'equation' must appear on "
822
+ f"left, but '{_einsum_convert_sublist_to_label(label, True)}' does not.")
823
+ else:
824
+ if labels_count[52] != 0:
825
+ output_shape += labels2dimlst[52]
826
+ labels_perm_idx[52] = idx
827
+ idx += len(labels2dimlst[52])
828
+ for label, count in enumerate(labels_count):
829
+ if count == 1:
830
+ output_shape += labels2dimlst[label]
831
+ labels_perm_idx[label] = idx
832
+ idx += len(labels2dimlst[label])
833
+
834
+ for label, count in enumerate(labels_count):
835
+ if count != 0 and labels_perm_idx[label] == idle_idx:
836
+ labels_perm_idx[label] = idx
837
+ idx += 1
838
+
839
+ return output_shape, labels_perm_idx
840
+
841
+
842
+ def _einsum_adjust_operands(operands, l_equationlst, labels2dimlst, labels_perm_idx, align_rank):
843
+ """Align operands to output as possible."""
844
+ # Unsqueeze miss dimensions to make all operands has same rank, compute diagonal if operand has same label.
845
+ # Then use _labels_perm_idx to transpose all operands to align dimensions with output.
846
+ adjust_operands = []
847
+ for idx, operand in enumerate(operands):
848
+ idle_dim = -1
849
+ align_axis = [idle_dim] * align_rank
850
+ label_dims = [idle_dim] * 53
851
+ dim = 0
852
+
853
+ for label in l_equationlst[idx]:
854
+ if label_dims[label] != idle_dim:
855
+ operand = ops.diagonal(operand, 0, label_dims[label], dim)
856
+ diag_perm = []
857
+ diag_dim = 0
858
+ for i in range(len(operand.shape)):
859
+ if i == label_dims[label]:
860
+ diag_perm.append(len(operand.shape) - 1)
861
+ else:
862
+ diag_perm.append(diag_dim)
863
+ diag_dim += 1
864
+ operand = permute(operand, tuple(diag_perm))
865
+ else:
866
+ label_dims[label] = dim
867
+ if label == 52:
868
+ for ell_idx in range(len(labels2dimlst[label])):
869
+ align_axis[labels_perm_idx[label] + ell_idx] = dim
870
+ dim += 1
871
+ else:
872
+ align_axis[labels_perm_idx[label]] = dim
873
+ dim += 1
874
+ if len(operand.shape) < align_rank:
875
+ for i, axis in enumerate(align_axis):
876
+ if axis == idle_dim:
877
+ align_axis[i] = dim
878
+ dim += 1
879
+ missing_dims = [1] * (align_rank - len(operand.shape))
880
+ operand_shape = list(operand.shape) + missing_dims
881
+ operand = reshape(operand, operand_shape)
882
+ operand = permute(operand, tuple(align_axis))
883
+ adjust_operands.append(operand)
884
+ return adjust_operands
885
+
886
+
887
+ def _einsum_find_dimlastop(align_rank, operands, adjust_operands):
888
+ """Find dim last operand."""
889
+ dim_last_op = [0 for _ in range(align_rank)]
890
+ has_zero_dim = False
891
+ for dim in range(align_rank):
892
+ broadcast_dim = adjust_operands[0].shape[dim]
893
+ for idx in range(1, len(adjust_operands)):
894
+ other_dim = adjust_operands[idx].shape[dim]
895
+ if broadcast_dim != other_dim and broadcast_dim != 1 and other_dim != 1:
896
+ err_msg = "For einsum, operands do not broadcast after align to output [shapes :origin -> adjust]:"
897
+ for i in range(len(operands)):
898
+ err_msg += f" {operands[i].shape} -> {adjust_operands[i].shape}"
899
+ raise ValueError(err_msg)
900
+ if other_dim != 1:
901
+ dim_last_op[dim] = idx
902
+ broadcast_dim = other_dim
903
+ has_zero_dim = has_zero_dim or broadcast_dim == 0
904
+ return dim_last_op, has_zero_dim
905
+
906
+
907
+ def _einsum_multiplication(sum_dims, l_tensor, r_tensor):
908
+ """Compute bmm for einsum."""
909
+ batch_dims = []
910
+ lonly_dims = []
911
+ ronly_dims = []
912
+ batch_size = 1
913
+ lonly_size = 1
914
+ ronly_size = 1
915
+ sum_size = 1
916
+
917
+ l_shape = l_tensor.shape
918
+ r_shape = r_tensor.shape
919
+
920
+ # Compute sum if dim is in sum_dims and get shapes for bmm
921
+ for i in range(len(l_shape)):
922
+ sum_l = l_shape[i] > 1
923
+ sum_r = r_shape[i] > 1
924
+ if i in sum_dims:
925
+ if sum_l and sum_r:
926
+ sum_size *= l_shape[i]
927
+ elif sum_l:
928
+ l_tensor = sum(l_tensor, i, True)
929
+ elif sum_r:
930
+ r_tensor = sum(r_tensor, i, True)
931
+ elif sum_l and sum_r:
932
+ batch_dims.append(i)
933
+ batch_size *= l_shape[i]
934
+ elif sum_l:
935
+ lonly_dims.append(i)
936
+ lonly_size *= l_shape[i]
937
+ else:
938
+ ronly_dims.append(i)
939
+ ronly_size *= r_shape[i]
940
+
941
+ # Compute the einsum bmm operators pipeline.
942
+ # The whole operators pipline is transpose(in) -> reshape(in) -> bmm(in) -> reshape(out) -> transpose(out).
943
+ l_reshape_shape = (batch_size, lonly_size, sum_size)
944
+ r_reshape_shape = (batch_size, sum_size, ronly_size)
945
+
946
+ out_reshape_shape = [l_shape[dim] for dim in batch_dims]
947
+ out_reshape_shape += [l_shape[dim] for dim in lonly_dims]
948
+ out_reshape_shape += [1 for _ in sum_dims]
949
+ out_reshape_shape += [r_shape[dim] for dim in ronly_dims]
950
+
951
+ l_perm_axis = batch_dims + lonly_dims + sum_dims + ronly_dims
952
+ r_perm_axis = batch_dims + sum_dims + ronly_dims + lonly_dims
953
+ out_perm_axis = [-1] * len(out_reshape_shape)
954
+
955
+ out_dim = 0
956
+ for idx in range(len(l_perm_axis)):
957
+ out_perm_axis[l_perm_axis[idx]] = out_dim
958
+ out_dim += 1
959
+
960
+ l_tensor = permute(l_tensor, tuple(l_perm_axis))
961
+ l_tensor = reshape(l_tensor, l_reshape_shape)
962
+
963
+ r_tensor = permute(r_tensor, tuple(r_perm_axis))
964
+ r_tensor = reshape(r_tensor, r_reshape_shape)
965
+
966
+ output = bmm(l_tensor, r_tensor)
967
+ output = reshape(output, out_reshape_shape)
968
+ output = permute(output, tuple(out_perm_axis))
969
+
970
+ output_origin_shape = output.shape
971
+ output_squeeze_shape = []
972
+ for dim in range(len(output_origin_shape)):
973
+ if dim not in sum_dims:
974
+ output_squeeze_shape.append(output_origin_shape[dim])
975
+
976
+ return reshape(output, output_squeeze_shape)
977
+
978
+
979
+ def _einsum_squeeze(operand, dim):
980
+ '''Will be replaced by mint.squeeze in the future'''
981
+ operand_shape = operand.shape
982
+ squeeze_shape = []
983
+ for idx in range(len(operand_shape)):
984
+ if idx != dim:
985
+ squeeze_shape.append(operand_shape[idx])
986
+ return reshape(operand, squeeze_shape)
987
+
988
+
989
+ def _einsum(equation, operands):
990
+ '''Einsum main process'''
991
+ _l_equationlst, _r_equationlst, _arrow_exist = _einsum_parse_equation(equation)
992
+ _labels2dimlst, _labels_count, _align_rank = _einsum_parse_labels(_l_equationlst, operands)
993
+ _output_shape, _labels_perm_idx = _einsum_infer_output(_r_equationlst, _arrow_exist, _labels2dimlst, _labels_count)
994
+ _output_rank = len(_output_shape)
995
+
996
+ _adjust_operands = _einsum_adjust_operands(operands, _l_equationlst, _labels2dimlst, _labels_perm_idx, _align_rank)
997
+ _dim_last_op, _has_zero_dim = _einsum_find_dimlastop(_align_rank, operands, _adjust_operands)
998
+ _result = _adjust_operands[0]
999
+
1000
+ # Fast path if operands has zero dim.
1001
+ if _has_zero_dim:
1002
+ return zeros(_output_shape, dtype=_result.dtype)
1003
+
1004
+ # Sum or squeeze dimensions that is 1 for all rest operands.
1005
+ _reduce_dim = _output_rank
1006
+ for dim in range(_output_rank, _align_rank):
1007
+ if _dim_last_op[dim] == 0:
1008
+ if _result.shape[_reduce_dim] == 1:
1009
+ _result = _einsum_squeeze(_result, _reduce_dim)
1010
+ else:
1011
+ _result = sum(_result, _reduce_dim)
1012
+ else:
1013
+ _reduce_dim += 1
1014
+
1015
+ # Compute multiplication if operands are more than two.
1016
+ for i in range(1, len(_adjust_operands)):
1017
+ operand = _adjust_operands[i]
1018
+ dim = _output_rank
1019
+ sum_dims = []
1020
+ for j in range(_output_rank, _align_rank):
1021
+ if _dim_last_op[j] < i:
1022
+ operand = _einsum_squeeze(operand, dim)
1023
+ elif _dim_last_op[j] == i:
1024
+ if _result.shape[dim] == 1:
1025
+ operand = sum(operand, dim)
1026
+ _result = _einsum_squeeze(_result, dim)
1027
+ else:
1028
+ sum_dims.append(dim)
1029
+ dim += 1
1030
+ else:
1031
+ dim += 1
1032
+
1033
+ if sum_dims == []:
1034
+ _result = mul(_result, operand)
1035
+ elif len(sum_dims) == len(_result.shape):
1036
+ _result = ops.auto_generate.dot(flatten(_result), flatten(operand))
1037
+ else:
1038
+ _result = _einsum_multiplication(sum_dims, _result, operand)
1039
+
1040
+ return _result
1041
+
1042
+
1043
+ def einsum(equation, *operands):
1044
+ r"""
1045
+ According to the Einstein summation Convention (Einsum),
1046
+ the product of the input tensor elements is summed along the specified dimension.
1047
+ You can use this operator to perform diagonal, reducesum, transpose, matmul, mul, inner product operations, etc.
1048
+
1049
+ Note:
1050
+ The sublist format is also supported. For example, mint.einsum(op1, sublist1, op2, sublist2, ..., sublist_out).
1051
+ In this format, equation can be derived by the sublists which are made up of Python's Ellipsis and list of
1052
+ integers in [0, 52). Each operand is followed by a sublist and an output sublist is at the end.
1053
+
1054
+ .. warning::
1055
+ This is an experimental API that is subject to change or deletion.
1056
+
1057
+ Args:
1058
+ equation (str): Notation based on the Einstein summation convention, represent the operation you want to do.
1059
+ the value can contain only letters, commas, ellipsis and arrow.
1060
+ The letters represent input tensor dimension, commas represent separate tensors, ellipsis indicates
1061
+ the tensor dimension that you do not care about, the left of the arrow indicates the input tensors,
1062
+ and the right of it indicates the desired output dimension.
1063
+ operands (Tensor): Input tensor used for calculation. The dtype of the tensor must be the same.
1064
+
1065
+ Returns:
1066
+ Tensor, the shape of it can be obtained from the `equation` , and the dtype is the same as input tensors.
1067
+
1068
+ Raises:
1069
+ TypeError: If `equation` is invalid, or the `equation` does not match the input tensor.
1070
+ ValueError: If the number in sublist is not in [0, 52) in sublist format.
1071
+
1072
+ Supported Platforms:
1073
+ ``Ascend``
1074
+
1075
+ Examples:
1076
+ >>> import mindspore
1077
+ >>> import numpy as np
1078
+ >>> from mindspore import Tensor, mint
1079
+ >>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
1080
+ >>> equation = "i->"
1081
+ >>> output = mint.einsum(equation, x)
1082
+ >>> print(output)
1083
+ [7.]
1084
+ >>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
1085
+ >>> y = Tensor(np.array([2.0, 4.0, 3.0]), mindspore.float32)
1086
+ >>> equation = "i,i->i"
1087
+ >>> output = mint.einsum(equation, x, y)
1088
+ >>> print(output)
1089
+ [ 2. 8. 12.]
1090
+ >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
1091
+ >>> y = Tensor(np.array([[2.0, 3.0], [1.0, 2.0], [4.0, 5.0]]), mindspore.float32)
1092
+ >>> equation = "ij,jk->ik"
1093
+ >>> output = mint.einsum(equation, x, y)
1094
+ >>> print(output)
1095
+ [[16. 22.]
1096
+ [37. 52.]]
1097
+ >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
1098
+ >>> equation = "ij->ji"
1099
+ >>> output = mint.einsum(equation, x)
1100
+ >>> print(output)
1101
+ [[1. 4.]
1102
+ [2. 5.]
1103
+ [3. 6.]]
1104
+ >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
1105
+ >>> equation = "ij->j"
1106
+ >>> output = mint.einsum(equation, x)
1107
+ >>> print(output)
1108
+ [5. 7. 9.]
1109
+ >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
1110
+ >>> equation = "...->"
1111
+ >>> output = mint.einsum(equation, x)
1112
+ >>> print(output)
1113
+ [21.]
1114
+ >>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
1115
+ >>> y = Tensor(np.array([2.0, 4.0, 1.0]), mindspore.float32)
1116
+ >>> equation = "j,i->ji"
1117
+ >>> output = mint.einsum(equation, x, y)
1118
+ >>> print(output)
1119
+ [[ 2. 4. 1.]
1120
+ [ 4. 8. 2.]
1121
+ [ 6. 12. 3.]]
1122
+ >>> x = mindspore.Tensor([1, 2, 3, 4], mindspore.float32)
1123
+ >>> y = mindspore.Tensor([1, 2], mindspore.float32)
1124
+ >>> output = mint.einsum(x, [..., 1], y, [..., 2], [..., 1, 2])
1125
+ >>> print(output)
1126
+ [[1. 2.]
1127
+ [2. 4.]
1128
+ [3. 6.]
1129
+ [4. 8.]]
1130
+ """
1131
+ _equation, _operands = _einsum_convert_sublist(equation, *operands)
1132
+ _einsum_check_inputargs(_equation, _operands)
1133
+
1134
+ for operand in _operands:
1135
+ if ops.is_sequence_shape_unknown(operand.shape) or ops.is_sequence_value_unknown(operand.shape):
1136
+ raise ValueError(f"For einsum, the element of 'operands' can't be dynamic shape or dynamic rank.")
1137
+
1138
+ return _einsum(_equation, _operands)
1139
+
1140
+
1141
+ def item(input):
1142
+ r"""
1143
+ Returns the value of this tensor as a standard Python number.
1144
+
1145
+ Note:
1146
+ This only works for tensors with one element.
1147
+
1148
+ Args:
1149
+ input (Tensor[Number]): The input tensor. The dtype of the tensor to be reduced is number.
1150
+ :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
1151
+
1152
+ Returns:
1153
+ number.
1154
+
1155
+ Raises:
1156
+ TypeError: If `input` is not a Tensor.
1157
+ RuntimeError: If the number of `input` elements is not 1.
1158
+
1159
+ Supported Platforms:
1160
+ ``Ascend`` ``GPU`` ``CPU``
1161
+
1162
+ Examples:
1163
+ >>> import mindspore
1164
+ >>> import numpy as np
1165
+ >>> from mindspore import Tensor, mint
1166
+ >>> x = Tensor(np.array([1]).astype(np.float32))
1167
+ >>> result = mint.item(x)
1168
+ >>> print(result)
1169
+ 1.0
1170
+ """
1171
+ if not isinstance(input, Tensor):
1172
+ raise TypeError(f"the input must be a Tensor, but got {type(input)}")
1173
+ if input.size != 1:
1174
+ raise RuntimeError(
1175
+ "a Tensor with {} elements cannot be converted to Scalar".format(input.size))
1176
+ return input.asnumpy().item()
1177
+
1178
+
490
1179
  def mean(input, dim=None, keepdim=False, *, dtype=None):
491
1180
  r"""
492
1181
  Reduces all dimension of a tensor by averaging all elements in the dimension, by default.
@@ -818,12 +1507,12 @@ def sub(input, other, *, alpha=1):
818
1507
  Args:
819
1508
  input (Union[Tensor, number.Number, bool]): The first input is a number.Number or
820
1509
  a bool or a tensor whose data type is
821
- `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
822
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
1510
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
1511
+ `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
823
1512
  other (Union[Tensor, number.Number, bool]): The second input, is a number.Number or
824
1513
  a bool or a tensor whose data type is
825
- `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
826
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
1514
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
1515
+ `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
827
1516
 
828
1517
  Keyword Args:
829
1518
  alpha (number.Number): A scaling factor applied to `other`, default 1.
@@ -859,6 +1548,22 @@ def sub(input, other, *, alpha=1):
859
1548
  return ops.auto_generate.sub_ext(input, other, alpha)
860
1549
 
861
1550
 
1551
+ def swapaxes(input, axis0, axis1):
1552
+ '''
1553
+ Interchange two axes of a tensor, alias for mint.transpose()
1554
+
1555
+ Examples:
1556
+ >>> import numpy as np
1557
+ >>> from mindspore import mint
1558
+ >>> from mindspore import Tensor
1559
+ >>> input = Tensor(np.ones((2,3,4), dtype=np.float32))
1560
+ >>> output = mint.swapaxes(input, 0, 2)
1561
+ >>> print(output.shape)
1562
+ (4, 3, 2)
1563
+ '''
1564
+ return transpose(input, axis0, axis1)
1565
+
1566
+
862
1567
  def zeros(size, *, dtype=None):
863
1568
  """
864
1569
  Creates a tensor filled with 0 with shape described by `size` and fills it with value 0 in type of `dtype`.
@@ -892,25 +1597,112 @@ def zeros(size, *, dtype=None):
892
1597
  return ops.auto_generate.zeros(size, dtype)
893
1598
 
894
1599
 
1600
+ def fix(input):
1601
+ """
1602
+ Alias for :func:`mindspore.mint.trunc` .
1603
+
1604
+ For more details, see :func:`mindspore.mint.trunc` .
1605
+
1606
+ Supported Platforms:
1607
+ ``Ascend``
1608
+ """
1609
+ return trunc(input)
1610
+
1611
+
1612
+ def scatter(input, dim, index, src):
1613
+ """
1614
+ Update the value in `src` to `input` according to the specified index.
1615
+ For a 3-D tensor, the output will be:
1616
+
1617
+ .. code-block::
1618
+
1619
+ output[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
1620
+
1621
+ output[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
1622
+
1623
+ output[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
1624
+
1625
+ .. note::
1626
+ The backward is supported only for the case `src.shape == index.shape` when `src` is a tensor.
1627
+
1628
+ Args:
1629
+ input (Tensor): The target tensor. The rank of `input` must be at least 1.
1630
+ dim (int): Which axis to scatter. Accepted range is [-r, r) where r = rank(input).
1631
+ index (Tensor): The index to do update operation whose data must be positive number with type of mindspore.int32
1632
+ or mindspore.int64. Same rank as `input` . And accepted range is [-s, s) where s is the size along axis.
1633
+ src (Tensor, float): The data doing the update operation with `input`. Can be a tensor with the same data type
1634
+ as `input` or a float number to scatter.
1635
+
1636
+ Returns:
1637
+ Tensor, has the same shape and type as `input` .
1638
+
1639
+ Raises:
1640
+ TypeError: If `index` is neither int32 nor int64.
1641
+ ValueError: If rank of any of `input` , `index` and `src` less than 1.
1642
+ ValueError: If the rank of `src` is not equal to the rank of `input` .
1643
+ TypeError: If the data type of `input` and `src` have different dtypes.
1644
+ RuntimeError: If `index` has negative elements.
1645
+
1646
+ Supported Platforms:
1647
+ ``Ascend`` ``GPU`` ``CPU``
1648
+
1649
+ Examples:
1650
+ >>> import numpy as np
1651
+ >>> import mindspore as ms
1652
+ >>> from mindspore import Tensor, mint
1653
+ >>> input = Tensor(np.array([[1, 2, 3, 4, 5]]), dtype=ms.float32)
1654
+ >>> src = Tensor(np.array([[8, 8]]), dtype=ms.float32)
1655
+ >>> index = Tensor(np.array([[2, 4]]), dtype=ms.int64)
1656
+ >>> out = mint.scatter(input=input, dim=1, index=index, src=src)
1657
+ >>> print(out)
1658
+ [[1. 2. 8. 4. 8.]]
1659
+ >>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32)
1660
+ >>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
1661
+ >>> index = Tensor(np.array([[0, 0, 0], [2, 2, 2], [4, 4, 4]]), dtype=ms.int64)
1662
+ >>> out = mint.scatter(input=input, dim=0, index=index, src=src)
1663
+ >>> print(out)
1664
+ [[1. 2. 3. 0. 0.]
1665
+ [0. 0. 0. 0. 0.]
1666
+ [4. 5. 6. 0. 0.]
1667
+ [0. 0. 0. 0. 0.]
1668
+ [7. 8. 9. 0. 0.]]
1669
+ >>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32)
1670
+ >>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
1671
+ >>> index = Tensor(np.array([[0, 2, 4], [0, 2, 4], [0, 2, 4]]), dtype=ms.int64)
1672
+ >>> out = mint.scatter(input=input, dim=1, index=index, src=src)
1673
+ >>> print(out)
1674
+ [[1. 0. 2. 0. 3.]
1675
+ [4. 0. 5. 0. 6.]
1676
+ [7. 0. 8. 0. 9.]
1677
+ [0. 0. 0. 0. 0.]
1678
+ [0. 0. 0. 0. 0.]]
1679
+ """
1680
+ return ops.function.array_func.scatter(input, dim, index, src)
1681
+
1682
+
895
1683
  __all__ = [
1684
+ 'conv2d',
896
1685
  'full',
897
1686
  'ones_like',
898
1687
  'zeros_like',
899
1688
  'abs',
900
1689
  'erf',
901
1690
  'where',
902
- 'linspace',
903
1691
  'isclose',
904
1692
  # 1
905
1693
  'div',
906
1694
  'divide',
907
1695
  'topk',
1696
+ 'roll',
908
1697
  # 2
909
1698
  'sin',
910
1699
  # 3
911
1700
  'clamp',
1701
+ 'xlogy',
912
1702
  # 4
913
-
1703
+ 'sinc',
1704
+ 'sinh',
1705
+ 'cosh',
914
1706
  # 5
915
1707
  'cumsum',
916
1708
  # 6
@@ -918,15 +1710,16 @@ __all__ = [
918
1710
  # 7
919
1711
  'zeros',
920
1712
  # 8
921
-
1713
+ 'transpose',
1714
+ 'swapaxes',
922
1715
  # 9
923
1716
 
924
1717
  # 10
925
1718
  'ne',
926
1719
  # 11
927
-
1720
+ 'unsqueeze',
928
1721
  # 12
929
-
1722
+ "repeat_interleave",
930
1723
  # 13
931
1724
  "flip",
932
1725
  # 14
@@ -966,8 +1759,9 @@ __all__ = [
966
1759
  # 30
967
1760
  'searchsorted',
968
1761
  # 31
969
-
970
- # 32
1762
+ 'cummax',
1763
+ 'cummin',
1764
+ 'einsum',
971
1765
  'sub',
972
1766
  # 33
973
1767
  'split',
@@ -994,15 +1788,17 @@ __all__ = [
994
1788
  # 44
995
1789
  'cos',
996
1790
  # 45
997
-
1791
+ 'concat',
998
1792
  # 46
999
-
1793
+ 'bitwise_and',
1794
+ 'bitwise_or',
1795
+ 'bitwise_xor',
1000
1796
  # 47
1001
1797
  'max',
1002
1798
  # 48
1003
1799
  'min',
1004
1800
  # 49
1005
-
1801
+ 'baddbmm',
1006
1802
  # 50
1007
1803
  'tile',
1008
1804
  # 51
@@ -1014,9 +1810,9 @@ __all__ = [
1014
1810
  # 54
1015
1811
  'normal',
1016
1812
  # 55
1017
-
1813
+ 'cross',
1018
1814
  # 56
1019
-
1815
+ 'norm',
1020
1816
  # 57
1021
1817
  'broadcast_to',
1022
1818
  # 58
@@ -1040,7 +1836,7 @@ __all__ = [
1040
1836
  # 67
1041
1837
  'logical_or',
1042
1838
  # 68
1043
-
1839
+ 'logical_xor',
1044
1840
  # 69
1045
1841
  'less_equal',
1046
1842
  'le',
@@ -1077,9 +1873,10 @@ __all__ = [
1077
1873
  'narrow',
1078
1874
  # 84
1079
1875
 
1080
- # 85
1876
+ 'masked_select',
1081
1877
 
1082
1878
  # 86
1879
+ 'select',
1083
1880
 
1084
1881
  # 87
1085
1882
 
@@ -1109,29 +1906,169 @@ __all__ = [
1109
1906
 
1110
1907
  # 100
1111
1908
 
1909
+ # 101
1910
+
1911
+ # 102
1912
+
1913
+ # 103
1914
+
1915
+ # 104
1916
+
1917
+ # 105
1918
+
1919
+ # 106
1920
+
1921
+ # 107
1922
+
1923
+ # 108
1924
+
1925
+ # 109
1926
+ 'argmin',
1927
+ # 110
1928
+
1929
+ # 111
1930
+
1931
+ # 112
1932
+
1933
+ # 113
1934
+
1935
+ # 114
1936
+
1937
+ # 115
1938
+
1939
+ # 116
1940
+
1941
+ # 117
1942
+
1943
+ # 118
1944
+
1945
+ # 119
1946
+
1947
+ # 120
1948
+
1949
+ # 121
1950
+
1951
+ # 122
1952
+
1953
+ # 151
1954
+ 'acos',
1955
+ 'arccos',
1956
+ # 152
1957
+ 'acosh',
1958
+ 'arccosh',
1959
+ # 153
1960
+
1961
+ # 154
1962
+
1963
+ # 155
1964
+
1965
+ # 156
1966
+
1967
+ # 157
1968
+ 'scatter',
1969
+ # 172
1970
+ 'asin',
1971
+ 'arcsin',
1972
+ # 173
1973
+ 'asinh',
1974
+ 'arcsinh',
1975
+ # 174
1976
+ 'atan',
1977
+ 'arctan',
1978
+ # 175
1979
+ 'atanh',
1980
+ 'arctanh',
1112
1981
  # 176
1113
1982
  'atan2',
1114
1983
  'arctan2',
1115
1984
 
1985
+ # 177
1986
+ 'round',
1987
+
1988
+ # 182
1989
+ 'bernoulli',
1990
+
1991
+ # 207
1992
+ 'expm1',
1993
+ # 204
1994
+ 'erfc',
1116
1995
  # 208
1117
1996
  'eye',
1997
+
1998
+ # 256
1999
+ 'median',
2000
+ 'randperm',
1118
2001
  'rand',
1119
2002
  'rand_like',
2003
+ 'randn',
2004
+ 'randn_like',
2005
+ 'randint',
2006
+ 'randint_like',
1120
2007
  # 210
1121
2008
  'floor',
1122
2009
  # 231
1123
2010
  'inverse',
2011
+ # 244
2012
+ 'log1p',
2013
+ # 261
2014
+ 'multinomial',
2015
+ # 275
2016
+ 'remainder',
1124
2017
  # 285
1125
2018
  'scatter_add',
2019
+ # 289
2020
+ 'sign',
2021
+ # 301
2022
+ 'tan',
2023
+ # 303
2024
+ 'trace',
2025
+ 'reshape',
2026
+ 'outer',
1126
2027
  # 304
2028
+ 'tril',
1127
2029
 
1128
2030
  # 305
1129
2031
  'triu',
2032
+
2033
+ # 538
2034
+ 'histc',
2035
+
2036
+ # 553
2037
+ 'logaddexp',
2038
+
2039
+ # 610
2040
+ 'nan_to_num',
2041
+
2042
+ # 695
2043
+ 'count_nonzero',
1130
2044
  ]
1131
- __all__.extend(array_func.__all__)
1132
- __all__.extend(math_func.__all__)
1133
- __all__.extend(nn_func.__all__)
2045
+
2046
+ setattr(tensor_operator_registry_for_mint, 'add', add)
2047
+ setattr(tensor_operator_registry_for_mint, 'all', all)
2048
+ setattr(tensor_operator_registry_for_mint, 'any', any)
2049
+ setattr(tensor_operator_registry_for_mint, 'log', log)
2050
+ setattr(tensor_operator_registry_for_mint, 'ceil', ceil)
2051
+ setattr(tensor_operator_registry_for_mint, 'clamp', clamp)
2052
+ setattr(tensor_operator_registry_for_mint, 'cos', cos)
2053
+ setattr(tensor_operator_registry_for_mint, 'flatten', flatten)
2054
+ setattr(tensor_operator_registry_for_mint, 'item', item)
2055
+ setattr(tensor_operator_registry_for_mint, 'max', max)
2056
+ setattr(tensor_operator_registry_for_mint, 'mean', mean)
2057
+ setattr(tensor_operator_registry_for_mint, 'min', min)
2058
+ setattr(tensor_operator_registry_for_mint,
2059
+ 'repeat_interleave', repeat_interleave)
2060
+ setattr(tensor_operator_registry_for_mint, 'ne', ne)
2061
+ setattr(tensor_operator_registry_for_mint, 'round', round)
2062
+ setattr(tensor_operator_registry_for_mint, 'sin', sin)
2063
+ setattr(tensor_operator_registry_for_mint, 'split', split)
2064
+ setattr(tensor_operator_registry_for_mint, 'sqrt', sqrt)
2065
+ setattr(tensor_operator_registry_for_mint, 'square', square)
2066
+ setattr(tensor_operator_registry_for_mint, 'sub', sub)
2067
+ setattr(tensor_operator_registry_for_mint, 'sum', sum)
2068
+
1134
2069
  __all__.extend(functional.__all__)
1135
2070
  __all__.extend(nn.__all__)
1136
2071
  __all__.extend(optim.__all__)
1137
2072
  __all__.extend(linalg.__all__)
2073
+ __all__.extend(special.__all__)
2074
+ __all__.extend(distributed.__all__)