mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/common/recompute.py
CHANGED
|
@@ -23,8 +23,10 @@ from mindspore.common.tensor import Tensor
|
|
|
23
23
|
from mindspore import ops
|
|
24
24
|
from mindspore.ops.composite import GradOperation
|
|
25
25
|
from mindspore.common._register_for_recompute import recompute_registry
|
|
26
|
-
from mindspore.common.api import _pynative_executor
|
|
26
|
+
from mindspore.common.api import _pynative_executor, _no_grad
|
|
27
27
|
from mindspore.common.generator import get_rng_state, set_rng_state
|
|
28
|
+
from mindspore.train.amp import amp_decorator
|
|
29
|
+
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class _WrapCell(Cell):
|
|
@@ -34,7 +36,7 @@ class _WrapCell(Cell):
|
|
|
34
36
|
"""
|
|
35
37
|
|
|
36
38
|
def __init__(self, function):
|
|
37
|
-
super(_WrapCell, self).__init__()
|
|
39
|
+
super(_WrapCell, self).__init__(auto_prefix=False)
|
|
38
40
|
self.function = function
|
|
39
41
|
|
|
40
42
|
def construct(self, *args, **kwargs):
|
|
@@ -56,6 +58,7 @@ class _RecomputeCell(Cell):
|
|
|
56
58
|
self.args = []
|
|
57
59
|
self.kwargs = []
|
|
58
60
|
self.wrap_cell = _WrapCell(block)
|
|
61
|
+
self.wrap_cell.set_inputs()
|
|
59
62
|
|
|
60
63
|
self.net = block
|
|
61
64
|
self.internal_params = []
|
|
@@ -64,15 +67,18 @@ class _RecomputeCell(Cell):
|
|
|
64
67
|
self._add_attr("is_cell_recompute", "True")
|
|
65
68
|
self.grad = GradOperation(get_all=True, get_by_list=True, sens_param=True)
|
|
66
69
|
self.init_mixed_precision_type(block)
|
|
70
|
+
self.amp_strategy = None
|
|
67
71
|
|
|
68
72
|
def construct(self, *args, **kwargs):
|
|
69
|
-
_check_input_args_validate(self.net, args)
|
|
73
|
+
_check_input_args_validate(self.net, args, kwargs)
|
|
70
74
|
self.args.append(args)
|
|
71
75
|
self.kwargs.append(kwargs)
|
|
72
76
|
self.save_rng_state = kwargs.pop("save_rng_state", True)
|
|
73
77
|
if self.save_rng_state:
|
|
74
78
|
self.cpu_rng_state = get_rng_state()
|
|
75
|
-
|
|
79
|
+
self.amp_strategy = get_curr_amp_strategy()
|
|
80
|
+
with _no_grad():
|
|
81
|
+
return self.net(*args, **kwargs)
|
|
76
82
|
|
|
77
83
|
def bprop(self, *args):
|
|
78
84
|
"""
|
|
@@ -86,14 +92,23 @@ class _RecomputeCell(Cell):
|
|
|
86
92
|
self.args.pop()
|
|
87
93
|
self.kwargs.pop()
|
|
88
94
|
if kwargs:
|
|
89
|
-
|
|
95
|
+
input_args_for_check = list(input_args) + list(kwargs.values())
|
|
96
|
+
else:
|
|
97
|
+
input_args_for_check = list(input_args)
|
|
90
98
|
# To detach inputs to avoid erasing auto grad meta info of origin inputs.
|
|
91
99
|
input_args = _detach_input(input_args)
|
|
100
|
+
kwargs = _detach_input(kwargs)
|
|
101
|
+
kwargs['sens'] = grad_input
|
|
92
102
|
try:
|
|
93
103
|
pre_rng_state = get_rng_state()
|
|
94
104
|
set_rng_state(self.cpu_rng_state)
|
|
95
105
|
_pynative_executor.set_is_run_recompute(True)
|
|
96
|
-
|
|
106
|
+
if self.amp_strategy:
|
|
107
|
+
with amp_decorator(self.amp_strategy.get_amp_level(), self.amp_strategy.get_amp_dtype(),
|
|
108
|
+
self.amp_strategy.get_white_list(), self.amp_strategy.get_black_list()):
|
|
109
|
+
grads = self.grad(self.net, self.internal_params)(*input_args, **kwargs)
|
|
110
|
+
else:
|
|
111
|
+
grads = self.grad(self.net, self.internal_params)(*input_args, **kwargs)
|
|
97
112
|
_pynative_executor.set_is_run_recompute(False)
|
|
98
113
|
set_rng_state(pre_rng_state)
|
|
99
114
|
except Exception as err:
|
|
@@ -101,7 +116,7 @@ class _RecomputeCell(Cell):
|
|
|
101
116
|
raise err
|
|
102
117
|
weights = OrderedDict()
|
|
103
118
|
input_grads = list(grads[0])
|
|
104
|
-
_padding_input_grads(
|
|
119
|
+
_padding_input_grads(input_args_for_check, input_grads)
|
|
105
120
|
for i, param in enumerate(self.internal_params):
|
|
106
121
|
weights[param] = grads[1][i]
|
|
107
122
|
return tuple(input_grads), weights
|
|
@@ -121,6 +136,7 @@ class _RecomputeCell(Cell):
|
|
|
121
136
|
# To avoid sub cell same name
|
|
122
137
|
block.__self__.check_names_and_refresh_name()
|
|
123
138
|
self.internal_params = block.__self__.trainable_params()
|
|
139
|
+
self.wrap_cell.mixed_precision_type = block.__self__.get_mixed_precision_type()
|
|
124
140
|
self.wrap_cell.set_mixed_precision_type(block.__self__.get_mixed_precision_type())
|
|
125
141
|
self.net = self.wrap_cell
|
|
126
142
|
else:
|
|
@@ -128,13 +144,14 @@ class _RecomputeCell(Cell):
|
|
|
128
144
|
"only support Cell object or MethodType function!")
|
|
129
145
|
|
|
130
146
|
|
|
131
|
-
def _check_input_args_validate(block, args):
|
|
147
|
+
def _check_input_args_validate(block, args, kwargs):
|
|
132
148
|
"""
|
|
133
149
|
Check recompute input args validate
|
|
134
150
|
:param args:
|
|
135
151
|
:return:
|
|
136
152
|
"""
|
|
137
|
-
if not any([isinstance(arg, Tensor) for arg in args])
|
|
153
|
+
if not (any([isinstance(arg, Tensor) for arg in args]) or \
|
|
154
|
+
any([isinstance(arg, Tensor) for arg in kwargs.values()])):
|
|
138
155
|
logger.warning("None of the inputs of function are tensors, which not need use recompute!")
|
|
139
156
|
for arg in args:
|
|
140
157
|
if isinstance(arg, (tuple, list)):
|
|
@@ -168,6 +185,11 @@ def _padding_input_grads(args, input_grads):
|
|
|
168
185
|
|
|
169
186
|
|
|
170
187
|
def _detach_input(input_arg):
|
|
188
|
+
"""
|
|
189
|
+
Detach input
|
|
190
|
+
:param input_arg:
|
|
191
|
+
:return: detach output
|
|
192
|
+
"""
|
|
171
193
|
if isinstance(input_arg, Tensor):
|
|
172
194
|
return ops.stop_gradient(input_arg)
|
|
173
195
|
if isinstance(input_arg, (list, tuple)):
|
|
@@ -175,6 +197,14 @@ def _detach_input(input_arg):
|
|
|
175
197
|
for arg in input_arg:
|
|
176
198
|
detach_inputs.append(_detach_input(arg))
|
|
177
199
|
return detach_inputs if isinstance(input_arg, list) else tuple(detach_inputs)
|
|
200
|
+
if isinstance(input_arg, dict):
|
|
201
|
+
detach_inputs = {}
|
|
202
|
+
for key, val in input_arg.items():
|
|
203
|
+
if isinstance(val, Tensor):
|
|
204
|
+
detach_inputs[key] = ops.stop_gradient(val)
|
|
205
|
+
else:
|
|
206
|
+
detach_inputs[key] = val
|
|
207
|
+
return detach_inputs
|
|
178
208
|
return input_arg
|
|
179
209
|
|
|
180
210
|
|
|
@@ -97,7 +97,8 @@ class RowTensor(RowTensorInner):
|
|
|
97
97
|
[0, 0]]
|
|
98
98
|
|
|
99
99
|
.. warning::
|
|
100
|
-
This is an experimental API that is subjected to change or deletion.
|
|
100
|
+
- This is an experimental API that is subjected to change or deletion.
|
|
101
|
+
- If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
|
|
101
102
|
|
|
102
103
|
Args:
|
|
103
104
|
indices (Tensor): A 1-D integer Tensor of shape :math:`(d_0)` . Default: ``None``.
|
|
@@ -226,10 +227,11 @@ class COOTensor(COOTensor_):
|
|
|
226
227
|
|
|
227
228
|
Common arithmetic operations include: addition (+), subtraction (-), multiplication (*),
|
|
228
229
|
and division (/). For details about operations supported by `COOTensor`, see
|
|
229
|
-
`operators <https://www.mindspore.cn/docs/en/master/
|
|
230
|
+
`operators <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#operators>`_.
|
|
230
231
|
|
|
231
232
|
.. warning::
|
|
232
233
|
- This is an experimental API that is subject to change or deletion.
|
|
234
|
+
- If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
|
|
233
235
|
- Currently, duplicate coordinates in the indices will not be coalesced.
|
|
234
236
|
If the indices contain out-of-bound values, the result will be undefined.
|
|
235
237
|
|
|
@@ -646,6 +648,7 @@ class CSRTensor(CSRTensor_):
|
|
|
646
648
|
[1., 2., 3., 4., 5., 6.], shape is (3, 5), then the dense representation of the sparse tensor will be:
|
|
647
649
|
|
|
648
650
|
.. code-block::
|
|
651
|
+
|
|
649
652
|
[[1., 0., 0., 2., 0.],
|
|
650
653
|
[0., 3., 4., 0., 5.],
|
|
651
654
|
[0., 0., 6., 0., 0.]]
|
|
@@ -668,10 +671,11 @@ class CSRTensor(CSRTensor_):
|
|
|
668
671
|
|
|
669
672
|
Common arithmetic operations include: addition (+), subtraction (-), multiplication (*),
|
|
670
673
|
and division (/). For details about operations supported by `CSRTensor`, see
|
|
671
|
-
`operators <https://www.mindspore.cn/docs/en/master/
|
|
674
|
+
`operators <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#operators>`_.
|
|
672
675
|
|
|
673
676
|
.. warning::
|
|
674
677
|
- This is an experimental API that is subjected to change.
|
|
678
|
+
- If use PyNative mode, set "export MS_PYNATIVE_CONFIG_STATIC_SHAPE=1".
|
|
675
679
|
- If the values given by `indptr` or `indices` are invalid, the results may be undefined. Invalid values include
|
|
676
680
|
when the length of `values` or `indices` exceeds the range indicated by `indptr`, and when the columns
|
|
677
681
|
indicated by `indices` are repeated on the same row.
|
mindspore/common/tensor.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -31,6 +31,8 @@ from mindspore.common.hook_handle import _TensorHookHandle
|
|
|
31
31
|
|
|
32
32
|
from mindspore.common._utils import get_slice_num
|
|
33
33
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
34
|
+
from mindspore.common._tensor_overload import (repeat_interleave_mint, add_mint, item_mint, isnan_mint, flatten_mint,
|
|
35
|
+
max_mint, mean_mint, min_mint, split_mint, sub_mint)
|
|
34
36
|
from mindspore._c_expression import Tensor as Tensor_
|
|
35
37
|
from mindspore import _checkparam as validator
|
|
36
38
|
from mindspore._checkparam import check_is_number, is_stub_tensor, check_hook_fn
|
|
@@ -51,7 +53,7 @@ def _check_input_data_type(input_data):
|
|
|
51
53
|
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
|
|
52
54
|
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
|
|
53
55
|
input_data.dtype.kind != 'U' and input_data.dtype.kind != 'S' and \
|
|
54
|
-
input_data.dtype.kind
|
|
56
|
+
not (input_data.dtype.kind == 'V' and input_data.dtype.char == 'E'): # Support np.str_ and np.bfloat16
|
|
55
57
|
new_line = '\n'
|
|
56
58
|
for index, x in np.ndenumerate(input_data):
|
|
57
59
|
if np.array(x).dtype not in valid_dtypes:
|
|
@@ -85,11 +87,11 @@ def tensor(input_data=None, dtype=None, shape=None, init=None, internal=False, c
|
|
|
85
87
|
based on the `dtype` argument.
|
|
86
88
|
|
|
87
89
|
Please refer to `Creating and Using Tensor
|
|
88
|
-
<https://www.mindspore.cn/docs/en/master/
|
|
90
|
+
<https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#mindspore-user-defined-data-types>`_ .
|
|
89
91
|
|
|
90
92
|
The difference between it and the Tensor class is that it adds
|
|
91
93
|
`Annotation
|
|
92
|
-
<https://www.mindspore.cn/docs/en/master/
|
|
94
|
+
<https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#annotation-type>`_
|
|
93
95
|
which can prevent the generation of AnyType compared to the Tensor class.
|
|
94
96
|
|
|
95
97
|
The arguments and return values are the same as the Tensor class. Also see: :class:`mindspore.Tensor`.
|
|
@@ -143,6 +145,8 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
143
145
|
Default: ``False`` .
|
|
144
146
|
const_arg (bool): Whether the tensor is a constant when it is used for the argument of a network.
|
|
145
147
|
Default: ``False`` .
|
|
148
|
+
device(str): This parameter is reserved and does not need to be configured.
|
|
149
|
+
Default: ``None`` .
|
|
146
150
|
|
|
147
151
|
Outputs:
|
|
148
152
|
Tensor.
|
|
@@ -205,7 +209,8 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
205
209
|
"""
|
|
206
210
|
delta_seed = 0
|
|
207
211
|
|
|
208
|
-
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False
|
|
212
|
+
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False,
|
|
213
|
+
device=None):
|
|
209
214
|
self.init_finished = False
|
|
210
215
|
if isinstance(input_data, (Tensor, Tensor_)) and dtype is not None:
|
|
211
216
|
logger.info("It is suggested to use 'Tensor.astype()' to convert the dtype of a Tensor.")
|
|
@@ -264,6 +269,9 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
264
269
|
Tensor_.__init__(self, input_data)
|
|
265
270
|
validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
|
|
266
271
|
|
|
272
|
+
if device is not None and device != "CPU":
|
|
273
|
+
raise ValueError(f"Only 'CPU' is supported for device, but got {device}.")
|
|
274
|
+
|
|
267
275
|
self.const_arg = const_arg
|
|
268
276
|
self.virtual_flag = False
|
|
269
277
|
self.init = init
|
|
@@ -380,6 +388,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
380
388
|
def __abs__(self):
|
|
381
389
|
return tensor_operator_registry.get('abs')(self)
|
|
382
390
|
|
|
391
|
+
@add_mint
|
|
383
392
|
def __add__(self, other):
|
|
384
393
|
return tensor_operator_registry.get('__add__')(self, other)
|
|
385
394
|
|
|
@@ -404,6 +413,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
404
413
|
def __iadd__(self, other):
|
|
405
414
|
return self.__add__(other)
|
|
406
415
|
|
|
416
|
+
@sub_mint
|
|
407
417
|
def __sub__(self, other):
|
|
408
418
|
return tensor_operator_registry.get('__sub__')(self, other)
|
|
409
419
|
|
|
@@ -513,9 +523,12 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
513
523
|
return state
|
|
514
524
|
|
|
515
525
|
def __setstate__(self, state):
|
|
516
|
-
|
|
526
|
+
if isinstance(state, tuple):
|
|
527
|
+
value = state
|
|
528
|
+
else:
|
|
529
|
+
value = state.pop("value")
|
|
530
|
+
self.__dict__.update(state)
|
|
517
531
|
Tensor_.__setstate__(self, value)
|
|
518
|
-
self.__dict__.update(state)
|
|
519
532
|
|
|
520
533
|
@property
|
|
521
534
|
def shape(self):
|
|
@@ -706,8 +719,9 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
706
719
|
|
|
707
720
|
Examples:
|
|
708
721
|
>>> from mindspore import Tensor
|
|
722
|
+
>>> from mindspore import dtype as mstype
|
|
709
723
|
>>> import numpy as np
|
|
710
|
-
>>> x = Tensor(np.array([[1, 2], [3, 4]]))
|
|
724
|
+
>>> x = Tensor(np.array([[1, 2], [3, 4]]), dtype=mstype.int64)
|
|
711
725
|
>>> output = x.strides
|
|
712
726
|
>>> print(output)
|
|
713
727
|
(16, 8)
|
|
@@ -940,6 +954,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
940
954
|
"""
|
|
941
955
|
return tensor_operator_registry.get('chunk')(self, chunks, axis)
|
|
942
956
|
|
|
957
|
+
@item_mint
|
|
943
958
|
def item(self, index=None):
|
|
944
959
|
"""
|
|
945
960
|
Get the item at the specified index of the tensor.
|
|
@@ -1054,7 +1069,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1054
1069
|
self.init_data()
|
|
1055
1070
|
return Tensor_.asnumpy(self)
|
|
1056
1071
|
|
|
1057
|
-
def numpy(self):
|
|
1072
|
+
def numpy(self, *, force=False):
|
|
1058
1073
|
"""
|
|
1059
1074
|
Alias for :func:`mindspore.Tensor.asnumpy`.
|
|
1060
1075
|
"""
|
|
@@ -1295,12 +1310,48 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1295
1310
|
"""
|
|
1296
1311
|
return tensor_operator_registry.get('addcmul')(self, tensor1, tensor2, value)
|
|
1297
1312
|
|
|
1313
|
+
@add_mint
|
|
1298
1314
|
def add(self, other):
|
|
1299
1315
|
r"""
|
|
1300
1316
|
For details, please refer to :func:`mindspore.ops.add`.
|
|
1301
1317
|
"""
|
|
1302
1318
|
return tensor_operator_registry.get('add')(self, other)
|
|
1303
1319
|
|
|
1320
|
+
def add_(self, other, *, alpha=1):
|
|
1321
|
+
"""
|
|
1322
|
+
inplace update self by following compute:
|
|
1323
|
+
self = self + other * alpha.
|
|
1324
|
+
|
|
1325
|
+
.. warning::
|
|
1326
|
+
This is an experimental API that is subject to change or deletion.
|
|
1327
|
+
The `other` tensor must be broadcastable with the `self` tensor. It may be of a different data type.
|
|
1328
|
+
|
|
1329
|
+
Args:
|
|
1330
|
+
other (Tensor): the source tensor Add to self Tensor.
|
|
1331
|
+
alpha (Number): no effect currently.
|
|
1332
|
+
|
|
1333
|
+
Returns:
|
|
1334
|
+
Return self Tensor.
|
|
1335
|
+
|
|
1336
|
+
Supported Platforms:
|
|
1337
|
+
``Ascend``
|
|
1338
|
+
|
|
1339
|
+
Examples:
|
|
1340
|
+
>>> import numpy as np
|
|
1341
|
+
>>> from mindspore import Tensor
|
|
1342
|
+
>>> a = Tensor(np.ones((2,3)).astype("float32"))
|
|
1343
|
+
>>> b = Tensor(np.ones((2,3)).astype("float32"))
|
|
1344
|
+
>>> a.add_(b)
|
|
1345
|
+
>>> print(a)
|
|
1346
|
+
[[2. 2. 2.]
|
|
1347
|
+
[2. 2. 2.]]
|
|
1348
|
+
"""
|
|
1349
|
+
if isinstance(other, (int, float)):
|
|
1350
|
+
ret = tensor_operator_registry.get("adds_")(self, other, alpha)
|
|
1351
|
+
else:
|
|
1352
|
+
ret = tensor_operator_registry.get("add_")(self, other, alpha)
|
|
1353
|
+
return ret
|
|
1354
|
+
|
|
1304
1355
|
def subtract(self, other, *, alpha=1):
|
|
1305
1356
|
r"""
|
|
1306
1357
|
For details, please refer to :func:`mindspore.ops.subtract`.
|
|
@@ -1337,6 +1388,19 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1337
1388
|
"""
|
|
1338
1389
|
return tensor_operator_registry.get('addmm')(self, mat1, mat2, beta=beta, alpha=alpha)
|
|
1339
1390
|
|
|
1391
|
+
def addmm_(self, mat1, mat2, *, beta=1, alpha=1):
|
|
1392
|
+
r"""
|
|
1393
|
+
For details, please refer to :func:`mindspore.ops.addmm`.
|
|
1394
|
+
|
|
1395
|
+
.. note::
|
|
1396
|
+
The output results are directly updated in the Tensor.
|
|
1397
|
+
|
|
1398
|
+
.. warning::
|
|
1399
|
+
This is an experimental API that is subject to change or deletion.
|
|
1400
|
+
|
|
1401
|
+
"""
|
|
1402
|
+
return tensor_operator_registry.get('addmm_')(self, mat1, mat2, beta=beta, alpha=alpha)
|
|
1403
|
+
|
|
1340
1404
|
def addr(self, vec1, vec2, beta=1, alpha=1):
|
|
1341
1405
|
r"""
|
|
1342
1406
|
For details, please refer to :func:`mindspore.ops.addr`.
|
|
@@ -1579,6 +1643,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1579
1643
|
"""
|
|
1580
1644
|
return tensor_operator_registry.get('square')(self)
|
|
1581
1645
|
|
|
1646
|
+
@sub_mint
|
|
1582
1647
|
def sub(self, y):
|
|
1583
1648
|
r"""
|
|
1584
1649
|
For details, please refer to :func:`mindspore.ops.sub`.
|
|
@@ -1842,6 +1907,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1842
1907
|
"""
|
|
1843
1908
|
return tensor_operator_registry.get('log2')(self)
|
|
1844
1909
|
|
|
1910
|
+
@mean_mint
|
|
1845
1911
|
def mean(self, axis=None, keep_dims=False):
|
|
1846
1912
|
"""
|
|
1847
1913
|
For details, please refer to :func:`mindspore.ops.mean`.
|
|
@@ -2012,11 +2078,11 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2012
2078
|
reshape_op = tensor_operator_registry.get('reshape')
|
|
2013
2079
|
return reshape_op(self, (-1,))
|
|
2014
2080
|
|
|
2015
|
-
def round(self):
|
|
2081
|
+
def round(self, decimals=0):
|
|
2016
2082
|
"""
|
|
2017
2083
|
For details, please refer to :func:`mindspore.ops.round`.
|
|
2018
2084
|
"""
|
|
2019
|
-
return tensor_operator_registry.get('round')(self)
|
|
2085
|
+
return tensor_operator_registry.get('round')(self, decimals=decimals)
|
|
2020
2086
|
|
|
2021
2087
|
def roll(self, shifts, dims):
|
|
2022
2088
|
"""
|
|
@@ -2091,6 +2157,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2091
2157
|
"""
|
|
2092
2158
|
return tensor_operator_registry.get('remainder')(self, divisor)
|
|
2093
2159
|
|
|
2160
|
+
@flatten_mint
|
|
2094
2161
|
def flatten(self, order='C', *, start_dim=0, end_dim=-1):
|
|
2095
2162
|
r"""
|
|
2096
2163
|
For details, please refer to :func:`mindspore.ops.flatten`.
|
|
@@ -2399,6 +2466,38 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2399
2466
|
x = x.astype(origin_dtype)
|
|
2400
2467
|
return x
|
|
2401
2468
|
|
|
2469
|
+
def copy_(self, src, non_blocking=False):
|
|
2470
|
+
"""
|
|
2471
|
+
Copies the elements from src into self tensor and returns self.
|
|
2472
|
+
|
|
2473
|
+
.. warning::
|
|
2474
|
+
This is an experimental API that is subject to change or deletion.
|
|
2475
|
+
The `src` tensor must be broadcastable with the `self` tensor. It may be of a different data type.
|
|
2476
|
+
|
|
2477
|
+
Args:
|
|
2478
|
+
src (Tensor): the source tensor to copy from.
|
|
2479
|
+
non_blocking (bool): no effect currently.
|
|
2480
|
+
|
|
2481
|
+
Returns:
|
|
2482
|
+
Return self Tensor.
|
|
2483
|
+
|
|
2484
|
+
Supported Platforms:
|
|
2485
|
+
``Ascend``
|
|
2486
|
+
|
|
2487
|
+
Examples:
|
|
2488
|
+
>>> import numpy as np
|
|
2489
|
+
>>> from mindspore import Tensor
|
|
2490
|
+
>>> a = Tensor(np.ones((3,3)).astype("float32"))
|
|
2491
|
+
>>> b = Tensor(np.zeros((3,3)).astype("float32"))
|
|
2492
|
+
>>> a.copy_(b)
|
|
2493
|
+
>>> print(a)
|
|
2494
|
+
[[0. 0. 0.]
|
|
2495
|
+
[0. 0. 0.]
|
|
2496
|
+
[0. 0. 0.]]
|
|
2497
|
+
"""
|
|
2498
|
+
return tensor_operator_registry.get("copy_")(self, src)
|
|
2499
|
+
|
|
2500
|
+
@max_mint
|
|
2402
2501
|
def max(self, axis=None, keepdims=False, *, initial=None, where=True, return_indices=False):
|
|
2403
2502
|
"""
|
|
2404
2503
|
Return the maximum of a tensor or maximum along an axis.
|
|
@@ -2467,6 +2566,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2467
2566
|
return values
|
|
2468
2567
|
return values, indices
|
|
2469
2568
|
|
|
2569
|
+
@min_mint
|
|
2470
2570
|
def min(self, axis=None, keepdims=False, *, initial=None, where=True, return_indices=False):
|
|
2471
2571
|
"""
|
|
2472
2572
|
Return the minimum of a tensor or minimum along an axis.
|
|
@@ -2763,7 +2863,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2763
2863
|
opt_shard_group(str): Optimizer shard group which is used in auto or semi auto parallel mode
|
|
2764
2864
|
to get one shard of a parameter's slice. For more information about optimizer parallel, please refer to:
|
|
2765
2865
|
`Optimizer Parallel
|
|
2766
|
-
<https://www.mindspore.cn/
|
|
2866
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/optimizer_parallel.html>`_.
|
|
2767
2867
|
Default: ``None``.
|
|
2768
2868
|
|
|
2769
2869
|
Returns:
|
|
@@ -2995,16 +3095,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2995
3095
|
>>> print(x.trace())
|
|
2996
3096
|
3.0
|
|
2997
3097
|
"""
|
|
2998
|
-
|
|
2999
|
-
return tensor_operator_registry.get('trace')(self)
|
|
3000
|
-
d = self.diagonal(offset, axis1=axis1, axis2=axis2)
|
|
3001
|
-
shape = d.shape
|
|
3002
|
-
if dtype is None:
|
|
3003
|
-
dtype = d.dtype
|
|
3004
|
-
if shape[-1] == 0:
|
|
3005
|
-
return tensor_operator_registry.get('fill')(dtype, shape[:-1], 0)
|
|
3006
|
-
res = tensor_operator_registry.get('reduce_sum')(d.astype(mstype.float32), -1)
|
|
3007
|
-
return res.astype(dtype)
|
|
3098
|
+
return tensor_operator_registry.get('tracev2')(self, offset, axis1, axis2, dtype)
|
|
3008
3099
|
|
|
3009
3100
|
def take(self, indices, axis=None, mode='clip'):
|
|
3010
3101
|
"""
|
|
@@ -3164,6 +3255,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
3164
3255
|
sorter (Union[int, list, tuple, Tensor]): optional tensor of
|
|
3165
3256
|
integer indices that sort the tensor into ascending order on the innermost dimension
|
|
3166
3257
|
and the type must be int64. They are typically the result of argsort. Default: ``None`` .
|
|
3258
|
+
CPU and GPU can only use default values
|
|
3167
3259
|
|
|
3168
3260
|
Returns:
|
|
3169
3261
|
Tensor, array of insertion points with the same shape as `v`.
|
|
@@ -3217,10 +3309,10 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
3217
3309
|
|
|
3218
3310
|
def uniform(self, from_=0., to=1., generator=None):
|
|
3219
3311
|
r"""
|
|
3220
|
-
Generates random numbers in the half-open interval [
|
|
3312
|
+
Generates random numbers in the half-open interval [from\_, to).
|
|
3221
3313
|
|
|
3222
3314
|
Args:
|
|
3223
|
-
|
|
3315
|
+
from\_ (number): The lower bound of the interval.
|
|
3224
3316
|
to (number): The upper bound of the interval.
|
|
3225
3317
|
generator (Generator, optional): The random seed. Default: None.
|
|
3226
3318
|
|
|
@@ -3506,6 +3598,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
3506
3598
|
repeated_subs.append(tensor_operator_registry.get('repeat_elements')(sub, rep, axis))
|
|
3507
3599
|
return tensor_operator_registry.get('concatenate')(repeated_subs, axis)
|
|
3508
3600
|
|
|
3601
|
+
@repeat_interleave_mint
|
|
3509
3602
|
def repeat_interleave(self, repeats, dim=None):
|
|
3510
3603
|
"""
|
|
3511
3604
|
For details, please refer to :func:`mindspore.ops.repeat_interleave`.
|
|
@@ -3740,6 +3833,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
3740
3833
|
"""
|
|
3741
3834
|
return tensor_operator_registry.get("xdivy")(self, y)
|
|
3742
3835
|
|
|
3836
|
+
@split_mint
|
|
3743
3837
|
def split(self, split_size_or_sections, axis=0):
|
|
3744
3838
|
"""
|
|
3745
3839
|
For details, please refer to :func:`mindspore.ops.split`.
|
|
@@ -4039,6 +4133,27 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4039
4133
|
"""
|
|
4040
4134
|
return tensor_operator_registry.get('int')(self, mstype.int32)
|
|
4041
4135
|
|
|
4136
|
+
def byte(self):
|
|
4137
|
+
r"""
|
|
4138
|
+
Converts input tensor dtype to `uint8`.
|
|
4139
|
+
|
|
4140
|
+
Returns:
|
|
4141
|
+
Tensor, converted to the `uint8` dtype.
|
|
4142
|
+
|
|
4143
|
+
Supported Platforms:
|
|
4144
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
4145
|
+
|
|
4146
|
+
Examples:
|
|
4147
|
+
>>> import numpy as np
|
|
4148
|
+
>>> import mindspore
|
|
4149
|
+
>>> from mindspore import Tensor
|
|
4150
|
+
>>> input_x = Tensor(np.ones([2,2]), mindspore.float32)
|
|
4151
|
+
>>> output = input_x.byte()
|
|
4152
|
+
>>> print(output.dtype)
|
|
4153
|
+
uint8
|
|
4154
|
+
"""
|
|
4155
|
+
return tensor_operator_registry.get('byte')(self, mstype.uint8)
|
|
4156
|
+
|
|
4042
4157
|
def long(self):
|
|
4043
4158
|
r"""
|
|
4044
4159
|
Converts input tensor dtype to `int64`. If the value in tensor is float or half, the decimal will be discarded.
|
|
@@ -4249,6 +4364,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4249
4364
|
"""
|
|
4250
4365
|
return tensor_operator_registry.get('isinf')(self)
|
|
4251
4366
|
|
|
4367
|
+
@isnan_mint
|
|
4252
4368
|
def isnan(self):
|
|
4253
4369
|
r"""
|
|
4254
4370
|
For details, please refer to :func:`mindspore.ops.isnan`.
|
|
@@ -4425,7 +4541,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4425
4541
|
"""
|
|
4426
4542
|
return tensor_operator_registry.get('mul')(self, value)
|
|
4427
4543
|
|
|
4428
|
-
def nan_to_num(self, nan=
|
|
4544
|
+
def nan_to_num(self, nan=None, posinf=None, neginf=None):
|
|
4429
4545
|
"""
|
|
4430
4546
|
For details, please refer to :func:`mindspore.ops.nan_to_num`.
|
|
4431
4547
|
"""
|
|
@@ -4482,6 +4598,31 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4482
4598
|
"""
|
|
4483
4599
|
return tensor_operator_registry.get('zeros')(size, dtype)
|
|
4484
4600
|
|
|
4601
|
+
def zero_(self):
|
|
4602
|
+
r"""
|
|
4603
|
+
Return a tensor filled with zeros.
|
|
4604
|
+
|
|
4605
|
+
.. warning::
|
|
4606
|
+
This is an experimental API that is subject to change or deletion.
|
|
4607
|
+
|
|
4608
|
+
Returns:
|
|
4609
|
+
Return a tensor. Fill self tensor with zeros.
|
|
4610
|
+
|
|
4611
|
+
Supported Platforms:
|
|
4612
|
+
``Ascend``
|
|
4613
|
+
|
|
4614
|
+
Examples:
|
|
4615
|
+
>>> import numpy as np
|
|
4616
|
+
>>> import mindspore
|
|
4617
|
+
>>> from mindspore import Tensor
|
|
4618
|
+
>>> x = Tensor(np.array([2, 2]))
|
|
4619
|
+
>>> output = x.zero_()
|
|
4620
|
+
>>> print(output)
|
|
4621
|
+
[[0. 0.]
|
|
4622
|
+
[0. 0.]]
|
|
4623
|
+
"""
|
|
4624
|
+
return tensor_operator_registry.get('zero_')(self)
|
|
4625
|
+
|
|
4485
4626
|
def new_ones(self, size, dtype=None):
|
|
4486
4627
|
r"""
|
|
4487
4628
|
Return a tensor of `size` filled with ones.
|
|
@@ -4758,7 +4899,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4758
4899
|
mode = context.get_context("mode")
|
|
4759
4900
|
if mode != context.PYNATIVE_MODE:
|
|
4760
4901
|
raise ValueError(f"The method of 'move_to' only supported in pynative mode, but got: {mode}.")
|
|
4761
|
-
return Tensor(Tensor_.move_to(self, to, blocking))
|
|
4902
|
+
return Tensor(Tensor_.move_to(self, to, blocking), device="CPU" if to == "CPU" else None)
|
|
4762
4903
|
|
|
4763
4904
|
|
|
4764
4905
|
def _offload(self):
|
|
@@ -4805,6 +4946,44 @@ def _vm_compare(*args):
|
|
|
4805
4946
|
return Tensor(np.array(fn(y)))
|
|
4806
4947
|
|
|
4807
4948
|
|
|
4949
|
+
def _check_sequence_shape(input_data):
|
|
4950
|
+
"""Check the shape of tensor input with type of sequence."""
|
|
4951
|
+
max_dims_reached = False
|
|
4952
|
+
max_ndim = 64 # corresponding to NPY_MAXDIMS
|
|
4953
|
+
out_shape = [0]*max_ndim
|
|
4954
|
+
|
|
4955
|
+
def check_shape_recursive(input_data, curr_ndim):
|
|
4956
|
+
nonlocal max_dims_reached, max_ndim, out_shape
|
|
4957
|
+
if curr_ndim > max_ndim:
|
|
4958
|
+
return False
|
|
4959
|
+
if not isinstance(input_data, (tuple, list)):
|
|
4960
|
+
if max_dims_reached and curr_ndim != max_ndim:
|
|
4961
|
+
max_ndim = curr_ndim
|
|
4962
|
+
return False
|
|
4963
|
+
max_dims_reached = True
|
|
4964
|
+
max_ndim = curr_ndim
|
|
4965
|
+
return True
|
|
4966
|
+
if not max_dims_reached:
|
|
4967
|
+
out_shape[curr_ndim] = len(input_data)
|
|
4968
|
+
else:
|
|
4969
|
+
if out_shape[curr_ndim] != len(input_data):
|
|
4970
|
+
max_ndim = curr_ndim
|
|
4971
|
+
return False
|
|
4972
|
+
if not input_data:
|
|
4973
|
+
# process empty list
|
|
4974
|
+
if not check_shape_recursive(None, curr_ndim + 1):
|
|
4975
|
+
return False
|
|
4976
|
+
for data in input_data:
|
|
4977
|
+
if not check_shape_recursive(data, curr_ndim + 1):
|
|
4978
|
+
return False
|
|
4979
|
+
return True
|
|
4980
|
+
|
|
4981
|
+
if not check_shape_recursive(input_data, 0):
|
|
4982
|
+
raise ValueError(f"When initializing a tensor with a sequence, the sequence has an inhomogeneous shape "
|
|
4983
|
+
f"after {max_ndim} dimensions. The detected shape was {tuple(out_shape[:max_ndim])} "
|
|
4984
|
+
f"+ inhomogeneous part.")
|
|
4985
|
+
|
|
4986
|
+
|
|
4808
4987
|
def _check_tensor_input(input_data=None, dtype=None, shape=None, init=None):
|
|
4809
4988
|
"""Check the tensor input."""
|
|
4810
4989
|
if input_data is not None and shape is not None:
|
|
@@ -4817,9 +4996,10 @@ def _check_tensor_input(input_data=None, dtype=None, shape=None, init=None):
|
|
|
4817
4996
|
if input_data is not None:
|
|
4818
4997
|
if isinstance(input_data, np.ndarray) and input_data.ndim >= 1 and input_data.size == 0:
|
|
4819
4998
|
raise ValueError("input_data can not contain zero dimension.")
|
|
4820
|
-
if isinstance(input_data, (tuple, list))
|
|
4821
|
-
|
|
4822
|
-
|
|
4999
|
+
if isinstance(input_data, (tuple, list)):
|
|
5000
|
+
_check_sequence_shape(input_data)
|
|
5001
|
+
if np.array(input_data).ndim >= 1 and np.array(input_data).size == 0:
|
|
5002
|
+
raise ValueError("input_data can not contain zero dimension.")
|
|
4823
5003
|
|
|
4824
5004
|
if shape is not None and not (hasattr(init, "__enable_zero_dim__") and init.__enable_zero_dim__) and 0 in shape:
|
|
4825
5005
|
raise ValueError("Shape can not contain zero value.")
|
|
@@ -20,7 +20,7 @@ Note that the APIs in the following list need to preset communication environmen
|
|
|
20
20
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
21
21
|
without any third-party or configuration file dependencies.
|
|
22
22
|
Please see the `msrun start up
|
|
23
|
-
<https://www.mindspore.cn/
|
|
23
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
24
24
|
for more details.
|
|
25
25
|
"""
|
|
26
26
|
|