mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,9 +13,190 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Communication management API"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
import hashlib
|
|
18
|
+
import builtins
|
|
19
|
+
import io
|
|
20
|
+
import pickle
|
|
21
|
+
import numpy as np
|
|
16
22
|
from mindspore import log as logger
|
|
17
|
-
from mindspore.
|
|
18
|
-
from mindspore.
|
|
23
|
+
from mindspore.common import dtype as mstype
|
|
24
|
+
from mindspore.ops import ReduceOp, cat
|
|
25
|
+
from mindspore.common.tensor import Tensor
|
|
26
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
27
|
+
from mindspore.ops.primitive import _primexpr
|
|
28
|
+
from mindspore.communication._comm_helper import (
|
|
29
|
+
_destroy_group_helper,
|
|
30
|
+
_get_rank_helper,
|
|
31
|
+
_get_size_helper,
|
|
32
|
+
_get_backend,
|
|
33
|
+
_get_group_ranks,
|
|
34
|
+
_is_available,
|
|
35
|
+
_is_initialized,
|
|
36
|
+
)
|
|
37
|
+
from mindspore.communication import (
|
|
38
|
+
init,
|
|
39
|
+
release,
|
|
40
|
+
get_group_size,
|
|
41
|
+
get_world_rank_from_group_rank,
|
|
42
|
+
create_group,
|
|
43
|
+
GlobalComm,
|
|
44
|
+
get_group_rank_from_world_rank,
|
|
45
|
+
)
|
|
46
|
+
from mindspore.communication.comm_func import (
|
|
47
|
+
_deal_comm_outputs,
|
|
48
|
+
_check_all_tensors,
|
|
49
|
+
_check_all_tensor_same_dtype,
|
|
50
|
+
_is_split_sizes_empty,
|
|
51
|
+
_get_size,
|
|
52
|
+
_get_group_rank_from_world_rank_from_cache_helper,
|
|
53
|
+
)
|
|
54
|
+
from mindspore.ops.auto_generate.gen_ops_prim import (
|
|
55
|
+
dist_comm_all_gather_op,
|
|
56
|
+
dist_comm_all_reduce_op,
|
|
57
|
+
dist_comm_reduce_scatter_op,
|
|
58
|
+
dist_comm_isend_op,
|
|
59
|
+
dist_comm_all_to_all_v_op,
|
|
60
|
+
dist_comm_reduce_scatter_tensor_op,
|
|
61
|
+
dist_comm_all_to_all_v_single_op,
|
|
62
|
+
dist_comm_broadcast_op,
|
|
63
|
+
dist_comm_all_gather_into_tensor_op,
|
|
64
|
+
dist_comm_irecv_op,
|
|
65
|
+
dist_comm_scatter_tensor_op,
|
|
66
|
+
dist_comm_gather_into_tensor_op,
|
|
67
|
+
dist_comm_gather_op,
|
|
68
|
+
dist_comm_reduce_op,
|
|
69
|
+
dist_comm_scatter_op,
|
|
70
|
+
dist_comm_barrier_op,
|
|
71
|
+
dist_comm_batch_isend_irecv_op,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
_pickler = pickle.Pickler
|
|
75
|
+
_unpickler = pickle.Unpickler
|
|
76
|
+
BACKEND_HCCL = "hccl"
|
|
77
|
+
BACKEND_MCCL = "mccl"
|
|
78
|
+
_GROPU_SIZE_CACHE = {}
|
|
79
|
+
_GROPU_RANK_CACHE = {}
|
|
80
|
+
|
|
81
|
+
safe_builtins = {
|
|
82
|
+
'range',
|
|
83
|
+
'complex',
|
|
84
|
+
'set',
|
|
85
|
+
'frozenset',
|
|
86
|
+
'slice',
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_cache_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|
91
|
+
"""get cache group size."""
|
|
92
|
+
global _GROPU_SIZE_CACHE
|
|
93
|
+
if group not in _GROPU_SIZE_CACHE:
|
|
94
|
+
_GROPU_SIZE_CACHE[group] = _get_size_helper(group)
|
|
95
|
+
group_size = _GROPU_SIZE_CACHE[group]
|
|
96
|
+
return group_size
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_cache_group_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|
100
|
+
"""get cache rank id."""
|
|
101
|
+
global _GROPU_RANK_CACHE
|
|
102
|
+
if group not in _GROPU_RANK_CACHE:
|
|
103
|
+
_GROPU_RANK_CACHE[group] = _get_rank_helper(group)
|
|
104
|
+
group_rank = _GROPU_RANK_CACHE[group]
|
|
105
|
+
return group_rank
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class RestrictedUnpickler(pickle.Unpickler):
|
|
109
|
+
# Override find_class method.
|
|
110
|
+
def find_class(self, module, name):
|
|
111
|
+
# Only allow safe classes from builtins.
|
|
112
|
+
if module == "builtins" and name in safe_builtins:
|
|
113
|
+
return getattr(builtins, name)
|
|
114
|
+
# Forbid everything else.
|
|
115
|
+
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
116
|
+
(module, name))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def restricted_loads(s):
|
|
120
|
+
"""Helper function analogous to pickle.loads()."""
|
|
121
|
+
return RestrictedUnpickler(io.BytesIO(s)).load()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _object_to_tensor(obj, size=0):
|
|
125
|
+
f = io.BytesIO()
|
|
126
|
+
_pickler(f).dump(obj)
|
|
127
|
+
buf = np.frombuffer(f.getvalue(), dtype=np.int8)
|
|
128
|
+
tensor_size = buf.size
|
|
129
|
+
if size > tensor_size:
|
|
130
|
+
buf = np.resize(buf, size)
|
|
131
|
+
tensor_size = size
|
|
132
|
+
return Tensor(buf), tensor_size
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _tensor_to_object(tensor, tensor_size):
|
|
136
|
+
buf = tensor.asnumpy().tobytes()[:tensor_size]
|
|
137
|
+
return restricted_loads(buf)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def is_available():
|
|
141
|
+
"""
|
|
142
|
+
Checks if distributed module is available.
|
|
143
|
+
|
|
144
|
+
Note:
|
|
145
|
+
Always returns `True` because MindSpore always has distributed ability on all platforms.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
bool, whether this distributed module is available.
|
|
149
|
+
|
|
150
|
+
Supported Platforms:
|
|
151
|
+
``Ascend``
|
|
152
|
+
|
|
153
|
+
Examples:
|
|
154
|
+
.. note::
|
|
155
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
156
|
+
|
|
157
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
158
|
+
without any third-party or configuration file dependencies.
|
|
159
|
+
Please see the `msrun start up
|
|
160
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
161
|
+
for more details.
|
|
162
|
+
|
|
163
|
+
>>> import mindspore as ms
|
|
164
|
+
>>> from mindspore.mint.distributed import is_available
|
|
165
|
+
>>> ms.set_device(device_target="Ascend")
|
|
166
|
+
>>> is_available()
|
|
167
|
+
True
|
|
168
|
+
"""
|
|
169
|
+
return _is_available()
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def is_initialized():
|
|
173
|
+
"""
|
|
174
|
+
Checks if default process group has been initialized.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
bool, whether the default process group has been initialized.
|
|
178
|
+
|
|
179
|
+
Supported Platforms:
|
|
180
|
+
``Ascend``
|
|
181
|
+
|
|
182
|
+
Examples:
|
|
183
|
+
.. note::
|
|
184
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
185
|
+
|
|
186
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
187
|
+
without any third-party or configuration file dependencies.
|
|
188
|
+
Please see the `msrun start up
|
|
189
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
190
|
+
for more details.
|
|
191
|
+
|
|
192
|
+
>>> import mindspore as ms
|
|
193
|
+
>>> from mindspore.mint.distributed import init_process_group, is_initialized
|
|
194
|
+
>>> ms.set_device(device_target="Ascend")
|
|
195
|
+
>>> init_process_group()
|
|
196
|
+
>>> print(is_initialized())
|
|
197
|
+
True
|
|
198
|
+
"""
|
|
199
|
+
return _is_initialized()
|
|
19
200
|
|
|
20
201
|
|
|
21
202
|
def init_process_group(backend="hccl",
|
|
@@ -37,28 +218,27 @@ def init_process_group(backend="hccl",
|
|
|
37
218
|
Args:
|
|
38
219
|
backend (str, optional): The backend to ues. default is hccl and now only support hccl.
|
|
39
220
|
init_method (str, invalid): URL specifying how to init collective communication group. Provides parameters
|
|
40
|
-
|
|
221
|
+
consistent with pytorch, but is not currently support, setting is invalid.
|
|
41
222
|
timeout (timedelta, invalid): Timeout for API executed. Provides parameters consistent with pytorch, but is not
|
|
42
|
-
|
|
223
|
+
currently support, setting is invalid.
|
|
43
224
|
world_size (int, optional): Number of the processes participating in the job.
|
|
44
225
|
rank (int, invalid): Rank of the current process. Provides parameters consistent with pytorch, but is not
|
|
45
|
-
|
|
226
|
+
currently support, setting is invalid.
|
|
46
227
|
store (Store, invalid): Key/Value store accessible to all workers, used to exchange connection/address
|
|
47
|
-
|
|
48
|
-
|
|
228
|
+
information. Provides parameters consistent with pytorch, but is not currently support,
|
|
229
|
+
setting is invalid.
|
|
49
230
|
pg_options (ProcessGroupOptions, invalid): process group options specifying what additional options need to be
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
setting is invalid.
|
|
231
|
+
passed in during the construction of specific process group. Provides parameters consistent with pytorch,
|
|
232
|
+
but is not currently support, setting is invalid.
|
|
53
233
|
device_id (int, invalid): the device id to exeute. Provides parameters consistent with pytorch, but is not
|
|
54
|
-
|
|
234
|
+
currently support, setting is invalid.
|
|
55
235
|
|
|
56
236
|
Raises:
|
|
57
237
|
ValueError: If `backend` is not hccl.
|
|
58
238
|
ValueError: If `world_size` is not equal to -1 or process group number.
|
|
59
239
|
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails,
|
|
60
|
-
|
|
61
|
-
|
|
240
|
+
or the environment variables RANK_ID/MINDSPORE_HCCL_CONFIG_PATH
|
|
241
|
+
have not been exported when backend is HCCL.
|
|
62
242
|
|
|
63
243
|
Supported Platforms:
|
|
64
244
|
``Ascend``
|
|
@@ -70,13 +250,12 @@ def init_process_group(backend="hccl",
|
|
|
70
250
|
For Ascend devices, it is recommended to use the msrun startup method
|
|
71
251
|
without any third-party or configuration file dependencies.
|
|
72
252
|
Please see the `msrun start up
|
|
73
|
-
<https://www.mindspore.cn/
|
|
253
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
74
254
|
for more details.
|
|
75
255
|
|
|
76
256
|
>>> import mindspore as ms
|
|
77
|
-
>>> from mindspore import set_context
|
|
78
257
|
>>> from mindspore.mint.distributed import init_process_group, destroy_process_group
|
|
79
|
-
>>>
|
|
258
|
+
>>> ms.set_device(device_target="Ascend")
|
|
80
259
|
>>> init_process_group()
|
|
81
260
|
>>> destroy_process_group()
|
|
82
261
|
"""
|
|
@@ -93,13 +272,18 @@ def init_process_group(backend="hccl",
|
|
|
93
272
|
if rank != -1:
|
|
94
273
|
logger.warning("rank is ignored, setting is invalid")
|
|
95
274
|
if backend != "hccl":
|
|
96
|
-
raise ValueError(
|
|
275
|
+
raise ValueError(
|
|
276
|
+
"Only support hccl now, please setting backend to hccl or using default value"
|
|
277
|
+
)
|
|
97
278
|
|
|
98
|
-
#init hccl & create world group
|
|
279
|
+
# init hccl & create world group
|
|
99
280
|
init(backend)
|
|
100
281
|
|
|
101
282
|
if world_size != -1 and world_size != get_group_size():
|
|
102
|
-
raise ValueError(
|
|
283
|
+
raise ValueError(
|
|
284
|
+
"world_size is wrong, please using default value or setting: ",
|
|
285
|
+
get_group_size(),
|
|
286
|
+
)
|
|
103
287
|
|
|
104
288
|
|
|
105
289
|
def destroy_process_group(group=None):
|
|
@@ -108,11 +292,13 @@ def destroy_process_group(group=None):
|
|
|
108
292
|
If group is None or "hccl_world_group", Destroy all group and release collective communication lib.
|
|
109
293
|
|
|
110
294
|
Note:
|
|
111
|
-
This method isn't supported in GPU and CPU versions of MindSpore.
|
|
112
|
-
This method should be used after init_process_group
|
|
295
|
+
- This method isn't supported in GPU and CPU versions of MindSpore.
|
|
296
|
+
- This method should be used after :func:`mindspore.mint.distributed.init_process_group`.
|
|
113
297
|
|
|
114
298
|
Args:
|
|
115
|
-
group (str): The communication group to
|
|
299
|
+
group (str, optional): The communication group to work on. Normally, the group should be created by
|
|
300
|
+
:func:`mindspore.mint.distributed.new_group`. If ``None``, which means ``"hccl_world_group"`` in Ascend.
|
|
301
|
+
Default: ``None``.
|
|
116
302
|
|
|
117
303
|
Raises:
|
|
118
304
|
TypeError: If group is not a string.
|
|
@@ -128,13 +314,12 @@ def destroy_process_group(group=None):
|
|
|
128
314
|
For Ascend devices, it is recommended to use the msrun startup method
|
|
129
315
|
without any third-party or configuration file dependencies.
|
|
130
316
|
Please see the `msrun start up
|
|
131
|
-
<https://www.mindspore.cn/
|
|
317
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
132
318
|
for more details.
|
|
133
319
|
|
|
134
320
|
>>> import mindspore as ms
|
|
135
|
-
>>> from mindspore import set_context
|
|
136
321
|
>>> from mindspore.mint.distributed import init_process_group, destroy_process_group
|
|
137
|
-
>>>
|
|
322
|
+
>>> ms.set_device(device_target="Ascend")
|
|
138
323
|
>>> init_process_group()
|
|
139
324
|
>>> destroy_process_group()
|
|
140
325
|
"""
|
|
@@ -142,8 +327,10 @@ def destroy_process_group(group=None):
|
|
|
142
327
|
if group == GlobalComm.WORLD_COMM_GROUP or group is None:
|
|
143
328
|
release()
|
|
144
329
|
elif not isinstance(group, str):
|
|
145
|
-
raise TypeError(
|
|
146
|
-
|
|
330
|
+
raise TypeError(
|
|
331
|
+
"For 'destroy_group', the argument 'group' must be type of string or None, "
|
|
332
|
+
"but got 'group' type : {}.".format(type(group))
|
|
333
|
+
)
|
|
147
334
|
else:
|
|
148
335
|
_destroy_group_helper(group)
|
|
149
336
|
|
|
@@ -153,11 +340,12 @@ def get_rank(group=None):
|
|
|
153
340
|
Get the rank ID for the current device in the specified collective communication group.
|
|
154
341
|
|
|
155
342
|
Note:
|
|
156
|
-
This method should be used after
|
|
343
|
+
This method should be used after :func:`mindspore.mint.distributed.init_process_group`.
|
|
157
344
|
|
|
158
345
|
Args:
|
|
159
|
-
group (str): The communication group to work on. Normally, the group should be created by
|
|
160
|
-
|
|
346
|
+
group (str, optional): The communication group to work on. Normally, the group should be created by
|
|
347
|
+
:func:`mindspore.mint.distributed.new_group`. If ``None``, which means ``"hccl_world_group"`` in Ascend.
|
|
348
|
+
Default: ``None``.
|
|
161
349
|
|
|
162
350
|
Returns:
|
|
163
351
|
int, the rank ID of the calling process within the group.
|
|
@@ -167,7 +355,7 @@ def get_rank(group=None):
|
|
|
167
355
|
TypeError: If group is not a string.
|
|
168
356
|
|
|
169
357
|
Supported Platforms:
|
|
170
|
-
``Ascend``
|
|
358
|
+
``Ascend`` ``CPU``
|
|
171
359
|
|
|
172
360
|
Examples:
|
|
173
361
|
.. note::
|
|
@@ -176,22 +364,26 @@ def get_rank(group=None):
|
|
|
176
364
|
For Ascend devices, it is recommended to use the msrun startup method
|
|
177
365
|
without any third-party or configuration file dependencies.
|
|
178
366
|
Please see the `msrun start up
|
|
179
|
-
<https://www.mindspore.cn/
|
|
367
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
180
368
|
for more details.
|
|
181
369
|
|
|
182
|
-
>>>
|
|
370
|
+
>>> import mindspore as ms
|
|
183
371
|
>>> from mindspore.mint.distributed import init_process_group, get_rank
|
|
184
|
-
>>>
|
|
372
|
+
>>> ms.set_device(device_target="Ascend")
|
|
185
373
|
>>> init_process_group()
|
|
186
374
|
>>> rank_id = get_rank()
|
|
187
375
|
>>> print(rank_id)
|
|
188
376
|
>>> # the result is the rank_id in world_group
|
|
377
|
+
#rank 0: 0
|
|
378
|
+
#rank 1: 1
|
|
189
379
|
"""
|
|
190
380
|
if group is None:
|
|
191
381
|
group = GlobalComm.WORLD_COMM_GROUP
|
|
192
382
|
if not isinstance(group, str):
|
|
193
|
-
raise TypeError(
|
|
194
|
-
|
|
383
|
+
raise TypeError(
|
|
384
|
+
"For 'get_rank', the argument 'group' must be type of string, "
|
|
385
|
+
"but got 'group' type : {}.".format(type(group))
|
|
386
|
+
)
|
|
195
387
|
try:
|
|
196
388
|
ret = _get_rank_helper(group)
|
|
197
389
|
except RuntimeError as e:
|
|
@@ -205,11 +397,12 @@ def get_world_size(group=None):
|
|
|
205
397
|
Get the rank size of the specified collective communication group.
|
|
206
398
|
|
|
207
399
|
Note:
|
|
208
|
-
This method should be used after
|
|
400
|
+
This method should be used after :func:`mindspore.mint.distributed.init_process_group`.
|
|
209
401
|
|
|
210
402
|
Args:
|
|
211
|
-
group (str): The communication group to work on. Normally, the group should be created by
|
|
212
|
-
|
|
403
|
+
group (str, optional): The communication group to work on. Normally, the group should be created by
|
|
404
|
+
:func:`mindspore.mint.distributed.new_group`. If ``None``, which means ``"hccl_world_group"`` in Ascend.
|
|
405
|
+
Default: ``None``.
|
|
213
406
|
|
|
214
407
|
Returns:
|
|
215
408
|
int, the rank size of the group.
|
|
@@ -228,13 +421,14 @@ def get_world_size(group=None):
|
|
|
228
421
|
For Ascend devices, it is recommended to use the msrun startup method
|
|
229
422
|
without any third-party or configuration file dependencies.
|
|
230
423
|
Please see the `msrun start up
|
|
231
|
-
<https://www.mindspore.cn/
|
|
424
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
232
425
|
for more details.
|
|
233
426
|
|
|
427
|
+
This example should be run with 8 devices.
|
|
428
|
+
|
|
234
429
|
>>> import mindspore as ms
|
|
235
|
-
>>> from mindspore import set_context
|
|
236
430
|
>>> from mindspore.mint.distributed import init_process_group, get_world_size
|
|
237
|
-
>>>
|
|
431
|
+
>>> ms.set_device(device_target="Ascend")
|
|
238
432
|
>>> init_process_group()
|
|
239
433
|
>>> group_size = get_world_size()
|
|
240
434
|
>>> print("group_size is: ", group_size)
|
|
@@ -244,11 +438,2452 @@ def get_world_size(group=None):
|
|
|
244
438
|
if group is None:
|
|
245
439
|
group = GlobalComm.WORLD_COMM_GROUP
|
|
246
440
|
if not isinstance(group, str):
|
|
247
|
-
raise TypeError(
|
|
248
|
-
|
|
441
|
+
raise TypeError(
|
|
442
|
+
"For 'get_world_size', the argument 'group' must be type of string, "
|
|
443
|
+
"but got 'group' type : {}.".format(type(group))
|
|
444
|
+
)
|
|
249
445
|
try:
|
|
250
446
|
ret = _get_size_helper(group)
|
|
251
447
|
except RuntimeError as e:
|
|
252
448
|
logger.warning(e)
|
|
253
449
|
ret = -1
|
|
254
450
|
return ret
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def new_group(ranks=None,
|
|
454
|
+
timeout=None,
|
|
455
|
+
backend=None,
|
|
456
|
+
pg_options=None,
|
|
457
|
+
use_local_synchronization=False,
|
|
458
|
+
group_desc=None):
|
|
459
|
+
"""
|
|
460
|
+
Create a new distributed group.
|
|
461
|
+
|
|
462
|
+
Note:
|
|
463
|
+
This method should be used after :func:`mindspore.mint.distributed.init_process_group`.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
ranks (list[int], optional): List of ranks of group members. If ``None``,
|
|
467
|
+
will be create the world group. Default is ``None``.
|
|
468
|
+
timeout (int, invalid): Currently it is a reserved parameter.
|
|
469
|
+
backend (str, invalid): Support backend Library, Currently support ``"hccl"`` and ``"mccl"``.
|
|
470
|
+
when backend is ``"hccl"`` will use Huawei Collective Communication Library(HCCL).
|
|
471
|
+
when backend is ``"mccl"`` will use MindSpore Collective Communication Library(MCCL).
|
|
472
|
+
If ``None``, which means ``"hccl"`` in Ascend. Default is ``None``.
|
|
473
|
+
pg_options (str, invalid): Currently it is a reserved parameter.
|
|
474
|
+
use_local_synchronization (bool, invalid): Currently it is a reserved parameter.
|
|
475
|
+
group_desc (str, invalid): Currently it is a reserved parameter.
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
A string with group name. Return "" in the abnormal scenarios.
|
|
479
|
+
|
|
480
|
+
Raises:
|
|
481
|
+
TypeError: If list ranks in Group has duplicate rank id.
|
|
482
|
+
|
|
483
|
+
Supported Platforms:
|
|
484
|
+
``Ascend`` ``CPU``
|
|
485
|
+
|
|
486
|
+
Examples:
|
|
487
|
+
.. note::
|
|
488
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
489
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
490
|
+
without any third-party or configuration file dependencies.
|
|
491
|
+
Please see the `msrun start up
|
|
492
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
493
|
+
for more details.
|
|
494
|
+
|
|
495
|
+
>>> import mindspore as ms
|
|
496
|
+
>>> from mindspore.mint.distributed import init_process_group, new_group
|
|
497
|
+
>>> ms.set_device(device_target="Ascend")
|
|
498
|
+
>>> init_process_group()
|
|
499
|
+
>>> group = new_group()
|
|
500
|
+
>>> print("group is: ", group)
|
|
501
|
+
group is: hccl_world_group
|
|
502
|
+
"""
|
|
503
|
+
if ranks is not None:
|
|
504
|
+
if not isinstance(ranks, list):
|
|
505
|
+
raise TypeError("ranks must be list, but got {}".format(type(ranks)))
|
|
506
|
+
ranks = sorted(ranks)
|
|
507
|
+
else:
|
|
508
|
+
return GlobalComm.WORLD_COMM_GROUP
|
|
509
|
+
if backend is None:
|
|
510
|
+
backend = "hccl"
|
|
511
|
+
if not isinstance(backend, str) or backend not in ("hccl", "mccl"):
|
|
512
|
+
raise TypeError(f"the input backend must be hccl or mccl, but got {backend}")
|
|
513
|
+
group = backend + "_" + str(len(ranks)) + "_" + hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
|
|
514
|
+
try:
|
|
515
|
+
create_group(group, ranks)
|
|
516
|
+
except RuntimeError as e:
|
|
517
|
+
logger.warning(e)
|
|
518
|
+
group = ""
|
|
519
|
+
return group
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def get_backend(group=None):
|
|
523
|
+
"""
|
|
524
|
+
Get the backend of communication process groups.
|
|
525
|
+
|
|
526
|
+
Note:
|
|
527
|
+
Only one communication backend is supported by MindSpore for each process.
|
|
528
|
+
It should be one of `hccl`/`nccl`/`mccl`. Currently only support hccl and mccl.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
group (str, optional): The communication group to work on.
|
|
532
|
+
Normally, the group should be created by :func:`mindspore.mint.distributed.new_group`, If ``None``,
|
|
533
|
+
which means ``"hccl_world_group"`` in Ascend. Default: ``None``.
|
|
534
|
+
|
|
535
|
+
Returns:
|
|
536
|
+
string, the backend of the group.
|
|
537
|
+
|
|
538
|
+
Raises:
|
|
539
|
+
TypeError: If the `group` is not a str.
|
|
540
|
+
|
|
541
|
+
Supported Platforms:
|
|
542
|
+
``Ascend`` ``CPU``
|
|
543
|
+
|
|
544
|
+
Examples:
|
|
545
|
+
.. note::
|
|
546
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
547
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
548
|
+
without any third-party or configuration file dependencies.
|
|
549
|
+
Please see the `msrun start up
|
|
550
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
551
|
+
for more details.
|
|
552
|
+
|
|
553
|
+
>>> import mindspore as ms
|
|
554
|
+
>>> from mindspore.mint.distributed import init_process_group, get_backend
|
|
555
|
+
>>> ms.set_device(device_target="Ascend")
|
|
556
|
+
>>> init_process_group()
|
|
557
|
+
>>> backend = get_backend()
|
|
558
|
+
>>> print("backend is: ", backend)
|
|
559
|
+
backend is: hccl
|
|
560
|
+
"""
|
|
561
|
+
if group is None:
|
|
562
|
+
return BACKEND_HCCL
|
|
563
|
+
if not isinstance(group, str):
|
|
564
|
+
raise TypeError(
|
|
565
|
+
"For 'get_backend', the argument 'group' must be type of string or None, "
|
|
566
|
+
"but got 'group' type : {}.".format(type(group))
|
|
567
|
+
)
|
|
568
|
+
if BACKEND_HCCL in group:
|
|
569
|
+
return BACKEND_HCCL
|
|
570
|
+
if BACKEND_MCCL in group:
|
|
571
|
+
return BACKEND_MCCL
|
|
572
|
+
return _get_backend()
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
def get_global_rank(group, group_rank):
|
|
576
|
+
"""
|
|
577
|
+
A function that returns the rank id in the world group corresponding to the
|
|
578
|
+
rank which id is 'group_rank' in the user group.
|
|
579
|
+
|
|
580
|
+
Note:
|
|
581
|
+
This method should be used after :func:`mindspore.mint.distributed.init_process_group`.
|
|
582
|
+
|
|
583
|
+
Args:
|
|
584
|
+
group (str): The communication group to work on. Normally, the group should
|
|
585
|
+
be created by :func:`mindspore.mint.distributed.new_group`. If ``None``, which
|
|
586
|
+
means ``"hccl_world_group"`` in Ascend.
|
|
587
|
+
group_rank (int): Group rank to query.
|
|
588
|
+
|
|
589
|
+
Returns:
|
|
590
|
+
An integer scalar with the rank id in the world group.
|
|
591
|
+
|
|
592
|
+
Raises:
|
|
593
|
+
TypeError: If the `group` is not a str.
|
|
594
|
+
TypeError: If the `group_rank` is not an integer.
|
|
595
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
596
|
+
|
|
597
|
+
Supported Platforms:
|
|
598
|
+
``Ascend``
|
|
599
|
+
|
|
600
|
+
Examples:
|
|
601
|
+
.. note::
|
|
602
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
603
|
+
|
|
604
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
605
|
+
without any third-party or configuration file dependencies.
|
|
606
|
+
|
|
607
|
+
Please see the `msrun start up
|
|
608
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
609
|
+
for more details.
|
|
610
|
+
|
|
611
|
+
This example should be run with 8 devices.
|
|
612
|
+
|
|
613
|
+
>>> import mindspore as ms
|
|
614
|
+
>>> from mindspore.mint.distributed import init_process_group, get_global_rank, new_group, get_rank
|
|
615
|
+
>>> ms.set_device(device_target="Ascend")
|
|
616
|
+
>>> # Launch 8 processes.
|
|
617
|
+
>>> init_process_group()
|
|
618
|
+
>>> rank_ids = [0,4]
|
|
619
|
+
>>> if get_rank() in rank_ids:
|
|
620
|
+
... group = new_group(rank_ids)
|
|
621
|
+
... world_rank_id = get_global_rank(group, 1)
|
|
622
|
+
... print("world_rank_id is: ", world_rank_id)
|
|
623
|
+
#rank 0 and 4:
|
|
624
|
+
world_rank_id is: 4
|
|
625
|
+
"""
|
|
626
|
+
if not isinstance(group_rank, int):
|
|
627
|
+
raise TypeError(
|
|
628
|
+
f"The group_rank argument must be integer, but got {type(group_rank)}."
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
if group is None or group is GlobalComm.WORLD_COMM_GROUP:
|
|
632
|
+
return group_rank
|
|
633
|
+
|
|
634
|
+
if not isinstance(group, str):
|
|
635
|
+
raise TypeError(
|
|
636
|
+
"For 'get_global_rank', the argument 'group' must be type of string or None, "
|
|
637
|
+
"but got 'group' type : {}.".format(type(group))
|
|
638
|
+
)
|
|
639
|
+
return get_world_rank_from_group_rank(group, group_rank)
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def get_group_rank(group, global_rank):
|
|
643
|
+
"""
|
|
644
|
+
Get the rank ID in the specified user communication group corresponding to
|
|
645
|
+
the rank ID in the world communication group.
|
|
646
|
+
|
|
647
|
+
Note:
|
|
648
|
+
This method should be used after :func:`mindspore.mint.distributed.init_process_group`.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
group (str): The communication group to work on. Normally, the group should be
|
|
652
|
+
created by :func:`mindspore.mint.distributed.new_group`. If ``None``, which means
|
|
653
|
+
``"hccl_world_group"`` in Ascend.
|
|
654
|
+
global_rank (int): A rank ID in the world communication group.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
int, the rank ID in the user communication group.
|
|
658
|
+
|
|
659
|
+
Raises:
|
|
660
|
+
TypeError: If global_rank is not an integer or the group is not a string.
|
|
661
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
662
|
+
|
|
663
|
+
Supported Platforms:
|
|
664
|
+
``Ascend``
|
|
665
|
+
|
|
666
|
+
Examples:
|
|
667
|
+
.. note::
|
|
668
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
669
|
+
|
|
670
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
671
|
+
without any third-party or configuration file dependencies.
|
|
672
|
+
Please see the `msrun start up
|
|
673
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
674
|
+
for more details.
|
|
675
|
+
|
|
676
|
+
This example should be run with 8 devices.
|
|
677
|
+
|
|
678
|
+
>>> import mindspore as ms
|
|
679
|
+
>>> from mindspore.mint.distributed import init_process_group, new_group, get_group_rank, get_rank
|
|
680
|
+
>>> ms.set_device(device_target="Ascend")
|
|
681
|
+
>>> # Launch 8 processes.
|
|
682
|
+
>>> init_process_group()
|
|
683
|
+
>>> rank_ids = [0,4]
|
|
684
|
+
>>> if get_rank() in rank_ids:
|
|
685
|
+
... group = new_group(rank_ids)
|
|
686
|
+
... group_rank_id = get_group_rank(group, 4)
|
|
687
|
+
... print("group_rank_id is: ", group_rank_id)
|
|
688
|
+
#rank 0 and 4:
|
|
689
|
+
group_rank_id is: 1
|
|
690
|
+
"""
|
|
691
|
+
if not isinstance(global_rank, int):
|
|
692
|
+
raise TypeError(
|
|
693
|
+
f"The global_rank argument must be integer, but got {type(global_rank)}."
|
|
694
|
+
)
|
|
695
|
+
if group is None:
|
|
696
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
697
|
+
if not isinstance(group, str):
|
|
698
|
+
raise TypeError(
|
|
699
|
+
"For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, "
|
|
700
|
+
"but got 'group' type : {}.".format(type(group))
|
|
701
|
+
)
|
|
702
|
+
return _get_group_rank_from_world_rank_from_cache_helper(
|
|
703
|
+
world_rank_id=global_rank, group=group
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
def get_process_group_ranks(group=None):
|
|
708
|
+
"""
|
|
709
|
+
Gets the ranks of the specific group and returns the process ranks in the communication group as a list.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
group (str, optional): The communication group to work on. Normally, the group should be created by
|
|
713
|
+
:func:`mindspore.mint.distributed.new_group`. If ``None``, which means ``"hccl_world_group"`` in Ascend.
|
|
714
|
+
Default: ``None``.
|
|
715
|
+
|
|
716
|
+
Returns:
|
|
717
|
+
List (List[int]), List of process ranks in the specified communication group.
|
|
718
|
+
|
|
719
|
+
Raises:
|
|
720
|
+
TypeError: If the `group` is not a str.
|
|
721
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
722
|
+
|
|
723
|
+
Supported Platforms:
|
|
724
|
+
``Ascend`` ``CPU``
|
|
725
|
+
|
|
726
|
+
Examples:
|
|
727
|
+
.. note::
|
|
728
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
729
|
+
|
|
730
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
731
|
+
without any third-party or configuration file dependencies.
|
|
732
|
+
|
|
733
|
+
Please see the `msrun start up
|
|
734
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
735
|
+
for more details.
|
|
736
|
+
|
|
737
|
+
This example should be run with 4 devices.
|
|
738
|
+
|
|
739
|
+
>>> import mindspore as ms
|
|
740
|
+
>>> from mindspore.mint.distributed import init_process_group, get_process_group_ranks
|
|
741
|
+
>>> # Launch 4 processes.
|
|
742
|
+
>>> ms.set_device(device_target="Ascend")
|
|
743
|
+
>>> init_process_group()
|
|
744
|
+
>>> output = get_process_group_ranks()
|
|
745
|
+
>>> print(output)
|
|
746
|
+
[0, 1, 2, 3]
|
|
747
|
+
|
|
748
|
+
"""
|
|
749
|
+
if group is None:
|
|
750
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
751
|
+
|
|
752
|
+
if not isinstance(group, str):
|
|
753
|
+
raise TypeError(
|
|
754
|
+
"For 'get_process_group_ranks', the argument 'group' must be type of string or None, "
|
|
755
|
+
"but got 'group' type : {}.".format(type(group))
|
|
756
|
+
)
|
|
757
|
+
return _get_group_ranks(group)
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
@_primexpr
|
|
761
|
+
def _check_all_tensor_same_dtype_and_shape(*tensor_lists):
|
|
762
|
+
"""check all the input tensor has same dtype and shape"""
|
|
763
|
+
consistent_dtype = None
|
|
764
|
+
consistent_shape = None
|
|
765
|
+
for list_ in tensor_lists:
|
|
766
|
+
if not isinstance(list_, (list, tuple)):
|
|
767
|
+
list_ = [list_]
|
|
768
|
+
for tensor_ in list_:
|
|
769
|
+
if not isinstance(tensor_, Tensor):
|
|
770
|
+
continue
|
|
771
|
+
dtype = tensor_.dtype
|
|
772
|
+
shape = tensor_.shape
|
|
773
|
+
if consistent_dtype is None:
|
|
774
|
+
consistent_dtype = dtype
|
|
775
|
+
consistent_shape = shape
|
|
776
|
+
else:
|
|
777
|
+
if dtype != consistent_dtype:
|
|
778
|
+
raise TypeError(
|
|
779
|
+
"tensor_lists dtype must be the same, "
|
|
780
|
+
f"but got {consistent_dtype} and {dtype}."
|
|
781
|
+
)
|
|
782
|
+
if shape != consistent_shape:
|
|
783
|
+
raise TypeError(
|
|
784
|
+
"tensor_lists shape must be the same, "
|
|
785
|
+
f"but got {consistent_shape} and {shape}."
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
|
790
|
+
"""
|
|
791
|
+
Reduce tensors across all devices in such a way that all deviceswill get the same final result,
|
|
792
|
+
returns the tensor which is all reduced.
|
|
793
|
+
|
|
794
|
+
Note:
|
|
795
|
+
The tensors must have the same shape and format in all processes of the collection.
|
|
796
|
+
|
|
797
|
+
Args:
|
|
798
|
+
tensor (Tensor): The input and output tensor of collective. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
799
|
+
The function operates in-place.
|
|
800
|
+
op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
|
|
801
|
+
Default: ``ReduceOp.SUM`` .
|
|
802
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
803
|
+
Ascend. Default: ``None``.
|
|
804
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
805
|
+
|
|
806
|
+
Returns:
|
|
807
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True. CommHandle will be None,
|
|
808
|
+
when `async_op` is False.
|
|
809
|
+
|
|
810
|
+
Raises:
|
|
811
|
+
TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str,
|
|
812
|
+
`op` range is illegal or async_op is not bool.
|
|
813
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
814
|
+
|
|
815
|
+
Supported Platforms:
|
|
816
|
+
``Ascend``
|
|
817
|
+
|
|
818
|
+
Examples:
|
|
819
|
+
.. note::
|
|
820
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
821
|
+
|
|
822
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
823
|
+
without any third-party or configuration file dependencies.
|
|
824
|
+
Please see the `msrun start up
|
|
825
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
826
|
+
for more details.
|
|
827
|
+
|
|
828
|
+
This example should be run with 2 devices.
|
|
829
|
+
|
|
830
|
+
>>> import numpy as np
|
|
831
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
832
|
+
>>> from mindspore.mint.distributed import all_reduce
|
|
833
|
+
>>> from mindspore import Tensor
|
|
834
|
+
>>>
|
|
835
|
+
>>> init_process_group()
|
|
836
|
+
>>> tensor = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
837
|
+
>>> output = all_reduce(tensor)
|
|
838
|
+
>>> print(tensor)
|
|
839
|
+
[[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
840
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]]
|
|
841
|
+
|
|
842
|
+
"""
|
|
843
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
844
|
+
raise TypeError("For all_reduce, the input tensor must be tensor")
|
|
845
|
+
if not isinstance(op, str):
|
|
846
|
+
raise TypeError("For all_reduce, the input op type must be str")
|
|
847
|
+
if op not in ("sum", "prod", "min", "max"):
|
|
848
|
+
raise TypeError(
|
|
849
|
+
"For all_reduce, the input op value must be one of sum, prod, min, max"
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
if group is None:
|
|
853
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
854
|
+
|
|
855
|
+
if not isinstance(group, str):
|
|
856
|
+
raise TypeError(
|
|
857
|
+
"The argument 'group' must be type of string, "
|
|
858
|
+
"but got 'group' type : {}.".format(type(group))
|
|
859
|
+
)
|
|
860
|
+
if not isinstance(async_op, bool):
|
|
861
|
+
raise TypeError(
|
|
862
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
output = dist_comm_all_reduce_op(tensor, op, group)
|
|
866
|
+
_, handle = _deal_comm_outputs(output, async_op)
|
|
867
|
+
return handle
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
|
|
871
|
+
"""
|
|
872
|
+
Gathers tensors from the specified communication group and returns the tensor which is all gathered.
|
|
873
|
+
|
|
874
|
+
Note:
|
|
875
|
+
The tensors must have the same shape and format in all processes of the collection.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
output_tensor (Tensor): The output tensor to be all gathered into tensor.If the number of devices
|
|
879
|
+
in the group is N, then the shape of output tensor is :math:`(N*x_1, x_2, ..., x_R)`.
|
|
880
|
+
input_tensor (Tensor): The input tensor to be all gathered into tensor.
|
|
881
|
+
The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
882
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
883
|
+
Ascend. Default: ``None``.
|
|
884
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
888
|
+
CommHandle will be None, when `async_op` is False.
|
|
889
|
+
|
|
890
|
+
Raises:
|
|
891
|
+
TypeError: If the type of the input_tensor or output_tensor parameter is not Tensor,
|
|
892
|
+
`group` is not a str, or async_op is not bool.
|
|
893
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
894
|
+
|
|
895
|
+
Supported Platforms:
|
|
896
|
+
``Ascend``
|
|
897
|
+
|
|
898
|
+
Examples:
|
|
899
|
+
.. note::
|
|
900
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
901
|
+
|
|
902
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
903
|
+
without any third-party or configuration file dependencies.
|
|
904
|
+
Please see the `msrun start up
|
|
905
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
906
|
+
for more details.
|
|
907
|
+
|
|
908
|
+
This example should be run with 2 devices.
|
|
909
|
+
|
|
910
|
+
>>> import numpy as np
|
|
911
|
+
>>> import mindspore as ms
|
|
912
|
+
>>> from mindspore import ops
|
|
913
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
914
|
+
>>> from mindspore.mint.distributed import all_gather_into_tensor
|
|
915
|
+
>>> from mindspore import Tensor
|
|
916
|
+
>>>
|
|
917
|
+
>>> ms.set_device(device_target="Ascend")
|
|
918
|
+
>>> init_process_group()
|
|
919
|
+
>>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
920
|
+
>>> out_tensor = Tensor(np.zeros([4, 8]).astype(np.float32))
|
|
921
|
+
>>> output = all_gather_into_tensor(out_tensor, input_tensor)
|
|
922
|
+
>>> print(out_tensor)
|
|
923
|
+
[[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
924
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
925
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
926
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]]
|
|
927
|
+
|
|
928
|
+
"""
|
|
929
|
+
|
|
930
|
+
if not isinstance(input_tensor, (Tensor, Tensor_)):
|
|
931
|
+
raise TypeError("For all_gather_into_tensor, the input tensor must be tensor")
|
|
932
|
+
if not isinstance(output_tensor, (Tensor, Tensor_)):
|
|
933
|
+
raise TypeError("For all_gather_into_tensor, the output tensor must be tensor")
|
|
934
|
+
if group is None:
|
|
935
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
936
|
+
if not isinstance(group, str):
|
|
937
|
+
raise TypeError(
|
|
938
|
+
"The argument 'group' must be type of string, "
|
|
939
|
+
"but got 'group' type : {}.".format(type(group))
|
|
940
|
+
)
|
|
941
|
+
if not isinstance(async_op, bool):
|
|
942
|
+
raise TypeError(
|
|
943
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
944
|
+
)
|
|
945
|
+
group_size = get_cache_group_size(group)
|
|
946
|
+
result = dist_comm_all_gather_into_tensor_op(
|
|
947
|
+
output_tensor, input_tensor, group_size, group
|
|
948
|
+
)
|
|
949
|
+
_, handle = _deal_comm_outputs(result, async_op)
|
|
950
|
+
return handle
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
|
|
954
|
+
r"""
|
|
955
|
+
Reduces and scatters tensors from the specified communication group and
|
|
956
|
+
returns the tensor which is reduced and scattered.
|
|
957
|
+
|
|
958
|
+
Note:
|
|
959
|
+
The tensors must have the same shape and format in all processes of the collection.
|
|
960
|
+
|
|
961
|
+
Args:
|
|
962
|
+
output(Tensor): the output tensor has the same dtype as `input_x` with a shape of :math:`(N/rank\_size, *)`
|
|
963
|
+
input(Tensor): The input tensor to be reduced and scattered, suppose it has a shape :math:`(N, *)`, where `*`
|
|
964
|
+
means any number of additional dimensions. N must be divisible by rank_size.
|
|
965
|
+
rank_size refers to the number of cards in the communication group.
|
|
966
|
+
op (str, optional): Specifies an operation used for element-wise reductions,
|
|
967
|
+
like SUM and MAX. Default: ``ReduceOp.SUM`` .
|
|
968
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
969
|
+
Ascend. Default: ``None``.
|
|
970
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
971
|
+
|
|
972
|
+
Returns:
|
|
973
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
974
|
+
CommHandle will be None, when `async_op` is False.
|
|
975
|
+
|
|
976
|
+
Raises:
|
|
977
|
+
TypeError: If the type of the input and output parameter is not Tensor, any of `op` and `group` is not a str.
|
|
978
|
+
async_op is not bool or 'op' is invalid.
|
|
979
|
+
ValueError: If the first dimension of the input cannot be divided by the rank_size.
|
|
980
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
981
|
+
|
|
982
|
+
Supported Platforms:
|
|
983
|
+
``Ascend``
|
|
984
|
+
|
|
985
|
+
Examples:
|
|
986
|
+
.. note::
|
|
987
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
988
|
+
|
|
989
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
990
|
+
without any third-party or configuration file dependencies.
|
|
991
|
+
Please see the `msrun start up
|
|
992
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
993
|
+
for more details.
|
|
994
|
+
|
|
995
|
+
This example should be run with 2 devices.
|
|
996
|
+
|
|
997
|
+
>>> import mindspore as ms
|
|
998
|
+
>>> from mindspore import Tensor
|
|
999
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
1000
|
+
>>> from mindspore.mint.distributed import reduce_scatter_tensor
|
|
1001
|
+
>>> import numpy as np
|
|
1002
|
+
>>>
|
|
1003
|
+
>>> ms.set_device(device_target="Ascend")
|
|
1004
|
+
>>> init_process_group()
|
|
1005
|
+
>>> input_tensor = Tensor(np.ones([8, 8]).astype(np.float32))
|
|
1006
|
+
>>> output_tensor = Tensor(np.ones([4, 8]).astype(np.float32))
|
|
1007
|
+
>>> output = reduce_scatter_tensor(output_tensor ,input_tensor)
|
|
1008
|
+
>>> print(output_tensor)
|
|
1009
|
+
[[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
1010
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
1011
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
1012
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]]
|
|
1013
|
+
|
|
1014
|
+
"""
|
|
1015
|
+
|
|
1016
|
+
if not isinstance(input, (Tensor, Tensor_)):
|
|
1017
|
+
raise TypeError("For reduce_scatter_tensor, the input tensor must be tensor")
|
|
1018
|
+
if not isinstance(output, (Tensor, Tensor_)):
|
|
1019
|
+
raise TypeError("For reduce_scatter_tensor, the output tensor must be tensor")
|
|
1020
|
+
if not isinstance(op, str):
|
|
1021
|
+
raise TypeError("For reduce_scatter_tensor, the input op type must be str")
|
|
1022
|
+
if op not in ("sum", "prod", "min", "max"):
|
|
1023
|
+
raise TypeError(
|
|
1024
|
+
"For reduce_scatter_tensor, the input op value must be one of sum, prod, min, max"
|
|
1025
|
+
)
|
|
1026
|
+
if group is None:
|
|
1027
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1028
|
+
if not isinstance(group, str):
|
|
1029
|
+
raise TypeError(
|
|
1030
|
+
"The argument 'group' must be type of string, "
|
|
1031
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1032
|
+
)
|
|
1033
|
+
if not isinstance(async_op, bool):
|
|
1034
|
+
raise TypeError(
|
|
1035
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
1036
|
+
)
|
|
1037
|
+
rank_size = get_cache_group_size(group)
|
|
1038
|
+
result = dist_comm_reduce_scatter_tensor_op(output, input, rank_size, op, group)
|
|
1039
|
+
_, handle = _deal_comm_outputs(result, async_op)
|
|
1040
|
+
return handle
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
|
|
1044
|
+
"""
|
|
1045
|
+
Reduces tensors across the processes in the specified communication group, sends the result
|
|
1046
|
+
to the target dst(global rank), and returns the tensor which is sent to the target process.
|
|
1047
|
+
|
|
1048
|
+
Note:
|
|
1049
|
+
- Only process with destination rank receives the reduced output.
|
|
1050
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
1051
|
+
- Other processes only get a tensor with shape [1], which has no mathematical meaning.
|
|
1052
|
+
|
|
1053
|
+
Args:
|
|
1054
|
+
tensor (Tensor): Input and output of the collective. The function operates in-place.
|
|
1055
|
+
dst (int): The target rank of the process(global rank) that receives the reduced output.
|
|
1056
|
+
op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
|
|
1057
|
+
Default: ``ReduceOp.SUM`` .
|
|
1058
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1059
|
+
Ascend. Default: ``None``.
|
|
1060
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
1061
|
+
|
|
1062
|
+
Returns:
|
|
1063
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to ``True``.
|
|
1064
|
+
CommHandle will be None, when `async_op` is ``False``.
|
|
1065
|
+
|
|
1066
|
+
Raises:
|
|
1067
|
+
TypeError: If the type of `tensor` is not Tensor, any of `op` and `group` is not a str.
|
|
1068
|
+
async_op is not bool or 'op' is invalid.
|
|
1069
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
1070
|
+
|
|
1071
|
+
Supported Platforms:
|
|
1072
|
+
``Ascend``
|
|
1073
|
+
|
|
1074
|
+
Examples:
|
|
1075
|
+
.. note::
|
|
1076
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1077
|
+
|
|
1078
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1079
|
+
without any third-party or configuration file dependencies.
|
|
1080
|
+
|
|
1081
|
+
Please see the `msrun start up
|
|
1082
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1083
|
+
for more details.
|
|
1084
|
+
|
|
1085
|
+
This example should be run with 4 devices.
|
|
1086
|
+
|
|
1087
|
+
>>> from mindspore import ops
|
|
1088
|
+
>>> import mindspore.nn as nn
|
|
1089
|
+
>>> from mindspore.mint.distributed import init_process_group, reduce
|
|
1090
|
+
>>> from mindspore import Tensor
|
|
1091
|
+
>>> import numpy as np
|
|
1092
|
+
>>> # Launch 2 processes.
|
|
1093
|
+
>>> init_process_group()
|
|
1094
|
+
>>> dest_rank=1
|
|
1095
|
+
>>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
1096
|
+
>>> output = reduce(input_tensor, dest_rank)
|
|
1097
|
+
>>> print(input_tensor)
|
|
1098
|
+
Process with rank 0: [[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
1099
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]],
|
|
1100
|
+
Process with rank 1: [[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
1101
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]],
|
|
1102
|
+
"""
|
|
1103
|
+
|
|
1104
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
1105
|
+
raise TypeError("For reduce, the input tensor must be tensor")
|
|
1106
|
+
if not isinstance(dst, int):
|
|
1107
|
+
raise TypeError("For reduce, the dst must be int")
|
|
1108
|
+
if not isinstance(op, str):
|
|
1109
|
+
raise TypeError("For reduce, the input op type must be str")
|
|
1110
|
+
if op not in ("sum", "prod", "min", "max"):
|
|
1111
|
+
raise TypeError(
|
|
1112
|
+
"For reduce, the input op value must be one of sum, prod, min, max"
|
|
1113
|
+
)
|
|
1114
|
+
if group is None:
|
|
1115
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1116
|
+
if not isinstance(group, str):
|
|
1117
|
+
raise TypeError(
|
|
1118
|
+
"The argument 'group' must be type of string, "
|
|
1119
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1120
|
+
)
|
|
1121
|
+
if not isinstance(async_op, bool):
|
|
1122
|
+
raise TypeError(
|
|
1123
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
1124
|
+
)
|
|
1125
|
+
result = dist_comm_reduce_op(tensor, op, dst, group)
|
|
1126
|
+
_, handle = _deal_comm_outputs(result, async_op)
|
|
1127
|
+
return handle
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
class P2POp:
|
|
1131
|
+
"""
|
|
1132
|
+
Object for `batch_isend_irecv` input, to store information of ``"isend"`` and ``"irecv"``.
|
|
1133
|
+
|
|
1134
|
+
Note:
|
|
1135
|
+
`tensor` will be modified in-place by final result when `op` is ``"irecv"``.
|
|
1136
|
+
|
|
1137
|
+
Args:
|
|
1138
|
+
op(Union[str, function]): Only string of ``"isend"`` and ``"irecv"`` are allowed.
|
|
1139
|
+
Or function of ``distributed.isend`` and ``distributed.irecv`` are allowed.
|
|
1140
|
+
tensor(Tensor): tensor for sending/receiving.
|
|
1141
|
+
peer(int): remote global rank for send/receive.
|
|
1142
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1143
|
+
Ascend. Default: ``None``.
|
|
1144
|
+
tag(int, optional): currently not supported yet. Default: ``0``.
|
|
1145
|
+
|
|
1146
|
+
Returns:
|
|
1147
|
+
P2POp Object.
|
|
1148
|
+
|
|
1149
|
+
Raises:
|
|
1150
|
+
TypeError: when `op` is not string or function of 'isend' and 'irecv'.
|
|
1151
|
+
TypeError: when `tensor` is not type of Tensor or 'peer' is not int.
|
|
1152
|
+
NotImplementedError: when `tag` is not 0.
|
|
1153
|
+
|
|
1154
|
+
Supported Platforms:
|
|
1155
|
+
``Ascend``
|
|
1156
|
+
|
|
1157
|
+
Examples:
|
|
1158
|
+
>>> import numpy as np
|
|
1159
|
+
>>> import mindspore
|
|
1160
|
+
>>> from mindspore.mint.distributed import P2POp, isend, irecv
|
|
1161
|
+
>>> from mindspore import Tensor
|
|
1162
|
+
>>> # Launch 2 processes.
|
|
1163
|
+
>>> send_tensor = Tensor(1.)
|
|
1164
|
+
>>> send_op = P2POp('isend', send_tensor, 1)
|
|
1165
|
+
>>> send_op = P2POp(isend, send_tensor, 1)
|
|
1166
|
+
>>> recv_tensor = Tensor(0.)
|
|
1167
|
+
>>> recv_op = P2POp('irecv', recv_tensor, 0)
|
|
1168
|
+
>>> recv_op = P2POp(irecv, recv_tensor, 0)
|
|
1169
|
+
"""
|
|
1170
|
+
|
|
1171
|
+
def __init__(self, op, tensor, peer, group=None, tag=0):
|
|
1172
|
+
self.op = op
|
|
1173
|
+
self.tensor = tensor
|
|
1174
|
+
self.peer = peer
|
|
1175
|
+
self.group = group
|
|
1176
|
+
self.tag = tag
|
|
1177
|
+
|
|
1178
|
+
def __new__(cls, op, tensor, peer, group=None, tag=0):
|
|
1179
|
+
if isinstance(op, str):
|
|
1180
|
+
op_name = op
|
|
1181
|
+
if op_name not in ["isend", "irecv"]:
|
|
1182
|
+
raise TypeError(
|
|
1183
|
+
f"Expected op to be of type isend or irecv, but got {op_name}"
|
|
1184
|
+
)
|
|
1185
|
+
else:
|
|
1186
|
+
if op not in [isend, irecv]:
|
|
1187
|
+
raise TypeError(
|
|
1188
|
+
f"Expected op to be of type isend or irecv, but got {op}"
|
|
1189
|
+
)
|
|
1190
|
+
op_name = op.__name__
|
|
1191
|
+
|
|
1192
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
1193
|
+
raise TypeError(
|
|
1194
|
+
f"Expected tensor to be Tensor, but got {type(tensor)}."
|
|
1195
|
+
)
|
|
1196
|
+
if not isinstance(peer, int):
|
|
1197
|
+
raise TypeError("For P2POp, the peer must be int")
|
|
1198
|
+
if tag != 0:
|
|
1199
|
+
raise NotImplementedError("tag is not support yet.")
|
|
1200
|
+
return object.__new__(cls)
|
|
1201
|
+
|
|
1202
|
+
|
|
1203
|
+
TYPE_ISEND = 0
|
|
1204
|
+
TYPE_IRECV = 1
|
|
1205
|
+
|
|
1206
|
+
|
|
1207
|
+
def batch_isend_irecv(p2p_op_list):
|
|
1208
|
+
"""
|
|
1209
|
+
Batch send and recv tensors asynchronously.
|
|
1210
|
+
|
|
1211
|
+
Note:
|
|
1212
|
+
- The 'isend' and 'irecv' of `P2POp` in `p2p_op_list` between ranks need to match each other.
|
|
1213
|
+
- `P2POp` in `p2p_op_list` can only use the same communication group.
|
|
1214
|
+
- `tag` of `P2POp` in `p2p_op_list` is not support yet.
|
|
1215
|
+
- `tensor` of `P2POp` in `p2p_op_list` will not be modified by result inplace.
|
|
1216
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
1217
|
+
|
|
1218
|
+
Args:
|
|
1219
|
+
p2p_op_list(list[P2POp]): list contains `P2POp`. `P2POp` is type of :class:`mindspore.mint.distributed.P2POp`
|
|
1220
|
+
|
|
1221
|
+
Returns:
|
|
1222
|
+
list[CommHandle], CommHandle is an async work handle, Currently only one packaging handle is supported.
|
|
1223
|
+
|
|
1224
|
+
Raises:
|
|
1225
|
+
TypeError: If `p2p_op_list` is empty or `p2p_op_list` are not all type of `P2POp`.
|
|
1226
|
+
TypeError: The group name in `p2p_op_list` are not consistent.
|
|
1227
|
+
TypeError: The `tensor` in `p2p_op_list` are not Tensor.
|
|
1228
|
+
TypeError: The `op` in `p2p_op_list` are not isend or irecv.
|
|
1229
|
+
|
|
1230
|
+
Supported Platforms:
|
|
1231
|
+
``Ascend``
|
|
1232
|
+
|
|
1233
|
+
Examples:
|
|
1234
|
+
.. note::
|
|
1235
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1236
|
+
|
|
1237
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1238
|
+
without any third-party or configuration file dependencies.
|
|
1239
|
+
Please see the `msrun start up
|
|
1240
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1241
|
+
for more details.
|
|
1242
|
+
|
|
1243
|
+
This example should be run with 2 devices.
|
|
1244
|
+
|
|
1245
|
+
>>> import numpy as np
|
|
1246
|
+
>>> import mindspore
|
|
1247
|
+
>>> from mindspore.mint.distributed import init_process_group, get_rank, get_world_size
|
|
1248
|
+
>>> from mindspore.mint.distributed import batch_isend_irecv, P2POp
|
|
1249
|
+
>>> from mindspore import Tensor
|
|
1250
|
+
>>>
|
|
1251
|
+
>>> init_process_group()
|
|
1252
|
+
>>> this_rank = get_rank()
|
|
1253
|
+
>>> world_size = get_world_size()
|
|
1254
|
+
>>> next_rank = (this_rank + 1) % world_size
|
|
1255
|
+
>>> prev_rank = (this_rank + world_size - 1) % world_size
|
|
1256
|
+
>>>
|
|
1257
|
+
>>> send_tensor = Tensor(this_rank + 1, dtype=mindspore.float32)
|
|
1258
|
+
>>> recv_tensor = Tensor(0., dtype=mindspore.float32)
|
|
1259
|
+
>>>
|
|
1260
|
+
>>> send_op = P2POp('isend', send_tensor, next_rank)
|
|
1261
|
+
>>> recv_op = P2POp('irecv', recv_tensor, prev_rank)
|
|
1262
|
+
>>>
|
|
1263
|
+
>>> p2p_op_list = [send_op, recv_op]
|
|
1264
|
+
>>> output = batch_isend_irecv(p2p_op_list)
|
|
1265
|
+
>>> print(recv_tensor)
|
|
1266
|
+
rank 0:
|
|
1267
|
+
2.0
|
|
1268
|
+
rank 1:
|
|
1269
|
+
1.0
|
|
1270
|
+
"""
|
|
1271
|
+
tensors = []
|
|
1272
|
+
op_types = []
|
|
1273
|
+
remotes_ranks = []
|
|
1274
|
+
tags = []
|
|
1275
|
+
if not p2p_op_list:
|
|
1276
|
+
raise TypeError(f"p2p_op_list can not be empty list.")
|
|
1277
|
+
for _, p2p_op in enumerate(p2p_op_list):
|
|
1278
|
+
if not isinstance(p2p_op, P2POp):
|
|
1279
|
+
raise TypeError("The elements in p2p_op_list must be type of P2POp.")
|
|
1280
|
+
group = p2p_op_list[0].group
|
|
1281
|
+
|
|
1282
|
+
type_ = None
|
|
1283
|
+
for _, p2p_op in enumerate(p2p_op_list):
|
|
1284
|
+
if group != p2p_op.group:
|
|
1285
|
+
raise TypeError("The group name in p2p_op_list must be consistent.")
|
|
1286
|
+
if isinstance(p2p_op.op, str):
|
|
1287
|
+
type_ = p2p_op.op
|
|
1288
|
+
else:
|
|
1289
|
+
type_ = p2p_op.op.__name__
|
|
1290
|
+
rank_ = (
|
|
1291
|
+
p2p_op.peer
|
|
1292
|
+
if p2p_op.group is None
|
|
1293
|
+
else get_group_rank_from_world_rank(p2p_op.peer, p2p_op.group)
|
|
1294
|
+
)
|
|
1295
|
+
remotes_ranks.append(rank_)
|
|
1296
|
+
tags.append(p2p_op.tag)
|
|
1297
|
+
if type_ == "isend":
|
|
1298
|
+
tensors.append(p2p_op.tensor)
|
|
1299
|
+
op_types.append(TYPE_ISEND)
|
|
1300
|
+
elif type_ == "irecv":
|
|
1301
|
+
if isinstance(p2p_op.tensor, Tensor):
|
|
1302
|
+
tensors.append(p2p_op.tensor)
|
|
1303
|
+
op_types.append(TYPE_IRECV)
|
|
1304
|
+
else:
|
|
1305
|
+
raise TypeError("p2p_op.tensor must be tensor")
|
|
1306
|
+
else:
|
|
1307
|
+
raise TypeError("p2p_op.op must be isend or irecv")
|
|
1308
|
+
|
|
1309
|
+
if group is None:
|
|
1310
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1311
|
+
output = dist_comm_batch_isend_irecv_op(tensors, group, op_types, remotes_ranks)
|
|
1312
|
+
_, handle = _deal_comm_outputs(output, True)
|
|
1313
|
+
return [handle]
|
|
1314
|
+
|
|
1315
|
+
|
|
1316
|
+
def scatter_tensor(output_tensor, input_tensor, src=0, group=None, async_op=False):
|
|
1317
|
+
r"""
|
|
1318
|
+
Scatter tensor evently across the processes in the specified communication group.
|
|
1319
|
+
|
|
1320
|
+
Note:
|
|
1321
|
+
- The interface behavior only support Tensor input and scatter evenly, which
|
|
1322
|
+
is different from that of `pytoch.distributed.scatter`.
|
|
1323
|
+
- Only the tensor in process `src` (global rank) will do scatter.
|
|
1324
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
1325
|
+
|
|
1326
|
+
Args:
|
|
1327
|
+
output_tensor (Tensor): Output tensor. It should have the same size across all ranks.
|
|
1328
|
+
input_tensor (Tensor): The input tensor to be scattered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
1329
|
+
src (int, optional): Specifies the rank(global rank) of the process that send the tensor.
|
|
1330
|
+
And only process `src` will send the tensor. Default is 0.
|
|
1331
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1332
|
+
Ascend. Default: ``None``.
|
|
1333
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
1334
|
+
|
|
1335
|
+
Returns:
|
|
1336
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
1337
|
+
CommHandle will be None, when `async_op` is False.
|
|
1338
|
+
|
|
1339
|
+
Raises:
|
|
1340
|
+
TypeError: If the type of the first input parameter is not Tensor, or any of `op` and `group` is not a str.
|
|
1341
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
1342
|
+
|
|
1343
|
+
Supported Platforms:
|
|
1344
|
+
``Ascend``
|
|
1345
|
+
|
|
1346
|
+
Examples:
|
|
1347
|
+
.. note::
|
|
1348
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1349
|
+
|
|
1350
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1351
|
+
without any third-party or configuration file dependencies.
|
|
1352
|
+
Please see the `msrun start up
|
|
1353
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1354
|
+
for more details.
|
|
1355
|
+
|
|
1356
|
+
This example should be run with 2 devices.
|
|
1357
|
+
|
|
1358
|
+
>>> import mindspore as ms
|
|
1359
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
1360
|
+
>>> from mindspore.communication.comm_func import scatter_tensor
|
|
1361
|
+
>>> import numpy as np
|
|
1362
|
+
>>> # Launch 2 processes.
|
|
1363
|
+
>>>
|
|
1364
|
+
>>> init_process_group()
|
|
1365
|
+
>>> input = ms.Tensor(np.arange(8).reshape([4, 2]).astype(np.float32))
|
|
1366
|
+
>>> output = ms.Tensor(np.zeros([2, 2]).astype(np.float32))
|
|
1367
|
+
>>> out = scatter_tensor(output, input, src=0)
|
|
1368
|
+
>>> print(output)
|
|
1369
|
+
# rank_0
|
|
1370
|
+
[[0. 1.]
|
|
1371
|
+
[2. 3.]]
|
|
1372
|
+
# rank_1
|
|
1373
|
+
[[4. 5.]
|
|
1374
|
+
[6. 7.]]
|
|
1375
|
+
"""
|
|
1376
|
+
if not isinstance(input_tensor, (Tensor, Tensor_)):
|
|
1377
|
+
raise TypeError("For scatter_tensor, the input tensor must be tensor")
|
|
1378
|
+
if not isinstance(output_tensor, (Tensor, Tensor_)):
|
|
1379
|
+
raise TypeError("For scatter_tensor, the output tensor must be tensor")
|
|
1380
|
+
if not isinstance(src, int):
|
|
1381
|
+
raise TypeError("For scatter_tensor, the src must be int")
|
|
1382
|
+
if group is None:
|
|
1383
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1384
|
+
if not isinstance(group, str):
|
|
1385
|
+
raise TypeError(
|
|
1386
|
+
"The argument 'group' must be type of string, "
|
|
1387
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1388
|
+
)
|
|
1389
|
+
if not isinstance(async_op, bool):
|
|
1390
|
+
raise TypeError(
|
|
1391
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
1392
|
+
)
|
|
1393
|
+
src = get_group_rank_from_world_rank(src, group)
|
|
1394
|
+
rank_size = get_cache_group_size(group)
|
|
1395
|
+
rank_id = get_cache_group_rank(group)
|
|
1396
|
+
output = dist_comm_scatter_tensor_op(
|
|
1397
|
+
output_tensor, input_tensor, rank_size, src, rank_id, group
|
|
1398
|
+
)
|
|
1399
|
+
_, handle = _deal_comm_outputs(output, async_op)
|
|
1400
|
+
return handle
|
|
1401
|
+
|
|
1402
|
+
|
|
1403
|
+
def gather_into_tensor(output_tensor, input_tensor, dst=0, group=None, async_op=False):
|
|
1404
|
+
r"""
|
|
1405
|
+
Gathers tensors from the specified communication group. The operation will gather the tensor
|
|
1406
|
+
from processes according to dimension 0.
|
|
1407
|
+
|
|
1408
|
+
Note:
|
|
1409
|
+
- Only the tensor in process `dst` (global rank) will keep the gathered tensor. The other process
|
|
1410
|
+
will keep a tensor with shape [1], which has no mathematical meaning.
|
|
1411
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
1412
|
+
|
|
1413
|
+
Args:
|
|
1414
|
+
output_tensor (Tensor): Output tensor to accommodate tensor elements from all ranks.
|
|
1415
|
+
input_tensor (Tensor): The tensor to be gathered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
1416
|
+
the input tensors in this API must have the same size across all ranks.
|
|
1417
|
+
dst(int, optional): Specifies the rank(global rank) of the process that receive the tensor.
|
|
1418
|
+
And only process `dst` will receive the gathered tensor. Default: 0.
|
|
1419
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1420
|
+
Ascend. Default: ``None``.
|
|
1421
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
1422
|
+
|
|
1423
|
+
Returns:
|
|
1424
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
1425
|
+
CommHandle will be None, when `async_op` is False.
|
|
1426
|
+
|
|
1427
|
+
Raises:
|
|
1428
|
+
TypeError: If the type of the `input_tensor` or `output_tensor` parameter is not Tensor,
|
|
1429
|
+
or any of `op` and `group` is not a str.
|
|
1430
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
1431
|
+
|
|
1432
|
+
Supported Platforms:
|
|
1433
|
+
``Ascend``
|
|
1434
|
+
|
|
1435
|
+
Examples:
|
|
1436
|
+
.. note::
|
|
1437
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1438
|
+
|
|
1439
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1440
|
+
without any third-party or configuration file dependencies.
|
|
1441
|
+
Please see the `msrun start up
|
|
1442
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1443
|
+
for more details.
|
|
1444
|
+
|
|
1445
|
+
This example should be run with 2 devices.
|
|
1446
|
+
|
|
1447
|
+
>>> import numpy as np
|
|
1448
|
+
>>> import mindspore as ms
|
|
1449
|
+
>>> import mindspore.nn as nn
|
|
1450
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
1451
|
+
>>> from mindspore import Tensor
|
|
1452
|
+
>>> from mindspore.communication.comm_func import gather_into_tensor
|
|
1453
|
+
>>> # Launch 2 processes.
|
|
1454
|
+
>>>
|
|
1455
|
+
>>> init_process_group()
|
|
1456
|
+
>>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32))
|
|
1457
|
+
>>> output = Tensor(np.zeros([4, 2]).astype(np.float32))
|
|
1458
|
+
>>> handle = gather_into_tensor(output, input, dst=0)
|
|
1459
|
+
>>> print(output)
|
|
1460
|
+
Process with rank 0: [[0. 1.],
|
|
1461
|
+
[2. 3.],
|
|
1462
|
+
[0. 1.],
|
|
1463
|
+
[2. 3.]]
|
|
1464
|
+
Process with rank 1: [[0. 0.],
|
|
1465
|
+
[0. 0.],
|
|
1466
|
+
[0. 0.],
|
|
1467
|
+
[0. 0.]]
|
|
1468
|
+
"""
|
|
1469
|
+
if not isinstance(input_tensor, (Tensor, Tensor_)):
|
|
1470
|
+
raise TypeError("For gather_into_tensor, the input tensor must be tensor")
|
|
1471
|
+
if not isinstance(output_tensor, (Tensor, Tensor_)):
|
|
1472
|
+
raise TypeError("For gather_into_tensor, the output tensor must be tensor")
|
|
1473
|
+
if not isinstance(dst, int):
|
|
1474
|
+
raise TypeError("For gather_into_tensor, the dst must be int")
|
|
1475
|
+
if group is None:
|
|
1476
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1477
|
+
if not isinstance(group, str):
|
|
1478
|
+
raise TypeError(
|
|
1479
|
+
"The argument 'group' must be type of string, "
|
|
1480
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1481
|
+
)
|
|
1482
|
+
if not isinstance(async_op, bool):
|
|
1483
|
+
raise TypeError(
|
|
1484
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
1485
|
+
)
|
|
1486
|
+
group_size = get_cache_group_size(group)
|
|
1487
|
+
dst = get_group_rank_from_world_rank(dst, group)
|
|
1488
|
+
rank_id = get_cache_group_rank(group)
|
|
1489
|
+
output = dist_comm_gather_into_tensor_op(
|
|
1490
|
+
output_tensor, input_tensor, group_size, dst, rank_id, group
|
|
1491
|
+
)
|
|
1492
|
+
_, handle = _deal_comm_outputs(output, async_op)
|
|
1493
|
+
return handle
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
def broadcast(tensor, src, group=None, async_op=False):
|
|
1497
|
+
"""
|
|
1498
|
+
Broadcasts the tensor to the whole group.
|
|
1499
|
+
|
|
1500
|
+
Note:
|
|
1501
|
+
- The tensors must have the same shape and format in all processes of the collection.
|
|
1502
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
1503
|
+
|
|
1504
|
+
Args:
|
|
1505
|
+
tensor (Tensor): Data to be sent if src is the rank of current process,
|
|
1506
|
+
and tensor to be used to save received data otherwise.
|
|
1507
|
+
src (int): Specifies the rank(global rank) of the process that broadcast the tensor.
|
|
1508
|
+
And only process `src` will broadcast the tensor.
|
|
1509
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1510
|
+
Ascend. Default: ``None``.
|
|
1511
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
1512
|
+
|
|
1513
|
+
Returns:
|
|
1514
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
1515
|
+
CommHandle will be None, when `async_op` is False.
|
|
1516
|
+
|
|
1517
|
+
Raises:
|
|
1518
|
+
TypeError: If the type of the `tensor` parameter is not Tensor, `src` is not an integer,
|
|
1519
|
+
`group` is not a string or `async_op` is not bool.
|
|
1520
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
1521
|
+
|
|
1522
|
+
Supported Platforms:
|
|
1523
|
+
``Ascend`` ``CPU``
|
|
1524
|
+
|
|
1525
|
+
Examples:
|
|
1526
|
+
.. note::
|
|
1527
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1528
|
+
|
|
1529
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1530
|
+
without any third-party or configuration file dependencies.
|
|
1531
|
+
Please see the `msrun start up
|
|
1532
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1533
|
+
for more details.
|
|
1534
|
+
|
|
1535
|
+
This example should be run with 2 devices.
|
|
1536
|
+
|
|
1537
|
+
>>> import mindspore as ms
|
|
1538
|
+
>>> from mindspore.mint.distributed import init_process_group, broadcast
|
|
1539
|
+
>>> import numpy as np
|
|
1540
|
+
>>> # Launch 2 processes.
|
|
1541
|
+
>>>
|
|
1542
|
+
>>> init_process_group()
|
|
1543
|
+
>>> data = ms.Tensor(np.arange(8).reshape([2, 4]).astype(np.float32))
|
|
1544
|
+
>>> handle = broadcast(tensor=data, src=0)
|
|
1545
|
+
>>> print(data)
|
|
1546
|
+
[[0. 1. 2. 3.]
|
|
1547
|
+
[4. 5. 6. 7.]]
|
|
1548
|
+
"""
|
|
1549
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
1550
|
+
raise TypeError("For broadcast, the input tensor must be tensor")
|
|
1551
|
+
if not isinstance(src, int):
|
|
1552
|
+
raise TypeError("For broadcast, the src must be int")
|
|
1553
|
+
if group is None:
|
|
1554
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1555
|
+
if not isinstance(group, str):
|
|
1556
|
+
raise TypeError(
|
|
1557
|
+
"The argument 'group' must be type of string, "
|
|
1558
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1559
|
+
)
|
|
1560
|
+
if not isinstance(async_op, bool):
|
|
1561
|
+
raise TypeError(
|
|
1562
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
1563
|
+
)
|
|
1564
|
+
src_rank = get_group_rank_from_world_rank(src, group)
|
|
1565
|
+
rank_id = get_cache_group_rank(group)
|
|
1566
|
+
output = dist_comm_broadcast_op(tensor, src_rank, rank_id, group)
|
|
1567
|
+
_, handle = _deal_comm_outputs(output, async_op)
|
|
1568
|
+
return handle
|
|
1569
|
+
|
|
1570
|
+
|
|
1571
|
+
def barrier(group=None, async_op=False, device_ids=None):
|
|
1572
|
+
"""
|
|
1573
|
+
Synchronizes all processes in the specified group. Once the process call this operation, it will be blocked until
|
|
1574
|
+
all processes call this operation. After all processes finish calling the operations, the blocked processes
|
|
1575
|
+
will be woken and continue their task.
|
|
1576
|
+
|
|
1577
|
+
Args:
|
|
1578
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1579
|
+
Ascend. Default: ``None``.
|
|
1580
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
1581
|
+
device_ids (list[int], optional): Currently It is a reserved Parameter.
|
|
1582
|
+
|
|
1583
|
+
Returns:
|
|
1584
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
1585
|
+
CommHandle will be None, when `async_op` is False.
|
|
1586
|
+
|
|
1587
|
+
Raises:
|
|
1588
|
+
TypeError: `group` is not a str or `async_op` is not a bool.
|
|
1589
|
+
RuntimeError: If backend is invalid, or distributed initialization fails.
|
|
1590
|
+
|
|
1591
|
+
Supported Platforms:
|
|
1592
|
+
``Ascend``
|
|
1593
|
+
|
|
1594
|
+
Examples:
|
|
1595
|
+
.. note::
|
|
1596
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1597
|
+
|
|
1598
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1599
|
+
without any third-party or configuration file dependencies.
|
|
1600
|
+
Please see the `msrun start up
|
|
1601
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1602
|
+
for more details.
|
|
1603
|
+
|
|
1604
|
+
This example should be run with 2 devices.
|
|
1605
|
+
|
|
1606
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
1607
|
+
>>> from mindspore.communication.comm_func import barrier
|
|
1608
|
+
>>> # Launch 2 processes.
|
|
1609
|
+
>>> init_process_group()
|
|
1610
|
+
>>> barrier()
|
|
1611
|
+
>>> print("barrier finish!")
|
|
1612
|
+
barrier finish!
|
|
1613
|
+
"""
|
|
1614
|
+
if group is None:
|
|
1615
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1616
|
+
if not isinstance(group, str):
|
|
1617
|
+
raise TypeError(
|
|
1618
|
+
"The argument 'group' must be type of string, "
|
|
1619
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1620
|
+
)
|
|
1621
|
+
if not isinstance(async_op, bool):
|
|
1622
|
+
raise TypeError(
|
|
1623
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
1624
|
+
)
|
|
1625
|
+
output = dist_comm_barrier_op(group)
|
|
1626
|
+
_, handle = _deal_comm_outputs(output, async_op, True)
|
|
1627
|
+
return handle
|
|
1628
|
+
|
|
1629
|
+
|
|
1630
|
+
def send(tensor, dst=0, group=None, tag=0):
|
|
1631
|
+
"""
|
|
1632
|
+
Send tensors to the specified dest_rank.
|
|
1633
|
+
|
|
1634
|
+
Note:
|
|
1635
|
+
Only support PyNative mode, Graph mode is not currently supported.
|
|
1636
|
+
|
|
1637
|
+
Args:
|
|
1638
|
+
tensor (Tensor): Tensor to send.
|
|
1639
|
+
dst (int, optional): A required integer identifying the destination rank(global rank). Default: 0.
|
|
1640
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1641
|
+
Ascend. Default: ``None``.
|
|
1642
|
+
tag (int, optional): A required integer identifying the send/recv message tag. The message will
|
|
1643
|
+
be received by the Receive op with the same "tag". Default: 0. It is a reserved parameter currently.
|
|
1644
|
+
|
|
1645
|
+
Raises:
|
|
1646
|
+
TypeError: If the `tensor` is not Tensor, `dst` is not an int or `group` is not a str.
|
|
1647
|
+
ValueError: If the `dst` process rank id is same as the current process.
|
|
1648
|
+
|
|
1649
|
+
Supported Platforms:
|
|
1650
|
+
``Ascend``
|
|
1651
|
+
|
|
1652
|
+
Examples:
|
|
1653
|
+
.. note::
|
|
1654
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1655
|
+
|
|
1656
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1657
|
+
without any third-party or configuration file dependencies.
|
|
1658
|
+
Please see the `msrun start up
|
|
1659
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1660
|
+
for more details.
|
|
1661
|
+
|
|
1662
|
+
This example should be run with 2 devices.
|
|
1663
|
+
|
|
1664
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
1665
|
+
>>> from mindspore.mint.distributed import send, recv, get_rank
|
|
1666
|
+
>>> from mindspore import Tensor
|
|
1667
|
+
>>> import numpy as np
|
|
1668
|
+
>>>
|
|
1669
|
+
# Launch 2 processes, Process 0 sends the array to Process 1.
|
|
1670
|
+
>>> init_process_group()
|
|
1671
|
+
>>> this_rank = get_rank()
|
|
1672
|
+
>>> if this_rank == 0:
|
|
1673
|
+
... input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
1674
|
+
... send(input_, 1)
|
|
1675
|
+
>>> if this_rank == 1:
|
|
1676
|
+
... x = Tensor(np.zeros([2, 8]).astype(np.float32))
|
|
1677
|
+
... out = recv(x, src=0)
|
|
1678
|
+
... print(x)
|
|
1679
|
+
rank 1:
|
|
1680
|
+
[[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
1681
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]]
|
|
1682
|
+
"""
|
|
1683
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
1684
|
+
raise TypeError("For send, the input tensor must be tensor")
|
|
1685
|
+
if not isinstance(dst, int):
|
|
1686
|
+
raise TypeError("For send, the dst must be int")
|
|
1687
|
+
if group is None:
|
|
1688
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1689
|
+
if not isinstance(group, str):
|
|
1690
|
+
raise TypeError(
|
|
1691
|
+
"The argument 'group' must be type of string, "
|
|
1692
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1693
|
+
)
|
|
1694
|
+
if get_cache_group_rank() == dst:
|
|
1695
|
+
raise ValueError(
|
|
1696
|
+
"Invalid destination rank: destination rank should not be the same as "
|
|
1697
|
+
"the rank of the current process."
|
|
1698
|
+
)
|
|
1699
|
+
_dst = _get_group_rank_from_world_rank_from_cache_helper(dst, group)
|
|
1700
|
+
output = dist_comm_isend_op(tensor, _dst, group, tag)
|
|
1701
|
+
_deal_comm_outputs(output, False)
|
|
1702
|
+
|
|
1703
|
+
|
|
1704
|
+
|
|
1705
|
+
def recv(tensor, src=0, group=None, tag=0):
|
|
1706
|
+
"""
|
|
1707
|
+
Receive tensors from src.
|
|
1708
|
+
|
|
1709
|
+
Note:
|
|
1710
|
+
Only support PyNative mode, Graph mode is not currently supported.
|
|
1711
|
+
|
|
1712
|
+
Args:
|
|
1713
|
+
tensor (Tensor): Tensor to fill with received data.
|
|
1714
|
+
src (int, optional): A required integer identifying the source rank(global rank). Default: ``0``.
|
|
1715
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1716
|
+
Ascend. Default: ``None``.
|
|
1717
|
+
tag (int, optional): A required integer identifying the send/recv message tag. The message will
|
|
1718
|
+
be received by the Send op with the same "tag". Default: ``0``. It is a reserved parameter currently.
|
|
1719
|
+
|
|
1720
|
+
Returns:
|
|
1721
|
+
int, If success, return ``0``.
|
|
1722
|
+
|
|
1723
|
+
Raises:
|
|
1724
|
+
TypeError: If the `tensor` is not Tensor, `src` is not an int or `group` is not a str.
|
|
1725
|
+
ValueError: If the rank ID of the process is greater than the rank size of the communication group.
|
|
1726
|
+
|
|
1727
|
+
Supported Platforms:
|
|
1728
|
+
``Ascend``
|
|
1729
|
+
|
|
1730
|
+
Examples:
|
|
1731
|
+
.. note::
|
|
1732
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1733
|
+
|
|
1734
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1735
|
+
without any third-party or configuration file dependencies.
|
|
1736
|
+
Please see the `msrun start up
|
|
1737
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1738
|
+
for more details.
|
|
1739
|
+
|
|
1740
|
+
This example should be run with 2 devices.
|
|
1741
|
+
|
|
1742
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
1743
|
+
>>> from mindspore.mint.distributed import send, recv, get_rank
|
|
1744
|
+
>>> from mindspore import Tensor
|
|
1745
|
+
>>> import numpy as np
|
|
1746
|
+
>>>
|
|
1747
|
+
# Launch 2 processes, Process 0 sends the array to Process 1.
|
|
1748
|
+
>>> init_process_group()
|
|
1749
|
+
>>> this_rank = get_rank()
|
|
1750
|
+
>>> if this_rank == 0:
|
|
1751
|
+
... input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
1752
|
+
... send(input_, 1)
|
|
1753
|
+
>>> if this_rank == 1:
|
|
1754
|
+
... x = Tensor(np.zeros([2, 8]).astype(np.float32))
|
|
1755
|
+
... out = recv(x, src=0)
|
|
1756
|
+
... print(x)
|
|
1757
|
+
rank 1:
|
|
1758
|
+
[[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
1759
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]]
|
|
1760
|
+
"""
|
|
1761
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
1762
|
+
raise TypeError("For recv, the input tensor must be tensor")
|
|
1763
|
+
if not isinstance(src, int):
|
|
1764
|
+
raise TypeError("For recv, the src must be int")
|
|
1765
|
+
if group is None:
|
|
1766
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1767
|
+
if not isinstance(group, str):
|
|
1768
|
+
raise TypeError(
|
|
1769
|
+
"The argument 'group' must be type of string, "
|
|
1770
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1771
|
+
)
|
|
1772
|
+
_src = _get_group_rank_from_world_rank_from_cache_helper(src, group)
|
|
1773
|
+
_deal_comm_outputs(
|
|
1774
|
+
dist_comm_irecv_op(tensor, tag, _src, group), False
|
|
1775
|
+
)
|
|
1776
|
+
return 0
|
|
1777
|
+
|
|
1778
|
+
|
|
1779
|
+
def isend(tensor, dst=0, group=None, tag=0):
|
|
1780
|
+
"""
|
|
1781
|
+
Send tensors to the specified dest_rank asynchronously.
|
|
1782
|
+
|
|
1783
|
+
Note:
|
|
1784
|
+
Only support PyNative mode, Graph mode is not currently supported.
|
|
1785
|
+
|
|
1786
|
+
Args:
|
|
1787
|
+
tensor (Tensor): Tensor to send.
|
|
1788
|
+
dst (int, optional): A required integer identifying the destination rank(global rank). Default: 0.
|
|
1789
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1790
|
+
Ascend. Default: ``None``.
|
|
1791
|
+
tag (int, optional): A required integer identifying the send/recv message tag. The message will
|
|
1792
|
+
be received by the Receive op with the same "tag". Default: 0. It is a reserved parameter currently.
|
|
1793
|
+
|
|
1794
|
+
Returns:
|
|
1795
|
+
CommHandle, it is an async work handle.
|
|
1796
|
+
|
|
1797
|
+
Raises:
|
|
1798
|
+
TypeError: If the `tensor` is not Tensor, `dst` is not an int or `group` is not a str.
|
|
1799
|
+
ValueError: If the `dst` process rank id is same as the current process.
|
|
1800
|
+
|
|
1801
|
+
Supported Platforms:
|
|
1802
|
+
``Ascend``
|
|
1803
|
+
|
|
1804
|
+
Examples:
|
|
1805
|
+
.. note::
|
|
1806
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1807
|
+
|
|
1808
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1809
|
+
without any third-party or configuration file dependencies.
|
|
1810
|
+
Please see the `msrun start up
|
|
1811
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1812
|
+
for more details.
|
|
1813
|
+
|
|
1814
|
+
This example should be run with 2 devices.
|
|
1815
|
+
|
|
1816
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
1817
|
+
>>> from mindspore.mint.distributed import isend, irecv, get_rank
|
|
1818
|
+
>>> from mindspore import Tensor
|
|
1819
|
+
>>> import numpy as np
|
|
1820
|
+
>>>
|
|
1821
|
+
# Launch 2 processes, Process 0 sends the array to Process 1.
|
|
1822
|
+
>>> init_process_group()
|
|
1823
|
+
>>> this_rank = get_rank()
|
|
1824
|
+
>>> if this_rank == 0:
|
|
1825
|
+
... input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
1826
|
+
... handle = isend(input_, 1)
|
|
1827
|
+
... handle.wait()
|
|
1828
|
+
>>> if this_rank == 1:
|
|
1829
|
+
... x = Tensor(np.zeros([2, 8]).astype(np.float32))
|
|
1830
|
+
... handle = irecv(x, src=0)
|
|
1831
|
+
... handle.wait()
|
|
1832
|
+
... print(x)
|
|
1833
|
+
rank 1:
|
|
1834
|
+
[[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
1835
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]]
|
|
1836
|
+
"""
|
|
1837
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
1838
|
+
raise TypeError("For isend, the input tensor must be tensor")
|
|
1839
|
+
if not isinstance(dst, int):
|
|
1840
|
+
raise TypeError("For isend, the dst must be int")
|
|
1841
|
+
if group is None:
|
|
1842
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1843
|
+
if not isinstance(group, str):
|
|
1844
|
+
raise TypeError(
|
|
1845
|
+
"The argument 'group' must be type of string, "
|
|
1846
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1847
|
+
)
|
|
1848
|
+
if get_cache_group_rank() == dst:
|
|
1849
|
+
raise ValueError(
|
|
1850
|
+
"Invalid destination rank: destination rank should not be the same as "
|
|
1851
|
+
"the rank of the current process."
|
|
1852
|
+
)
|
|
1853
|
+
_dst = _get_group_rank_from_world_rank_from_cache_helper(dst, group)
|
|
1854
|
+
output = dist_comm_isend_op(tensor, _dst, group, tag)
|
|
1855
|
+
_, handle = _deal_comm_outputs(output, True)
|
|
1856
|
+
return handle
|
|
1857
|
+
|
|
1858
|
+
|
|
1859
|
+
def irecv(tensor, src=0, group=None, tag=0):
|
|
1860
|
+
"""
|
|
1861
|
+
Receive tensors from src asynchronously.
|
|
1862
|
+
|
|
1863
|
+
Note:
|
|
1864
|
+
Only support PyNative mode, Graph mode is not currently supported.
|
|
1865
|
+
|
|
1866
|
+
Args:
|
|
1867
|
+
tensor (Tensor): Tensor to fill with received data.
|
|
1868
|
+
src (int, optional): A required integer identifying the source rank(global rank). Default: ``0``.
|
|
1869
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1870
|
+
Ascend. Default: ``None``.
|
|
1871
|
+
tag (int, optional): A required integer identifying the send/recv message tag. The message will
|
|
1872
|
+
be received by the Send op with the same "tag". Default: ``0``. It is a reserved parameter currently.
|
|
1873
|
+
|
|
1874
|
+
Returns:
|
|
1875
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
1876
|
+
CommHandle will be None, when `async_op` is False.
|
|
1877
|
+
|
|
1878
|
+
Raises:
|
|
1879
|
+
TypeError: If the type of `tensor` is not Tensor, If `src` is not an int or `group` is not a str.
|
|
1880
|
+
ValueError: If the rank ID of the process is greater than the rank size of the communication group.
|
|
1881
|
+
|
|
1882
|
+
Supported Platforms:
|
|
1883
|
+
``Ascend``
|
|
1884
|
+
|
|
1885
|
+
Examples:
|
|
1886
|
+
.. note::
|
|
1887
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1888
|
+
|
|
1889
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1890
|
+
without any third-party or configuration file dependencies.
|
|
1891
|
+
Please see the `msrun start up
|
|
1892
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1893
|
+
for more details.
|
|
1894
|
+
|
|
1895
|
+
This example should be run with 2 devices.
|
|
1896
|
+
|
|
1897
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
1898
|
+
>>> from mindspore.mint.distributed import isend, irecv, get_rank
|
|
1899
|
+
>>> from mindspore import Tensor
|
|
1900
|
+
>>> import numpy as np
|
|
1901
|
+
>>>
|
|
1902
|
+
# Launch 2 processes, Process 0 sends the array to Process 1.
|
|
1903
|
+
>>> init_process_group()
|
|
1904
|
+
>>> this_rank = get_rank()
|
|
1905
|
+
>>> if this_rank == 0:
|
|
1906
|
+
... input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
1907
|
+
... handle = isend(input_, 1)
|
|
1908
|
+
... handle.wait()
|
|
1909
|
+
>>> if this_rank == 1:
|
|
1910
|
+
... x = Tensor(np.zeros([2, 8]).astype(np.float32))
|
|
1911
|
+
... handle = irecv(x, src=0)
|
|
1912
|
+
... handle.wait()
|
|
1913
|
+
... print(x)
|
|
1914
|
+
rank 1:
|
|
1915
|
+
[[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
1916
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]]
|
|
1917
|
+
"""
|
|
1918
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
1919
|
+
raise TypeError("For irecv, the input tensor must be tensor")
|
|
1920
|
+
if group is None:
|
|
1921
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
1922
|
+
if not isinstance(group, str):
|
|
1923
|
+
raise TypeError(
|
|
1924
|
+
"The argument 'group' must be type of string, "
|
|
1925
|
+
"but got 'group' type : {}.".format(type(group))
|
|
1926
|
+
)
|
|
1927
|
+
if not isinstance(src, int):
|
|
1928
|
+
raise TypeError("For irecv, the src must be int")
|
|
1929
|
+
_src = _get_group_rank_from_world_rank_from_cache_helper(src, group)
|
|
1930
|
+
output = dist_comm_irecv_op(tensor, tag, _src, group)
|
|
1931
|
+
_, handle = _deal_comm_outputs(output, True)
|
|
1932
|
+
return handle
|
|
1933
|
+
|
|
1934
|
+
|
|
1935
|
+
def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
|
|
1936
|
+
"""
|
|
1937
|
+
scatter and gather list of tensor to/from all rank according to input/output tensor list.
|
|
1938
|
+
|
|
1939
|
+
Note:
|
|
1940
|
+
- tensor shape in `output_shape_list` and `input_tensor_list` should be match across ranks.
|
|
1941
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
1942
|
+
|
|
1943
|
+
Args:
|
|
1944
|
+
output_tensor_list (List[Tensor]): List of tensors that indicate the gathered from remote ranks.
|
|
1945
|
+
input_tensor_list (List[Tensor]): List of tensors to scatter to the remote rank.
|
|
1946
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
1947
|
+
Ascend. Default: ``None``.
|
|
1948
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
1949
|
+
|
|
1950
|
+
Returns:
|
|
1951
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
1952
|
+
CommHandle will be None, when `async_op` is False.
|
|
1953
|
+
|
|
1954
|
+
Raises:
|
|
1955
|
+
TypeError: If not all elements in `input_tensor_list` or `output_tensor_list` are Tensor.
|
|
1956
|
+
TypeError: If tensors in `input_tensor_list` or `output_tensor_list` are not the same type.
|
|
1957
|
+
TypeError: If `group` is not str or `async_op` is not bool.
|
|
1958
|
+
|
|
1959
|
+
Supported Platforms:
|
|
1960
|
+
``Ascend``
|
|
1961
|
+
|
|
1962
|
+
Examples:
|
|
1963
|
+
.. note::
|
|
1964
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1965
|
+
|
|
1966
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
1967
|
+
without any third-party or configuration file dependencies.
|
|
1968
|
+
Please see the `msrun start up
|
|
1969
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
1970
|
+
for more details.
|
|
1971
|
+
|
|
1972
|
+
This example should be run with 2 devices.
|
|
1973
|
+
|
|
1974
|
+
>>> import mindspore as ms
|
|
1975
|
+
>>> from mindspore.mint.distributed import init_process_group, get_rank
|
|
1976
|
+
>>> from mindspore.mint.distributed import all_to_all
|
|
1977
|
+
>>> from mindspore import Tensor
|
|
1978
|
+
>>>
|
|
1979
|
+
>>> init_process_group()
|
|
1980
|
+
>>> this_rank = get_rank()
|
|
1981
|
+
>>> if this_rank == 0:
|
|
1982
|
+
... send_tensor_list = [Tensor(1.), Tensor([[2, 3], [4, 5.]])]
|
|
1983
|
+
... recv_tensor_list = [Tensor((0), dtype=ms.float32), Tensor([0, 0.])]
|
|
1984
|
+
>>> if this_rank == 1:
|
|
1985
|
+
... send_tensor_list = [Tensor([2, 2.]), Tensor([4, 5, 6, 7.])]
|
|
1986
|
+
... recv_tensor_list = [Tensor([[0, 0.],[0, 0]]), Tensor([0, 0, 0, 0.])]
|
|
1987
|
+
>>> handle = all_to_all(recv_tensor_list, send_tensor_list)
|
|
1988
|
+
>>> print(recv_tensor_list)
|
|
1989
|
+
rank 0:
|
|
1990
|
+
(Tensor(shape=[], dtype=Float32, value= 1),
|
|
1991
|
+
Tensor(shape=[2], dtype=Float32, value= [2.00000000e+00, 2.00000000e+00]))
|
|
1992
|
+
rank 1:
|
|
1993
|
+
(Tensor(shape=[2, 2], dtype=Float32, value=
|
|
1994
|
+
[[2.00000000e+00, 3.00000000e+00],
|
|
1995
|
+
[4.00000000e+00, 5.00000000e+00]]),
|
|
1996
|
+
Tensor(shape=[4], dtype=Float32, value=[4.00000000e+00, 5.00000000e+00, 6.00000000e+00, 7.00000000e+00]))
|
|
1997
|
+
|
|
1998
|
+
"""
|
|
1999
|
+
if group is None:
|
|
2000
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2001
|
+
if not isinstance(group, str):
|
|
2002
|
+
raise TypeError(
|
|
2003
|
+
"The argument 'group' must be type of string, "
|
|
2004
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2005
|
+
)
|
|
2006
|
+
if not isinstance(async_op, bool):
|
|
2007
|
+
raise TypeError(
|
|
2008
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
2009
|
+
)
|
|
2010
|
+
|
|
2011
|
+
_check_all_tensors(input_tensor_list)
|
|
2012
|
+
_check_all_tensors(output_tensor_list)
|
|
2013
|
+
_check_all_tensor_same_dtype(input_tensor_list)
|
|
2014
|
+
_check_all_tensor_same_dtype(output_tensor_list)
|
|
2015
|
+
send_numel_list = []
|
|
2016
|
+
send_flatten_tensor = []
|
|
2017
|
+
recv_numel_list = []
|
|
2018
|
+
recv_shape_list = []
|
|
2019
|
+
|
|
2020
|
+
for tensor in input_tensor_list:
|
|
2021
|
+
send_numel_list.append(tensor.size)
|
|
2022
|
+
send_flatten_tensor.append(tensor.reshape(-1))
|
|
2023
|
+
for tensor in output_tensor_list:
|
|
2024
|
+
recv_numel_list.append(tensor.size)
|
|
2025
|
+
recv_shape_list.append(tensor.shape)
|
|
2026
|
+
|
|
2027
|
+
send_flatten_tensor = cat(send_flatten_tensor)
|
|
2028
|
+
|
|
2029
|
+
rank_size = get_cache_group_size(group)
|
|
2030
|
+
output = dist_comm_all_to_all_v_op(
|
|
2031
|
+
output_tensor_list,
|
|
2032
|
+
send_flatten_tensor,
|
|
2033
|
+
group,
|
|
2034
|
+
send_numel_list,
|
|
2035
|
+
recv_numel_list,
|
|
2036
|
+
rank_size,
|
|
2037
|
+
)
|
|
2038
|
+
_, handle = _deal_comm_outputs(output, async_op)
|
|
2039
|
+
return handle
|
|
2040
|
+
|
|
2041
|
+
|
|
2042
|
+
def _get_all_to_all_single_numel_list(tensor, output, output_split_sizes,
|
|
2043
|
+
input_split_sizes, group):
|
|
2044
|
+
"""get numel list for all_to_all_single."""
|
|
2045
|
+
if _is_split_sizes_empty(input_split_sizes):
|
|
2046
|
+
_world_size = get_cache_group_size(group)
|
|
2047
|
+
if tensor.shape[0] % _world_size != 0:
|
|
2048
|
+
raise ValueError(
|
|
2049
|
+
"input shape at dim 0 must be divided by world_size, "
|
|
2050
|
+
f"but got {tensor.shape[0]} and {_world_size}."
|
|
2051
|
+
)
|
|
2052
|
+
_split_size = tensor.shape[0] // _world_size
|
|
2053
|
+
input_split_sizes = (_split_size,) * _world_size
|
|
2054
|
+
if _is_split_sizes_empty(output_split_sizes):
|
|
2055
|
+
_world_size = get_cache_group_size(group)
|
|
2056
|
+
shape_dim_0 = output.shape[0]
|
|
2057
|
+
|
|
2058
|
+
if shape_dim_0 % _world_size != 0:
|
|
2059
|
+
raise ValueError(
|
|
2060
|
+
"output shape at dim 0 must be divided by world_size, "
|
|
2061
|
+
f"but got {shape_dim_0} and {_world_size}."
|
|
2062
|
+
)
|
|
2063
|
+
_split_size = shape_dim_0 // _world_size
|
|
2064
|
+
output_split_sizes = (_split_size,) * _world_size
|
|
2065
|
+
|
|
2066
|
+
send_size_without_first_dim = _get_size(tensor.shape[1:])
|
|
2067
|
+
send_numel_list = [size * send_size_without_first_dim for size in input_split_sizes]
|
|
2068
|
+
|
|
2069
|
+
recv_shape_without_first_dim = output.shape[1:]
|
|
2070
|
+
recv_size_without_first_dim = _get_size(recv_shape_without_first_dim)
|
|
2071
|
+
recv_numel_list = [
|
|
2072
|
+
size * recv_size_without_first_dim for size in output_split_sizes
|
|
2073
|
+
]
|
|
2074
|
+
return send_numel_list, recv_numel_list, recv_shape_without_first_dim
|
|
2075
|
+
|
|
2076
|
+
|
|
2077
|
+
def all_to_all_single(output,
|
|
2078
|
+
input,
|
|
2079
|
+
output_split_sizes=None,
|
|
2080
|
+
input_split_sizes=None,
|
|
2081
|
+
group=None,
|
|
2082
|
+
async_op=False):
|
|
2083
|
+
"""
|
|
2084
|
+
scatter and gather input with split size to/from all rank, and return result in a single tensor.
|
|
2085
|
+
|
|
2086
|
+
Note:
|
|
2087
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
2088
|
+
|
|
2089
|
+
Args:
|
|
2090
|
+
output (Tensor): the output tensor is gathered concatenated from remote ranks.
|
|
2091
|
+
input (Tensor): tensor to be scattered to remote rank.
|
|
2092
|
+
output_split_sizes (Union(Tuple(int), List(int)), optional): output split size at dim 0. If set to None,
|
|
2093
|
+
it means equally split by ``world_size``. Default: ``None``.
|
|
2094
|
+
input_split_sizes (Union(Tuple(int), List(int)), optional): input split size at dim 0. If set to None,
|
|
2095
|
+
it means equally split by ``world_size``. Default: ``None``.
|
|
2096
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
2097
|
+
Ascend. Default: ``None``.
|
|
2098
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
2099
|
+
|
|
2100
|
+
Returns:
|
|
2101
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
2102
|
+
CommHandle will be None, when `async_op` is False.
|
|
2103
|
+
|
|
2104
|
+
Raises:
|
|
2105
|
+
TypeError: If `input` or `output` is not tensor. `group` is not a str, or async_op is not bool.
|
|
2106
|
+
ValueError: When `input_split_sizes` is empty, input dim 0 can not be divided by ``world_size``.
|
|
2107
|
+
ValueError: When `output_split_sizes` is empty, output dim 0 can not be divided by ``world_size``.
|
|
2108
|
+
|
|
2109
|
+
Supported Platforms:
|
|
2110
|
+
``Ascend``
|
|
2111
|
+
|
|
2112
|
+
Examples:
|
|
2113
|
+
.. note::
|
|
2114
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2115
|
+
|
|
2116
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
2117
|
+
without any third-party or configuration file dependencies.
|
|
2118
|
+
Please see the `msrun start up
|
|
2119
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2120
|
+
for more details.
|
|
2121
|
+
|
|
2122
|
+
This example should be run with 2 devices.
|
|
2123
|
+
|
|
2124
|
+
>>> import numpy as np
|
|
2125
|
+
>>> import mindspore
|
|
2126
|
+
>>> from mindspore.mint.distributed import init_process_group, get_rank
|
|
2127
|
+
>>> from mindspore.mint.distributed import all_to_all_single
|
|
2128
|
+
>>> from mindspore import Tensor
|
|
2129
|
+
>>> from mindspore.ops import zeros
|
|
2130
|
+
>>>
|
|
2131
|
+
>>> init_process_group()
|
|
2132
|
+
>>> this_rank = get_rank()
|
|
2133
|
+
>>> if this_rank == 0:
|
|
2134
|
+
... output = Tensor(np.zeros([3, 3]).astype(np.float32))
|
|
2135
|
+
... tensor = Tensor([[0, 1, 2.], [3, 4, 5], [6, 7, 8]])
|
|
2136
|
+
... result = all_to_all_single(output, tensor, [2, 1], [2, 1])
|
|
2137
|
+
... print(output)
|
|
2138
|
+
>>> if this_rank == 1:
|
|
2139
|
+
... output = Tensor(np.zeros([2, 3]).astype(np.float32))
|
|
2140
|
+
... tensor = Tensor([[9, 10., 11], [12, 13, 14]])
|
|
2141
|
+
... result = all_to_all_single(output, tensor, [1, 1], [1, 1])
|
|
2142
|
+
... print(output)
|
|
2143
|
+
rank 0:
|
|
2144
|
+
[[ 0. 1. 2.]
|
|
2145
|
+
[ 3. 4. 5.]
|
|
2146
|
+
[ 9. 10. 11.]]
|
|
2147
|
+
rank 1:
|
|
2148
|
+
[[ 6. 7. 8.]
|
|
2149
|
+
[12. 13. 14.]]
|
|
2150
|
+
|
|
2151
|
+
"""
|
|
2152
|
+
|
|
2153
|
+
_check_all_tensors([input])
|
|
2154
|
+
_check_all_tensors([output])
|
|
2155
|
+
if group is None:
|
|
2156
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2157
|
+
if not isinstance(group, str):
|
|
2158
|
+
raise TypeError(
|
|
2159
|
+
"The argument 'group' must be type of string, "
|
|
2160
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2161
|
+
)
|
|
2162
|
+
if not isinstance(async_op, bool):
|
|
2163
|
+
raise TypeError(
|
|
2164
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
2165
|
+
)
|
|
2166
|
+
split_sizes_empty = _is_split_sizes_empty(output_split_sizes) and _is_split_sizes_empty(input_split_sizes)
|
|
2167
|
+
send_numel_list, recv_numel_list, _ = \
|
|
2168
|
+
_get_all_to_all_single_numel_list(input, output, output_split_sizes, input_split_sizes, group)
|
|
2169
|
+
_input = input.reshape(-1)
|
|
2170
|
+
rank_size = get_cache_group_size(group)
|
|
2171
|
+
result = dist_comm_all_to_all_v_single_op(
|
|
2172
|
+
output,
|
|
2173
|
+
_input,
|
|
2174
|
+
group,
|
|
2175
|
+
send_numel_list,
|
|
2176
|
+
recv_numel_list,
|
|
2177
|
+
rank_size,
|
|
2178
|
+
split_sizes_empty,
|
|
2179
|
+
)
|
|
2180
|
+
_, handle = _deal_comm_outputs(result, async_op)
|
|
2181
|
+
return handle
|
|
2182
|
+
|
|
2183
|
+
|
|
2184
|
+
def _check_tensor_list(tensor_list, tensor, group_size):
|
|
2185
|
+
"""check all elements in tensor_list are type of Tensor or tuple or list"""
|
|
2186
|
+
if not tensor_list or len(tensor_list) != group_size:
|
|
2187
|
+
raise TypeError(
|
|
2188
|
+
f"The argument list tensor len must be equal to group rank size, but got {len(tensor_list)}."
|
|
2189
|
+
)
|
|
2190
|
+
if tensor.dtype != tensor_list[0].dtype:
|
|
2191
|
+
raise TypeError(
|
|
2192
|
+
f"The argument list tensor type must be equal to tensor type, but got {tensor_list[0].dtype}."
|
|
2193
|
+
)
|
|
2194
|
+
if tensor.shape != tensor_list[0].shape:
|
|
2195
|
+
raise TypeError(
|
|
2196
|
+
f"The argument list tensor shape must be equal to tensor shape, but got {tensor_list[0].shape}."
|
|
2197
|
+
)
|
|
2198
|
+
|
|
2199
|
+
|
|
2200
|
+
def all_gather(tensor_list, tensor, group=None, async_op=False):
|
|
2201
|
+
"""
|
|
2202
|
+
Gathers tensors from the specified communication group and returns the tensor list which is all gathered.
|
|
2203
|
+
|
|
2204
|
+
Note:
|
|
2205
|
+
The tensors must have the same shape and format in all processes of the collection.
|
|
2206
|
+
|
|
2207
|
+
Args:
|
|
2208
|
+
tensor_list (list[Tensor]): Output list.
|
|
2209
|
+
tensor (Tensor): The input tensor to be all gathered into tensor.
|
|
2210
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
2211
|
+
Ascend. Default: ``None``.
|
|
2212
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
2213
|
+
|
|
2214
|
+
Returns:
|
|
2215
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
2216
|
+
CommHandle will be None, when `async_op` is False.
|
|
2217
|
+
|
|
2218
|
+
Raises:
|
|
2219
|
+
TypeError: If the type of input `tensor` is not Tensor, `tensor_list` is not Tensor List,
|
|
2220
|
+
`group` is not a str or async_op is not bool.
|
|
2221
|
+
TypeError: If size of `tensor_list` is not equal to group size。
|
|
2222
|
+
TypeError: If the type or shape of `tensor` not equal to the member of `tensor_list`。
|
|
2223
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
2224
|
+
|
|
2225
|
+
Supported Platforms:
|
|
2226
|
+
``Ascend`` ``CPU``
|
|
2227
|
+
|
|
2228
|
+
Examples:
|
|
2229
|
+
.. note::
|
|
2230
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2231
|
+
|
|
2232
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
2233
|
+
without any third-party or configuration file dependencies.
|
|
2234
|
+
Please see the `msrun start up
|
|
2235
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2236
|
+
for more details.
|
|
2237
|
+
|
|
2238
|
+
This example should be run with 2 devices.
|
|
2239
|
+
|
|
2240
|
+
>>> import numpy as np
|
|
2241
|
+
>>> import mindspore as ms
|
|
2242
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
2243
|
+
>>> from mindspore.mint.distributed import all_gather
|
|
2244
|
+
>>> from mindspore import Tensor
|
|
2245
|
+
>>>
|
|
2246
|
+
>>> init_process_group()
|
|
2247
|
+
>>> input_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
2248
|
+
>>> out_tensors = [Tensor(np.zeros([2, 8]).astype(np.float32)), Tensor(np.zeros([2, 8]).astype(np.float32))]
|
|
2249
|
+
>>> output = all_gather(out_tensors, input_tensor)
|
|
2250
|
+
>>> print(out_tensors)
|
|
2251
|
+
[Tensor(shape=[2, 8], dtype=Float32, value=
|
|
2252
|
+
[[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00 ... 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
|
|
2253
|
+
[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00 ... 1.00000000e+00, 1.00000000e+00, 1.00000000e+00]]),
|
|
2254
|
+
Tensor(shape=[2, 8], dtype=Float32, value=
|
|
2255
|
+
[[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00 ... 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
|
|
2256
|
+
[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00 ... 1.00000000e+00, 1.00000000e+00, 1.00000000e+00]])]
|
|
2257
|
+
|
|
2258
|
+
|
|
2259
|
+
"""
|
|
2260
|
+
_check_all_tensors(tensor_list)
|
|
2261
|
+
_check_all_tensor_same_dtype_and_shape(tensor_list)
|
|
2262
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
2263
|
+
raise TypeError("For all_gather_into_tensor, the input tensor must be tensor")
|
|
2264
|
+
if group is None:
|
|
2265
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2266
|
+
if not isinstance(group, str):
|
|
2267
|
+
raise TypeError(
|
|
2268
|
+
"The argument 'group' must be type of string, "
|
|
2269
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2270
|
+
)
|
|
2271
|
+
if not isinstance(async_op, bool):
|
|
2272
|
+
raise TypeError(
|
|
2273
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
2274
|
+
)
|
|
2275
|
+
group_size = get_cache_group_size(group)
|
|
2276
|
+
_check_tensor_list(tensor_list, tensor, group_size)
|
|
2277
|
+
result = dist_comm_all_gather_op(tensor_list, tensor, group_size, group)
|
|
2278
|
+
_, handle = _deal_comm_outputs(result, async_op)
|
|
2279
|
+
return handle
|
|
2280
|
+
|
|
2281
|
+
|
|
2282
|
+
def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
|
|
2283
|
+
r"""
|
|
2284
|
+
Reduces and scatters tensors from the specified communication group and
|
|
2285
|
+
returns the tensor which is reduced and scattered.
|
|
2286
|
+
|
|
2287
|
+
Note:
|
|
2288
|
+
The tensors must have the same shape and format in all processes of the collection.
|
|
2289
|
+
|
|
2290
|
+
Args:
|
|
2291
|
+
output (Tensor): the output tensor.
|
|
2292
|
+
input_list (list[Tensor]): List of tensors to reduce and scatter.
|
|
2293
|
+
op (str, optional): Specifies an operation used for element-wise reductions,
|
|
2294
|
+
like SUM and MAX. Default: ``ReduceOp.SUM`` .
|
|
2295
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
2296
|
+
Ascend. Default: ``None``.
|
|
2297
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
2298
|
+
|
|
2299
|
+
Returns:
|
|
2300
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
2301
|
+
CommHandle will be None, when `async_op` is False.
|
|
2302
|
+
|
|
2303
|
+
Raises:
|
|
2304
|
+
TypeError: If the type of `output` parameter is not Tensor, `input_list` is not Tensor List.
|
|
2305
|
+
TypeError: If any of `op` and `group` is not a str. async_op is not bool or 'op' is invalid.
|
|
2306
|
+
TypeError: If size of `input_list` is not equal to group size.
|
|
2307
|
+
TypeError: If the type or shape of `output` not equal to the member of `input_list`.
|
|
2308
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
2309
|
+
|
|
2310
|
+
Supported Platforms:
|
|
2311
|
+
``Ascend``
|
|
2312
|
+
|
|
2313
|
+
Examples:
|
|
2314
|
+
.. note::
|
|
2315
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2316
|
+
|
|
2317
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
2318
|
+
without any third-party or configuration file dependencies.
|
|
2319
|
+
Please see the `msrun start up
|
|
2320
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2321
|
+
for more details.
|
|
2322
|
+
|
|
2323
|
+
This example should be run with 2 devices.
|
|
2324
|
+
|
|
2325
|
+
>>> from mindspore import Tensor
|
|
2326
|
+
>>> from mindspore.mint.distributed import init_process_group
|
|
2327
|
+
>>> from mindspore.mint.distributed import reduce_scatter
|
|
2328
|
+
>>> import numpy as np
|
|
2329
|
+
>>>
|
|
2330
|
+
>>> init_process_group()
|
|
2331
|
+
>>> input_tensors = [Tensor(np.ones([4, 8]).astype(np.float32)), Tensor(np.ones([4, 8]).astype(np.float32))]
|
|
2332
|
+
>>> output_tensor = Tensor(np.zeros([4, 8]).astype(np.float32))
|
|
2333
|
+
>>> output = reduce_scatter(output_tensor ,input_tensors)
|
|
2334
|
+
>>> print(output_tensor)
|
|
2335
|
+
[[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
2336
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
2337
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
2338
|
+
[2. 2. 2. 2. 2. 2. 2. 2.]]
|
|
2339
|
+
|
|
2340
|
+
"""
|
|
2341
|
+
|
|
2342
|
+
_check_all_tensors(input_list)
|
|
2343
|
+
_check_all_tensor_same_dtype_and_shape(input_list)
|
|
2344
|
+
if not isinstance(output, (Tensor, Tensor_)):
|
|
2345
|
+
raise TypeError("For reduce_scatter, the output tensor must be tensor")
|
|
2346
|
+
if group is None:
|
|
2347
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2348
|
+
if not isinstance(group, str):
|
|
2349
|
+
raise TypeError(
|
|
2350
|
+
"The argument 'group' must be type of string, "
|
|
2351
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2352
|
+
)
|
|
2353
|
+
if not isinstance(async_op, bool):
|
|
2354
|
+
raise TypeError(
|
|
2355
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
2356
|
+
)
|
|
2357
|
+
if not isinstance(op, str):
|
|
2358
|
+
raise TypeError("For reduce_scatter, the input op type must be str")
|
|
2359
|
+
if op not in ("sum", "prod", "min", "max"):
|
|
2360
|
+
raise TypeError(
|
|
2361
|
+
"For reduce_scatter, the input op value must be one of sum, prod, min, max"
|
|
2362
|
+
)
|
|
2363
|
+
rank_size = get_cache_group_size(group)
|
|
2364
|
+
_check_tensor_list(input_list, output, rank_size)
|
|
2365
|
+
result = dist_comm_reduce_scatter_op(output, input_list, rank_size, op, group)
|
|
2366
|
+
_, handle = _deal_comm_outputs(result, async_op)
|
|
2367
|
+
return handle
|
|
2368
|
+
|
|
2369
|
+
|
|
2370
|
+
def scatter(tensor, scatter_list, src=0, group=None, async_op=False):
|
|
2371
|
+
r"""
|
|
2372
|
+
Scatter tensor evently across the processes in the specified communication group.
|
|
2373
|
+
|
|
2374
|
+
Note:
|
|
2375
|
+
- The interface behavior only support Tensor List input and scatter evenly.
|
|
2376
|
+
- Only the tensor in process `src` (global rank) will do scatter.
|
|
2377
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
2378
|
+
|
|
2379
|
+
Args:
|
|
2380
|
+
tensor (Tensor): the output tensor.
|
|
2381
|
+
scatter_list (list[Tensor]): List of same-sized tensors to scatter.
|
|
2382
|
+
default is None, must be specified on the source rank.
|
|
2383
|
+
src (int, optional): Specifies the rank(global rank) of the process that send the tensor.
|
|
2384
|
+
And only process `src` will send the tensor.
|
|
2385
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
2386
|
+
Ascend. Default: ``None``.
|
|
2387
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
2388
|
+
|
|
2389
|
+
Returns:
|
|
2390
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
2391
|
+
CommHandle will be None, when `async_op` is False.
|
|
2392
|
+
|
|
2393
|
+
Raises:
|
|
2394
|
+
TypeError: If the type of `tensor` parameter is not Tensor, `scatter_list` is not Tensor List.
|
|
2395
|
+
TypeError: If any of `op` and `group` is not a str. async_op is not bool or 'op' is invalid.
|
|
2396
|
+
TypeError: If size of `scatter_list` is not equal to group size.
|
|
2397
|
+
TypeError: If the type or shape of `tensor` not equal to the member of `scatter_list`.
|
|
2398
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
2399
|
+
|
|
2400
|
+
Supported Platforms:
|
|
2401
|
+
``Ascend`` ``CPU``
|
|
2402
|
+
|
|
2403
|
+
Examples:
|
|
2404
|
+
.. note::
|
|
2405
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2406
|
+
|
|
2407
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
2408
|
+
without any third-party or configuration file dependencies.
|
|
2409
|
+
Please see the `msrun start up
|
|
2410
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2411
|
+
for more details.
|
|
2412
|
+
|
|
2413
|
+
This example should be run with 2 devices.
|
|
2414
|
+
|
|
2415
|
+
>>> from mindspore import Tensor
|
|
2416
|
+
>>> from mindspore.mint.distributed import init_process_group, scatter
|
|
2417
|
+
>>> import numpy as np
|
|
2418
|
+
>>> # Launch 2 processes.
|
|
2419
|
+
>>>
|
|
2420
|
+
>>> init_process_group()
|
|
2421
|
+
>>> inputs = [Tensor(np.ones([2, 2]).astype(np.float32)), Tensor(np.ones([2, 2]).astype(np.float32))]
|
|
2422
|
+
>>> output = Tensor(np.zeros([2, 2]).astype(np.float32))
|
|
2423
|
+
>>> scatter(output, inputs, src=0)
|
|
2424
|
+
>>> print(output)
|
|
2425
|
+
# rank_0
|
|
2426
|
+
[[1. 1.]
|
|
2427
|
+
[1. 1.]]
|
|
2428
|
+
# rank_1
|
|
2429
|
+
[[1. 1.]
|
|
2430
|
+
[1. 1.]]
|
|
2431
|
+
"""
|
|
2432
|
+
_check_all_tensors(scatter_list)
|
|
2433
|
+
_check_all_tensor_same_dtype_and_shape(scatter_list)
|
|
2434
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
2435
|
+
raise TypeError("For scatter_tensor, the output tensor must be tensor")
|
|
2436
|
+
if not isinstance(src, int):
|
|
2437
|
+
raise TypeError("For scatter_tensor, the src must be int")
|
|
2438
|
+
if group is None:
|
|
2439
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2440
|
+
if not isinstance(group, str):
|
|
2441
|
+
raise TypeError(
|
|
2442
|
+
"The argument 'group' must be type of string, "
|
|
2443
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2444
|
+
)
|
|
2445
|
+
if not isinstance(async_op, bool):
|
|
2446
|
+
raise TypeError(
|
|
2447
|
+
f"The argument 'async_op' must be a bool, but got {type(async_op)}."
|
|
2448
|
+
)
|
|
2449
|
+
src = get_group_rank_from_world_rank(src, group)
|
|
2450
|
+
rank_size = get_cache_group_size(group)
|
|
2451
|
+
rank_id = get_cache_group_rank(group)
|
|
2452
|
+
if src == rank_id:
|
|
2453
|
+
_check_tensor_list(scatter_list, tensor, rank_size)
|
|
2454
|
+
output = dist_comm_scatter_op(tensor, scatter_list, rank_size, src, rank_id, group)
|
|
2455
|
+
_, handle = _deal_comm_outputs(output, async_op)
|
|
2456
|
+
return handle
|
|
2457
|
+
|
|
2458
|
+
|
|
2459
|
+
def gather(tensor, gather_list, dst=0, group=None, async_op=False):
|
|
2460
|
+
r"""
|
|
2461
|
+
Gathers tensors from the specified communication group. The operation will gather the tensor
|
|
2462
|
+
from processes according to dimension 0.
|
|
2463
|
+
|
|
2464
|
+
Note:
|
|
2465
|
+
- Only the tensor in process `dst` (global rank) will keep the gathered tensor. The other process
|
|
2466
|
+
will keep a tensor list which has no mathematical meaning.
|
|
2467
|
+
- The tensors must have the same shape and format in all processes of the collection.
|
|
2468
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
2469
|
+
|
|
2470
|
+
Args:
|
|
2471
|
+
tensor (Tensor): The tensor to be gathered.
|
|
2472
|
+
gather_list (list[Tensor]): List of same-sized tensors to use for gathered data.
|
|
2473
|
+
dst (int, optional): Specifies the rank(global rank) of the process that receive the tensor.
|
|
2474
|
+
And only process `dst` will receive the gathered tensor. Default: ``0`` .
|
|
2475
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
2476
|
+
Ascend. Default: ``None``.
|
|
2477
|
+
async_op (bool, optional): Whether this operator should be an async operator. Default: ``False`` .
|
|
2478
|
+
|
|
2479
|
+
Returns:
|
|
2480
|
+
CommHandle, CommHandle is an async work handle, if `async_op` is set to True.
|
|
2481
|
+
CommHandle will be None, when `async_op` is False.
|
|
2482
|
+
|
|
2483
|
+
Raises:
|
|
2484
|
+
TypeError: If the type of input tensor is not Tensor, or gather_list is not Tensor list.
|
|
2485
|
+
TypeError: If dst is not an integer, group is not a string or async_op is not bool.
|
|
2486
|
+
TypeError: If size of `gather_list` is not equal to group size.
|
|
2487
|
+
TypeError: If the type or shape of `tensor` not equal to the member of `gather_list`.
|
|
2488
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
2489
|
+
|
|
2490
|
+
Supported Platforms:
|
|
2491
|
+
``Ascend`` ``CPU``
|
|
2492
|
+
|
|
2493
|
+
Examples:
|
|
2494
|
+
.. note::
|
|
2495
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2496
|
+
|
|
2497
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
2498
|
+
without any third-party or configuration file dependencies.
|
|
2499
|
+
Please see the `msrun start up
|
|
2500
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2501
|
+
for more details.
|
|
2502
|
+
|
|
2503
|
+
This example should be run with 2 devices.
|
|
2504
|
+
|
|
2505
|
+
>>> import numpy as np
|
|
2506
|
+
>>> import mindspore as ms
|
|
2507
|
+
>>> import mindspore.nn as nn
|
|
2508
|
+
>>> from mindspore.mint.distributed import init_process_group, gather
|
|
2509
|
+
>>> from mindspore import Tensor
|
|
2510
|
+
>>> # Launch 2 processes.
|
|
2511
|
+
>>> init_process_group()
|
|
2512
|
+
>>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32))
|
|
2513
|
+
>>> outputs = [Tensor(np.zeros([2, 2]).astype(np.float32)),Tensor(np.zeros([2, 2]).astype(np.float32))]
|
|
2514
|
+
>>> gather(input, outputs, dst=0)
|
|
2515
|
+
>>> print(outputs)
|
|
2516
|
+
# rank_0
|
|
2517
|
+
[Tensor(shape=[2, 2], dtype=Float32, value=
|
|
2518
|
+
[[ 0.00000000e+00, 1.00000000e+00],
|
|
2519
|
+
[ 2.00000000e+00, 3.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
2520
|
+
[[ 0.00000000e+00, 1.00000000e+00], [ 2.00000000e+00, 3.00000000e+00]])]
|
|
2521
|
+
[Tensor(shape=[2, 2], dtype=Float32, value=
|
|
2522
|
+
[[ 0.00000000e+00, 1.00000000e+00],
|
|
2523
|
+
[ 2.00000000e+00, 3.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
2524
|
+
[[ 0.00000000e+00, 1.00000000e+00], [ 2.00000000e+00, 3.00000000e+00]])]
|
|
2525
|
+
# rank_1
|
|
2526
|
+
[Tensor(shape=[2, 2], dtype=Float32, value=
|
|
2527
|
+
[[ 0.00000000e+00, 0.00000000e+00],
|
|
2528
|
+
[ 0.00000000e+00, 0.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
2529
|
+
[[ 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00]])]
|
|
2530
|
+
[Tensor(shape=[2, 2], dtype=Float32, value=
|
|
2531
|
+
[[ 0.00000000e+00, 0.00000000e+00],
|
|
2532
|
+
[ 0.00000000e+00, 0.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
2533
|
+
[[ 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00]])]
|
|
2534
|
+
"""
|
|
2535
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
2536
|
+
raise TypeError("For gather, the input tensor must be tensor")
|
|
2537
|
+
_check_all_tensors(gather_list)
|
|
2538
|
+
_check_all_tensor_same_dtype_and_shape(gather_list)
|
|
2539
|
+
if not isinstance(dst, int):
|
|
2540
|
+
raise TypeError("For gather, the dst must be int")
|
|
2541
|
+
if group is None:
|
|
2542
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2543
|
+
if not isinstance(group, str):
|
|
2544
|
+
raise TypeError(
|
|
2545
|
+
"The argument 'group' must be type of string, "
|
|
2546
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2547
|
+
)
|
|
2548
|
+
if not isinstance(async_op, bool):
|
|
2549
|
+
raise TypeError(f"The argument 'async_op' must be a bool, but got {type(async_op)}.")
|
|
2550
|
+
group_size = get_cache_group_size(group)
|
|
2551
|
+
dst = get_group_rank_from_world_rank(dst, group)
|
|
2552
|
+
rank_id = get_cache_group_rank(group)
|
|
2553
|
+
if dst == rank_id:
|
|
2554
|
+
_check_tensor_list(gather_list, tensor, group_size)
|
|
2555
|
+
output = dist_comm_gather_op(tensor, gather_list, group_size, dst, rank_id, group)
|
|
2556
|
+
_, handle = _deal_comm_outputs(output, async_op)
|
|
2557
|
+
return handle
|
|
2558
|
+
|
|
2559
|
+
|
|
2560
|
+
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None):
|
|
2561
|
+
r"""
|
|
2562
|
+
Scatters picklable objects in scatter_object_input_list to the whole group.
|
|
2563
|
+
|
|
2564
|
+
Note:
|
|
2565
|
+
- Similar to :func:`mindspore.mint.distributed.scatter`, but Python objects can be passed in.
|
|
2566
|
+
- Only the objects in process `src` (global rank) will do scatter.
|
|
2567
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
2568
|
+
|
|
2569
|
+
Args:
|
|
2570
|
+
scatter_object_output_list (list[Any]): Non-empty list whose first element
|
|
2571
|
+
will store the object scattered to this rank.
|
|
2572
|
+
scatter_object_input_list (list[Any]): List of python objects to scatter.
|
|
2573
|
+
it must be specified on the source rank.
|
|
2574
|
+
src (int, optional): Specifies the rank(global rank) of the process that send the tensor.
|
|
2575
|
+
And only process `src` will send the tensor. Default: ``0`` .
|
|
2576
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
2577
|
+
Ascend. Default: ``None``.
|
|
2578
|
+
|
|
2579
|
+
Raises:
|
|
2580
|
+
TypeError: If `group` is not a str or `src` is not an integer.
|
|
2581
|
+
TypeError: If size of `scatter_object_input_list` is not equal to group size.
|
|
2582
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
2583
|
+
|
|
2584
|
+
Supported Platforms:
|
|
2585
|
+
``Ascend``
|
|
2586
|
+
|
|
2587
|
+
Examples:
|
|
2588
|
+
.. note::
|
|
2589
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2590
|
+
|
|
2591
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
2592
|
+
without any third-party or configuration file dependencies.
|
|
2593
|
+
Please see the `msrun start up
|
|
2594
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2595
|
+
for more details.
|
|
2596
|
+
|
|
2597
|
+
This example should be run with 2 devices.
|
|
2598
|
+
|
|
2599
|
+
>>> from mindspore.mint.distributed import init_process_group, scatter_object_list
|
|
2600
|
+
>>> init_process_group()
|
|
2601
|
+
>>> obj = ["test", {1: 2}]
|
|
2602
|
+
>>> scatter_object_output_list=[None]
|
|
2603
|
+
>>> scatter_object_list(scatter_object_output_list, obj)
|
|
2604
|
+
>>> print(scatter_object_output_list)
|
|
2605
|
+
# rank_0
|
|
2606
|
+
['test']
|
|
2607
|
+
# rank_1
|
|
2608
|
+
[{1: 2}]
|
|
2609
|
+
"""
|
|
2610
|
+
if group is None:
|
|
2611
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2612
|
+
if not isinstance(group, str):
|
|
2613
|
+
raise TypeError(
|
|
2614
|
+
"For 'scatter_object_list', the argument 'group' must be type of string, "
|
|
2615
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2616
|
+
)
|
|
2617
|
+
if not isinstance(scatter_object_output_list, list) or not scatter_object_output_list:
|
|
2618
|
+
raise TypeError(f"The scatter_object_output_list can not be empty.")
|
|
2619
|
+
if not isinstance(src, int):
|
|
2620
|
+
raise TypeError("For scatter_object_list, the src must be int")
|
|
2621
|
+
group_size = get_cache_group_size(group)
|
|
2622
|
+
rank_id = get_cache_group_rank()
|
|
2623
|
+
tensor_sizes = []
|
|
2624
|
+
tensor_list = []
|
|
2625
|
+
if rank_id == src:
|
|
2626
|
+
if not isinstance(scatter_object_input_list, list) or len(scatter_object_input_list) != group_size:
|
|
2627
|
+
raise TypeError(
|
|
2628
|
+
"The len of scatter_object_input_list must be equal to group rank size, "
|
|
2629
|
+
"but got {len(scatter_object_input_list)}."
|
|
2630
|
+
)
|
|
2631
|
+
for obj in scatter_object_input_list:
|
|
2632
|
+
_, size = _object_to_tensor(obj)
|
|
2633
|
+
tensor_sizes.append(Tensor([size], dtype=mstype.int32))
|
|
2634
|
+
max_size = int(max(tensor_sizes).item())
|
|
2635
|
+
for obj in scatter_object_input_list:
|
|
2636
|
+
tensor, _ = _object_to_tensor(obj, max_size)
|
|
2637
|
+
tensor_list.append(tensor)
|
|
2638
|
+
else:
|
|
2639
|
+
tensor_sizes = [Tensor([0], dtype=mstype.int32) for i in range(group_size)]
|
|
2640
|
+
|
|
2641
|
+
object_size = cat(tensor_sizes)
|
|
2642
|
+
broadcast(object_size, src, group)
|
|
2643
|
+
max_object_size = int(max(object_size).item())
|
|
2644
|
+
data = np.zeros((max_object_size)).astype(np.int8)
|
|
2645
|
+
if rank_id != src:
|
|
2646
|
+
tensor_list = [Tensor(data) for i in range(group_size)]
|
|
2647
|
+
out_tensor = Tensor(data)
|
|
2648
|
+
scatter(out_tensor, tensor_list, src, group)
|
|
2649
|
+
group_id = get_group_rank_from_world_rank(rank_id, group)
|
|
2650
|
+
scatter_object_output_list[0] = _tensor_to_object(out_tensor, object_size[group_id])
|
|
2651
|
+
|
|
2652
|
+
|
|
2653
|
+
def gather_object(obj, object_gather_list=None, dst=0, group=None):
|
|
2654
|
+
r"""
|
|
2655
|
+
Gathers python objects from the whole group in a single process.
|
|
2656
|
+
|
|
2657
|
+
Note:
|
|
2658
|
+
- Similar to :func:`mindspore.mint.distributed.gather`, but Python objects can be passed in.
|
|
2659
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
2660
|
+
|
|
2661
|
+
Args:
|
|
2662
|
+
obj (Any): The python objects to be gathered.
|
|
2663
|
+
object_gather_list (list[Any], optional): List of same-sized tensors to use for gathered data.
|
|
2664
|
+
On the ``dst`` rank, it should be correctly sized as the size of the group for this
|
|
2665
|
+
collective and will contain the output. Default: ``None``.
|
|
2666
|
+
dst (int, optional): Specifies the rank(global rank) of the process that receive the tensor.
|
|
2667
|
+
And only process `dst` will receive the gathered tensor. Default: ``0`` .
|
|
2668
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
2669
|
+
Ascend. Default: ``None``.
|
|
2670
|
+
|
|
2671
|
+
Raises:
|
|
2672
|
+
TypeError: If dst is not an integer, or group is not a string.
|
|
2673
|
+
TypeError: If size of `object_gather_list` is not equal to group size.
|
|
2674
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
2675
|
+
|
|
2676
|
+
Supported Platforms:
|
|
2677
|
+
``Ascend``
|
|
2678
|
+
|
|
2679
|
+
Examples:
|
|
2680
|
+
.. note::
|
|
2681
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2682
|
+
|
|
2683
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
2684
|
+
without any third-party or configuration file dependencies.
|
|
2685
|
+
Please see the `msrun start up
|
|
2686
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2687
|
+
for more details.
|
|
2688
|
+
|
|
2689
|
+
This example should be run with 2 devices.
|
|
2690
|
+
|
|
2691
|
+
>>> from mindspore.mint.distributed import init_process_group, gather_object, get_rank
|
|
2692
|
+
>>> init_process_group()
|
|
2693
|
+
>>> rank = get_rank()
|
|
2694
|
+
>>> obj = ["test", {1: 2}]
|
|
2695
|
+
>>> object_gather_list=[None, None]
|
|
2696
|
+
>>> gather_object(obj[rank], object_gather_list)
|
|
2697
|
+
>>> print(object_gather_list)
|
|
2698
|
+
# rank_0
|
|
2699
|
+
['test', {1: 2}]
|
|
2700
|
+
"""
|
|
2701
|
+
if group is None:
|
|
2702
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2703
|
+
if not isinstance(group, str):
|
|
2704
|
+
raise TypeError(
|
|
2705
|
+
"For 'gather_object', the argument 'group' must be type of string, "
|
|
2706
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2707
|
+
)
|
|
2708
|
+
if not isinstance(dst, int):
|
|
2709
|
+
raise TypeError("For gather_object, the dst must be int")
|
|
2710
|
+
group_size = get_cache_group_size(group)
|
|
2711
|
+
rank_id = get_cache_group_rank()
|
|
2712
|
+
if rank_id == dst:
|
|
2713
|
+
if not isinstance(object_gather_list, list) or len(object_gather_list) != group_size:
|
|
2714
|
+
raise TypeError(
|
|
2715
|
+
f"The len of object_gather_list must be equal to group rank size, but got {len(object_gather_list)}."
|
|
2716
|
+
)
|
|
2717
|
+
_, size = _object_to_tensor(obj)
|
|
2718
|
+
tensor = Tensor([size], dtype=mstype.int32)
|
|
2719
|
+
object_size_list = [Tensor([0], dtype=mstype.int32) for i in range(group_size)]
|
|
2720
|
+
all_gather(object_size_list, tensor, group=group)
|
|
2721
|
+
max_object_size = int(max(object_size_list).item())
|
|
2722
|
+
in_tensor, size = _object_to_tensor(obj, max_object_size)
|
|
2723
|
+
data = np.zeros((size)).astype(np.int8)
|
|
2724
|
+
object_tensor_list = [Tensor(data) for i in range(group_size)]
|
|
2725
|
+
gather(in_tensor, object_tensor_list, dst, group)
|
|
2726
|
+
if rank_id != dst:
|
|
2727
|
+
return
|
|
2728
|
+
for i, item in enumerate(object_size_list):
|
|
2729
|
+
tensor_size = int(item.item())
|
|
2730
|
+
tensor = object_tensor_list[i]
|
|
2731
|
+
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
|
|
2732
|
+
|
|
2733
|
+
|
|
2734
|
+
def broadcast_object_list(object_list, src=0, group=None, device=None):
|
|
2735
|
+
"""
|
|
2736
|
+
Broadcasts the entire group of input Python objects.
|
|
2737
|
+
|
|
2738
|
+
Note:
|
|
2739
|
+
- Similar to :func:`mindspore.mint.distributed.broadcast`, but Python objects can be passed in.
|
|
2740
|
+
- Only support PyNative mode, Graph mode is not currently supported.
|
|
2741
|
+
|
|
2742
|
+
Args:
|
|
2743
|
+
object_list (list[Any]): list of input to be sent if src is the rank of current process,
|
|
2744
|
+
and list to be used to save received data otherwise.
|
|
2745
|
+
src (int, optional): Specifies the rank(global rank) of the process that broadcast the Python objects.
|
|
2746
|
+
And only process `src` will broadcast the Python objects. Default: ``0`` .
|
|
2747
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
2748
|
+
Ascend. Default: ``None``.
|
|
2749
|
+
device (str, optional): Currently it is a reserved parameter. Default: ``None``.
|
|
2750
|
+
|
|
2751
|
+
Raises:
|
|
2752
|
+
TypeError: If `src` is not an integer or `group` is not a string.
|
|
2753
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
2754
|
+
|
|
2755
|
+
Supported Platforms:
|
|
2756
|
+
``Ascend``
|
|
2757
|
+
|
|
2758
|
+
Examples:
|
|
2759
|
+
.. note::
|
|
2760
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2761
|
+
|
|
2762
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
2763
|
+
without any third-party or configuration file dependencies.
|
|
2764
|
+
Please see the `msrun start up
|
|
2765
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2766
|
+
for more details.
|
|
2767
|
+
|
|
2768
|
+
This example should be run with 2 devices.
|
|
2769
|
+
|
|
2770
|
+
>>> from mindspore.mint.distributed import init_process_group, broadcast_object_list, get_rank
|
|
2771
|
+
>>> init_process_group()
|
|
2772
|
+
>>> rank = get_rank()
|
|
2773
|
+
>>> obj = ["test", 12, {1: 2}]
|
|
2774
|
+
>>> if rank == 1:
|
|
2775
|
+
... obj = [None, None, None]
|
|
2776
|
+
>>> broadcast_object_list(obj)
|
|
2777
|
+
>>> print(obj)
|
|
2778
|
+
['test', 12, {1: 2}]
|
|
2779
|
+
"""
|
|
2780
|
+
if group is None:
|
|
2781
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2782
|
+
if not isinstance(group, str):
|
|
2783
|
+
raise TypeError(
|
|
2784
|
+
"For 'broadcast_object_list', the argument 'group' must be type of string, "
|
|
2785
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2786
|
+
)
|
|
2787
|
+
if not isinstance(src, int):
|
|
2788
|
+
raise TypeError("For broadcast_object_list, the src must be int")
|
|
2789
|
+
if not isinstance(object_list, list) or not object_list:
|
|
2790
|
+
raise TypeError(f"The object_list can not be empty.")
|
|
2791
|
+
rank_id = get_cache_group_rank()
|
|
2792
|
+
tensor_sizes = []
|
|
2793
|
+
tensor_list = []
|
|
2794
|
+
size = 0
|
|
2795
|
+
object_size_list = [Tensor([0], dtype=mstype.int32) for i in range(len(object_list))]
|
|
2796
|
+
if rank_id == src:
|
|
2797
|
+
tensor_list, tensor_sizes = zip(
|
|
2798
|
+
*[_object_to_tensor(obj) for obj in object_list]
|
|
2799
|
+
)
|
|
2800
|
+
object_size_list = [Tensor([tensor_sizes[i]], dtype=mstype.int32) for i in range(len(tensor_sizes))]
|
|
2801
|
+
object_tensor = cat(tensor_list)
|
|
2802
|
+
object_size = cat(object_size_list)
|
|
2803
|
+
broadcast(object_size, src, group)
|
|
2804
|
+
size = int(sum(object_size).item())
|
|
2805
|
+
if rank_id != src:
|
|
2806
|
+
data = np.zeros((size)).astype(np.int8)
|
|
2807
|
+
object_tensor = Tensor(data)
|
|
2808
|
+
broadcast(object_tensor, src, group)
|
|
2809
|
+
if rank_id != src:
|
|
2810
|
+
offset = 0
|
|
2811
|
+
for i, item in enumerate(object_size):
|
|
2812
|
+
obj_size = item
|
|
2813
|
+
obj_view = object_tensor[offset : offset + obj_size]
|
|
2814
|
+
offset += obj_size
|
|
2815
|
+
object_list[i] = _tensor_to_object(obj_view, obj_size)
|
|
2816
|
+
|
|
2817
|
+
|
|
2818
|
+
def all_gather_object(object_list, obj, group=None):
|
|
2819
|
+
"""
|
|
2820
|
+
Aggregates Python objects in a specified communication group.
|
|
2821
|
+
|
|
2822
|
+
Note:
|
|
2823
|
+
Similar to :func:`mindspore.mint.distributed.all_gather`, but Python objects can be passed in.
|
|
2824
|
+
|
|
2825
|
+
Args:
|
|
2826
|
+
object_list (list[Any]): Output Python object list.
|
|
2827
|
+
obj (Any): Python object to be broadcast from current process.
|
|
2828
|
+
group (str, optional): The communication group to work on. If ``None``, which means ``"hccl_world_group"`` in
|
|
2829
|
+
Ascend. Default: ``None``.
|
|
2830
|
+
|
|
2831
|
+
Raises:
|
|
2832
|
+
TypeError: `group` is not a str.
|
|
2833
|
+
TypeError: If size of `object_list` is not equal to group size.
|
|
2834
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
2835
|
+
|
|
2836
|
+
Supported Platforms:
|
|
2837
|
+
``Ascend``
|
|
2838
|
+
|
|
2839
|
+
Examples:
|
|
2840
|
+
.. note::
|
|
2841
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2842
|
+
|
|
2843
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
2844
|
+
without any third-party or configuration file dependencies.
|
|
2845
|
+
Please see the `msrun start up
|
|
2846
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2847
|
+
for more details.
|
|
2848
|
+
|
|
2849
|
+
This example should be run with 2 devices.
|
|
2850
|
+
|
|
2851
|
+
>>> from mindspore.mint.distributed import init_process_group, get_rank
|
|
2852
|
+
>>> from mindspore.mint.distributed import all_gather_object
|
|
2853
|
+
>>> init_process_group()
|
|
2854
|
+
>>> rank = get_rank()
|
|
2855
|
+
>>> obj = ["test", {1: 2}]
|
|
2856
|
+
>>> object_gather_list=[None, None]
|
|
2857
|
+
>>> all_gather_object(object_gather_list, obj[rank])
|
|
2858
|
+
>>> print(object_gather_list)
|
|
2859
|
+
# rank_0
|
|
2860
|
+
['test', {1: 2}]
|
|
2861
|
+
# rank_1
|
|
2862
|
+
['test', {1: 2}]
|
|
2863
|
+
"""
|
|
2864
|
+
if group is None:
|
|
2865
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
2866
|
+
if not isinstance(group, str):
|
|
2867
|
+
raise TypeError(
|
|
2868
|
+
"For 'all_gather_object', the argument 'group' must be type of string, "
|
|
2869
|
+
"but got 'group' type : {}.".format(type(group))
|
|
2870
|
+
)
|
|
2871
|
+
group_size = get_cache_group_size(group)
|
|
2872
|
+
if not isinstance(object_list, list) or len(object_list) != group_size:
|
|
2873
|
+
raise TypeError(
|
|
2874
|
+
f"The len of argument object_list must be equal to group rank size, but got {len(object_list)}."
|
|
2875
|
+
)
|
|
2876
|
+
_, size = _object_to_tensor(obj)
|
|
2877
|
+
tensor = Tensor([size], dtype=mstype.int32)
|
|
2878
|
+
object_size_list = [Tensor([0], dtype=mstype.int32) for i in range(group_size)]
|
|
2879
|
+
all_gather(object_size_list, tensor, group=group)
|
|
2880
|
+
max_object_size = int(max(object_size_list).item())
|
|
2881
|
+
in_tensor, size = _object_to_tensor(obj, max_object_size)
|
|
2882
|
+
data = np.zeros((size)).astype(np.int8)
|
|
2883
|
+
object_tensor_list = [Tensor(data) for i in range(group_size)]
|
|
2884
|
+
all_gather(object_tensor_list, in_tensor, group=group)
|
|
2885
|
+
|
|
2886
|
+
for i, item in enumerate(object_size_list):
|
|
2887
|
+
tensor_size = int(item.item())
|
|
2888
|
+
tensor = object_tensor_list[i]
|
|
2889
|
+
object_list[i] = _tensor_to_object(tensor, tensor_size)
|