megatron-core 0.12.0rc2__tar.gz → 0.12.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- {megatron_core-0.12.0rc2/megatron_core.egg-info → megatron_core-0.12.1}/PKG-INFO +1 -1
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/async_utils.py +29 -11
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +214 -159
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +275 -186
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/distributed_data_parallel_config.py +7 -3
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/finalize_model_grads.py +3 -2
- megatron_core-0.12.1/megatron/core/export/model_type.py +8 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +6 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +10 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trt_model_config.py +1 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trt_model_type.py +1 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_helper.py +20 -2
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_layers.py +9 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +10 -4
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +17 -5
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/extensions/transformer_engine.py +34 -8
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fp8_utils.py +15 -8
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_bias_swiglu.py +57 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/contexts/dynamic_context.py +19 -1
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/dynamic_engine.py +8 -2
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +6 -3
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/rope_utils.py +20 -3
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/gpt/gpt_layer_specs.py +27 -6
- megatron_core-0.12.1/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +209 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/__init__.py +16 -5
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/optimizer.py +99 -21
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/package_info.py +2 -2
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/gpt/model_specs.py +12 -4
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/__init__.py +2 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/random.py +149 -23
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/attention.py +4 -1
- megatron_core-0.12.1/megatron/core/transformer/heterogeneous/heterogeneous_config.py +267 -0
- megatron_core-0.12.1/megatron/core/transformer/heterogeneous/linear_replacements.py +111 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/mlp.py +37 -18
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/experts.py +166 -60
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py +18 -11
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/moe_layer.py +20 -10
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/moe_utils.py +91 -19
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/shared_experts.py +4 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/token_dispatcher.py +63 -64
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/multi_latent_attention.py +121 -69
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/torch_norm.py +49 -1
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/transformer_block.py +22 -5
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/transformer_config.py +88 -14
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/transformer_layer.py +51 -5
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/utils.py +25 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1/megatron_core.egg-info}/PKG-INFO +1 -1
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron_core.egg-info/SOURCES.txt +3 -0
- megatron_core-0.12.0rc2/megatron/core/export/model_type.py +0 -7
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/LICENSE +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/MANIFEST.in +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/README.md +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/README.md +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/config_logger.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/bert_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/blended_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/gpt_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/helpers.cpp +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/helpers.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/indexed_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/masked_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/megatron_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/megatron_tokenizer.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/multimodal_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/db/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/db/build.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/db/dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/db/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/external_libs.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/build.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/factory.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/index.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/validate.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/query.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/t5_dataset.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/utils_s3.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/core.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/mapping.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/optimizer.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/serialization.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/validation.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/data_parallel_base.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/distributed_data_parallel.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/enums.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/data_type.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/export_config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/extensions/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_bias_dropout.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_bias_geglu.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_bias_gelu.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_cross_entropy.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_layer_norm.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_softmax.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/async_stream.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/common_inference_params.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/communication_utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/contexts/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/contexts/base_context.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/contexts/static_context.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/abstract_engine.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/mcore_engine.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/static_engine.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/inference_request.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/gpt/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/gpt/model_specs.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/mamba/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/mamba/model_specs.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/sampling_params.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/scheduler.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference_params.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/jit.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/model_parallel_config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/T5/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/T5/t5_model.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/T5/t5_spec.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/bert_layer_specs.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/bert_lm_head.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/bert_model.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/pooler.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/language_module/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/language_module/language_module.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/vision_module/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/vision_module/vision_module.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/gpt/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/gpt/gpt_model.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/gpt/moe_module_specs.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/huggingface/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/huggingface/clip_model.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/huggingface/module.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/huggingface/qwen_model.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/mamba/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/mamba/mamba_model.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/multimodal/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/multimodal/context_parallel.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/multimodal/llava_model.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/multimodal/llava_spec.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/base_attention.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/decoder_attention.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/decoder_spec.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/encoder_attention.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/encoder_spec.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/model.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/clip_vit_model.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/multimodal_projector.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/radio.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/vit_layer_specs.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/num_microbatches_calculator.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/clip_grads.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/distrib_optimizer.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/grad_scaler.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/optimizer_config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer_param_scheduler.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/packed_seq_params.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/parallel_state.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/pipeline_parallel/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/pipeline_parallel/schedules.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/layers.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/process_groups_config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/requirements.txt +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/rerun_state_machine.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_block.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_config.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_layer.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_mixer.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mlp_layer.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/triton_cache_manager.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/data.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/layers.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/mappings.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/timers.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/cuda_graphs.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/custom_layers/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/dot_product_attention.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/enums.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/identity_op.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/module.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/__init__.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/fused_a2a.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/router.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/multi_token_prediction.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/spec_utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/torch_layer_norm.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/utils.py +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron_core.egg-info/dependency_links.txt +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron_core.egg-info/requires.txt +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron_core.egg-info/top_level.txt +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/pyproject.toml +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/requirements/pytorch_24.01/requirements.txt +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/requirements/pytorch_24.07/requirements.txt +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/requirements/pytorch_24.10/requirements.txt +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/requirements/pytorch_25.03/requirements.txt +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/setup.cfg +0 -0
- {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: megatron-core
|
|
3
|
-
Version: 0.12.
|
|
3
|
+
Version: 0.12.1
|
|
4
4
|
Summary: Megatron Core - a library for efficient and scalable training of transformer based models
|
|
5
5
|
Home-page: https://github.com/NVIDIA/Megatron-LM/megatron/core
|
|
6
6
|
Download-URL: https://github.com/NVIDIA/Megatron-LM/releases
|
|
@@ -155,7 +155,7 @@ class AsyncCaller(ABC):
|
|
|
155
155
|
logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")
|
|
156
156
|
|
|
157
157
|
def __del__(self):
|
|
158
|
-
|
|
158
|
+
raise NotImplementedError("This should be implemented")
|
|
159
159
|
|
|
160
160
|
|
|
161
161
|
class TemporalAsyncCaller(AsyncCaller):
|
|
@@ -227,12 +227,22 @@ class TemporalAsyncCaller(AsyncCaller):
|
|
|
227
227
|
is_alive = int(self.process.is_alive()) if self.process is not None else 0
|
|
228
228
|
is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
|
|
229
229
|
|
|
230
|
-
if
|
|
230
|
+
if is_done or blocking:
|
|
231
|
+
# Process join is called in the following cases
|
|
232
|
+
# 1. blocking == True -> regardless of is_done
|
|
233
|
+
# 2. blocking == False (non-blocking)
|
|
234
|
+
# -> is_done == True: async requests on all ranks are identified to be finished
|
|
235
|
+
# `self.close()` makes sure the async callers terminated
|
|
231
236
|
self.close()
|
|
232
237
|
is_done = True
|
|
233
238
|
return is_done
|
|
234
239
|
|
|
235
240
|
def close(self):
|
|
241
|
+
"""For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done`
|
|
242
|
+
|
|
243
|
+
This method make sure the TemporalAsyncCaller terminated
|
|
244
|
+
with all its assigned async request completed
|
|
245
|
+
"""
|
|
236
246
|
if self.process:
|
|
237
247
|
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
|
|
238
248
|
self.process.join()
|
|
@@ -243,6 +253,9 @@ class TemporalAsyncCaller(AsyncCaller):
|
|
|
243
253
|
)
|
|
244
254
|
self.start_time = None
|
|
245
255
|
|
|
256
|
+
def __del__(self):
|
|
257
|
+
pass
|
|
258
|
+
|
|
246
259
|
|
|
247
260
|
class PersistentAsyncCaller(AsyncCaller):
|
|
248
261
|
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
|
|
@@ -376,6 +389,10 @@ class PersistentAsyncCaller(AsyncCaller):
|
|
|
376
389
|
return is_done
|
|
377
390
|
|
|
378
391
|
def close(self):
|
|
392
|
+
"""Wait on the left async requests and terminate the PersistentAsyncCaller
|
|
393
|
+
|
|
394
|
+
Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated
|
|
395
|
+
"""
|
|
379
396
|
logger.info(
|
|
380
397
|
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller"
|
|
381
398
|
)
|
|
@@ -385,6 +402,9 @@ class PersistentAsyncCaller(AsyncCaller):
|
|
|
385
402
|
self.process.join()
|
|
386
403
|
self.process = None
|
|
387
404
|
|
|
405
|
+
def __del__(self):
|
|
406
|
+
self.close()
|
|
407
|
+
|
|
388
408
|
@staticmethod
|
|
389
409
|
@_disable_gc()
|
|
390
410
|
def async_loop(
|
|
@@ -492,13 +512,11 @@ class AsyncCallsQueue:
|
|
|
492
512
|
# Backward compatibility for local checkpointing built with the old AsyncRequest
|
|
493
513
|
if len(async_request._fields) != len(AsyncRequest._fields):
|
|
494
514
|
async_request = AsyncRequest(**async_request._asdict())
|
|
495
|
-
|
|
496
|
-
async_request = async_request._replace(call_idx=self.call_idx)
|
|
497
|
-
finalize_fns = async_request.finalize_fns
|
|
498
|
-
async_request = async_request._replace(finalize_fns=None)
|
|
499
515
|
async_request = async_request.freeze()
|
|
500
|
-
async_caller.schedule_async_call(
|
|
501
|
-
|
|
516
|
+
async_caller.schedule_async_call(
|
|
517
|
+
async_request._replace(call_idx=self.call_idx, finalize_fns=[])
|
|
518
|
+
)
|
|
519
|
+
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
|
|
502
520
|
return self.call_idx
|
|
503
521
|
|
|
504
522
|
def maybe_finalize_async_calls(self, blocking=False, no_dist=False) -> List[int]:
|
|
@@ -522,13 +540,13 @@ class AsyncCallsQueue:
|
|
|
522
540
|
if not next_async_done:
|
|
523
541
|
break
|
|
524
542
|
with debug_time("finalize", logger):
|
|
525
|
-
call_idx, _,
|
|
543
|
+
call_idx, _, async_request = self.async_calls.popleft()
|
|
544
|
+
for finalize_fn in async_request.finalize_fns:
|
|
545
|
+
finalize_fn()
|
|
526
546
|
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
|
|
527
547
|
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
|
|
528
548
|
assert ten.item() == call_idx, 'Unmatched async calls. '
|
|
529
549
|
'That probably means not all ranks are participating in async finalization'
|
|
530
|
-
for finalize_fn in finalize_fns:
|
|
531
|
-
finalize_fn()
|
|
532
550
|
call_idx_finalized.append(call_idx)
|
|
533
551
|
return call_idx_finalized
|
|
534
552
|
|
|
@@ -22,7 +22,6 @@ from megatron.core.distributed.custom_fsdp.param_and_grad_buffer import (
|
|
|
22
22
|
from megatron.core.distributed.data_parallel_base import _BaseDataParallel
|
|
23
23
|
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
|
|
24
24
|
from megatron.core.fp8_utils import is_float8tensor
|
|
25
|
-
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
|
|
26
25
|
from megatron.core.transformer.transformer_config import TransformerConfig
|
|
27
26
|
from megatron.core.transformer.transformer_layer import TransformerLayer
|
|
28
27
|
from megatron.core.utils import is_submodule, log_single_rank
|
|
@@ -77,7 +76,10 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
77
76
|
module: Underlying model.
|
|
78
77
|
fsdp_unit_modules: List of modules that should be treated as FSDP Unit,
|
|
79
78
|
i.e., the minimum releasable model unit. If not provided, defaults to
|
|
80
|
-
[TransformerLayer, LanguageModelEmbedding] for GPT-like models.
|
|
79
|
+
[TransformerLayer, LanguageModelEmbedding] for GPT-like models. In
|
|
80
|
+
addition to this, it affects the granularity of the communication
|
|
81
|
+
parameter grouping and triggers aggregate collective communication
|
|
82
|
+
in fp8 mixed precision training.
|
|
81
83
|
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
|
|
82
84
|
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
|
|
83
85
|
per bucket.
|
|
@@ -123,9 +125,10 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
123
125
|
if fsdp_unit_modules is not None:
|
|
124
126
|
self.fsdp_unit_modules = fsdp_unit_modules
|
|
125
127
|
else:
|
|
126
|
-
self.
|
|
127
|
-
|
|
128
|
-
|
|
128
|
+
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
|
|
129
|
+
self.fsdp_unit_modules = [TransformerLayer]
|
|
130
|
+
else:
|
|
131
|
+
self.fsdp_unit_modules = []
|
|
129
132
|
self.main_weights = True
|
|
130
133
|
self.data_parallel_group = parallel_state.get_data_parallel_group(
|
|
131
134
|
with_context_parallel=True
|
|
@@ -180,14 +183,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
180
183
|
self.module,
|
|
181
184
|
bucketing_policy=BucketingPolicy(
|
|
182
185
|
suggested_bucket_size=self.bucket_size,
|
|
183
|
-
fsdp_unit_modules=
|
|
184
|
-
# Only when model weights need to be sharded, we need to
|
|
185
|
-
# identify the minimum releasable model unit, which is the
|
|
186
|
-
# FSDP Unit Module.
|
|
187
|
-
self.fsdp_unit_modules
|
|
188
|
-
if self.data_parallel_sharding_strategy == "optim_grads_params"
|
|
189
|
-
else []
|
|
190
|
-
),
|
|
186
|
+
fsdp_unit_modules=self.fsdp_unit_modules,
|
|
191
187
|
data_parallel_sharding_strategy=self.data_parallel_sharding_strategy,
|
|
192
188
|
),
|
|
193
189
|
data_parallel_group=self.data_parallel_group,
|
|
@@ -211,8 +207,24 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
211
207
|
# Initialize the all-gather pipeline.
|
|
212
208
|
self.all_gather_pipeline = AllGatherPipeline(self.param_and_grad_buffer)
|
|
213
209
|
|
|
214
|
-
|
|
215
|
-
|
|
210
|
+
suggested_communication_unit_size = self.ddp_config.suggested_communication_unit_size
|
|
211
|
+
if suggested_communication_unit_size is None:
|
|
212
|
+
if self.data_parallel_sharding_strategy == "optim_grads_params":
|
|
213
|
+
total_param_elements = 0
|
|
214
|
+
total_fsdp_module = 0
|
|
215
|
+
for module in self.module.modules():
|
|
216
|
+
if isinstance(module, tuple(self.fsdp_unit_modules)):
|
|
217
|
+
total_fsdp_module += 1
|
|
218
|
+
total_param_elements += sum(p.numel() for p in module.parameters())
|
|
219
|
+
# The suggested size is twice the number of elements in the FSDP modules.
|
|
220
|
+
# This ensures we process the current FSDP module and attempt to prefetch
|
|
221
|
+
# the next FSDP module, making the flow of communication better.
|
|
222
|
+
suggested_communication_unit_size = total_param_elements // total_fsdp_module * 2
|
|
223
|
+
elif self.bucket_size is not None:
|
|
224
|
+
suggested_communication_unit_size = self.bucket_size * 2
|
|
225
|
+
|
|
226
|
+
self.suggested_RS_queue_capacity = suggested_communication_unit_size
|
|
227
|
+
self.suggested_AG_prefetch_size = suggested_communication_unit_size
|
|
216
228
|
|
|
217
229
|
def _register_fsdp_hooks(self, root_module):
|
|
218
230
|
"""Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
|
|
@@ -222,8 +234,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
222
234
|
- Pre-forward hook: Unshards parameters before forward pass
|
|
223
235
|
- Post-forward hook: Reshards parameters after forward pass
|
|
224
236
|
- Pre-backward hook: Unshards parameters before backward pass
|
|
225
|
-
- Post-backward hook: Reshards parameters after backward pass
|
|
226
|
-
- Gradient accumulation hook: Handles gradient accumulation and reduction across devices
|
|
237
|
+
- Post-backward hook: Reshards parameters and reduces gradients after backward pass
|
|
227
238
|
|
|
228
239
|
Args:
|
|
229
240
|
root_module: The PyTorch module to register FSDP hooks on
|
|
@@ -257,10 +268,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
257
268
|
`optim` and `optim_grads` do not require FSDP units because they do not
|
|
258
269
|
shard model parameters.
|
|
259
270
|
"""
|
|
260
|
-
|
|
261
|
-
fsdp_unit_modules = []
|
|
262
|
-
else:
|
|
263
|
-
fsdp_unit_modules = self.fsdp_unit_modules
|
|
271
|
+
fsdp_unit_modules = self.fsdp_unit_modules
|
|
264
272
|
|
|
265
273
|
def release_module_parameters(module, *unused):
|
|
266
274
|
for param in module.parameters():
|
|
@@ -283,27 +291,74 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
283
291
|
prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
|
|
284
292
|
wait_bucket_ready=True,
|
|
285
293
|
):
|
|
286
|
-
wait_list = []
|
|
287
294
|
ag_pipeline = self.all_gather_pipeline
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
295
|
-
)
|
|
296
|
-
wait_list.append(bucket_id)
|
|
297
|
-
|
|
295
|
+
ag_pipeline.all_gather_params(
|
|
296
|
+
params=list(module.parameters()),
|
|
297
|
+
prefetch=prefetch,
|
|
298
|
+
prefetch_order=prefetch_order,
|
|
299
|
+
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
300
|
+
)
|
|
298
301
|
if wait_bucket_ready:
|
|
299
|
-
for
|
|
302
|
+
for param in module.parameters():
|
|
303
|
+
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
300
304
|
ag_pipeline.wait_bucket_ready(bucket_id)
|
|
301
305
|
|
|
306
|
+
def _grad_acc(param):
|
|
307
|
+
"""
|
|
308
|
+
Accumulate the gradient in the main_grad buffer.
|
|
309
|
+
"""
|
|
310
|
+
group_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
311
|
+
group = self.param_and_grad_buffer.parameter_groups[group_id]
|
|
312
|
+
if not group.requires_grad:
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
overwrite_main_grad = self.ddp_config.data_parallel_sharding_strategy in [
|
|
316
|
+
"optim_grads",
|
|
317
|
+
"optim_grads_params",
|
|
318
|
+
]
|
|
319
|
+
if overwrite_main_grad:
|
|
320
|
+
if not param.grad_added_to_main_grad:
|
|
321
|
+
if param.grad is not None:
|
|
322
|
+
param.main_grad.copy_(param.grad)
|
|
323
|
+
del param.grad
|
|
324
|
+
else:
|
|
325
|
+
param.main_grad.zero_()
|
|
326
|
+
else:
|
|
327
|
+
if not param.grad_added_to_main_grad:
|
|
328
|
+
if param.grad is not None:
|
|
329
|
+
param.main_grad.add_(param.grad)
|
|
330
|
+
del param.grad
|
|
331
|
+
# Reset the grad accumulate flag.
|
|
332
|
+
param.grad_added_to_main_grad = False
|
|
333
|
+
|
|
334
|
+
self._params_require_handle_grad = set()
|
|
335
|
+
|
|
302
336
|
def _post_backward(module, *unused):
|
|
303
|
-
|
|
304
|
-
|
|
337
|
+
if isinstance(module, tuple(fsdp_unit_modules)):
|
|
338
|
+
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
|
|
339
|
+
release_module_parameters(module)
|
|
340
|
+
module._training_state = TrainingState.IDLE
|
|
341
|
+
param_list = list(module.parameters())
|
|
342
|
+
else:
|
|
343
|
+
param_list = list(module.parameters(recurse=False))
|
|
344
|
+
|
|
345
|
+
for param in param_list:
|
|
346
|
+
_grad_acc(param)
|
|
347
|
+
self._params_require_handle_grad.discard(param)
|
|
348
|
+
|
|
349
|
+
grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
|
|
350
|
+
"optim_grads",
|
|
351
|
+
"optim_grads_params",
|
|
352
|
+
]
|
|
353
|
+
if grad_reduce_every_bprop or self.is_last_microbatch:
|
|
354
|
+
self.grad_reduce_pipeline.reduce_gradients(
|
|
355
|
+
param_list, suggested_queue_capacity=self.suggested_RS_queue_capacity
|
|
356
|
+
)
|
|
305
357
|
|
|
306
|
-
def
|
|
358
|
+
def _pre_forward_param_unshard(
|
|
359
|
+
module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
|
360
|
+
):
|
|
361
|
+
# Unshard the parameters before the forward pass.
|
|
307
362
|
input_training_state = module._training_state
|
|
308
363
|
fsdp_forward_prefetch = True
|
|
309
364
|
if input_training_state == TrainingState.PRE_BACKWARD:
|
|
@@ -313,72 +368,104 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
313
368
|
module._training_state = TrainingState.FORWARD
|
|
314
369
|
|
|
315
370
|
if isinstance(module, tuple(fsdp_unit_modules)):
|
|
316
|
-
|
|
317
|
-
|
|
371
|
+
param_list = list(module.parameters())
|
|
372
|
+
self.all_gather_pipeline.all_gather_params(
|
|
373
|
+
params=param_list,
|
|
374
|
+
prefetch=fsdp_forward_prefetch,
|
|
375
|
+
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
376
|
+
)
|
|
377
|
+
for param in param_list:
|
|
318
378
|
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
319
|
-
self.all_gather_pipeline.queue_bucket_to_all_gather(
|
|
320
|
-
bucket_id,
|
|
321
|
-
prefetch=fsdp_forward_prefetch,
|
|
322
|
-
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
323
|
-
)
|
|
324
|
-
wait_list.append(bucket_id)
|
|
325
|
-
for bucket_id in wait_list:
|
|
326
379
|
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
|
|
327
|
-
|
|
328
|
-
if not torch.is_grad_enabled():
|
|
329
|
-
return args, kwargs
|
|
330
|
-
|
|
331
|
-
# Register the backward function to release the parameters.
|
|
332
|
-
args_list, args_spec = tree_flatten(args)
|
|
333
|
-
kwargs_list, kwargs_spec = tree_flatten(kwargs)
|
|
334
|
-
args_kwargs_list = list(args_list) + list(kwargs_list)
|
|
335
|
-
inp_tensor_indices: List[int] = []
|
|
336
|
-
inp_tensors: List[torch.Tensor] = []
|
|
337
|
-
for i, obj in enumerate(args_kwargs_list):
|
|
338
|
-
if torch.is_tensor(obj) and obj.requires_grad:
|
|
339
|
-
inp_tensor_indices.append(i)
|
|
340
|
-
inp_tensors.append(obj)
|
|
341
|
-
if len(inp_tensors) == 0:
|
|
342
|
-
return args, kwargs
|
|
343
|
-
inp_tensors = RegisterFSDPBackwardFunction.apply(
|
|
344
|
-
functools.partial(_post_backward, module), *inp_tensors
|
|
345
|
-
)
|
|
346
|
-
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
|
|
347
|
-
args_kwargs_list[inp_tensor_idx] = inp_tensor
|
|
348
|
-
args_list = args_kwargs_list[: len(args_list)]
|
|
349
|
-
kwargs_list = args_kwargs_list[len(args_list) :]
|
|
350
|
-
args = tree_unflatten(args_list, args_spec)
|
|
351
|
-
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
|
|
352
|
-
|
|
353
|
-
return args, kwargs
|
|
354
380
|
else:
|
|
355
381
|
# All-gather the parameters in every forward pass for FSDP.
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
for param in module.parameters(recurse=False):
|
|
382
|
+
param_list = list(module.parameters(recurse=False))
|
|
383
|
+
self.all_gather_pipeline.all_gather_params(
|
|
384
|
+
params=param_list,
|
|
385
|
+
prefetch=fsdp_forward_prefetch,
|
|
386
|
+
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
|
|
387
|
+
)
|
|
388
|
+
for param in param_list:
|
|
364
389
|
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
365
390
|
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
|
|
391
|
+
return args, kwargs
|
|
392
|
+
|
|
393
|
+
def _register_post_backward_hook(
|
|
394
|
+
post_backward_hook: callable,
|
|
395
|
+
module: nn.Module,
|
|
396
|
+
args: Tuple[Any, ...],
|
|
397
|
+
kwargs: Dict[str, Any],
|
|
398
|
+
):
|
|
399
|
+
# Register the backward function to reduce gradients after the backward pass.
|
|
400
|
+
# And for optim_grads_params, we need to release the parameters after the backward pass.
|
|
401
|
+
if not torch.is_grad_enabled():
|
|
402
|
+
return args, kwargs
|
|
403
|
+
|
|
404
|
+
args_list, args_spec = tree_flatten(args)
|
|
405
|
+
kwargs_list, kwargs_spec = tree_flatten(kwargs)
|
|
406
|
+
args_kwargs_list = list(args_list) + list(kwargs_list)
|
|
407
|
+
inp_tensor_indices: List[int] = []
|
|
408
|
+
inp_tensors: List[torch.Tensor] = []
|
|
409
|
+
for i, obj in enumerate(args_kwargs_list):
|
|
410
|
+
if torch.is_tensor(obj) and obj.requires_grad:
|
|
411
|
+
inp_tensor_indices.append(i)
|
|
412
|
+
inp_tensors.append(obj)
|
|
413
|
+
|
|
414
|
+
if len(inp_tensors) == 0:
|
|
415
|
+
return args, kwargs
|
|
416
|
+
|
|
417
|
+
inp_tensors = RegisterFSDPBackwardFunction.apply(
|
|
418
|
+
functools.partial(post_backward_hook, module), *inp_tensors
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
|
|
422
|
+
args_kwargs_list[inp_tensor_idx] = inp_tensor
|
|
423
|
+
args_list = args_kwargs_list[: len(args_list)]
|
|
424
|
+
kwargs_list = args_kwargs_list[len(args_list) :]
|
|
425
|
+
args = tree_unflatten(args_list, args_spec)
|
|
426
|
+
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
|
|
366
427
|
|
|
367
428
|
return args, kwargs
|
|
368
429
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
|
|
374
|
-
continue
|
|
430
|
+
fsdp_modules = []
|
|
431
|
+
for name, module in root_module.named_modules():
|
|
432
|
+
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
|
|
433
|
+
continue
|
|
375
434
|
|
|
376
|
-
|
|
377
|
-
|
|
435
|
+
if isinstance(module, tuple(fsdp_unit_modules)):
|
|
436
|
+
fsdp_modules.append(module)
|
|
437
|
+
|
|
438
|
+
self.forward_pre_hooks[f'module {name} parameter unshard'] = (
|
|
439
|
+
module.register_forward_pre_hook(
|
|
440
|
+
_pre_forward_param_unshard, prepend=True, with_kwargs=True
|
|
441
|
+
)
|
|
442
|
+
)
|
|
443
|
+
self.forward_pre_hooks[f"module {name} register post-backward hook"] = (
|
|
444
|
+
module.register_forward_pre_hook(
|
|
445
|
+
functools.partial(_register_post_backward_hook, _post_backward),
|
|
446
|
+
with_kwargs=True,
|
|
447
|
+
)
|
|
448
|
+
)
|
|
378
449
|
|
|
379
|
-
|
|
380
|
-
|
|
450
|
+
def _root_post_backward(*unused):
|
|
451
|
+
# Make sure all the gradients are handled.
|
|
452
|
+
for param in self._params_require_handle_grad:
|
|
453
|
+
_grad_acc(param)
|
|
454
|
+
|
|
455
|
+
# Reduce the remain gradients.
|
|
456
|
+
grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
|
|
457
|
+
"optim_grads",
|
|
458
|
+
"optim_grads_params",
|
|
459
|
+
]
|
|
460
|
+
if grad_reduce_every_bprop or self.is_last_microbatch:
|
|
461
|
+
self.grad_reduce_pipeline.reduce_gradients(
|
|
462
|
+
list(self._params_require_handle_grad),
|
|
463
|
+
suggested_queue_capacity=self.suggested_RS_queue_capacity,
|
|
381
464
|
)
|
|
465
|
+
self.grad_reduce_pipeline.reset()
|
|
466
|
+
|
|
467
|
+
# Reset root_pre_backward_hook_issued flag.
|
|
468
|
+
self._root_pre_backward_hook_issued = False
|
|
382
469
|
|
|
383
470
|
def _pre_backward(module: nn.Module, *unused):
|
|
384
471
|
module._training_state = TrainingState.PRE_BACKWARD
|
|
@@ -387,6 +474,8 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
387
474
|
module, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER
|
|
388
475
|
)
|
|
389
476
|
|
|
477
|
+
self._root_pre_backward_hook_issued = False
|
|
478
|
+
|
|
390
479
|
def _root_pre_backward(module: nn.Module, *unused):
|
|
391
480
|
"""Marks the module's training state as 'pre_backward' before the
|
|
392
481
|
backprop, this function is registered on the root module.
|
|
@@ -395,13 +484,26 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
395
484
|
perform reshard/unshard operations in activation recomputation
|
|
396
485
|
scenarios.
|
|
397
486
|
"""
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
487
|
+
if self._root_pre_backward_hook_issued:
|
|
488
|
+
return
|
|
489
|
+
self._root_pre_backward_hook_issued = True
|
|
490
|
+
|
|
491
|
+
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
|
|
492
|
+
for module in root_module.modules():
|
|
493
|
+
if isinstance(module, tuple(fsdp_unit_modules)):
|
|
494
|
+
module._training_state = TrainingState.PRE_BACKWARD
|
|
495
|
+
for param in module.parameters():
|
|
496
|
+
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
497
|
+
self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True)
|
|
498
|
+
self.all_gather_pipeline.release_bucket(bucket_id)
|
|
499
|
+
self._params_require_handle_grad = set()
|
|
500
|
+
for param_group in self.param_and_grad_buffer.parameter_groups:
|
|
501
|
+
if not param_group.requires_grad:
|
|
502
|
+
continue
|
|
503
|
+
self._params_require_handle_grad |= set(param_group.params)
|
|
504
|
+
for param in param_group.params:
|
|
505
|
+
param.grad_added_to_main_grad = False
|
|
506
|
+
torch.autograd.Variable._execution_engine.queue_callback(_root_post_backward)
|
|
405
507
|
|
|
406
508
|
def _post_forward(module: nn.Module, input: Any, output: Any):
|
|
407
509
|
# When composing with module-hook-based activation checkpointing, the
|
|
@@ -417,7 +519,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
417
519
|
def _release_module_fp8_transpose_cache(module: nn.Module, *unused):
|
|
418
520
|
release_params_fp8_transpose_cache(module.parameters(recurse=False))
|
|
419
521
|
|
|
420
|
-
if
|
|
522
|
+
if len(fsdp_unit_modules) != 0:
|
|
421
523
|
fsdp_modules = []
|
|
422
524
|
for name, module in root_module.named_modules():
|
|
423
525
|
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
|
|
@@ -437,68 +539,20 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
437
539
|
_release_module_fp8_transpose_cache, prepend=False
|
|
438
540
|
)
|
|
439
541
|
)
|
|
440
|
-
self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
|
|
441
|
-
_root_pre_backward
|
|
442
|
-
)
|
|
443
|
-
|
|
444
|
-
def _make_param_hook(param: torch.nn.Parameter):
|
|
445
|
-
"""
|
|
446
|
-
Creates the all-reduce / reduce-scatter hook for backprop.
|
|
447
|
-
"""
|
|
448
542
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
if param.requires_grad:
|
|
456
|
-
if self.ddp_config.overlap_grad_reduce:
|
|
457
|
-
assert (
|
|
458
|
-
param.grad is not None
|
|
459
|
-
), 'param.grad being None is not safe when overlap_grad_reduce is True'
|
|
460
|
-
|
|
461
|
-
if param.grad is not None and (
|
|
462
|
-
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
|
|
463
|
-
):
|
|
464
|
-
if self.is_delay_grad_reduce:
|
|
465
|
-
param.main_grad.add_(param.grad.data)
|
|
466
|
-
else:
|
|
467
|
-
param.main_grad.copy_(param.grad.data)
|
|
468
|
-
param.grad = None
|
|
469
|
-
|
|
470
|
-
if self.ddp_config.overlap_grad_reduce and (
|
|
471
|
-
not self.is_delay_grad_reduce or self.is_last_microbatch
|
|
472
|
-
):
|
|
473
|
-
gr_pipeline = self.grad_reduce_pipeline
|
|
474
|
-
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
|
|
475
|
-
gr_pipeline.place_bucket(bucket_id)
|
|
476
|
-
go_rs = gr_pipeline.mark_item_ready(param, async_rs=True)
|
|
477
|
-
if go_rs and wait_previous_grad_reduce:
|
|
478
|
-
gr_pipeline.wait_for_previous_grad_reduce(
|
|
479
|
-
recommeded_queue_capacity=self.suggested_RS_queue_capacity
|
|
480
|
-
)
|
|
543
|
+
# Registering all models with all parameters is to handle some special cases
|
|
544
|
+
# where the forward function of root_module is not called, but the forward
|
|
545
|
+
# functions of these equivalent modules are called instead.
|
|
546
|
+
for name, module in root_module.named_modules():
|
|
547
|
+
if len(list(module.parameters())) != len(list(root_module.parameters())):
|
|
548
|
+
continue
|
|
481
549
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
self.
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
wbuf = self.param_and_grad_buffer.parameter_groups[bucket_id].model_weight_buffer
|
|
489
|
-
if param.requires_grad:
|
|
490
|
-
if wbuf and wbuf.is_data_distributed:
|
|
491
|
-
wbuf.fetch_bucket(and_allocate_params_data=True)
|
|
492
|
-
|
|
493
|
-
# Expand so we get access to grad_fn.
|
|
494
|
-
param_tmp = param.expand_as(param)
|
|
495
|
-
# Get the gradient accumulator function.
|
|
496
|
-
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
497
|
-
grad_acc.register_hook(_make_param_hook(param))
|
|
498
|
-
self.grad_accs.append(grad_acc)
|
|
499
|
-
|
|
500
|
-
if wbuf and wbuf.is_data_distributed:
|
|
501
|
-
wbuf.free_bucket_storage()
|
|
550
|
+
self.backward_pre_hooks[f"{name} _root_pre_backward"] = (
|
|
551
|
+
module.register_full_backward_pre_hook(_root_pre_backward)
|
|
552
|
+
)
|
|
553
|
+
self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
|
|
554
|
+
_root_pre_backward
|
|
555
|
+
)
|
|
502
556
|
|
|
503
557
|
@contextmanager
|
|
504
558
|
def no_sync(self):
|
|
@@ -529,7 +583,8 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
529
583
|
"""
|
|
530
584
|
if not force_sync and self.ddp_config.overlap_param_gather:
|
|
531
585
|
# All-gather the first bucket before the forward pass.
|
|
532
|
-
self.
|
|
586
|
+
first_param = list(self.module.parameters())[0]
|
|
587
|
+
self.all_gather_pipeline.all_gather_params(params=[first_param], prefetch=False)
|
|
533
588
|
else:
|
|
534
589
|
self.all_gather_pipeline.reset()
|
|
535
590
|
for bucket_id in range(self.all_gather_pipeline.num_buckets):
|