mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -18,13 +18,16 @@ import hashlib
|
|
|
18
18
|
import builtins
|
|
19
19
|
import io
|
|
20
20
|
import pickle
|
|
21
|
+
from datetime import timedelta
|
|
21
22
|
import numpy as np
|
|
22
23
|
from mindspore import log as logger
|
|
23
24
|
from mindspore.common import dtype as mstype
|
|
25
|
+
from mindspore._checkparam import args_type_check
|
|
24
26
|
from mindspore.ops import ReduceOp, cat
|
|
25
27
|
from mindspore.common.tensor import Tensor
|
|
26
28
|
from mindspore._c_expression import TensorPy as Tensor_
|
|
27
29
|
from mindspore.ops.primitive import _primexpr
|
|
30
|
+
from mindspore.common.api import _pynative_executor
|
|
28
31
|
from mindspore.communication._comm_helper import (
|
|
29
32
|
_destroy_group_helper,
|
|
30
33
|
_get_rank_helper,
|
|
@@ -33,10 +36,11 @@ from mindspore.communication._comm_helper import (
|
|
|
33
36
|
_get_group_ranks,
|
|
34
37
|
_is_available,
|
|
35
38
|
_is_initialized,
|
|
39
|
+
_ExistingGroup,
|
|
36
40
|
)
|
|
41
|
+
from mindspore.communication.management import _init_without_sched
|
|
37
42
|
from mindspore.communication import (
|
|
38
43
|
init,
|
|
39
|
-
release,
|
|
40
44
|
get_group_size,
|
|
41
45
|
get_world_rank_from_group_rank,
|
|
42
46
|
create_group,
|
|
@@ -58,9 +62,11 @@ from mindspore.ops.auto_generate.gen_ops_prim import (
|
|
|
58
62
|
dist_comm_isend_op,
|
|
59
63
|
dist_comm_all_to_all_v_op,
|
|
60
64
|
dist_comm_reduce_scatter_tensor_op,
|
|
65
|
+
dist_comm_reduce_scatter_tensor_uneven_op,
|
|
61
66
|
dist_comm_all_to_all_v_single_op,
|
|
62
67
|
dist_comm_broadcast_op,
|
|
63
68
|
dist_comm_all_gather_into_tensor_op,
|
|
69
|
+
dist_comm_all_gather_into_tensor_uneven_op,
|
|
64
70
|
dist_comm_irecv_op,
|
|
65
71
|
dist_comm_scatter_tensor_op,
|
|
66
72
|
dist_comm_gather_into_tensor_op,
|
|
@@ -70,7 +76,7 @@ from mindspore.ops.auto_generate.gen_ops_prim import (
|
|
|
70
76
|
dist_comm_barrier_op,
|
|
71
77
|
dist_comm_batch_isend_irecv_op,
|
|
72
78
|
)
|
|
73
|
-
from mindspore._c_expression import TCPStoreClient, GroupOptions
|
|
79
|
+
from mindspore._c_expression import TCPStoreClient, GroupOptions, _finalize_collective
|
|
74
80
|
|
|
75
81
|
_pickler = pickle.Pickler
|
|
76
82
|
_unpickler = pickle.Unpickler
|
|
@@ -144,28 +150,26 @@ class TCPStore:
|
|
|
144
150
|
|
|
145
151
|
Note:
|
|
146
152
|
- The function is implemented by CPU and does not involve any hardware operations related to Ascend.
|
|
147
|
-
- Currently, all parameters provided by the TCPStore class constructor are not supported
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
- The current
|
|
153
|
+
- Currently, all parameters provided by the TCPStore class constructor are not supported
|
|
154
|
+
except for `host_name`, `port`, `world_size`, `is_master`, `timeout` and `wait_for_workers`,
|
|
155
|
+
which are reserved parameters and invalid settings.
|
|
156
|
+
- The current TCPStore function is limited and only supports scenarios where the key is
|
|
151
157
|
less than 4k and the value is less than 1G. Complex scenarios are to be supported.
|
|
152
|
-
- The timeout interval for message sending and receiving in the TcpStore function is controlled by
|
|
153
|
-
the `MS_RECEIVE_MSG_TIMEOUT` environment variable, in seconds, with a default value of ``15``.
|
|
154
|
-
If a timeout occurs, the user needs to increase the configuration value.
|
|
155
158
|
|
|
156
159
|
Args:
|
|
157
|
-
host_name (str
|
|
158
|
-
|
|
159
|
-
port (int
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
is_master (bool,
|
|
160
|
+
host_name (str): The hostname or IP Address the server store should run on.
|
|
161
|
+
Currently only supports user input IP addresses.
|
|
162
|
+
port (int): The port on which the server store should listen for incoming requests.
|
|
163
|
+
world_size (int, optional): The total number of store users (number of clients + 1 for the server).
|
|
164
|
+
Default is ``None``, indicates a non-fixed number of store users. This parameter is
|
|
165
|
+
only valid for the server.
|
|
166
|
+
is_master (bool, optional): True when initializing the server store and False for client stores.
|
|
164
167
|
Default is ``False``.
|
|
165
|
-
timeout (timedelta,
|
|
166
|
-
|
|
167
|
-
wait_for_workers (bool,
|
|
168
|
-
store. This is only applicable when `world_size` is a fixed value. Default is ``True``.
|
|
168
|
+
timeout (timedelta, optional): Timeout used by the store during initialization. Default is
|
|
169
|
+
``timedelta(seconds=300)``.
|
|
170
|
+
wait_for_workers (bool, optional): Whether to wait for all the workers to connect with the server
|
|
171
|
+
store. This is only applicable when `world_size` is a fixed value. Default is ``True``. This
|
|
172
|
+
parameter is only valid for the server.
|
|
169
173
|
multi_tenant (bool, invalid, optional): If ``True``, all ``TCPStore`` instances in the current process with
|
|
170
174
|
the same host/port will use the same underlying ``TCPServer``. Default is ``False``.
|
|
171
175
|
master_listen_fd (int, invalid, optional): If specified, the underlying ``TCPServer`` will listen on this file
|
|
@@ -191,12 +195,106 @@ class TCPStore:
|
|
|
191
195
|
for more details.
|
|
192
196
|
|
|
193
197
|
>>> from mindspore.mint.distributed import TCPStore
|
|
194
|
-
>>> store = TCPStore()
|
|
198
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
195
199
|
"""
|
|
196
200
|
|
|
197
|
-
def __init__(self, host_name
|
|
201
|
+
def __init__(self, host_name, port, world_size=None, is_master=False, timeout=timedelta(seconds=300),
|
|
198
202
|
wait_for_workers=True, multi_tenant=False, master_listen_fd=None, use_libuv=True):
|
|
199
|
-
|
|
203
|
+
if not isinstance(host_name, str):
|
|
204
|
+
raise TypeError(
|
|
205
|
+
"For 'TCPStore', the argument 'host_name' must be type of string, "
|
|
206
|
+
"but got 'host_name' type : {}.".format(type(host_name))
|
|
207
|
+
)
|
|
208
|
+
if not isinstance(port, int):
|
|
209
|
+
raise TypeError(
|
|
210
|
+
"For 'TCPStore', the argument 'port' must be type of int, "
|
|
211
|
+
"but got 'port' type : {}.".format(type(port))
|
|
212
|
+
)
|
|
213
|
+
if not isinstance(is_master, bool):
|
|
214
|
+
raise TypeError(
|
|
215
|
+
"For 'TCPStore', the argument 'is_master' must be type of bool, "
|
|
216
|
+
"but got 'is_master' type : {}.".format(type(is_master))
|
|
217
|
+
)
|
|
218
|
+
if not isinstance(timeout, timedelta):
|
|
219
|
+
raise TypeError(
|
|
220
|
+
"For 'TCPStore', the argument 'timeout' must be type of timedelta, "
|
|
221
|
+
"but got 'timeout' type : {}.".format(type(timeout))
|
|
222
|
+
)
|
|
223
|
+
if not isinstance(wait_for_workers, bool):
|
|
224
|
+
raise TypeError(
|
|
225
|
+
"For 'TCPStore', the argument 'wait_for_workers' must be type of bool, "
|
|
226
|
+
"but got 'wait_for_workers' type : {}.".format(type(wait_for_workers))
|
|
227
|
+
)
|
|
228
|
+
if world_size is None:
|
|
229
|
+
world_size = 1
|
|
230
|
+
if not isinstance(world_size, int):
|
|
231
|
+
raise TypeError(
|
|
232
|
+
"For 'TCPStore', the argument 'world_size' must be type of int, "
|
|
233
|
+
"but got 'world_size' type : {}.".format(type(world_size))
|
|
234
|
+
)
|
|
235
|
+
if port < 0 or port > 65535:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"For 'TCPStore', the argument 'port' must be legal, "
|
|
238
|
+
f"but got {port}."
|
|
239
|
+
)
|
|
240
|
+
if world_size <= 0:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"For 'TCPStore', the argument 'world_size' must be legal, "
|
|
243
|
+
f"but got {world_size}."
|
|
244
|
+
)
|
|
245
|
+
timeout_ms = int(timeout.total_seconds() * 1000)
|
|
246
|
+
self.instance = TCPStoreClient(host_name, port, is_master, timeout_ms, world_size, wait_for_workers)
|
|
247
|
+
self.host = host_name
|
|
248
|
+
self.port = port
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def add(self, key, amount):
|
|
252
|
+
"""
|
|
253
|
+
When the `add` function is called for the first time with a given key, it creates a counter in
|
|
254
|
+
the storage corresponding to that key, with the initial value set to `amount`. Subsequent calls
|
|
255
|
+
to `add` with the same key increment the counter by amount.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
key (str): The key whose counter value will be incremented.
|
|
259
|
+
amount (int): The amount by which the counter will be incremented.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
int, value of counter with `key`.
|
|
263
|
+
|
|
264
|
+
Raises:
|
|
265
|
+
TypeError: If `key` is not string.
|
|
266
|
+
TypeError: If `amount` is not int.
|
|
267
|
+
RuntimeError: If the `add` and `set` pass the same `key` and the `value` passed by `set` cannot
|
|
268
|
+
be correctly converted to a numerical value, calling `add` will result in an error.
|
|
269
|
+
|
|
270
|
+
Supported Platforms:
|
|
271
|
+
``Ascend``
|
|
272
|
+
|
|
273
|
+
Examples:
|
|
274
|
+
.. note::
|
|
275
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
276
|
+
|
|
277
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
278
|
+
without any third-party or configuration file dependencies.
|
|
279
|
+
Please see the `msrun start up
|
|
280
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
281
|
+
for more details.
|
|
282
|
+
|
|
283
|
+
>>> from mindspore.mint.distributed import TCPStore
|
|
284
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
285
|
+
>>> store.add("first_key", 1)
|
|
286
|
+
"""
|
|
287
|
+
if not isinstance(key, str):
|
|
288
|
+
raise TypeError(
|
|
289
|
+
"For 'TCPStore.add', the argument 'key' must be type of string, "
|
|
290
|
+
"but got 'key' type : {}.".format(type(key))
|
|
291
|
+
)
|
|
292
|
+
if not isinstance(amount, int):
|
|
293
|
+
raise TypeError(
|
|
294
|
+
"For 'TCPStore.add', the argument 'amount' must be type of string or int, "
|
|
295
|
+
"but got 'amount' type : {}.".format(type(amount))
|
|
296
|
+
)
|
|
297
|
+
return self.instance.add(key, amount)
|
|
200
298
|
|
|
201
299
|
|
|
202
300
|
def set(self, key, value):
|
|
@@ -227,7 +325,7 @@ class TCPStore:
|
|
|
227
325
|
for more details.
|
|
228
326
|
|
|
229
327
|
>>> from mindspore.mint.distributed import TCPStore
|
|
230
|
-
>>> store = TCPStore()
|
|
328
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
231
329
|
>>> store.set("first_key", "first_value")
|
|
232
330
|
"""
|
|
233
331
|
if not isinstance(key, str):
|
|
@@ -245,8 +343,9 @@ class TCPStore:
|
|
|
245
343
|
|
|
246
344
|
def get(self, key):
|
|
247
345
|
"""
|
|
248
|
-
Retrieves the value associated with the given `key` in the store. If `key`
|
|
249
|
-
|
|
346
|
+
Retrieves the value associated with the given `key` in the store. If the `key` does not exist
|
|
347
|
+
in the storage, this function will wait for the `timeout` set by the class initialization and then
|
|
348
|
+
throw an exception.
|
|
250
349
|
|
|
251
350
|
Args:
|
|
252
351
|
key (str): The function will return the value associated with this key.
|
|
@@ -256,6 +355,7 @@ class TCPStore:
|
|
|
256
355
|
|
|
257
356
|
Raises:
|
|
258
357
|
TypeError: If `key` is not string.
|
|
358
|
+
RuntimeError: If `get` runs out of time.
|
|
259
359
|
|
|
260
360
|
Supported Platforms:
|
|
261
361
|
``Ascend``
|
|
@@ -271,7 +371,7 @@ class TCPStore:
|
|
|
271
371
|
for more details.
|
|
272
372
|
|
|
273
373
|
>>> from mindspore.mint.distributed import TCPStore
|
|
274
|
-
>>> store = TCPStore()
|
|
374
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
275
375
|
>>> store.set("first_key", "first_value")
|
|
276
376
|
>>> data = store.get("first_key")
|
|
277
377
|
>>> print(data)
|
|
@@ -299,7 +399,7 @@ class TCPStore:
|
|
|
299
399
|
TypeError: If `key` is not string.
|
|
300
400
|
|
|
301
401
|
Supported Platforms:
|
|
302
|
-
``
|
|
402
|
+
``Ascend``
|
|
303
403
|
|
|
304
404
|
Examples:
|
|
305
405
|
.. note::
|
|
@@ -312,7 +412,7 @@ class TCPStore:
|
|
|
312
412
|
for more details.
|
|
313
413
|
|
|
314
414
|
>>> from mindspore.mint.distributed import TCPStore
|
|
315
|
-
>>> store = TCPStore()
|
|
415
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
316
416
|
>>> store.set("first_key", "first_value")
|
|
317
417
|
>>> # This should return true
|
|
318
418
|
>>> store.delete_key("first_key")
|
|
@@ -387,6 +487,7 @@ def is_initialized():
|
|
|
387
487
|
return _is_initialized()
|
|
388
488
|
|
|
389
489
|
|
|
490
|
+
@args_type_check(init_method=str, timeout=timedelta, world_size=int, rank=int, store=TCPStore)
|
|
390
491
|
def init_process_group(backend="hccl",
|
|
391
492
|
init_method=None,
|
|
392
493
|
timeout=None,
|
|
@@ -404,26 +505,29 @@ def init_process_group(backend="hccl",
|
|
|
404
505
|
and the instantiation and execution of any operation and net.
|
|
405
506
|
|
|
406
507
|
Args:
|
|
407
|
-
backend (str, optional): The backend to ues.
|
|
408
|
-
init_method (str,
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
information. Provides parameters consistent with pytorch, but is not currently support,
|
|
417
|
-
setting is invalid.
|
|
508
|
+
backend (str, optional): The backend to ues. Default is ``"hccl"`` and now only support hccl.
|
|
509
|
+
init_method (str, optional): URL specifying how to init collective communication group. Default is ``None``.
|
|
510
|
+
timeout (timedelta, optional): Timeout for API executed. Default is ``None``. Currently, this parameter is
|
|
511
|
+
only supported for host-side cluster network configuration using `init_method` or `store`.
|
|
512
|
+
world_size (int, optional): Number of the processes participating in the job. Default is ``-1``.
|
|
513
|
+
rank (int, optional): Rank of the current process. Default is ``-1``.
|
|
514
|
+
store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process
|
|
515
|
+
communication addresses and connection information. Default is ``None``. Currently, only the
|
|
516
|
+
``TCPStore`` type is supported.
|
|
418
517
|
pg_options (ProcessGroupOptions, invalid): process group options specifying what additional options need to be
|
|
419
|
-
passed in during the construction of specific process group.
|
|
420
|
-
|
|
421
|
-
device_id (int, invalid): the device id to exeute.
|
|
422
|
-
|
|
518
|
+
passed in during the construction of specific process group. The provided parameter is a reserved
|
|
519
|
+
parameter, and the current setting does not take effect.
|
|
520
|
+
device_id (int, invalid): the device id to exeute. The provided parameter is a reserved parameter,
|
|
521
|
+
and the current setting does not take effect.
|
|
423
522
|
|
|
424
523
|
Raises:
|
|
425
524
|
ValueError: If `backend` is not hccl.
|
|
426
525
|
ValueError: If `world_size` is not equal to -1 or process group number.
|
|
526
|
+
ValueError: If both `init_method` and `store` are set.
|
|
527
|
+
ValueError: `world_size` is not correctly set as a positive integer value, when using the initialization
|
|
528
|
+
method `init_method` or `store`.
|
|
529
|
+
ValueError: `rank` is not correctly set as a non-negative integer, when using the initialization method
|
|
530
|
+
`init_method` or `store`.
|
|
427
531
|
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails,
|
|
428
532
|
or the environment variables RANK_ID/MINDSPORE_HCCL_CONFIG_PATH
|
|
429
533
|
have not been exported when backend is HCCL.
|
|
@@ -447,25 +551,34 @@ def init_process_group(backend="hccl",
|
|
|
447
551
|
>>> init_process_group()
|
|
448
552
|
>>> destroy_process_group()
|
|
449
553
|
"""
|
|
450
|
-
if init_method is not None:
|
|
451
|
-
logger.warning("init_method is ignored, setting is invalid")
|
|
452
|
-
if timeout is not None:
|
|
453
|
-
logger.warning("timeout is ignored, setting is invalid")
|
|
454
|
-
if store is not None:
|
|
455
|
-
logger.warning("store is ignored, setting is invalid")
|
|
456
554
|
if pg_options is not None:
|
|
457
555
|
logger.warning("pg_options is ignored, setting is invalid")
|
|
458
556
|
if device_id is not None:
|
|
459
557
|
logger.warning("device_id is ignored, setting is invalid")
|
|
460
|
-
if rank != -1:
|
|
461
|
-
logger.warning("rank is ignored, setting is invalid")
|
|
462
558
|
if backend != "hccl":
|
|
463
559
|
raise ValueError(
|
|
464
560
|
"Only support hccl now, please setting backend to hccl or using default value"
|
|
465
561
|
)
|
|
466
562
|
|
|
467
|
-
|
|
468
|
-
|
|
563
|
+
if init_method is not None and store is not None:
|
|
564
|
+
raise ValueError(
|
|
565
|
+
"Only one of init_method and store is supported."
|
|
566
|
+
)
|
|
567
|
+
if init_method is not None or store is not None:
|
|
568
|
+
if world_size <= 0:
|
|
569
|
+
raise ValueError(
|
|
570
|
+
"Specified world_size must be a positive integer."
|
|
571
|
+
)
|
|
572
|
+
if rank < 0:
|
|
573
|
+
raise ValueError(
|
|
574
|
+
"Specified rank must be a non-negative integer."
|
|
575
|
+
)
|
|
576
|
+
if timeout is None:
|
|
577
|
+
timeout = timedelta(seconds=300)
|
|
578
|
+
timeout_ms = int(timeout.total_seconds() * 1000)
|
|
579
|
+
_init_without_sched(backend, init_method, timeout_ms, world_size, rank, store)
|
|
580
|
+
else:
|
|
581
|
+
init(backend)
|
|
469
582
|
|
|
470
583
|
if world_size != -1 and world_size != get_group_size():
|
|
471
584
|
raise ValueError(
|
|
@@ -513,7 +626,10 @@ def destroy_process_group(group=None):
|
|
|
513
626
|
"""
|
|
514
627
|
|
|
515
628
|
if group == GlobalComm.WORLD_COMM_GROUP or group is None:
|
|
516
|
-
|
|
629
|
+
_pynative_executor.sync()
|
|
630
|
+
_finalize_collective()
|
|
631
|
+
_ExistingGroup.ITEMS.clear()
|
|
632
|
+
_ExistingGroup.GROUP_RANKS.clear()
|
|
517
633
|
elif not isinstance(group, str):
|
|
518
634
|
raise TypeError(
|
|
519
635
|
"For 'destroy_group', the argument 'group' must be type of string or None, "
|
|
@@ -671,6 +787,12 @@ def new_group(ranks=None,
|
|
|
671
787
|
hccl_config(dict)
|
|
672
788
|
}
|
|
673
789
|
|
|
790
|
+
`hccl_config` currently only supports "hccl_buffer_size" or "hccl_comm".
|
|
791
|
+
|
|
792
|
+
- hccl_buffer_size (uint32): specifies the size of the HCCL communication buffer.
|
|
793
|
+
- hccl_comm (int64): specifies an existing HcclComm pointer. If "hccl_comm" is set,
|
|
794
|
+
"hccl_buffer_size" will be ignored.
|
|
795
|
+
|
|
674
796
|
use_local_synchronization (bool, invalid): Currently it is a reserved parameter.
|
|
675
797
|
group_desc (str, invalid): Currently it is a reserved parameter.
|
|
676
798
|
|
|
@@ -989,6 +1111,22 @@ def _check_all_tensor_same_dtype_and_shape(*tensor_lists):
|
|
|
989
1111
|
)
|
|
990
1112
|
|
|
991
1113
|
|
|
1114
|
+
@_primexpr
|
|
1115
|
+
def _check_output_shape(output, expected_shape, op_name):
|
|
1116
|
+
if output.shape != expected_shape:
|
|
1117
|
+
raise TypeError(
|
|
1118
|
+
f"For {op_name}, the output shape should be {expected_shape}, "
|
|
1119
|
+
f"but got {output.shape}.")
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
@_primexpr
|
|
1123
|
+
def _check_output_dtype(output, expected_dtype, op_name):
|
|
1124
|
+
if output.dtype != expected_dtype:
|
|
1125
|
+
raise TypeError(
|
|
1126
|
+
f"For {op_name}, the output dtype should be {expected_dtype}, "
|
|
1127
|
+
f"but got {output.dtype}.")
|
|
1128
|
+
|
|
1129
|
+
|
|
992
1130
|
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
|
993
1131
|
"""
|
|
994
1132
|
Reduce tensors across all devices in such a way that all deviceswill get the same final result,
|
|
@@ -1153,6 +1291,91 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
|
|
|
1153
1291
|
return handle
|
|
1154
1292
|
|
|
1155
1293
|
|
|
1294
|
+
def all_gather_into_tensor_uneven(output, input, output_split_sizes=None, group=None, async_op=False):
|
|
1295
|
+
r"""
|
|
1296
|
+
Gathers and concatenates tensors across devices with uneven first dimensions.
|
|
1297
|
+
|
|
1298
|
+
Note:
|
|
1299
|
+
- Input tensors must have identical shapes except for the first dimension.
|
|
1300
|
+
- Output tensor's first dimension should equal to the sum of all devices' input first dimensions.
|
|
1301
|
+
|
|
1302
|
+
Args:
|
|
1303
|
+
output (Tensor): Concatenated output tensor with shape :math:`(\sum_{i=0}^{N-1} x_{i1}, x_2, ..., x_R)`,
|
|
1304
|
+
where N is the number of devices in the group.
|
|
1305
|
+
input (Tensor): Local input tensor with shape :math:`(x_{k1}, x_2, ..., x_R)`, where k is current device's rank.
|
|
1306
|
+
output_split_sizes (list[int], optional): Specifies first dimension sizes from each device.
|
|
1307
|
+
Must match actual input dimensions when provided.
|
|
1308
|
+
If ``None``, assumes equal split sizes across devices. Default: ``None``.
|
|
1309
|
+
group (str, optional): The communication group to work on. If ``None``,
|
|
1310
|
+
which means ``"hccl_world_group"`` in Ascend. Default: ``None``.
|
|
1311
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False``.
|
|
1312
|
+
|
|
1313
|
+
Returns:
|
|
1314
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
1315
|
+
CommHandle will be None, when `async_op` is False.
|
|
1316
|
+
|
|
1317
|
+
Raises:
|
|
1318
|
+
ValueError: If the shape of `input` does not match the constraints of `output_split_sizes`.
|
|
1319
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
1320
|
+
|
|
1321
|
+
Supported Platforms:
|
|
1322
|
+
``Ascend``
|
|
1323
|
+
|
|
1324
|
+
Examples:
|
|
1325
|
+
.. note::
|
|
1326
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1327
|
+
|
|
1328
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1329
|
+
without any third-party or configuration file dependencies.
|
|
1330
|
+
Please see the `msrun start up
|
|
1331
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1332
|
+
for more details.
|
|
1333
|
+
|
|
1334
|
+
This example should be run with 2 devices.
|
|
1335
|
+
|
|
1336
|
+
>>> import numpy as np
|
|
1337
|
+
>>> import mindspore as ms
|
|
1338
|
+
>>> from mindspore import ops
|
|
1339
|
+
>>> from mindspore.mint.distributed import init_process_group, get_rank
|
|
1340
|
+
>>> from mindspore.mint.distributed import all_gather_into_tensor_uneven
|
|
1341
|
+
>>> from mindspore import Tensor
|
|
1342
|
+
>>>
|
|
1343
|
+
>>> ms.set_device(device_target="Ascend")
|
|
1344
|
+
>>> init_process_group()
|
|
1345
|
+
>>> if get_rank() == 0:
|
|
1346
|
+
... input_tensor = Tensor(np.ones([3, 4]).astype(np.float32))
|
|
1347
|
+
... else:
|
|
1348
|
+
... input_tensor = Tensor(np.ones([2, 4]).astype(np.float32))
|
|
1349
|
+
>>> out_tensor = Tensor(np.zeros([5, 4]).astype(np.float32))
|
|
1350
|
+
>>> output_split_sizes = [3, 2]
|
|
1351
|
+
>>> output = all_gather_into_tensor_uneven(out_tensor, input_tensor, output_split_sizes)
|
|
1352
|
+
>>> print(out_tensor)
|
|
1353
|
+
[[1. 1. 1. 1.]
|
|
1354
|
+
[1. 1. 1. 1.]
|
|
1355
|
+
[1. 1. 1. 1.]
|
|
1356
|
+
[1. 1. 1. 1.]
|
|
1357
|
+
[1. 1. 1. 1.]]
|
|
1358
|
+
"""
|
|
1359
|
+
if group is None:
|
|
1360
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1361
|
+
if not isinstance(group, str):
|
|
1362
|
+
raise TypeError(
|
|
1363
|
+
"The argument 'group' must be type of string, "
|
|
1364
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1365
|
+
)
|
|
1366
|
+
if not isinstance(async_op, bool):
|
|
1367
|
+
raise TypeError(
|
|
1368
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
1369
|
+
)
|
|
1370
|
+
group_size = get_cache_group_size(group)
|
|
1371
|
+
output_split_sizes = [] if output_split_sizes is None else output_split_sizes
|
|
1372
|
+
result = dist_comm_all_gather_into_tensor_uneven_op(
|
|
1373
|
+
output, input, output_split_sizes, group_size, group
|
|
1374
|
+
)
|
|
1375
|
+
_, handle = _deal_comm_outputs(result, async_op)
|
|
1376
|
+
return handle
|
|
1377
|
+
|
|
1378
|
+
|
|
1156
1379
|
def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
|
|
1157
1380
|
r"""
|
|
1158
1381
|
Reduces and scatters tensors from the specified communication group and
|
|
@@ -1243,6 +1466,101 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
|
|
|
1243
1466
|
return handle
|
|
1244
1467
|
|
|
1245
1468
|
|
|
1469
|
+
def reduce_scatter_tensor_uneven(output, input, input_split_sizes=None, op=ReduceOp.SUM, group=None, async_op=False):
|
|
1470
|
+
r"""
|
|
1471
|
+
Reduce tensors from the specified communication group and scatter to the output tensor
|
|
1472
|
+
according to `input_split_sizes`.
|
|
1473
|
+
|
|
1474
|
+
Note:
|
|
1475
|
+
- The input tensor must have identical shape and format across all processes.
|
|
1476
|
+
- The first dimension of input tensor should equal to the sum of `input_split_sizes`.
|
|
1477
|
+
|
|
1478
|
+
Args:
|
|
1479
|
+
output(Tensor): the output tensor has the same dtype as `input` with a shape of
|
|
1480
|
+
:math:`(input\_split\_sizes[rank], *)`, where rank is the local rank id of the device.
|
|
1481
|
+
input(Tensor): The input tensor to be reduced and scattered, Expected shape :math:`(N, *)`, where `*`
|
|
1482
|
+
means any number of additional dimensions. N must equal the sum of `input_split_sizes` across ranks.
|
|
1483
|
+
input_split_sizes (list[int], optional): List specifying how to split the first dimension of input tensor.
|
|
1484
|
+
If ``None``, splits evenly according to group size. Default: ``None``.
|
|
1485
|
+
op (str, optional): Specifies an operation used for element-wise reductions,
|
|
1486
|
+
One of ReduceOp: 'SUM', 'MIN', 'MAX'. Default: ``ReduceOp.SUM``.
|
|
1487
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1488
|
+
Ascend. Default: ``None``.
|
|
1489
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False``.
|
|
1490
|
+
|
|
1491
|
+
Returns:
|
|
1492
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
1493
|
+
CommHandle will be None, when `async_op` is False.
|
|
1494
|
+
|
|
1495
|
+
Raises:
|
|
1496
|
+
ValueError: If the shape of `output` does not match the constraints of `input_split_sizes`.
|
|
1497
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
1498
|
+
|
|
1499
|
+
Supported Platforms:
|
|
1500
|
+
``Ascend``
|
|
1501
|
+
|
|
1502
|
+
Examples:
|
|
1503
|
+
.. note::
|
|
1504
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1505
|
+
|
|
1506
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1507
|
+
without any third-party or configuration file dependencies.
|
|
1508
|
+
Please see the `msrun start up
|
|
1509
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1510
|
+
for more details.
|
|
1511
|
+
|
|
1512
|
+
This example should be run with 2 devices.
|
|
1513
|
+
|
|
1514
|
+
>>> import mindspore as ms
|
|
1515
|
+
>>> from mindspore import Tensor
|
|
1516
|
+
>>> from mindspore.mint.distributed import init_process_group, get_rank
|
|
1517
|
+
>>> from mindspore.mint.distributed import reduce_scatter_tensor_uneven
|
|
1518
|
+
>>> import numpy as np
|
|
1519
|
+
>>>
|
|
1520
|
+
>>> ms.set_device(device_target="Ascend")
|
|
1521
|
+
>>> init_process_group()
|
|
1522
|
+
>>> input_tensor = Tensor(np.ones([5, 8]).astype(np.float32))
|
|
1523
|
+
>>> if get_rank() == 0:
|
|
1524
|
+
... output_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
1525
|
+
... else:
|
|
1526
|
+
... output_tensor = Tensor(np.ones([3, 8]).astype(np.float32))
|
|
1527
|
+
>>> input_split_sizes = [2, 3]
|
|
1528
|
+
>>> output = reduce_scatter_tensor_uneven(output_tensor, input_tensor, input_split_sizes)
|
|
1529
|
+
>>> print(output_tensor)
|
|
1530
|
+
rank 0:
|
|
1531
|
+
[[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
1532
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]]
|
|
1533
|
+
rank 1:
|
|
1534
|
+
[[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
1535
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
1536
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]]
|
|
1537
|
+
"""
|
|
1538
|
+
if not isinstance(op, str):
|
|
1539
|
+
raise TypeError("For reduce_scatter_tensor_uneven, the input op type must be str")
|
|
1540
|
+
if op not in ("sum", "min", "max"):
|
|
1541
|
+
raise TypeError(
|
|
1542
|
+
"For reduce_scatter_tensor_uneven, the input op value must be one of sum, prod, min, max"
|
|
1543
|
+
)
|
|
1544
|
+
if group is None:
|
|
1545
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1546
|
+
if not isinstance(group, str):
|
|
1547
|
+
raise TypeError(
|
|
1548
|
+
"The argument 'group' must be type of string, "
|
|
1549
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1550
|
+
)
|
|
1551
|
+
if not isinstance(async_op, bool):
|
|
1552
|
+
raise TypeError(
|
|
1553
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
1554
|
+
)
|
|
1555
|
+
input_split_sizes = [] if input_split_sizes is None else input_split_sizes
|
|
1556
|
+
rank_size = get_cache_group_size(group)
|
|
1557
|
+
result = dist_comm_reduce_scatter_tensor_uneven_op(
|
|
1558
|
+
output, input, input_split_sizes, rank_size, op, group
|
|
1559
|
+
)
|
|
1560
|
+
_, handle = _deal_comm_outputs(result, async_op)
|
|
1561
|
+
return handle
|
|
1562
|
+
|
|
1563
|
+
|
|
1246
1564
|
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
|
|
1247
1565
|
"""
|
|
1248
1566
|
Reduces tensors across the processes in the specified communication group, sends the result
|
|
@@ -2386,10 +2704,7 @@ def all_to_all_single(output,
|
|
|
2386
2704
|
|
|
2387
2705
|
def _check_tensor_list(tensor_list, tensor, group_size):
|
|
2388
2706
|
"""check all elements in tensor_list are type of Tensor or tuple or list"""
|
|
2389
|
-
|
|
2390
|
-
raise TypeError(
|
|
2391
|
-
f"The argument list tensor len must be equal to group rank size, but got {len(tensor_list)}."
|
|
2392
|
-
)
|
|
2707
|
+
_check_group_tensor_list(tensor_list, group_size)
|
|
2393
2708
|
if tensor.dtype != tensor_list[0].dtype:
|
|
2394
2709
|
raise TypeError(
|
|
2395
2710
|
f"The argument list tensor type must be equal to tensor type, but got {tensor_list[0].dtype}."
|
|
@@ -2400,13 +2715,17 @@ def _check_tensor_list(tensor_list, tensor, group_size):
|
|
|
2400
2715
|
)
|
|
2401
2716
|
|
|
2402
2717
|
|
|
2718
|
+
def _check_group_tensor_list(tensor_list, group_size):
|
|
2719
|
+
if not tensor_list or len(tensor_list) != group_size:
|
|
2720
|
+
raise TypeError(
|
|
2721
|
+
f"The argument list tensor len must be equal to group rank size, but got {len(tensor_list)}."
|
|
2722
|
+
)
|
|
2723
|
+
|
|
2724
|
+
|
|
2403
2725
|
def all_gather(tensor_list, tensor, group=None, async_op=False):
|
|
2404
2726
|
"""
|
|
2405
2727
|
Gathers tensors from the specified communication group and returns the tensor list which is all gathered.
|
|
2406
2728
|
|
|
2407
|
-
Note:
|
|
2408
|
-
The tensors must have the same shape and format in all processes of the collection.
|
|
2409
|
-
|
|
2410
2729
|
Args:
|
|
2411
2730
|
tensor_list (list[Tensor]): Output list.
|
|
2412
2731
|
tensor (Tensor): The input tensor to be all gathered into tensor.
|
|
@@ -2461,7 +2780,7 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
|
|
|
2461
2780
|
|
|
2462
2781
|
"""
|
|
2463
2782
|
_check_all_tensors(tensor_list)
|
|
2464
|
-
|
|
2783
|
+
_check_all_tensor_same_dtype(tensor_list)
|
|
2465
2784
|
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
2466
2785
|
raise TypeError("For all_gather_into_tensor, the input tensor must be tensor")
|
|
2467
2786
|
if group is None:
|
|
@@ -2476,7 +2795,10 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
|
|
|
2476
2795
|
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
2477
2796
|
)
|
|
2478
2797
|
group_size = get_cache_group_size(group)
|
|
2479
|
-
|
|
2798
|
+
_check_group_tensor_list(tensor_list, group_size)
|
|
2799
|
+
rank_id = get_group_rank_from_world_rank(get_rank(), group)
|
|
2800
|
+
_check_output_shape(tensor, tensor_list[rank_id].shape, "all_gather")
|
|
2801
|
+
_check_output_dtype(tensor, tensor_list[0].dtype, "all_gather")
|
|
2480
2802
|
result = dist_comm_all_gather_op(tensor_list, tensor, group_size, group)
|
|
2481
2803
|
_, handle = _deal_comm_outputs(result, async_op)
|
|
2482
2804
|
return handle
|
|
@@ -2487,9 +2809,6 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
|
|
|
2487
2809
|
Reduces and scatters tensors from the specified communication group and
|
|
2488
2810
|
returns the tensor which is reduced and scattered.
|
|
2489
2811
|
|
|
2490
|
-
Note:
|
|
2491
|
-
The tensors must have the same shape and format in all processes of the collection.
|
|
2492
|
-
|
|
2493
2812
|
Args:
|
|
2494
2813
|
output (Tensor): the output tensor.
|
|
2495
2814
|
input_list (list[Tensor]): List of tensors to reduce and scatter.
|
|
@@ -2543,7 +2862,7 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
|
|
|
2543
2862
|
"""
|
|
2544
2863
|
|
|
2545
2864
|
_check_all_tensors(input_list)
|
|
2546
|
-
|
|
2865
|
+
_check_all_tensor_same_dtype(input_list)
|
|
2547
2866
|
if not isinstance(output, (Tensor, Tensor_)):
|
|
2548
2867
|
raise TypeError("For reduce_scatter, the output tensor must be tensor")
|
|
2549
2868
|
if group is None:
|
|
@@ -2564,7 +2883,11 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
|
|
|
2564
2883
|
"For reduce_scatter, the input op value must be one of sum, prod, min, max"
|
|
2565
2884
|
)
|
|
2566
2885
|
rank_size = get_cache_group_size(group)
|
|
2567
|
-
|
|
2886
|
+
_check_group_tensor_list(input_list, rank_size)
|
|
2887
|
+
|
|
2888
|
+
rank_id = get_group_rank_from_world_rank(get_rank(), group)
|
|
2889
|
+
_check_output_shape(output, input_list[rank_id].shape, "reduce_scatter")
|
|
2890
|
+
_check_output_dtype(output, input_list[0].dtype, "reduce_scatter")
|
|
2568
2891
|
result = dist_comm_reduce_scatter_op(output, input_list, rank_size, op, group)
|
|
2569
2892
|
_, handle = _deal_comm_outputs(result, async_op)
|
|
2570
2893
|
return handle
|