megatron-core 0.12.0rc3__tar.gz → 0.12.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- {megatron_core-0.12.0rc3/megatron_core.egg-info → megatron_core-0.12.2}/PKG-INFO +1 -1
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/async_utils.py +29 -11
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +214 -156
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +275 -186
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/distributed_data_parallel_config.py +7 -3
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/finalize_model_grads.py +3 -2
- megatron_core-0.12.2/megatron/core/export/model_type.py +8 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +6 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +10 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trt_model_config.py +1 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trt_model_type.py +1 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_helper.py +20 -2
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_layers.py +9 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +10 -4
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +17 -5
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/extensions/transformer_engine.py +34 -8
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fp8_utils.py +15 -8
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_bias_swiglu.py +57 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/rope_utils.py +20 -3
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/gpt/gpt_layer_specs.py +27 -6
- megatron_core-0.12.2/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +209 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/__init__.py +16 -5
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/optimizer.py +99 -21
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/package_info.py +2 -2
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/gpt/model_specs.py +12 -4
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/__init__.py +2 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/random.py +149 -23
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/attention.py +4 -1
- megatron_core-0.12.2/megatron/core/transformer/heterogeneous/heterogeneous_config.py +267 -0
- megatron_core-0.12.2/megatron/core/transformer/heterogeneous/linear_replacements.py +111 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/mlp.py +37 -18
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/experts.py +166 -60
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py +18 -11
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/moe_layer.py +20 -10
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/moe_utils.py +91 -19
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/shared_experts.py +4 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/token_dispatcher.py +63 -64
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/multi_latent_attention.py +121 -69
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/torch_norm.py +49 -1
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/transformer_block.py +22 -5
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/transformer_config.py +88 -14
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/transformer_layer.py +51 -5
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2/megatron_core.egg-info}/PKG-INFO +1 -1
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron_core.egg-info/SOURCES.txt +3 -0
- megatron_core-0.12.0rc3/megatron/core/export/model_type.py +0 -7
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/LICENSE +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/MANIFEST.in +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/README.md +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/README.md +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/config_logger.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/bert_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/blended_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/gpt_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/helpers.cpp +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/helpers.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/indexed_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/masked_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/megatron_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/megatron_tokenizer.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/multimodal_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/db/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/db/build.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/db/dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/db/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/external_libs.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/build.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/factory.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/index.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/validate.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/query.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/t5_dataset.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/utils_s3.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/core.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/mapping.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/optimizer.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/serialization.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/validation.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/data_parallel_base.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/distributed_data_parallel.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/enums.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/data_type.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/export_config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/extensions/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_bias_dropout.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_bias_geglu.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_bias_gelu.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_cross_entropy.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_layer_norm.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_softmax.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/async_stream.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/common_inference_params.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/communication_utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/contexts/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/contexts/base_context.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/contexts/dynamic_context.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/contexts/static_context.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/abstract_engine.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/dynamic_engine.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/mcore_engine.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/static_engine.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/inference_request.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/gpt/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/gpt/model_specs.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/mamba/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/mamba/model_specs.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/sampling_params.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/scheduler.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference_params.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/jit.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/model_parallel_config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/T5/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/T5/t5_model.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/T5/t5_spec.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/bert_layer_specs.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/bert_lm_head.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/bert_model.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/pooler.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/language_module/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/language_module/language_module.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/vision_module/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/vision_module/vision_module.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/gpt/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/gpt/gpt_model.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/gpt/moe_module_specs.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/huggingface/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/huggingface/clip_model.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/huggingface/module.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/huggingface/qwen_model.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/mamba/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/mamba/mamba_model.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/multimodal/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/multimodal/context_parallel.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/multimodal/llava_model.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/multimodal/llava_spec.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/base_attention.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/decoder_attention.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/decoder_spec.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/encoder_attention.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/encoder_spec.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/model.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/clip_vit_model.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/multimodal_projector.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/radio.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/vit_layer_specs.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/num_microbatches_calculator.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/clip_grads.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/distrib_optimizer.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/grad_scaler.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/optimizer_config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer_param_scheduler.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/packed_seq_params.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/parallel_state.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/pipeline_parallel/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/pipeline_parallel/schedules.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/layers.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/process_groups_config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/requirements.txt +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/rerun_state_machine.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_block.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_config.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_layer.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_mixer.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mlp_layer.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/triton_cache_manager.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/data.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/layers.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/mappings.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/timers.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/cuda_graphs.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/custom_layers/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/dot_product_attention.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/enums.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/identity_op.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/module.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/__init__.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/fused_a2a.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/router.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/multi_token_prediction.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/spec_utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/torch_layer_norm.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/utils.py +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron_core.egg-info/dependency_links.txt +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron_core.egg-info/requires.txt +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron_core.egg-info/top_level.txt +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/pyproject.toml +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/requirements/pytorch_24.01/requirements.txt +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/requirements/pytorch_24.07/requirements.txt +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/requirements/pytorch_24.10/requirements.txt +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/requirements/pytorch_25.03/requirements.txt +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/setup.cfg +0 -0
- {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: megatron-core
|
|
3
|
-
Version: 0.12.
|
|
3
|
+
Version: 0.12.2
|
|
4
4
|
Summary: Megatron Core - a library for efficient and scalable training of transformer based models
|
|
5
5
|
Home-page: https://github.com/NVIDIA/Megatron-LM/megatron/core
|
|
6
6
|
Download-URL: https://github.com/NVIDIA/Megatron-LM/releases
|
|
@@ -155,7 +155,7 @@ class AsyncCaller(ABC):
|
|
|
155
155
|
logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")
|
|
156
156
|
|
|
157
157
|
def __del__(self):
|
|
158
|
-
|
|
158
|
+
raise NotImplementedError("This should be implemented")
|
|
159
159
|
|
|
160
160
|
|
|
161
161
|
class TemporalAsyncCaller(AsyncCaller):
|
|
@@ -227,12 +227,22 @@ class TemporalAsyncCaller(AsyncCaller):
|
|
|
227
227
|
is_alive = int(self.process.is_alive()) if self.process is not None else 0
|
|
228
228
|
is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
|
|
229
229
|
|
|
230
|
-
if
|
|
230
|
+
if is_done or blocking:
|
|
231
|
+
# Process join is called in the following cases
|
|
232
|
+
# 1. blocking == True -> regardless of is_done
|
|
233
|
+
# 2. blocking == False (non-blocking)
|
|
234
|
+
# -> is_done == True: async requests on all ranks are identified to be finished
|
|
235
|
+
# `self.close()` makes sure the async callers terminated
|
|
231
236
|
self.close()
|
|
232
237
|
is_done = True
|
|
233
238
|
return is_done
|
|
234
239
|
|
|
235
240
|
def close(self):
|
|
241
|
+
"""For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done`
|
|
242
|
+
|
|
243
|
+
This method make sure the TemporalAsyncCaller terminated
|
|
244
|
+
with all its assigned async request completed
|
|
245
|
+
"""
|
|
236
246
|
if self.process:
|
|
237
247
|
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
|
|
238
248
|
self.process.join()
|
|
@@ -243,6 +253,9 @@ class TemporalAsyncCaller(AsyncCaller):
|
|
|
243
253
|
)
|
|
244
254
|
self.start_time = None
|
|
245
255
|
|
|
256
|
+
def __del__(self):
|
|
257
|
+
pass
|
|
258
|
+
|
|
246
259
|
|
|
247
260
|
class PersistentAsyncCaller(AsyncCaller):
|
|
248
261
|
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
|
|
@@ -376,6 +389,10 @@ class PersistentAsyncCaller(AsyncCaller):
|
|
|
376
389
|
return is_done
|
|
377
390
|
|
|
378
391
|
def close(self):
|
|
392
|
+
"""Wait on the left async requests and terminate the PersistentAsyncCaller
|
|
393
|
+
|
|
394
|
+
Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated
|
|
395
|
+
"""
|
|
379
396
|
logger.info(
|
|
380
397
|
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller"
|
|
381
398
|
)
|
|
@@ -385,6 +402,9 @@ class PersistentAsyncCaller(AsyncCaller):
|
|
|
385
402
|
self.process.join()
|
|
386
403
|
self.process = None
|
|
387
404
|
|
|
405
|
+
def __del__(self):
|
|
406
|
+
self.close()
|
|
407
|
+
|
|
388
408
|
@staticmethod
|
|
389
409
|
@_disable_gc()
|
|
390
410
|
def async_loop(
|
|
@@ -492,13 +512,11 @@ class AsyncCallsQueue:
|
|
|
492
512
|
# Backward compatibility for local checkpointing built with the old AsyncRequest
|
|
493
513
|
if len(async_request._fields) != len(AsyncRequest._fields):
|
|
494
514
|
async_request = AsyncRequest(**async_request._asdict())
|
|
495
|
-
|
|
496
|
-
async_request = async_request._replace(call_idx=self.call_idx)
|
|
497
|
-
finalize_fns = async_request.finalize_fns
|
|
498
|
-
async_request = async_request._replace(finalize_fns=None)
|
|
499
515
|
async_request = async_request.freeze()
|
|
500
|
-
async_caller.schedule_async_call(
|
|
501
|
-
|
|
516
|
+
async_caller.schedule_async_call(
|
|
517
|
+
async_request._replace(call_idx=self.call_idx, finalize_fns=[])
|
|
518
|
+
)
|
|
519
|
+
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
|
|
502
520
|
return self.call_idx
|
|
503
521
|
|
|
504
522
|
def maybe_finalize_async_calls(self, blocking=False, no_dist=False) -> List[int]:
|
|
@@ -522,13 +540,13 @@ class AsyncCallsQueue:
|
|
|
522
540
|
if not next_async_done:
|
|
523
541
|
break
|
|
524
542
|
with debug_time("finalize", logger):
|
|
525
|
-
call_idx, _,
|
|
543
|
+
call_idx, _, async_request = self.async_calls.popleft()
|
|
544
|
+
for finalize_fn in async_request.finalize_fns:
|
|
545
|
+
finalize_fn()
|
|
526
546
|
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
|
|
527
547
|
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
|
|
528
548
|
assert ten.item() == call_idx, 'Unmatched async calls. '
|
|
529
549
|
'That probably means not all ranks are participating in async finalization'
|
|
530
|
-
for finalize_fn in finalize_fns:
|
|
531
|
-
finalize_fn()
|
|
532
550
|
call_idx_finalized.append(call_idx)
|
|
533
551
|
return call_idx_finalized
|
|
534
552
|
|
|
@@ -76,7 +76,10 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
76
76
|
module: Underlying model.
|
|
77
77
|
fsdp_unit_modules: List of modules that should be treated as FSDP Unit,
|
|
78
78
|
i.e., the minimum releasable model unit. If not provided, defaults to
|
|
79
|
-
[TransformerLayer, LanguageModelEmbedding] for GPT-like models.
|
|
79
|
+
[TransformerLayer, LanguageModelEmbedding] for GPT-like models. In
|
|
80
|
+
addition to this, it affects the granularity of the communication
|
|
81
|
+
parameter grouping and triggers aggregate collective communication
|
|
82
|
+
in fp8 mixed precision training.
|
|
80
83
|
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
|
|
81
84
|
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
|
|
82
85
|
per bucket.
|
|
@@ -122,7 +125,10 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
122
125
|
if fsdp_unit_modules is not None:
|
|
123
126
|
self.fsdp_unit_modules = fsdp_unit_modules
|
|
124
127
|
else:
|
|
125
|
-
self.
|
|
128
|
+
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
|
|
129
|
+
self.fsdp_unit_modules = [TransformerLayer]
|
|
130
|
+
else:
|
|
131
|
+
self.fsdp_unit_modules = []
|
|
126
132
|
self.main_weights = True
|
|
127
133
|
self.data_parallel_group = parallel_state.get_data_parallel_group(
|
|
128
134
|
with_context_parallel=True
|
|
@@ -177,14 +183,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
177
183
|
self.module,
|
|
178
184
|
bucketing_policy=BucketingPolicy(
|
|
179
185
|
suggested_bucket_size=self.bucket_size,
|
|
180
|
-
fsdp_unit_modules=
|
|
181
|
-
# Only when model weights need to be sharded, we need to
|
|
182
|
-
# identify the minimum releasable model unit, which is the
|
|
183
|
-
# FSDP Unit Module.
|
|
184
|
-
self.fsdp_unit_modules
|
|
185
|
-
if self.data_parallel_sharding_strategy == "optim_grads_params"
|
|
186
|
-
else []
|
|
187
|
-
),
|
|
186
|
+
fsdp_unit_modules=self.fsdp_unit_modules,
|
|
188
187
|
data_parallel_sharding_strategy=self.data_parallel_sharding_strategy,
|
|
189
188
|
),
|
|
190
189
|
data_parallel_group=self.data_parallel_group,
|
|
@@ -208,8 +207,24 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
208
207
|
# Initialize the all-gather pipeline.
|
|
209
208
|
self.all_gather_pipeline = AllGatherPipeline(self.param_and_grad_buffer)
|
|
210
209
|
|
|
211
|
-
|
|
212
|
-
|
|
210
|
+
suggested_communication_unit_size = self.ddp_config.suggested_communication_unit_size
|
|
211
|
+
if suggested_communication_unit_size is None:
|
|
212
|
+
if self.data_parallel_sharding_strategy == "optim_grads_params":
|
|
213
|
+
total_param_elements = 0
|
|
214
|
+
total_fsdp_module = 0
|
|
215
|
+
for module in self.module.modules():
|
|
216
|
+
if isinstance(module, tuple(self.fsdp_unit_modules)):
|
|
217
|
+
total_fsdp_module += 1
|
|
218
|
+
total_param_elements += sum(p.numel() for p in module.parameters())
|
|
219
|
+
# The suggested size is twice the number of elements in the FSDP modules.
|
|
220
|
+
# This ensures we process the current FSDP module and attempt to prefetch
|
|
221
|
+
# the next FSDP module, making the flow of communication better.
|
|
222
|
+
suggested_communication_unit_size = total_param_elements // total_fsdp_module * 2
|
|
223
|
+
elif self.bucket_size is not None:
|
|
224
|
+
suggested_communication_unit_size = self.bucket_size * 2
|
|
225
|
+
|
|
226
|
+
self.suggested_RS_queue_capacity = suggested_communication_unit_size
|
|
227
|
+
self.suggested_AG_prefetch_size = suggested_communication_unit_size
|
|
213
228
|
|
|
214
229
|
def _register_fsdp_hooks(self, root_module):
|
|
215
230
|
"""Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
|
|
@@ -219,8 +234,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
219
234
|
- Pre-forward hook: Unshards parameters before forward pass
|
|
220
235
|
- Post-forward hook: Reshards parameters after forward pass
|
|
221
236
|
- Pre-backward hook: Unshards parameters before backward pass
|
|
222
|
-
- Post-backward hook: Reshards parameters after backward pass
|
|
223
|
-
- Gradient accumulation hook: Handles gradient accumulation and reduction across devices
|
|
237
|
+
- Post-backward hook: Reshards parameters and reduces gradients after backward pass
|
|
224
238
|
|
|
225
239
|
Args:
|
|
226
240
|
root_module: The PyTorch module to register FSDP hooks on
|
|
@@ -254,10 +268,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
254
268
|
`optim` and `optim_grads` do not require FSDP units because they do not
|
|
255
269
|
shard model parameters.
|
|
256
270
|
"""
|
|
257
|
-
|
|
258
|
-
fsdp_unit_modules = []
|
|
259
|
-
else:
|
|
260
|
-
fsdp_unit_modules = self.fsdp_unit_modules
|
|
271
|
+
fsdp_unit_modules = self.fsdp_unit_modules
|
|
261
272
|
|
|
262
273
|
def release_module_parameters(module, *unused):
|
|
263
274
|
for param in module.parameters():
|
|
@@ -280,27 +291,74 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
280
291
|
prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
|
|
281
292
|
wait_bucket_ready=True,
|
|
282
293
|
):
|
|
283
|
-
wait_list = []
|
|
284
294
|
ag_pipeline = self.all_gather_pipeline
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
292
|
-
)
|
|
293
|
-
wait_list.append(bucket_id)
|
|
294
|
-
|
|
295
|
+
ag_pipeline.all_gather_params(
|
|
296
|
+
params=list(module.parameters()),
|
|
297
|
+
prefetch=prefetch,
|
|
298
|
+
prefetch_order=prefetch_order,
|
|
299
|
+
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
300
|
+
)
|
|
295
301
|
if wait_bucket_ready:
|
|
296
|
-
for
|
|
302
|
+
for param in module.parameters():
|
|
303
|
+
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
297
304
|
ag_pipeline.wait_bucket_ready(bucket_id)
|
|
298
305
|
|
|
306
|
+
def _grad_acc(param):
|
|
307
|
+
"""
|
|
308
|
+
Accumulate the gradient in the main_grad buffer.
|
|
309
|
+
"""
|
|
310
|
+
group_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
311
|
+
group = self.param_and_grad_buffer.parameter_groups[group_id]
|
|
312
|
+
if not group.requires_grad:
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
overwrite_main_grad = self.ddp_config.data_parallel_sharding_strategy in [
|
|
316
|
+
"optim_grads",
|
|
317
|
+
"optim_grads_params",
|
|
318
|
+
]
|
|
319
|
+
if overwrite_main_grad:
|
|
320
|
+
if not param.grad_added_to_main_grad:
|
|
321
|
+
if param.grad is not None:
|
|
322
|
+
param.main_grad.copy_(param.grad)
|
|
323
|
+
del param.grad
|
|
324
|
+
else:
|
|
325
|
+
param.main_grad.zero_()
|
|
326
|
+
else:
|
|
327
|
+
if not param.grad_added_to_main_grad:
|
|
328
|
+
if param.grad is not None:
|
|
329
|
+
param.main_grad.add_(param.grad)
|
|
330
|
+
del param.grad
|
|
331
|
+
# Reset the grad accumulate flag.
|
|
332
|
+
param.grad_added_to_main_grad = False
|
|
333
|
+
|
|
334
|
+
self._params_require_handle_grad = set()
|
|
335
|
+
|
|
299
336
|
def _post_backward(module, *unused):
|
|
300
|
-
|
|
301
|
-
|
|
337
|
+
if isinstance(module, tuple(fsdp_unit_modules)):
|
|
338
|
+
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
|
|
339
|
+
release_module_parameters(module)
|
|
340
|
+
module._training_state = TrainingState.IDLE
|
|
341
|
+
param_list = list(module.parameters())
|
|
342
|
+
else:
|
|
343
|
+
param_list = list(module.parameters(recurse=False))
|
|
344
|
+
|
|
345
|
+
for param in param_list:
|
|
346
|
+
_grad_acc(param)
|
|
347
|
+
self._params_require_handle_grad.discard(param)
|
|
348
|
+
|
|
349
|
+
grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
|
|
350
|
+
"optim_grads",
|
|
351
|
+
"optim_grads_params",
|
|
352
|
+
]
|
|
353
|
+
if grad_reduce_every_bprop or self.is_last_microbatch:
|
|
354
|
+
self.grad_reduce_pipeline.reduce_gradients(
|
|
355
|
+
param_list, suggested_queue_capacity=self.suggested_RS_queue_capacity
|
|
356
|
+
)
|
|
302
357
|
|
|
303
|
-
def
|
|
358
|
+
def _pre_forward_param_unshard(
|
|
359
|
+
module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
|
360
|
+
):
|
|
361
|
+
# Unshard the parameters before the forward pass.
|
|
304
362
|
input_training_state = module._training_state
|
|
305
363
|
fsdp_forward_prefetch = True
|
|
306
364
|
if input_training_state == TrainingState.PRE_BACKWARD:
|
|
@@ -310,72 +368,104 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
310
368
|
module._training_state = TrainingState.FORWARD
|
|
311
369
|
|
|
312
370
|
if isinstance(module, tuple(fsdp_unit_modules)):
|
|
313
|
-
|
|
314
|
-
|
|
371
|
+
param_list = list(module.parameters())
|
|
372
|
+
self.all_gather_pipeline.all_gather_params(
|
|
373
|
+
params=param_list,
|
|
374
|
+
prefetch=fsdp_forward_prefetch,
|
|
375
|
+
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
376
|
+
)
|
|
377
|
+
for param in param_list:
|
|
315
378
|
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
316
|
-
self.all_gather_pipeline.queue_bucket_to_all_gather(
|
|
317
|
-
bucket_id,
|
|
318
|
-
prefetch=fsdp_forward_prefetch,
|
|
319
|
-
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
320
|
-
)
|
|
321
|
-
wait_list.append(bucket_id)
|
|
322
|
-
for bucket_id in wait_list:
|
|
323
379
|
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
|
|
324
|
-
|
|
325
|
-
if not torch.is_grad_enabled():
|
|
326
|
-
return args, kwargs
|
|
327
|
-
|
|
328
|
-
# Register the backward function to release the parameters.
|
|
329
|
-
args_list, args_spec = tree_flatten(args)
|
|
330
|
-
kwargs_list, kwargs_spec = tree_flatten(kwargs)
|
|
331
|
-
args_kwargs_list = list(args_list) + list(kwargs_list)
|
|
332
|
-
inp_tensor_indices: List[int] = []
|
|
333
|
-
inp_tensors: List[torch.Tensor] = []
|
|
334
|
-
for i, obj in enumerate(args_kwargs_list):
|
|
335
|
-
if torch.is_tensor(obj) and obj.requires_grad:
|
|
336
|
-
inp_tensor_indices.append(i)
|
|
337
|
-
inp_tensors.append(obj)
|
|
338
|
-
if len(inp_tensors) == 0:
|
|
339
|
-
return args, kwargs
|
|
340
|
-
inp_tensors = RegisterFSDPBackwardFunction.apply(
|
|
341
|
-
functools.partial(_post_backward, module), *inp_tensors
|
|
342
|
-
)
|
|
343
|
-
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
|
|
344
|
-
args_kwargs_list[inp_tensor_idx] = inp_tensor
|
|
345
|
-
args_list = args_kwargs_list[: len(args_list)]
|
|
346
|
-
kwargs_list = args_kwargs_list[len(args_list) :]
|
|
347
|
-
args = tree_unflatten(args_list, args_spec)
|
|
348
|
-
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
|
|
349
|
-
|
|
350
|
-
return args, kwargs
|
|
351
380
|
else:
|
|
352
381
|
# All-gather the parameters in every forward pass for FSDP.
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
for param in module.parameters(recurse=False):
|
|
382
|
+
param_list = list(module.parameters(recurse=False))
|
|
383
|
+
self.all_gather_pipeline.all_gather_params(
|
|
384
|
+
params=param_list,
|
|
385
|
+
prefetch=fsdp_forward_prefetch,
|
|
386
|
+
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
387
|
+
)
|
|
388
|
+
for param in param_list:
|
|
361
389
|
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
362
390
|
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
|
|
391
|
+
return args, kwargs
|
|
392
|
+
|
|
393
|
+
def _register_post_backward_hook(
|
|
394
|
+
post_backward_hook: callable,
|
|
395
|
+
module: nn.Module,
|
|
396
|
+
args: Tuple[Any, ...],
|
|
397
|
+
kwargs: Dict[str, Any],
|
|
398
|
+
):
|
|
399
|
+
# Register the backward function to reduce gradients after the backward pass.
|
|
400
|
+
# And for optim_grads_params, we need to release the parameters after the backward pass.
|
|
401
|
+
if not torch.is_grad_enabled():
|
|
402
|
+
return args, kwargs
|
|
403
|
+
|
|
404
|
+
args_list, args_spec = tree_flatten(args)
|
|
405
|
+
kwargs_list, kwargs_spec = tree_flatten(kwargs)
|
|
406
|
+
args_kwargs_list = list(args_list) + list(kwargs_list)
|
|
407
|
+
inp_tensor_indices: List[int] = []
|
|
408
|
+
inp_tensors: List[torch.Tensor] = []
|
|
409
|
+
for i, obj in enumerate(args_kwargs_list):
|
|
410
|
+
if torch.is_tensor(obj) and obj.requires_grad:
|
|
411
|
+
inp_tensor_indices.append(i)
|
|
412
|
+
inp_tensors.append(obj)
|
|
413
|
+
|
|
414
|
+
if len(inp_tensors) == 0:
|
|
415
|
+
return args, kwargs
|
|
416
|
+
|
|
417
|
+
inp_tensors = RegisterFSDPBackwardFunction.apply(
|
|
418
|
+
functools.partial(post_backward_hook, module), *inp_tensors
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
|
|
422
|
+
args_kwargs_list[inp_tensor_idx] = inp_tensor
|
|
423
|
+
args_list = args_kwargs_list[: len(args_list)]
|
|
424
|
+
kwargs_list = args_kwargs_list[len(args_list) :]
|
|
425
|
+
args = tree_unflatten(args_list, args_spec)
|
|
426
|
+
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
|
|
363
427
|
|
|
364
428
|
return args, kwargs
|
|
365
429
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
|
|
371
|
-
continue
|
|
430
|
+
fsdp_modules = []
|
|
431
|
+
for name, module in root_module.named_modules():
|
|
432
|
+
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
|
|
433
|
+
continue
|
|
372
434
|
|
|
373
|
-
|
|
374
|
-
|
|
435
|
+
if isinstance(module, tuple(fsdp_unit_modules)):
|
|
436
|
+
fsdp_modules.append(module)
|
|
437
|
+
|
|
438
|
+
self.forward_pre_hooks[f'module {name} parameter unshard'] = (
|
|
439
|
+
module.register_forward_pre_hook(
|
|
440
|
+
_pre_forward_param_unshard, prepend=True, with_kwargs=True
|
|
441
|
+
)
|
|
442
|
+
)
|
|
443
|
+
self.forward_pre_hooks[f"module {name} register post-backward hook"] = (
|
|
444
|
+
module.register_forward_pre_hook(
|
|
445
|
+
functools.partial(_register_post_backward_hook, _post_backward),
|
|
446
|
+
with_kwargs=True,
|
|
447
|
+
)
|
|
448
|
+
)
|
|
375
449
|
|
|
376
|
-
|
|
377
|
-
|
|
450
|
+
def _root_post_backward(*unused):
|
|
451
|
+
# Make sure all the gradients are handled.
|
|
452
|
+
for param in self._params_require_handle_grad:
|
|
453
|
+
_grad_acc(param)
|
|
454
|
+
|
|
455
|
+
# Reduce the remain gradients.
|
|
456
|
+
grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
|
|
457
|
+
"optim_grads",
|
|
458
|
+
"optim_grads_params",
|
|
459
|
+
]
|
|
460
|
+
if grad_reduce_every_bprop or self.is_last_microbatch:
|
|
461
|
+
self.grad_reduce_pipeline.reduce_gradients(
|
|
462
|
+
list(self._params_require_handle_grad),
|
|
463
|
+
suggested_queue_capacity=self.suggested_RS_queue_capacity,
|
|
378
464
|
)
|
|
465
|
+
self.grad_reduce_pipeline.reset()
|
|
466
|
+
|
|
467
|
+
# Reset root_pre_backward_hook_issued flag.
|
|
468
|
+
self._root_pre_backward_hook_issued = False
|
|
379
469
|
|
|
380
470
|
def _pre_backward(module: nn.Module, *unused):
|
|
381
471
|
module._training_state = TrainingState.PRE_BACKWARD
|
|
@@ -384,6 +474,8 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
384
474
|
module, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER
|
|
385
475
|
)
|
|
386
476
|
|
|
477
|
+
self._root_pre_backward_hook_issued = False
|
|
478
|
+
|
|
387
479
|
def _root_pre_backward(module: nn.Module, *unused):
|
|
388
480
|
"""Marks the module's training state as 'pre_backward' before the
|
|
389
481
|
backprop, this function is registered on the root module.
|
|
@@ -392,13 +484,26 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
392
484
|
perform reshard/unshard operations in activation recomputation
|
|
393
485
|
scenarios.
|
|
394
486
|
"""
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
487
|
+
if self._root_pre_backward_hook_issued:
|
|
488
|
+
return
|
|
489
|
+
self._root_pre_backward_hook_issued = True
|
|
490
|
+
|
|
491
|
+
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
|
|
492
|
+
for module in root_module.modules():
|
|
493
|
+
if isinstance(module, tuple(fsdp_unit_modules)):
|
|
494
|
+
module._training_state = TrainingState.PRE_BACKWARD
|
|
495
|
+
for param in module.parameters():
|
|
496
|
+
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
497
|
+
self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True)
|
|
498
|
+
self.all_gather_pipeline.release_bucket(bucket_id)
|
|
499
|
+
self._params_require_handle_grad = set()
|
|
500
|
+
for param_group in self.param_and_grad_buffer.parameter_groups:
|
|
501
|
+
if not param_group.requires_grad:
|
|
502
|
+
continue
|
|
503
|
+
self._params_require_handle_grad |= set(param_group.params)
|
|
504
|
+
for param in param_group.params:
|
|
505
|
+
param.grad_added_to_main_grad = False
|
|
506
|
+
torch.autograd.Variable._execution_engine.queue_callback(_root_post_backward)
|
|
402
507
|
|
|
403
508
|
def _post_forward(module: nn.Module, input: Any, output: Any):
|
|
404
509
|
# When composing with module-hook-based activation checkpointing, the
|
|
@@ -414,7 +519,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
414
519
|
def _release_module_fp8_transpose_cache(module: nn.Module, *unused):
|
|
415
520
|
release_params_fp8_transpose_cache(module.parameters(recurse=False))
|
|
416
521
|
|
|
417
|
-
if
|
|
522
|
+
if len(fsdp_unit_modules) != 0:
|
|
418
523
|
fsdp_modules = []
|
|
419
524
|
for name, module in root_module.named_modules():
|
|
420
525
|
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
|
|
@@ -434,68 +539,20 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
434
539
|
_release_module_fp8_transpose_cache, prepend=False
|
|
435
540
|
)
|
|
436
541
|
)
|
|
437
|
-
self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
|
|
438
|
-
_root_pre_backward
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
def _make_param_hook(param: torch.nn.Parameter):
|
|
442
|
-
"""
|
|
443
|
-
Creates the all-reduce / reduce-scatter hook for backprop.
|
|
444
|
-
"""
|
|
445
542
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
if param.requires_grad:
|
|
453
|
-
if self.ddp_config.overlap_grad_reduce:
|
|
454
|
-
assert (
|
|
455
|
-
param.grad is not None
|
|
456
|
-
), 'param.grad being None is not safe when overlap_grad_reduce is True'
|
|
457
|
-
|
|
458
|
-
if param.grad is not None and (
|
|
459
|
-
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
|
|
460
|
-
):
|
|
461
|
-
if self.is_delay_grad_reduce:
|
|
462
|
-
param.main_grad.add_(param.grad.data)
|
|
463
|
-
else:
|
|
464
|
-
param.main_grad.copy_(param.grad.data)
|
|
465
|
-
param.grad = None
|
|
466
|
-
|
|
467
|
-
if self.ddp_config.overlap_grad_reduce and (
|
|
468
|
-
not self.is_delay_grad_reduce or self.is_last_microbatch
|
|
469
|
-
):
|
|
470
|
-
gr_pipeline = self.grad_reduce_pipeline
|
|
471
|
-
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
472
|
-
gr_pipeline.place_bucket(bucket_id)
|
|
473
|
-
go_rs = gr_pipeline.mark_item_ready(param, async_rs=True)
|
|
474
|
-
if go_rs and wait_previous_grad_reduce:
|
|
475
|
-
gr_pipeline.wait_for_previous_grad_reduce(
|
|
476
|
-
recommeded_queue_capacity=self.suggested_RS_queue_capacity
|
|
477
|
-
)
|
|
543
|
+
# Registering all models with all parameters is to handle some special cases
|
|
544
|
+
# where the forward function of root_module is not called, but the forward
|
|
545
|
+
# functions of these equivalent modules are called instead.
|
|
546
|
+
for name, module in root_module.named_modules():
|
|
547
|
+
if len(list(module.parameters())) != len(list(root_module.parameters())):
|
|
548
|
+
continue
|
|
478
549
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
self.
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
wbuf = self.param_and_grad_buffer.parameter_groups[bucket_id].model_weight_buffer
|
|
486
|
-
if param.requires_grad:
|
|
487
|
-
if wbuf and wbuf.is_data_distributed:
|
|
488
|
-
wbuf.fetch_bucket(and_allocate_params_data=True)
|
|
489
|
-
|
|
490
|
-
# Expand so we get access to grad_fn.
|
|
491
|
-
param_tmp = param.expand_as(param)
|
|
492
|
-
# Get the gradient accumulator function.
|
|
493
|
-
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
494
|
-
grad_acc.register_hook(_make_param_hook(param))
|
|
495
|
-
self.grad_accs.append(grad_acc)
|
|
496
|
-
|
|
497
|
-
if wbuf and wbuf.is_data_distributed:
|
|
498
|
-
wbuf.free_bucket_storage()
|
|
550
|
+
self.backward_pre_hooks[f"{name} _root_pre_backward"] = (
|
|
551
|
+
module.register_full_backward_pre_hook(_root_pre_backward)
|
|
552
|
+
)
|
|
553
|
+
self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
|
|
554
|
+
_root_pre_backward
|
|
555
|
+
)
|
|
499
556
|
|
|
500
557
|
@contextmanager
|
|
501
558
|
def no_sync(self):
|
|
@@ -526,7 +583,8 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
526
583
|
"""
|
|
527
584
|
if not force_sync and self.ddp_config.overlap_param_gather:
|
|
528
585
|
# All-gather the first bucket before the forward pass.
|
|
529
|
-
self.
|
|
586
|
+
first_param = list(self.module.parameters())[0]
|
|
587
|
+
self.all_gather_pipeline.all_gather_params(params=[first_param], prefetch=False)
|
|
530
588
|
else:
|
|
531
589
|
self.all_gather_pipeline.reset()
|
|
532
590
|
for bucket_id in range(self.all_gather_pipeline.num_buckets):
|