mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -39,11 +39,12 @@ from ..auto_generate import (ExpandDims, Reshape, TensorShape, Transpose, Gather
|
|
|
39
39
|
OnesLike, ZerosLike, Argmax, ArgMaxExt,
|
|
40
40
|
ReverseV2, Diag, Eye, ScatterNd, ResizeNearestNeighborV2,
|
|
41
41
|
GatherNd, GatherD, Range, MaskedFill, RightShift, NonZero,
|
|
42
|
-
ResizeNearestNeighbor, Identity, Split, CumSum, CumProd,
|
|
42
|
+
ResizeNearestNeighbor, Identity, Split, CumSum, CumProd, MaskedSelect,
|
|
43
43
|
Cummax, Cummin, Argmin, Concat, UnsortedSegmentSum, ScalarToTensor,
|
|
44
44
|
Triu, BroadcastTo, StridedSlice, Select, TopkExt, SearchSorted)
|
|
45
45
|
from .manually_defined import Rank, Shape, Tile, Cast, Ones, Zeros
|
|
46
46
|
from ..auto_generate import ArgMaxWithValue, ArgMinWithValue
|
|
47
|
+
from ..auto_generate import TensorScatterElements as TensorScatterElementsExt
|
|
47
48
|
|
|
48
49
|
class _ScatterOp(PrimitiveWithInfer):
|
|
49
50
|
"""
|
|
@@ -769,41 +770,15 @@ class Padding(Primitive):
|
|
|
769
770
|
|
|
770
771
|
class UniqueWithPad(Primitive):
|
|
771
772
|
"""
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
The returned tuple(`y`, `idx`) after the input Tensor `x` is processed by the unique operator,
|
|
776
|
-
in which the shapes of `y` and `idx` are mostly not equal. Therefore, in order to solve the above situation,
|
|
777
|
-
the UniqueWithPad operator will fill the `y` Tensor with the `pad_num` specified by the user
|
|
778
|
-
to make it have the same shape as the Tensor `idx`.
|
|
779
|
-
|
|
780
|
-
Refer to :func:`mindspore.ops.unique_with_pad` for more details.
|
|
781
|
-
|
|
782
|
-
Inputs:
|
|
783
|
-
- **x** (Tensor) - The tensor need to be unique. Must be 1-D vector with types: int32, int64.
|
|
784
|
-
- **pad_num** (int) - Pad num. The data type is an int.
|
|
785
|
-
|
|
786
|
-
Outputs:
|
|
787
|
-
tuple(Tensor), tuple of 2 tensors, `y` and `idx`.
|
|
788
|
-
|
|
789
|
-
- y (Tensor) - The unique elements filled with pad_num, the shape and data type same as `x`.
|
|
790
|
-
- idx (Tensor) - The index of each value of `x` in the unique output `y`, the shape and data type same as `x`.
|
|
773
|
+
'ops.UniqueWithPad' is deprecated from version 2.4 and will be removed in a future version.
|
|
774
|
+
Please use the :func:`mindspore.ops.unique` combined with :func:`mindspore.ops.pad` to realize
|
|
775
|
+
the same function.
|
|
791
776
|
|
|
792
777
|
Supported Platforms:
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
Examples:
|
|
796
|
-
>>> import mindspore
|
|
797
|
-
>>> import numpy as np
|
|
798
|
-
>>> from mindspore import Tensor, ops
|
|
799
|
-
>>> x = Tensor(np.array([1, 1, 2, 2, 3, 3, 4, 5]), mindspore.int32)
|
|
800
|
-
>>> pad_num = 8
|
|
801
|
-
>>> output = ops.UniqueWithPad()(x, pad_num)
|
|
802
|
-
>>> print(output)
|
|
803
|
-
(Tensor(shape=[8], dtype=Int32, value= [1, 2, 3, 4, 5, 8, 8, 8]),
|
|
804
|
-
Tensor(shape=[8], dtype=Int32, value= [0, 0, 1, 1, 2, 2, 3, 4]))
|
|
778
|
+
Deprecated
|
|
805
779
|
"""
|
|
806
780
|
|
|
781
|
+
@deprecated("2.4", "ops.unique and ops.pad", False)
|
|
807
782
|
@prim_attr_register
|
|
808
783
|
def __init__(self):
|
|
809
784
|
"""init UniqueWithPad"""
|
|
@@ -819,7 +794,7 @@ class Size(Primitive):
|
|
|
819
794
|
|
|
820
795
|
Inputs:
|
|
821
796
|
- **input_x** (Tensor) - Input parameters, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is
|
|
822
|
-
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
797
|
+
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
|
|
823
798
|
|
|
824
799
|
Outputs:
|
|
825
800
|
int. A scalar representing the elements' size of `input_x`, tensor is the number of elements
|
|
@@ -2112,60 +2087,6 @@ class Rint(Primitive):
|
|
|
2112
2087
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
|
2113
2088
|
|
|
2114
2089
|
|
|
2115
|
-
class StridedSliceV2(Primitive):
|
|
2116
|
-
r"""
|
|
2117
|
-
StridedSliceV2 will be deprecated by StridedSlice in the future.
|
|
2118
|
-
Extracts a strided slice of a tensor.
|
|
2119
|
-
Refer to class StridedSlice for more details.
|
|
2120
|
-
|
|
2121
|
-
Args:
|
|
2122
|
-
begin_mask (int): Starting index of the slice. Default: ``0`` .
|
|
2123
|
-
end_mask (int): Ending index of the slice. Default: ``0`` .
|
|
2124
|
-
ellipsis_mask (int): An int mask. Default: ``0`` .
|
|
2125
|
-
new_axis_mask (int): An int mask. Default: ``0`` .
|
|
2126
|
-
shrink_axis_mask (int): An int mask. Default: ``0`` .
|
|
2127
|
-
|
|
2128
|
-
Inputs:
|
|
2129
|
-
- **input_x** (Tensor) - The input Tensor.
|
|
2130
|
-
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only
|
|
2131
|
-
constant value is allowed.
|
|
2132
|
-
- **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
|
|
2133
|
-
Only constant value is allowed.
|
|
2134
|
-
- **strides** (tuple[int]) - A tuple which represents the stride is continuously added
|
|
2135
|
-
before reaching the maximum location. Only constant value is allowed.
|
|
2136
|
-
|
|
2137
|
-
Outputs:
|
|
2138
|
-
Tensor, The output is explained by following example.
|
|
2139
|
-
|
|
2140
|
-
Raises:
|
|
2141
|
-
TypeError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is not an int.
|
|
2142
|
-
TypeError: If `begin`, `end` or `strides` is not a tuple.
|
|
2143
|
-
ValueError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is less than 0.
|
|
2144
|
-
|
|
2145
|
-
Supported Platforms:
|
|
2146
|
-
``Ascend`` ``CPU``
|
|
2147
|
-
|
|
2148
|
-
Examples:
|
|
2149
|
-
>>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
|
|
2150
|
-
... [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
|
|
2151
|
-
>>> strided_slice_v2 = ops.StridedSliceV2()
|
|
2152
|
-
>>> output = strided_slice_v2(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
|
|
2153
|
-
>>> print(output)
|
|
2154
|
-
[[[3.]]
|
|
2155
|
-
[[5.]]]
|
|
2156
|
-
"""
|
|
2157
|
-
|
|
2158
|
-
@prim_attr_register
|
|
2159
|
-
def __init__(self,
|
|
2160
|
-
begin_mask=0,
|
|
2161
|
-
end_mask=0,
|
|
2162
|
-
ellipsis_mask=0,
|
|
2163
|
-
new_axis_mask=0,
|
|
2164
|
-
shrink_axis_mask=0):
|
|
2165
|
-
"""Initialize StridedSliceV2"""
|
|
2166
|
-
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
|
|
2167
|
-
|
|
2168
|
-
|
|
2169
2090
|
class DiagPart(PrimitiveWithCheck):
|
|
2170
2091
|
r"""
|
|
2171
2092
|
|
|
@@ -4356,53 +4277,6 @@ class MaskedScatter(Primitive):
|
|
|
4356
4277
|
self.init_prim_io_names(inputs=['x', 'mask', 'updates'], outputs=['y'])
|
|
4357
4278
|
|
|
4358
4279
|
|
|
4359
|
-
class MaskedSelect(PrimitiveWithCheck):
|
|
4360
|
-
"""
|
|
4361
|
-
Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
|
|
4362
|
-
The shapes of the `mask` tensor and the `x` tensor don't need to match, but they must be broadcastable.
|
|
4363
|
-
|
|
4364
|
-
Inputs:
|
|
4365
|
-
- **x** (Tensor) - Input Tensor of any dimension.
|
|
4366
|
-
- **mask** (Tensor[bool]) - Boolean mask Tensor, has the same shape as `x`.
|
|
4367
|
-
|
|
4368
|
-
Outputs:
|
|
4369
|
-
A 1-D Tensor, with the same type as x.
|
|
4370
|
-
|
|
4371
|
-
Raises:
|
|
4372
|
-
TypeError: If `x` or `mask` is not a Tensor.
|
|
4373
|
-
TypeError: If dtype of `mask` is not bool.
|
|
4374
|
-
|
|
4375
|
-
Supported Platforms:
|
|
4376
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
4377
|
-
|
|
4378
|
-
Examples:
|
|
4379
|
-
>>> import mindspore
|
|
4380
|
-
>>> import numpy as np
|
|
4381
|
-
>>> from mindspore import Tensor, ops
|
|
4382
|
-
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int32)
|
|
4383
|
-
>>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_)
|
|
4384
|
-
>>> output = ops.MaskedSelect()(x, mask)
|
|
4385
|
-
>>> print(output)
|
|
4386
|
-
[1 3]
|
|
4387
|
-
>>> x = Tensor(2.1, mindspore.float32)
|
|
4388
|
-
>>> mask = Tensor(True, mindspore.bool_)
|
|
4389
|
-
>>> output = ops.MaskedSelect()(x, mask)
|
|
4390
|
-
>>> print(output)
|
|
4391
|
-
[2.1]
|
|
4392
|
-
"""
|
|
4393
|
-
|
|
4394
|
-
@prim_attr_register
|
|
4395
|
-
def __init__(self):
|
|
4396
|
-
self.init_prim_io_names(inputs=['x', 'mask'], outputs=['output'])
|
|
4397
|
-
|
|
4398
|
-
def check_shape(self, x_shape, mask_shape):
|
|
4399
|
-
get_broadcast_shape(x_shape, mask_shape, self.name, arg_name1="x", arg_name2="mask")
|
|
4400
|
-
|
|
4401
|
-
def check_dtype(self, x_dtype, mask_dtype):
|
|
4402
|
-
validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)
|
|
4403
|
-
validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name)
|
|
4404
|
-
|
|
4405
|
-
|
|
4406
4280
|
class _TensorScatterOp(PrimitiveWithInfer):
|
|
4407
4281
|
"""
|
|
4408
4282
|
Defines TensorScatter Base Operators
|
|
@@ -4962,7 +4836,7 @@ class SplitV(Primitive):
|
|
|
4962
4836
|
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
|
4963
4837
|
|
|
4964
4838
|
|
|
4965
|
-
class TensorScatterElements(
|
|
4839
|
+
class TensorScatterElements(TensorScatterElementsExt):
|
|
4966
4840
|
"""
|
|
4967
4841
|
Write all elements in `updates` to the index specified by `indices` in `input_x` according to the reduction
|
|
4968
4842
|
operation specified by `reduction`.
|
|
@@ -4977,6 +4851,9 @@ class TensorScatterElements(Primitive):
|
|
|
4977
4851
|
.. warning::
|
|
4978
4852
|
This is an experimental API that is subject to change or deletion.
|
|
4979
4853
|
|
|
4854
|
+
Note:
|
|
4855
|
+
The backward is supported only for the case `updates.shape == indices.shape`.
|
|
4856
|
+
|
|
4980
4857
|
Args:
|
|
4981
4858
|
axis (int, optional): Specify which axis to do scatter operation. Default: ``0`` .
|
|
4982
4859
|
reduction (str, optional): Which reduction operation to scatter, default is ``"none"`` . Other option: "add".
|
|
@@ -4986,7 +4863,7 @@ class TensorScatterElements(Primitive):
|
|
|
4986
4863
|
- **indices** (Tensor) - The index of `input_x` to do scatter operation whose data type must be int32 or
|
|
4987
4864
|
int64. It has the same rank as `data`. And accepted range is [-s, s) where s is the size along axis.
|
|
4988
4865
|
- **updates** (Tensor) - The tensor doing the scatter operation with `data`,
|
|
4989
|
-
it has the same type as `data
|
|
4866
|
+
it has the same type as `data`.
|
|
4990
4867
|
|
|
4991
4868
|
Outputs:
|
|
4992
4869
|
Tensor, has the same shape and type as `data`.
|
|
@@ -5021,16 +4898,7 @@ class TensorScatterElements(Primitive):
|
|
|
5021
4898
|
|
|
5022
4899
|
@prim_attr_register
|
|
5023
4900
|
def __init__(self, axis=0, reduction="none"):
|
|
5024
|
-
|
|
5025
|
-
validator.check_value_type("axis", axis, [int], self.name)
|
|
5026
|
-
validator.check_value_type("reduction", reduction, [str], self.name)
|
|
5027
|
-
validator.check_string(reduction, ["none", "add"], "reduction", self.name)
|
|
5028
|
-
self.init_prim_io_names(inputs=['data', 'indices', 'updates'], outputs=['y'])
|
|
5029
|
-
target = context.get_context("device_target")
|
|
5030
|
-
if reduction != 'none' and target.lower() == "ascend":
|
|
5031
|
-
raise ValueError(f"For '{self.name}', "
|
|
5032
|
-
f"Currently Ascend device_target only support `reduction`='none', "
|
|
5033
|
-
f"but got {reduction}")
|
|
4901
|
+
super().__init__(axis, reduce=reduction)
|
|
5034
4902
|
|
|
5035
4903
|
|
|
5036
4904
|
class ExtractVolumePatches(Primitive):
|
|
@@ -54,7 +54,7 @@ class ReduceOp:
|
|
|
54
54
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
55
55
|
without any third-party or configuration file dependencies.
|
|
56
56
|
Please see the `msrun start up
|
|
57
|
-
<https://www.mindspore.cn/
|
|
57
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
58
58
|
for more details.
|
|
59
59
|
|
|
60
60
|
This example should be run with multiple devices.
|
|
@@ -141,7 +141,7 @@ class AllReduce(Primitive):
|
|
|
141
141
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
142
142
|
without any third-party or configuration file dependencies.
|
|
143
143
|
Please see the `msrun start up
|
|
144
|
-
<https://www.mindspore.cn/
|
|
144
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
145
145
|
for more details.
|
|
146
146
|
|
|
147
147
|
This example should be run with 2 devices.
|
|
@@ -178,14 +178,15 @@ class AllReduce(Primitive):
|
|
|
178
178
|
@prim_attr_register
|
|
179
179
|
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
|
180
180
|
"""Initialize AllReduce."""
|
|
181
|
+
self.group = _get_group(group)
|
|
181
182
|
if not isinstance(op, type(ReduceOp.SUM)):
|
|
182
183
|
raise TypeError(f"For '{self.name}', the 'op' must be str, but got {type(op).__name__}.")
|
|
183
|
-
if not isinstance(
|
|
184
|
+
if not isinstance(self.group, str):
|
|
184
185
|
raise TypeError(f"For '{self.name}', the 'group' must be str, "
|
|
185
|
-
f"but got {type(
|
|
186
|
-
check_hcom_group_valid(group, prim_name=self.name)
|
|
186
|
+
f"but got {type(self.group).__name__}.")
|
|
187
|
+
check_hcom_group_valid(self.group, prim_name=self.name)
|
|
187
188
|
self.op = op
|
|
188
|
-
self.add_prim_attr('group',
|
|
189
|
+
self.add_prim_attr('group', self.group)
|
|
189
190
|
self.add_prim_attr('fusion', 0)
|
|
190
191
|
self.add_prim_attr('index', 0)
|
|
191
192
|
self.add_prim_attr('no_eliminate', True)
|
|
@@ -230,7 +231,7 @@ class Reduce(PrimitiveWithInfer):
|
|
|
230
231
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party
|
|
231
232
|
or configuration file dependencies.
|
|
232
233
|
Please see the `msrun start up
|
|
233
|
-
<https://www.mindspore.cn/
|
|
234
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
234
235
|
for more details.
|
|
235
236
|
|
|
236
237
|
This example should be run with 4 devices.
|
|
@@ -314,7 +315,7 @@ class AllGather(PrimitiveWithInfer):
|
|
|
314
315
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
315
316
|
without any third-party or configuration file dependencies.
|
|
316
317
|
Please see the `msrun start up
|
|
317
|
-
<https://www.mindspore.cn/
|
|
318
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
318
319
|
for more details.
|
|
319
320
|
|
|
320
321
|
This example should be run with 2 devices.
|
|
@@ -354,12 +355,13 @@ class AllGather(PrimitiveWithInfer):
|
|
|
354
355
|
@prim_attr_register
|
|
355
356
|
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
|
356
357
|
"""Initialize AllGather."""
|
|
357
|
-
|
|
358
|
-
self.
|
|
359
|
-
self.
|
|
358
|
+
self.group = _get_group(group)
|
|
359
|
+
validator.check_value_type('group', self.group, (str,), self.name)
|
|
360
|
+
self.rank = get_rank(self.group)
|
|
361
|
+
self.rank_size = get_group_size(self.group)
|
|
360
362
|
validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
|
|
361
363
|
self.add_prim_attr('rank_size', self.rank_size)
|
|
362
|
-
self.add_prim_attr('group',
|
|
364
|
+
self.add_prim_attr('group', self.group)
|
|
363
365
|
self.add_prim_attr('fusion', 0)
|
|
364
366
|
self.add_prim_attr('mean_flag', False)
|
|
365
367
|
self.add_prim_attr('no_eliminate', True)
|
|
@@ -375,25 +377,6 @@ class AllGather(PrimitiveWithInfer):
|
|
|
375
377
|
return x_dtype
|
|
376
378
|
|
|
377
379
|
|
|
378
|
-
class AShardIdentity(PrimitiveWithInfer):
|
|
379
|
-
"""
|
|
380
|
-
Auto parallel virtual operator. Identity operator only for shard function.
|
|
381
|
-
Do nothing in terms of infer_shape, infer_dtype, and the tensor.
|
|
382
|
-
|
|
383
|
-
It is only for internal use of parallel modules and cannot be called by users.
|
|
384
|
-
"""
|
|
385
|
-
|
|
386
|
-
@prim_attr_register
|
|
387
|
-
def __init__(self):
|
|
388
|
-
pass
|
|
389
|
-
|
|
390
|
-
def infer_shape(self, x_shape):
|
|
391
|
-
return x_shape
|
|
392
|
-
|
|
393
|
-
def infer_dtype(self, x_dtype):
|
|
394
|
-
return x_dtype
|
|
395
|
-
|
|
396
|
-
|
|
397
380
|
class _MiniStepAllGather(PrimitiveWithInfer):
|
|
398
381
|
"""
|
|
399
382
|
Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
|
|
@@ -555,7 +538,7 @@ class ReduceScatter(Primitive):
|
|
|
555
538
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
556
539
|
without any third-party or configuration file dependencies.
|
|
557
540
|
Please see the `msrun start up
|
|
558
|
-
<https://www.mindspore.cn/
|
|
541
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
559
542
|
for more details.
|
|
560
543
|
|
|
561
544
|
This example should be run with 2 devices.
|
|
@@ -597,11 +580,12 @@ class ReduceScatter(Primitive):
|
|
|
597
580
|
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
|
598
581
|
"""Initialize ReduceScatter."""
|
|
599
582
|
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
|
|
600
|
-
|
|
583
|
+
self.group = _get_group(group)
|
|
584
|
+
validator.check_value_type('group', self.group, (str,), self.name)
|
|
601
585
|
self.op = op
|
|
602
|
-
self.rank_size = get_group_size(
|
|
586
|
+
self.rank_size = get_group_size(self.group)
|
|
603
587
|
self.add_prim_attr('rank_size', self.rank_size)
|
|
604
|
-
self.add_prim_attr('group',
|
|
588
|
+
self.add_prim_attr('group', self.group)
|
|
605
589
|
self.add_prim_attr('fusion', 0)
|
|
606
590
|
self.add_prim_attr('no_eliminate', True)
|
|
607
591
|
|
|
@@ -692,7 +676,7 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
692
676
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
693
677
|
without any third-party or configuration file dependencies.
|
|
694
678
|
Please see the `msrun start up
|
|
695
|
-
<https://www.mindspore.cn/
|
|
679
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
696
680
|
for more details.
|
|
697
681
|
|
|
698
682
|
This example should be run with 2 devices.
|
|
@@ -922,7 +906,7 @@ class AlltoAll(PrimitiveWithInfer):
|
|
|
922
906
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
923
907
|
without any third-party or configuration file dependencies.
|
|
924
908
|
Please see the `msrun start up
|
|
925
|
-
<https://www.mindspore.cn/
|
|
909
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
926
910
|
for more details.
|
|
927
911
|
|
|
928
912
|
This example should be run with 8 devices.
|
|
@@ -1041,7 +1025,7 @@ class NeighborExchangeV2(Primitive):
|
|
|
1041
1025
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
1042
1026
|
without any third-party or configuration file dependencies.
|
|
1043
1027
|
Please see the `msrun start up
|
|
1044
|
-
<https://www.mindspore.cn/
|
|
1028
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
1045
1029
|
for more details.
|
|
1046
1030
|
|
|
1047
1031
|
This example should be run with 2 devices.
|
|
@@ -1158,7 +1142,7 @@ class CollectiveScatter(Primitive):
|
|
|
1158
1142
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
1159
1143
|
without any third-party or configuration file dependencies.
|
|
1160
1144
|
Please see the `msrun start up
|
|
1161
|
-
<https://www.mindspore.cn/
|
|
1145
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
1162
1146
|
for more details.
|
|
1163
1147
|
|
|
1164
1148
|
This example should be run with 2 devices.
|
|
@@ -1243,7 +1227,7 @@ class CollectiveGather(Primitive):
|
|
|
1243
1227
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
1244
1228
|
without any third-party or configuration file dependencies.
|
|
1245
1229
|
Please see the `msrun start up
|
|
1246
|
-
<https://www.mindspore.cn/
|
|
1230
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
1247
1231
|
for more details.
|
|
1248
1232
|
|
|
1249
1233
|
This example should be run with 4 devices.
|
|
@@ -1308,8 +1292,6 @@ class Barrier(PrimitiveWithInfer):
|
|
|
1308
1292
|
Raises:
|
|
1309
1293
|
TypeError: If `group` is not a str.
|
|
1310
1294
|
RuntimeError: If backend is invalid, or distributed initialization fails.
|
|
1311
|
-
ValueError: If the local rank id of the calling process in the group
|
|
1312
|
-
is larger than the group's rank size.
|
|
1313
1295
|
|
|
1314
1296
|
Supported Platforms:
|
|
1315
1297
|
``Ascend``
|
|
@@ -1321,7 +1303,7 @@ class Barrier(PrimitiveWithInfer):
|
|
|
1321
1303
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
1322
1304
|
without any third-party or configuration file dependencies.
|
|
1323
1305
|
Please see the `msrun start up
|
|
1324
|
-
<https://www.mindspore.cn/
|
|
1306
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
1325
1307
|
for more details.
|
|
1326
1308
|
|
|
1327
1309
|
This example should be run with 2 devices.
|
|
@@ -1395,7 +1377,7 @@ class Send(PrimitiveWithInfer):
|
|
|
1395
1377
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
1396
1378
|
without any third-party or configuration file dependencies.
|
|
1397
1379
|
Please see the `msrun start up
|
|
1398
|
-
<https://www.mindspore.cn/
|
|
1380
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
1399
1381
|
for more details.
|
|
1400
1382
|
|
|
1401
1383
|
This example should be run with 2 devices.
|
|
@@ -1431,7 +1413,7 @@ class Send(PrimitiveWithInfer):
|
|
|
1431
1413
|
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP):
|
|
1432
1414
|
self.rank = dest_rank
|
|
1433
1415
|
self.sr_tag = sr_tag
|
|
1434
|
-
self.group = group
|
|
1416
|
+
self.group = _get_group(group)
|
|
1435
1417
|
self.add_prim_attr("no_eliminate", True)
|
|
1436
1418
|
|
|
1437
1419
|
def infer_shape(self, x_shape):
|
|
@@ -1479,7 +1461,7 @@ class Receive(PrimitiveWithInfer):
|
|
|
1479
1461
|
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
1480
1462
|
without any third-party or configuration file dependencies.
|
|
1481
1463
|
Please see the `msrun start up
|
|
1482
|
-
<https://www.mindspore.cn/
|
|
1464
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
1483
1465
|
for more details.
|
|
1484
1466
|
|
|
1485
1467
|
This example should be run with 2 devices.
|
|
@@ -1517,7 +1499,7 @@ class Receive(PrimitiveWithInfer):
|
|
|
1517
1499
|
self.tag = sr_tag
|
|
1518
1500
|
self.shape = shape
|
|
1519
1501
|
self.dtype = dtype
|
|
1520
|
-
self.group = group
|
|
1502
|
+
self.group = _get_group(group)
|
|
1521
1503
|
self.add_prim_attr("no_eliminate", True)
|
|
1522
1504
|
valid_type = [mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16,
|
|
1523
1505
|
mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
|
@@ -1695,9 +1677,32 @@ class _VirtualAssignAdd(PrimitiveWithInfer):
|
|
|
1695
1677
|
|
|
1696
1678
|
def infer_dtype(self, x_dtype, y_dtype):
|
|
1697
1679
|
return x_dtype
|
|
1680
|
+
|
|
1681
|
+
|
|
1698
1682
|
virtual_assign_add = _VirtualAssignAdd()
|
|
1699
1683
|
|
|
1700
1684
|
|
|
1685
|
+
class _VirtualAssignKvCache(PrimitiveWithInfer):
|
|
1686
|
+
"""
|
|
1687
|
+
Auto parallel virtual operator. Do nothing in forward, do Assign kv cache in backward. It is only for
|
|
1688
|
+
internal use of parallel modules and cannot be called by users.
|
|
1689
|
+
|
|
1690
|
+
"""
|
|
1691
|
+
|
|
1692
|
+
@prim_attr_register
|
|
1693
|
+
def __init__(self):
|
|
1694
|
+
"""Initialize _VirtualAssignAdd."""
|
|
1695
|
+
self.add_prim_attr('order_enforce_skip', True)
|
|
1696
|
+
self.add_prim_attr('side_effect_backprop_mem', True)
|
|
1697
|
+
|
|
1698
|
+
def infer_shape(self, x_shape, y_shape, kv_equal_shape):
|
|
1699
|
+
return x_shape
|
|
1700
|
+
|
|
1701
|
+
def infer_dtype(self, x_dtype, y_dtype, kv_equal_dtype):
|
|
1702
|
+
return x_dtype
|
|
1703
|
+
virtual_assign_kv_cache = _VirtualAssignKvCache()
|
|
1704
|
+
|
|
1705
|
+
|
|
1701
1706
|
class _VirtualAccuGrad(PrimitiveWithInfer):
|
|
1702
1707
|
"""
|
|
1703
1708
|
Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for
|
|
@@ -1834,7 +1839,7 @@ class BatchISendIRecv(PrimitiveWithInfer):
|
|
|
1834
1839
|
without any third-party or configuration file dependencies.
|
|
1835
1840
|
|
|
1836
1841
|
Please see the `msrun start up
|
|
1837
|
-
<https://www.mindspore.cn/
|
|
1842
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
1838
1843
|
for more details.
|
|
1839
1844
|
|
|
1840
1845
|
This example should be run with 2 devices.
|
|
@@ -1924,6 +1929,7 @@ class AlltoAllV(PrimitiveWithInfer):
|
|
|
1924
1929
|
recv_numel_list(Union[tuple[int], list[int]]): split numel to gather from different remote rank.
|
|
1925
1930
|
group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which
|
|
1926
1931
|
means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
|
|
1932
|
+
TODO:
|
|
1927
1933
|
|
|
1928
1934
|
Inputs:
|
|
1929
1935
|
- **input_x** (Tensor) - flatten tensor to scatter. The shape of tensor is :math:`(x_1)`.
|
|
@@ -1946,7 +1952,7 @@ class AlltoAllV(PrimitiveWithInfer):
|
|
|
1946
1952
|
without any third-party or configuration file dependencies.
|
|
1947
1953
|
|
|
1948
1954
|
Please see the `msrun start up
|
|
1949
|
-
<https://www.mindspore.cn/
|
|
1955
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
1950
1956
|
for more details.
|
|
1951
1957
|
|
|
1952
1958
|
This example should be run with 2 devices.
|
|
@@ -1986,11 +1992,15 @@ class AlltoAllV(PrimitiveWithInfer):
|
|
|
1986
1992
|
"""
|
|
1987
1993
|
|
|
1988
1994
|
@prim_attr_register
|
|
1989
|
-
def __init__(self, send_numel_list, recv_numel_list, group=None):
|
|
1995
|
+
def __init__(self, send_numel_list, recv_numel_list, group=None, split_sizes_empty=False):
|
|
1990
1996
|
validator.check_value_type("send_numel_list", send_numel_list, [tuple, list], self.name)
|
|
1991
1997
|
validator.check_value_type("recv_numel_list", recv_numel_list, [tuple, list], self.name)
|
|
1992
|
-
if group is None
|
|
1993
|
-
|
|
1994
|
-
self.
|
|
1998
|
+
self.group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
|
|
1999
|
+
self.send_numel_list = send_numel_list
|
|
2000
|
+
self.recv_numel_list = recv_numel_list
|
|
2001
|
+
self.split_sizes_empty = split_sizes_empty
|
|
2002
|
+
self.rank_size = get_group_size(self.group)
|
|
2003
|
+
|
|
2004
|
+
self.add_prim_attr('group', self.group)
|
|
1995
2005
|
self.add_prim_attr('send_numel_list', send_numel_list)
|
|
1996
2006
|
self.add_prim_attr('recv_numel_list', recv_numel_list)
|