mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
""" Distributed data parallel wrapper. """
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
__all__ = ["DistributedDataParallel"]
|
|
19
|
+
|
|
20
|
+
import itertools
|
|
21
|
+
from contextlib import contextmanager
|
|
22
|
+
from typing import Optional
|
|
23
|
+
import mindspore.nn as nn
|
|
24
|
+
import mindspore.log as logger
|
|
25
|
+
from mindspore import Tensor, mint
|
|
26
|
+
from mindspore.common import dtype as mstype
|
|
27
|
+
from mindspore.mint.distributed import get_world_size
|
|
28
|
+
from mindspore.communication import GlobalComm
|
|
29
|
+
from mindspore.common.api import _pynative_executor
|
|
30
|
+
from mindspore.mint.distributed import broadcast, get_global_rank
|
|
31
|
+
from mindspore.parallel.distributed.flatten_grad_buffer import FlattenGradBuffer
|
|
32
|
+
from mindspore._c_expression import Reducer, _find_unused_parameters
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_data_parallel_group():
|
|
36
|
+
"""get default global data parallel group"""
|
|
37
|
+
return GlobalComm.WORLD_COMM_GROUP
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_data_parallel_world_size(group):
|
|
41
|
+
"""get group world size"""
|
|
42
|
+
return get_world_size(group)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _find_tensors(obj):
|
|
46
|
+
if isinstance(obj, Tensor):
|
|
47
|
+
return [obj]
|
|
48
|
+
if isinstance(obj, (list, tuple)):
|
|
49
|
+
return itertools.chain.from_iterable(map(_find_tensors, obj))
|
|
50
|
+
if isinstance(obj, dict):
|
|
51
|
+
return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
|
|
52
|
+
|
|
53
|
+
return []
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DistributedDataParallel(nn.Cell):
|
|
57
|
+
"""
|
|
58
|
+
DistributedDataParallel wrapper. DistributedDataParallel allocates contiguous memory buffer for gradients.
|
|
59
|
+
Parameters' gradients will be combined into multiple buckets which are the unit to conduct all-reduce
|
|
60
|
+
communication among data parallel group to overlap communication latency.
|
|
61
|
+
|
|
62
|
+
.. warning::
|
|
63
|
+
- The method is currently only supported in PyNative mode.
|
|
64
|
+
- This is an experimental interface, may be changed or canceled in the future.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
module (nn.Cell): the module to be wrapped with DDP.
|
|
68
|
+
init_sync (bool, optional): whether to sync params from rank0 of process_group when init. Default: ``True``.
|
|
69
|
+
process_group (str, optional): the comm group of data prallel. Default: ``None``.
|
|
70
|
+
bucket_cap_mb (int, optional): size of bucket in MB, default is 25MB if not set. Default: ``None``.
|
|
71
|
+
find_unused_parameters (bool, optional): whether to find unused params in the bucket. Default: ``False``.
|
|
72
|
+
average_in_collective (bool, optional): True means allreduce sum within DP group firstly then scaling with
|
|
73
|
+
dp size. Otherwise scaling local rank grad first and then allreduce sum. Default: ``False``.
|
|
74
|
+
static_graph (bool, optional): Indicate whether it is a static network. When it is a static network, the
|
|
75
|
+
parameter `find_unused_parameters` will be ignored, and unused parameters will be searched for in the
|
|
76
|
+
first step. Bucket reconstruction will be performed in execution order before the second step to achieve
|
|
77
|
+
better performance. Default: ``False``.
|
|
78
|
+
reducer_mode (str, optional): the backend to be used, could be "CppReducer" for cpp backend or "PythonReducer"
|
|
79
|
+
for Python backend. Default: ``"CppReducer"``.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Model wrapped with DistributedDataParallel.
|
|
83
|
+
|
|
84
|
+
Supported Platforms:
|
|
85
|
+
``Ascend``
|
|
86
|
+
|
|
87
|
+
Examples:
|
|
88
|
+
.. note::
|
|
89
|
+
- When enabling recomputation or gradient freezing, the model should be wrapped by
|
|
90
|
+
`DistributedDataParallel` at the outermost layer.
|
|
91
|
+
- Before running the following examples, you need to configure the communication environment variables.
|
|
92
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
93
|
+
without any third-party or configuration file dependencies. For detailed information, refer to
|
|
94
|
+
`msrun launch <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_ .
|
|
95
|
+
|
|
96
|
+
>>> from mindspore.parallel.distributed import DistributedDataParallel
|
|
97
|
+
>>> from mindspore.mint.optim import AdamW
|
|
98
|
+
>>> from mindspore import Parameter, Tensor, ops, nn
|
|
99
|
+
>>> import mindspore as ms
|
|
100
|
+
>>> from mindspore.communication import init
|
|
101
|
+
>>> from mindspore.mint.distributed.distributed import init_process_group
|
|
102
|
+
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
103
|
+
>>> init_process_group()
|
|
104
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
105
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
106
|
+
>>> net = LeNet5()
|
|
107
|
+
>>> net = DistributedDataParallel(module=net,
|
|
108
|
+
... bucket_cap_mb=None,
|
|
109
|
+
... average_in_collective=True,
|
|
110
|
+
... static_graph=True)
|
|
111
|
+
>>> optimizer = AdamW(net.trainable_params(), 1e-4)
|
|
112
|
+
>>> loss_fn = nn.CrossEntropyLoss()
|
|
113
|
+
>>>
|
|
114
|
+
>>> def forward_fn(data, target):
|
|
115
|
+
... logits = net(data)
|
|
116
|
+
... loss = loss_fn(logits, target)
|
|
117
|
+
... return loss, logits
|
|
118
|
+
>>>
|
|
119
|
+
>>> grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True)
|
|
120
|
+
>>>
|
|
121
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
122
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
123
|
+
>>> dataset = create_dataset()
|
|
124
|
+
>>> for epoch in range(1):
|
|
125
|
+
... step = 0
|
|
126
|
+
... for image, label in dataset:
|
|
127
|
+
... (loss_value, _), grads = grad_fn(image, label)
|
|
128
|
+
... optimizer(grads)
|
|
129
|
+
... net.zero_grad()
|
|
130
|
+
... step += 1
|
|
131
|
+
... print("epoch: %s, step: %s, loss is %.15f" % (epoch, step, loss_value))
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(self, module, init_sync=True, process_group=None, bucket_cap_mb: Optional[int] = None,
|
|
135
|
+
find_unused_parameters=False, average_in_collective: bool = False, static_graph=False,
|
|
136
|
+
reducer_mode="CppReducer"):
|
|
137
|
+
super(DistributedDataParallel, self).__init__(auto_prefix=False)
|
|
138
|
+
self.init_sync = init_sync
|
|
139
|
+
self.bucket_cap_mb = bucket_cap_mb
|
|
140
|
+
self.average_in_collective = average_in_collective
|
|
141
|
+
self.grad_reduce_in_fp32 = False
|
|
142
|
+
self.process_group = process_group if process_group else get_data_parallel_group()
|
|
143
|
+
self.static_graph = static_graph
|
|
144
|
+
self.find_unused_parameters = find_unused_parameters
|
|
145
|
+
|
|
146
|
+
self.module = module
|
|
147
|
+
self.param_to_buffer = {}
|
|
148
|
+
self.has_buckets_grad_sync = False
|
|
149
|
+
|
|
150
|
+
# default is 25MB for each buck
|
|
151
|
+
if bucket_cap_mb is None:
|
|
152
|
+
bucket_cap_mb = 25
|
|
153
|
+
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
|
|
154
|
+
|
|
155
|
+
# grads sync with allreduce comm
|
|
156
|
+
self.sync_enabled = True
|
|
157
|
+
self.reducer_mode = reducer_mode # "CppReducer" or "PythonReducer"
|
|
158
|
+
self.buffers = []
|
|
159
|
+
self.has_mark_unused_param = False
|
|
160
|
+
|
|
161
|
+
bucketed_params = []
|
|
162
|
+
self.skipped_params = []
|
|
163
|
+
for _, param in self.module.parameters_and_names():
|
|
164
|
+
if not param.requires_grad:
|
|
165
|
+
self.skipped_params.append(param)
|
|
166
|
+
continue
|
|
167
|
+
param.grad = None
|
|
168
|
+
param.main_grad = None
|
|
169
|
+
bucketed_params.append(param)
|
|
170
|
+
if self.average_in_collective:
|
|
171
|
+
# allreduce to add grads, then to scale grads with dp size
|
|
172
|
+
self.gradient_scaling_factor = 1.0
|
|
173
|
+
else:
|
|
174
|
+
# scale grads with dp size locally, then allreduce to add grads
|
|
175
|
+
data_parallel_world_size = get_data_parallel_world_size(self.process_group)
|
|
176
|
+
self.gradient_scaling_factor = 1.0 / data_parallel_world_size
|
|
177
|
+
self.bucketed_params = bucketed_params
|
|
178
|
+
|
|
179
|
+
if self.reducer_mode == "CppReducer":
|
|
180
|
+
self.reducer = Reducer(self.bucketed_params,
|
|
181
|
+
self.process_group,
|
|
182
|
+
bucket_cap_mb,
|
|
183
|
+
self.grad_reduce_in_fp32,
|
|
184
|
+
average_in_collective,
|
|
185
|
+
static_graph,
|
|
186
|
+
find_unused_parameters)
|
|
187
|
+
if self.init_sync:
|
|
188
|
+
self.broadcast_coalesced()
|
|
189
|
+
return
|
|
190
|
+
# allocate buffer for trained params
|
|
191
|
+
self.buffers = self.allocate_buffers_for_parameters(
|
|
192
|
+
self.bucketed_params,
|
|
193
|
+
group=self.process_group,
|
|
194
|
+
gradient_scaling_factor=self.gradient_scaling_factor,
|
|
195
|
+
)
|
|
196
|
+
if self.init_sync:
|
|
197
|
+
self.broadcast_coalesced()
|
|
198
|
+
|
|
199
|
+
# register hook for bucket grad reduce
|
|
200
|
+
self._register_hook_for_params()
|
|
201
|
+
|
|
202
|
+
# bucket rebuilding
|
|
203
|
+
self.rebuilt_params_ = []
|
|
204
|
+
self.buffer_iterations = 0
|
|
205
|
+
self.has_bucket_rebuilt = False
|
|
206
|
+
self.buffer_issued = 0
|
|
207
|
+
self.triggered_once = False
|
|
208
|
+
|
|
209
|
+
def _group_params_by_dtype(self, input_params):
|
|
210
|
+
param_and_grad_dtype_to_params = {}
|
|
211
|
+
# group all params by parameter's data type and their gradient's data type.
|
|
212
|
+
for param in input_params:
|
|
213
|
+
param_dtype = param.dtype
|
|
214
|
+
grad_dtype = mstype.float32 if self.grad_reduce_in_fp32 else param.dtype
|
|
215
|
+
if (param_dtype, grad_dtype) not in param_and_grad_dtype_to_params:
|
|
216
|
+
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = []
|
|
217
|
+
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)].append(param)
|
|
218
|
+
return param_and_grad_dtype_to_params
|
|
219
|
+
|
|
220
|
+
def allocate_buffers_for_parameters(self, input_params, group, gradient_scaling_factor):
|
|
221
|
+
"""allocate buffers for parameters in different dtype group."""
|
|
222
|
+
param_and_grad_dtype_to_params = self._group_params_by_dtype(input_params)
|
|
223
|
+
|
|
224
|
+
buffers = []
|
|
225
|
+
# allocate buffer for each group separately
|
|
226
|
+
for (param_dtype, grad_dtype,), params in param_and_grad_dtype_to_params.items():
|
|
227
|
+
buffers.append(
|
|
228
|
+
FlattenGradBuffer(
|
|
229
|
+
average_in_collective=self.average_in_collective,
|
|
230
|
+
param_dtype=param_dtype,
|
|
231
|
+
grad_dtype=grad_dtype,
|
|
232
|
+
params=params,
|
|
233
|
+
data_parallel_group=group,
|
|
234
|
+
bucket_size=self.bucket_bytes_cap,
|
|
235
|
+
gradient_scaling_factor=gradient_scaling_factor,
|
|
236
|
+
ddp_handle=self,
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
for param in params:
|
|
240
|
+
self.param_to_buffer[param] = buffers[-1]
|
|
241
|
+
logger.debug("allocate buffers for parameters: %s", buffers)
|
|
242
|
+
return buffers
|
|
243
|
+
|
|
244
|
+
def final_grad_reduce(self):
|
|
245
|
+
"""trigger final grad reduction"""
|
|
246
|
+
logger.debug("trigger ddp final grad reduce, %d, %d", self.static_graph, len(self.unused_param))
|
|
247
|
+
if self._should_rebuild_buckets():
|
|
248
|
+
for param in self.unused_param:
|
|
249
|
+
self.rebuilt_params_.append(param)
|
|
250
|
+
for buffer in self.buffers:
|
|
251
|
+
buffer.final_grad_reduce()
|
|
252
|
+
buffer.issued = 0
|
|
253
|
+
self.buffer_issued = 0
|
|
254
|
+
|
|
255
|
+
def _register_hook_for_params(self):
|
|
256
|
+
"""register backward hook for each params."""
|
|
257
|
+
for param in self.module.get_parameters():
|
|
258
|
+
if param.requires_grad:
|
|
259
|
+
param.register_hook(self._make_param_hook(param))
|
|
260
|
+
|
|
261
|
+
def _post_forward(self, output):
|
|
262
|
+
"""prepare for backward (e.g. find unused params) if needed"""
|
|
263
|
+
if self.reducer_mode == "CppReducer":
|
|
264
|
+
if _pynative_executor.grad_flag() and self.sync_enabled:
|
|
265
|
+
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
|
266
|
+
else:
|
|
267
|
+
unused_param_idx = []
|
|
268
|
+
if self.static_graph and not self.triggered_once:
|
|
269
|
+
self.triggered_once = True
|
|
270
|
+
self.find_unused_parameters = False
|
|
271
|
+
unused_param_idx = _find_unused_parameters(list(_find_tensors(output)), self.bucketed_params)
|
|
272
|
+
elif self.find_unused_parameters:
|
|
273
|
+
unused_param_idx = _find_unused_parameters(list(_find_tensors(output)), self.bucketed_params)
|
|
274
|
+
self.unused_param = [self.bucketed_params[idx] for idx in unused_param_idx]
|
|
275
|
+
self.unused_param_name = [param.name for param in self.unused_param]
|
|
276
|
+
self.has_mark_unused_param = False
|
|
277
|
+
|
|
278
|
+
def _pre_forward(self):
|
|
279
|
+
"""pre-process of forward pass to allocate buffer for parameters."""
|
|
280
|
+
if self.reducer_mode == "CppReducer":
|
|
281
|
+
if _pynative_executor.grad_flag() and self.sync_enabled:
|
|
282
|
+
self.reducer.prepare_for_forward()
|
|
283
|
+
self.reducer.rebuild_buckets()
|
|
284
|
+
return
|
|
285
|
+
if self.rebuilt_params_ and self._should_rebuild_buckets():
|
|
286
|
+
for i in self.rebuilt_params_:
|
|
287
|
+
i.old_grad = i.grad
|
|
288
|
+
|
|
289
|
+
self.buffers = self.allocate_buffers_for_parameters(
|
|
290
|
+
self.rebuilt_params_,
|
|
291
|
+
group=self.process_group,
|
|
292
|
+
gradient_scaling_factor=self.gradient_scaling_factor,
|
|
293
|
+
)
|
|
294
|
+
for buffer in self.buffers:
|
|
295
|
+
buffer.sync_enabled = self.sync_enabled
|
|
296
|
+
|
|
297
|
+
for i in self.rebuilt_params_:
|
|
298
|
+
i.grad.copy_(i.old_grad)
|
|
299
|
+
i.old_grad = None
|
|
300
|
+
|
|
301
|
+
logger.debug("register unused param: %s", self.rebuilt_params_)
|
|
302
|
+
self.has_bucket_rebuilt = True
|
|
303
|
+
self.rebuilt_params_ = []
|
|
304
|
+
|
|
305
|
+
def construct(self, *inputs, **inputs_dict):
|
|
306
|
+
"""construct for DistributedDataParallel."""
|
|
307
|
+
self._pre_forward()
|
|
308
|
+
output = self.module(*inputs, **inputs_dict)
|
|
309
|
+
self._post_forward(output)
|
|
310
|
+
return output
|
|
311
|
+
|
|
312
|
+
def zero_grad(self):
|
|
313
|
+
"""DPP will accumulate grads automatically, it will zero grads when call zero_grad() manually."""
|
|
314
|
+
if self.reducer_mode == "CppReducer":
|
|
315
|
+
self.reducer.zero_grad()
|
|
316
|
+
else:
|
|
317
|
+
for buffer in self.buffers:
|
|
318
|
+
buffer.reset()
|
|
319
|
+
|
|
320
|
+
def _enable_sync(self, enable):
|
|
321
|
+
"""enable grad buffer sync or not."""
|
|
322
|
+
for buffer in self.buffers:
|
|
323
|
+
buffer.sync_enabled = enable
|
|
324
|
+
self.sync_enabled = enable
|
|
325
|
+
|
|
326
|
+
@contextmanager
|
|
327
|
+
def no_sync(self):
|
|
328
|
+
"""Context manager helper function. When enabled, no grad allreduce synchronization will be executed."""
|
|
329
|
+
self._enable_sync(False)
|
|
330
|
+
try:
|
|
331
|
+
yield
|
|
332
|
+
finally:
|
|
333
|
+
self._enable_sync(True)
|
|
334
|
+
|
|
335
|
+
def _should_rebuild_buckets(self):
|
|
336
|
+
if self.static_graph and not self.has_bucket_rebuilt:
|
|
337
|
+
return True
|
|
338
|
+
return False
|
|
339
|
+
|
|
340
|
+
def _make_param_hook(self, param):
|
|
341
|
+
"""make closure function as the param hook."""
|
|
342
|
+
def param_hook(grad):
|
|
343
|
+
if not self.has_mark_unused_param:
|
|
344
|
+
for cur_param in self.unused_param:
|
|
345
|
+
buffer = self.param_to_buffer[cur_param]
|
|
346
|
+
logger.debug("register unused param: %s", cur_param)
|
|
347
|
+
buffer.register_grad_ready(cur_param)
|
|
348
|
+
self.has_mark_unused_param = True
|
|
349
|
+
elif param.name in self.unused_param_name:
|
|
350
|
+
logger.debug("unused param already registered: %s", param)
|
|
351
|
+
return param.grad
|
|
352
|
+
|
|
353
|
+
logger.debug("register normal param: %s", param)
|
|
354
|
+
buffer = self.param_to_buffer[param]
|
|
355
|
+
param.grad.add_(grad)
|
|
356
|
+
buffer.register_grad_ready(param)
|
|
357
|
+
if self._should_rebuild_buckets():
|
|
358
|
+
self.rebuilt_params_.append(param)
|
|
359
|
+
return param.grad
|
|
360
|
+
|
|
361
|
+
return param_hook
|
|
362
|
+
|
|
363
|
+
def broadcast_coalesced(self):
|
|
364
|
+
"""broadcast params from rank 0"""
|
|
365
|
+
if self.reducer_mode == "CppReducer":
|
|
366
|
+
buckets = [[self.bucketed_params[idx] for idx in bucket] for bucket in self.reducer.bucket_indices]
|
|
367
|
+
else:
|
|
368
|
+
buckets = [bucket.params_list for buffer in self.buffers for bucket in buffer.buckets]
|
|
369
|
+
if self.skipped_params:
|
|
370
|
+
param_and_grad_dtype_to_params = self._group_params_by_dtype(self.skipped_params)
|
|
371
|
+
for params_list in param_and_grad_dtype_to_params.values():
|
|
372
|
+
buckets.append(params_list)
|
|
373
|
+
|
|
374
|
+
def finish(rate_limiter):
|
|
375
|
+
for _ in rate_limiter:
|
|
376
|
+
handle, coalesced, params = rate_limiter.pop(0)
|
|
377
|
+
handle.wait()
|
|
378
|
+
ptr = 0
|
|
379
|
+
for param in params:
|
|
380
|
+
param.view(-1).copy_(coalesced[ptr:ptr + param.numel()])
|
|
381
|
+
ptr += param.numel()
|
|
382
|
+
|
|
383
|
+
rate_limiter = []
|
|
384
|
+
for params in buckets:
|
|
385
|
+
flat_tensors = [t.view(-1) for t in params]
|
|
386
|
+
coalesced = mint.cat(flat_tensors)
|
|
387
|
+
global_rank = get_global_rank(self.process_group, 0)
|
|
388
|
+
handle = broadcast(coalesced, src=global_rank, group=self.process_group, async_op=True)
|
|
389
|
+
rate_limiter.append((handle, coalesced, params))
|
|
390
|
+
|
|
391
|
+
if len(rate_limiter) >= 2:
|
|
392
|
+
finish(rate_limiter)
|
|
393
|
+
finish(rate_limiter)
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
""" Param and grad buffer, bucket implemenatrion. """
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
__all__ = ["Bucket", "FlattenGradBuffer"]
|
|
19
|
+
|
|
20
|
+
from enum import Enum
|
|
21
|
+
import numpy as np
|
|
22
|
+
from mindspore import mint, Tensor
|
|
23
|
+
from mindspore.common.initializer import Zero
|
|
24
|
+
from mindspore.communication.management import get_group_size
|
|
25
|
+
import mindspore.communication.comm_func as comm_func
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BufferType(Enum):
|
|
29
|
+
PARAM = 0
|
|
30
|
+
GRAD = 1
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
MEM_ALIGN_SIZE = 512
|
|
34
|
+
ALIGN_BYTES = 32
|
|
35
|
+
MIN_BUCKET_SIZE = int(1 * 1024 * 1024)
|
|
36
|
+
DEFAULT_BUCKET_SIZE = int(25 * 1024 * 1024)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Bucket:
|
|
40
|
+
"""
|
|
41
|
+
Bucket to track a subset of parameters and gradients in the buffer. Bucket records the parameters
|
|
42
|
+
whose gradient has already been computed. It also provide functionality to synchronize gradients among
|
|
43
|
+
data parallel group when all parameters' graidents have been computed.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
average_in_collective (bool): Scaling grads before/after AllReduce, True: scaling after AllReduce.
|
|
47
|
+
params (List(Parameters)): Parameters belongs to this bucket.
|
|
48
|
+
grad_data (Tensor): A section of buffers' gradient data, coressponding to parameters in this bucket.
|
|
49
|
+
offset (int): Start index in the buffer.
|
|
50
|
+
numel_unpadded (int): Number of unpadded elements in bucket.
|
|
51
|
+
data_parallel_group (str): Data parallel group name.
|
|
52
|
+
data_parallel_world_size (int): Data parallel group size.
|
|
53
|
+
gradient_scaling_factor (float): Work with average_in_collective, it is 1.0 when average_in_collective
|
|
54
|
+
true else 1.0/dp
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, average_in_collective, params, grad_data, offset, numel_unpadded, data_parallel_group,
|
|
58
|
+
data_parallel_world_size, gradient_scaling_factor):
|
|
59
|
+
self.average_in_collective = average_in_collective
|
|
60
|
+
self.params_list = params
|
|
61
|
+
self.params = set(params)
|
|
62
|
+
self.params_grad_ready = set()
|
|
63
|
+
self.grad_data = grad_data
|
|
64
|
+
self.grad_data_numel = self.grad_data.numel()
|
|
65
|
+
self.offset = offset
|
|
66
|
+
self.numel_unpadded = numel_unpadded
|
|
67
|
+
self.data_parallel_group = data_parallel_group
|
|
68
|
+
self.data_parallel_world_size = data_parallel_world_size
|
|
69
|
+
self.gradient_scaling_factor = gradient_scaling_factor
|
|
70
|
+
|
|
71
|
+
if self.data_parallel_world_size > 1:
|
|
72
|
+
self.grad_reducer = comm_func.all_reduce
|
|
73
|
+
|
|
74
|
+
self.reset()
|
|
75
|
+
|
|
76
|
+
def inplace_reduce_dp(self, src):
|
|
77
|
+
"""conduct all-reduce/reduce-scatter on src tensor and inplace update result into target."""
|
|
78
|
+
self.communication_result, self.communication_handle = self.grad_reducer(
|
|
79
|
+
src, "sum", self.data_parallel_group, async_op=True
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def reset(self):
|
|
83
|
+
"""reset bucket for the next iteration."""
|
|
84
|
+
self.params_grad_ready = set()
|
|
85
|
+
self.is_reduce_issued = False
|
|
86
|
+
self.communication_handle = None
|
|
87
|
+
self.communication_result = None
|
|
88
|
+
|
|
89
|
+
def issue_grad_reduce(self):
|
|
90
|
+
"""issue grad reduce for the local grad data view."""
|
|
91
|
+
if self.is_reduce_issued:
|
|
92
|
+
raise RuntimeError("The bucket reduce is already issued")
|
|
93
|
+
|
|
94
|
+
if self.gradient_scaling_factor != 1.0:
|
|
95
|
+
self.grad_data.copy_(mint.mul(self.grad_data, self.gradient_scaling_factor))
|
|
96
|
+
|
|
97
|
+
if self.data_parallel_world_size > 1:
|
|
98
|
+
self.inplace_reduce_dp(self.grad_data)
|
|
99
|
+
|
|
100
|
+
self.is_reduce_issued = True
|
|
101
|
+
|
|
102
|
+
def final_grad_reduce(self):
|
|
103
|
+
"""finalize grad reduce for the local grad data view."""
|
|
104
|
+
start_idx = 0
|
|
105
|
+
end_idx = self.grad_data_numel
|
|
106
|
+
target = self.grad_data[start_idx:end_idx]
|
|
107
|
+
|
|
108
|
+
if not self.is_reduce_issued:
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
f"The bucket reduce has not been issued "
|
|
111
|
+
f"with only {len(self.params_grad_ready)}/{len(self.params)} params ready"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if self.data_parallel_world_size > 1:
|
|
115
|
+
self.communication_handle.wait()
|
|
116
|
+
target.copy_(self.communication_result)
|
|
117
|
+
self.communication_result = None
|
|
118
|
+
if self.average_in_collective:
|
|
119
|
+
target.copy_(mint.div(target, self.data_parallel_world_size))
|
|
120
|
+
|
|
121
|
+
def register_grad_ready(self, param):
|
|
122
|
+
"""register grad ready and issue bucket grad reduce when the bucket is ready."""
|
|
123
|
+
if param not in self.params:
|
|
124
|
+
raise ValueError("The param to be registered is not in the bucket")
|
|
125
|
+
|
|
126
|
+
if param in self.params_grad_ready:
|
|
127
|
+
raise ValueError(f"The param {param} is already registered")
|
|
128
|
+
|
|
129
|
+
self.params_grad_ready.add(param)
|
|
130
|
+
if len(self.params_grad_ready) == len(self.params):
|
|
131
|
+
self.issue_grad_reduce()
|
|
132
|
+
return True
|
|
133
|
+
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
def __repr__(self):
|
|
137
|
+
return f"Bucket (offset={self.offset}, param_lens={len(self.params)})"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class FlattenGradBuffer:
|
|
141
|
+
"""
|
|
142
|
+
Allocate contiguous memory buffer for given parameters and corresponding gradients. Breaking
|
|
143
|
+
up parameters and gradients buffer into small buckets, which is the unit for all-reduce/reduce-scatter
|
|
144
|
+
communication during back-propagation.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
average_in_collective (bool): Scaling grads before/after AllReduce, True: scaling after AllReduce.
|
|
148
|
+
param_dtype (mindspore.dtype): The parameters' datatype.
|
|
149
|
+
grad_dtype (mindspore.dtype): The gradients' datatype.
|
|
150
|
+
params (List(Parameters)): Parameters belongs to this buffer.
|
|
151
|
+
data_parallel_group (str): Data parallel group name.
|
|
152
|
+
bucket_size (int): Bucket size threshold used to partition bucekts.
|
|
153
|
+
gradient_scaling_factor (float):
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(self, average_in_collective, param_dtype, grad_dtype, params, data_parallel_group,
|
|
157
|
+
bucket_size, gradient_scaling_factor, ddp_handle):
|
|
158
|
+
super(FlattenGradBuffer, self).__init__()
|
|
159
|
+
self.param_dtype = param_dtype
|
|
160
|
+
self.grad_dtype = grad_dtype
|
|
161
|
+
self.data_parallel_group = data_parallel_group
|
|
162
|
+
self.data_parallel_world_size = get_group_size(group=self.data_parallel_group)
|
|
163
|
+
self.gradient_scaling_factor = gradient_scaling_factor
|
|
164
|
+
self.average_in_collective = average_in_collective
|
|
165
|
+
|
|
166
|
+
self.buckets = []
|
|
167
|
+
self.param_index_map = {}
|
|
168
|
+
self.param_to_bucket = {}
|
|
169
|
+
self.sync_enabled = True
|
|
170
|
+
self.issued = 0
|
|
171
|
+
self.ddp_handle = ddp_handle
|
|
172
|
+
|
|
173
|
+
buckets_metadata = self.calc_partition_metadata(bucket_size, params)
|
|
174
|
+
self.instantiate_buckets(buckets_metadata, params)
|
|
175
|
+
|
|
176
|
+
def calc_partition_metadata(self, bucket_size, params):
|
|
177
|
+
"""calc bucket partition metadata"""
|
|
178
|
+
# helper func
|
|
179
|
+
def _need_new_bucket(bucket_numel, bucket_id):
|
|
180
|
+
target_bucket_size = bucket_size
|
|
181
|
+
if bucket_id == 0 and bucket_size == DEFAULT_BUCKET_SIZE:
|
|
182
|
+
target_bucket_size = MIN_BUCKET_SIZE
|
|
183
|
+
return (
|
|
184
|
+
bucket_size is not None
|
|
185
|
+
and bucket_numel != 0
|
|
186
|
+
and bucket_numel >= target_bucket_size
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def _build_bucket():
|
|
190
|
+
nonlocal buckets_metadata, bucket_start_index, bucket_params, bucket_id
|
|
191
|
+
bucket_end_index = data_start_index
|
|
192
|
+
buckets_metadata.append(
|
|
193
|
+
(bucket_start_index, bucket_end_index, bucket_params)
|
|
194
|
+
)
|
|
195
|
+
bucket_start_index = bucket_end_index
|
|
196
|
+
bucket_id = bucket_id + 1
|
|
197
|
+
bucket_params = []
|
|
198
|
+
|
|
199
|
+
param_data_list = []
|
|
200
|
+
buckets_metadata = []
|
|
201
|
+
data_start_index = 0
|
|
202
|
+
data_end_index = 0
|
|
203
|
+
bucket_id = 0
|
|
204
|
+
bucket_start_index = 0
|
|
205
|
+
bucket_params = []
|
|
206
|
+
for param in params[::]: # traverse from the beginning
|
|
207
|
+
last_bucket_numel = data_start_index - bucket_start_index
|
|
208
|
+
if _need_new_bucket(last_bucket_numel, bucket_id):
|
|
209
|
+
_build_bucket()
|
|
210
|
+
data_end_index = data_start_index + param.numel()
|
|
211
|
+
bucket_params.append(param)
|
|
212
|
+
param_data_list.append(param)
|
|
213
|
+
self.param_index_map[param] = (data_start_index, data_end_index, bucket_id)
|
|
214
|
+
data_start_index = data_end_index
|
|
215
|
+
|
|
216
|
+
# add bucket for the last few params which do not reach the bucket_size threshold
|
|
217
|
+
if data_start_index - bucket_start_index > 0:
|
|
218
|
+
bucket_end_index = data_start_index
|
|
219
|
+
buckets_metadata.append(
|
|
220
|
+
(bucket_start_index, bucket_end_index, bucket_params)
|
|
221
|
+
)
|
|
222
|
+
data_start_index = bucket_end_index
|
|
223
|
+
|
|
224
|
+
# allocate contiguous memory for parameters and gradients
|
|
225
|
+
self.numel = data_start_index
|
|
226
|
+
self.grad_data = Tensor(shape=(self.numel), dtype=self.grad_dtype, init=Zero())
|
|
227
|
+
self.grad_data.init_data()
|
|
228
|
+
self.numel_unpadded = 0
|
|
229
|
+
return buckets_metadata
|
|
230
|
+
|
|
231
|
+
def instantiate_buckets(self, buckets_metadata, params):
|
|
232
|
+
"""build bucket instance according to partition metadata"""
|
|
233
|
+
for bucket_start_index, bucket_end_index, bucket_params in buckets_metadata:
|
|
234
|
+
local_grad_data = self.grad_data[bucket_start_index:bucket_end_index]
|
|
235
|
+
self.numel_unpadded += bucket_end_index - bucket_start_index
|
|
236
|
+
bucket = Bucket(
|
|
237
|
+
average_in_collective=self.average_in_collective,
|
|
238
|
+
params=bucket_params,
|
|
239
|
+
grad_data=local_grad_data,
|
|
240
|
+
offset=bucket_start_index,
|
|
241
|
+
numel_unpadded=bucket_end_index - bucket_start_index,
|
|
242
|
+
data_parallel_group=self.data_parallel_group,
|
|
243
|
+
data_parallel_world_size=self.data_parallel_world_size,
|
|
244
|
+
gradient_scaling_factor=self.gradient_scaling_factor,
|
|
245
|
+
)
|
|
246
|
+
self.buckets.append(bucket)
|
|
247
|
+
for param in bucket_params:
|
|
248
|
+
self.param_to_bucket[param] = bucket
|
|
249
|
+
|
|
250
|
+
for param in params:
|
|
251
|
+
data_start_index, _, _ = self.param_index_map[param]
|
|
252
|
+
param.grad = self._get_buffer_slice(
|
|
253
|
+
param.shape, data_start_index, BufferType.GRAD
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def _get_buffer_slice(self, shape, start_index, buffer_type):
|
|
257
|
+
"""get the buffer view with the same shape"""
|
|
258
|
+
end_index = start_index + int(np.prod(shape))
|
|
259
|
+
if start_index < 0 or end_index > self.numel:
|
|
260
|
+
raise ValueError("index out of range")
|
|
261
|
+
if buffer_type == BufferType.GRAD:
|
|
262
|
+
buffer_tensor = self.grad_data[start_index:end_index]
|
|
263
|
+
else:
|
|
264
|
+
raise TypeError("Invalid buffer type for _get_buffer_slice.")
|
|
265
|
+
buffer_tensor = buffer_tensor.view(shape)
|
|
266
|
+
return buffer_tensor
|
|
267
|
+
|
|
268
|
+
def reset(self):
|
|
269
|
+
"""reset buffer for the next iteration."""
|
|
270
|
+
self.grad_data.zero_()
|
|
271
|
+
for bucket in self.buckets:
|
|
272
|
+
bucket.reset()
|
|
273
|
+
self.sync_enabled = True
|
|
274
|
+
|
|
275
|
+
def final_grad_reduce(self):
|
|
276
|
+
"""finalize grad reduce for each bucket"""
|
|
277
|
+
for bucket in self.buckets:
|
|
278
|
+
bucket.final_grad_reduce()
|
|
279
|
+
|
|
280
|
+
def register_grad_ready(self, param):
|
|
281
|
+
"""register ready grad in its buckets"""
|
|
282
|
+
if self.sync_enabled:
|
|
283
|
+
bucket = self.param_to_bucket[param]
|
|
284
|
+
if bucket.register_grad_ready(param):
|
|
285
|
+
self.issued += 1
|
|
286
|
+
if self.issued == len(self.buckets):
|
|
287
|
+
self.ddp_handle.buffer_issued += 1
|
|
288
|
+
if self.ddp_handle.buffer_issued == len(self.ddp_handle.buffers):
|
|
289
|
+
self.ddp_handle.final_grad_reduce()
|
|
290
|
+
|
|
291
|
+
def __repr__(self):
|
|
292
|
+
param_index_with_name = {
|
|
293
|
+
param.name: index for (param, index) in self.param_index_map.items()
|
|
294
|
+
}
|
|
295
|
+
return f"Buffer has buckets: \n {self.buckets} \n and param_index_map: \n {param_index_with_name}"
|