megatron-core 0.14.0rc1__tar.gz → 0.14.0rc2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- {megatron_core-0.14.0rc1/megatron_core.egg-info → megatron_core-0.14.0rc2}/PKG-INFO +2 -2
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/enums.py +10 -3
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fp8_utils.py +6 -2
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/dynamic_context.py +52 -6
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/static_context.py +1 -1
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/dynamic_engine.py +18 -3
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +2 -10
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +2 -6
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +2 -9
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +15 -2
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +57 -13
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +15 -2
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/utils.py +16 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/model_parallel_config.py +0 -5
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/T5/t5_model.py +2 -7
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/T5/t5_spec.py +2 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/bert_layer_specs.py +2 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/language_model_embedding.py +3 -3
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +2 -2
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/language_module/language_module.py +57 -17
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/gpt_layer_specs.py +4 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/gpt_model.py +19 -15
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +2 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/moe_module_specs.py +2 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mamba/mamba_model.py +12 -16
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/submodules/audio.py +1 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/llava_model.py +19 -4
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/decoder_spec.py +2 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/encoder_spec.py +2 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/clip_vit_model.py +9 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/multimodal_projector.py +10 -1
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/radio.py +7 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/__init__.py +38 -4
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/distrib_optimizer.py +54 -6
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/optimizer.py +27 -1
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/package_info.py +1 -1
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/parallel_state.py +42 -451
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/p2p_communication.py +25 -68
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/schedules.py +12 -73
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/utils.py +57 -1
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/rerun_state_machine.py +123 -86
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/cuda_graphs.py +62 -45
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/enums.py +8 -1
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/heterogeneous/linear_replacements.py +4 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/experts.py +1 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/moe_layer.py +2 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/moe_utils.py +6 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/router.py +23 -2
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/multi_latent_attention.py +9 -3
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/multi_token_prediction.py +10 -3
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/transformer_block.py +22 -11
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/transformer_config.py +31 -2
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/transformer_layer.py +0 -4
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2/megatron_core.egg-info}/PKG-INFO +2 -2
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron_core.egg-info/requires.txt +1 -1
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/pyproject.toml +13 -3
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/LICENSE +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/MANIFEST.in +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/README.md +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/README.md +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/config_logger.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/bert_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/blended_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/helpers.cpp +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/helpers.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/indexed_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/masked_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/megatron_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/megatron_tokenizer.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/multimodal_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/object_storage_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/build.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/external_libs.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/build.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/factory.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/index.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/validate.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/query.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/t5_dataset.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/utils_object_storage.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/utils_s3.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/core.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/mapping.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/optimizer.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/serialization.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/validation.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/data_parallel_base.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/distributed_data_parallel.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/finalize_model_grads.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/energy_monitor.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/data_type.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/export_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/model_type.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trt_model_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trt_model_type.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/extensions/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/extensions/kitchen.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/extensions/transformer_engine.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_dropout.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_geglu.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_gelu.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_cross_entropy.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_indices_converter.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_layer_norm.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_softmax.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/hyper_comm_grid.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/async_stream.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/common_inference_params.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/communication_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/base_context.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/abstract_engine.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/mcore_engine.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/static_engine.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/inference_request.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/sampling_params.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/scheduler.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference_params.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/jit.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/T5/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/backends.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/bert_lm_head.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/bert_model.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/pooler.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/language_module/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/vision_module/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/vision_module/vision_module.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/clip_model.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/module.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/qwen_model.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/config/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/config/base_configs.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/model/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/model/base.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/submodules/base.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/submodules/vision.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/context_parallel.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/llava_spec.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/base_attention.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/decoder_attention.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/encoder_attention.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/model.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/vit_layer_specs.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/msc_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/num_microbatches_calculator.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/clip_grads.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/grad_scaler.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/optimizer_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer_param_scheduler.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/packed_seq_params.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/layers.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/process_groups_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/quantization/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/quantization/quant_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/quantization/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/requirements.txt +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_block.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_context_parallel.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_layer.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_mixer.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mlp_layer.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/triton_cache_manager.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/data.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/layers.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/mappings.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/random.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/timers.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/attention.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/custom_layers/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/dot_product_attention.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/identity_op.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/mlp.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/module.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/__init__.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/fused_a2a.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/shared_experts.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/spec_utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/torch_layer_norm.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/torch_norm.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/utils.py +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron_core.egg-info/SOURCES.txt +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron_core.egg-info/dependency_links.txt +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron_core.egg-info/top_level.txt +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/setup.cfg +0 -0
- {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: megatron-core
|
|
3
|
-
Version: 0.14.
|
|
3
|
+
Version: 0.14.0rc2
|
|
4
4
|
Summary: Megatron Core - a library for efficient and scalable training of transformer based models
|
|
5
5
|
Author-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
6
6
|
Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
@@ -44,7 +44,7 @@ Requires-Dist: nvtx; extra == "dev"
|
|
|
44
44
|
Requires-Dist: transformers; extra == "dev"
|
|
45
45
|
Requires-Dist: multi-storage-client; extra == "dev"
|
|
46
46
|
Requires-Dist: setuptools<80.0.0; extra == "dev"
|
|
47
|
-
Requires-Dist: nvidia-modelopt[torch]; sys_platform != "darwin" and extra == "dev"
|
|
47
|
+
Requires-Dist: nvidia-modelopt[torch]~=0.31.0; sys_platform != "darwin" and extra == "dev"
|
|
48
48
|
Requires-Dist: megatron-energon[av_decode]<7; extra == "dev"
|
|
49
49
|
Provides-Extra: lts
|
|
50
50
|
Requires-Dist: tqdm; extra == "lts"
|
|
@@ -7,9 +7,16 @@ class ModelType(enum.Enum):
|
|
|
7
7
|
"""Model type."""
|
|
8
8
|
|
|
9
9
|
encoder_or_decoder = 1
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
10
|
+
retro_encoder = 2
|
|
11
|
+
retro_decoder = 3
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def encoder_and_decoder(self):
|
|
15
|
+
"""Deprecated property - use encoder_or_decoder instead."""
|
|
16
|
+
raise ValueError(
|
|
17
|
+
"ModelType.encoder_and_decoder is deprecated. Please use ModelType.encoder_or_decoder "
|
|
18
|
+
"instead."
|
|
19
|
+
)
|
|
13
20
|
|
|
14
21
|
|
|
15
22
|
class Fp8Recipe(str, enum.Enum):
|
|
@@ -346,8 +346,12 @@ else:
|
|
|
346
346
|
def _modify_underlying_storage_impl(*args, **kwargs):
|
|
347
347
|
raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
|
|
348
348
|
|
|
349
|
-
def _quantize_param_shard_impl(*args, **kwargs):
|
|
350
|
-
|
|
349
|
+
def _quantize_param_shard_impl(model_params, *args, **kwargs):
|
|
350
|
+
if len(model_params) == 0:
|
|
351
|
+
return
|
|
352
|
+
else:
|
|
353
|
+
# If TE is not installed, there shouldn't be any fp8 params.
|
|
354
|
+
raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
|
|
351
355
|
|
|
352
356
|
def _correct_amax_history_if_needed_impl(*args, **kwargs):
|
|
353
357
|
# If TE is not installed, we are definitely not using fp8 for training, so no correction
|
|
@@ -2,9 +2,11 @@
|
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
4
|
import warnings
|
|
5
|
-
from typing import Optional, Tuple
|
|
5
|
+
from typing import List, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from packaging.version import Version as PkgVersion
|
|
8
10
|
from torch import Tensor
|
|
9
11
|
|
|
10
12
|
from megatron.core import parallel_state
|
|
@@ -123,8 +125,10 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
123
125
|
max_requests_override: Optional[int] = None,
|
|
124
126
|
max_tokens_override: Optional[int] = None,
|
|
125
127
|
tensor_model_parallel_size: Optional[int] = None,
|
|
128
|
+
materialize_only_last_token_logits: bool = True,
|
|
126
129
|
):
|
|
127
|
-
|
|
130
|
+
|
|
131
|
+
super().__init__(materialize_only_last_token_logits=materialize_only_last_token_logits)
|
|
128
132
|
# Per partition num heads and hidden size.
|
|
129
133
|
projection_size = kv_channels * num_attention_heads
|
|
130
134
|
if tensor_model_parallel_size is None:
|
|
@@ -762,7 +766,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
762
766
|
self.total_request_count += 1
|
|
763
767
|
self.active_token_count += context_length
|
|
764
768
|
|
|
765
|
-
def
|
|
769
|
+
def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens):
|
|
766
770
|
"""
|
|
767
771
|
Swaps all the relevent booking tensors with src idxs to dst idxs
|
|
768
772
|
"""
|
|
@@ -866,7 +870,12 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
866
870
|
kv_chunks_asigned = self.request_to_kv_chunk_ids[finished_idxs]
|
|
867
871
|
non_zero_values_in_kv_memory = kv_chunks_asigned[kv_chunks_asigned != -1]
|
|
868
872
|
self.chunk_allocator.release_memory_chunks(non_zero_values_in_kv_memory)
|
|
869
|
-
|
|
873
|
+
|
|
874
|
+
# Reset the KV chunks for finished requests.
|
|
875
|
+
# Note: do not use fill_() (or add_() and similar inplace ops) here.
|
|
876
|
+
# The combinition of indexing with a tensor (like finished_idxs) and fill_()/add_() creates a clone
|
|
877
|
+
# and updates it instead of the original tensor.
|
|
878
|
+
self.request_to_kv_chunk_ids[finished_idxs] = -1
|
|
870
879
|
|
|
871
880
|
if active_request_count > 0:
|
|
872
881
|
finished_idxs_on_left = (
|
|
@@ -881,12 +890,15 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
881
890
|
+ self.paused_request_count
|
|
882
891
|
)
|
|
883
892
|
|
|
884
|
-
self.
|
|
893
|
+
self._move_book_keeping_tensors(
|
|
885
894
|
src_idxs=active_idxs_on_right,
|
|
886
895
|
dst_idxs=finished_idxs_on_left,
|
|
887
896
|
next_tokens=next_tokens,
|
|
888
897
|
)
|
|
889
898
|
|
|
899
|
+
# Reset chunk ids for recently moved requests.
|
|
900
|
+
self.request_to_kv_chunk_ids[active_idxs_on_right] = -1
|
|
901
|
+
|
|
890
902
|
# 5. We identify requests that require a new chunk and add them to the paused requests (i.e move them left) :-
|
|
891
903
|
# a) Put requests that have filled their current chunk and require a new one in a pause state temporarily
|
|
892
904
|
# b) Move the paused requests to the left, and active requets to the right
|
|
@@ -931,7 +943,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
931
943
|
)
|
|
932
944
|
dst_idxs = torch.cat((active_request_ids_on_left, paused_requests_idxs_on_right))
|
|
933
945
|
src_idxs = torch.cat((paused_requests_idxs_on_right, active_request_ids_on_left))
|
|
934
|
-
self.
|
|
946
|
+
self._move_book_keeping_tensors(
|
|
935
947
|
src_idxs=src_idxs, dst_idxs=dst_idxs, next_tokens=next_tokens
|
|
936
948
|
)
|
|
937
949
|
|
|
@@ -974,6 +986,8 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
974
986
|
if self.paused_request_count > 0:
|
|
975
987
|
self.paused_tokens = next_tokens[: self.paused_request_count]
|
|
976
988
|
|
|
989
|
+
# add_ and fill_ calls seems to work as intended with sliced indexing (i.e. x[3:5].add(...) or x[3:5].fill_)
|
|
990
|
+
# but when another tensor is used for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors)
|
|
977
991
|
self.request_kv_length_offsets[self.paused_request_count : self.total_request_count].add_(
|
|
978
992
|
self.request_query_lengths[self.paused_request_count : self.total_request_count]
|
|
979
993
|
)
|
|
@@ -1027,3 +1041,35 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
1027
1041
|
self.token_to_local_position_within_kv_chunk[: self.active_token_count] = (
|
|
1028
1042
|
self.request_last_kv_chunk_offset[self.paused_request_count : self.total_request_count]
|
|
1029
1043
|
)
|
|
1044
|
+
|
|
1045
|
+
def calculate_log_probs(self, logits: torch.Tensor) -> List[List[float]]:
|
|
1046
|
+
"""Calculate log probs for all active requests and return them.
|
|
1047
|
+
|
|
1048
|
+
TODO: @wdykas support top-n log probs.
|
|
1049
|
+
|
|
1050
|
+
Args:
|
|
1051
|
+
logits: Raw model output logits with shape [1, sequence_length, vocab_size].
|
|
1052
|
+
|
|
1053
|
+
Returns:
|
|
1054
|
+
List of lists where each inner list contains log probs for a request in the
|
|
1055
|
+
same order as the active requests (from paused_request_count to total_request_count).
|
|
1056
|
+
"""
|
|
1057
|
+
# Calculate log_probs (sequence_length x vocab_size)
|
|
1058
|
+
log_probs = F.log_softmax(logits, dim=-1).to(torch.float32).squeeze()
|
|
1059
|
+
|
|
1060
|
+
# Extract the log probs for only the selected tokens
|
|
1061
|
+
# (sequence_length x vocab_size) -> (sequence_length)
|
|
1062
|
+
active_token_ids = self.token_to_input_ids[: self.active_token_count]
|
|
1063
|
+
sequence_indices = torch.arange(self.active_token_count, device=log_probs.device)
|
|
1064
|
+
selected_log_probs = log_probs[sequence_indices, active_token_ids]
|
|
1065
|
+
|
|
1066
|
+
# Split the log probs across request boundaries
|
|
1067
|
+
active_query_lengths = self.request_query_lengths[
|
|
1068
|
+
self.paused_request_count : self.total_request_count
|
|
1069
|
+
]
|
|
1070
|
+
selected_log_probs_list = selected_log_probs.cpu().split(
|
|
1071
|
+
active_query_lengths.tolist(), dim=0
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
# Convert each log prob tensor into a list
|
|
1075
|
+
return [lp.tolist() for lp in selected_log_probs_list]
|
|
@@ -17,7 +17,7 @@ class StaticInferenceContext(BaseInferenceContext):
|
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
19
|
def __init__(self, max_batch_size: int, max_sequence_length: int):
|
|
20
|
-
super().__init__(materialize_only_last_token_logits=
|
|
20
|
+
super().__init__(materialize_only_last_token_logits=True)
|
|
21
21
|
self.max_sequence_length = max_sequence_length
|
|
22
22
|
self.max_batch_size = max_batch_size
|
|
23
23
|
self.sequence_len_offset = 0
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from collections import deque
|
|
5
|
+
from itertools import repeat
|
|
5
6
|
from typing import Dict, List, Optional, Tuple, Union
|
|
6
7
|
|
|
7
8
|
import torch
|
|
@@ -182,6 +183,7 @@ class DynamicInferenceEngine(AbstractEngine):
|
|
|
182
183
|
finished_request_ids: torch.Tensor,
|
|
183
184
|
step_time: float,
|
|
184
185
|
sample: torch.Tensor,
|
|
186
|
+
log_probs: torch.Tensor,
|
|
185
187
|
) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]:
|
|
186
188
|
"""
|
|
187
189
|
Handles post-processing for requests after a step.
|
|
@@ -191,6 +193,7 @@ class DynamicInferenceEngine(AbstractEngine):
|
|
|
191
193
|
finished_request_ids (torch.Tensor): A list of finished request ids
|
|
192
194
|
step_time (float): The latency of the last step
|
|
193
195
|
sample: (torch.Tensor): The newly generated tokens for each request
|
|
196
|
+
log_probs: (List): Log probs for each request
|
|
194
197
|
|
|
195
198
|
Returns:
|
|
196
199
|
A list of active requests and completed requests as `DynamicInferenceRequest` objects
|
|
@@ -200,13 +203,25 @@ class DynamicInferenceEngine(AbstractEngine):
|
|
|
200
203
|
finished_request_ids = set(finished_request_ids.tolist())
|
|
201
204
|
self.finished_request_count += len(finished_request_ids)
|
|
202
205
|
|
|
203
|
-
|
|
206
|
+
log_probs_iter = log_probs if log_probs else repeat(None)
|
|
207
|
+
|
|
208
|
+
for request_id, token, request_log_probs in zip(
|
|
209
|
+
request_ids.tolist(), sample.tolist(), log_probs_iter
|
|
210
|
+
):
|
|
204
211
|
request: DynamicInferenceRequest = self.requests[request_id]
|
|
205
212
|
request.generated_tokens.append(token)
|
|
206
213
|
if request.tpot is None:
|
|
207
214
|
request.tpot = []
|
|
208
215
|
request.tpot.append(step_time)
|
|
209
216
|
|
|
217
|
+
if request_log_probs is not None:
|
|
218
|
+
# If prompt log probs is None we are in prefill
|
|
219
|
+
if request.prompt_log_probs is None:
|
|
220
|
+
request.prompt_log_probs = request_log_probs
|
|
221
|
+
request.generated_log_probs = []
|
|
222
|
+
else:
|
|
223
|
+
request.generated_log_probs.extend(request_log_probs)
|
|
224
|
+
|
|
210
225
|
if request_id in finished_request_ids:
|
|
211
226
|
request.generated_length = len(request.generated_tokens)
|
|
212
227
|
request.status = Status.COMPLETED
|
|
@@ -266,11 +281,11 @@ class DynamicInferenceEngine(AbstractEngine):
|
|
|
266
281
|
step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3
|
|
267
282
|
|
|
268
283
|
if result is not None:
|
|
269
|
-
request_ids, finished_request_ids, sample = result
|
|
284
|
+
request_ids, finished_request_ids, sample, log_probs = result
|
|
270
285
|
|
|
271
286
|
# TODO: Move this to a background thread?
|
|
272
287
|
(active_requests, finished_requests) = self.post_process_requests(
|
|
273
|
-
request_ids, finished_request_ids, step_time, sample
|
|
288
|
+
request_ids, finished_request_ids, step_time, sample, log_probs
|
|
274
289
|
)
|
|
275
290
|
|
|
276
291
|
# TODO: Move this to a background thread?
|
|
@@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, Optional, Union
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from megatron.core import parallel_state
|
|
10
|
+
from megatron.core import parallel_state
|
|
11
11
|
from megatron.core.inference.communication_utils import (
|
|
12
12
|
is_pipeline_first_stage,
|
|
13
13
|
is_pipeline_last_stage,
|
|
@@ -152,13 +152,12 @@ class AbstractModelInferenceWrapper(abc.ABC):
|
|
|
152
152
|
tokens = inference_input["tokens"]
|
|
153
153
|
position_ids = inference_input["position_ids"]
|
|
154
154
|
attention_mask = inference_input["attention_mask"]
|
|
155
|
-
runtime_gather_output = inference_input.get("runtime_gather_output")
|
|
156
155
|
return self.model(
|
|
157
156
|
tokens,
|
|
158
157
|
position_ids,
|
|
159
158
|
attention_mask,
|
|
160
159
|
inference_context=self.inference_context,
|
|
161
|
-
runtime_gather_output=
|
|
160
|
+
runtime_gather_output=True, # Inference should always gather the logits
|
|
162
161
|
)
|
|
163
162
|
|
|
164
163
|
def _get_batch_size_and_seq_len(
|
|
@@ -201,7 +200,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
|
|
|
201
200
|
"""
|
|
202
201
|
tokens = inference_input["tokens"]
|
|
203
202
|
logits = self._forward(inference_input)
|
|
204
|
-
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits, self.tp_group)
|
|
205
203
|
self.inference_context.increment_sequence_len_offset(tokens.size(1))
|
|
206
204
|
|
|
207
205
|
return logits
|
|
@@ -243,7 +241,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
|
|
|
243
241
|
logits = None
|
|
244
242
|
if is_pipeline_last_stage(self.pp_group):
|
|
245
243
|
logits = output_tensor
|
|
246
|
-
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits, self.tp_group)
|
|
247
244
|
|
|
248
245
|
# Explicitly cast logits to expected dtype
|
|
249
246
|
logits = logits.to(self.inference_wrapper_config.params_dtype)
|
|
@@ -269,7 +266,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
|
|
|
269
266
|
tokens = inference_input["tokens"]
|
|
270
267
|
position_ids = inference_input["position_ids"]
|
|
271
268
|
attention_mask = inference_input["attention_mask"]
|
|
272
|
-
runtime_gather_output = inference_input.get("runtime_gather_output")
|
|
273
269
|
materialize_only_last_token_logits = (
|
|
274
270
|
self.inference_context.materialize_only_last_token_logits
|
|
275
271
|
)
|
|
@@ -317,7 +313,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
|
|
|
317
313
|
"position_ids": position_ids2use,
|
|
318
314
|
"attention_mask": attention_mask,
|
|
319
315
|
"inference_context": self.inference_context,
|
|
320
|
-
"runtime_gather_output": runtime_gather_output,
|
|
321
316
|
}
|
|
322
317
|
)
|
|
323
318
|
|
|
@@ -327,9 +322,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
|
|
|
327
322
|
self.inference_context.batch_size_offset += current_micro_batch_size
|
|
328
323
|
|
|
329
324
|
if is_pipeline_last_stage(self.pp_group):
|
|
330
|
-
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
|
|
331
|
-
output_tensor, self.tp_group
|
|
332
|
-
)
|
|
333
325
|
assert logits is not None
|
|
334
326
|
logits[start:end, ...] = output_tensor
|
|
335
327
|
|
|
@@ -10,6 +10,7 @@ from megatron.core.inference.model_inference_wrappers.abstract_model_inference_w
|
|
|
10
10
|
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
|
|
11
11
|
InferenceWrapperConfig,
|
|
12
12
|
)
|
|
13
|
+
from megatron.core.inference.utils import get_attention_mask
|
|
13
14
|
from megatron.core.models.gpt import GPTModel
|
|
14
15
|
from megatron.core.transformer.enums import AttnBackend
|
|
15
16
|
from megatron.core.utils import get_model_config
|
|
@@ -74,12 +75,7 @@ class GPTInferenceWrapper(AbstractModelInferenceWrapper):
|
|
|
74
75
|
attention_backend = config.attention_backend
|
|
75
76
|
|
|
76
77
|
if attention_backend == AttnBackend.local:
|
|
77
|
-
attention_mask =
|
|
78
|
-
torch.ones((1, seq_length, seq_length), device=prompts_tokens.device)
|
|
79
|
-
).view(1, 1, seq_length, seq_length)
|
|
80
|
-
|
|
81
|
-
# Convert to boolean
|
|
82
|
-
attention_mask = attention_mask < 0.5
|
|
78
|
+
attention_mask = get_attention_mask(seq_length)
|
|
83
79
|
elif (
|
|
84
80
|
attention_backend == AttnBackend.flash
|
|
85
81
|
or attention_backend == AttnBackend.fused
|
|
@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from megatron.core import parallel_state
|
|
8
7
|
from megatron.core.inference.communication_utils import (
|
|
9
8
|
is_pipeline_first_stage,
|
|
10
9
|
is_pipeline_last_stage,
|
|
@@ -48,16 +47,10 @@ class VLMInferenceWrapper(GPTInferenceWrapper):
|
|
|
48
47
|
# has part of the LM decoder. In this case, the current stage should only receive
|
|
49
48
|
# vision embeddings.
|
|
50
49
|
if pp_rank > 0:
|
|
51
|
-
self._recv_only_vision_embeds =
|
|
52
|
-
parallel_state.is_inside_encoder(pp_rank - 1)
|
|
53
|
-
and (not parallel_state.is_inside_decoder(pp_rank - 1))
|
|
54
|
-
and parallel_state.is_inside_decoder()
|
|
55
|
-
)
|
|
50
|
+
self._recv_only_vision_embeds = False # TODO: Implement new logic for vision embeddings
|
|
56
51
|
|
|
57
52
|
# Checks if the current stage only has a vision encoder
|
|
58
|
-
self._encoder_only =
|
|
59
|
-
parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder()
|
|
60
|
-
)
|
|
53
|
+
self._encoder_only = False # TODO: Implement new logic for encoder-only stages
|
|
61
54
|
|
|
62
55
|
def prep_inference_input(
|
|
63
56
|
self,
|
|
@@ -7,6 +7,7 @@ from megatron.core.inference.inference_request import InferenceRequest
|
|
|
7
7
|
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
|
|
8
8
|
TextGenerationController,
|
|
9
9
|
)
|
|
10
|
+
from megatron.core.inference.utils import get_attention_mask
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class EncoderDecoderTextGenerationController(TextGenerationController):
|
|
@@ -18,13 +19,18 @@ class EncoderDecoderTextGenerationController(TextGenerationController):
|
|
|
18
19
|
"""
|
|
19
20
|
|
|
20
21
|
def prep_inference_input(
|
|
21
|
-
self,
|
|
22
|
+
self,
|
|
23
|
+
prompts_tokens: torch.Tensor,
|
|
24
|
+
active_requests: OrderedDict[str, InferenceRequest],
|
|
25
|
+
use_attention_mask: bool = False,
|
|
22
26
|
) -> Dict[str, Any]:
|
|
23
27
|
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
|
|
24
28
|
|
|
25
29
|
Args:
|
|
26
30
|
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
|
|
27
31
|
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
|
|
32
|
+
use_attention_mask (bool): Whether to use an attention mask. Should be set to True only
|
|
33
|
+
when exclusively doing prefill (no decode) with variable prompt lengths.
|
|
28
34
|
|
|
29
35
|
Returns:
|
|
30
36
|
A dict of the inference input for the current batch.
|
|
@@ -33,6 +39,13 @@ class EncoderDecoderTextGenerationController(TextGenerationController):
|
|
|
33
39
|
map(lambda request: request.encoder_prompt, active_requests.values())
|
|
34
40
|
)
|
|
35
41
|
|
|
36
|
-
|
|
42
|
+
inference_input = self.inference_wrapped_model.prep_inference_input(
|
|
37
43
|
prompts_tokens, encoder_prompts, tokenizer=self.tokenizer
|
|
38
44
|
)
|
|
45
|
+
|
|
46
|
+
if use_attention_mask and (
|
|
47
|
+
attention_mask := inference_input.get("attention_mask", None) is None
|
|
48
|
+
):
|
|
49
|
+
inference_input["attention_mask"] = get_attention_mask(prompts_tokens.size(1))
|
|
50
|
+
|
|
51
|
+
return inference_input
|
|
@@ -24,10 +24,13 @@ from megatron.core.inference.model_inference_wrappers.abstract_model_inference_w
|
|
|
24
24
|
AbstractModelInferenceWrapper,
|
|
25
25
|
)
|
|
26
26
|
from megatron.core.inference.sampling_params import SamplingParams
|
|
27
|
+
from megatron.core.inference.utils import get_attention_mask
|
|
27
28
|
from megatron.core.transformer.cuda_graphs import create_cudagraphs
|
|
28
29
|
from megatron.core.utils import get_model_config
|
|
29
30
|
|
|
30
31
|
try:
|
|
32
|
+
import transformer_engine as te # pylint: disable=unused-import
|
|
33
|
+
|
|
31
34
|
from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
|
|
32
35
|
|
|
33
36
|
HAVE_TE = True
|
|
@@ -429,6 +432,11 @@ class TextGenerationController:
|
|
|
429
432
|
|
|
430
433
|
context = self.inference_wrapped_model.inference_context
|
|
431
434
|
|
|
435
|
+
if sampling_params.return_log_probs:
|
|
436
|
+
assert (
|
|
437
|
+
context.materialize_only_last_token_logits is False
|
|
438
|
+
), "Materialize only last token logits must be false for returning log probs"
|
|
439
|
+
|
|
432
440
|
# No tokens?
|
|
433
441
|
if context.active_token_count == 0:
|
|
434
442
|
return None
|
|
@@ -478,7 +486,13 @@ class TextGenerationController:
|
|
|
478
486
|
pp_group=self.pp_group,
|
|
479
487
|
)
|
|
480
488
|
|
|
481
|
-
|
|
489
|
+
# Last token logits.
|
|
490
|
+
if context.materialize_only_last_token_logits:
|
|
491
|
+
# When materialize_only_last_token_logits is true, last_token_logits is
|
|
492
|
+
# already called in the forward pass of GPT.
|
|
493
|
+
last_token_logits = logits.squeeze(0)
|
|
494
|
+
else:
|
|
495
|
+
last_token_logits = context.last_token_logits(logits)
|
|
482
496
|
|
|
483
497
|
# Sample.
|
|
484
498
|
# Use padded vocab size because tokenizer vocab size might not include padding
|
|
@@ -505,11 +519,15 @@ class TextGenerationController:
|
|
|
505
519
|
)
|
|
506
520
|
finished_request_ids = context.request_ids[finished_idxs]
|
|
507
521
|
|
|
522
|
+
log_probs = None
|
|
523
|
+
if sampling_params.return_log_probs:
|
|
524
|
+
log_probs = context.calculate_log_probs(logits)
|
|
525
|
+
|
|
508
526
|
# Update requests.
|
|
509
527
|
# New sample gets updated in update_requests, so we pass in a clone
|
|
510
528
|
context.update_requests(active_request_mask, new_sample.clone())
|
|
511
529
|
|
|
512
|
-
return current_request_ids, finished_request_ids, new_sample
|
|
530
|
+
return current_request_ids, finished_request_ids, new_sample, log_probs
|
|
513
531
|
|
|
514
532
|
def _update_top_n_logprobs_dict(
|
|
515
533
|
self,
|
|
@@ -581,13 +599,12 @@ class TextGenerationController:
|
|
|
581
599
|
|
|
582
600
|
model_config = get_model_config(self.inference_wrapped_model.model)
|
|
583
601
|
|
|
584
|
-
#
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
sampling_params.add_attributes({"echo": True})
|
|
602
|
+
# We only need an attention mask if we are exclusively doing prefill over
|
|
603
|
+
# prompts of variable length
|
|
604
|
+
use_attention_mask = (
|
|
605
|
+
sampling_params.num_tokens_to_generate == 0
|
|
606
|
+
and min_prompt_length_in_batch != max_prompt_length_in_batch
|
|
607
|
+
)
|
|
591
608
|
|
|
592
609
|
# Check whether CUDA graphs are enabled
|
|
593
610
|
enable_cuda_graph = model_config.enable_cuda_graph
|
|
@@ -689,7 +706,9 @@ class TextGenerationController:
|
|
|
689
706
|
self.inference_wrapped_model.prep_model_for_inference()
|
|
690
707
|
|
|
691
708
|
inference_input: Dict[str, Any] = self.prep_inference_input(
|
|
692
|
-
prompts_tokens=padded_batch_prompt_tokens,
|
|
709
|
+
prompts_tokens=padded_batch_prompt_tokens,
|
|
710
|
+
active_requests=active_requests,
|
|
711
|
+
use_attention_mask=use_attention_mask,
|
|
693
712
|
)
|
|
694
713
|
|
|
695
714
|
assert (
|
|
@@ -706,7 +725,13 @@ class TextGenerationController:
|
|
|
706
725
|
self.inference_wrapped_model.model.module.set_symmetric_ar(None)
|
|
707
726
|
|
|
708
727
|
context_start_position = 0
|
|
709
|
-
|
|
728
|
+
|
|
729
|
+
# If we are exclusively doing prefill then we can process all prompt tokens
|
|
730
|
+
# together even if the prompt lengths are different
|
|
731
|
+
if sampling_params.num_tokens_to_generate == 0:
|
|
732
|
+
context_end_position = max_prompt_length_in_batch
|
|
733
|
+
else:
|
|
734
|
+
context_end_position = min_prompt_length_in_batch
|
|
710
735
|
|
|
711
736
|
# The initial iteration of this loop runs the prefill phase up to the shortest
|
|
712
737
|
# prompt length in the batch. Then every subsequent iterations runs a decode step.
|
|
@@ -734,6 +759,13 @@ class TextGenerationController:
|
|
|
734
759
|
and "attention_mask" in inference_input_for_context_window
|
|
735
760
|
):
|
|
736
761
|
inference_input_for_context_window["attention_mask"] = None
|
|
762
|
+
elif use_attention_mask:
|
|
763
|
+
assert (
|
|
764
|
+
attention_mask := inference_input_for_context_window.get(
|
|
765
|
+
"attention_mask", None
|
|
766
|
+
)
|
|
767
|
+
is not None
|
|
768
|
+
)
|
|
737
769
|
|
|
738
770
|
# Only materialize prompt log probs if the user requests log probs
|
|
739
771
|
materialize_only_last_token_logits = (
|
|
@@ -985,18 +1017,30 @@ class TextGenerationController:
|
|
|
985
1017
|
return active_requests
|
|
986
1018
|
|
|
987
1019
|
def prep_inference_input(
|
|
988
|
-
self,
|
|
1020
|
+
self,
|
|
1021
|
+
prompts_tokens: torch.Tensor,
|
|
1022
|
+
active_requests: OrderedDict[str, InferenceRequest],
|
|
1023
|
+
use_attention_mask: bool = False,
|
|
989
1024
|
) -> Dict[str, Any]:
|
|
990
1025
|
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
|
|
991
1026
|
|
|
992
1027
|
Args:
|
|
993
1028
|
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
|
|
994
1029
|
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
|
|
1030
|
+
use_attention_mask (bool): Whether to use an attention mask. Should be set to True only
|
|
1031
|
+
when exclusively doing prefill (no decode) with variable prompt lengths.
|
|
995
1032
|
|
|
996
1033
|
Returns:
|
|
997
1034
|
A dict of the inference input for the current batch.
|
|
998
1035
|
"""
|
|
999
|
-
|
|
1036
|
+
inference_input = self.inference_wrapped_model.prep_inference_input(prompts_tokens)
|
|
1037
|
+
|
|
1038
|
+
if use_attention_mask and (
|
|
1039
|
+
attention_mask := inference_input.get("attention_mask", None) is None
|
|
1040
|
+
):
|
|
1041
|
+
inference_input["attention_mask"] = get_attention_mask(prompts_tokens.size(1))
|
|
1042
|
+
|
|
1043
|
+
return inference_input
|
|
1000
1044
|
|
|
1001
1045
|
def stream_tokens(
|
|
1002
1046
|
self,
|
|
@@ -7,13 +7,17 @@ from megatron.core.inference.inference_request import InferenceRequest, VLMInfer
|
|
|
7
7
|
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
|
|
8
8
|
TextGenerationController,
|
|
9
9
|
)
|
|
10
|
+
from megatron.core.inference.utils import get_attention_mask
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class VLMTextGenerationController(TextGenerationController):
|
|
13
14
|
"""The text generation controller for VLMs"""
|
|
14
15
|
|
|
15
16
|
def prep_inference_input(
|
|
16
|
-
self,
|
|
17
|
+
self,
|
|
18
|
+
prompts_tokens: torch.Tensor,
|
|
19
|
+
active_requests: OrderedDict[str, InferenceRequest],
|
|
20
|
+
use_attention_mask: bool = False,
|
|
17
21
|
):
|
|
18
22
|
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
|
|
19
23
|
|
|
@@ -22,6 +26,8 @@ class VLMTextGenerationController(TextGenerationController):
|
|
|
22
26
|
Args:
|
|
23
27
|
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
|
|
24
28
|
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
|
|
29
|
+
use_attention_mask (bool): Whether to use an attention mask. Should be set to True only
|
|
30
|
+
when exclusively doing prefill (no decode) with variable prompt lengths.
|
|
25
31
|
"""
|
|
26
32
|
assert len(active_requests) == 1, f"VLM inference currently only supports batch size 1"
|
|
27
33
|
|
|
@@ -31,10 +37,17 @@ class VLMTextGenerationController(TextGenerationController):
|
|
|
31
37
|
request, VLMInferenceRequest
|
|
32
38
|
), f"Found inference request of type {type(request)}, expected VLMInferenceRequest"
|
|
33
39
|
|
|
34
|
-
|
|
40
|
+
inference_input = self.inference_wrapped_model.prep_inference_input(
|
|
35
41
|
prompts_tokens,
|
|
36
42
|
request.num_img_embeddings_per_tile,
|
|
37
43
|
request.imgs,
|
|
38
44
|
request.num_tiles,
|
|
39
45
|
request.decoder_seq_length,
|
|
40
46
|
)
|
|
47
|
+
|
|
48
|
+
if use_attention_mask and (
|
|
49
|
+
attention_mask := inference_input.get("attention_mask", None) is None
|
|
50
|
+
):
|
|
51
|
+
inference_input["attention_mask"] = get_attention_mask(prompts_tokens.size(1))
|
|
52
|
+
|
|
53
|
+
return inference_input
|
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
2
6
|
class Counter:
|
|
3
7
|
"""A simple counter class
|
|
4
8
|
|
|
@@ -16,3 +20,15 @@ class Counter:
|
|
|
16
20
|
def reset(self) -> None:
|
|
17
21
|
"""Reset counter"""
|
|
18
22
|
self.counter = 0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_attention_mask(seq_length: int) -> torch.Tensor:
|
|
26
|
+
"""Constructs an attention mask given the input sequence length."""
|
|
27
|
+
attention_mask = torch.tril(
|
|
28
|
+
torch.ones((1, seq_length, seq_length), device=torch.cuda.current_device())
|
|
29
|
+
).view(1, 1, seq_length, seq_length)
|
|
30
|
+
|
|
31
|
+
# Convert to boolean
|
|
32
|
+
attention_mask = attention_mask < 0.5
|
|
33
|
+
|
|
34
|
+
return attention_mask
|
|
@@ -286,11 +286,6 @@ class ModelParallelConfig:
|
|
|
286
286
|
Defaults to 0, which means all micro-batches are deferred.
|
|
287
287
|
"""
|
|
288
288
|
|
|
289
|
-
pipeline_model_parallel_split_rank: Optional[int] = None
|
|
290
|
-
"""If int, rank where encoder and decoder should be split in cases where the model has both an
|
|
291
|
-
encoder and decoder (e.g., T5). Ignored if None.
|
|
292
|
-
"""
|
|
293
|
-
|
|
294
289
|
overlap_p2p_comm_warmup_flush: bool = False
|
|
295
290
|
"""If true, overlap communication and computation in warm up and flush phase.
|
|
296
291
|
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
|