mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/mint/__init__.py
CHANGED
|
@@ -15,45 +15,57 @@
|
|
|
15
15
|
"""mint module."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
import mindspore.ops as ops
|
|
18
|
-
from mindspore.ops.
|
|
19
|
-
from mindspore.
|
|
20
|
-
from mindspore.
|
|
18
|
+
from mindspore.ops.primitive import constexpr
|
|
19
|
+
from mindspore.common._register_for_tensor import tensor_operator_registry_for_mint
|
|
20
|
+
from mindspore.common.tensor import Tensor
|
|
21
|
+
from mindspore.ops.function.array_func import gather_ext as gather, max_ext as max, min_ext as min
|
|
22
|
+
from mindspore.ops.function.nn_func import conv2d_ext as conv2d
|
|
23
|
+
from mindspore.mint.nn.functional import sigmoid
|
|
21
24
|
from mindspore.mint.nn import functional
|
|
22
25
|
from mindspore.mint import linalg
|
|
23
|
-
from mindspore.
|
|
26
|
+
from mindspore.mint import special
|
|
27
|
+
from mindspore.mint import distributed
|
|
28
|
+
from mindspore.ops import erf, where
|
|
24
29
|
from mindspore.ops.function.math_func import linspace_ext as linspace
|
|
25
|
-
from mindspore.ops.function.
|
|
30
|
+
from mindspore.ops.function.math_func import median_ext as median
|
|
26
31
|
from mindspore.ops.function.array_func import ones_like_ext as ones_like
|
|
32
|
+
from mindspore.ops.function.array_func import full_ext as full
|
|
27
33
|
from mindspore.ops.function.array_func import zeros_like_ext as zeros_like
|
|
28
34
|
from mindspore.ops.function.array_func import unique_ext as unique
|
|
35
|
+
from mindspore.ops.function.array_func import chunk_ext as chunk
|
|
29
36
|
from mindspore.ops.function.math_func import isclose
|
|
30
37
|
from mindspore.ops.auto_generate import abs
|
|
31
38
|
# 1
|
|
32
39
|
from mindspore.ops.function.math_func import divide, div
|
|
33
40
|
from mindspore.ops.auto_generate import topk_ext as topk
|
|
41
|
+
from mindspore.ops.function.math_func import roll
|
|
34
42
|
# 2
|
|
35
43
|
from mindspore.ops.function.math_func import sin
|
|
36
44
|
# 3
|
|
37
45
|
from mindspore.ops.function.clip_func import clamp
|
|
38
46
|
# 4
|
|
39
|
-
|
|
47
|
+
from mindspore.ops.auto_generate import sinc
|
|
48
|
+
from mindspore.ops.auto_generate import sinh
|
|
49
|
+
from mindspore.ops.auto_generate import cosh
|
|
50
|
+
from mindspore.ops.function.math_func import xlogy_ext as xlogy
|
|
40
51
|
# 5
|
|
41
52
|
from mindspore.ops.auto_generate import cumsum_ext as cumsum
|
|
42
53
|
# 6
|
|
43
54
|
from mindspore.ops.auto_generate import stack_ext as stack
|
|
44
55
|
|
|
45
56
|
# 7
|
|
46
|
-
|
|
57
|
+
from mindspore.ops.function.array_func import unsqueeze
|
|
47
58
|
# 8
|
|
48
|
-
|
|
59
|
+
from mindspore.ops.auto_generate import transpose_ext as transpose
|
|
49
60
|
# 9
|
|
50
|
-
|
|
61
|
+
from mindspore.ops.auto_generate import masked_select
|
|
62
|
+
from mindspore.ops.function.math_func import cross
|
|
51
63
|
# 10
|
|
52
64
|
from mindspore.ops.function.math_func import ne
|
|
53
65
|
# 11
|
|
54
66
|
|
|
55
67
|
# 12
|
|
56
|
-
|
|
68
|
+
from mindspore.ops.function.array_func import repeat_interleave_ext as repeat_interleave
|
|
57
69
|
# 13
|
|
58
70
|
from mindspore.ops.functional import flip
|
|
59
71
|
# 14
|
|
@@ -122,13 +134,13 @@ from mindspore.ops.functional import cos
|
|
|
122
134
|
# 45
|
|
123
135
|
|
|
124
136
|
# 46
|
|
125
|
-
|
|
137
|
+
from mindspore.ops.function.math_func import bitwise_and_ext as bitwise_and
|
|
126
138
|
# 47
|
|
127
|
-
|
|
139
|
+
from mindspore.ops.function.math_func import bitwise_or_ext as bitwise_or
|
|
128
140
|
# 48
|
|
129
|
-
|
|
141
|
+
from mindspore.ops.function.math_func import bitwise_xor_ext as bitwise_xor
|
|
130
142
|
# 49
|
|
131
|
-
|
|
143
|
+
from mindspore.ops.function.math_func import baddbmm_ext as baddbmm
|
|
132
144
|
# 50
|
|
133
145
|
from mindspore.ops.functional import tile
|
|
134
146
|
# 51
|
|
@@ -142,7 +154,7 @@ from mindspore.ops.function.random_func import normal_ext as normal
|
|
|
142
154
|
# 55
|
|
143
155
|
|
|
144
156
|
# 56
|
|
145
|
-
|
|
157
|
+
from mindspore.ops.function.math_func import norm_ext as norm
|
|
146
158
|
# 57
|
|
147
159
|
from mindspore.ops.functional import broadcast_to
|
|
148
160
|
# 58
|
|
@@ -166,7 +178,7 @@ from mindspore.ops.functional import logical_not
|
|
|
166
178
|
# 67
|
|
167
179
|
from mindspore.ops.functional import logical_or
|
|
168
180
|
# 68
|
|
169
|
-
|
|
181
|
+
from mindspore.ops.functional import logical_xor
|
|
170
182
|
# 69
|
|
171
183
|
from mindspore.ops.functional import less_equal, le
|
|
172
184
|
# 70
|
|
@@ -194,7 +206,7 @@ from mindspore.ops.function import arange_ext as arange
|
|
|
194
206
|
# 81
|
|
195
207
|
from mindspore.ops.auto_generate import index_select_ext as index_select
|
|
196
208
|
# 82
|
|
197
|
-
|
|
209
|
+
from mindspore.ops.auto_generate import cummin_ext as cummin
|
|
198
210
|
# 83
|
|
199
211
|
from mindspore.ops.function.array_func import narrow_ext as narrow
|
|
200
212
|
# 84
|
|
@@ -204,9 +216,9 @@ from mindspore.mint import nn, optim
|
|
|
204
216
|
# 86
|
|
205
217
|
|
|
206
218
|
# 87
|
|
207
|
-
|
|
219
|
+
from mindspore.ops.auto_generate import trunc
|
|
208
220
|
# 88
|
|
209
|
-
|
|
221
|
+
|
|
210
222
|
# 89
|
|
211
223
|
|
|
212
224
|
# 90
|
|
@@ -231,24 +243,135 @@ from mindspore.ops.function.math_func import tanh
|
|
|
231
243
|
|
|
232
244
|
# 100
|
|
233
245
|
|
|
246
|
+
# 101
|
|
247
|
+
|
|
248
|
+
# 102
|
|
249
|
+
|
|
250
|
+
# 103
|
|
251
|
+
|
|
252
|
+
# 104
|
|
253
|
+
|
|
254
|
+
# 105
|
|
255
|
+
|
|
256
|
+
# 106
|
|
257
|
+
|
|
258
|
+
# 107
|
|
259
|
+
|
|
260
|
+
# 108
|
|
261
|
+
|
|
262
|
+
# 109
|
|
263
|
+
from mindspore.ops.auto_generate import argmin_ext as argmin
|
|
264
|
+
# 110
|
|
265
|
+
|
|
266
|
+
# 111
|
|
267
|
+
|
|
268
|
+
# 112
|
|
269
|
+
|
|
270
|
+
# 113
|
|
271
|
+
|
|
272
|
+
# 114
|
|
273
|
+
|
|
274
|
+
# 115
|
|
275
|
+
|
|
276
|
+
# 116
|
|
277
|
+
|
|
278
|
+
# 117
|
|
279
|
+
|
|
280
|
+
# 118
|
|
281
|
+
|
|
282
|
+
# 119
|
|
283
|
+
|
|
284
|
+
# 120
|
|
285
|
+
|
|
286
|
+
# 121
|
|
287
|
+
|
|
234
288
|
# 122
|
|
235
289
|
|
|
290
|
+
# 151
|
|
291
|
+
from mindspore.ops.function.math_func import acos_ext as acos
|
|
292
|
+
from mindspore.ops.function.math_func import arccos_ext as arccos
|
|
293
|
+
# 152
|
|
294
|
+
from mindspore.ops.function.math_func import acosh_ext as acosh
|
|
295
|
+
from mindspore.ops.function.math_func import arccosh_ext as arccosh
|
|
296
|
+
# 172
|
|
297
|
+
from mindspore.ops.function.math_func import asin_ext as asin
|
|
298
|
+
from mindspore.ops.function.math_func import arcsin_ext as arcsin
|
|
299
|
+
# 173
|
|
300
|
+
from mindspore.ops.function.math_func import asinh_ext as asinh
|
|
301
|
+
from mindspore.ops.function.math_func import arcsinh_ext as arcsinh
|
|
302
|
+
# 174
|
|
303
|
+
from mindspore.ops.function.math_func import atan_ext as atan
|
|
304
|
+
from mindspore.ops.function.math_func import arctan_ext as arctan
|
|
305
|
+
# 175
|
|
306
|
+
from mindspore.ops.function.math_func import atanh
|
|
307
|
+
from mindspore.ops.function.math_func import arctanh
|
|
236
308
|
# 176
|
|
237
309
|
from mindspore.ops.function.math_func import atan2_ext as atan2
|
|
238
310
|
from mindspore.ops.function.math_func import arctan2_ext as arctan2
|
|
239
311
|
|
|
312
|
+
# 177
|
|
313
|
+
from mindspore.ops.function.math_func import round
|
|
240
314
|
|
|
315
|
+
# 182
|
|
316
|
+
from mindspore.ops.function.math_func import bernoulli_ext as bernoulli
|
|
317
|
+
|
|
318
|
+
# 204
|
|
319
|
+
from mindspore.ops.auto_generate import erfc
|
|
320
|
+
# 207
|
|
321
|
+
from mindspore.ops.auto_generate import expm1
|
|
241
322
|
# 208
|
|
242
323
|
from mindspore.ops.function.array_func import eye
|
|
324
|
+
from mindspore.ops.function.random_func import randperm_ext as randperm
|
|
243
325
|
from mindspore.ops.function.random_func import rand_ext as rand
|
|
244
326
|
from mindspore.ops.function.random_func import rand_like_ext as rand_like
|
|
327
|
+
from mindspore.ops.function.random_func import randn_ext as randn
|
|
328
|
+
from mindspore.ops.function.random_func import randn_like_ext as randn_like
|
|
329
|
+
from mindspore.ops.function.random_func import randint_ext as randint
|
|
330
|
+
from mindspore.ops.function.random_func import randint_like_ext as randint_like
|
|
245
331
|
# 210
|
|
246
332
|
from mindspore.ops.auto_generate import floor
|
|
247
333
|
# 231
|
|
248
334
|
from mindspore.ops.function.math_func import inverse_ext as inverse
|
|
249
|
-
|
|
335
|
+
# 244
|
|
336
|
+
from mindspore.ops.auto_generate import log1p
|
|
337
|
+
# 261
|
|
338
|
+
from mindspore.ops.function.random_func import multinomial_ext as multinomial
|
|
339
|
+
# 275
|
|
340
|
+
from mindspore.ops.function.math_func import remainder_ext as remainder
|
|
250
341
|
# 285
|
|
251
342
|
from mindspore.ops.function.array_func import scatter_add_ext as scatter_add
|
|
343
|
+
# 289
|
|
344
|
+
from mindspore.ops.auto_generate import sign
|
|
345
|
+
|
|
346
|
+
from mindspore.ops.auto_generate import select_ext as select
|
|
347
|
+
|
|
348
|
+
# 301
|
|
349
|
+
from mindspore.ops.function.math_func import tan
|
|
350
|
+
|
|
351
|
+
# 303
|
|
352
|
+
from mindspore.ops.auto_generate import trace_ext as trace
|
|
353
|
+
|
|
354
|
+
from mindspore.ops.function.array_func import reshape
|
|
355
|
+
|
|
356
|
+
from mindspore.ops.auto_generate import outer_ext as outer
|
|
357
|
+
|
|
358
|
+
# 304
|
|
359
|
+
from mindspore.ops.function.array_func import tril_ext as tril
|
|
360
|
+
|
|
361
|
+
# 305
|
|
362
|
+
from mindspore.ops import triu
|
|
363
|
+
|
|
364
|
+
# 538
|
|
365
|
+
from mindspore.ops.function.math_func import histc_ext as histc
|
|
366
|
+
|
|
367
|
+
# 553
|
|
368
|
+
from mindspore.ops.auto_generate import logaddexp_ext as logaddexp
|
|
369
|
+
|
|
370
|
+
# 610
|
|
371
|
+
from mindspore.ops.function.math_func import nan_to_num
|
|
372
|
+
|
|
373
|
+
# 695
|
|
374
|
+
from mindspore.ops.auto_generate import count_nonzero
|
|
252
375
|
|
|
253
376
|
|
|
254
377
|
def add(input, other, *, alpha=1):
|
|
@@ -268,12 +391,12 @@ def add(input, other, *, alpha=1):
|
|
|
268
391
|
Args:
|
|
269
392
|
input (Union[Tensor, number.Number, bool]): The first input is a number.Number or
|
|
270
393
|
a bool or a tensor whose data type is
|
|
271
|
-
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
272
|
-
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
394
|
+
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
|
|
395
|
+
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
|
|
273
396
|
other (Union[Tensor, number.Number, bool]): The second input, is a number.Number or
|
|
274
397
|
a bool or a tensor whose data type is
|
|
275
|
-
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
276
|
-
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
398
|
+
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
|
|
399
|
+
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
|
|
277
400
|
|
|
278
401
|
Keyword Args:
|
|
279
402
|
alpha (number.Number): A scaling factor applied to `other`, default 1.
|
|
@@ -429,7 +552,6 @@ def all(input, dim=None, keepdim=False):
|
|
|
429
552
|
return ops.function.math_func.all(input, dim, keepdim)
|
|
430
553
|
|
|
431
554
|
|
|
432
|
-
|
|
433
555
|
def cat(tensors, dim=0):
|
|
434
556
|
r"""
|
|
435
557
|
Connect input tensors along with the given dimension.
|
|
@@ -487,6 +609,573 @@ def cat(tensors, dim=0):
|
|
|
487
609
|
return ops.auto_generate.cat(tensors, dim)
|
|
488
610
|
|
|
489
611
|
|
|
612
|
+
def concat(tensors, dim=0):
|
|
613
|
+
r"""
|
|
614
|
+
.. warning::
|
|
615
|
+
This is an experimental API that is subject to change or deletion.
|
|
616
|
+
|
|
617
|
+
Alias of mint.cat().
|
|
618
|
+
"""
|
|
619
|
+
return cat(tensors, dim)
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def cummax(input, dim):
|
|
623
|
+
r"""
|
|
624
|
+
Returns a tuple (values, indices) where `values` is the cumulative maximum value of input Tensor `input`
|
|
625
|
+
along the dimension `dim`, and `indices` is the index location of each maximum value.
|
|
626
|
+
|
|
627
|
+
.. math::
|
|
628
|
+
\begin{array}{ll} \\
|
|
629
|
+
y_{i} = \max(x_{1}, x_{2}, ... , x_{i})
|
|
630
|
+
\end{array}
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
input (Tensor): The input Tensor. Rank of `input` must be greater than 0.
|
|
634
|
+
dim (int): The dimension to do the operation over. The value of `dim` must be in the range
|
|
635
|
+
`[-input.ndim, input.ndim - 1]`.
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
tuple [Tensor], tuple of 2 Tensors, containing the cumulative maximum of elements and the index.
|
|
639
|
+
The shape of each output tensor is the same as that of input `input`.
|
|
640
|
+
|
|
641
|
+
Raises:
|
|
642
|
+
TypeError: If `input` is not a Tensor.
|
|
643
|
+
TypeError: If `dim` is not an int.
|
|
644
|
+
ValueError: If `dim` is out the range of `[-input.ndim, input.ndim - 1]`.
|
|
645
|
+
|
|
646
|
+
.. note::
|
|
647
|
+
O2 mode is not supported in Ascend.
|
|
648
|
+
|
|
649
|
+
Supported Platforms:
|
|
650
|
+
``Ascend``
|
|
651
|
+
|
|
652
|
+
Examples:
|
|
653
|
+
>>> import mindspore
|
|
654
|
+
>>> import numpy as np
|
|
655
|
+
>>> from mindspore import Tensor
|
|
656
|
+
>>> from mindspore import ops
|
|
657
|
+
>>> x = Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float32))
|
|
658
|
+
>>> output = mint.cummax(x, dim=0)
|
|
659
|
+
>>> print(output[0])
|
|
660
|
+
[[ 3. 4. 6. 10.]
|
|
661
|
+
[ 3. 6. 7. 10.]
|
|
662
|
+
[ 4. 6. 8. 10.]
|
|
663
|
+
[ 4. 6. 8. 10.]]
|
|
664
|
+
>>> print(output[1])
|
|
665
|
+
[[0 0 0 0]
|
|
666
|
+
[0 1 1 0]
|
|
667
|
+
[2 1 2 0]
|
|
668
|
+
[2 1 2 0]]
|
|
669
|
+
"""
|
|
670
|
+
return ops.auto_generate.cummax(input, dim)
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def _einsum_convert_sublist_to_label(num, ell_num=False):
|
|
674
|
+
"""Convert sublist to label."""
|
|
675
|
+
if num == Ellipsis or ell_num and num == 52:
|
|
676
|
+
return '...'
|
|
677
|
+
if 0 <= num < 26:
|
|
678
|
+
return chr(num + ord('A'))
|
|
679
|
+
if 26 <= num < 52:
|
|
680
|
+
return chr(num + ord('a') - 26)
|
|
681
|
+
raise ValueError(f'For einsum, the number in sublist must be in range [0, 52), but got {num}')
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def _einsum_convert_label_to_index(label):
|
|
685
|
+
"""Convert label to index."""
|
|
686
|
+
label_num = ord(label)
|
|
687
|
+
if ord('A') <= label_num <= ord('Z'):
|
|
688
|
+
return label_num - ord('A')
|
|
689
|
+
if ord('a') <= label_num <= ord('z'):
|
|
690
|
+
return label_num - ord('a') + 26
|
|
691
|
+
if label_num == ord('.'):
|
|
692
|
+
return 52
|
|
693
|
+
raise ValueError(f'For einsum, the label in equation must be in [a-zA-Z] or ., but got {label}')
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def _einsum_convert_sublist(equation, *operands):
|
|
697
|
+
"""Convert the sublist to an equation operand if the received input is a sublist format."""
|
|
698
|
+
if isinstance(equation, Tensor):
|
|
699
|
+
equation_tmp = ''
|
|
700
|
+
for i, lst in enumerate(operands):
|
|
701
|
+
if i % 2 == 0:
|
|
702
|
+
for _, num in enumerate(lst):
|
|
703
|
+
equation_tmp += _einsum_convert_sublist_to_label(num)
|
|
704
|
+
if i in (len(operands) - 1, len(operands) - 2):
|
|
705
|
+
continue
|
|
706
|
+
equation_tmp += ','
|
|
707
|
+
if len(operands) % 2 == 0:
|
|
708
|
+
equation_tmp += '->'
|
|
709
|
+
for _, num in enumerate(operands[-1]):
|
|
710
|
+
equation_tmp += _einsum_convert_sublist_to_label(num)
|
|
711
|
+
operands_tmp = list([equation]) + list(operands[1:-1:2])
|
|
712
|
+
else:
|
|
713
|
+
operands_tmp = list([equation]) + list(operands[1::2])
|
|
714
|
+
equation = equation_tmp
|
|
715
|
+
operands = tuple(operands_tmp)
|
|
716
|
+
if len(operands) == 0: # pylint: disable=len-as-condition
|
|
717
|
+
raise ValueError("For einsum, the 'operands' must have at least one operand.")
|
|
718
|
+
return equation, operands
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
def _einsum_check_inputargs(equation, operands):
|
|
722
|
+
"""Check equation and operands."""
|
|
723
|
+
if not isinstance(equation, str):
|
|
724
|
+
raise TypeError(f"For einsum, 'equation' must be a str, but got {type(equation)}.")
|
|
725
|
+
for operand in operands:
|
|
726
|
+
if not isinstance(operand, Tensor):
|
|
727
|
+
raise TypeError(f"For einsum, members of 'operands' must be Tensor, but got {type(operand)}.")
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
@constexpr
|
|
731
|
+
def _einsum_parse_equation(equation):
|
|
732
|
+
"""Parse equation."""
|
|
733
|
+
l_equation = ''
|
|
734
|
+
r_equation = ''
|
|
735
|
+
equation = equation.replace(' ', '')
|
|
736
|
+
|
|
737
|
+
if '->' in equation:
|
|
738
|
+
l_equation, r_equation = equation.split('->', 1)
|
|
739
|
+
if l_equation == '':
|
|
740
|
+
raise ValueError('For einsum, equation must contain characters to the left fo the arrow.')
|
|
741
|
+
else:
|
|
742
|
+
l_equation = equation
|
|
743
|
+
|
|
744
|
+
if ',' in l_equation:
|
|
745
|
+
l_equationlst = l_equation.split(",")
|
|
746
|
+
else:
|
|
747
|
+
l_equationlst = [l_equation]
|
|
748
|
+
|
|
749
|
+
l_equationlst = []
|
|
750
|
+
|
|
751
|
+
for subequation in l_equation.split(','):
|
|
752
|
+
if '.' in subequation and ('...' not in subequation or subequation.count('.') != 3):
|
|
753
|
+
raise ValueError(f"For einsum, an ellipsis in the equation must include three continuous \'.\', "
|
|
754
|
+
f"and can only be found once.")
|
|
755
|
+
subequation_lst = [_einsum_convert_label_to_index(label) for label in subequation.replace('...', '.')]
|
|
756
|
+
l_equationlst.append(subequation_lst)
|
|
757
|
+
|
|
758
|
+
if "." in r_equation and ('...' not in r_equation or r_equation.count('.') != 3):
|
|
759
|
+
raise ValueError(f"For einsum, an ellipsis in the equation must include three continuous \'.\', "
|
|
760
|
+
f"and can only be found once.")
|
|
761
|
+
r_equationlst = [_einsum_convert_label_to_index(label) for label in r_equation.replace('...', '.')]
|
|
762
|
+
|
|
763
|
+
return l_equationlst, r_equationlst, ('->' in equation)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def _einsum_parse_labels(l_equationlst, operands):
|
|
767
|
+
"""Parse left script of equation."""
|
|
768
|
+
align_rank = 0
|
|
769
|
+
max_labels = 53
|
|
770
|
+
labels_count = [0] * max_labels
|
|
771
|
+
labels2dimlst = [None] * max_labels
|
|
772
|
+
|
|
773
|
+
if len(operands) != len(l_equationlst):
|
|
774
|
+
raise ValueError(f"For einsum, 'operands' is not equal to specified in the 'equation', "
|
|
775
|
+
f"but got {len(operands)} and {len(l_equationlst)}.")
|
|
776
|
+
|
|
777
|
+
for idx, sub_equ in enumerate(l_equationlst):
|
|
778
|
+
start_dim = 0
|
|
779
|
+
label_num = 0
|
|
780
|
+
operand_shape = list(operands[idx].shape)
|
|
781
|
+
for label in sub_equ:
|
|
782
|
+
label_num += 1
|
|
783
|
+
end_dim = start_dim + 1
|
|
784
|
+
|
|
785
|
+
# Label is ellipsis
|
|
786
|
+
if label == 52:
|
|
787
|
+
end_dim = len(operand_shape) - len(sub_equ) + label_num
|
|
788
|
+
if labels2dimlst[label] is None:
|
|
789
|
+
labels2dimlst[label] = operand_shape[start_dim:end_dim]
|
|
790
|
+
align_rank += (end_dim - start_dim)
|
|
791
|
+
else:
|
|
792
|
+
if labels2dimlst[label] != operand_shape[start_dim:end_dim]:
|
|
793
|
+
raise ValueError(f"For einsum, one label in 'equation' can only represent the same dimension "
|
|
794
|
+
f"in 'operands', but '{_einsum_convert_sublist_to_label(label, True)}' "
|
|
795
|
+
f"represented different dimensions.")
|
|
796
|
+
labels_count[label] += 1
|
|
797
|
+
start_dim = end_dim
|
|
798
|
+
if label_num != len(sub_equ) or start_dim != len(operand_shape):
|
|
799
|
+
raise ValueError(f"For einsum, the numbers of labels specified in the 'equation' does not match "
|
|
800
|
+
f"'operands[{idx}]'.")
|
|
801
|
+
return labels2dimlst, labels_count, align_rank
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def _einsum_infer_output(r_equationlst, arrow_exist, labels2dimlst, labels_count):
|
|
805
|
+
"""Parse right script of equation and infer output shape."""
|
|
806
|
+
idx = 0
|
|
807
|
+
idle_idx = -1
|
|
808
|
+
output_shape = []
|
|
809
|
+
labels_perm_idx = [idle_idx] * 53
|
|
810
|
+
|
|
811
|
+
if arrow_exist:
|
|
812
|
+
for label in r_equationlst:
|
|
813
|
+
if labels_count[label] != 0:
|
|
814
|
+
output_shape += labels2dimlst[label]
|
|
815
|
+
if labels_perm_idx[label] != idle_idx:
|
|
816
|
+
raise ValueError(f"For einsum, '{_einsum_convert_sublist_to_label(label, True)}' or {label} in "
|
|
817
|
+
f"sublist format has appears more than once in output subscript.")
|
|
818
|
+
labels_perm_idx[label] = idx
|
|
819
|
+
idx += len(labels2dimlst[label])
|
|
820
|
+
else:
|
|
821
|
+
raise ValueError(f"For einsum, the label to the right of arrow in the 'equation' must appear on "
|
|
822
|
+
f"left, but '{_einsum_convert_sublist_to_label(label, True)}' does not.")
|
|
823
|
+
else:
|
|
824
|
+
if labels_count[52] != 0:
|
|
825
|
+
output_shape += labels2dimlst[52]
|
|
826
|
+
labels_perm_idx[52] = idx
|
|
827
|
+
idx += len(labels2dimlst[52])
|
|
828
|
+
for label, count in enumerate(labels_count):
|
|
829
|
+
if count == 1:
|
|
830
|
+
output_shape += labels2dimlst[label]
|
|
831
|
+
labels_perm_idx[label] = idx
|
|
832
|
+
idx += len(labels2dimlst[label])
|
|
833
|
+
|
|
834
|
+
for label, count in enumerate(labels_count):
|
|
835
|
+
if count != 0 and labels_perm_idx[label] == idle_idx:
|
|
836
|
+
labels_perm_idx[label] = idx
|
|
837
|
+
idx += 1
|
|
838
|
+
|
|
839
|
+
return output_shape, labels_perm_idx
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
def _einsum_adjust_operands(operands, l_equationlst, labels2dimlst, labels_perm_idx, align_rank):
|
|
843
|
+
"""Align operands to output as possible."""
|
|
844
|
+
# Unsqueeze miss dimensions to make all operands has same rank, compute diagonal if operand has same label.
|
|
845
|
+
# Then use _labels_perm_idx to transpose all operands to align dimensions with output.
|
|
846
|
+
adjust_operands = []
|
|
847
|
+
for idx, operand in enumerate(operands):
|
|
848
|
+
idle_dim = -1
|
|
849
|
+
align_axis = [idle_dim] * align_rank
|
|
850
|
+
label_dims = [idle_dim] * 53
|
|
851
|
+
dim = 0
|
|
852
|
+
|
|
853
|
+
for label in l_equationlst[idx]:
|
|
854
|
+
if label_dims[label] != idle_dim:
|
|
855
|
+
operand = ops.diagonal(operand, 0, label_dims[label], dim)
|
|
856
|
+
diag_perm = []
|
|
857
|
+
diag_dim = 0
|
|
858
|
+
for i in range(len(operand.shape)):
|
|
859
|
+
if i == label_dims[label]:
|
|
860
|
+
diag_perm.append(len(operand.shape) - 1)
|
|
861
|
+
else:
|
|
862
|
+
diag_perm.append(diag_dim)
|
|
863
|
+
diag_dim += 1
|
|
864
|
+
operand = permute(operand, tuple(diag_perm))
|
|
865
|
+
else:
|
|
866
|
+
label_dims[label] = dim
|
|
867
|
+
if label == 52:
|
|
868
|
+
for ell_idx in range(len(labels2dimlst[label])):
|
|
869
|
+
align_axis[labels_perm_idx[label] + ell_idx] = dim
|
|
870
|
+
dim += 1
|
|
871
|
+
else:
|
|
872
|
+
align_axis[labels_perm_idx[label]] = dim
|
|
873
|
+
dim += 1
|
|
874
|
+
if len(operand.shape) < align_rank:
|
|
875
|
+
for i, axis in enumerate(align_axis):
|
|
876
|
+
if axis == idle_dim:
|
|
877
|
+
align_axis[i] = dim
|
|
878
|
+
dim += 1
|
|
879
|
+
missing_dims = [1] * (align_rank - len(operand.shape))
|
|
880
|
+
operand_shape = list(operand.shape) + missing_dims
|
|
881
|
+
operand = reshape(operand, operand_shape)
|
|
882
|
+
operand = permute(operand, tuple(align_axis))
|
|
883
|
+
adjust_operands.append(operand)
|
|
884
|
+
return adjust_operands
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def _einsum_find_dimlastop(align_rank, operands, adjust_operands):
|
|
888
|
+
"""Find dim last operand."""
|
|
889
|
+
dim_last_op = [0 for _ in range(align_rank)]
|
|
890
|
+
has_zero_dim = False
|
|
891
|
+
for dim in range(align_rank):
|
|
892
|
+
broadcast_dim = adjust_operands[0].shape[dim]
|
|
893
|
+
for idx in range(1, len(adjust_operands)):
|
|
894
|
+
other_dim = adjust_operands[idx].shape[dim]
|
|
895
|
+
if broadcast_dim != other_dim and broadcast_dim != 1 and other_dim != 1:
|
|
896
|
+
err_msg = "For einsum, operands do not broadcast after align to output [shapes :origin -> adjust]:"
|
|
897
|
+
for i in range(len(operands)):
|
|
898
|
+
err_msg += f" {operands[i].shape} -> {adjust_operands[i].shape}"
|
|
899
|
+
raise ValueError(err_msg)
|
|
900
|
+
if other_dim != 1:
|
|
901
|
+
dim_last_op[dim] = idx
|
|
902
|
+
broadcast_dim = other_dim
|
|
903
|
+
has_zero_dim = has_zero_dim or broadcast_dim == 0
|
|
904
|
+
return dim_last_op, has_zero_dim
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def _einsum_multiplication(sum_dims, l_tensor, r_tensor):
|
|
908
|
+
"""Compute bmm for einsum."""
|
|
909
|
+
batch_dims = []
|
|
910
|
+
lonly_dims = []
|
|
911
|
+
ronly_dims = []
|
|
912
|
+
batch_size = 1
|
|
913
|
+
lonly_size = 1
|
|
914
|
+
ronly_size = 1
|
|
915
|
+
sum_size = 1
|
|
916
|
+
|
|
917
|
+
l_shape = l_tensor.shape
|
|
918
|
+
r_shape = r_tensor.shape
|
|
919
|
+
|
|
920
|
+
# Compute sum if dim is in sum_dims and get shapes for bmm
|
|
921
|
+
for i in range(len(l_shape)):
|
|
922
|
+
sum_l = l_shape[i] > 1
|
|
923
|
+
sum_r = r_shape[i] > 1
|
|
924
|
+
if i in sum_dims:
|
|
925
|
+
if sum_l and sum_r:
|
|
926
|
+
sum_size *= l_shape[i]
|
|
927
|
+
elif sum_l:
|
|
928
|
+
l_tensor = sum(l_tensor, i, True)
|
|
929
|
+
elif sum_r:
|
|
930
|
+
r_tensor = sum(r_tensor, i, True)
|
|
931
|
+
elif sum_l and sum_r:
|
|
932
|
+
batch_dims.append(i)
|
|
933
|
+
batch_size *= l_shape[i]
|
|
934
|
+
elif sum_l:
|
|
935
|
+
lonly_dims.append(i)
|
|
936
|
+
lonly_size *= l_shape[i]
|
|
937
|
+
else:
|
|
938
|
+
ronly_dims.append(i)
|
|
939
|
+
ronly_size *= r_shape[i]
|
|
940
|
+
|
|
941
|
+
# Compute the einsum bmm operators pipeline.
|
|
942
|
+
# The whole operators pipline is transpose(in) -> reshape(in) -> bmm(in) -> reshape(out) -> transpose(out).
|
|
943
|
+
l_reshape_shape = (batch_size, lonly_size, sum_size)
|
|
944
|
+
r_reshape_shape = (batch_size, sum_size, ronly_size)
|
|
945
|
+
|
|
946
|
+
out_reshape_shape = [l_shape[dim] for dim in batch_dims]
|
|
947
|
+
out_reshape_shape += [l_shape[dim] for dim in lonly_dims]
|
|
948
|
+
out_reshape_shape += [1 for _ in sum_dims]
|
|
949
|
+
out_reshape_shape += [r_shape[dim] for dim in ronly_dims]
|
|
950
|
+
|
|
951
|
+
l_perm_axis = batch_dims + lonly_dims + sum_dims + ronly_dims
|
|
952
|
+
r_perm_axis = batch_dims + sum_dims + ronly_dims + lonly_dims
|
|
953
|
+
out_perm_axis = [-1] * len(out_reshape_shape)
|
|
954
|
+
|
|
955
|
+
out_dim = 0
|
|
956
|
+
for idx in range(len(l_perm_axis)):
|
|
957
|
+
out_perm_axis[l_perm_axis[idx]] = out_dim
|
|
958
|
+
out_dim += 1
|
|
959
|
+
|
|
960
|
+
l_tensor = permute(l_tensor, tuple(l_perm_axis))
|
|
961
|
+
l_tensor = reshape(l_tensor, l_reshape_shape)
|
|
962
|
+
|
|
963
|
+
r_tensor = permute(r_tensor, tuple(r_perm_axis))
|
|
964
|
+
r_tensor = reshape(r_tensor, r_reshape_shape)
|
|
965
|
+
|
|
966
|
+
output = bmm(l_tensor, r_tensor)
|
|
967
|
+
output = reshape(output, out_reshape_shape)
|
|
968
|
+
output = permute(output, tuple(out_perm_axis))
|
|
969
|
+
|
|
970
|
+
output_origin_shape = output.shape
|
|
971
|
+
output_squeeze_shape = []
|
|
972
|
+
for dim in range(len(output_origin_shape)):
|
|
973
|
+
if dim not in sum_dims:
|
|
974
|
+
output_squeeze_shape.append(output_origin_shape[dim])
|
|
975
|
+
|
|
976
|
+
return reshape(output, output_squeeze_shape)
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
def _einsum_squeeze(operand, dim):
|
|
980
|
+
'''Will be replaced by mint.squeeze in the future'''
|
|
981
|
+
operand_shape = operand.shape
|
|
982
|
+
squeeze_shape = []
|
|
983
|
+
for idx in range(len(operand_shape)):
|
|
984
|
+
if idx != dim:
|
|
985
|
+
squeeze_shape.append(operand_shape[idx])
|
|
986
|
+
return reshape(operand, squeeze_shape)
|
|
987
|
+
|
|
988
|
+
|
|
989
|
+
def _einsum(equation, operands):
|
|
990
|
+
'''Einsum main process'''
|
|
991
|
+
_l_equationlst, _r_equationlst, _arrow_exist = _einsum_parse_equation(equation)
|
|
992
|
+
_labels2dimlst, _labels_count, _align_rank = _einsum_parse_labels(_l_equationlst, operands)
|
|
993
|
+
_output_shape, _labels_perm_idx = _einsum_infer_output(_r_equationlst, _arrow_exist, _labels2dimlst, _labels_count)
|
|
994
|
+
_output_rank = len(_output_shape)
|
|
995
|
+
|
|
996
|
+
_adjust_operands = _einsum_adjust_operands(operands, _l_equationlst, _labels2dimlst, _labels_perm_idx, _align_rank)
|
|
997
|
+
_dim_last_op, _has_zero_dim = _einsum_find_dimlastop(_align_rank, operands, _adjust_operands)
|
|
998
|
+
_result = _adjust_operands[0]
|
|
999
|
+
|
|
1000
|
+
# Fast path if operands has zero dim.
|
|
1001
|
+
if _has_zero_dim:
|
|
1002
|
+
return zeros(_output_shape, dtype=_result.dtype)
|
|
1003
|
+
|
|
1004
|
+
# Sum or squeeze dimensions that is 1 for all rest operands.
|
|
1005
|
+
_reduce_dim = _output_rank
|
|
1006
|
+
for dim in range(_output_rank, _align_rank):
|
|
1007
|
+
if _dim_last_op[dim] == 0:
|
|
1008
|
+
if _result.shape[_reduce_dim] == 1:
|
|
1009
|
+
_result = _einsum_squeeze(_result, _reduce_dim)
|
|
1010
|
+
else:
|
|
1011
|
+
_result = sum(_result, _reduce_dim)
|
|
1012
|
+
else:
|
|
1013
|
+
_reduce_dim += 1
|
|
1014
|
+
|
|
1015
|
+
# Compute multiplication if operands are more than two.
|
|
1016
|
+
for i in range(1, len(_adjust_operands)):
|
|
1017
|
+
operand = _adjust_operands[i]
|
|
1018
|
+
dim = _output_rank
|
|
1019
|
+
sum_dims = []
|
|
1020
|
+
for j in range(_output_rank, _align_rank):
|
|
1021
|
+
if _dim_last_op[j] < i:
|
|
1022
|
+
operand = _einsum_squeeze(operand, dim)
|
|
1023
|
+
elif _dim_last_op[j] == i:
|
|
1024
|
+
if _result.shape[dim] == 1:
|
|
1025
|
+
operand = sum(operand, dim)
|
|
1026
|
+
_result = _einsum_squeeze(_result, dim)
|
|
1027
|
+
else:
|
|
1028
|
+
sum_dims.append(dim)
|
|
1029
|
+
dim += 1
|
|
1030
|
+
else:
|
|
1031
|
+
dim += 1
|
|
1032
|
+
|
|
1033
|
+
if sum_dims == []:
|
|
1034
|
+
_result = mul(_result, operand)
|
|
1035
|
+
elif len(sum_dims) == len(_result.shape):
|
|
1036
|
+
_result = ops.auto_generate.dot(flatten(_result), flatten(operand))
|
|
1037
|
+
else:
|
|
1038
|
+
_result = _einsum_multiplication(sum_dims, _result, operand)
|
|
1039
|
+
|
|
1040
|
+
return _result
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
def einsum(equation, *operands):
|
|
1044
|
+
r"""
|
|
1045
|
+
According to the Einstein summation Convention (Einsum),
|
|
1046
|
+
the product of the input tensor elements is summed along the specified dimension.
|
|
1047
|
+
You can use this operator to perform diagonal, reducesum, transpose, matmul, mul, inner product operations, etc.
|
|
1048
|
+
|
|
1049
|
+
Note:
|
|
1050
|
+
The sublist format is also supported. For example, mint.einsum(op1, sublist1, op2, sublist2, ..., sublist_out).
|
|
1051
|
+
In this format, equation can be derived by the sublists which are made up of Python's Ellipsis and list of
|
|
1052
|
+
integers in [0, 52). Each operand is followed by a sublist and an output sublist is at the end.
|
|
1053
|
+
|
|
1054
|
+
.. warning::
|
|
1055
|
+
This is an experimental API that is subject to change or deletion.
|
|
1056
|
+
|
|
1057
|
+
Args:
|
|
1058
|
+
equation (str): Notation based on the Einstein summation convention, represent the operation you want to do.
|
|
1059
|
+
the value can contain only letters, commas, ellipsis and arrow.
|
|
1060
|
+
The letters represent input tensor dimension, commas represent separate tensors, ellipsis indicates
|
|
1061
|
+
the tensor dimension that you do not care about, the left of the arrow indicates the input tensors,
|
|
1062
|
+
and the right of it indicates the desired output dimension.
|
|
1063
|
+
operands (Tensor): Input tensor used for calculation. The dtype of the tensor must be the same.
|
|
1064
|
+
|
|
1065
|
+
Returns:
|
|
1066
|
+
Tensor, the shape of it can be obtained from the `equation` , and the dtype is the same as input tensors.
|
|
1067
|
+
|
|
1068
|
+
Raises:
|
|
1069
|
+
TypeError: If `equation` is invalid, or the `equation` does not match the input tensor.
|
|
1070
|
+
ValueError: If the number in sublist is not in [0, 52) in sublist format.
|
|
1071
|
+
|
|
1072
|
+
Supported Platforms:
|
|
1073
|
+
``Ascend``
|
|
1074
|
+
|
|
1075
|
+
Examples:
|
|
1076
|
+
>>> import mindspore
|
|
1077
|
+
>>> import numpy as np
|
|
1078
|
+
>>> from mindspore import Tensor, mint
|
|
1079
|
+
>>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
|
|
1080
|
+
>>> equation = "i->"
|
|
1081
|
+
>>> output = mint.einsum(equation, x)
|
|
1082
|
+
>>> print(output)
|
|
1083
|
+
[7.]
|
|
1084
|
+
>>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
|
|
1085
|
+
>>> y = Tensor(np.array([2.0, 4.0, 3.0]), mindspore.float32)
|
|
1086
|
+
>>> equation = "i,i->i"
|
|
1087
|
+
>>> output = mint.einsum(equation, x, y)
|
|
1088
|
+
>>> print(output)
|
|
1089
|
+
[ 2. 8. 12.]
|
|
1090
|
+
>>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
|
|
1091
|
+
>>> y = Tensor(np.array([[2.0, 3.0], [1.0, 2.0], [4.0, 5.0]]), mindspore.float32)
|
|
1092
|
+
>>> equation = "ij,jk->ik"
|
|
1093
|
+
>>> output = mint.einsum(equation, x, y)
|
|
1094
|
+
>>> print(output)
|
|
1095
|
+
[[16. 22.]
|
|
1096
|
+
[37. 52.]]
|
|
1097
|
+
>>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
|
|
1098
|
+
>>> equation = "ij->ji"
|
|
1099
|
+
>>> output = mint.einsum(equation, x)
|
|
1100
|
+
>>> print(output)
|
|
1101
|
+
[[1. 4.]
|
|
1102
|
+
[2. 5.]
|
|
1103
|
+
[3. 6.]]
|
|
1104
|
+
>>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
|
|
1105
|
+
>>> equation = "ij->j"
|
|
1106
|
+
>>> output = mint.einsum(equation, x)
|
|
1107
|
+
>>> print(output)
|
|
1108
|
+
[5. 7. 9.]
|
|
1109
|
+
>>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
|
|
1110
|
+
>>> equation = "...->"
|
|
1111
|
+
>>> output = mint.einsum(equation, x)
|
|
1112
|
+
>>> print(output)
|
|
1113
|
+
[21.]
|
|
1114
|
+
>>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
|
|
1115
|
+
>>> y = Tensor(np.array([2.0, 4.0, 1.0]), mindspore.float32)
|
|
1116
|
+
>>> equation = "j,i->ji"
|
|
1117
|
+
>>> output = mint.einsum(equation, x, y)
|
|
1118
|
+
>>> print(output)
|
|
1119
|
+
[[ 2. 4. 1.]
|
|
1120
|
+
[ 4. 8. 2.]
|
|
1121
|
+
[ 6. 12. 3.]]
|
|
1122
|
+
>>> x = mindspore.Tensor([1, 2, 3, 4], mindspore.float32)
|
|
1123
|
+
>>> y = mindspore.Tensor([1, 2], mindspore.float32)
|
|
1124
|
+
>>> output = mint.einsum(x, [..., 1], y, [..., 2], [..., 1, 2])
|
|
1125
|
+
>>> print(output)
|
|
1126
|
+
[[1. 2.]
|
|
1127
|
+
[2. 4.]
|
|
1128
|
+
[3. 6.]
|
|
1129
|
+
[4. 8.]]
|
|
1130
|
+
"""
|
|
1131
|
+
_equation, _operands = _einsum_convert_sublist(equation, *operands)
|
|
1132
|
+
_einsum_check_inputargs(_equation, _operands)
|
|
1133
|
+
|
|
1134
|
+
for operand in _operands:
|
|
1135
|
+
if ops.is_sequence_shape_unknown(operand.shape) or ops.is_sequence_value_unknown(operand.shape):
|
|
1136
|
+
raise ValueError(f"For einsum, the element of 'operands' can't be dynamic shape or dynamic rank.")
|
|
1137
|
+
|
|
1138
|
+
return _einsum(_equation, _operands)
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def item(input):
|
|
1142
|
+
r"""
|
|
1143
|
+
Returns the value of this tensor as a standard Python number.
|
|
1144
|
+
|
|
1145
|
+
Note:
|
|
1146
|
+
This only works for tensors with one element.
|
|
1147
|
+
|
|
1148
|
+
Args:
|
|
1149
|
+
input (Tensor[Number]): The input tensor. The dtype of the tensor to be reduced is number.
|
|
1150
|
+
:math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
1151
|
+
|
|
1152
|
+
Returns:
|
|
1153
|
+
number.
|
|
1154
|
+
|
|
1155
|
+
Raises:
|
|
1156
|
+
TypeError: If `input` is not a Tensor.
|
|
1157
|
+
RuntimeError: If the number of `input` elements is not 1.
|
|
1158
|
+
|
|
1159
|
+
Supported Platforms:
|
|
1160
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1161
|
+
|
|
1162
|
+
Examples:
|
|
1163
|
+
>>> import mindspore
|
|
1164
|
+
>>> import numpy as np
|
|
1165
|
+
>>> from mindspore import Tensor, mint
|
|
1166
|
+
>>> x = Tensor(np.array([1]).astype(np.float32))
|
|
1167
|
+
>>> result = mint.item(x)
|
|
1168
|
+
>>> print(result)
|
|
1169
|
+
1.0
|
|
1170
|
+
"""
|
|
1171
|
+
if not isinstance(input, Tensor):
|
|
1172
|
+
raise TypeError(f"the input must be a Tensor, but got {type(input)}")
|
|
1173
|
+
if input.size != 1:
|
|
1174
|
+
raise RuntimeError(
|
|
1175
|
+
"a Tensor with {} elements cannot be converted to Scalar".format(input.size))
|
|
1176
|
+
return input.asnumpy().item()
|
|
1177
|
+
|
|
1178
|
+
|
|
490
1179
|
def mean(input, dim=None, keepdim=False, *, dtype=None):
|
|
491
1180
|
r"""
|
|
492
1181
|
Reduces all dimension of a tensor by averaging all elements in the dimension, by default.
|
|
@@ -818,12 +1507,12 @@ def sub(input, other, *, alpha=1):
|
|
|
818
1507
|
Args:
|
|
819
1508
|
input (Union[Tensor, number.Number, bool]): The first input is a number.Number or
|
|
820
1509
|
a bool or a tensor whose data type is
|
|
821
|
-
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
822
|
-
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
1510
|
+
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
|
|
1511
|
+
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
|
|
823
1512
|
other (Union[Tensor, number.Number, bool]): The second input, is a number.Number or
|
|
824
1513
|
a bool or a tensor whose data type is
|
|
825
|
-
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
826
|
-
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
1514
|
+
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
|
|
1515
|
+
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
|
|
827
1516
|
|
|
828
1517
|
Keyword Args:
|
|
829
1518
|
alpha (number.Number): A scaling factor applied to `other`, default 1.
|
|
@@ -859,6 +1548,22 @@ def sub(input, other, *, alpha=1):
|
|
|
859
1548
|
return ops.auto_generate.sub_ext(input, other, alpha)
|
|
860
1549
|
|
|
861
1550
|
|
|
1551
|
+
def swapaxes(input, axis0, axis1):
|
|
1552
|
+
'''
|
|
1553
|
+
Interchange two axes of a tensor, alias for mint.transpose()
|
|
1554
|
+
|
|
1555
|
+
Examples:
|
|
1556
|
+
>>> import numpy as np
|
|
1557
|
+
>>> from mindspore import mint
|
|
1558
|
+
>>> from mindspore import Tensor
|
|
1559
|
+
>>> input = Tensor(np.ones((2,3,4), dtype=np.float32))
|
|
1560
|
+
>>> output = mint.swapaxes(input, 0, 2)
|
|
1561
|
+
>>> print(output.shape)
|
|
1562
|
+
(4, 3, 2)
|
|
1563
|
+
'''
|
|
1564
|
+
return transpose(input, axis0, axis1)
|
|
1565
|
+
|
|
1566
|
+
|
|
862
1567
|
def zeros(size, *, dtype=None):
|
|
863
1568
|
"""
|
|
864
1569
|
Creates a tensor filled with 0 with shape described by `size` and fills it with value 0 in type of `dtype`.
|
|
@@ -892,25 +1597,112 @@ def zeros(size, *, dtype=None):
|
|
|
892
1597
|
return ops.auto_generate.zeros(size, dtype)
|
|
893
1598
|
|
|
894
1599
|
|
|
1600
|
+
def fix(input):
|
|
1601
|
+
"""
|
|
1602
|
+
Alias for :func:`mindspore.mint.trunc` .
|
|
1603
|
+
|
|
1604
|
+
For more details, see :func:`mindspore.mint.trunc` .
|
|
1605
|
+
|
|
1606
|
+
Supported Platforms:
|
|
1607
|
+
``Ascend``
|
|
1608
|
+
"""
|
|
1609
|
+
return trunc(input)
|
|
1610
|
+
|
|
1611
|
+
|
|
1612
|
+
def scatter(input, dim, index, src):
|
|
1613
|
+
"""
|
|
1614
|
+
Update the value in `src` to `input` according to the specified index.
|
|
1615
|
+
For a 3-D tensor, the output will be:
|
|
1616
|
+
|
|
1617
|
+
.. code-block::
|
|
1618
|
+
|
|
1619
|
+
output[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
|
|
1620
|
+
|
|
1621
|
+
output[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
|
|
1622
|
+
|
|
1623
|
+
output[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
|
|
1624
|
+
|
|
1625
|
+
.. note::
|
|
1626
|
+
The backward is supported only for the case `src.shape == index.shape` when `src` is a tensor.
|
|
1627
|
+
|
|
1628
|
+
Args:
|
|
1629
|
+
input (Tensor): The target tensor. The rank of `input` must be at least 1.
|
|
1630
|
+
dim (int): Which axis to scatter. Accepted range is [-r, r) where r = rank(input).
|
|
1631
|
+
index (Tensor): The index to do update operation whose data must be positive number with type of mindspore.int32
|
|
1632
|
+
or mindspore.int64. Same rank as `input` . And accepted range is [-s, s) where s is the size along axis.
|
|
1633
|
+
src (Tensor, float): The data doing the update operation with `input`. Can be a tensor with the same data type
|
|
1634
|
+
as `input` or a float number to scatter.
|
|
1635
|
+
|
|
1636
|
+
Returns:
|
|
1637
|
+
Tensor, has the same shape and type as `input` .
|
|
1638
|
+
|
|
1639
|
+
Raises:
|
|
1640
|
+
TypeError: If `index` is neither int32 nor int64.
|
|
1641
|
+
ValueError: If rank of any of `input` , `index` and `src` less than 1.
|
|
1642
|
+
ValueError: If the rank of `src` is not equal to the rank of `input` .
|
|
1643
|
+
TypeError: If the data type of `input` and `src` have different dtypes.
|
|
1644
|
+
RuntimeError: If `index` has negative elements.
|
|
1645
|
+
|
|
1646
|
+
Supported Platforms:
|
|
1647
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1648
|
+
|
|
1649
|
+
Examples:
|
|
1650
|
+
>>> import numpy as np
|
|
1651
|
+
>>> import mindspore as ms
|
|
1652
|
+
>>> from mindspore import Tensor, mint
|
|
1653
|
+
>>> input = Tensor(np.array([[1, 2, 3, 4, 5]]), dtype=ms.float32)
|
|
1654
|
+
>>> src = Tensor(np.array([[8, 8]]), dtype=ms.float32)
|
|
1655
|
+
>>> index = Tensor(np.array([[2, 4]]), dtype=ms.int64)
|
|
1656
|
+
>>> out = mint.scatter(input=input, dim=1, index=index, src=src)
|
|
1657
|
+
>>> print(out)
|
|
1658
|
+
[[1. 2. 8. 4. 8.]]
|
|
1659
|
+
>>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32)
|
|
1660
|
+
>>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
|
|
1661
|
+
>>> index = Tensor(np.array([[0, 0, 0], [2, 2, 2], [4, 4, 4]]), dtype=ms.int64)
|
|
1662
|
+
>>> out = mint.scatter(input=input, dim=0, index=index, src=src)
|
|
1663
|
+
>>> print(out)
|
|
1664
|
+
[[1. 2. 3. 0. 0.]
|
|
1665
|
+
[0. 0. 0. 0. 0.]
|
|
1666
|
+
[4. 5. 6. 0. 0.]
|
|
1667
|
+
[0. 0. 0. 0. 0.]
|
|
1668
|
+
[7. 8. 9. 0. 0.]]
|
|
1669
|
+
>>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32)
|
|
1670
|
+
>>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
|
|
1671
|
+
>>> index = Tensor(np.array([[0, 2, 4], [0, 2, 4], [0, 2, 4]]), dtype=ms.int64)
|
|
1672
|
+
>>> out = mint.scatter(input=input, dim=1, index=index, src=src)
|
|
1673
|
+
>>> print(out)
|
|
1674
|
+
[[1. 0. 2. 0. 3.]
|
|
1675
|
+
[4. 0. 5. 0. 6.]
|
|
1676
|
+
[7. 0. 8. 0. 9.]
|
|
1677
|
+
[0. 0. 0. 0. 0.]
|
|
1678
|
+
[0. 0. 0. 0. 0.]]
|
|
1679
|
+
"""
|
|
1680
|
+
return ops.function.array_func.scatter(input, dim, index, src)
|
|
1681
|
+
|
|
1682
|
+
|
|
895
1683
|
__all__ = [
|
|
1684
|
+
'conv2d',
|
|
896
1685
|
'full',
|
|
897
1686
|
'ones_like',
|
|
898
1687
|
'zeros_like',
|
|
899
1688
|
'abs',
|
|
900
1689
|
'erf',
|
|
901
1690
|
'where',
|
|
902
|
-
'linspace',
|
|
903
1691
|
'isclose',
|
|
904
1692
|
# 1
|
|
905
1693
|
'div',
|
|
906
1694
|
'divide',
|
|
907
1695
|
'topk',
|
|
1696
|
+
'roll',
|
|
908
1697
|
# 2
|
|
909
1698
|
'sin',
|
|
910
1699
|
# 3
|
|
911
1700
|
'clamp',
|
|
1701
|
+
'xlogy',
|
|
912
1702
|
# 4
|
|
913
|
-
|
|
1703
|
+
'sinc',
|
|
1704
|
+
'sinh',
|
|
1705
|
+
'cosh',
|
|
914
1706
|
# 5
|
|
915
1707
|
'cumsum',
|
|
916
1708
|
# 6
|
|
@@ -918,15 +1710,16 @@ __all__ = [
|
|
|
918
1710
|
# 7
|
|
919
1711
|
'zeros',
|
|
920
1712
|
# 8
|
|
921
|
-
|
|
1713
|
+
'transpose',
|
|
1714
|
+
'swapaxes',
|
|
922
1715
|
# 9
|
|
923
1716
|
|
|
924
1717
|
# 10
|
|
925
1718
|
'ne',
|
|
926
1719
|
# 11
|
|
927
|
-
|
|
1720
|
+
'unsqueeze',
|
|
928
1721
|
# 12
|
|
929
|
-
|
|
1722
|
+
"repeat_interleave",
|
|
930
1723
|
# 13
|
|
931
1724
|
"flip",
|
|
932
1725
|
# 14
|
|
@@ -966,8 +1759,9 @@ __all__ = [
|
|
|
966
1759
|
# 30
|
|
967
1760
|
'searchsorted',
|
|
968
1761
|
# 31
|
|
969
|
-
|
|
970
|
-
|
|
1762
|
+
'cummax',
|
|
1763
|
+
'cummin',
|
|
1764
|
+
'einsum',
|
|
971
1765
|
'sub',
|
|
972
1766
|
# 33
|
|
973
1767
|
'split',
|
|
@@ -994,15 +1788,17 @@ __all__ = [
|
|
|
994
1788
|
# 44
|
|
995
1789
|
'cos',
|
|
996
1790
|
# 45
|
|
997
|
-
|
|
1791
|
+
'concat',
|
|
998
1792
|
# 46
|
|
999
|
-
|
|
1793
|
+
'bitwise_and',
|
|
1794
|
+
'bitwise_or',
|
|
1795
|
+
'bitwise_xor',
|
|
1000
1796
|
# 47
|
|
1001
1797
|
'max',
|
|
1002
1798
|
# 48
|
|
1003
1799
|
'min',
|
|
1004
1800
|
# 49
|
|
1005
|
-
|
|
1801
|
+
'baddbmm',
|
|
1006
1802
|
# 50
|
|
1007
1803
|
'tile',
|
|
1008
1804
|
# 51
|
|
@@ -1014,9 +1810,9 @@ __all__ = [
|
|
|
1014
1810
|
# 54
|
|
1015
1811
|
'normal',
|
|
1016
1812
|
# 55
|
|
1017
|
-
|
|
1813
|
+
'cross',
|
|
1018
1814
|
# 56
|
|
1019
|
-
|
|
1815
|
+
'norm',
|
|
1020
1816
|
# 57
|
|
1021
1817
|
'broadcast_to',
|
|
1022
1818
|
# 58
|
|
@@ -1040,7 +1836,7 @@ __all__ = [
|
|
|
1040
1836
|
# 67
|
|
1041
1837
|
'logical_or',
|
|
1042
1838
|
# 68
|
|
1043
|
-
|
|
1839
|
+
'logical_xor',
|
|
1044
1840
|
# 69
|
|
1045
1841
|
'less_equal',
|
|
1046
1842
|
'le',
|
|
@@ -1077,9 +1873,10 @@ __all__ = [
|
|
|
1077
1873
|
'narrow',
|
|
1078
1874
|
# 84
|
|
1079
1875
|
|
|
1080
|
-
|
|
1876
|
+
'masked_select',
|
|
1081
1877
|
|
|
1082
1878
|
# 86
|
|
1879
|
+
'select',
|
|
1083
1880
|
|
|
1084
1881
|
# 87
|
|
1085
1882
|
|
|
@@ -1109,29 +1906,169 @@ __all__ = [
|
|
|
1109
1906
|
|
|
1110
1907
|
# 100
|
|
1111
1908
|
|
|
1909
|
+
# 101
|
|
1910
|
+
|
|
1911
|
+
# 102
|
|
1912
|
+
|
|
1913
|
+
# 103
|
|
1914
|
+
|
|
1915
|
+
# 104
|
|
1916
|
+
|
|
1917
|
+
# 105
|
|
1918
|
+
|
|
1919
|
+
# 106
|
|
1920
|
+
|
|
1921
|
+
# 107
|
|
1922
|
+
|
|
1923
|
+
# 108
|
|
1924
|
+
|
|
1925
|
+
# 109
|
|
1926
|
+
'argmin',
|
|
1927
|
+
# 110
|
|
1928
|
+
|
|
1929
|
+
# 111
|
|
1930
|
+
|
|
1931
|
+
# 112
|
|
1932
|
+
|
|
1933
|
+
# 113
|
|
1934
|
+
|
|
1935
|
+
# 114
|
|
1936
|
+
|
|
1937
|
+
# 115
|
|
1938
|
+
|
|
1939
|
+
# 116
|
|
1940
|
+
|
|
1941
|
+
# 117
|
|
1942
|
+
|
|
1943
|
+
# 118
|
|
1944
|
+
|
|
1945
|
+
# 119
|
|
1946
|
+
|
|
1947
|
+
# 120
|
|
1948
|
+
|
|
1949
|
+
# 121
|
|
1950
|
+
|
|
1951
|
+
# 122
|
|
1952
|
+
|
|
1953
|
+
# 151
|
|
1954
|
+
'acos',
|
|
1955
|
+
'arccos',
|
|
1956
|
+
# 152
|
|
1957
|
+
'acosh',
|
|
1958
|
+
'arccosh',
|
|
1959
|
+
# 153
|
|
1960
|
+
|
|
1961
|
+
# 154
|
|
1962
|
+
|
|
1963
|
+
# 155
|
|
1964
|
+
|
|
1965
|
+
# 156
|
|
1966
|
+
|
|
1967
|
+
# 157
|
|
1968
|
+
'scatter',
|
|
1969
|
+
# 172
|
|
1970
|
+
'asin',
|
|
1971
|
+
'arcsin',
|
|
1972
|
+
# 173
|
|
1973
|
+
'asinh',
|
|
1974
|
+
'arcsinh',
|
|
1975
|
+
# 174
|
|
1976
|
+
'atan',
|
|
1977
|
+
'arctan',
|
|
1978
|
+
# 175
|
|
1979
|
+
'atanh',
|
|
1980
|
+
'arctanh',
|
|
1112
1981
|
# 176
|
|
1113
1982
|
'atan2',
|
|
1114
1983
|
'arctan2',
|
|
1115
1984
|
|
|
1985
|
+
# 177
|
|
1986
|
+
'round',
|
|
1987
|
+
|
|
1988
|
+
# 182
|
|
1989
|
+
'bernoulli',
|
|
1990
|
+
|
|
1991
|
+
# 207
|
|
1992
|
+
'expm1',
|
|
1993
|
+
# 204
|
|
1994
|
+
'erfc',
|
|
1116
1995
|
# 208
|
|
1117
1996
|
'eye',
|
|
1997
|
+
|
|
1998
|
+
# 256
|
|
1999
|
+
'median',
|
|
2000
|
+
'randperm',
|
|
1118
2001
|
'rand',
|
|
1119
2002
|
'rand_like',
|
|
2003
|
+
'randn',
|
|
2004
|
+
'randn_like',
|
|
2005
|
+
'randint',
|
|
2006
|
+
'randint_like',
|
|
1120
2007
|
# 210
|
|
1121
2008
|
'floor',
|
|
1122
2009
|
# 231
|
|
1123
2010
|
'inverse',
|
|
2011
|
+
# 244
|
|
2012
|
+
'log1p',
|
|
2013
|
+
# 261
|
|
2014
|
+
'multinomial',
|
|
2015
|
+
# 275
|
|
2016
|
+
'remainder',
|
|
1124
2017
|
# 285
|
|
1125
2018
|
'scatter_add',
|
|
2019
|
+
# 289
|
|
2020
|
+
'sign',
|
|
2021
|
+
# 301
|
|
2022
|
+
'tan',
|
|
2023
|
+
# 303
|
|
2024
|
+
'trace',
|
|
2025
|
+
'reshape',
|
|
2026
|
+
'outer',
|
|
1126
2027
|
# 304
|
|
2028
|
+
'tril',
|
|
1127
2029
|
|
|
1128
2030
|
# 305
|
|
1129
2031
|
'triu',
|
|
2032
|
+
|
|
2033
|
+
# 538
|
|
2034
|
+
'histc',
|
|
2035
|
+
|
|
2036
|
+
# 553
|
|
2037
|
+
'logaddexp',
|
|
2038
|
+
|
|
2039
|
+
# 610
|
|
2040
|
+
'nan_to_num',
|
|
2041
|
+
|
|
2042
|
+
# 695
|
|
2043
|
+
'count_nonzero',
|
|
1130
2044
|
]
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
2045
|
+
|
|
2046
|
+
setattr(tensor_operator_registry_for_mint, 'add', add)
|
|
2047
|
+
setattr(tensor_operator_registry_for_mint, 'all', all)
|
|
2048
|
+
setattr(tensor_operator_registry_for_mint, 'any', any)
|
|
2049
|
+
setattr(tensor_operator_registry_for_mint, 'log', log)
|
|
2050
|
+
setattr(tensor_operator_registry_for_mint, 'ceil', ceil)
|
|
2051
|
+
setattr(tensor_operator_registry_for_mint, 'clamp', clamp)
|
|
2052
|
+
setattr(tensor_operator_registry_for_mint, 'cos', cos)
|
|
2053
|
+
setattr(tensor_operator_registry_for_mint, 'flatten', flatten)
|
|
2054
|
+
setattr(tensor_operator_registry_for_mint, 'item', item)
|
|
2055
|
+
setattr(tensor_operator_registry_for_mint, 'max', max)
|
|
2056
|
+
setattr(tensor_operator_registry_for_mint, 'mean', mean)
|
|
2057
|
+
setattr(tensor_operator_registry_for_mint, 'min', min)
|
|
2058
|
+
setattr(tensor_operator_registry_for_mint,
|
|
2059
|
+
'repeat_interleave', repeat_interleave)
|
|
2060
|
+
setattr(tensor_operator_registry_for_mint, 'ne', ne)
|
|
2061
|
+
setattr(tensor_operator_registry_for_mint, 'round', round)
|
|
2062
|
+
setattr(tensor_operator_registry_for_mint, 'sin', sin)
|
|
2063
|
+
setattr(tensor_operator_registry_for_mint, 'split', split)
|
|
2064
|
+
setattr(tensor_operator_registry_for_mint, 'sqrt', sqrt)
|
|
2065
|
+
setattr(tensor_operator_registry_for_mint, 'square', square)
|
|
2066
|
+
setattr(tensor_operator_registry_for_mint, 'sub', sub)
|
|
2067
|
+
setattr(tensor_operator_registry_for_mint, 'sum', sum)
|
|
2068
|
+
|
|
1134
2069
|
__all__.extend(functional.__all__)
|
|
1135
2070
|
__all__.extend(nn.__all__)
|
|
1136
2071
|
__all__.extend(optim.__all__)
|
|
1137
2072
|
__all__.extend(linalg.__all__)
|
|
2073
|
+
__all__.extend(special.__all__)
|
|
2074
|
+
__all__.extend(distributed.__all__)
|