megatron-core 0.14.0rc6__tar.gz → 0.15.0rc0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- {megatron_core-0.14.0rc6/megatron_core.egg-info → megatron_core-0.15.0rc0}/PKG-INFO +3 -3
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/dict_utils.py +13 -5
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/mapping.py +11 -11
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/optimizer.py +6 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/async_utils.py +52 -14
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/base.py +1 -5
- megatron_core-0.15.0rc0/megatron/core/dist_checkpointing/strategies/checkpointable.py +196 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/torch.py +38 -15
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/zarr.py +6 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/validation.py +13 -3
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/__init__.py +1 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/distributed_data_parallel_config.py +20 -6
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/finalize_model_grads.py +27 -14
- megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/__init__.py +3 -0
- megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +317 -0
- megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/__init__.py +13 -0
- megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +22 -0
- megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +141 -0
- megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +387 -0
- megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +1107 -0
- {megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp → megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp}/param_and_grad_buffer.py +1649 -522
- megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +458 -0
- megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +908 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/param_and_grad_buffer.py +5 -7
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +8 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/extensions/kitchen.py +4 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/extensions/transformer_engine.py +72 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/extensions/transformer_engine_spec_provider.py +5 -0
- megatron_core-0.15.0rc0/megatron/core/inference/data_parallel_inference_coordinator.py +322 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/dynamic_engine.py +323 -6
- megatron_core-0.15.0rc0/megatron/core/inference/headers.py +17 -0
- megatron_core-0.15.0rc0/megatron/core/inference/inference_client.py +190 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/inference_request.py +11 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/sampling_params.py +11 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +19 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/model_parallel_config.py +2 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/backends.py +9 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +23 -21
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/language_module/language_module.py +19 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/gpt_layer_specs.py +13 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/moe_module_specs.py +7 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/huggingface/clip_model.py +1 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/huggingface/qwen_model.py +1 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/multimodal/context_parallel.py +25 -13
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/multimodal/llava_model.py +5 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/multimodal_projector.py +35 -30
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/radio.py +26 -0
- megatron_core-0.15.0rc0/megatron/core/nccl_allocator.py +249 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/__init__.py +3 -22
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/clip_grads.py +15 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/distrib_optimizer.py +556 -248
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/optimizer.py +15 -10
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/optimizer_config.py +6 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/package_info.py +2 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +1 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/layers.py +8 -8
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/random.py +5 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/attention.py +30 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/cuda_graphs.py +16 -5
- megatron_core-0.15.0rc0/megatron/core/transformer/fsdp_dtensor_checkpoint.py +195 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/mlp.py +20 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/experts.py +25 -31
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/moe_layer.py +28 -1
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/shared_experts.py +33 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/multi_latent_attention.py +28 -3
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/transformer_config.py +55 -5
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/transformer_layer.py +12 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/utils.py +0 -3
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/utils.py +4 -41
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0/megatron_core.egg-info}/PKG-INFO +3 -3
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron_core.egg-info/SOURCES.txt +16 -3
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron_core.egg-info/requires.txt +2 -2
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/pyproject.toml +3 -3
- megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp/__init__.py +0 -3
- megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -835
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/LICENSE +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/MANIFEST.in +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/README.md +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/README.md +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/activations.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/config_logger.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/bert_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/blended_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/helpers.cpp +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/helpers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/indexed_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/masked_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/megatron_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/megatron_tokenizer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/multimodal_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/object_storage_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/db/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/db/build.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/db/dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/db/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/external_libs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/build.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/factory.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/index.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/validate.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/query.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/t5_dataset.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/utils_object_storage.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/utils_s3.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/core.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/serialization.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/data_parallel_base.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/distributed_data_parallel.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/energy_monitor.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/enums.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/data_type.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/export_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/model_type.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trt_model_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trt_model_type.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/extensions/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fp8_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/full_cuda_graph.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_bias_dropout.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_bias_geglu.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_bias_gelu.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_cross_entropy.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_indices_converter.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_layer_norm.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_softmax.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_weighted_squared_relu.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/hyper_comm_grid.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/async_stream.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/common_inference_params.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/communication_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/base_context.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/dynamic_context.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/static_context.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/abstract_engine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/mcore_engine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/static_engine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/scheduler.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference_params.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/jit.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/T5/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/T5/t5_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/T5/t5_spec.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/bert_layer_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/bert_lm_head.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/bert_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/pooler.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/language_module/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/vision_module/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/vision_module/vision_module.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/gpt_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/huggingface/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/huggingface/module.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mamba/mamba_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/config/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/config/base_configs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/model/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/model/base.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/submodules/audio.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/submodules/base.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/submodules/vision.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/multimodal/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/multimodal/llava_spec.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/base_attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/decoder_attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/decoder_spec.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/encoder_attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/encoder_spec.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/clip_vit_model.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/vit_layer_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/msc_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/num_microbatches_calculator.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/grad_scaler.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer_param_scheduler.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/packed_seq_params.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/parallel_state.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/schedules.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/layers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/process_groups_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/quantization/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/quantization/quant_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/quantization/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/requirements.txt +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/rerun_state_machine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_block.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_context_parallel.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_layer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_mixer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mlp_layer.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/triton_cache_manager.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/data.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/mappings.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/timers.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/custom_layers/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/dot_product_attention.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/enums.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/identity_op.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/module.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/__init__.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/fused_a2a.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/moe_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/router.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/multi_token_prediction.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/spec_utils.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/torch_layer_norm.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/torch_norm.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/transformer_block.py +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron_core.egg-info/dependency_links.txt +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron_core.egg-info/top_level.txt +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/setup.cfg +0 -0
- {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: megatron-core
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.15.0rc0
|
|
4
4
|
Summary: Megatron Core - a library for efficient and scalable training of transformer based models
|
|
5
5
|
Author-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
6
6
|
Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
@@ -31,7 +31,7 @@ Description-Content-Type: text/markdown
|
|
|
31
31
|
License-File: LICENSE
|
|
32
32
|
Requires-Dist: torch
|
|
33
33
|
Requires-Dist: numpy<2.0.0
|
|
34
|
-
Requires-Dist: packaging
|
|
34
|
+
Requires-Dist: packaging>=24.2
|
|
35
35
|
Provides-Extra: mlm
|
|
36
36
|
Requires-Dist: flask-restful; extra == "mlm"
|
|
37
37
|
Requires-Dist: sentencepiece; extra == "mlm"
|
|
@@ -43,7 +43,7 @@ Requires-Dist: einops~=0.8; extra == "dev"
|
|
|
43
43
|
Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "dev"
|
|
44
44
|
Requires-Dist: nvtx~=0.2; extra == "dev"
|
|
45
45
|
Requires-Dist: transformers~=4.53; extra == "dev"
|
|
46
|
-
Requires-Dist: multi-storage-client
|
|
46
|
+
Requires-Dist: multi-storage-client<0.26,~=0.25; extra == "dev"
|
|
47
47
|
Requires-Dist: opentelemetry-api~=1.33.1; extra == "dev"
|
|
48
48
|
Requires-Dist: setuptools<80.0.0; extra == "dev"
|
|
49
49
|
Requires-Dist: mamba-ssm~=2.2; extra == "dev"
|
{megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/dict_utils.py
RENAMED
|
@@ -103,11 +103,19 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
|
|
|
103
103
|
else:
|
|
104
104
|
only_left = []
|
|
105
105
|
only_right = []
|
|
106
|
+
mismatch_debug_data = [prefix, type(x1), type(x2)]
|
|
106
107
|
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
108
|
+
try:
|
|
109
|
+
if x1.device != x2.device:
|
|
110
|
+
_is_mismatch = not torch.all(x1.cpu() == x2.cpu())
|
|
111
|
+
else:
|
|
112
|
+
_is_mismatch = not torch.all(x1 == x2)
|
|
113
|
+
mismatch_debug_data.extend(
|
|
114
|
+
[(x1 != x2).sum(), (x1 != x2).shape, (x1 != x2).nonzero().tolist()]
|
|
115
|
+
)
|
|
116
|
+
except (RuntimeError, TypeError, ValueError):
|
|
117
|
+
_is_mismatch = True
|
|
118
|
+
mismatch_debug_data.extend([x1.shape, x2.shape])
|
|
111
119
|
# TODO: change with concrete type that has both replica_id and data attrs
|
|
112
120
|
elif hasattr(x1, "replica_id") and hasattr(x2, "replica_id"):
|
|
113
121
|
assert type(x1) == type(x2)
|
|
@@ -122,7 +130,7 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
|
|
|
122
130
|
_is_mismatch = True
|
|
123
131
|
|
|
124
132
|
if _is_mismatch:
|
|
125
|
-
mismatch.append((
|
|
133
|
+
mismatch.append(tuple(mismatch_debug_data))
|
|
126
134
|
|
|
127
135
|
return only_left, only_right, mismatch
|
|
128
136
|
|
{megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/mapping.py
RENAMED
|
@@ -135,23 +135,23 @@ class ShardedTensor(ShardedBase):
|
|
|
135
135
|
f"equal to global shape dimensions for {self}"
|
|
136
136
|
)
|
|
137
137
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
# the shape of p2 on GPU0 is zero.
|
|
145
|
-
if sh != 0 and off % sh != 0:
|
|
146
|
-
raise CheckpointingException(
|
|
147
|
-
f"Global offset ({off}) must be divisible by local shape ({sh}) for {self}."
|
|
148
|
-
)
|
|
138
|
+
if self.axis_fragmentations is not None:
|
|
139
|
+
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
|
|
140
|
+
if sh != 0 and off % sh != 0:
|
|
141
|
+
raise CheckpointingException(
|
|
142
|
+
f"Global offset ({off}) must be divisible by local shape ({sh}) for {self}."
|
|
143
|
+
)
|
|
149
144
|
|
|
150
145
|
if has_flattened_range and self.flattened_range.step is not None:
|
|
151
146
|
raise CheckpointingException(
|
|
152
147
|
f"`step` argument in the flattened range of a ShardedTensor is not supported."
|
|
153
148
|
)
|
|
154
149
|
|
|
150
|
+
@property
|
|
151
|
+
def has_regular_grid(self):
|
|
152
|
+
"""Alias for having a regular sharding grid."""
|
|
153
|
+
return self.axis_fragmentations is not None
|
|
154
|
+
|
|
155
155
|
def global_slice(self) -> Tuple[Union[int, slice], ...]:
|
|
156
156
|
"""
|
|
157
157
|
Returns a tuple of int and slice objects representing a slice of the
|
{megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/optimizer.py
RENAMED
|
@@ -25,6 +25,12 @@ from .mapping import (
|
|
|
25
25
|
)
|
|
26
26
|
from .utils import extract_sharded_tensors_and_factories
|
|
27
27
|
|
|
28
|
+
KEEP_VARS_HINT = (
|
|
29
|
+
" Make sure state dict contains original torch.nn.Parameters (not pure torch.Tensors)"
|
|
30
|
+
" by passing `keep_vars=True` to `.state_dict()`. If any transformation of the original"
|
|
31
|
+
" parameter is needed, use a ShardedTensorFactory."
|
|
32
|
+
)
|
|
33
|
+
|
|
28
34
|
|
|
29
35
|
def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
|
|
30
36
|
"""Generate mapping from optimizer param to optimizer state id."""
|
|
@@ -79,9 +79,24 @@ class AsyncRequest(NamedTuple):
|
|
|
79
79
|
|
|
80
80
|
This logic is equivalent to what should happen in case of the async call.
|
|
81
81
|
"""
|
|
82
|
+
# preload tensors.
|
|
83
|
+
async_fn_args = list(self.async_fn_args)
|
|
84
|
+
if self.preload_fn:
|
|
85
|
+
assert len(async_fn_args) == 3, "Expected 3 args to be passed to async function"
|
|
86
|
+
# The async_fn is passed as a partial functool with pre-determined args
|
|
87
|
+
# In the async_fn_args we pass the remaining positional args required by the async_fn
|
|
88
|
+
# async_fn_args[1] refers to the write_buckets
|
|
89
|
+
# To ensure we stage the write_buckets to CPU memory for sync CP,
|
|
90
|
+
# we replace it with preload_fn callable that returns the CPU staged tensors
|
|
91
|
+
async_fn_args[1] = self.preload_fn()
|
|
92
|
+
# persist the state
|
|
82
93
|
if self.async_fn is not None:
|
|
83
|
-
self.async_fn(*self.
|
|
94
|
+
self.async_fn(*async_fn_args, **self.async_fn_kwargs)
|
|
95
|
+
|
|
96
|
+
# This utility implements a sync cp save. Hence the barrier.
|
|
84
97
|
torch.distributed.barrier()
|
|
98
|
+
|
|
99
|
+
# Finalize the CP state
|
|
85
100
|
for finalize_fn in self.finalize_fns:
|
|
86
101
|
finalize_fn()
|
|
87
102
|
|
|
@@ -150,7 +165,7 @@ class AsyncCaller(ABC):
|
|
|
150
165
|
return ten[0] == 0
|
|
151
166
|
|
|
152
167
|
@abstractmethod
|
|
153
|
-
def close(self):
|
|
168
|
+
def close(self, abort=False):
|
|
154
169
|
"""Terminate the async caller at exit of an application or some termination conditions"""
|
|
155
170
|
logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")
|
|
156
171
|
|
|
@@ -237,15 +252,23 @@ class TemporalAsyncCaller(AsyncCaller):
|
|
|
237
252
|
is_done = True
|
|
238
253
|
return is_done
|
|
239
254
|
|
|
240
|
-
def close(self):
|
|
255
|
+
def close(self, abort=False):
|
|
241
256
|
"""For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done`
|
|
242
257
|
|
|
243
258
|
This method make sure the TemporalAsyncCaller terminated
|
|
244
259
|
with all its assigned async request completed
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
abort (bool, optional): Default to False. Needs to be manually set to true when
|
|
263
|
+
the checkpoint async process needs to be aborted.
|
|
245
264
|
"""
|
|
246
265
|
if self.process:
|
|
247
266
|
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
|
|
248
|
-
|
|
267
|
+
if abort:
|
|
268
|
+
logger.warning(f"Temporal worker aborted in rank {torch.distributed.get_rank()}")
|
|
269
|
+
self.process.kill()
|
|
270
|
+
else:
|
|
271
|
+
self.process.join()
|
|
249
272
|
self.process = None
|
|
250
273
|
logger.debug(
|
|
251
274
|
"TemporalAsyncCaller: Async process join finished "
|
|
@@ -388,18 +411,25 @@ class PersistentAsyncCaller(AsyncCaller):
|
|
|
388
411
|
|
|
389
412
|
return is_done
|
|
390
413
|
|
|
391
|
-
def close(self):
|
|
414
|
+
def close(self, abort=False):
|
|
392
415
|
"""Wait on the left async requests and terminate the PersistentAsyncCaller
|
|
393
416
|
|
|
394
417
|
Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated
|
|
418
|
+
Args:
|
|
419
|
+
abort (bool, optional): Default to False. Needs to be manually set to true when
|
|
420
|
+
the checkpoint async process needs to be aborted.
|
|
395
421
|
"""
|
|
396
422
|
logger.info(
|
|
397
423
|
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller"
|
|
398
424
|
)
|
|
399
425
|
if self.process:
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
426
|
+
if abort:
|
|
427
|
+
logger.warning(f"Persistent worker aborted in rank {torch.distributed.get_rank()}")
|
|
428
|
+
self.process.kill()
|
|
429
|
+
else:
|
|
430
|
+
self.queue.put('DONE')
|
|
431
|
+
self.queue.join()
|
|
432
|
+
self.process.join()
|
|
403
433
|
self.process = None
|
|
404
434
|
|
|
405
435
|
def __del__(self):
|
|
@@ -528,6 +558,9 @@ class AsyncCallsQueue:
|
|
|
528
558
|
blocking (bool, optional): if True, will wait until all active requests
|
|
529
559
|
are done. Otherwise, finalizes only the async request that already
|
|
530
560
|
finished. Defaults to False.
|
|
561
|
+
|
|
562
|
+
no_dist (bool, Optional): if True, training ranks simply check its
|
|
563
|
+
asynchronous checkpoint writer without synchronization.
|
|
531
564
|
Returns:
|
|
532
565
|
List[int]: list of indices (as returned by `schedule_async_request`)
|
|
533
566
|
of async calls that have been successfully finalized.
|
|
@@ -545,8 +578,8 @@ class AsyncCallsQueue:
|
|
|
545
578
|
finalize_fn()
|
|
546
579
|
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
|
|
547
580
|
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
|
|
548
|
-
assert ten.item() == call_idx,
|
|
549
|
-
|
|
581
|
+
assert ten.item() == call_idx, "Unmatched async calls. "
|
|
582
|
+
"That probably means not all ranks are participating in async finalization"
|
|
550
583
|
call_idx_finalized.append(call_idx)
|
|
551
584
|
return call_idx_finalized
|
|
552
585
|
|
|
@@ -554,8 +587,13 @@ class AsyncCallsQueue:
|
|
|
554
587
|
"""Get the number of active async calls."""
|
|
555
588
|
return len(self.async_calls)
|
|
556
589
|
|
|
557
|
-
def close(self):
|
|
558
|
-
"""Finalize all calls upon closing.
|
|
559
|
-
|
|
590
|
+
def close(self, abort=False):
|
|
591
|
+
"""Finalize all calls upon closing.
|
|
592
|
+
Args:
|
|
593
|
+
abort (bool, optional): Default to False. Needs to be manually set to true when
|
|
594
|
+
the checkpoint async process needs to be aborted.
|
|
595
|
+
"""
|
|
596
|
+
if not abort:
|
|
597
|
+
self.maybe_finalize_async_calls(blocking=True)
|
|
560
598
|
if self.persistent and self.persistent_caller:
|
|
561
|
-
self.persistent_caller.close()
|
|
599
|
+
self.persistent_caller.close(abort=abort)
|
|
@@ -221,8 +221,4 @@ class AsyncSaveShardedStrategy(SaveShardedStrategy):
|
|
|
221
221
|
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union[str, Path]):
|
|
222
222
|
"""Each async strategy can be trivially used as a sync strategy."""
|
|
223
223
|
async_request = self.async_save(sharded_state_dict, checkpoint_dir)
|
|
224
|
-
|
|
225
|
-
# We keep this verbose call for now
|
|
226
|
-
global async_calls
|
|
227
|
-
async_calls.schedule_async_request(async_request)
|
|
228
|
-
async_calls.maybe_finalize_async_calls(blocking=True)
|
|
224
|
+
async_request.execute_sync()
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
from itertools import chain
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.distributed.checkpoint.metadata import (
|
|
6
|
+
ChunkStorageMetadata,
|
|
7
|
+
MetadataIndex,
|
|
8
|
+
TensorProperties,
|
|
9
|
+
)
|
|
10
|
+
from torch.distributed.checkpoint.planner import TensorWriteData, WriteItem, WriteItemType
|
|
11
|
+
|
|
12
|
+
from ..mapping import ShardedTensor
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CheckpointableShardedTensor(torch.Tensor):
|
|
16
|
+
"""ShardedTensor extension compatible with PyTorch DCP checkpointing library.
|
|
17
|
+
|
|
18
|
+
Implements the torch.distributed._checkpointable._Checkpointable protocol.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __new__(cls, data: torch.Tensor, sh_ten: ShardedTensor):
|
|
22
|
+
return torch.Tensor._make_wrapper_subclass(cls, torch.Size(sh_ten.global_shape))
|
|
23
|
+
|
|
24
|
+
def __init__(self, data: torch.Tensor, sh_ten: ShardedTensor):
|
|
25
|
+
self._data = data
|
|
26
|
+
self._sh_ten = sh_ten
|
|
27
|
+
|
|
28
|
+
def __create_write_items__(
|
|
29
|
+
self, fqn: str, sh_ten: 'CheckpointableShardedTensor', index: int = None
|
|
30
|
+
) -> list[WriteItem]:
|
|
31
|
+
"""Simple translation from ShardedTensor offsets into DCP offsets.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
fqn (str): tensor FQN.
|
|
35
|
+
sh_ten (CheckpointableShardedTensor): same as `self`
|
|
36
|
+
index (int): specifies index within the LocalShardsContainer.
|
|
37
|
+
This is an optimization hint used in DCP.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
List[WriteItem]: list of DCP WriteItem metadata objects.
|
|
41
|
+
"""
|
|
42
|
+
offsets = torch.Size(sh_ten._sh_ten.global_offset)
|
|
43
|
+
global_shape = torch.Size(sh_ten._sh_ten.global_shape)
|
|
44
|
+
chunk_size = torch.Size(sh_ten._sh_ten.local_shape)
|
|
45
|
+
assert chunk_size == sh_ten._sh_ten.data.size()
|
|
46
|
+
|
|
47
|
+
return [
|
|
48
|
+
WriteItem(
|
|
49
|
+
index=MetadataIndex(fqn, offsets, index),
|
|
50
|
+
type=WriteItemType.SHARD,
|
|
51
|
+
tensor_data=TensorWriteData(
|
|
52
|
+
chunk=ChunkStorageMetadata(offsets=offsets, sizes=chunk_size),
|
|
53
|
+
properties=TensorProperties.create_from_tensor(sh_ten._sh_ten.data),
|
|
54
|
+
size=global_shape,
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
def __create_chunk_list__(self) -> list[ChunkStorageMetadata]:
|
|
60
|
+
"""Simple translation from ShardedTensor offsets into DCP offsets.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List[ChunkStorageMetadata]: list of DCP ChunkStorageMetadata metadata objects.
|
|
64
|
+
"""
|
|
65
|
+
offsets = torch.Size(self._sh_ten.global_offset)
|
|
66
|
+
chunk_size = torch.Size(self._sh_ten.local_shape)
|
|
67
|
+
assert chunk_size == self._sh_ten.data.size()
|
|
68
|
+
|
|
69
|
+
return [ChunkStorageMetadata(offsets=offsets, sizes=chunk_size)]
|
|
70
|
+
|
|
71
|
+
def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor:
|
|
72
|
+
"""Trivial implementation which simply yields the underlying tensor.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
index (MetadataIndex): unused
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Tensor: the underlying data tensor
|
|
79
|
+
"""
|
|
80
|
+
return self._sh_ten.data
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def from_sh_ten(cls, sh_ten: ShardedTensor) -> 'CheckpointableShardedTensor':
|
|
84
|
+
"""Constructor which turns a ShardedTensor into CheckpointableShardedTensor
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
sh_ten (ShardedTensor): a sharded tensor to wrap
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
CheckpointableShardedTensor: wrapped ShardedTensor
|
|
91
|
+
"""
|
|
92
|
+
assert isinstance(sh_ten, ShardedTensor)
|
|
93
|
+
return cls(sh_ten.data, sh_ten)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def __torch_dispatch__(cls, func, types, args, kwargs=None):
|
|
97
|
+
"""Placeholder implementation."""
|
|
98
|
+
raise NotImplementedError(
|
|
99
|
+
f"{cls.__name__}.__torch_dispatch__ not implemented."
|
|
100
|
+
f" {cls.__name__} shouldn't be used with Tensor operations."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def __repr__(self):
|
|
104
|
+
return f'{self.__class__.__name__}({self._sh_ten.__repr__()})'
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class LocalShardsContainer(torch.Tensor):
|
|
108
|
+
"""DCP compatible container for local shards.
|
|
109
|
+
|
|
110
|
+
PyTorch DCP requires a single tensor per rank for a given global tensor FQN.
|
|
111
|
+
This class acts as a container allowing multiple checkpointable shards per rank.
|
|
112
|
+
|
|
113
|
+
Implements the torch.distributed._checkpointable._Checkpointable protocol.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def __new__(cls, local_shards: list[torch.Tensor]) -> "LocalShardsContainer":
|
|
118
|
+
assert len(local_shards) > 0
|
|
119
|
+
# This assumes local shard already has correct size info
|
|
120
|
+
return torch.Tensor._make_wrapper_subclass(cls, local_shards[0].size())
|
|
121
|
+
|
|
122
|
+
def __init__(self, local_shards: list[torch.Tensor]):
|
|
123
|
+
for local_shard in local_shards:
|
|
124
|
+
# this is needed only for __get_tensor_shard__
|
|
125
|
+
assert isinstance(local_shard, CheckpointableShardedTensor)
|
|
126
|
+
self._local_shards = local_shards
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
130
|
+
"""Placeholder implementation."""
|
|
131
|
+
raise NotImplementedError(
|
|
132
|
+
f"{cls.__name__}.__torch_dispatch__ not implemented."
|
|
133
|
+
f" {cls.__name__} shouldn't be used with Tensor operations."
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def __create_write_items__(
|
|
137
|
+
self, fqn: str, local_shards_cont: 'LocalShardsContainer'
|
|
138
|
+
) -> list[object]:
|
|
139
|
+
"""Delegates creating write items to local shards.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
fqn (str): tensor FQN.
|
|
143
|
+
local_shards_cont (LocalShardsContainer): same as `self`
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
List[WriteItem]: list of DCP WriteItem metadata objects.
|
|
147
|
+
"""
|
|
148
|
+
return list(
|
|
149
|
+
chain.from_iterable(
|
|
150
|
+
shard.__create_write_items__(fqn, shard, index=index)
|
|
151
|
+
for index, shard in enumerate(local_shards_cont._local_shards)
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def __create_chunk_list__(self) -> list[ChunkStorageMetadata]:
|
|
156
|
+
"""Delegates creating chunk items to local shards.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
List[ChunkStorageMetadata]: list of DCP ChunkStorageMetadata metadata objects.
|
|
160
|
+
"""
|
|
161
|
+
return list(
|
|
162
|
+
chain.from_iterable(shard.__create_chunk_list__() for shard in self._local_shards)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor:
|
|
166
|
+
"""Performs shard matching lookup based on index hint or offset.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
index (MetadataIndex): metadata specifying the offset of the queried shard.
|
|
170
|
+
Optionally provides an index hint which speeds up the lookup.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Tensor: the matching shard data tensor
|
|
174
|
+
"""
|
|
175
|
+
if index.offset is None:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
f"Cannot lookup {index.fqn} for a LocalShardsContainer without an offset"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
shards = self._local_shards
|
|
181
|
+
# index hint direct lookup
|
|
182
|
+
if index.index is not None:
|
|
183
|
+
if (
|
|
184
|
+
len(shards) > index.index
|
|
185
|
+
and torch.Size(shards[index.index]._sh_ten.global_offset) == index.offset
|
|
186
|
+
):
|
|
187
|
+
return shards[index.index].__get_tensor_shard__(index)
|
|
188
|
+
|
|
189
|
+
# slow linear search
|
|
190
|
+
for shard in shards:
|
|
191
|
+
if torch.Size(shard._sh_ten.global_offset) == index.offset:
|
|
192
|
+
return shard.__get_tensor_shard__(index)
|
|
193
|
+
raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
|
|
194
|
+
|
|
195
|
+
def __repr__(self):
|
|
196
|
+
return f'{self.__class__.__name__}({self._local_shards.__repr__()})'
|
|
@@ -57,6 +57,7 @@ from .base import (
|
|
|
57
57
|
register_default_strategy,
|
|
58
58
|
)
|
|
59
59
|
from .cached_metadata_filesystem_reader import CachedMetadataFileSystemReader
|
|
60
|
+
from .checkpointable import CheckpointableShardedTensor, LocalShardsContainer
|
|
60
61
|
from .filesystem_async import FileSystemWriterAsync
|
|
61
62
|
from .resharding import (
|
|
62
63
|
TensorReformulationMetadata,
|
|
@@ -240,14 +241,18 @@ def sharded_tensor_to_torch_sharded_tensor(
|
|
|
240
241
|
placement = f"rank:{rank}/cuda"
|
|
241
242
|
for sh_ten in local_global_offsets[offset]:
|
|
242
243
|
if has_flattened_range:
|
|
243
|
-
assert offset == sh_ten.local_chunk_offset_in_global()
|
|
244
|
+
assert offset == sh_ten.local_chunk_offset_in_global(), (
|
|
245
|
+
offset,
|
|
246
|
+
sh_ten.local_chunk_offset_in_global(),
|
|
247
|
+
)
|
|
244
248
|
# This is not an actual offset, but an offset of the whole shard
|
|
245
249
|
# This is needed for a PyT Dist internal integrity check
|
|
246
|
-
|
|
250
|
+
_shard_offset = sh_ten.local_chunk_offset_in_global() + (0,)
|
|
247
251
|
size = (1,) * len(offsets_shape) + global_shape[-1:]
|
|
248
252
|
else:
|
|
249
253
|
size = sh_ten.data.shape
|
|
250
|
-
|
|
254
|
+
_shard_offset = offset
|
|
255
|
+
shard_metadata.append(ShardMetadata(_shard_offset, size, placement))
|
|
251
256
|
|
|
252
257
|
else:
|
|
253
258
|
# pylint: disable=line-too-long
|
|
@@ -312,7 +317,7 @@ def mcore_to_pyt_state_dict(
|
|
|
312
317
|
rank = torch.distributed.get_rank()
|
|
313
318
|
pyt_state_dict = {}
|
|
314
319
|
|
|
315
|
-
def
|
|
320
|
+
def _mcore_to_dcp_compatible_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor:
|
|
316
321
|
"""Build a PyT ShardedTensor from given shards.
|
|
317
322
|
|
|
318
323
|
During loading:
|
|
@@ -335,11 +340,24 @@ def mcore_to_pyt_state_dict(
|
|
|
335
340
|
if sh_ten.allow_shape_mismatch and is_loading:
|
|
336
341
|
sh_ten.data.zero_()
|
|
337
342
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
+
if not sh_tens[0].has_regular_grid:
|
|
344
|
+
if not is_torch_min_version("2.6a0"):
|
|
345
|
+
raise CheckpointingException(
|
|
346
|
+
f"Uneven sharding not supported for PyTorch version {get_torch_version()}"
|
|
347
|
+
)
|
|
348
|
+
assert sh_tens[0].flattened_range is None
|
|
349
|
+
if len(sh_tens) > 1:
|
|
350
|
+
return LocalShardsContainer(
|
|
351
|
+
[CheckpointableShardedTensor.from_sh_ten(sh_ten) for sh_ten in sh_tens]
|
|
352
|
+
)
|
|
353
|
+
else:
|
|
354
|
+
return CheckpointableShardedTensor.from_sh_ten(sh_tens[0])
|
|
355
|
+
else:
|
|
356
|
+
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(
|
|
357
|
+
sh_tens, rank, load_legacy_1d_flatten_tensors
|
|
358
|
+
)
|
|
359
|
+
torch_sh_ten.key = sh_tens[0].key
|
|
360
|
+
return torch_sh_ten
|
|
343
361
|
|
|
344
362
|
def _mcore_to_torch_sharded_object(sh_objs: List[ShardedObject]) -> io.BytesIO:
|
|
345
363
|
"""Build io.BytesIO from given sharded objects data."""
|
|
@@ -351,7 +369,7 @@ def mcore_to_pyt_state_dict(
|
|
|
351
369
|
for k, v in state_dict.items():
|
|
352
370
|
if isinstance(v[0], ShardedTensor):
|
|
353
371
|
v = cast(List[ShardedTensor], v)
|
|
354
|
-
pyt_state_dict[k] =
|
|
372
|
+
pyt_state_dict[k] = _mcore_to_dcp_compatible_tensor(v)
|
|
355
373
|
else:
|
|
356
374
|
v = cast(List[ShardedObject], v)
|
|
357
375
|
pyt_state_dict[k] = _mcore_to_torch_sharded_object(v)
|
|
@@ -359,12 +377,20 @@ def mcore_to_pyt_state_dict(
|
|
|
359
377
|
return pyt_state_dict
|
|
360
378
|
|
|
361
379
|
|
|
362
|
-
def _unwrap_pyt_sharded_tensor(
|
|
380
|
+
def _unwrap_pyt_sharded_tensor(
|
|
381
|
+
sh_ten: Union[TorchShardedTensor, CheckpointableShardedTensor, LocalShardsContainer, Any]
|
|
382
|
+
) -> Union[List[torch.Tensor], Any]:
|
|
363
383
|
"""Unwrap tensor from PyT ShardedTensor instance.
|
|
364
384
|
|
|
365
385
|
If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor)
|
|
366
386
|
then the tensor has additional singleton dimensions which should be squeezed.
|
|
367
387
|
"""
|
|
388
|
+
if isinstance(sh_ten, CheckpointableShardedTensor):
|
|
389
|
+
return [sh_ten._sh_ten.data]
|
|
390
|
+
if isinstance(sh_ten, LocalShardsContainer):
|
|
391
|
+
return [local_shard._sh_ten.data for local_shard in sh_ten._local_shards]
|
|
392
|
+
if not isinstance(sh_ten, TorchShardedTensor):
|
|
393
|
+
return sh_ten
|
|
368
394
|
mcore_sh_ten = sh_ten.mcore_sh_ten
|
|
369
395
|
ret_tensors = []
|
|
370
396
|
for sh in sh_ten.local_shards():
|
|
@@ -930,10 +956,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
|
|
|
930
956
|
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
|
|
931
957
|
)
|
|
932
958
|
# Unwrap ShardedTensors and return to original state dict
|
|
933
|
-
mcore_state_dict = {
|
|
934
|
-
k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
|
|
935
|
-
for k, v in pyt_state_dict.items()
|
|
936
|
-
}
|
|
959
|
+
mcore_state_dict = {k: _unwrap_pyt_sharded_tensor(v) for k, v in pyt_state_dict.items()}
|
|
937
960
|
mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
|
|
938
961
|
mcore_state_dict, flat_mapping, rename_mapping # type: ignore[arg-type]
|
|
939
962
|
)
|
|
@@ -175,6 +175,11 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
|
|
|
175
175
|
compressor=None,
|
|
176
176
|
fill_value=None,
|
|
177
177
|
write_empty_chunks=True,
|
|
178
|
+
synchronizer=(
|
|
179
|
+
zarr.ProcessSynchronizer(str(checkpoint_dir / f'{sharded_tensor.key}.sync'))
|
|
180
|
+
if sharded_tensor.flattened_range is not None
|
|
181
|
+
else None
|
|
182
|
+
),
|
|
178
183
|
)
|
|
179
184
|
logger.debug(f"Created a new Zarr array at {checkpoint_dir / sharded_tensor.key}")
|
|
180
185
|
except zarr.errors.ContainsArrayError as e:
|
|
@@ -328,7 +333,7 @@ def load_zarr_based_sharded_metadata(
|
|
|
328
333
|
|
|
329
334
|
sharded_state_dict = {}
|
|
330
335
|
for subdir in checkpoint_dir.iterdir():
|
|
331
|
-
if not subdir.is_dir() or not (subdir / ".zarray").exists():
|
|
336
|
+
if not subdir.is_dir() or not (subdir / ".zarray").exists() or subdir.suffix == ".sync":
|
|
332
337
|
continue
|
|
333
338
|
key = subdir.name
|
|
334
339
|
arr_shape, arr_dtype = get_shape_dtype_fn(str(subdir))
|
{megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/validation.py
RENAMED
|
@@ -450,6 +450,7 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
|
|
|
450
450
|
local_shape = some_rank_shard.local_shape
|
|
451
451
|
dtype = some_rank_shard.dtype
|
|
452
452
|
has_flattened_range = some_rank_shard.flattened_range is not None
|
|
453
|
+
has_regular_sharding_grid = some_rank_shard.has_regular_grid
|
|
453
454
|
for rank, sharding in rank_sharding:
|
|
454
455
|
assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
|
|
455
456
|
assert sharding.global_shape == global_shape, (
|
|
@@ -457,17 +458,26 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
|
|
|
457
458
|
global_shape,
|
|
458
459
|
some_rank_shard,
|
|
459
460
|
)
|
|
460
|
-
assert sharding.
|
|
461
|
-
|
|
462
|
-
local_shape,
|
|
461
|
+
assert sharding.has_regular_grid == has_regular_sharding_grid, (
|
|
462
|
+
has_regular_sharding_grid,
|
|
463
463
|
some_rank_shard,
|
|
464
464
|
)
|
|
465
|
+
if has_regular_sharding_grid:
|
|
466
|
+
assert sharding.local_shape == local_shape, (
|
|
467
|
+
sharding.local_shape,
|
|
468
|
+
local_shape,
|
|
469
|
+
some_rank_shard,
|
|
470
|
+
)
|
|
465
471
|
assert (sharding.flattened_range is not None) == has_flattened_range, (
|
|
466
472
|
(sharding.flattened_range is not None),
|
|
467
473
|
has_flattened_range,
|
|
468
474
|
some_rank_shard,
|
|
469
475
|
)
|
|
470
476
|
|
|
477
|
+
if not has_regular_sharding_grid:
|
|
478
|
+
# In case of uneven sharding we defer the validation to DCP
|
|
479
|
+
return
|
|
480
|
+
|
|
471
481
|
shard_access_cnt = _compute_shards_access(rank_sharding)
|
|
472
482
|
if has_flattened_range:
|
|
473
483
|
map_reduce(
|
|
@@ -8,5 +8,6 @@ except ImportError:
|
|
|
8
8
|
from .distributed_data_parallel import DistributedDataParallel
|
|
9
9
|
from .distributed_data_parallel_config import DistributedDataParallelConfig
|
|
10
10
|
from .finalize_model_grads import finalize_model_grads
|
|
11
|
+
from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel
|
|
11
12
|
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
|
|
12
13
|
from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig
|