mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""comm_func"""
|
|
17
|
+
from ..ops.function.comm_func import (
|
|
18
|
+
TCPStore,
|
|
19
|
+
init_process_group,
|
|
20
|
+
destroy_process_group,
|
|
21
|
+
get_rank,
|
|
22
|
+
get_world_size,
|
|
23
|
+
new_group,
|
|
24
|
+
get_backend,
|
|
25
|
+
get_global_rank,
|
|
26
|
+
get_process_group_ranks,
|
|
27
|
+
get_group_rank,
|
|
28
|
+
all_reduce,
|
|
29
|
+
all_gather_into_tensor,
|
|
30
|
+
all_gather_into_tensor_uneven,
|
|
31
|
+
all_to_all,
|
|
32
|
+
all_to_all_single,
|
|
33
|
+
reduce_scatter_tensor,
|
|
34
|
+
reduce_scatter_tensor_uneven,
|
|
35
|
+
isend,
|
|
36
|
+
irecv,
|
|
37
|
+
send,
|
|
38
|
+
recv,
|
|
39
|
+
barrier,
|
|
40
|
+
broadcast,
|
|
41
|
+
reduce,
|
|
42
|
+
P2POp,
|
|
43
|
+
batch_isend_irecv,
|
|
44
|
+
gather,
|
|
45
|
+
scatter,
|
|
46
|
+
all_gather,
|
|
47
|
+
reduce_scatter,
|
|
48
|
+
all_gather_object,
|
|
49
|
+
broadcast_object_list,
|
|
50
|
+
gather_object,
|
|
51
|
+
scatter_object_list,
|
|
52
|
+
is_available,
|
|
53
|
+
is_initialized,
|
|
54
|
+
set_comm_ops_inplace,
|
|
55
|
+
all_to_all_v_c,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
__all__ = [
|
|
59
|
+
"TCPStore",
|
|
60
|
+
"init_process_group",
|
|
61
|
+
"destroy_process_group",
|
|
62
|
+
"get_rank",
|
|
63
|
+
"get_world_size",
|
|
64
|
+
"new_group",
|
|
65
|
+
"get_backend",
|
|
66
|
+
"get_global_rank",
|
|
67
|
+
"get_process_group_ranks",
|
|
68
|
+
"get_group_rank",
|
|
69
|
+
"all_reduce",
|
|
70
|
+
"all_gather_into_tensor",
|
|
71
|
+
"all_gather_into_tensor_uneven",
|
|
72
|
+
"all_to_all",
|
|
73
|
+
"all_to_all_single",
|
|
74
|
+
"reduce_scatter_tensor",
|
|
75
|
+
"reduce_scatter_tensor_uneven",
|
|
76
|
+
"isend",
|
|
77
|
+
"irecv",
|
|
78
|
+
"send",
|
|
79
|
+
"recv",
|
|
80
|
+
"gather",
|
|
81
|
+
"scatter",
|
|
82
|
+
"all_gather",
|
|
83
|
+
"reduce_scatter",
|
|
84
|
+
"barrier",
|
|
85
|
+
"broadcast",
|
|
86
|
+
"reduce",
|
|
87
|
+
"P2POp",
|
|
88
|
+
"batch_isend_irecv",
|
|
89
|
+
"all_gather_object",
|
|
90
|
+
"broadcast_object_list",
|
|
91
|
+
"gather_object",
|
|
92
|
+
"scatter_object_list",
|
|
93
|
+
"is_available",
|
|
94
|
+
"is_initialized",
|
|
95
|
+
"set_comm_ops_inplace",
|
|
96
|
+
'all_to_all_v_c',
|
|
97
|
+
]
|
|
@@ -21,7 +21,7 @@ Pre-defined combination of operators.
|
|
|
21
21
|
|
|
22
22
|
from __future__ import absolute_import
|
|
23
23
|
from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
|
|
24
|
-
tail, zip_operation, _Vmap, _TaylorOperation, iter_converter, ms_hasnext, ms_next
|
|
24
|
+
tail, zip_operation, _Vmap, _TaylorOperation, iter_converter, ms_hasnext, ms_next, recompute_block
|
|
25
25
|
from mindspore.ops.composite.env_ops import env_get
|
|
26
26
|
from mindspore.ops.function.clip_func import clip_by_global_norm
|
|
27
27
|
from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
|
|
@@ -30,6 +30,7 @@ from mindspore.ops.composite.multitype_ops.sub_impl import augassign_sub
|
|
|
30
30
|
from mindspore.ops.composite.multitype_ops.mul_impl import augassign_mul
|
|
31
31
|
from mindspore.ops.composite.multitype_ops.div_impl import augassign_div
|
|
32
32
|
from mindspore.ops.composite.multitype_ops.floordiv_impl import augassign_floordiv
|
|
33
|
+
from mindspore.ops.composite.multitype_ops.mod_impl import augassign_mod
|
|
33
34
|
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like, _ones_like_for_grad
|
|
34
35
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
35
36
|
from mindspore.ops.function.random_func import normal, laplace, uniform, gamma, poisson, multinomial
|
|
@@ -54,6 +55,7 @@ __all__ = [
|
|
|
54
55
|
'augassign_mul',
|
|
55
56
|
'augassign_div',
|
|
56
57
|
'augassign_floordiv',
|
|
58
|
+
'augassign_mod',
|
|
57
59
|
'zeros_like',
|
|
58
60
|
'ones_like',
|
|
59
61
|
'_ones_like_for_grad',
|
|
@@ -79,4 +81,5 @@ __all__ = [
|
|
|
79
81
|
'_Vmap',
|
|
80
82
|
'iter_converter',
|
|
81
83
|
'ms_hasnext',
|
|
82
|
-
'ms_next'
|
|
84
|
+
'ms_next',
|
|
85
|
+
'recompute_block']
|
mindspore/ops/composite/base.py
CHANGED
|
@@ -31,7 +31,7 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
|
|
|
31
31
|
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
|
|
32
32
|
ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
|
|
33
33
|
ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
|
|
34
|
-
HandleBoolTensor_, PreSetitemByTuple_, StarredGetItem_, \
|
|
34
|
+
HandleBoolTensor_, PreSetitemByTuple_, StarredGetItem_, RecomputeBlock_, \
|
|
35
35
|
StarredUnpack_, StarredUnpackMerge_, IterConverter_, HasNext_, Next_, MSContext
|
|
36
36
|
from mindspore.common import dtype as mstype
|
|
37
37
|
from mindspore.common.api import jit, _pynative_executor, _wrap_func
|
|
@@ -1315,3 +1315,17 @@ class _Next(Next_):
|
|
|
1315
1315
|
|
|
1316
1316
|
ms_next = _Next('next')
|
|
1317
1317
|
"""`ms_next` will get next element and res elements for input"""
|
|
1318
|
+
|
|
1319
|
+
|
|
1320
|
+
class _RecomputeBlock(RecomputeBlock_):
|
|
1321
|
+
"""Set the block to be recomputed"""
|
|
1322
|
+
|
|
1323
|
+
def __init__(self, name):
|
|
1324
|
+
"""Initialize RecomputeBlock_."""
|
|
1325
|
+
RecomputeBlock_.__init__(self, name)
|
|
1326
|
+
|
|
1327
|
+
def __call__(self, *args):
|
|
1328
|
+
pass
|
|
1329
|
+
|
|
1330
|
+
|
|
1331
|
+
recompute_block = _RecomputeBlock("recompute_block")
|
|
@@ -50,6 +50,7 @@ from mindspore.ops.composite.multitype_ops.sub_impl import augassign_sub
|
|
|
50
50
|
from mindspore.ops.composite.multitype_ops.mul_impl import augassign_mul
|
|
51
51
|
from mindspore.ops.composite.multitype_ops.div_impl import augassign_div
|
|
52
52
|
from mindspore.ops.composite.multitype_ops.floordiv_impl import augassign_floordiv
|
|
53
|
+
from mindspore.ops.composite.multitype_ops.mod_impl import augassign_mod
|
|
53
54
|
|
|
54
55
|
__all__ = [
|
|
55
56
|
'add',
|
|
@@ -87,5 +88,6 @@ __all__ = [
|
|
|
87
88
|
'augassign_sub',
|
|
88
89
|
'augassign_mul',
|
|
89
90
|
'augassign_div',
|
|
90
|
-
'augassign_floordiv'
|
|
91
|
+
'augassign_floordiv',
|
|
92
|
+
'augassign_mod'
|
|
91
93
|
]
|
|
@@ -344,6 +344,152 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
|
|
|
344
344
|
return self_viewed, remain_indexes, need_index_prim
|
|
345
345
|
|
|
346
346
|
|
|
347
|
+
def _get_need_index_prim(index, need_index_prim):
|
|
348
|
+
if isinstance(index, bool):
|
|
349
|
+
need_index_prim = True
|
|
350
|
+
elif isinstance(index, Tensor):
|
|
351
|
+
if F.rank(index) == 0 and index.dtype in mstype.int_type + mstype.uint_type + (mstype.bool_,):
|
|
352
|
+
if index.dtype not in mstype.int_type + mstype.uint_type:
|
|
353
|
+
need_index_prim = True
|
|
354
|
+
else:
|
|
355
|
+
need_index_prim = True
|
|
356
|
+
return need_index_prim
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _process_with_inplace_index_input(prev_result, orig_tensor, index, dim, dim_index, remain_indexes, orig_dim, value):
|
|
360
|
+
"""Process dim in multi dim index"""
|
|
361
|
+
result = prev_result
|
|
362
|
+
if isinstance(index, bool):
|
|
363
|
+
result = expand_dims_view_op(prev_result, dim)
|
|
364
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
365
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes)) + (index_for_bool,)
|
|
366
|
+
inplace_index_put_op(result, remain_indexes, value, False)
|
|
367
|
+
elif isinstance(index, int):
|
|
368
|
+
result = _do_select(prev_result, dim, index, dim_index, F.shape(orig_tensor)[orig_dim])
|
|
369
|
+
inplace_index_put_op(result, remain_indexes, value, False)
|
|
370
|
+
elif isinstance(index, slice):
|
|
371
|
+
result = _do_slice(prev_result, dim, index, F.shape(orig_tensor)[orig_dim])
|
|
372
|
+
inplace_index_put_op(result, remain_indexes, value, False)
|
|
373
|
+
elif isinstance(index, EllipsisType):
|
|
374
|
+
inplace_index_put_op(result, remain_indexes, value, False)
|
|
375
|
+
elif index is None:
|
|
376
|
+
result = expand_dims_view_op(prev_result, dim)
|
|
377
|
+
inplace_index_put_op(result, remain_indexes, value, False)
|
|
378
|
+
elif isinstance(index, Tensor):
|
|
379
|
+
result = prev_result
|
|
380
|
+
if F.rank(index) == 0 and index.dtype in mstype.int_type + mstype.uint_type + (mstype.bool_,):
|
|
381
|
+
if index.dtype in mstype.int_type + mstype.uint_type:
|
|
382
|
+
index_py = TensorToScalar()(index)
|
|
383
|
+
dim_size = F.shape(orig_tensor)[orig_dim]
|
|
384
|
+
if index_py >= dim_size or index_py < -dim_size:
|
|
385
|
+
raise IndexError("Index is out of bounds.")
|
|
386
|
+
new_index = (index_py + dim_size) % dim_size
|
|
387
|
+
result = select_ext_view_op(prev_result, dim, new_index)
|
|
388
|
+
# in graph mode, remain_indexes in different branch requires same size, so we fill empty tensor to it
|
|
389
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes) + 1)
|
|
390
|
+
inplace_index_put_op(result, remain_indexes, value, False)
|
|
391
|
+
else:
|
|
392
|
+
# process index with Tensor bool type
|
|
393
|
+
result = expand_dims_view_op(prev_result, dim)
|
|
394
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
395
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes)) + \
|
|
396
|
+
(index_for_bool,)
|
|
397
|
+
inplace_index_put_op(result, remain_indexes, value, False)
|
|
398
|
+
else:
|
|
399
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes)) + (index,)
|
|
400
|
+
inplace_index_put_op(result, remain_indexes, value, False)
|
|
401
|
+
else:
|
|
402
|
+
raise IndexError("Invalid tensor index type")
|
|
403
|
+
return orig_tensor
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _process_with_do_copy(prev_result, orig_tensor, index, dim, dim_index, remain_indexes, orig_dim, value):
|
|
407
|
+
"""Process dim in multi dim index"""
|
|
408
|
+
result = prev_result
|
|
409
|
+
if isinstance(index, bool):
|
|
410
|
+
result = expand_dims_view_op(prev_result, dim)
|
|
411
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
412
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes)) + (index_for_bool,)
|
|
413
|
+
do_copy(result, value)
|
|
414
|
+
elif isinstance(index, int):
|
|
415
|
+
result = _do_select(prev_result, dim, index, dim_index, F.shape(orig_tensor)[orig_dim])
|
|
416
|
+
do_copy(result, value)
|
|
417
|
+
elif isinstance(index, slice):
|
|
418
|
+
result = _do_slice(prev_result, dim, index, F.shape(orig_tensor)[orig_dim])
|
|
419
|
+
do_copy(result, value)
|
|
420
|
+
elif isinstance(index, EllipsisType):
|
|
421
|
+
do_copy(result, value)
|
|
422
|
+
elif index is None:
|
|
423
|
+
result = expand_dims_view_op(prev_result, dim)
|
|
424
|
+
do_copy(result, value)
|
|
425
|
+
elif isinstance(index, Tensor):
|
|
426
|
+
result = prev_result
|
|
427
|
+
if F.rank(index) == 0 and index.dtype in mstype.int_type + mstype.uint_type + (mstype.bool_,):
|
|
428
|
+
if index.dtype in mstype.int_type + mstype.uint_type:
|
|
429
|
+
index_py = TensorToScalar()(index)
|
|
430
|
+
dim_size = F.shape(orig_tensor)[orig_dim]
|
|
431
|
+
if index_py >= dim_size or index_py < -dim_size:
|
|
432
|
+
raise IndexError("Index is out of bounds.")
|
|
433
|
+
new_index = (index_py + dim_size) % dim_size
|
|
434
|
+
result = select_ext_view_op(prev_result, dim, new_index)
|
|
435
|
+
# in graph mode, remain_indexes in different branch requires same size, so we fill empty tensor to it
|
|
436
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes) + 1)
|
|
437
|
+
do_copy(result, value)
|
|
438
|
+
else:
|
|
439
|
+
# process index with Tensor bool type
|
|
440
|
+
result = expand_dims_view_op(prev_result, dim)
|
|
441
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
442
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes)) + \
|
|
443
|
+
(index_for_bool,)
|
|
444
|
+
do_copy(result, value)
|
|
445
|
+
else:
|
|
446
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes)) + (index,)
|
|
447
|
+
do_copy(result, value)
|
|
448
|
+
else:
|
|
449
|
+
raise IndexError("Invalid tensor index type")
|
|
450
|
+
return orig_tensor
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def _process_multi_dim_index_for_setitem(self, indexes, remain_indexes, indexed_dims, value):
|
|
454
|
+
"""Process indexes in tuple"""
|
|
455
|
+
self_viewed = self
|
|
456
|
+
dim = 0
|
|
457
|
+
orig_dim = 0
|
|
458
|
+
need_index_prim = False
|
|
459
|
+
preprocessed_index = []
|
|
460
|
+
for index in indexes:
|
|
461
|
+
if isinstance(index, (list, tuple, np.ndarray)):
|
|
462
|
+
if not F.isconstant(index):
|
|
463
|
+
raise IndexError(
|
|
464
|
+
"Current Tensor indexing does not support mutable list/tuple or list containing tensors. "
|
|
465
|
+
"Please use an immutable expression instead.")
|
|
466
|
+
index = Tensor(index)
|
|
467
|
+
if isinstance(index, Tensor) and \
|
|
468
|
+
F.dtype(index) in (mstype.int8, mstype.int16, mstype.uint16, mstype.uint32,
|
|
469
|
+
mstype.uint64, mstype.float16, mstype.float32, mstype.float64):
|
|
470
|
+
# only uint8, int32 and int64 are supported by IndexOp
|
|
471
|
+
index = F.cast(index, mstype.int64)
|
|
472
|
+
preprocessed_index.append(index)
|
|
473
|
+
need_index_prim = _get_need_index_prim(index, need_index_prim)
|
|
474
|
+
|
|
475
|
+
if not preprocessed_index:
|
|
476
|
+
return do_copy(self_viewed, value)
|
|
477
|
+
|
|
478
|
+
result = self
|
|
479
|
+
|
|
480
|
+
for i, index in enumerate(preprocessed_index):
|
|
481
|
+
if i == len(preprocessed_index) - 1:
|
|
482
|
+
if need_index_prim:
|
|
483
|
+
result = _process_with_inplace_index_input(self_viewed,
|
|
484
|
+
self, index, dim, i, remain_indexes, orig_dim, value)
|
|
485
|
+
else:
|
|
486
|
+
result = _process_with_do_copy(self_viewed, self, index, dim, i, remain_indexes, orig_dim, value)
|
|
487
|
+
else:
|
|
488
|
+
self_viewed, dim, remain_indexes, orig_dim, _ = _process_dim_in_multi_dim_index(
|
|
489
|
+
self_viewed, self, index, dim, indexed_dims, i, remain_indexes, orig_dim, True)
|
|
490
|
+
return result
|
|
491
|
+
|
|
492
|
+
|
|
347
493
|
def _check_type_of_list_index(index_list):
|
|
348
494
|
"""Check type of element in list index"""
|
|
349
495
|
for index in index_list:
|
|
@@ -446,13 +592,7 @@ def _tensor_setitem(self, index, value):
|
|
|
446
592
|
if F.rank(self) < indexed_dims:
|
|
447
593
|
raise IndexError("For setitem, there are too many indices")
|
|
448
594
|
remain_indexes = ()
|
|
449
|
-
|
|
450
|
-
if not need_index_prim:
|
|
451
|
-
do_copy(self_viewed, value)
|
|
452
|
-
return self
|
|
453
|
-
inplace_index_put_op(self_viewed, remain_indexes, value, False)
|
|
454
|
-
return self
|
|
455
|
-
|
|
595
|
+
return _process_multi_dim_index_for_setitem(self, indexes, remain_indexes, indexed_dims, value)
|
|
456
596
|
|
|
457
597
|
setattr(tensor_operator_registry, "_tensor_getitem", _tensor_getitem)
|
|
458
598
|
setattr(tensor_operator_registry, "_tensor_setitem", _tensor_setitem)
|
|
@@ -660,7 +800,9 @@ def handle_empty_tensor(arg, data):
|
|
|
660
800
|
if 0 in arg:
|
|
661
801
|
init_func = Zero()
|
|
662
802
|
init_func.__enable_zero_dim__ = True
|
|
663
|
-
|
|
803
|
+
zero_tensor = Tensor(shape=arg, dtype=data.dtype, init=init_func)
|
|
804
|
+
zero_tensor.init_data()
|
|
805
|
+
return zero_tensor
|
|
664
806
|
return const_utils.make_tensor([], data.dtype, arg)
|
|
665
807
|
|
|
666
808
|
|
|
@@ -70,6 +70,13 @@ _tuple_add = _TupleAdd('tuple_add')
|
|
|
70
70
|
"""`_tuple_add` is an metafuncgraph object which will concatenate two tuples to form a tuple."""
|
|
71
71
|
|
|
72
72
|
|
|
73
|
+
def _create_tuple_add(name):
|
|
74
|
+
"""
|
|
75
|
+
Create and return a new `_TupleAdd` instance.
|
|
76
|
+
"""
|
|
77
|
+
return _TupleAdd(name)
|
|
78
|
+
|
|
79
|
+
|
|
73
80
|
class _DictUpdate(base.DictUpdate_):
|
|
74
81
|
"""
|
|
75
82
|
A metafuncgraph class that append another dict to the end of the dict.
|
|
@@ -20,7 +20,15 @@ from mindspore.ops.composite import base
|
|
|
20
20
|
from mindspore.ops import functional as F
|
|
21
21
|
from mindspore.ops.auto_generate import (remainder_tensor_tensor_op, remainder_tensor_scalar_op,
|
|
22
22
|
remainder_scalar_tensor_op)
|
|
23
|
+
from mindspore.ops.auto_generate.gen_ops_prim import InplaceRemainderTensorTensor, InplaceRemainderTensorScalar
|
|
23
24
|
|
|
25
|
+
# x %= y
|
|
26
|
+
augassign_mod = base.MultitypeFuncGraph("augassign_mod", True)
|
|
27
|
+
"""
|
|
28
|
+
`augassign_mod` is a metafuncgraph object which will compute the mod of two objects
|
|
29
|
+
using ".register" decorator.
|
|
30
|
+
"""
|
|
31
|
+
augassign_mod.set_need_raise()
|
|
24
32
|
|
|
25
33
|
mod = base.MultitypeFuncGraph("mod", True)
|
|
26
34
|
"""
|
|
@@ -30,6 +38,7 @@ using ".register" decorator.
|
|
|
30
38
|
mod.set_need_raise()
|
|
31
39
|
|
|
32
40
|
|
|
41
|
+
@augassign_mod.register("Number", "Number")
|
|
33
42
|
@mod.register("Number", "Number")
|
|
34
43
|
def _mod_scalar(x, y):
|
|
35
44
|
"""Returns x % y where x and y are all scalars."""
|
|
@@ -42,18 +51,32 @@ def _mod_tensor(x, y):
|
|
|
42
51
|
return remainder_tensor_tensor_op(x, y)
|
|
43
52
|
|
|
44
53
|
|
|
54
|
+
@augassign_mod.register("Tensor", "Tensor")
|
|
55
|
+
def _mul_tensor_augassign(x, y):
|
|
56
|
+
"""Returns x % y where x and y are all tensors."""
|
|
57
|
+
return InplaceRemainderTensorTensor()(x, y)
|
|
58
|
+
|
|
59
|
+
|
|
45
60
|
@mod.register("Tensor", "Number")
|
|
46
61
|
def _tensor_mod_scalar(x, y):
|
|
47
62
|
"""Returns x % y where x is a tensor and y is a scalar. x and y should have same dtype."""
|
|
48
63
|
return remainder_tensor_scalar_op(x, y)
|
|
49
64
|
|
|
50
65
|
|
|
66
|
+
@augassign_mod.register("Tensor", "Number")
|
|
67
|
+
def _tensor_mod_scalar_augassign(x, y):
|
|
68
|
+
"""Returns x % y where x is a tensor and y is a scalar. x and y should have same dtype."""
|
|
69
|
+
return InplaceRemainderTensorScalar()(x, y)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@augassign_mod.register("Number", "Tensor")
|
|
51
73
|
@mod.register("Number", "Tensor")
|
|
52
74
|
def _scalar_mod_tensor(x, y):
|
|
53
75
|
"""Returns x % y where x is a scalar and y is a tensor. x and y should have same dtype."""
|
|
54
76
|
return remainder_scalar_tensor_op(x, y)
|
|
55
77
|
|
|
56
78
|
|
|
79
|
+
@augassign_mod.register("Tuple", "Tensor")
|
|
57
80
|
@mod.register("Tuple", "Tensor")
|
|
58
81
|
def _tuple_mod_tensor(x, y):
|
|
59
82
|
"""Returns x % y where x is a tuple and y is a tensor. """
|
|
@@ -61,6 +84,7 @@ def _tuple_mod_tensor(x, y):
|
|
|
61
84
|
return F.tensor_mod(x, y)
|
|
62
85
|
|
|
63
86
|
|
|
87
|
+
@augassign_mod.register("Tensor", "Tuple")
|
|
64
88
|
@mod.register("Tensor", "Tuple")
|
|
65
89
|
def _tensor_mod_tuple(x, y):
|
|
66
90
|
"""Returns x % y where x is a tensor and y is a tuple. """
|
|
@@ -68,6 +92,7 @@ def _tensor_mod_tuple(x, y):
|
|
|
68
92
|
return F.tensor_mod(x, y)
|
|
69
93
|
|
|
70
94
|
|
|
95
|
+
@augassign_mod.register("List", "Tensor")
|
|
71
96
|
@mod.register("List", "Tensor")
|
|
72
97
|
def _list_mod_tensor(x, y):
|
|
73
98
|
"""Returns x % y where x is a list and y is a tensor. """
|
|
@@ -75,6 +100,7 @@ def _list_mod_tensor(x, y):
|
|
|
75
100
|
return F.tensor_mod(x, y)
|
|
76
101
|
|
|
77
102
|
|
|
103
|
+
@augassign_mod.register("Tensor", "List")
|
|
78
104
|
@mod.register("Tensor", "List")
|
|
79
105
|
def _tensor_mod_list(x, y):
|
|
80
106
|
"""Returns x % y where x is a tensor and y is a list. """
|
|
@@ -83,6 +109,7 @@ def _tensor_mod_list(x, y):
|
|
|
83
109
|
|
|
84
110
|
|
|
85
111
|
# pylint: disable=protected-access
|
|
112
|
+
@augassign_mod._register_default()
|
|
86
113
|
@mod._register_default()
|
|
87
114
|
def default_mod(x, y):
|
|
88
115
|
"""Default function for mod."""
|
|
@@ -33,10 +33,10 @@ from mindspore.ops.operations._sequence_ops import TupleToTensor
|
|
|
33
33
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
34
34
|
from mindspore.ops.operations._sequence_ops import TensorToList
|
|
35
35
|
# 1
|
|
36
|
-
from mindspore.ops.auto_generate import OnesLikeExt, ZerosLikeExt, FillScalar, FillTensor, Arange,
|
|
36
|
+
from mindspore.ops.auto_generate import OnesLikeExt, ZerosLikeExt, FillScalar, FillTensor, Arange, UniqueDim, \
|
|
37
37
|
Unique2, SortExt, NonZero, NonZeroExt, Scatter, ScatterValue, NewOnes, NewZeros
|
|
38
38
|
# 2
|
|
39
|
-
|
|
39
|
+
from mindspore.ops.auto_generate.pyboost_inner_prim import squeeze_impl
|
|
40
40
|
# 3
|
|
41
41
|
|
|
42
42
|
# 4
|
|
@@ -108,7 +108,7 @@ from mindspore.ops.auto_generate import cat, range, scatter_nd, deepcopy, masked
|
|
|
108
108
|
index_fill_scalar, index_fill_tensor
|
|
109
109
|
from mindspore.ops.auto_generate import take, tensor_scatter_elements as tensor_scatter_elements_ext
|
|
110
110
|
from mindspore.ops.auto_generate.gen_ops_prim import scatter_add_ext_op, gather_d_op, slice_op, tril_ext_op, \
|
|
111
|
-
split_tensor_op, split_with_size_op
|
|
111
|
+
split_tensor_op, split_with_size_op, chunk_op
|
|
112
112
|
from mindspore.ops.operations.manually_defined import tile, rank, scalar_cast
|
|
113
113
|
from mindspore.ops.auto_generate.pyboost_inner_prim import _PyboostOneHotExtPrim
|
|
114
114
|
|
|
@@ -181,7 +181,6 @@ sort_ext_ = SortExt()
|
|
|
181
181
|
scatter_prim = Scatter()
|
|
182
182
|
scatter_value_ = ScatterValue()
|
|
183
183
|
arange_ = Arange()
|
|
184
|
-
chunk_ = Chunk()
|
|
185
184
|
repeat_interleave_int_ = RepeatInterleaveInt()
|
|
186
185
|
repeat_interleave_tensor_ = RepeatInterleaveTensor()
|
|
187
186
|
unique_dim_ = UniqueDim()
|
|
@@ -994,7 +993,7 @@ def chunk_ext(input, chunks, dim=0):
|
|
|
994
993
|
Tensor(shape=[3], dtype=Float32, value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
|
|
995
994
|
Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
|
|
996
995
|
"""
|
|
997
|
-
return
|
|
996
|
+
return chunk_op(input, chunks, dim)
|
|
998
997
|
|
|
999
998
|
|
|
1000
999
|
def fills(x, value):
|
|
@@ -1960,10 +1959,7 @@ def squeeze(input, axis=None):
|
|
|
1960
1959
|
"""
|
|
1961
1960
|
if axis is None:
|
|
1962
1961
|
axis = ()
|
|
1963
|
-
|
|
1964
|
-
axis = tuple(axis)
|
|
1965
|
-
squeeze_ = _get_cache_prim(P.Squeeze)(axis)
|
|
1966
|
-
return squeeze_(input)
|
|
1962
|
+
return squeeze_impl(input, axis)
|
|
1967
1963
|
|
|
1968
1964
|
|
|
1969
1965
|
def scatter_mul(input_x, indices, updates):
|
|
@@ -2835,6 +2831,11 @@ def tensor_scatter_add(input_x, indices, updates):
|
|
|
2835
2831
|
.. math::
|
|
2836
2832
|
output\left [indices \right ] = input\_x + update
|
|
2837
2833
|
|
|
2834
|
+
The figure below shows an example of the computational process of tensor_scatter_add:
|
|
2835
|
+
|
|
2836
|
+
.. image:: ../images/TensorScatterAdd.png
|
|
2837
|
+
:align: center
|
|
2838
|
+
|
|
2838
2839
|
Note:
|
|
2839
2840
|
- On GPU, if some values of the `indices` are out of bound, instead of raising an index error,
|
|
2840
2841
|
the corresponding `updates` will not be updated to self tensor.
|
|
@@ -3555,7 +3556,7 @@ def matrix_set_diag(x, diagonal, k=0, align="RIGHT_LEFT"): # pylint: disable=re
|
|
|
3555
3556
|
return matrix_set_diag_v3_op(x, diagonal, k)
|
|
3556
3557
|
|
|
3557
3558
|
|
|
3558
|
-
def meshgrid_ext(*tensors, indexing=
|
|
3559
|
+
def meshgrid_ext(*tensors, indexing=None):
|
|
3559
3560
|
"""
|
|
3560
3561
|
Generates coordinate matrices from given coordinate tensors.
|
|
3561
3562
|
|
|
@@ -3577,7 +3578,7 @@ def meshgrid_ext(*tensors, indexing='ij'):
|
|
|
3577
3578
|
for ``'ij'`` indexing, the shape of outputs is :math:`(M, N)`. In the 3-D
|
|
3578
3579
|
case with inputs of length `M`, `N` and `P`, for ``'xy'`` indexing, the shape of outputs is
|
|
3579
3580
|
:math:`(N, M, P)` and for ``'ij'`` indexing, the shape of outputs is :math:`(M, N, P)`.
|
|
3580
|
-
Default: ``'ij'`` .
|
|
3581
|
+
Default: ``None`` , which is equivalent to the value ``'ij'`` .
|
|
3581
3582
|
|
|
3582
3583
|
Returns:
|
|
3583
3584
|
Tensors, a Tuple of N N-D Tensor objects. The data type is the same with the Inputs.
|
|
@@ -5816,6 +5817,7 @@ def nonzero(input, *, as_tuple=False):
|
|
|
5816
5817
|
.. note::
|
|
5817
5818
|
- Ascend: Rank of Input tensor can be equal to 0 except GE backend.
|
|
5818
5819
|
- CPU/GPU: Rank of Input tensor should be greater than or eaqual to 1.
|
|
5820
|
+
- Currently, only the Ascend backend is supported when `as_tuple` is ``True``.
|
|
5819
5821
|
|
|
5820
5822
|
Keyword Args:
|
|
5821
5823
|
as_tuple (bool, optional): Whether the output is tuple. Default ``False`` .
|
|
@@ -6305,7 +6307,7 @@ def repeat_elements(x, rep, axis=0):
|
|
|
6305
6307
|
Repeat elements of a tensor along an axis, like :func:`mindspore.numpy.repeat` .
|
|
6306
6308
|
|
|
6307
6309
|
Note:
|
|
6308
|
-
It is recommended to use :func
|
|
6310
|
+
It is recommended to use :func:`mindspore.mint.repeat_interleave`, the dimension of input 'x' can support
|
|
6309
6311
|
a maximum of 8, and get better performance.
|
|
6310
6312
|
|
|
6311
6313
|
Args:
|