mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/common/tensor.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -31,6 +31,8 @@ from mindspore.common.hook_handle import _TensorHookHandle
|
|
|
31
31
|
|
|
32
32
|
from mindspore.common._utils import get_slice_num
|
|
33
33
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
34
|
+
from mindspore.common._tensor_overload import (repeat_interleave_mint, add_mint, item_mint, isnan_mint, flatten_mint,
|
|
35
|
+
max_mint, mean_mint, min_mint, split_mint, sub_mint)
|
|
34
36
|
from mindspore._c_expression import Tensor as Tensor_
|
|
35
37
|
from mindspore import _checkparam as validator
|
|
36
38
|
from mindspore._checkparam import check_is_number, is_stub_tensor, check_hook_fn
|
|
@@ -51,7 +53,7 @@ def _check_input_data_type(input_data):
|
|
|
51
53
|
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
|
|
52
54
|
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
|
|
53
55
|
input_data.dtype.kind != 'U' and input_data.dtype.kind != 'S' and \
|
|
54
|
-
input_data.dtype.kind
|
|
56
|
+
not (input_data.dtype.kind == 'V' and input_data.dtype.char == 'E'): # Support np.str_ and np.bfloat16
|
|
55
57
|
new_line = '\n'
|
|
56
58
|
for index, x in np.ndenumerate(input_data):
|
|
57
59
|
if np.array(x).dtype not in valid_dtypes:
|
|
@@ -85,11 +87,11 @@ def tensor(input_data=None, dtype=None, shape=None, init=None, internal=False, c
|
|
|
85
87
|
based on the `dtype` argument.
|
|
86
88
|
|
|
87
89
|
Please refer to `Creating and Using Tensor
|
|
88
|
-
<https://www.mindspore.cn/docs/en/master/
|
|
90
|
+
<https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#mindspore-user-defined-data-types>`_ .
|
|
89
91
|
|
|
90
92
|
The difference between it and the Tensor class is that it adds
|
|
91
93
|
`Annotation
|
|
92
|
-
<https://www.mindspore.cn/docs/en/master/
|
|
94
|
+
<https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#annotation-type>`_
|
|
93
95
|
which can prevent the generation of AnyType compared to the Tensor class.
|
|
94
96
|
|
|
95
97
|
The arguments and return values are the same as the Tensor class. Also see: :class:`mindspore.Tensor`.
|
|
@@ -143,6 +145,8 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
143
145
|
Default: ``False`` .
|
|
144
146
|
const_arg (bool): Whether the tensor is a constant when it is used for the argument of a network.
|
|
145
147
|
Default: ``False`` .
|
|
148
|
+
device(str): This parameter is reserved and does not need to be configured.
|
|
149
|
+
Default: ``None`` .
|
|
146
150
|
|
|
147
151
|
Outputs:
|
|
148
152
|
Tensor.
|
|
@@ -205,7 +209,8 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
205
209
|
"""
|
|
206
210
|
delta_seed = 0
|
|
207
211
|
|
|
208
|
-
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False
|
|
212
|
+
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False,
|
|
213
|
+
device=None):
|
|
209
214
|
self.init_finished = False
|
|
210
215
|
if isinstance(input_data, (Tensor, Tensor_)) and dtype is not None:
|
|
211
216
|
logger.info("It is suggested to use 'Tensor.astype()' to convert the dtype of a Tensor.")
|
|
@@ -264,6 +269,9 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
264
269
|
Tensor_.__init__(self, input_data)
|
|
265
270
|
validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
|
|
266
271
|
|
|
272
|
+
if device is not None and device != "CPU":
|
|
273
|
+
raise ValueError(f"Only 'CPU' is supported for device, but got {device}.")
|
|
274
|
+
|
|
267
275
|
self.const_arg = const_arg
|
|
268
276
|
self.virtual_flag = False
|
|
269
277
|
self.init = init
|
|
@@ -380,6 +388,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
380
388
|
def __abs__(self):
|
|
381
389
|
return tensor_operator_registry.get('abs')(self)
|
|
382
390
|
|
|
391
|
+
@add_mint
|
|
383
392
|
def __add__(self, other):
|
|
384
393
|
return tensor_operator_registry.get('__add__')(self, other)
|
|
385
394
|
|
|
@@ -404,6 +413,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
404
413
|
def __iadd__(self, other):
|
|
405
414
|
return self.__add__(other)
|
|
406
415
|
|
|
416
|
+
@sub_mint
|
|
407
417
|
def __sub__(self, other):
|
|
408
418
|
return tensor_operator_registry.get('__sub__')(self, other)
|
|
409
419
|
|
|
@@ -513,9 +523,12 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
513
523
|
return state
|
|
514
524
|
|
|
515
525
|
def __setstate__(self, state):
|
|
516
|
-
|
|
526
|
+
if isinstance(state, tuple):
|
|
527
|
+
value = state
|
|
528
|
+
else:
|
|
529
|
+
value = state.pop("value")
|
|
530
|
+
self.__dict__.update(state)
|
|
517
531
|
Tensor_.__setstate__(self, value)
|
|
518
|
-
self.__dict__.update(state)
|
|
519
532
|
|
|
520
533
|
@property
|
|
521
534
|
def shape(self):
|
|
@@ -706,8 +719,9 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
706
719
|
|
|
707
720
|
Examples:
|
|
708
721
|
>>> from mindspore import Tensor
|
|
722
|
+
>>> from mindspore import dtype as mstype
|
|
709
723
|
>>> import numpy as np
|
|
710
|
-
>>> x = Tensor(np.array([[1, 2], [3, 4]]))
|
|
724
|
+
>>> x = Tensor(np.array([[1, 2], [3, 4]]), dtype=mstype.int64)
|
|
711
725
|
>>> output = x.strides
|
|
712
726
|
>>> print(output)
|
|
713
727
|
(16, 8)
|
|
@@ -940,6 +954,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
940
954
|
"""
|
|
941
955
|
return tensor_operator_registry.get('chunk')(self, chunks, axis)
|
|
942
956
|
|
|
957
|
+
@item_mint
|
|
943
958
|
def item(self, index=None):
|
|
944
959
|
"""
|
|
945
960
|
Get the item at the specified index of the tensor.
|
|
@@ -1054,7 +1069,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1054
1069
|
self.init_data()
|
|
1055
1070
|
return Tensor_.asnumpy(self)
|
|
1056
1071
|
|
|
1057
|
-
def numpy(self):
|
|
1072
|
+
def numpy(self, *, force=False):
|
|
1058
1073
|
"""
|
|
1059
1074
|
Alias for :func:`mindspore.Tensor.asnumpy`.
|
|
1060
1075
|
"""
|
|
@@ -1295,12 +1310,48 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1295
1310
|
"""
|
|
1296
1311
|
return tensor_operator_registry.get('addcmul')(self, tensor1, tensor2, value)
|
|
1297
1312
|
|
|
1313
|
+
@add_mint
|
|
1298
1314
|
def add(self, other):
|
|
1299
1315
|
r"""
|
|
1300
1316
|
For details, please refer to :func:`mindspore.ops.add`.
|
|
1301
1317
|
"""
|
|
1302
1318
|
return tensor_operator_registry.get('add')(self, other)
|
|
1303
1319
|
|
|
1320
|
+
def add_(self, other, *, alpha=1):
|
|
1321
|
+
"""
|
|
1322
|
+
inplace update self by following compute:
|
|
1323
|
+
self = self + other * alpha.
|
|
1324
|
+
|
|
1325
|
+
.. warning::
|
|
1326
|
+
This is an experimental API that is subject to change or deletion.
|
|
1327
|
+
The `other` tensor must be broadcastable with the `self` tensor. It may be of a different data type.
|
|
1328
|
+
|
|
1329
|
+
Args:
|
|
1330
|
+
other (Tensor): the source tensor Add to self Tensor.
|
|
1331
|
+
alpha (Number): no effect currently.
|
|
1332
|
+
|
|
1333
|
+
Returns:
|
|
1334
|
+
Return self Tensor.
|
|
1335
|
+
|
|
1336
|
+
Supported Platforms:
|
|
1337
|
+
``Ascend``
|
|
1338
|
+
|
|
1339
|
+
Examples:
|
|
1340
|
+
>>> import numpy as np
|
|
1341
|
+
>>> from mindspore import Tensor
|
|
1342
|
+
>>> a = Tensor(np.ones((2,3)).astype("float32"))
|
|
1343
|
+
>>> b = Tensor(np.ones((2,3)).astype("float32"))
|
|
1344
|
+
>>> a.add_(b)
|
|
1345
|
+
>>> print(a)
|
|
1346
|
+
[[2. 2. 2.]
|
|
1347
|
+
[2. 2. 2.]]
|
|
1348
|
+
"""
|
|
1349
|
+
if isinstance(other, (int, float)):
|
|
1350
|
+
ret = tensor_operator_registry.get("adds_")(self, other, alpha)
|
|
1351
|
+
else:
|
|
1352
|
+
ret = tensor_operator_registry.get("add_")(self, other, alpha)
|
|
1353
|
+
return ret
|
|
1354
|
+
|
|
1304
1355
|
def subtract(self, other, *, alpha=1):
|
|
1305
1356
|
r"""
|
|
1306
1357
|
For details, please refer to :func:`mindspore.ops.subtract`.
|
|
@@ -1337,6 +1388,19 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1337
1388
|
"""
|
|
1338
1389
|
return tensor_operator_registry.get('addmm')(self, mat1, mat2, beta=beta, alpha=alpha)
|
|
1339
1390
|
|
|
1391
|
+
def addmm_(self, mat1, mat2, *, beta=1, alpha=1):
|
|
1392
|
+
r"""
|
|
1393
|
+
For details, please refer to :func:`mindspore.ops.addmm`.
|
|
1394
|
+
|
|
1395
|
+
.. note::
|
|
1396
|
+
The output results are directly updated in the Tensor.
|
|
1397
|
+
|
|
1398
|
+
.. warning::
|
|
1399
|
+
This is an experimental API that is subject to change or deletion.
|
|
1400
|
+
|
|
1401
|
+
"""
|
|
1402
|
+
return tensor_operator_registry.get('addmm_')(self, mat1, mat2, beta=beta, alpha=alpha)
|
|
1403
|
+
|
|
1340
1404
|
def addr(self, vec1, vec2, beta=1, alpha=1):
|
|
1341
1405
|
r"""
|
|
1342
1406
|
For details, please refer to :func:`mindspore.ops.addr`.
|
|
@@ -1579,6 +1643,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1579
1643
|
"""
|
|
1580
1644
|
return tensor_operator_registry.get('square')(self)
|
|
1581
1645
|
|
|
1646
|
+
@sub_mint
|
|
1582
1647
|
def sub(self, y):
|
|
1583
1648
|
r"""
|
|
1584
1649
|
For details, please refer to :func:`mindspore.ops.sub`.
|
|
@@ -1842,6 +1907,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
1842
1907
|
"""
|
|
1843
1908
|
return tensor_operator_registry.get('log2')(self)
|
|
1844
1909
|
|
|
1910
|
+
@mean_mint
|
|
1845
1911
|
def mean(self, axis=None, keep_dims=False):
|
|
1846
1912
|
"""
|
|
1847
1913
|
For details, please refer to :func:`mindspore.ops.mean`.
|
|
@@ -2012,11 +2078,11 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2012
2078
|
reshape_op = tensor_operator_registry.get('reshape')
|
|
2013
2079
|
return reshape_op(self, (-1,))
|
|
2014
2080
|
|
|
2015
|
-
def round(self):
|
|
2081
|
+
def round(self, decimals=0):
|
|
2016
2082
|
"""
|
|
2017
2083
|
For details, please refer to :func:`mindspore.ops.round`.
|
|
2018
2084
|
"""
|
|
2019
|
-
return tensor_operator_registry.get('round')(self)
|
|
2085
|
+
return tensor_operator_registry.get('round')(self, decimals=decimals)
|
|
2020
2086
|
|
|
2021
2087
|
def roll(self, shifts, dims):
|
|
2022
2088
|
"""
|
|
@@ -2091,6 +2157,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2091
2157
|
"""
|
|
2092
2158
|
return tensor_operator_registry.get('remainder')(self, divisor)
|
|
2093
2159
|
|
|
2160
|
+
@flatten_mint
|
|
2094
2161
|
def flatten(self, order='C', *, start_dim=0, end_dim=-1):
|
|
2095
2162
|
r"""
|
|
2096
2163
|
For details, please refer to :func:`mindspore.ops.flatten`.
|
|
@@ -2399,6 +2466,38 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2399
2466
|
x = x.astype(origin_dtype)
|
|
2400
2467
|
return x
|
|
2401
2468
|
|
|
2469
|
+
def copy_(self, src, non_blocking=False):
|
|
2470
|
+
"""
|
|
2471
|
+
Copies the elements from src into self tensor and returns self.
|
|
2472
|
+
|
|
2473
|
+
.. warning::
|
|
2474
|
+
This is an experimental API that is subject to change or deletion.
|
|
2475
|
+
The `src` tensor must be broadcastable with the `self` tensor. It may be of a different data type.
|
|
2476
|
+
|
|
2477
|
+
Args:
|
|
2478
|
+
src (Tensor): the source tensor to copy from.
|
|
2479
|
+
non_blocking (bool): no effect currently.
|
|
2480
|
+
|
|
2481
|
+
Returns:
|
|
2482
|
+
Return self Tensor.
|
|
2483
|
+
|
|
2484
|
+
Supported Platforms:
|
|
2485
|
+
``Ascend``
|
|
2486
|
+
|
|
2487
|
+
Examples:
|
|
2488
|
+
>>> import numpy as np
|
|
2489
|
+
>>> from mindspore import Tensor
|
|
2490
|
+
>>> a = Tensor(np.ones((3,3)).astype("float32"))
|
|
2491
|
+
>>> b = Tensor(np.zeros((3,3)).astype("float32"))
|
|
2492
|
+
>>> a.copy_(b)
|
|
2493
|
+
>>> print(a)
|
|
2494
|
+
[[0. 0. 0.]
|
|
2495
|
+
[0. 0. 0.]
|
|
2496
|
+
[0. 0. 0.]]
|
|
2497
|
+
"""
|
|
2498
|
+
return tensor_operator_registry.get("copy_")(self, src)
|
|
2499
|
+
|
|
2500
|
+
@max_mint
|
|
2402
2501
|
def max(self, axis=None, keepdims=False, *, initial=None, where=True, return_indices=False):
|
|
2403
2502
|
"""
|
|
2404
2503
|
Return the maximum of a tensor or maximum along an axis.
|
|
@@ -2467,6 +2566,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2467
2566
|
return values
|
|
2468
2567
|
return values, indices
|
|
2469
2568
|
|
|
2569
|
+
@min_mint
|
|
2470
2570
|
def min(self, axis=None, keepdims=False, *, initial=None, where=True, return_indices=False):
|
|
2471
2571
|
"""
|
|
2472
2572
|
Return the minimum of a tensor or minimum along an axis.
|
|
@@ -2763,7 +2863,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2763
2863
|
opt_shard_group(str): Optimizer shard group which is used in auto or semi auto parallel mode
|
|
2764
2864
|
to get one shard of a parameter's slice. For more information about optimizer parallel, please refer to:
|
|
2765
2865
|
`Optimizer Parallel
|
|
2766
|
-
<https://www.mindspore.cn/
|
|
2866
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/optimizer_parallel.html>`_.
|
|
2767
2867
|
Default: ``None``.
|
|
2768
2868
|
|
|
2769
2869
|
Returns:
|
|
@@ -2796,8 +2896,13 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2796
2896
|
self.slice_shape_of_persistent_data_ = data_shape
|
|
2797
2897
|
self.slice_num_of_persistent_data_ = slice_num_of_persistent_data
|
|
2798
2898
|
|
|
2899
|
+
from mindspore.common.initializer import Zero as ZeroInitializer
|
|
2900
|
+
|
|
2799
2901
|
try:
|
|
2800
|
-
|
|
2902
|
+
if isinstance(self.init, ZeroInitializer):
|
|
2903
|
+
data = np.zeros(data_shape, dtype=mstype.dtype_to_nptype(self.dtype))
|
|
2904
|
+
else:
|
|
2905
|
+
data = np.ndarray(data_shape, dtype=mstype.dtype_to_nptype(self.dtype))
|
|
2801
2906
|
except ValueError as e:
|
|
2802
2907
|
msg = "Error shape={}".format(shape)
|
|
2803
2908
|
logger.critical(msg)
|
|
@@ -2833,7 +2938,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2833
2938
|
self.init.seed, _ = self.seed
|
|
2834
2939
|
|
|
2835
2940
|
with seed_context(self.init):
|
|
2836
|
-
if slice_num_of_persistent_data == 1:
|
|
2941
|
+
if not isinstance(self.init, ZeroInitializer) and slice_num_of_persistent_data == 1:
|
|
2837
2942
|
self.init(data)
|
|
2838
2943
|
self.init = None
|
|
2839
2944
|
|
|
@@ -2995,16 +3100,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
2995
3100
|
>>> print(x.trace())
|
|
2996
3101
|
3.0
|
|
2997
3102
|
"""
|
|
2998
|
-
|
|
2999
|
-
return tensor_operator_registry.get('trace')(self)
|
|
3000
|
-
d = self.diagonal(offset, axis1=axis1, axis2=axis2)
|
|
3001
|
-
shape = d.shape
|
|
3002
|
-
if dtype is None:
|
|
3003
|
-
dtype = d.dtype
|
|
3004
|
-
if shape[-1] == 0:
|
|
3005
|
-
return tensor_operator_registry.get('fill')(dtype, shape[:-1], 0)
|
|
3006
|
-
res = tensor_operator_registry.get('reduce_sum')(d.astype(mstype.float32), -1)
|
|
3007
|
-
return res.astype(dtype)
|
|
3103
|
+
return tensor_operator_registry.get('tracev2')(self, offset, axis1, axis2, dtype)
|
|
3008
3104
|
|
|
3009
3105
|
def take(self, indices, axis=None, mode='clip'):
|
|
3010
3106
|
"""
|
|
@@ -3164,6 +3260,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
3164
3260
|
sorter (Union[int, list, tuple, Tensor]): optional tensor of
|
|
3165
3261
|
integer indices that sort the tensor into ascending order on the innermost dimension
|
|
3166
3262
|
and the type must be int64. They are typically the result of argsort. Default: ``None`` .
|
|
3263
|
+
CPU and GPU can only use default values
|
|
3167
3264
|
|
|
3168
3265
|
Returns:
|
|
3169
3266
|
Tensor, array of insertion points with the same shape as `v`.
|
|
@@ -3217,10 +3314,10 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
3217
3314
|
|
|
3218
3315
|
def uniform(self, from_=0., to=1., generator=None):
|
|
3219
3316
|
r"""
|
|
3220
|
-
Generates random numbers in the half-open interval [
|
|
3317
|
+
Generates random numbers in the half-open interval [from\_, to).
|
|
3221
3318
|
|
|
3222
3319
|
Args:
|
|
3223
|
-
|
|
3320
|
+
from\_ (number): The lower bound of the interval.
|
|
3224
3321
|
to (number): The upper bound of the interval.
|
|
3225
3322
|
generator (Generator, optional): The random seed. Default: None.
|
|
3226
3323
|
|
|
@@ -3506,6 +3603,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
3506
3603
|
repeated_subs.append(tensor_operator_registry.get('repeat_elements')(sub, rep, axis))
|
|
3507
3604
|
return tensor_operator_registry.get('concatenate')(repeated_subs, axis)
|
|
3508
3605
|
|
|
3606
|
+
@repeat_interleave_mint
|
|
3509
3607
|
def repeat_interleave(self, repeats, dim=None):
|
|
3510
3608
|
"""
|
|
3511
3609
|
For details, please refer to :func:`mindspore.ops.repeat_interleave`.
|
|
@@ -3740,6 +3838,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
3740
3838
|
"""
|
|
3741
3839
|
return tensor_operator_registry.get("xdivy")(self, y)
|
|
3742
3840
|
|
|
3841
|
+
@split_mint
|
|
3743
3842
|
def split(self, split_size_or_sections, axis=0):
|
|
3744
3843
|
"""
|
|
3745
3844
|
For details, please refer to :func:`mindspore.ops.split`.
|
|
@@ -4039,6 +4138,27 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4039
4138
|
"""
|
|
4040
4139
|
return tensor_operator_registry.get('int')(self, mstype.int32)
|
|
4041
4140
|
|
|
4141
|
+
def byte(self):
|
|
4142
|
+
r"""
|
|
4143
|
+
Converts input tensor dtype to `uint8`.
|
|
4144
|
+
|
|
4145
|
+
Returns:
|
|
4146
|
+
Tensor, converted to the `uint8` dtype.
|
|
4147
|
+
|
|
4148
|
+
Supported Platforms:
|
|
4149
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
4150
|
+
|
|
4151
|
+
Examples:
|
|
4152
|
+
>>> import numpy as np
|
|
4153
|
+
>>> import mindspore
|
|
4154
|
+
>>> from mindspore import Tensor
|
|
4155
|
+
>>> input_x = Tensor(np.ones([2,2]), mindspore.float32)
|
|
4156
|
+
>>> output = input_x.byte()
|
|
4157
|
+
>>> print(output.dtype)
|
|
4158
|
+
uint8
|
|
4159
|
+
"""
|
|
4160
|
+
return tensor_operator_registry.get('byte')(self, mstype.uint8)
|
|
4161
|
+
|
|
4042
4162
|
def long(self):
|
|
4043
4163
|
r"""
|
|
4044
4164
|
Converts input tensor dtype to `int64`. If the value in tensor is float or half, the decimal will be discarded.
|
|
@@ -4249,6 +4369,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4249
4369
|
"""
|
|
4250
4370
|
return tensor_operator_registry.get('isinf')(self)
|
|
4251
4371
|
|
|
4372
|
+
@isnan_mint
|
|
4252
4373
|
def isnan(self):
|
|
4253
4374
|
r"""
|
|
4254
4375
|
For details, please refer to :func:`mindspore.ops.isnan`.
|
|
@@ -4425,7 +4546,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4425
4546
|
"""
|
|
4426
4547
|
return tensor_operator_registry.get('mul')(self, value)
|
|
4427
4548
|
|
|
4428
|
-
def nan_to_num(self, nan=
|
|
4549
|
+
def nan_to_num(self, nan=None, posinf=None, neginf=None):
|
|
4429
4550
|
"""
|
|
4430
4551
|
For details, please refer to :func:`mindspore.ops.nan_to_num`.
|
|
4431
4552
|
"""
|
|
@@ -4482,6 +4603,31 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4482
4603
|
"""
|
|
4483
4604
|
return tensor_operator_registry.get('zeros')(size, dtype)
|
|
4484
4605
|
|
|
4606
|
+
def zero_(self):
|
|
4607
|
+
r"""
|
|
4608
|
+
Return a tensor filled with zeros.
|
|
4609
|
+
|
|
4610
|
+
.. warning::
|
|
4611
|
+
This is an experimental API that is subject to change or deletion.
|
|
4612
|
+
|
|
4613
|
+
Returns:
|
|
4614
|
+
Return a tensor. Fill self tensor with zeros.
|
|
4615
|
+
|
|
4616
|
+
Supported Platforms:
|
|
4617
|
+
``Ascend``
|
|
4618
|
+
|
|
4619
|
+
Examples:
|
|
4620
|
+
>>> import numpy as np
|
|
4621
|
+
>>> import mindspore
|
|
4622
|
+
>>> from mindspore import Tensor
|
|
4623
|
+
>>> x = Tensor(np.array([2, 2]))
|
|
4624
|
+
>>> output = x.zero_()
|
|
4625
|
+
>>> print(output)
|
|
4626
|
+
[[0. 0.]
|
|
4627
|
+
[0. 0.]]
|
|
4628
|
+
"""
|
|
4629
|
+
return tensor_operator_registry.get('zero_')(self)
|
|
4630
|
+
|
|
4485
4631
|
def new_ones(self, size, dtype=None):
|
|
4486
4632
|
r"""
|
|
4487
4633
|
Return a tensor of `size` filled with ones.
|
|
@@ -4608,7 +4754,6 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4608
4754
|
"""
|
|
4609
4755
|
return tensor_operator_registry.get('lu_solve')(self, LU_data, LU_pivots)
|
|
4610
4756
|
|
|
4611
|
-
|
|
4612
4757
|
def nextafter(self, other):
|
|
4613
4758
|
r"""
|
|
4614
4759
|
For details, please refer to :func:`mindspore.ops.nextafter`.
|
|
@@ -4622,7 +4767,6 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4622
4767
|
validator.check_value_type('some', some, bool, 'Tensor.qr')
|
|
4623
4768
|
return tensor_operator_registry.get('qr')(self, 'reduced' if some else 'complete')
|
|
4624
4769
|
|
|
4625
|
-
|
|
4626
4770
|
def ormqr(self, input2, input3, left=True, transpose=False):
|
|
4627
4771
|
r"""
|
|
4628
4772
|
For details, please refer to :func:`mindspore.ops.ormqr`,
|
|
@@ -4630,7 +4774,6 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4630
4774
|
"""
|
|
4631
4775
|
return tensor_operator_registry.get('ormqr')(self, input2, input3, left, transpose)
|
|
4632
4776
|
|
|
4633
|
-
|
|
4634
4777
|
def masked_scatter(self, mask, x):
|
|
4635
4778
|
r"""
|
|
4636
4779
|
Returns a Tensor. Updates the value in the "self Tensor" with the `tensor` value according to the mask.
|
|
@@ -4671,7 +4814,6 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4671
4814
|
"""
|
|
4672
4815
|
return tensor_operator_registry.get('masked_scatter')()(self, mask, x)
|
|
4673
4816
|
|
|
4674
|
-
|
|
4675
4817
|
def index_put(self, indices, values, accumulate=False):
|
|
4676
4818
|
r"""
|
|
4677
4819
|
Returns a Tensor. According to the index number of `indices` ,
|
|
@@ -4724,7 +4866,6 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4724
4866
|
_index_put = tensor_operator_registry.get('index_put')(0 if accumulate is False else 1)
|
|
4725
4867
|
return _index_put(self, values, indices)
|
|
4726
4868
|
|
|
4727
|
-
|
|
4728
4869
|
def move_to(self, to, blocking=True):
|
|
4729
4870
|
r"""
|
|
4730
4871
|
Copy Tensor to target device synchronously or asynchronously, default synchronously. only support PyNative mode.
|
|
@@ -4758,8 +4899,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|
|
4758
4899
|
mode = context.get_context("mode")
|
|
4759
4900
|
if mode != context.PYNATIVE_MODE:
|
|
4760
4901
|
raise ValueError(f"The method of 'move_to' only supported in pynative mode, but got: {mode}.")
|
|
4761
|
-
return
|
|
4762
|
-
|
|
4902
|
+
return Tensor_.move_to(self, to, blocking)
|
|
4763
4903
|
|
|
4764
4904
|
def _offload(self):
|
|
4765
4905
|
r"""
|
|
@@ -4817,9 +4957,15 @@ def _check_tensor_input(input_data=None, dtype=None, shape=None, init=None):
|
|
|
4817
4957
|
if input_data is not None:
|
|
4818
4958
|
if isinstance(input_data, np.ndarray) and input_data.ndim >= 1 and input_data.size == 0:
|
|
4819
4959
|
raise ValueError("input_data can not contain zero dimension.")
|
|
4820
|
-
if isinstance(input_data, (tuple, list))
|
|
4821
|
-
|
|
4822
|
-
|
|
4960
|
+
if isinstance(input_data, (tuple, list)):
|
|
4961
|
+
try:
|
|
4962
|
+
np_data = np.array(input_data)
|
|
4963
|
+
except ValueError as e:
|
|
4964
|
+
if "The requested array has an inhomogeneous shape" in str(e):
|
|
4965
|
+
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
|
|
4966
|
+
raise
|
|
4967
|
+
if np_data.ndim >= 1 and np_data.size == 0:
|
|
4968
|
+
raise ValueError("input_data can not contain zero dimension.")
|
|
4823
4969
|
|
|
4824
4970
|
if shape is not None and not (hasattr(init, "__enable_zero_dim__") and init.__enable_zero_dim__) and 0 in shape:
|
|
4825
4971
|
raise ValueError("Shape can not contain zero value.")
|
|
@@ -20,7 +20,7 @@ Note that the APIs in the following list need to preset communication environmen
|
|
|
20
20
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
21
21
|
without any third-party or configuration file dependencies.
|
|
22
22
|
Please see the `msrun start up
|
|
23
|
-
<https://www.mindspore.cn/
|
|
23
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
24
24
|
for more details.
|
|
25
25
|
"""
|
|
26
26
|
|
|
@@ -22,7 +22,7 @@ import sys
|
|
|
22
22
|
from sys import excepthook
|
|
23
23
|
|
|
24
24
|
from mindspore import context
|
|
25
|
-
from mindspore.parallel._ps_context import
|
|
25
|
+
from mindspore.parallel._ps_context import _is_role_sched, _is_ps_mode,\
|
|
26
26
|
_get_ps_context
|
|
27
27
|
from mindspore import log as logger
|
|
28
28
|
from mindspore._c_expression import CollectiveManager, set_cluster_exit_with_exception, MSContext
|
|
@@ -127,6 +127,7 @@ class _ExistingGroup:
|
|
|
127
127
|
The communication groups which exist in the progress.
|
|
128
128
|
"""
|
|
129
129
|
ITEMS = {}
|
|
130
|
+
GROUP_RANKS = {}
|
|
130
131
|
|
|
131
132
|
|
|
132
133
|
def _hccl_test():
|
|
@@ -160,8 +161,7 @@ def _check_bypass_rank_id_and_size():
|
|
|
160
161
|
|
|
161
162
|
|
|
162
163
|
def _set_elegant_exit_handle():
|
|
163
|
-
|
|
164
|
-
sys.excepthook = lambda *args: (set_cluster_exit_with_exception(), excepthook(*args))
|
|
164
|
+
sys.excepthook = lambda *args: (set_cluster_exit_with_exception(), excepthook(*args))
|
|
165
165
|
|
|
166
166
|
|
|
167
167
|
def check_parameter_available(func):
|
|
@@ -390,6 +390,38 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group):
|
|
|
390
390
|
return group_rank_id
|
|
391
391
|
|
|
392
392
|
|
|
393
|
+
@check_parameter_available
|
|
394
|
+
def _get_group_rank_from_world_rank_from_cache_helper(world_rank_id, group):
|
|
395
|
+
"""
|
|
396
|
+
The Helper to do get_group_rank_from_world_rank_from_cache.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
world_rank_id (int): A rank id in world communication group.
|
|
400
|
+
group (str): The user communication group.
|
|
401
|
+
|
|
402
|
+
Raises:
|
|
403
|
+
TypeError: If world_rank_id is not int.
|
|
404
|
+
KeyError: If group and world_rank_id is not found in cache.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Integer. A rank id in user communication group.
|
|
408
|
+
"""
|
|
409
|
+
if not isinstance(world_rank_id, int):
|
|
410
|
+
raise TypeError("For 'get_group_rank_from_world_rank_from_cache', the argument 'world_rank_id' must be type of "
|
|
411
|
+
"int, but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
|
|
412
|
+
|
|
413
|
+
if group == GlobalComm.WORLD_COMM_GROUP:
|
|
414
|
+
# world_rank_id is same with group_rank_id in WORLD_COMM_GROUP
|
|
415
|
+
return world_rank_id
|
|
416
|
+
if group not in _ExistingGroup.GROUP_RANKS:
|
|
417
|
+
raise KeyError("For 'get_group_rank_from_world_rank_from_cache', the argument 'group' is not "
|
|
418
|
+
"found in GROUP_RANKS, 'group' : {}, 'world_rank_id' : {}".format(group, world_rank_id))
|
|
419
|
+
if world_rank_id not in _ExistingGroup.GROUP_RANKS[group]:
|
|
420
|
+
raise KeyError("For 'get_group_rank_from_world_rank_from_cache', the argument 'world_rank_id' is not "
|
|
421
|
+
"found in GROUP_RANKS, 'group' : {}, 'world_rank_id' : {}".format(group, world_rank_id))
|
|
422
|
+
return _ExistingGroup.GROUP_RANKS[group][world_rank_id]
|
|
423
|
+
|
|
424
|
+
|
|
393
425
|
@check_parameter_available
|
|
394
426
|
def _get_group_ranks(group):
|
|
395
427
|
"""
|
|
@@ -444,6 +476,9 @@ def _create_group_helper(group, rank_ids):
|
|
|
444
476
|
"is suggested before launching jobs.".format(group, rank_ids))
|
|
445
477
|
|
|
446
478
|
_ExistingGroup.ITEMS[group] = rank_ids
|
|
479
|
+
sorted_ranks = sorted(rank_ids)
|
|
480
|
+
_ExistingGroup.GROUP_RANKS[group] = {world_rank_id: group_rank_id
|
|
481
|
+
for group_rank_id, world_rank_id in enumerate(sorted_ranks)}
|
|
447
482
|
|
|
448
483
|
|
|
449
484
|
@check_parameter_available
|