megatron-core 0.14.0rc6__tar.gz → 0.14.0rc7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- {megatron_core-0.14.0rc6/megatron_core.egg-info → megatron_core-0.14.0rc7}/PKG-INFO +1 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/__init__.py +6 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/mapping.py +0 -6
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/common.py +6 -6
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/__init__.py +1 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/distributed_data_parallel_config.py +20 -6
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/finalize_model_grads.py +27 -14
- megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/__init__.py +3 -0
- megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +317 -0
- megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/__init__.py +13 -0
- megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +22 -0
- megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +141 -0
- megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +387 -0
- megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +1107 -0
- {megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp → megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp}/param_and_grad_buffer.py +1658 -522
- megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +458 -0
- megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +908 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/param_and_grad_buffer.py +6 -7
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +8 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/extensions/transformer_engine.py +14 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/language_module/language_module.py +19 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/huggingface/clip_model.py +1 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/huggingface/qwen_model.py +1 -1
- megatron_core-0.14.0rc7/megatron/core/nccl_allocator.py +249 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/__init__.py +3 -22
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/clip_grads.py +15 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/distrib_optimizer.py +155 -129
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/optimizer.py +3 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/optimizer_config.py +6 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/package_info.py +1 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/parallel_state.py +6 -3
- megatron_core-0.14.0rc7/megatron/core/safe_globals.py +33 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/layers.py +8 -8
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/cuda_graphs.py +318 -7
- megatron_core-0.14.0rc7/megatron/core/transformer/fsdp_dtensor_checkpoint.py +195 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/experts.py +1 -25
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/transformer_config.py +5 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/transformer_layer.py +1 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/utils.py +0 -3
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/utils.py +4 -41
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7/megatron_core.egg-info}/PKG-INFO +1 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron_core.egg-info/SOURCES.txt +13 -3
- megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp/__init__.py +0 -3
- megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -835
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/LICENSE +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/MANIFEST.in +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/README.md +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/README.md +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/activations.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/config_logger.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/bert_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/blended_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/helpers.cpp +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/helpers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/indexed_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/masked_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/megatron_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/megatron_tokenizer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/multimodal_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/object_storage_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/db/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/db/build.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/db/dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/db/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/external_libs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/build.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/factory.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/index.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/validate.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/query.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/t5_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/utils_object_storage.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/utils_s3.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/core.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/optimizer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/serialization.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/validation.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/data_parallel_base.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/distributed_data_parallel.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/energy_monitor.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/enums.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/data_type.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/export_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/model_type.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trt_model_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trt_model_type.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/extensions/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/extensions/kitchen.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fp8_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/full_cuda_graph.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_bias_dropout.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_bias_geglu.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_bias_gelu.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_cross_entropy.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_indices_converter.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_layer_norm.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_softmax.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_weighted_squared_relu.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/hyper_comm_grid.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/async_stream.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/common_inference_params.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/communication_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/base_context.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/dynamic_context.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/static_context.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/abstract_engine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/dynamic_engine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/mcore_engine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/static_engine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/inference_request.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/sampling_params.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/scheduler.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference_params.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/jit.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/model_parallel_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/T5/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/T5/t5_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/T5/t5_spec.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/backends.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/bert_layer_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/bert_lm_head.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/bert_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/pooler.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/language_module/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/vision_module/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/vision_module/vision_module.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/gpt_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/moe_module_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/huggingface/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/huggingface/module.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mamba/mamba_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/config/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/config/base_configs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/model/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/model/base.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/submodules/audio.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/submodules/base.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/submodules/vision.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/multimodal/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/multimodal/context_parallel.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/multimodal/llava_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/multimodal/llava_spec.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/base_attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/decoder_attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/decoder_spec.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/encoder_attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/encoder_spec.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/clip_vit_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/multimodal_projector.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/radio.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/vit_layer_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/msc_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/num_microbatches_calculator.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/grad_scaler.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer_param_scheduler.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/packed_seq_params.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/schedules.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/layers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/process_groups_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/quantization/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/quantization/quant_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/quantization/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/requirements.txt +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/rerun_state_machine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_block.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_context_parallel.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_layer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_mixer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mlp_layer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/triton_cache_manager.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/data.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/mappings.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/random.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/timers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/custom_layers/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/dot_product_attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/enums.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/identity_op.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/mlp.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/module.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/fused_a2a.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/moe_layer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/moe_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/router.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/shared_experts.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/multi_latent_attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/multi_token_prediction.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/spec_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/torch_layer_norm.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/torch_norm.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/transformer_block.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron_core.egg-info/dependency_links.txt +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron_core.egg-info/requires.txt +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron_core.egg-info/top_level.txt +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/pyproject.toml +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/setup.cfg +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: megatron-core
|
|
3
|
-
Version: 0.14.
|
|
3
|
+
Version: 0.14.0rc7
|
|
4
4
|
Summary: Megatron Core - a library for efficient and scalable training of transformer based models
|
|
5
5
|
Author-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
6
6
|
Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
@@ -20,6 +20,7 @@ from megatron.core.package_info import (
|
|
|
20
20
|
__version__,
|
|
21
21
|
)
|
|
22
22
|
from megatron.core.timers import Timers
|
|
23
|
+
from megatron.core.utils import is_torch_min_version
|
|
23
24
|
|
|
24
25
|
# Alias parallel_state as mpu, its legacy name
|
|
25
26
|
mpu = parallel_state
|
|
@@ -33,3 +34,8 @@ __all__ = [
|
|
|
33
34
|
"ModelParallelConfig",
|
|
34
35
|
"Timers",
|
|
35
36
|
]
|
|
37
|
+
|
|
38
|
+
from .safe_globals import register_safe_globals
|
|
39
|
+
|
|
40
|
+
if is_torch_min_version("2.6a0"):
|
|
41
|
+
register_safe_globals()
|
{megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/mapping.py
RENAMED
|
@@ -136,12 +136,6 @@ class ShardedTensor(ShardedBase):
|
|
|
136
136
|
)
|
|
137
137
|
|
|
138
138
|
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
|
|
139
|
-
# NOTE: In custom FSDP, we have a case where a new parameter shard is created locally.
|
|
140
|
-
# For example, consider parameters [p0, p1, p2] sharded across GPU0 and GPU1.
|
|
141
|
-
# GPU0 receives p0 and a portion of p1, while GPU1 receives the
|
|
142
|
-
# remaining portion of p1 and p2.
|
|
143
|
-
# As a result, there is no parameter shard of p2 on GPU0, and
|
|
144
|
-
# the shape of p2 on GPU0 is zero.
|
|
145
139
|
if sh != 0 and off % sh != 0:
|
|
146
140
|
raise CheckpointingException(
|
|
147
141
|
f"Global offset ({off}) must be divisible by local shape ({sh}) for {self}."
|
|
@@ -84,9 +84,9 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
|
|
|
84
84
|
try:
|
|
85
85
|
if MultiStorageClientFeature.is_enabled():
|
|
86
86
|
msc = MultiStorageClientFeature.import_package()
|
|
87
|
-
return msc.torch.load(load_path, map_location='cpu'
|
|
87
|
+
return msc.torch.load(load_path, map_location='cpu')
|
|
88
88
|
else:
|
|
89
|
-
return torch.load(load_path, map_location='cpu'
|
|
89
|
+
return torch.load(load_path, map_location='cpu')
|
|
90
90
|
except FileNotFoundError as e:
|
|
91
91
|
err_msg = f'Common file {load_path} does not exist'
|
|
92
92
|
if MultiStorageClientFeature.is_enabled():
|
|
@@ -118,9 +118,9 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
|
|
|
118
118
|
try:
|
|
119
119
|
if MultiStorageClientFeature.is_enabled():
|
|
120
120
|
msc = MultiStorageClientFeature.import_package()
|
|
121
|
-
loaded_obj = msc.torch.load(load_path
|
|
121
|
+
loaded_obj = msc.torch.load(load_path)
|
|
122
122
|
else:
|
|
123
|
-
loaded_obj = torch.load(load_path
|
|
123
|
+
loaded_obj = torch.load(load_path)
|
|
124
124
|
except FileNotFoundError as e:
|
|
125
125
|
# Backward compatible logic: previously the save format was incorrect
|
|
126
126
|
base, _ = os.path.splitext(sh_obj.unique_key)
|
|
@@ -128,9 +128,9 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
|
|
|
128
128
|
try:
|
|
129
129
|
if MultiStorageClientFeature.is_enabled():
|
|
130
130
|
msc = MultiStorageClientFeature.import_package()
|
|
131
|
-
loaded_obj = msc.torch.load(old_load_path
|
|
131
|
+
loaded_obj = msc.torch.load(old_load_path)
|
|
132
132
|
else:
|
|
133
|
-
loaded_obj = torch.load(old_load_path
|
|
133
|
+
loaded_obj = torch.load(old_load_path)
|
|
134
134
|
except FileNotFoundError:
|
|
135
135
|
err_msg = f'Object shard {load_path} not found'
|
|
136
136
|
obj_subdir = os.path.join(checkpoint_dir, sh_obj.key)
|
|
@@ -8,5 +8,6 @@ except ImportError:
|
|
|
8
8
|
from .distributed_data_parallel import DistributedDataParallel
|
|
9
9
|
from .distributed_data_parallel_config import DistributedDataParallelConfig
|
|
10
10
|
from .finalize_model_grads import finalize_model_grads
|
|
11
|
+
from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel
|
|
11
12
|
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
|
|
12
13
|
from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig
|
|
@@ -61,9 +61,16 @@ class DistributedDataParallelConfig:
|
|
|
61
61
|
"""If true, reuse the grad buffer for param AG when using mxfp8 recipe. Should be
|
|
62
62
|
set to True only when fp8_recipe is mxfp8 and fp8_param_gather is True."""
|
|
63
63
|
|
|
64
|
-
|
|
64
|
+
use_megatron_fsdp: bool = False
|
|
65
65
|
"""If true, use the FSDP code path for DDP."""
|
|
66
66
|
|
|
67
|
+
use_custom_fsdp: bool = False
|
|
68
|
+
"""
|
|
69
|
+
NOTE: The flag `use_custom_fsdp` is deprecated and will be removed in future versions.
|
|
70
|
+
Please use `use_megatron_fsdp` instead, as all functionality will be migrated there.
|
|
71
|
+
Future updates will drop support for `use_custom_fsdp` to avoid confusion.
|
|
72
|
+
"""
|
|
73
|
+
|
|
67
74
|
data_parallel_sharding_strategy: str = 'no_shard'
|
|
68
75
|
"""Sharding strategy for FSDP. Valid values are 'no_shard', 'optim',
|
|
69
76
|
'optim_grads', 'optim_grads_params'."""
|
|
@@ -80,10 +87,10 @@ class DistributedDataParallelConfig:
|
|
|
80
87
|
based on your system's memory and performance requirements."""
|
|
81
88
|
|
|
82
89
|
preserve_fp32_weights: bool = True
|
|
83
|
-
"""If true, preserve fp32 weights in the
|
|
90
|
+
"""If true, preserve fp32 weights in the Megatron FSDP ParamAndGradBuffer."""
|
|
84
91
|
|
|
85
|
-
|
|
86
|
-
"""If true, keep the fp8 transpose cache when using
|
|
92
|
+
keep_fp8_transpose_cache: bool = False
|
|
93
|
+
"""If true, keep the fp8 transpose cache when using Megatron FSDP."""
|
|
87
94
|
|
|
88
95
|
nccl_ub: bool = False
|
|
89
96
|
"""If true, allocate and register NCCL userbuffer for param and grad buffer.
|
|
@@ -106,12 +113,19 @@ class DistributedDataParallelConfig:
|
|
|
106
113
|
|
|
107
114
|
fsdp_double_buffer: bool = False
|
|
108
115
|
"""If true, use persistently allocated double buffers for the
|
|
109
|
-
temporary memory needed in the
|
|
116
|
+
temporary memory needed in the Megatron FSDP communications.
|
|
110
117
|
This option will cause additional memory overhead, however, it is necessary for
|
|
111
|
-
to register user buffer (nccl_ub=True) for the
|
|
118
|
+
to register user buffer (nccl_ub=True) for the Megatron FSDP.
|
|
112
119
|
This option will be automatically set to True when nccl_ub=True.
|
|
113
120
|
"""
|
|
114
121
|
|
|
122
|
+
outer_dp_sharding_strategy: str = 'no_shard'
|
|
123
|
+
"""
|
|
124
|
+
Sharding strategy for outer data parallel group in Hybrid Sharded Data Parallel (HSDP) mode.
|
|
125
|
+
Valid values are 'no_shard', 'optim', 'optim_grads', 'optim_grads_params'.
|
|
126
|
+
This option is only effective when Hybrid FSDP is enabled.
|
|
127
|
+
"""
|
|
128
|
+
|
|
115
129
|
def __post_init__(self):
|
|
116
130
|
import os
|
|
117
131
|
|
|
@@ -31,9 +31,7 @@ from ..utils import (
|
|
|
31
31
|
)
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def _get_main_grad_attr(param: torch.nn.Parameter,
|
|
35
|
-
if use_custom_fsdp:
|
|
36
|
-
return "fsdp_managed_main_grad"
|
|
34
|
+
def _get_main_grad_attr(param: torch.nn.Parameter, use_megatron_fsdp: bool = False):
|
|
37
35
|
if hasattr(param, "main_grad"):
|
|
38
36
|
return "main_grad"
|
|
39
37
|
return "grad"
|
|
@@ -241,8 +239,10 @@ def _allreduce_embedding_grad(
|
|
|
241
239
|
if weight is None and skip_if_none:
|
|
242
240
|
return
|
|
243
241
|
|
|
244
|
-
grad_attr = _get_main_grad_attr(weight, ddp_config.
|
|
242
|
+
grad_attr = _get_main_grad_attr(weight, ddp_config.use_megatron_fsdp)
|
|
245
243
|
orig_grad = getattr(weight, grad_attr)
|
|
244
|
+
if ddp_config.use_megatron_fsdp:
|
|
245
|
+
orig_grad = orig_grad._local_tensor if orig_grad is not None else None
|
|
246
246
|
grad = _unshard_if_dtensor(orig_grad)
|
|
247
247
|
# When the embedding is frozen, the grad is None.
|
|
248
248
|
if grad is None and skip_if_none:
|
|
@@ -320,20 +320,30 @@ def _allreduce_non_tensor_model_parallel_grads(
|
|
|
320
320
|
if param.requires_grad:
|
|
321
321
|
# Check if this param needs average reduction (average_gradients_across_tp_domain)
|
|
322
322
|
if getattr(param, "average_gradients_across_tp_domain", False):
|
|
323
|
-
|
|
324
|
-
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
|
|
323
|
+
grad_attr = _get_main_grad_attr(param, ddp_config.use_megatron_fsdp)
|
|
325
324
|
grad = getattr(param, grad_attr)
|
|
326
|
-
grad
|
|
327
|
-
|
|
325
|
+
if grad is None:
|
|
326
|
+
continue
|
|
327
|
+
params_avg.append(param)
|
|
328
|
+
if ddp_config.use_megatron_fsdp:
|
|
329
|
+
grads_avg.append(grad._local_tensor.data)
|
|
330
|
+
else:
|
|
331
|
+
grad = _unshard_if_dtensor(grad)
|
|
332
|
+
grads_avg.append(grad.data)
|
|
328
333
|
# Check if this param needs sum reduction (sequence parallel or qk_layernorm)
|
|
329
334
|
elif (config.sequence_parallel and getattr(param, "sequence_parallel", False)) or (
|
|
330
335
|
config.qk_layernorm and ("q_layernorm" in name or "k_layernorm" in name)
|
|
331
336
|
):
|
|
332
|
-
|
|
333
|
-
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
|
|
337
|
+
grad_attr = _get_main_grad_attr(param, ddp_config.use_megatron_fsdp)
|
|
334
338
|
grad = getattr(param, grad_attr)
|
|
335
|
-
grad
|
|
336
|
-
|
|
339
|
+
if grad is None:
|
|
340
|
+
continue
|
|
341
|
+
params_sum.append(param)
|
|
342
|
+
if ddp_config.use_megatron_fsdp:
|
|
343
|
+
grads_sum.append(grad._local_tensor.data)
|
|
344
|
+
else:
|
|
345
|
+
grad = _unshard_if_dtensor(grad)
|
|
346
|
+
grads_sum.append(grad.data)
|
|
337
347
|
|
|
338
348
|
# Loop grads and perform correct all-reduce
|
|
339
349
|
for params, grads, all_reduce_op in zip(
|
|
@@ -348,9 +358,12 @@ def _allreduce_non_tensor_model_parallel_grads(
|
|
|
348
358
|
params, grads, _unflatten_dense_tensors(coalesced, grads)
|
|
349
359
|
):
|
|
350
360
|
buf.copy_(synced)
|
|
351
|
-
grad_attr = _get_main_grad_attr(param, ddp_config.
|
|
361
|
+
grad_attr = _get_main_grad_attr(param, ddp_config.use_megatron_fsdp)
|
|
352
362
|
orig_grad = getattr(param, grad_attr)
|
|
353
|
-
|
|
363
|
+
if ddp_config.use_megatron_fsdp:
|
|
364
|
+
setattr(param, grad_attr, orig_grad)
|
|
365
|
+
else:
|
|
366
|
+
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
|
|
354
367
|
|
|
355
368
|
|
|
356
369
|
"""
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from typing import List, Optional
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import einops
|
|
20
|
+
|
|
21
|
+
HAVE_EINOPS = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
HAVE_EINOPS = False
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.distributed as dist
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from torch.distributed import DeviceMesh
|
|
30
|
+
|
|
31
|
+
HAVE_DTENSOR = True
|
|
32
|
+
except ImportError:
|
|
33
|
+
HAVE_DTENSOR = False
|
|
34
|
+
|
|
35
|
+
from megatron.core import parallel_state
|
|
36
|
+
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
|
|
37
|
+
from megatron.core.distributed.data_parallel_base import _BaseDataParallel
|
|
38
|
+
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
|
|
39
|
+
from megatron.core.process_groups_config import GradCommProcessGroups, ModelCommProcessGroups
|
|
40
|
+
from megatron.core.transformer.transformer_config import TransformerConfig
|
|
41
|
+
from megatron.core.transformer.transformer_layer import TransformerLayer
|
|
42
|
+
from megatron.core.utils import log_single_rank
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
from megatron.core.distributed.fsdp.src.megatron_fsdp import FSDPDistributedIndex, MegatronFSDP
|
|
46
|
+
|
|
47
|
+
HAVE_MEGATRON_FSDP = True
|
|
48
|
+
except ImportError as import_megatron_fsdp_error:
|
|
49
|
+
IMPORT_MEGATRON_FSDP_ERROR = import_megatron_fsdp_error
|
|
50
|
+
HAVE_MEGATRON_FSDP = False
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger(__name__)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class FullyShardedDataParallel(_BaseDataParallel):
|
|
56
|
+
"""
|
|
57
|
+
Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
config: TransformerConfig,
|
|
63
|
+
ddp_config: DistributedDataParallelConfig,
|
|
64
|
+
module: torch.nn.Module,
|
|
65
|
+
fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
|
|
66
|
+
disable_bucketing: bool = False,
|
|
67
|
+
device: Optional[torch.device] = None,
|
|
68
|
+
grad_comm_pgs: Optional[GradCommProcessGroups] = None,
|
|
69
|
+
model_comm_pgs: Optional[ModelCommProcessGroups] = None,
|
|
70
|
+
):
|
|
71
|
+
if not HAVE_MEGATRON_FSDP:
|
|
72
|
+
raise IMPORT_MEGATRON_FSDP_ERROR
|
|
73
|
+
|
|
74
|
+
if has_config_logger_enabled(config):
|
|
75
|
+
log_config_to_disk(config, locals(), prefix=type(self).__name__)
|
|
76
|
+
|
|
77
|
+
self.ddp_config = ddp_config
|
|
78
|
+
log_single_rank(
|
|
79
|
+
logger,
|
|
80
|
+
logging.INFO,
|
|
81
|
+
f'Setting up DistributedDataParallel with config {self.ddp_config}',
|
|
82
|
+
)
|
|
83
|
+
self.megatron_fsdp_dist_index = self._init_dist_index(grad_comm_pgs, model_comm_pgs)
|
|
84
|
+
|
|
85
|
+
self.bucket_size = self.ddp_config.bucket_size
|
|
86
|
+
if disable_bucketing:
|
|
87
|
+
self.bucket_size = None
|
|
88
|
+
self.device = device if device else torch.device(f'cuda:{torch.cuda.current_device()}')
|
|
89
|
+
|
|
90
|
+
if fsdp_unit_modules is not None:
|
|
91
|
+
self.fsdp_unit_modules = fsdp_unit_modules
|
|
92
|
+
else:
|
|
93
|
+
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
|
|
94
|
+
self.fsdp_unit_modules = [TransformerLayer]
|
|
95
|
+
else:
|
|
96
|
+
self.fsdp_unit_modules = []
|
|
97
|
+
|
|
98
|
+
super().__init__(
|
|
99
|
+
config=config,
|
|
100
|
+
module=MegatronFSDP(
|
|
101
|
+
ddp_config=ddp_config,
|
|
102
|
+
module=module,
|
|
103
|
+
fsdp_unit_modules=self.fsdp_unit_modules,
|
|
104
|
+
disable_bucketing=disable_bucketing,
|
|
105
|
+
device=self.device,
|
|
106
|
+
dist_index=self.megatron_fsdp_dist_index,
|
|
107
|
+
calculate_per_token_loss=config.calculate_per_token_loss,
|
|
108
|
+
init_model_with_meta_device=config.init_model_with_meta_device,
|
|
109
|
+
),
|
|
110
|
+
)
|
|
111
|
+
self.param_and_grad_buffer = self.module.param_and_grad_buffer
|
|
112
|
+
self.no_sync = self.module.no_sync
|
|
113
|
+
self.start_param_sync = self.module.start_param_sync
|
|
114
|
+
self.start_grad_sync = self.module.start_grad_sync
|
|
115
|
+
self.finish_grad_sync = self.module.finish_grad_sync
|
|
116
|
+
self.scale_gradients = self.module.scale_gradients
|
|
117
|
+
self.zero_grad_buffer = self.module.zero_grad_buffer
|
|
118
|
+
self.broadcast_params = self.module.broadcast_params
|
|
119
|
+
self.module.state_dict_for_save_checkpoint = self.module.state_dict
|
|
120
|
+
self.state_dict_for_save_checkpoint = self.state_dict
|
|
121
|
+
|
|
122
|
+
def load_state_dict(self, state_dict, strict=True):
|
|
123
|
+
"""
|
|
124
|
+
Load the state dictionary into the module.
|
|
125
|
+
"""
|
|
126
|
+
custom_state_dict = {}
|
|
127
|
+
for key, value in state_dict.items():
|
|
128
|
+
if self.config.fp8 and key.endswith('._extra_state'):
|
|
129
|
+
# Skip extra state keys
|
|
130
|
+
continue
|
|
131
|
+
custom_state_dict[f"module.{key}"] = value
|
|
132
|
+
|
|
133
|
+
if self.config.fp8 or self.config.gated_linear_unit:
|
|
134
|
+
strict = False
|
|
135
|
+
log_single_rank(
|
|
136
|
+
logger,
|
|
137
|
+
logging.WARNING,
|
|
138
|
+
"Loading state_dict with strict=False due to fp8 configuration. "
|
|
139
|
+
"This is expected as some keys may not match exactly.",
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
self.module.load_state_dict(custom_state_dict, strict=strict)
|
|
143
|
+
|
|
144
|
+
def _init_dist_index(self, grad_comm_pgs, model_comm_pgs):
|
|
145
|
+
"""
|
|
146
|
+
Initialize the distributed index for the module.
|
|
147
|
+
"""
|
|
148
|
+
if not HAVE_DTENSOR:
|
|
149
|
+
raise ImportError(
|
|
150
|
+
"This module requires PyTorch with DTensor support. "
|
|
151
|
+
"Please install a compatible version of PyTorch."
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
enable_hsdp = self.ddp_config.num_distributed_optimizer_instances > 1
|
|
155
|
+
if grad_comm_pgs is None and model_comm_pgs is None:
|
|
156
|
+
tp_group = parallel_state.get_tensor_model_parallel_group()
|
|
157
|
+
if enable_hsdp:
|
|
158
|
+
dp_cp_group = parallel_state.get_data_parallel_group(
|
|
159
|
+
with_context_parallel=True, partial_data_parallel=True
|
|
160
|
+
)
|
|
161
|
+
inter_fsdp_group = parallel_state.get_inter_distributed_optimizer_instance_group()
|
|
162
|
+
hybrid_fsdp_group = parallel_state.get_data_parallel_group(
|
|
163
|
+
with_context_parallel=True, partial_data_parallel=False
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
dp_cp_group = parallel_state.get_data_parallel_group(
|
|
167
|
+
with_context_parallel=True, partial_data_parallel=False
|
|
168
|
+
)
|
|
169
|
+
inter_fsdp_group = None
|
|
170
|
+
hybrid_fsdp_group = None
|
|
171
|
+
elif grad_comm_pgs is not None and model_comm_pgs is not None:
|
|
172
|
+
tp_group = getattr(model_comm_pgs, 'tp', None)
|
|
173
|
+
if enable_hsdp:
|
|
174
|
+
dp_cp_group = grad_comm_pgs.intra_dp_cp
|
|
175
|
+
inter_fsdp_group = grad_comm_pgs.inter_dist_opt
|
|
176
|
+
hybrid_fsdp_group = grad_comm_pgs.dp_cp
|
|
177
|
+
else:
|
|
178
|
+
dp_cp_group = grad_comm_pgs.dp_cp
|
|
179
|
+
inter_fsdp_group = None
|
|
180
|
+
hybrid_fsdp_group = None
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
"Both grad_comm_pgs and model_comm_pgs must be either None or provided together."
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
if tp_group is None:
|
|
187
|
+
single_rank_group = dist.new_group(ranks=[dist.get_rank()])
|
|
188
|
+
tp_group = single_rank_group
|
|
189
|
+
|
|
190
|
+
if enable_hsdp:
|
|
191
|
+
mesh = _get_hsdp_tp_mesh(inter_fsdp_group, dp_cp_group, tp_group)
|
|
192
|
+
dist_index = FSDPDistributedIndex(
|
|
193
|
+
use_hybrid_fsdp=True,
|
|
194
|
+
hsdp_outer_dp_shard=self.ddp_config.outer_dp_sharding_strategy != "no_shard",
|
|
195
|
+
device_mesh=DeviceMesh.from_group(
|
|
196
|
+
[inter_fsdp_group, dp_cp_group, tp_group],
|
|
197
|
+
device_type="cuda",
|
|
198
|
+
mesh=mesh.tolist(),
|
|
199
|
+
mesh_dim_names=["inter_fsdp_dp", "dp_cp", "tp"],
|
|
200
|
+
),
|
|
201
|
+
dp_inter_dim="inter_fsdp_dp",
|
|
202
|
+
dp_shard_dim="dp_cp",
|
|
203
|
+
tp_dim="tp",
|
|
204
|
+
hybrid_fsdp_group=hybrid_fsdp_group,
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
mesh = _get_dp_tp_mesh(dp_cp_group, tp_group)
|
|
208
|
+
dist_index = FSDPDistributedIndex(
|
|
209
|
+
device_mesh=DeviceMesh.from_group(
|
|
210
|
+
[dp_cp_group, tp_group],
|
|
211
|
+
device_type="cuda",
|
|
212
|
+
mesh=mesh.tolist(),
|
|
213
|
+
mesh_dim_names=["dp_cp", "tp"],
|
|
214
|
+
),
|
|
215
|
+
dp_shard_dim="dp_cp",
|
|
216
|
+
tp_dim="tp",
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return dist_index
|
|
220
|
+
|
|
221
|
+
def stop_communication(self):
|
|
222
|
+
"""
|
|
223
|
+
Stop communication for the module.
|
|
224
|
+
"""
|
|
225
|
+
self.module.synchronize_gradient_reduce()
|
|
226
|
+
self.module.synchronize_param_gather()
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _get_hsdp_tp_mesh(inter_fsdp_dp_group, dp_cp_group, tp_group):
|
|
230
|
+
assert HAVE_EINOPS, "einops is not installed. Please install it with `pip install einops`."
|
|
231
|
+
world_size = dist.get_world_size()
|
|
232
|
+
|
|
233
|
+
mesh = einops.rearrange(
|
|
234
|
+
torch.arange(world_size),
|
|
235
|
+
"(inter_fsdp_dp fsdp tp) -> inter_fsdp_dp fsdp tp",
|
|
236
|
+
inter_fsdp_dp=inter_fsdp_dp_group.size(),
|
|
237
|
+
tp=tp_group.size(),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
mesh_fsdp_ranks = einops.rearrange(
|
|
241
|
+
mesh,
|
|
242
|
+
'inter_fsdp_dp fsdp tp -> (inter_fsdp_dp tp) fsdp',
|
|
243
|
+
tp=tp_group.size(),
|
|
244
|
+
fsdp=dp_cp_group.size(),
|
|
245
|
+
)
|
|
246
|
+
fsdp_group_ranks = dist.get_process_group_ranks(dp_cp_group)
|
|
247
|
+
assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_fsdp_ranks, fsdp_group_ranks), (
|
|
248
|
+
f"[Megatron-FSDP] FSDP ranks in the mesh {mesh_fsdp_ranks} "
|
|
249
|
+
f"do not match the ranks in the FSDP group {fsdp_group_ranks}."
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
mesh_tp_ranks = einops.rearrange(
|
|
253
|
+
mesh,
|
|
254
|
+
'inter_fsdp_dp fsdp tp -> (inter_fsdp_dp fsdp) tp',
|
|
255
|
+
tp=tp_group.size(),
|
|
256
|
+
fsdp=dp_cp_group.size(),
|
|
257
|
+
)
|
|
258
|
+
tp_group_ranks = dist.get_process_group_ranks(tp_group)
|
|
259
|
+
assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_tp_ranks, tp_group_ranks), (
|
|
260
|
+
f"[Megatron-FSDP] Tensor Parallel ranks in the mesh {mesh_tp_ranks} "
|
|
261
|
+
f"do not match the ranks in the TP group {tp_group_ranks}."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
mesh_inter_fsdp_dp_ranks = einops.rearrange(
|
|
265
|
+
mesh,
|
|
266
|
+
'inter_fsdp_dp fsdp tp -> (fsdp tp) inter_fsdp_dp',
|
|
267
|
+
tp=tp_group.size(),
|
|
268
|
+
fsdp=dp_cp_group.size(),
|
|
269
|
+
)
|
|
270
|
+
inter_fsdp_dp_group_ranks = dist.get_process_group_ranks(inter_fsdp_dp_group)
|
|
271
|
+
assert _check_mesh_ranks_and_group_ranks_are_consistent(
|
|
272
|
+
mesh_inter_fsdp_dp_ranks, inter_fsdp_dp_group_ranks
|
|
273
|
+
), (
|
|
274
|
+
f"[Megatron-FSDP] Inter FSDP Data Parallel ranks in the mesh {mesh_inter_fsdp_dp_ranks} "
|
|
275
|
+
f"do not match the ranks in the Inter FSDP DP group {inter_fsdp_dp_group_ranks}."
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return mesh
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _get_dp_tp_mesh(dp_cp_group, tp_group):
|
|
282
|
+
assert HAVE_EINOPS, "einops is not installed. Please install it with `pip install einops`."
|
|
283
|
+
world_size = dist.get_world_size()
|
|
284
|
+
|
|
285
|
+
tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
|
|
286
|
+
# TODO: Supports configurable (dp, cp, tp) order.
|
|
287
|
+
mesh = einops.rearrange(torch.arange(world_size), "(dp_cp tp) -> dp_cp tp", tp=tp_size)
|
|
288
|
+
|
|
289
|
+
mesh_dp_ranks = einops.rearrange(mesh, 'dp_cp tp -> tp dp_cp', tp=tp_size)
|
|
290
|
+
dp_cp_group_ranks = dist.get_process_group_ranks(dp_cp_group)
|
|
291
|
+
assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_dp_ranks, dp_cp_group_ranks), (
|
|
292
|
+
f"[Megatron-FSDP] Data Parallel ranks in the mesh {mesh_dp_ranks} "
|
|
293
|
+
f"do not match the ranks in the DP group {dp_cp_group_ranks}."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
mesh_tp_ranks = einops.rearrange(mesh, 'dp_cp tp -> (dp_cp) tp', tp=tp_size)
|
|
297
|
+
tp_group_ranks = dist.get_process_group_ranks(tp_group)
|
|
298
|
+
assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_tp_ranks, tp_group_ranks), (
|
|
299
|
+
f"[Megatron-FSDP] Tensor Parallel ranks in the mesh {mesh_tp_ranks} "
|
|
300
|
+
f"do not match the ranks in the TP group {tp_group_ranks}."
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
return mesh
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _check_mesh_ranks_and_group_ranks_are_consistent(mesh_ranks, group_ranks):
|
|
307
|
+
current_rank = dist.get_rank()
|
|
308
|
+
current_ranks = list(filter(lambda ranks: current_rank in ranks, mesh_ranks.tolist()))
|
|
309
|
+
assert len(current_ranks) == 1, (
|
|
310
|
+
f"[Megatron-FSDP] Current rank {current_rank} is not unique in "
|
|
311
|
+
f"the mesh ranks {mesh_ranks.tolist()}."
|
|
312
|
+
)
|
|
313
|
+
assert sorted(current_ranks[0]) == sorted(group_ranks), (
|
|
314
|
+
f"[Megatron-FSDP] Current rank {current_rank} in the mesh ranks "
|
|
315
|
+
f"{mesh_ranks.tolist()} does not match the group ranks {group_ranks}."
|
|
316
|
+
)
|
|
317
|
+
return sorted(current_ranks[0]) == sorted(group_ranks)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .distributed_data_parallel_config import DistributedDataParallelConfig
|
|
16
|
+
from .megatron_fsdp import MegatronFSDP
|
|
17
|
+
from .utils import FSDPDistributedIndex
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from .fully_shard import fully_shard
|
|
21
|
+
except ImportError as e:
|
|
22
|
+
print(f"Failed to import fully_shard: {e}")
|