megatron-core 0.14.0rc2__tar.gz → 0.14.0rc4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- {megatron_core-0.14.0rc2/megatron_core.egg-info → megatron_core-0.14.0rc4}/PKG-INFO +11 -8
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/torch.py +2 -1
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/validation.py +21 -15
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +10 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine.py +5 -27
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine_spec_provider.py +5 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fp8_utils.py +119 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +81 -16
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/__init__.py +1 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/dynamic_context.py +191 -86
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/dynamic_engine.py +79 -18
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +4 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +6 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +26 -39
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/model_parallel_config.py +8 -3
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/rope_utils.py +20 -32
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +10 -4
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +11 -6
- megatron_core-0.14.0rc4/megatron/core/models/common/model_chunk_schedule_plan.py +502 -0
- megatron_core-0.14.0rc4/megatron/core/models/gpt/fine_grained_callables.py +474 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/gpt_layer_specs.py +2 -2
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/gpt_model.py +62 -1
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/__init__.py +143 -44
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/optimizer.py +11 -4
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/optimizer_config.py +1 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/package_info.py +1 -1
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/packed_seq_params.py +2 -2
- megatron_core-0.14.0rc4/megatron/core/pipeline_parallel/combined_1f1b.py +331 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/schedules.py +169 -101
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/utils.py +91 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_block.py +4 -1
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_layer.py +1 -1
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/layers.py +23 -12
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/random.py +4 -1
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/attention.py +3 -7
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/cuda_graphs.py +178 -43
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/mlp.py +20 -2
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/experts.py +22 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/multi_latent_attention.py +81 -9
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_config.py +60 -7
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_layer.py +11 -10
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/utils.py +17 -11
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/utils.py +27 -3
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4/megatron_core.egg-info}/PKG-INFO +11 -8
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron_core.egg-info/SOURCES.txt +2 -0
- megatron_core-0.14.0rc4/megatron_core.egg-info/requires.txt +33 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/pyproject.toml +13 -10
- megatron_core-0.14.0rc2/megatron/core/models/gpt/fine_grained_callables.py +0 -195
- megatron_core-0.14.0rc2/megatron_core.egg-info/requires.txt +0 -30
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/LICENSE +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/MANIFEST.in +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/README.md +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/README.md +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/config_logger.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/bert_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/helpers.cpp +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/helpers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/indexed_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/masked_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/megatron_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/megatron_tokenizer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/multimodal_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/object_storage_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/build.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/external_libs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/build.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/factory.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/index.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/validate.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/query.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/t5_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils_object_storage.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils_s3.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/core.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/mapping.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/optimizer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/serialization.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/data_parallel_base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/distributed_data_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/finalize_model_grads.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/energy_monitor.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/enums.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/data_type.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/export_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/model_type.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trt_model_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trt_model_type.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/extensions/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/extensions/kitchen.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_dropout.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_geglu.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_gelu.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_cross_entropy.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_indices_converter.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_layer_norm.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_softmax.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/hyper_comm_grid.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/async_stream.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/common_inference_params.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/communication_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/base_context.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/static_context.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/abstract_engine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/mcore_engine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/static_engine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/inference_request.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/sampling_params.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/scheduler.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference_params.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/jit.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/T5/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/T5/t5_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/T5/t5_spec.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/backends.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_layer_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_lm_head.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/pooler.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/language_module/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/language_module/language_module.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/vision_module/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/vision_module/vision_module.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/moe_module_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/clip_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/module.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/qwen_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/mamba_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/config/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/config/base_configs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/model/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/model/base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/audio.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/vision.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/context_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/llava_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/llava_spec.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/base_attention.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/decoder_attention.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/decoder_spec.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/encoder_attention.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/encoder_spec.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/clip_vit_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/multimodal_projector.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/radio.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/vit_layer_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/msc_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/num_microbatches_calculator.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/clip_grads.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/distrib_optimizer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/grad_scaler.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer_param_scheduler.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/parallel_state.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/layers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/process_groups_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/quantization/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/quantization/quant_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/quantization/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/requirements.txt +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/rerun_state_machine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_context_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_mixer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mlp_layer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/triton_cache_manager.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/data.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/mappings.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/timers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/custom_layers/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/dot_product_attention.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/enums.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/identity_op.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/module.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/fused_a2a.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/moe_layer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/moe_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/router.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/shared_experts.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/multi_token_prediction.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/spec_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/torch_layer_norm.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/torch_norm.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_block.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron_core.egg-info/dependency_links.txt +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron_core.egg-info/top_level.txt +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/setup.cfg +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: megatron-core
|
|
3
|
-
Version: 0.14.
|
|
3
|
+
Version: 0.14.0rc4
|
|
4
4
|
Summary: Megatron Core - a library for efficient and scalable training of transformer based models
|
|
5
5
|
Author-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
6
6
|
Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
@@ -31,6 +31,7 @@ Description-Content-Type: text/markdown
|
|
|
31
31
|
License-File: LICENSE
|
|
32
32
|
Requires-Dist: torch
|
|
33
33
|
Requires-Dist: numpy<2.0.0
|
|
34
|
+
Requires-Dist: packaging~=25.0
|
|
34
35
|
Provides-Extra: mlm
|
|
35
36
|
Requires-Dist: flask-restful; extra == "mlm"
|
|
36
37
|
Requires-Dist: sentencepiece; extra == "mlm"
|
|
@@ -38,14 +39,16 @@ Requires-Dist: tiktoken; extra == "mlm"
|
|
|
38
39
|
Requires-Dist: wandb; extra == "mlm"
|
|
39
40
|
Provides-Extra: dev
|
|
40
41
|
Requires-Dist: tqdm; extra == "dev"
|
|
41
|
-
Requires-Dist: einops; extra == "dev"
|
|
42
|
-
Requires-Dist: tensorstore!=0.1.46,!=0.1.72; extra == "dev"
|
|
43
|
-
Requires-Dist: nvtx; extra == "dev"
|
|
44
|
-
Requires-Dist: transformers; extra == "dev"
|
|
45
|
-
Requires-Dist: multi-storage-client; extra == "dev"
|
|
42
|
+
Requires-Dist: einops~=0.8; extra == "dev"
|
|
43
|
+
Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "dev"
|
|
44
|
+
Requires-Dist: nvtx~=0.2; extra == "dev"
|
|
45
|
+
Requires-Dist: transformers~=4.53; extra == "dev"
|
|
46
|
+
Requires-Dist: multi-storage-client~=0.20.3; extra == "dev"
|
|
47
|
+
Requires-Dist: opentelemetry-api~=1.33.1; extra == "dev"
|
|
46
48
|
Requires-Dist: setuptools<80.0.0; extra == "dev"
|
|
47
|
-
Requires-Dist: nvidia-modelopt[torch]
|
|
48
|
-
Requires-Dist: megatron-energon[av_decode]
|
|
49
|
+
Requires-Dist: nvidia-modelopt[torch]<0.32.0,>=0.31.0a0; sys_platform != "darwin" and extra == "dev"
|
|
50
|
+
Requires-Dist: megatron-energon[av_decode]~=6.0; extra == "dev"
|
|
51
|
+
Requires-Dist: flashinfer-python; extra == "dev"
|
|
49
52
|
Provides-Extra: lts
|
|
50
53
|
Requires-Dist: tqdm; extra == "lts"
|
|
51
54
|
Requires-Dist: einops; extra == "lts"
|
|
@@ -374,7 +374,8 @@ def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]
|
|
|
374
374
|
ten = ten.view(-1)
|
|
375
375
|
else:
|
|
376
376
|
for _ in range(mcore_sh_ten.prepend_axis_num):
|
|
377
|
-
|
|
377
|
+
assert ten.size(0) == 1
|
|
378
|
+
ten = ten[0] # NOTE: ten.squeeze(0) uses more memory for FP8 tensors
|
|
378
379
|
ret_tensors.append(ten)
|
|
379
380
|
return ret_tensors
|
|
380
381
|
|
{megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/validation.py
RENAMED
|
@@ -375,28 +375,34 @@ def maybe_report_missing_and_unexpected_keys(
|
|
|
375
375
|
def _validate_common_state_dict(common_state_dict: CommonStateDict) -> None:
|
|
376
376
|
"""Validate consistancy across ranks for the common state dict
|
|
377
377
|
|
|
378
|
-
We save the common state dict only on rank 0. We validate to make sure that the common dict is
|
|
378
|
+
We save the common state dict only on rank 0. We validate to make sure that the common dict is consistent across ranks before saving.
|
|
379
379
|
|
|
380
380
|
Args:
|
|
381
381
|
common_state_dict: The common state dict present in all ransk
|
|
382
382
|
"""
|
|
383
|
+
if not torch.distributed.is_initialized():
|
|
384
|
+
return
|
|
383
385
|
|
|
384
|
-
#
|
|
386
|
+
# Broadcast the common state dict from rank 0 to all other ranks
|
|
387
|
+
# Each rank will do a comparison with its local rank vs the broadcasted state dict from rank 0
|
|
385
388
|
rank = torch.distributed.get_rank()
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
if
|
|
389
|
+
|
|
390
|
+
object_list = [common_state_dict] if rank == 0 else [None]
|
|
391
|
+
torch.distributed.broadcast_object_list(object_list, src=0)
|
|
392
|
+
rank0_state_dict = object_list[0]
|
|
393
|
+
|
|
394
|
+
# Skip comparing rank 0 with itself
|
|
395
|
+
if rank > 0:
|
|
396
|
+
current_rank_state_dict = common_state_dict
|
|
397
|
+
only_in_rank0, only_in_current_rank, mismatch = diff(
|
|
398
|
+
rank0_state_dict, current_rank_state_dict
|
|
399
|
+
)
|
|
400
|
+
if only_in_rank0 or only_in_current_rank or mismatch:
|
|
398
401
|
logger.warning(
|
|
399
|
-
f"
|
|
402
|
+
f"Rank {rank} common state dict differs from rank 0 common state dict. "
|
|
403
|
+
f"Keys only on rank 0: {only_in_rank0}, "
|
|
404
|
+
f"Keys only on {rank}: {only_in_current_rank}, "
|
|
405
|
+
f"Mismatched keys: {mismatch}"
|
|
400
406
|
)
|
|
401
407
|
|
|
402
408
|
|
|
@@ -217,6 +217,16 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
217
217
|
|
|
218
218
|
self.module.apply(unmap_weight_tensor)
|
|
219
219
|
|
|
220
|
+
for param in self.module.parameters():
|
|
221
|
+
if not hasattr(param, 'grad_added_to_main_grad'):
|
|
222
|
+
# This is to ensure that the param.grad_added_to_main_grad is set to False
|
|
223
|
+
# when the parameter is created.
|
|
224
|
+
param.grad_added_to_main_grad = False
|
|
225
|
+
if not hasattr(param, '__fsdp_param__'):
|
|
226
|
+
# This is to ensure that the param.__fsdp_param__ is set to True
|
|
227
|
+
# when the parameter is created.
|
|
228
|
+
param.__fsdp_param__ = True
|
|
229
|
+
|
|
220
230
|
def _init_fsdp_param_and_grad_buffer(self):
|
|
221
231
|
if self.config.calculate_per_token_loss:
|
|
222
232
|
# We don't need to scale the gradients in this case.
|
{megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine.py
RENAMED
|
@@ -889,25 +889,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
|
|
|
889
889
|
if packed_seq_params is not None
|
|
890
890
|
else {}
|
|
891
891
|
)
|
|
892
|
-
|
|
893
|
-
# after init
|
|
894
|
-
if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False):
|
|
895
|
-
self.qkv_format = "bshd"
|
|
896
|
-
|
|
897
|
-
qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
|
|
898
|
-
|
|
899
|
-
# WAR for peak memory usage.
|
|
900
|
-
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
|
|
901
|
-
if self.config.apply_rope_fusion and qkv_format == "bshd":
|
|
902
|
-
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
|
|
903
|
-
# In PyTorch, the following two tensors are in fact the same:
|
|
904
|
-
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
|
|
905
|
-
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
|
|
906
|
-
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
|
|
907
|
-
# can have same shape but different strides.
|
|
908
|
-
# We unify them to the first one to pass the stride check in TE
|
|
909
|
-
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
|
|
910
|
-
value = value.as_strided(value.shape, key.stride())
|
|
892
|
+
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
|
|
911
893
|
|
|
912
894
|
attention_bias_kwargs = {}
|
|
913
895
|
if attention_bias is not None:
|
|
@@ -942,10 +924,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
|
|
|
942
924
|
query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
|
|
943
925
|
)
|
|
944
926
|
|
|
945
|
-
|
|
946
|
-
return core_attn_out.transpose(0, 1)
|
|
947
|
-
else:
|
|
948
|
-
return core_attn_out
|
|
927
|
+
return core_attn_out
|
|
949
928
|
|
|
950
929
|
|
|
951
930
|
if HAVE_TE and is_te_min_version("1.9.0.dev0"):
|
|
@@ -1633,10 +1612,8 @@ try:
|
|
|
1633
1612
|
else:
|
|
1634
1613
|
if interleaved:
|
|
1635
1614
|
raise ValueError("Only TE >= 2.3.0 supports interleaved fused RoPE.")
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
else:
|
|
1639
|
-
raise ValueError("Only TE >= 1.4.0.dev0 supports fused RoPE.")
|
|
1615
|
+
|
|
1616
|
+
return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True)
|
|
1640
1617
|
|
|
1641
1618
|
def fused_apply_rotary_pos_emb_thd(
|
|
1642
1619
|
t: torch.Tensor,
|
|
@@ -1659,6 +1636,7 @@ try:
|
|
|
1659
1636
|
cp_rank=cp_rank,
|
|
1660
1637
|
)
|
|
1661
1638
|
else:
|
|
1639
|
+
assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP."
|
|
1662
1640
|
return apply_rotary_pos_emb(
|
|
1663
1641
|
t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens
|
|
1664
1642
|
)
|
|
@@ -8,6 +8,7 @@ from megatron.core.extensions.transformer_engine import (
|
|
|
8
8
|
TEColumnParallelLinear,
|
|
9
9
|
TEDotProductAttention,
|
|
10
10
|
TELayerNormColumnParallelLinear,
|
|
11
|
+
TELinear,
|
|
11
12
|
TENorm,
|
|
12
13
|
TERowParallelGroupedLinear,
|
|
13
14
|
TERowParallelLinear,
|
|
@@ -23,6 +24,10 @@ from megatron.core.utils import get_te_version, is_te_min_version
|
|
|
23
24
|
class TESpecProvider(BackendSpecProvider):
|
|
24
25
|
"""A protocol for providing the submodules used in Spec building."""
|
|
25
26
|
|
|
27
|
+
def linear(self) -> type:
|
|
28
|
+
"""Which linear module TE backend uses"""
|
|
29
|
+
return TELinear
|
|
30
|
+
|
|
26
31
|
def column_parallel_linear(self) -> type:
|
|
27
32
|
"""Which column parallel linear module TE backend uses"""
|
|
28
33
|
return TEColumnParallelLinear
|
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
"""Utility functions related to FP8 that are used throughout Megatron core"""
|
|
4
4
|
|
|
5
|
+
import weakref
|
|
5
6
|
from contextlib import nullcontext
|
|
7
|
+
from functools import wraps
|
|
6
8
|
from typing import List, Optional
|
|
7
9
|
|
|
8
10
|
import torch
|
|
@@ -53,6 +55,29 @@ except (ImportError, ModuleNotFoundError):
|
|
|
53
55
|
# MXFP8Tensor not found
|
|
54
56
|
HAVE_TE_MXFP8TENSOR = False
|
|
55
57
|
|
|
58
|
+
if HAVE_TE:
|
|
59
|
+
from megatron.core.extensions.transformer_engine import (
|
|
60
|
+
TEColumnParallelLinear,
|
|
61
|
+
TELayerNormColumnParallelLinear,
|
|
62
|
+
TELinear,
|
|
63
|
+
TERowParallelLinear,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
TE_LINEAR_TYPES = (
|
|
67
|
+
TELinear,
|
|
68
|
+
TEColumnParallelLinear,
|
|
69
|
+
TERowParallelLinear,
|
|
70
|
+
TELayerNormColumnParallelLinear,
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
TE_LINEAR_TYPES = ()
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
|
|
77
|
+
except ImportError:
|
|
78
|
+
Fp8Padding = None
|
|
79
|
+
Fp8Unpadding = None
|
|
80
|
+
|
|
56
81
|
|
|
57
82
|
def is_float8tensor(tensor: torch.Tensor) -> bool:
|
|
58
83
|
"""Check if a tensor is a Transformer Engine Float8Tensor.
|
|
@@ -511,3 +536,97 @@ else:
|
|
|
511
536
|
def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False):
|
|
512
537
|
"""Returns dummy fp8 context manager since TE is not available."""
|
|
513
538
|
return nullcontext()
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
if HAVE_TE:
|
|
542
|
+
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
|
|
543
|
+
|
|
544
|
+
# Modules that have been wrapped for inference for fp8
|
|
545
|
+
_fp8_inference_wrapped_modules = weakref.WeakSet()
|
|
546
|
+
|
|
547
|
+
def _wrap_te_linear_for_padding(module: torch.nn.Module):
|
|
548
|
+
"""Wrap a TE linear module to automatically pad sequences for FP8 inference.
|
|
549
|
+
|
|
550
|
+
Modifies the module's forward method to:
|
|
551
|
+
1. Pad input sequences to FP8 alignment requirements
|
|
552
|
+
2. Run the original forward pass
|
|
553
|
+
3. Unpad outputs to original sequence length
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
module: A Transformer Engine linear layer (TELinear, TEColumnParallelLinear, etc.)
|
|
557
|
+
"""
|
|
558
|
+
if module in _fp8_inference_wrapped_modules:
|
|
559
|
+
return
|
|
560
|
+
_pad_func = Fp8Padding(1)
|
|
561
|
+
_unpad_func = Fp8Unpadding(1)
|
|
562
|
+
|
|
563
|
+
original_forward = module.forward
|
|
564
|
+
|
|
565
|
+
@wraps(original_forward)
|
|
566
|
+
def padded_forward(input_tensor, *args, **kwargs):
|
|
567
|
+
# Only do padding for fp8 if we are in fp8 context
|
|
568
|
+
if not FP8GlobalStateManager.is_fp8_enabled():
|
|
569
|
+
return original_forward(input_tensor, *args, **kwargs)
|
|
570
|
+
|
|
571
|
+
seq_len, batch_size, hidden_size = input_tensor.shape
|
|
572
|
+
# Reshape to (S, B*H) to pad sequence dimension
|
|
573
|
+
input_2d = input_tensor.reshape(seq_len, -1)
|
|
574
|
+
# Pad the sequence dimension
|
|
575
|
+
padded_input_2d, _ = _pad_func(input_2d, [seq_len])
|
|
576
|
+
padded_seq_len = padded_input_2d.shape[0]
|
|
577
|
+
|
|
578
|
+
# Reshape back to (padded_S, B, H)
|
|
579
|
+
padded_input_3d = padded_input_2d.view(padded_seq_len, batch_size, hidden_size)
|
|
580
|
+
output = original_forward(padded_input_3d, *args, **kwargs)
|
|
581
|
+
|
|
582
|
+
# Handle output
|
|
583
|
+
if isinstance(output, tuple):
|
|
584
|
+
output_tensor = output[0]
|
|
585
|
+
other_outputs = output[1:]
|
|
586
|
+
else:
|
|
587
|
+
output_tensor = output
|
|
588
|
+
other_outputs = ()
|
|
589
|
+
|
|
590
|
+
# Unpad output - reshape to 2D, unpad, reshape back
|
|
591
|
+
_, _, output_hidden_size = output_tensor.shape
|
|
592
|
+
output_2d = output_tensor.reshape(padded_seq_len, -1)
|
|
593
|
+
unpadded_output_2d = _unpad_func(output_2d, [seq_len])
|
|
594
|
+
unpadded_output = unpadded_output_2d.reshape(seq_len, batch_size, output_hidden_size)
|
|
595
|
+
|
|
596
|
+
if other_outputs:
|
|
597
|
+
return (unpadded_output,) + other_outputs
|
|
598
|
+
else:
|
|
599
|
+
return unpadded_output
|
|
600
|
+
|
|
601
|
+
module.forward = padded_forward
|
|
602
|
+
_fp8_inference_wrapped_modules.add(module)
|
|
603
|
+
|
|
604
|
+
def prepare_model_for_fp8_inference(model):
|
|
605
|
+
"""Prepare a model for FP8 inference by wrapping TE linear layers with padding support.
|
|
606
|
+
|
|
607
|
+
FP8 TE Gemms have specific shape requirements. This function wraps all Transformer
|
|
608
|
+
Engine linear layers in the model to automatically pad/unpad sequences during inference.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
model (model (GPTModel): Model containing TE linear layers.
|
|
612
|
+
|
|
613
|
+
Returns:
|
|
614
|
+
GPTModel: The same model with wrapped linear layers (modified in-place).
|
|
615
|
+
|
|
616
|
+
"""
|
|
617
|
+
assert Fp8Padding and Fp8Unpadding, "TE version does not have FP8 padding functions"
|
|
618
|
+
# Find and wrap all TE linear layers
|
|
619
|
+
for module in model.modules():
|
|
620
|
+
if isinstance(module, TE_LINEAR_TYPES):
|
|
621
|
+
_wrap_te_linear_for_padding(module)
|
|
622
|
+
|
|
623
|
+
return model
|
|
624
|
+
|
|
625
|
+
else:
|
|
626
|
+
|
|
627
|
+
def prepare_model_for_fp8_inference(model):
|
|
628
|
+
"""If trys using prepare_model_for_fp8_inference without TE we error"""
|
|
629
|
+
raise RuntimeError(
|
|
630
|
+
"prepare_model_for_fp8_inference requires Transformer Engine to be installed. "
|
|
631
|
+
"Please install transformer-engine to use FP8 inference."
|
|
632
|
+
)
|
|
@@ -28,16 +28,25 @@ if not HAVE_TRITON:
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@triton.jit
|
|
31
|
-
def _get_thd_token_idx(cu_seqlens, pid_m, seq_num):
|
|
31
|
+
def _get_thd_token_idx(cu_seqlens, pid_m, seq_num, cp_rank, cp_size):
|
|
32
32
|
token_idx = -1
|
|
33
|
+
this_seq_len = 0
|
|
33
34
|
seq_idx = 0
|
|
34
|
-
last_cum_seqlen = tl.load(cu_seqlens)
|
|
35
|
+
last_cum_seqlen = tl.load(cu_seqlens) // cp_size
|
|
35
36
|
while seq_idx < seq_num:
|
|
36
|
-
cur_cum_seqlen = tl.load(cu_seqlens + seq_idx + 1)
|
|
37
|
+
cur_cum_seqlen = tl.load(cu_seqlens + seq_idx + 1) // cp_size
|
|
37
38
|
if token_idx == -1 and cur_cum_seqlen > pid_m:
|
|
38
39
|
token_idx = pid_m - last_cum_seqlen
|
|
40
|
+
this_seq_len = cur_cum_seqlen - last_cum_seqlen
|
|
39
41
|
last_cum_seqlen = cur_cum_seqlen
|
|
40
42
|
seq_idx += 1
|
|
43
|
+
if cp_size > 1:
|
|
44
|
+
if token_idx < this_seq_len // 2:
|
|
45
|
+
token_idx = token_idx + cp_rank * this_seq_len // 2
|
|
46
|
+
else:
|
|
47
|
+
token_idx = (token_idx - this_seq_len // 2) + (
|
|
48
|
+
2 * cp_size - cp_rank - 1
|
|
49
|
+
) * this_seq_len // 2
|
|
41
50
|
return token_idx
|
|
42
51
|
|
|
43
52
|
|
|
@@ -68,6 +77,8 @@ def rotary_fwd_q_kernel(
|
|
|
68
77
|
cu_seqlens_q,
|
|
69
78
|
stride_x_seq,
|
|
70
79
|
stride_x_nheads,
|
|
80
|
+
cp_rank,
|
|
81
|
+
cp_size,
|
|
71
82
|
BLOCK_H: tl.constexpr,
|
|
72
83
|
):
|
|
73
84
|
"""
|
|
@@ -89,7 +100,7 @@ def rotary_fwd_q_kernel(
|
|
|
89
100
|
if cu_seqlens_q is None:
|
|
90
101
|
token_idx = pid_m // batch_size
|
|
91
102
|
else:
|
|
92
|
-
token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num)
|
|
103
|
+
token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num, cp_rank, cp_size)
|
|
93
104
|
|
|
94
105
|
cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
95
106
|
sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
@@ -146,6 +157,8 @@ def rotary_bwd_q_kernel(
|
|
|
146
157
|
cu_seqlens_q,
|
|
147
158
|
stride_x_seq,
|
|
148
159
|
stride_x_nheads,
|
|
160
|
+
cp_rank,
|
|
161
|
+
cp_size,
|
|
149
162
|
BLOCK_H: tl.constexpr,
|
|
150
163
|
):
|
|
151
164
|
"""
|
|
@@ -165,7 +178,7 @@ def rotary_bwd_q_kernel(
|
|
|
165
178
|
if cu_seqlens_q is None:
|
|
166
179
|
token_idx = pid_m // batch_size
|
|
167
180
|
else:
|
|
168
|
-
token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num)
|
|
181
|
+
token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num, cp_rank, cp_size)
|
|
169
182
|
|
|
170
183
|
cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
171
184
|
sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
@@ -200,7 +213,18 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
|
|
|
200
213
|
"""
|
|
201
214
|
|
|
202
215
|
@staticmethod
|
|
203
|
-
def forward(
|
|
216
|
+
def forward(
|
|
217
|
+
ctx,
|
|
218
|
+
q,
|
|
219
|
+
cos,
|
|
220
|
+
sin,
|
|
221
|
+
qk_head_dim,
|
|
222
|
+
emb_dim,
|
|
223
|
+
cu_seqlens_q,
|
|
224
|
+
cp_rank,
|
|
225
|
+
cp_size,
|
|
226
|
+
rotary_interleaved=False,
|
|
227
|
+
):
|
|
204
228
|
"""
|
|
205
229
|
Forward function for ApplyMLARotaryEmbQ.
|
|
206
230
|
|
|
@@ -243,12 +267,16 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
|
|
|
243
267
|
cu_seqlens_q,
|
|
244
268
|
q.stride(0),
|
|
245
269
|
q.stride(1),
|
|
270
|
+
cp_rank,
|
|
271
|
+
cp_size,
|
|
246
272
|
)
|
|
247
273
|
ctx.save_for_backward(cos, sin)
|
|
248
274
|
ctx.qk_head_dim = qk_head_dim
|
|
249
275
|
ctx.emb_dim = emb_dim
|
|
250
276
|
ctx.cu_seqlens_q = cu_seqlens_q
|
|
251
277
|
ctx.rotary_interleaved = rotary_interleaved
|
|
278
|
+
ctx.cp_rank = cp_rank
|
|
279
|
+
ctx.cp_size = cp_size
|
|
252
280
|
if cu_seqlens_q is None:
|
|
253
281
|
q = q.view(max_seqlen, batch_size, nheads, headdim)
|
|
254
282
|
return q
|
|
@@ -268,7 +296,7 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
|
|
|
268
296
|
seq_num = None
|
|
269
297
|
if ctx.cu_seqlens_q is None:
|
|
270
298
|
max_seqlen, batch_size, nheads, headdim = grad.shape
|
|
271
|
-
grad = grad.view(-1, nheads, headdim)
|
|
299
|
+
grad = grad.contiguous().view(-1, nheads, headdim)
|
|
272
300
|
total_seqlen = grad.shape[0]
|
|
273
301
|
else:
|
|
274
302
|
seq_num = len(ctx.cu_seqlens_q) - 1
|
|
@@ -288,10 +316,12 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
|
|
|
288
316
|
ctx.cu_seqlens_q,
|
|
289
317
|
grad.stride(0),
|
|
290
318
|
grad.stride(1),
|
|
319
|
+
ctx.cp_rank,
|
|
320
|
+
ctx.cp_size,
|
|
291
321
|
)
|
|
292
322
|
if ctx.cu_seqlens_q is None:
|
|
293
323
|
grad = grad.view(max_seqlen, batch_size, nheads, headdim)
|
|
294
|
-
return grad, None, None, None, None, None, None
|
|
324
|
+
return grad, None, None, None, None, None, None, None, None
|
|
295
325
|
|
|
296
326
|
|
|
297
327
|
@experimental_fn(introduced_with_version="0.13.0")
|
|
@@ -302,6 +332,8 @@ def fused_apply_mla_rope_for_q(
|
|
|
302
332
|
qk_head_dim: int,
|
|
303
333
|
emb_dim: int,
|
|
304
334
|
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
335
|
+
cp_rank: int = 0,
|
|
336
|
+
cp_size: int = 1,
|
|
305
337
|
rotary_interleaved: bool = False,
|
|
306
338
|
):
|
|
307
339
|
"""
|
|
@@ -327,7 +359,7 @@ def fused_apply_mla_rope_for_q(
|
|
|
327
359
|
t: inplace modified input tensor
|
|
328
360
|
"""
|
|
329
361
|
return ApplyMLARotaryEmbQ.apply(
|
|
330
|
-
t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, rotary_interleaved
|
|
362
|
+
t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, cp_rank, cp_size, rotary_interleaved
|
|
331
363
|
)
|
|
332
364
|
|
|
333
365
|
|
|
@@ -366,6 +398,8 @@ def rotary_fwd_kv_kernel(
|
|
|
366
398
|
stride_k_nheads,
|
|
367
399
|
stride_v_seq,
|
|
368
400
|
stride_v_nheads,
|
|
401
|
+
cp_rank,
|
|
402
|
+
cp_size,
|
|
369
403
|
BLOCK_H: tl.constexpr,
|
|
370
404
|
):
|
|
371
405
|
"""
|
|
@@ -394,7 +428,7 @@ def rotary_fwd_kv_kernel(
|
|
|
394
428
|
if cu_seqlens_kv is None:
|
|
395
429
|
token_idx = pid_m // batch_size
|
|
396
430
|
else:
|
|
397
|
-
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num)
|
|
431
|
+
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
|
|
398
432
|
|
|
399
433
|
cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
400
434
|
sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
@@ -472,6 +506,8 @@ def rotary_bwd_kv_kernel(
|
|
|
472
506
|
stride_dkv_seq,
|
|
473
507
|
stride_dkv_nheads,
|
|
474
508
|
stride_demb_seq,
|
|
509
|
+
cp_rank,
|
|
510
|
+
cp_size,
|
|
475
511
|
BLOCK_H: tl.constexpr,
|
|
476
512
|
):
|
|
477
513
|
"""
|
|
@@ -496,7 +532,7 @@ def rotary_bwd_kv_kernel(
|
|
|
496
532
|
if cu_seqlens_kv is None:
|
|
497
533
|
token_idx = pid_m // batch_size
|
|
498
534
|
else:
|
|
499
|
-
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num)
|
|
535
|
+
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
|
|
500
536
|
|
|
501
537
|
dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads
|
|
502
538
|
dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads
|
|
@@ -550,7 +586,18 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
550
586
|
|
|
551
587
|
@staticmethod
|
|
552
588
|
def forward(
|
|
553
|
-
ctx,
|
|
589
|
+
ctx,
|
|
590
|
+
kv,
|
|
591
|
+
k_pos_emb,
|
|
592
|
+
cos,
|
|
593
|
+
sin,
|
|
594
|
+
emb_dim,
|
|
595
|
+
k_dim,
|
|
596
|
+
v_dim,
|
|
597
|
+
cu_seqlens_kv,
|
|
598
|
+
cp_rank,
|
|
599
|
+
cp_size,
|
|
600
|
+
rotary_interleaved=False,
|
|
554
601
|
):
|
|
555
602
|
"""
|
|
556
603
|
Forward function for ApplyMLARotaryEmbKV.
|
|
@@ -609,6 +656,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
609
656
|
o_key.stride(1),
|
|
610
657
|
o_value.stride(0),
|
|
611
658
|
o_value.stride(1),
|
|
659
|
+
cp_rank,
|
|
660
|
+
cp_size,
|
|
612
661
|
)
|
|
613
662
|
ctx.save_for_backward(cos, sin)
|
|
614
663
|
ctx.rotary_interleaved = rotary_interleaved
|
|
@@ -616,6 +665,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
616
665
|
ctx.k_dim = k_dim
|
|
617
666
|
ctx.v_dim = v_dim
|
|
618
667
|
ctx.cu_seqlens_kv = cu_seqlens_kv
|
|
668
|
+
ctx.cp_rank = cp_rank
|
|
669
|
+
ctx.cp_size = cp_size
|
|
619
670
|
if cu_seqlens_kv is None:
|
|
620
671
|
o_key = o_key.view(max_seqlen, -1, nheads, emb_dim + k_dim)
|
|
621
672
|
o_value = o_value.view(max_seqlen, -1, nheads, v_dim)
|
|
@@ -638,8 +689,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
638
689
|
if ctx.cu_seqlens_kv is None:
|
|
639
690
|
# sbhd
|
|
640
691
|
max_seqlen, batch_size, nheads, _ = dk.shape
|
|
641
|
-
dk = dk.view(-1, nheads, ctx.emb_dim + ctx.k_dim)
|
|
642
|
-
dv = dv.view(-1, nheads, ctx.v_dim)
|
|
692
|
+
dk = dk.contiguous().view(-1, nheads, ctx.emb_dim + ctx.k_dim)
|
|
693
|
+
dv = dv.contiguous().view(-1, nheads, ctx.v_dim)
|
|
643
694
|
total_seqlen = dk.shape[0]
|
|
644
695
|
else:
|
|
645
696
|
# thd
|
|
@@ -673,11 +724,13 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
673
724
|
d_kv.stride(0),
|
|
674
725
|
d_kv.stride(1),
|
|
675
726
|
d_emb.stride(0),
|
|
727
|
+
ctx.cp_rank,
|
|
728
|
+
ctx.cp_size,
|
|
676
729
|
)
|
|
677
730
|
if ctx.cu_seqlens_kv is None:
|
|
678
731
|
d_kv = d_kv.view(max_seqlen, batch_size, nheads, ctx.k_dim + ctx.v_dim)
|
|
679
732
|
d_emb = d_emb.view(max_seqlen, batch_size, 1, ctx.emb_dim)
|
|
680
|
-
return d_kv, d_emb, None, None, None, None, None, None, None
|
|
733
|
+
return d_kv, d_emb, None, None, None, None, None, None, None, None, None
|
|
681
734
|
|
|
682
735
|
|
|
683
736
|
@experimental_fn(introduced_with_version="0.13.0")
|
|
@@ -690,6 +743,8 @@ def fused_apply_mla_rope_for_kv(
|
|
|
690
743
|
k_dim: int,
|
|
691
744
|
v_dim: int,
|
|
692
745
|
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
|
746
|
+
cp_rank: int = 0,
|
|
747
|
+
cp_size: int = 1,
|
|
693
748
|
rotary_interleaved: bool = False,
|
|
694
749
|
):
|
|
695
750
|
"""
|
|
@@ -715,5 +770,15 @@ def fused_apply_mla_rope_for_kv(
|
|
|
715
770
|
value: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]
|
|
716
771
|
"""
|
|
717
772
|
return ApplyMLARotaryEmbKV.apply(
|
|
718
|
-
kv,
|
|
773
|
+
kv,
|
|
774
|
+
k_pos_emb,
|
|
775
|
+
cos,
|
|
776
|
+
sin,
|
|
777
|
+
emb_dim,
|
|
778
|
+
k_dim,
|
|
779
|
+
v_dim,
|
|
780
|
+
cu_seqlens_kv,
|
|
781
|
+
cp_rank,
|
|
782
|
+
cp_size,
|
|
783
|
+
rotary_interleaved,
|
|
719
784
|
)
|