mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -30,7 +30,7 @@ def _generate_cmd(cmd, cmd_args, output_name):
|
|
|
30
30
|
|
|
31
31
|
"""
|
|
32
32
|
if cmd not in ['python', 'pytest', 'python3']:
|
|
33
|
-
# If user don't set binary file name,
|
|
33
|
+
# If user don't set binary file name, defaultly use 'python' to launch the job.
|
|
34
34
|
command = f"python {cmd} {' '.join(cmd_args)} > {output_name} 2>&1 &"
|
|
35
35
|
else:
|
|
36
36
|
command = f"{cmd} {' '.join(cmd_args)} > {output_name} 2>&1 &"
|
|
@@ -42,7 +42,7 @@ def _generate_cmd_args_list(cmd, cmd_args):
|
|
|
42
42
|
Generates arguments list for 'Popen'. It consists of a binary file name and subsequential arguments.
|
|
43
43
|
"""
|
|
44
44
|
if cmd not in ['python', 'pytest', 'python3']:
|
|
45
|
-
# If user don't set binary file name,
|
|
45
|
+
# If user don't set binary file name, defaultly use 'python' to launch the job.
|
|
46
46
|
return ['python'] + [cmd] + cmd_args
|
|
47
47
|
return [cmd] + cmd_args
|
|
48
48
|
|
|
@@ -55,7 +55,7 @@ def _generate_cmd_args_list_with_core(cmd, cmd_args, affinity_cpu_str):
|
|
|
55
55
|
taskset_args = ['taskset'] + ['-c'] + [affinity_cpu_str]
|
|
56
56
|
final_cmd = []
|
|
57
57
|
if cmd not in ['python', 'pytest', 'python3']:
|
|
58
|
-
# If user don't set binary file name,
|
|
58
|
+
# If user don't set binary file name, defaultly use 'python' to launch the job.
|
|
59
59
|
final_cmd = taskset_args + ['python'] + [cmd] + cmd_args
|
|
60
60
|
else:
|
|
61
61
|
final_cmd = taskset_args + [cmd] + cmd_args
|
|
@@ -143,8 +143,14 @@ def _parse_global_device_to_cpu_map(local_rank_id, physical_device_id, device_to
|
|
|
143
143
|
Parse the global device_to_cpu_map and return a cpu list for assigned local_rank_id.
|
|
144
144
|
|
|
145
145
|
"""
|
|
146
|
+
if local_rank_id >= len(list(device_to_cpu_map.keys())):
|
|
147
|
+
logger.warning(f"Cannot find process[{local_rank_id}] in args '--bind_core'. "
|
|
148
|
+
"Will not launch process with taskset.")
|
|
149
|
+
return ""
|
|
146
150
|
input_device_id = int(list(device_to_cpu_map.keys())[local_rank_id].replace("device", ""))
|
|
147
151
|
if physical_device_id != input_device_id:
|
|
152
|
+
logger.warning(f"Cannot find physical_device_id[{physical_device_id}] for process[{local_rank_id}] "
|
|
153
|
+
"in args '--bind_core'. Will not launch process with taskset.")
|
|
148
154
|
return ""
|
|
149
155
|
affinity_cpu_list = list(device_to_cpu_map.values())[local_rank_id]
|
|
150
156
|
affinity_cpu_str = ",".join(affinity_cpu_list)
|
|
@@ -212,8 +218,6 @@ def _generate_bind_core_strategy(local_rank_id, device_to_cpu_map, arg_bind_core
|
|
|
212
218
|
if isinstance(arg_bind_core, dict):
|
|
213
219
|
affinity_cpu_str = _parse_global_device_to_cpu_map(local_rank_id, physical_device_id, arg_bind_core)
|
|
214
220
|
if not affinity_cpu_str:
|
|
215
|
-
logger.warning(f"Failed to find physical_device_id[{physical_device_id}] for "
|
|
216
|
-
f"process[{local_rank_id}]. Will not launch process with taskset.")
|
|
217
221
|
return None
|
|
218
222
|
elif arg_bind_core is True:
|
|
219
223
|
cpu_list_for_device = device_to_cpu_map.get(physical_device_id, [])
|
|
@@ -125,14 +125,16 @@ def get_args():
|
|
|
125
125
|
default=-1,
|
|
126
126
|
type=int,
|
|
127
127
|
choices=[0, 1, 2, 3],
|
|
128
|
-
help="specifies simulation level.
|
|
129
|
-
"
|
|
128
|
+
help="specifies simulation level. This argument activates dryrun mode, functioning "
|
|
129
|
+
"equivalently to environment variable 'MS_SIMULATION_LEVEL' while having higher priority."
|
|
130
130
|
)
|
|
131
131
|
parser.add_argument(
|
|
132
132
|
"--sim_rank_id",
|
|
133
133
|
default=-1,
|
|
134
134
|
type=int,
|
|
135
|
-
help="specifies simulation process's rank id.
|
|
135
|
+
help="specifies simulation process's rank id. When this argument is set, only one process "
|
|
136
|
+
"is spawned on dryrun mode, functioning equivalently to environment variable 'RANK_ID' "
|
|
137
|
+
"while having higher priority."
|
|
136
138
|
)
|
|
137
139
|
parser.add_argument(
|
|
138
140
|
"--rank_table_file",
|
|
@@ -1,22 +1,21 @@
|
|
|
1
|
-
# Copyright
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
__all__ = ['LlamaBoostAscendNative']
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""distributed init"""
|
|
17
|
+
from mindspore.parallel.distributed.distributed_data_parallel import DistributedDataParallel
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"DistributedDataParallel",
|
|
21
|
+
]
|
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
""" Distributed data parallel wrapper. """
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
__all__ = ["DistributedDataParallel"]
|
|
19
|
+
|
|
20
|
+
import itertools
|
|
21
|
+
from contextlib import contextmanager
|
|
22
|
+
from typing import Optional
|
|
23
|
+
import mindspore.nn as nn
|
|
24
|
+
import mindspore.log as logger
|
|
25
|
+
from mindspore import Tensor, mint
|
|
26
|
+
from mindspore.common import dtype as mstype
|
|
27
|
+
from mindspore.mint.distributed import get_world_size
|
|
28
|
+
from mindspore.communication import GlobalComm
|
|
29
|
+
from mindspore.common.api import _pynative_executor
|
|
30
|
+
from mindspore.mint.distributed import broadcast, get_global_rank
|
|
31
|
+
from mindspore.parallel.distributed.flatten_grad_buffer import FlattenGradBuffer
|
|
32
|
+
from mindspore._c_expression import Reducer, _find_unused_parameters
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_data_parallel_group():
|
|
36
|
+
"""get default global data parallel group"""
|
|
37
|
+
return GlobalComm.WORLD_COMM_GROUP
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_data_parallel_world_size(group):
|
|
41
|
+
"""get group world size"""
|
|
42
|
+
return get_world_size(group)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _find_tensors(obj):
|
|
46
|
+
if isinstance(obj, Tensor):
|
|
47
|
+
return [obj]
|
|
48
|
+
if isinstance(obj, (list, tuple)):
|
|
49
|
+
return itertools.chain.from_iterable(map(_find_tensors, obj))
|
|
50
|
+
if isinstance(obj, dict):
|
|
51
|
+
return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
|
|
52
|
+
|
|
53
|
+
return []
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DistributedDataParallel(nn.Cell):
|
|
57
|
+
"""
|
|
58
|
+
DistributedDataParallel wrapper. DistributedDataParallel allocates contiguous memory buffer for gradients.
|
|
59
|
+
Parameters' gradients will be combined into multiple buckets which are the unit to conduct all-reduce
|
|
60
|
+
communication among data parallel group to overlap communication latency.
|
|
61
|
+
|
|
62
|
+
.. warning::
|
|
63
|
+
- The method is currently only supported in PyNative mode.
|
|
64
|
+
- This is an experimental interface, may be changed or canceled in the future.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
module (nn.Cell): the module to be wrapped with DDP.
|
|
68
|
+
init_sync (bool, optional): whether to sync params from rank0 of process_group when init. Default: ``True``.
|
|
69
|
+
process_group (str, optional): the comm group of data prallel. Default: ``None``.
|
|
70
|
+
bucket_cap_mb (int, optional): size of bucket in MB, default is 25MB if not set. Default: ``None``.
|
|
71
|
+
find_unused_parameters (bool, optional): whether to find unused params in the bucket. Default: ``False``.
|
|
72
|
+
average_in_collective (bool, optional): True means allreduce sum within DP group firstly then scaling with
|
|
73
|
+
dp size. Otherwise scaling local rank grad first and then allreduce sum. Default: ``False``.
|
|
74
|
+
static_graph (bool, optional): Indicate whether it is a static network. When it is a static network, the
|
|
75
|
+
parameter `find_unused_parameters` will be ignored, and unused parameters will be searched for in the
|
|
76
|
+
first step. Bucket reconstruction will be performed in execution order before the second step to achieve
|
|
77
|
+
better performance. Default: ``False``.
|
|
78
|
+
reducer_mode (str, optional): the backend to be used, could be "CppReducer" for cpp backend or "PythonReducer"
|
|
79
|
+
for Python backend. Default: ``"CppReducer"``.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Model wrapped with DistributedDataParallel.
|
|
83
|
+
|
|
84
|
+
Supported Platforms:
|
|
85
|
+
``Ascend``
|
|
86
|
+
|
|
87
|
+
Examples:
|
|
88
|
+
.. note::
|
|
89
|
+
- When enabling recomputation or gradient freezing, the model should be wrapped by
|
|
90
|
+
`DistributedDataParallel` at the outermost layer.
|
|
91
|
+
- Before running the following examples, you need to configure the communication environment variables.
|
|
92
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
93
|
+
without any third-party or configuration file dependencies. For detailed information, refer to
|
|
94
|
+
`msrun launch <https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_ .
|
|
95
|
+
|
|
96
|
+
>>> from mindspore.parallel.distributed import DistributedDataParallel
|
|
97
|
+
>>> from mindspore.mint.optim import AdamW
|
|
98
|
+
>>> from mindspore import Parameter, Tensor, ops, nn
|
|
99
|
+
>>> import mindspore as ms
|
|
100
|
+
>>> from mindspore.communication import init
|
|
101
|
+
>>> from mindspore.mint.distributed.distributed import init_process_group
|
|
102
|
+
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
103
|
+
>>> init_process_group()
|
|
104
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
105
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
106
|
+
>>> net = LeNet5()
|
|
107
|
+
>>> net = DistributedDataParallel(module=net,
|
|
108
|
+
... bucket_cap_mb=None,
|
|
109
|
+
... average_in_collective=True,
|
|
110
|
+
... static_graph=True)
|
|
111
|
+
>>> optimizer = AdamW(net.trainable_params(), 1e-4)
|
|
112
|
+
>>> loss_fn = nn.CrossEntropyLoss()
|
|
113
|
+
>>>
|
|
114
|
+
>>> def forward_fn(data, target):
|
|
115
|
+
... logits = net(data)
|
|
116
|
+
... loss = loss_fn(logits, target)
|
|
117
|
+
... return loss, logits
|
|
118
|
+
>>>
|
|
119
|
+
>>> grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True)
|
|
120
|
+
>>>
|
|
121
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
122
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
123
|
+
>>> dataset = create_dataset()
|
|
124
|
+
>>> for epoch in range(1):
|
|
125
|
+
... step = 0
|
|
126
|
+
... for image, label in dataset:
|
|
127
|
+
... (loss_value, _), grads = grad_fn(image, label)
|
|
128
|
+
... optimizer(grads)
|
|
129
|
+
... net.zero_grad()
|
|
130
|
+
... step += 1
|
|
131
|
+
... print("epoch: %s, step: %s, loss is %.15f" % (epoch, step, loss_value))
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(self, module, init_sync=True, process_group=None, bucket_cap_mb: Optional[int] = None,
|
|
135
|
+
find_unused_parameters=False, average_in_collective: bool = False, static_graph=False,
|
|
136
|
+
reducer_mode="CppReducer"):
|
|
137
|
+
super(DistributedDataParallel, self).__init__(auto_prefix=False)
|
|
138
|
+
self.init_sync = init_sync
|
|
139
|
+
self.bucket_cap_mb = bucket_cap_mb
|
|
140
|
+
self.average_in_collective = average_in_collective
|
|
141
|
+
self.grad_reduce_in_fp32 = False
|
|
142
|
+
self.process_group = process_group if process_group else get_data_parallel_group()
|
|
143
|
+
self.static_graph = static_graph
|
|
144
|
+
self.find_unused_parameters = find_unused_parameters
|
|
145
|
+
|
|
146
|
+
self.module = module
|
|
147
|
+
self.param_to_buffer = {}
|
|
148
|
+
self.has_buckets_grad_sync = False
|
|
149
|
+
|
|
150
|
+
# default is 25MB for each buck
|
|
151
|
+
if bucket_cap_mb is None:
|
|
152
|
+
bucket_cap_mb = 25
|
|
153
|
+
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
|
|
154
|
+
|
|
155
|
+
# grads sync with allreduce comm
|
|
156
|
+
self.sync_enabled = True
|
|
157
|
+
self.reducer_mode = reducer_mode # "CppReducer" or "PythonReducer"
|
|
158
|
+
self.buffers = []
|
|
159
|
+
self.has_mark_unused_param = False
|
|
160
|
+
|
|
161
|
+
bucketed_params = []
|
|
162
|
+
self.skipped_params = []
|
|
163
|
+
for _, param in self.module.parameters_and_names():
|
|
164
|
+
if not param.requires_grad:
|
|
165
|
+
self.skipped_params.append(param)
|
|
166
|
+
continue
|
|
167
|
+
param.grad = None
|
|
168
|
+
param.main_grad = None
|
|
169
|
+
bucketed_params.append(param)
|
|
170
|
+
if self.average_in_collective:
|
|
171
|
+
# allreduce to add grads, then to scale grads with dp size
|
|
172
|
+
self.gradient_scaling_factor = 1.0
|
|
173
|
+
else:
|
|
174
|
+
# scale grads with dp size locally, then allreduce to add grads
|
|
175
|
+
data_parallel_world_size = get_data_parallel_world_size(self.process_group)
|
|
176
|
+
self.gradient_scaling_factor = 1.0 / data_parallel_world_size
|
|
177
|
+
self.bucketed_params = bucketed_params
|
|
178
|
+
|
|
179
|
+
if self.reducer_mode == "CppReducer":
|
|
180
|
+
self.reducer = Reducer(self.bucketed_params,
|
|
181
|
+
self.process_group,
|
|
182
|
+
bucket_cap_mb,
|
|
183
|
+
self.grad_reduce_in_fp32,
|
|
184
|
+
average_in_collective,
|
|
185
|
+
static_graph,
|
|
186
|
+
find_unused_parameters)
|
|
187
|
+
if self.init_sync:
|
|
188
|
+
self.broadcast_coalesced()
|
|
189
|
+
return
|
|
190
|
+
# allocate buffer for trained params
|
|
191
|
+
self.buffers = self.allocate_buffers_for_parameters(
|
|
192
|
+
self.bucketed_params,
|
|
193
|
+
group=self.process_group,
|
|
194
|
+
gradient_scaling_factor=self.gradient_scaling_factor,
|
|
195
|
+
)
|
|
196
|
+
if self.init_sync:
|
|
197
|
+
self.broadcast_coalesced()
|
|
198
|
+
|
|
199
|
+
# register hook for bucket grad reduce
|
|
200
|
+
self._register_hook_for_params()
|
|
201
|
+
|
|
202
|
+
# bucket rebuilding
|
|
203
|
+
self.rebuilt_params_ = []
|
|
204
|
+
self.buffer_iterations = 0
|
|
205
|
+
self.has_bucket_rebuilt = False
|
|
206
|
+
self.buffer_issued = 0
|
|
207
|
+
self.triggered_once = False
|
|
208
|
+
|
|
209
|
+
def _group_params_by_dtype(self, input_params):
|
|
210
|
+
param_and_grad_dtype_to_params = {}
|
|
211
|
+
# group all params by parameter's data type and their gradient's data type.
|
|
212
|
+
for param in input_params:
|
|
213
|
+
param_dtype = param.dtype
|
|
214
|
+
grad_dtype = mstype.float32 if self.grad_reduce_in_fp32 else param.dtype
|
|
215
|
+
if (param_dtype, grad_dtype) not in param_and_grad_dtype_to_params:
|
|
216
|
+
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = []
|
|
217
|
+
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)].append(param)
|
|
218
|
+
return param_and_grad_dtype_to_params
|
|
219
|
+
|
|
220
|
+
def allocate_buffers_for_parameters(self, input_params, group, gradient_scaling_factor):
|
|
221
|
+
"""allocate buffers for parameters in different dtype group."""
|
|
222
|
+
param_and_grad_dtype_to_params = self._group_params_by_dtype(input_params)
|
|
223
|
+
|
|
224
|
+
buffers = []
|
|
225
|
+
# allocate buffer for each group separately
|
|
226
|
+
for (param_dtype, grad_dtype,), params in param_and_grad_dtype_to_params.items():
|
|
227
|
+
buffers.append(
|
|
228
|
+
FlattenGradBuffer(
|
|
229
|
+
average_in_collective=self.average_in_collective,
|
|
230
|
+
param_dtype=param_dtype,
|
|
231
|
+
grad_dtype=grad_dtype,
|
|
232
|
+
params=params,
|
|
233
|
+
data_parallel_group=group,
|
|
234
|
+
bucket_size=self.bucket_bytes_cap,
|
|
235
|
+
gradient_scaling_factor=gradient_scaling_factor,
|
|
236
|
+
ddp_handle=self,
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
for param in params:
|
|
240
|
+
self.param_to_buffer[param] = buffers[-1]
|
|
241
|
+
logger.debug("allocate buffers for parameters: %s", buffers)
|
|
242
|
+
return buffers
|
|
243
|
+
|
|
244
|
+
def final_grad_reduce(self):
|
|
245
|
+
"""trigger final grad reduction"""
|
|
246
|
+
logger.debug("trigger ddp final grad reduce, %d, %d", self.static_graph, len(self.unused_param))
|
|
247
|
+
if self._should_rebuild_buckets():
|
|
248
|
+
for param in self.unused_param:
|
|
249
|
+
self.rebuilt_params_.append(param)
|
|
250
|
+
for buffer in self.buffers:
|
|
251
|
+
buffer.final_grad_reduce()
|
|
252
|
+
buffer.issued = 0
|
|
253
|
+
self.buffer_issued = 0
|
|
254
|
+
|
|
255
|
+
def _register_hook_for_params(self):
|
|
256
|
+
"""register backward hook for each params."""
|
|
257
|
+
for param in self.module.get_parameters():
|
|
258
|
+
if param.requires_grad:
|
|
259
|
+
param.register_hook(self._make_param_hook(param))
|
|
260
|
+
|
|
261
|
+
def _post_forward(self, output):
|
|
262
|
+
"""prepare for backward (e.g. find unused params) if needed"""
|
|
263
|
+
if self.reducer_mode == "CppReducer":
|
|
264
|
+
if _pynative_executor.grad_flag() and self.sync_enabled:
|
|
265
|
+
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
|
266
|
+
else:
|
|
267
|
+
unused_param_idx = []
|
|
268
|
+
if self.static_graph and not self.triggered_once:
|
|
269
|
+
self.triggered_once = True
|
|
270
|
+
self.find_unused_parameters = False
|
|
271
|
+
unused_param_idx = _find_unused_parameters(list(_find_tensors(output)), self.bucketed_params)
|
|
272
|
+
elif self.find_unused_parameters:
|
|
273
|
+
unused_param_idx = _find_unused_parameters(list(_find_tensors(output)), self.bucketed_params)
|
|
274
|
+
self.unused_param = [self.bucketed_params[idx] for idx in unused_param_idx]
|
|
275
|
+
self.unused_param_name = [param.name for param in self.unused_param]
|
|
276
|
+
self.has_mark_unused_param = False
|
|
277
|
+
|
|
278
|
+
def _pre_forward(self):
|
|
279
|
+
"""pre-process of forward pass to allocate buffer for parameters."""
|
|
280
|
+
if self.reducer_mode == "CppReducer":
|
|
281
|
+
if _pynative_executor.grad_flag() and self.sync_enabled:
|
|
282
|
+
self.reducer.prepare_for_forward()
|
|
283
|
+
self.reducer.rebuild_buckets()
|
|
284
|
+
return
|
|
285
|
+
if self.rebuilt_params_ and self._should_rebuild_buckets():
|
|
286
|
+
for i in self.rebuilt_params_:
|
|
287
|
+
i.old_grad = i.grad
|
|
288
|
+
|
|
289
|
+
self.buffers = self.allocate_buffers_for_parameters(
|
|
290
|
+
self.rebuilt_params_,
|
|
291
|
+
group=self.process_group,
|
|
292
|
+
gradient_scaling_factor=self.gradient_scaling_factor,
|
|
293
|
+
)
|
|
294
|
+
for buffer in self.buffers:
|
|
295
|
+
buffer.sync_enabled = self.sync_enabled
|
|
296
|
+
|
|
297
|
+
for i in self.rebuilt_params_:
|
|
298
|
+
i.grad.copy_(i.old_grad)
|
|
299
|
+
i.old_grad = None
|
|
300
|
+
|
|
301
|
+
logger.debug("register unused param: %s", self.rebuilt_params_)
|
|
302
|
+
self.has_bucket_rebuilt = True
|
|
303
|
+
self.rebuilt_params_ = []
|
|
304
|
+
|
|
305
|
+
def construct(self, *inputs, **inputs_dict):
|
|
306
|
+
"""construct for DistributedDataParallel."""
|
|
307
|
+
self._pre_forward()
|
|
308
|
+
output = self.module(*inputs, **inputs_dict)
|
|
309
|
+
self._post_forward(output)
|
|
310
|
+
return output
|
|
311
|
+
|
|
312
|
+
def zero_grad(self):
|
|
313
|
+
"""DPP will accumulate grads automatically, it will zero grads when call zero_grad() manually."""
|
|
314
|
+
if self.reducer_mode == "CppReducer":
|
|
315
|
+
self.reducer.zero_grad()
|
|
316
|
+
else:
|
|
317
|
+
for buffer in self.buffers:
|
|
318
|
+
buffer.reset()
|
|
319
|
+
|
|
320
|
+
def _enable_sync(self, enable):
|
|
321
|
+
"""enable grad buffer sync or not."""
|
|
322
|
+
for buffer in self.buffers:
|
|
323
|
+
buffer.sync_enabled = enable
|
|
324
|
+
self.sync_enabled = enable
|
|
325
|
+
|
|
326
|
+
@contextmanager
|
|
327
|
+
def no_sync(self):
|
|
328
|
+
"""Context manager helper function. When enabled, no grad allreduce synchronization will be executed."""
|
|
329
|
+
self._enable_sync(False)
|
|
330
|
+
try:
|
|
331
|
+
yield
|
|
332
|
+
finally:
|
|
333
|
+
self._enable_sync(True)
|
|
334
|
+
|
|
335
|
+
def _should_rebuild_buckets(self):
|
|
336
|
+
if self.static_graph and not self.has_bucket_rebuilt:
|
|
337
|
+
return True
|
|
338
|
+
return False
|
|
339
|
+
|
|
340
|
+
def _make_param_hook(self, param):
|
|
341
|
+
"""make closure function as the param hook."""
|
|
342
|
+
def param_hook(grad):
|
|
343
|
+
if not self.has_mark_unused_param:
|
|
344
|
+
for cur_param in self.unused_param:
|
|
345
|
+
buffer = self.param_to_buffer[cur_param]
|
|
346
|
+
logger.debug("register unused param: %s", cur_param)
|
|
347
|
+
buffer.register_grad_ready(cur_param)
|
|
348
|
+
self.has_mark_unused_param = True
|
|
349
|
+
elif param.name in self.unused_param_name:
|
|
350
|
+
logger.debug("unused param already registered: %s", param)
|
|
351
|
+
return param.grad
|
|
352
|
+
|
|
353
|
+
logger.debug("register normal param: %s", param)
|
|
354
|
+
buffer = self.param_to_buffer[param]
|
|
355
|
+
param.grad.add_(grad)
|
|
356
|
+
buffer.register_grad_ready(param)
|
|
357
|
+
if self._should_rebuild_buckets():
|
|
358
|
+
self.rebuilt_params_.append(param)
|
|
359
|
+
return param.grad
|
|
360
|
+
|
|
361
|
+
return param_hook
|
|
362
|
+
|
|
363
|
+
def broadcast_coalesced(self):
|
|
364
|
+
"""broadcast params from rank 0"""
|
|
365
|
+
if self.reducer_mode == "CppReducer":
|
|
366
|
+
buckets = [[self.bucketed_params[idx] for idx in bucket] for bucket in self.reducer.bucket_indices]
|
|
367
|
+
else:
|
|
368
|
+
buckets = [bucket.params_list for buffer in self.buffers for bucket in buffer.buckets]
|
|
369
|
+
if self.skipped_params:
|
|
370
|
+
param_and_grad_dtype_to_params = self._group_params_by_dtype(self.skipped_params)
|
|
371
|
+
for params_list in param_and_grad_dtype_to_params.values():
|
|
372
|
+
buckets.append(params_list)
|
|
373
|
+
|
|
374
|
+
def finish(rate_limiter):
|
|
375
|
+
for _ in rate_limiter:
|
|
376
|
+
handle, coalesced, params = rate_limiter.pop(0)
|
|
377
|
+
handle.wait()
|
|
378
|
+
ptr = 0
|
|
379
|
+
for param in params:
|
|
380
|
+
param.view(-1).copy_(coalesced[ptr:ptr + param.numel()])
|
|
381
|
+
ptr += param.numel()
|
|
382
|
+
|
|
383
|
+
rate_limiter = []
|
|
384
|
+
for params in buckets:
|
|
385
|
+
flat_tensors = [t.view(-1) for t in params]
|
|
386
|
+
coalesced = mint.cat(flat_tensors)
|
|
387
|
+
global_rank = get_global_rank(self.process_group, 0)
|
|
388
|
+
handle = broadcast(coalesced, src=global_rank, group=self.process_group, async_op=True)
|
|
389
|
+
rate_limiter.append((handle, coalesced, params))
|
|
390
|
+
|
|
391
|
+
if len(rate_limiter) >= 2:
|
|
392
|
+
finish(rate_limiter)
|
|
393
|
+
finish(rate_limiter)
|