megatron-core 0.14.0rc2__tar.gz → 0.14.0rc3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- {megatron_core-0.14.0rc2/megatron_core.egg-info → megatron_core-0.14.0rc3}/PKG-INFO +11 -8
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +10 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fp8_utils.py +119 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/__init__.py +1 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/dynamic_context.py +148 -59
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/dynamic_engine.py +79 -18
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +4 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +6 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +3 -37
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +10 -4
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/__init__.py +143 -44
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/optimizer.py +0 -3
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/package_info.py +1 -1
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/packed_seq_params.py +2 -2
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/random.py +4 -1
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/attention.py +2 -7
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/cuda_graphs.py +178 -43
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3/megatron_core.egg-info}/PKG-INFO +11 -8
- megatron_core-0.14.0rc3/megatron_core.egg-info/requires.txt +33 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/pyproject.toml +13 -10
- megatron_core-0.14.0rc2/megatron_core.egg-info/requires.txt +0 -30
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/LICENSE +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/MANIFEST.in +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/README.md +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/README.md +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/config_logger.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/bert_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/blended_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/helpers.cpp +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/helpers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/indexed_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/masked_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/megatron_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/megatron_tokenizer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/multimodal_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/object_storage_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/build.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/external_libs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/build.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/factory.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/index.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/validate.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/query.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/t5_dataset.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/utils_object_storage.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/utils_s3.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/core.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/mapping.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/optimizer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/serialization.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/validation.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/data_parallel_base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/distributed_data_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/finalize_model_grads.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/energy_monitor.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/enums.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/data_type.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/export_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/model_type.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trt_model_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trt_model_type.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/extensions/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/extensions/kitchen.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/extensions/transformer_engine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_dropout.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_geglu.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_gelu.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_cross_entropy.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_indices_converter.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_layer_norm.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_softmax.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/hyper_comm_grid.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/async_stream.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/common_inference_params.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/communication_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/base_context.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/static_context.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/abstract_engine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/mcore_engine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/static_engine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/inference_request.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/sampling_params.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/scheduler.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference_params.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/jit.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/model_parallel_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/T5/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/T5/t5_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/T5/t5_spec.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/backends.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/bert_layer_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/bert_lm_head.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/bert_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/pooler.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/language_module/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/language_module/language_module.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/vision_module/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/vision_module/vision_module.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/gpt_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/moe_module_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/clip_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/module.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/qwen_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mamba/mamba_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/config/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/config/base_configs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/model/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/model/base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/submodules/audio.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/submodules/base.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/submodules/vision.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/context_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/llava_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/llava_spec.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/base_attention.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/decoder_attention.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/decoder_spec.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/encoder_attention.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/encoder_spec.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/clip_vit_model.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/multimodal_projector.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/radio.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/vit_layer_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/msc_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/num_microbatches_calculator.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/clip_grads.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/distrib_optimizer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/grad_scaler.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/optimizer_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer_param_scheduler.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/parallel_state.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/schedules.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/layers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/process_groups_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/quantization/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/quantization/quant_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/quantization/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/requirements.txt +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/rerun_state_machine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_block.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_context_parallel.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_layer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_mixer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mlp_layer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/triton_cache_manager.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/data.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/layers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/mappings.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/timers.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/custom_layers/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/dot_product_attention.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/enums.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/identity_op.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/mlp.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/module.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/__init__.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/experts.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/fused_a2a.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/moe_layer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/moe_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/router.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/shared_experts.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/multi_latent_attention.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/multi_token_prediction.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/spec_utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/torch_layer_norm.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/torch_norm.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/transformer_block.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/transformer_config.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/transformer_layer.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/utils.py +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron_core.egg-info/SOURCES.txt +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron_core.egg-info/dependency_links.txt +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron_core.egg-info/top_level.txt +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/setup.cfg +0 -0
- {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: megatron-core
|
|
3
|
-
Version: 0.14.
|
|
3
|
+
Version: 0.14.0rc3
|
|
4
4
|
Summary: Megatron Core - a library for efficient and scalable training of transformer based models
|
|
5
5
|
Author-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
6
6
|
Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
@@ -31,6 +31,7 @@ Description-Content-Type: text/markdown
|
|
|
31
31
|
License-File: LICENSE
|
|
32
32
|
Requires-Dist: torch
|
|
33
33
|
Requires-Dist: numpy<2.0.0
|
|
34
|
+
Requires-Dist: packaging~=25.0
|
|
34
35
|
Provides-Extra: mlm
|
|
35
36
|
Requires-Dist: flask-restful; extra == "mlm"
|
|
36
37
|
Requires-Dist: sentencepiece; extra == "mlm"
|
|
@@ -38,14 +39,16 @@ Requires-Dist: tiktoken; extra == "mlm"
|
|
|
38
39
|
Requires-Dist: wandb; extra == "mlm"
|
|
39
40
|
Provides-Extra: dev
|
|
40
41
|
Requires-Dist: tqdm; extra == "dev"
|
|
41
|
-
Requires-Dist: einops; extra == "dev"
|
|
42
|
-
Requires-Dist: tensorstore!=0.1.46,!=0.1.72; extra == "dev"
|
|
43
|
-
Requires-Dist: nvtx; extra == "dev"
|
|
44
|
-
Requires-Dist: transformers; extra == "dev"
|
|
45
|
-
Requires-Dist: multi-storage-client; extra == "dev"
|
|
42
|
+
Requires-Dist: einops~=0.8; extra == "dev"
|
|
43
|
+
Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "dev"
|
|
44
|
+
Requires-Dist: nvtx~=0.2; extra == "dev"
|
|
45
|
+
Requires-Dist: transformers~=4.53; extra == "dev"
|
|
46
|
+
Requires-Dist: multi-storage-client~=0.20.3; extra == "dev"
|
|
47
|
+
Requires-Dist: opentelemetry-api~=1.33.1; extra == "dev"
|
|
46
48
|
Requires-Dist: setuptools<80.0.0; extra == "dev"
|
|
47
|
-
Requires-Dist: nvidia-modelopt[torch]
|
|
48
|
-
Requires-Dist: megatron-energon[av_decode]
|
|
49
|
+
Requires-Dist: nvidia-modelopt[torch]<0.32.0,>=0.31.0a0; sys_platform != "darwin" and extra == "dev"
|
|
50
|
+
Requires-Dist: megatron-energon[av_decode]~=6.0; extra == "dev"
|
|
51
|
+
Requires-Dist: flashinfer-python; extra == "dev"
|
|
49
52
|
Provides-Extra: lts
|
|
50
53
|
Requires-Dist: tqdm; extra == "lts"
|
|
51
54
|
Requires-Dist: einops; extra == "lts"
|
|
@@ -217,6 +217,16 @@ class FullyShardedDataParallel(_BaseDataParallel):
|
|
|
217
217
|
|
|
218
218
|
self.module.apply(unmap_weight_tensor)
|
|
219
219
|
|
|
220
|
+
for param in self.module.parameters():
|
|
221
|
+
if not hasattr(param, 'grad_added_to_main_grad'):
|
|
222
|
+
# This is to ensure that the param.grad_added_to_main_grad is set to False
|
|
223
|
+
# when the parameter is created.
|
|
224
|
+
param.grad_added_to_main_grad = False
|
|
225
|
+
if not hasattr(param, '__fsdp_param__'):
|
|
226
|
+
# This is to ensure that the param.__fsdp_param__ is set to True
|
|
227
|
+
# when the parameter is created.
|
|
228
|
+
param.__fsdp_param__ = True
|
|
229
|
+
|
|
220
230
|
def _init_fsdp_param_and_grad_buffer(self):
|
|
221
231
|
if self.config.calculate_per_token_loss:
|
|
222
232
|
# We don't need to scale the gradients in this case.
|
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
"""Utility functions related to FP8 that are used throughout Megatron core"""
|
|
4
4
|
|
|
5
|
+
import weakref
|
|
5
6
|
from contextlib import nullcontext
|
|
7
|
+
from functools import wraps
|
|
6
8
|
from typing import List, Optional
|
|
7
9
|
|
|
8
10
|
import torch
|
|
@@ -53,6 +55,29 @@ except (ImportError, ModuleNotFoundError):
|
|
|
53
55
|
# MXFP8Tensor not found
|
|
54
56
|
HAVE_TE_MXFP8TENSOR = False
|
|
55
57
|
|
|
58
|
+
if HAVE_TE:
|
|
59
|
+
from megatron.core.extensions.transformer_engine import (
|
|
60
|
+
TEColumnParallelLinear,
|
|
61
|
+
TELayerNormColumnParallelLinear,
|
|
62
|
+
TELinear,
|
|
63
|
+
TERowParallelLinear,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
TE_LINEAR_TYPES = (
|
|
67
|
+
TELinear,
|
|
68
|
+
TEColumnParallelLinear,
|
|
69
|
+
TERowParallelLinear,
|
|
70
|
+
TELayerNormColumnParallelLinear,
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
TE_LINEAR_TYPES = ()
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
|
|
77
|
+
except ImportError:
|
|
78
|
+
Fp8Padding = None
|
|
79
|
+
Fp8Unpadding = None
|
|
80
|
+
|
|
56
81
|
|
|
57
82
|
def is_float8tensor(tensor: torch.Tensor) -> bool:
|
|
58
83
|
"""Check if a tensor is a Transformer Engine Float8Tensor.
|
|
@@ -511,3 +536,97 @@ else:
|
|
|
511
536
|
def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False):
|
|
512
537
|
"""Returns dummy fp8 context manager since TE is not available."""
|
|
513
538
|
return nullcontext()
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
if HAVE_TE:
|
|
542
|
+
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
|
|
543
|
+
|
|
544
|
+
# Modules that have been wrapped for inference for fp8
|
|
545
|
+
_fp8_inference_wrapped_modules = weakref.WeakSet()
|
|
546
|
+
|
|
547
|
+
def _wrap_te_linear_for_padding(module: torch.nn.Module):
|
|
548
|
+
"""Wrap a TE linear module to automatically pad sequences for FP8 inference.
|
|
549
|
+
|
|
550
|
+
Modifies the module's forward method to:
|
|
551
|
+
1. Pad input sequences to FP8 alignment requirements
|
|
552
|
+
2. Run the original forward pass
|
|
553
|
+
3. Unpad outputs to original sequence length
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
module: A Transformer Engine linear layer (TELinear, TEColumnParallelLinear, etc.)
|
|
557
|
+
"""
|
|
558
|
+
if module in _fp8_inference_wrapped_modules:
|
|
559
|
+
return
|
|
560
|
+
_pad_func = Fp8Padding(1)
|
|
561
|
+
_unpad_func = Fp8Unpadding(1)
|
|
562
|
+
|
|
563
|
+
original_forward = module.forward
|
|
564
|
+
|
|
565
|
+
@wraps(original_forward)
|
|
566
|
+
def padded_forward(input_tensor, *args, **kwargs):
|
|
567
|
+
# Only do padding for fp8 if we are in fp8 context
|
|
568
|
+
if not FP8GlobalStateManager.is_fp8_enabled():
|
|
569
|
+
return original_forward(input_tensor, *args, **kwargs)
|
|
570
|
+
|
|
571
|
+
seq_len, batch_size, hidden_size = input_tensor.shape
|
|
572
|
+
# Reshape to (S, B*H) to pad sequence dimension
|
|
573
|
+
input_2d = input_tensor.reshape(seq_len, -1)
|
|
574
|
+
# Pad the sequence dimension
|
|
575
|
+
padded_input_2d, _ = _pad_func(input_2d, [seq_len])
|
|
576
|
+
padded_seq_len = padded_input_2d.shape[0]
|
|
577
|
+
|
|
578
|
+
# Reshape back to (padded_S, B, H)
|
|
579
|
+
padded_input_3d = padded_input_2d.view(padded_seq_len, batch_size, hidden_size)
|
|
580
|
+
output = original_forward(padded_input_3d, *args, **kwargs)
|
|
581
|
+
|
|
582
|
+
# Handle output
|
|
583
|
+
if isinstance(output, tuple):
|
|
584
|
+
output_tensor = output[0]
|
|
585
|
+
other_outputs = output[1:]
|
|
586
|
+
else:
|
|
587
|
+
output_tensor = output
|
|
588
|
+
other_outputs = ()
|
|
589
|
+
|
|
590
|
+
# Unpad output - reshape to 2D, unpad, reshape back
|
|
591
|
+
_, _, output_hidden_size = output_tensor.shape
|
|
592
|
+
output_2d = output_tensor.reshape(padded_seq_len, -1)
|
|
593
|
+
unpadded_output_2d = _unpad_func(output_2d, [seq_len])
|
|
594
|
+
unpadded_output = unpadded_output_2d.reshape(seq_len, batch_size, output_hidden_size)
|
|
595
|
+
|
|
596
|
+
if other_outputs:
|
|
597
|
+
return (unpadded_output,) + other_outputs
|
|
598
|
+
else:
|
|
599
|
+
return unpadded_output
|
|
600
|
+
|
|
601
|
+
module.forward = padded_forward
|
|
602
|
+
_fp8_inference_wrapped_modules.add(module)
|
|
603
|
+
|
|
604
|
+
def prepare_model_for_fp8_inference(model):
|
|
605
|
+
"""Prepare a model for FP8 inference by wrapping TE linear layers with padding support.
|
|
606
|
+
|
|
607
|
+
FP8 TE Gemms have specific shape requirements. This function wraps all Transformer
|
|
608
|
+
Engine linear layers in the model to automatically pad/unpad sequences during inference.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
model (model (GPTModel): Model containing TE linear layers.
|
|
612
|
+
|
|
613
|
+
Returns:
|
|
614
|
+
GPTModel: The same model with wrapped linear layers (modified in-place).
|
|
615
|
+
|
|
616
|
+
"""
|
|
617
|
+
assert Fp8Padding and Fp8Unpadding, "TE version does not have FP8 padding functions"
|
|
618
|
+
# Find and wrap all TE linear layers
|
|
619
|
+
for module in model.modules():
|
|
620
|
+
if isinstance(module, TE_LINEAR_TYPES):
|
|
621
|
+
_wrap_te_linear_for_padding(module)
|
|
622
|
+
|
|
623
|
+
return model
|
|
624
|
+
|
|
625
|
+
else:
|
|
626
|
+
|
|
627
|
+
def prepare_model_for_fp8_inference(model):
|
|
628
|
+
"""If trys using prepare_model_for_fp8_inference without TE we error"""
|
|
629
|
+
raise RuntimeError(
|
|
630
|
+
"prepare_model_for_fp8_inference requires Transformer Engine to be installed. "
|
|
631
|
+
"Please install transformer-engine to use FP8 inference."
|
|
632
|
+
)
|
|
@@ -56,6 +56,18 @@ class ChunkOverflowError(ContextOverflowError):
|
|
|
56
56
|
pass
|
|
57
57
|
|
|
58
58
|
|
|
59
|
+
class ActiveRequestCountOverflowError(ContextOverflowError):
|
|
60
|
+
'''Used when `initialize_attention_state()` is called with
|
|
61
|
+
`num_warmup_requests > max_requests.'''
|
|
62
|
+
|
|
63
|
+
def __init__(self, max_request_count, active_request_count):
|
|
64
|
+
assert active_request_count > max_request_count
|
|
65
|
+
super().__init__(
|
|
66
|
+
"active_request_count (%d) > max_request_count (%d)."
|
|
67
|
+
% (active_request_count, max_request_count)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
59
71
|
# pylint: disable=line-too-long
|
|
60
72
|
class DynamicInferenceContext(BaseInferenceContext):
|
|
61
73
|
"""Inference context that is passed to the main model in order
|
|
@@ -108,6 +120,11 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
108
120
|
from `buffer_overflow_factor`.
|
|
109
121
|
max_tokens_override (Optional[int]): If set, overrides value computed
|
|
110
122
|
from `buffer_overflow_factor`.
|
|
123
|
+
tensor_model_parallel_size (Optional[int]): Tensor model parallel size.
|
|
124
|
+
num_cuda_graphs (Optional[int]): Maximum number of cuda graphs to capture,
|
|
125
|
+
where the cuda graph batch sizes range from 1 to `max_requests` (as
|
|
126
|
+
computed below). Due to rounding, the actual number of cuda graphs may
|
|
127
|
+
not equal this argument.
|
|
111
128
|
"""
|
|
112
129
|
|
|
113
130
|
def __init__(
|
|
@@ -125,6 +142,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
125
142
|
max_requests_override: Optional[int] = None,
|
|
126
143
|
max_tokens_override: Optional[int] = None,
|
|
127
144
|
tensor_model_parallel_size: Optional[int] = None,
|
|
145
|
+
num_cuda_graphs: Optional[int] = None,
|
|
128
146
|
materialize_only_last_token_logits: bool = True,
|
|
129
147
|
):
|
|
130
148
|
|
|
@@ -188,7 +206,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
188
206
|
self.active_token_count = 0
|
|
189
207
|
self.paused_request_count = 0
|
|
190
208
|
self.padded_active_token_count = None
|
|
191
|
-
self.
|
|
209
|
+
self.padded_active_request_count = None
|
|
192
210
|
self.paused_tokens = None
|
|
193
211
|
|
|
194
212
|
# Per-request state.
|
|
@@ -246,6 +264,34 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
246
264
|
device=torch.cuda.current_device(),
|
|
247
265
|
)
|
|
248
266
|
|
|
267
|
+
# Cuda graph request counts (i.e., batch sizes used for decode-only steps).
|
|
268
|
+
self.cuda_graph_request_counts = None
|
|
269
|
+
if num_cuda_graphs is not None:
|
|
270
|
+
|
|
271
|
+
# Ensure valid num_cuda_graphs.
|
|
272
|
+
num_cuda_graphs = min(max(num_cuda_graphs, 1), self.max_requests)
|
|
273
|
+
|
|
274
|
+
# Cuda graph step size.
|
|
275
|
+
cuda_graph_rounder = 8
|
|
276
|
+
self.cuda_graph_step_size = self.max_requests / num_cuda_graphs
|
|
277
|
+
self.cuda_graph_step_size = cuda_graph_rounder * int(
|
|
278
|
+
math.ceil(int(self.cuda_graph_step_size) / cuda_graph_rounder)
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Cuda graph request counts.
|
|
282
|
+
if num_cuda_graphs == 1:
|
|
283
|
+
self.cuda_graph_request_counts = [self.max_requests]
|
|
284
|
+
else:
|
|
285
|
+
self.cuda_graph_request_counts = list(
|
|
286
|
+
range(self.cuda_graph_step_size, self.max_requests, self.cuda_graph_step_size)
|
|
287
|
+
)
|
|
288
|
+
if self.cuda_graph_request_counts[-1] != self.max_requests:
|
|
289
|
+
self.cuda_graph_request_counts.append(self.max_requests)
|
|
290
|
+
self.cuda_graph_request_counts.reverse()
|
|
291
|
+
|
|
292
|
+
# Set used for validating active cuda graph request count.
|
|
293
|
+
self.cuda_graph_request_counts_set = set(self.cuda_graph_request_counts)
|
|
294
|
+
|
|
249
295
|
# `*_decode_only` tensors are for use with cuda graphs to maintain
|
|
250
296
|
# consistent input shapes, which is required to use cuda graphs. Cuda
|
|
251
297
|
# graphs are used only during decode-only steps (i.e., no requests are in
|
|
@@ -269,7 +315,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
269
315
|
(self.max_requests + 1,), 0, dtype=torch.int32, device=torch.cuda.current_device()
|
|
270
316
|
)
|
|
271
317
|
|
|
272
|
-
self.
|
|
318
|
+
self.request_to_kv_chunk_ids_decode_only = torch.full(
|
|
273
319
|
(self.max_requests, self.max_kv_chunk_count),
|
|
274
320
|
0,
|
|
275
321
|
dtype=torch.int,
|
|
@@ -278,27 +324,22 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
278
324
|
|
|
279
325
|
# Guaranteed active requests.
|
|
280
326
|
# * See details in the class docstring above. `gtd_request_fraction` is
|
|
281
|
-
# the fraction of the memory buffer that
|
|
282
|
-
# that some number of active requests can always proceed
|
|
283
|
-
# generations. The number of
|
|
284
|
-
#
|
|
285
|
-
#
|
|
286
|
-
#
|
|
287
|
-
#
|
|
288
|
-
#
|
|
289
|
-
#
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
gtd_request_count =
|
|
296
|
-
gtd_chunk_count = gtd_request_count * self.max_kv_chunk_count
|
|
297
|
-
assert (
|
|
298
|
-
gtd_request_count <= self.max_requests
|
|
299
|
-
), "gtd_request_count (%d) > max_requests (%d)." % (gtd_request_count, self.max_requests)
|
|
300
|
-
self.gtd_request_count = gtd_request_count
|
|
301
|
-
self.gtd_chunk_count = gtd_chunk_count
|
|
327
|
+
# the fraction of chunks in the memory buffer that are reserved for
|
|
328
|
+
# guaranteeing that some number of active requests can always proceed
|
|
329
|
+
# with their generations. The number of chunks defined by
|
|
330
|
+
# `buffer_guaranteed_fraction * chunk_count_total` is converted to a
|
|
331
|
+
# number of requests that this reserved space can safely handle
|
|
332
|
+
# (`gtd_request_count`).
|
|
333
|
+
# * Note: computing the size of this guaranteed space from chunks rather
|
|
334
|
+
# than bytes is safer due to the non-linear impacts of a large
|
|
335
|
+
# `chunk_size_tokens` or `max_kv_chunk_count`. When computing from
|
|
336
|
+
# chunks, this space will always be less than `chunk_count_total`. When
|
|
337
|
+
# computing from bytes, this space can unexpectedly be much larger than
|
|
338
|
+
# `chunk_count_total`, resulting in stalled generations.
|
|
339
|
+
gtd_chunk_count = int(buffer_guaranteed_fraction * chunk_count_total)
|
|
340
|
+
gtd_chunk_count = min(gtd_chunk_count, chunk_count_total)
|
|
341
|
+
self.gtd_request_count = max(1, gtd_chunk_count // self.max_kv_chunk_count)
|
|
342
|
+
self.gtd_chunk_count = self.gtd_request_count * self.max_kv_chunk_count
|
|
302
343
|
|
|
303
344
|
# Initialize chunk allocator
|
|
304
345
|
self.chunk_allocator = ChunkAllocator(
|
|
@@ -368,12 +409,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
368
409
|
|
|
369
410
|
def cu_kv_lengths(self) -> Tensor:
|
|
370
411
|
"""Cumulative key/value sequence lengths."""
|
|
371
|
-
return (
|
|
372
|
-
self.cu_kv_seq_lengths,
|
|
373
|
-
self.kv_seq_lengths,
|
|
374
|
-
self.kv_seq_lengths_decode_only,
|
|
375
|
-
self.max_seqlen_k,
|
|
376
|
-
)
|
|
412
|
+
return (self.cu_kv_seq_lengths, self.kv_seq_lengths, self.max_seqlen_k)
|
|
377
413
|
|
|
378
414
|
def get_active_sequence_lengths(self) -> Tensor:
|
|
379
415
|
"""Total sequence length (query + key) for active requests."""
|
|
@@ -487,7 +523,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
487
523
|
key_seq_idx = self.token_to_position_in_request[:n]
|
|
488
524
|
key_emb = key_emb[key_seq_idx]
|
|
489
525
|
if self.is_decode_only():
|
|
490
|
-
assert key.shape[0] == n
|
|
526
|
+
assert key.shape[0] == n
|
|
491
527
|
key = apply_rotary_pos_emb(
|
|
492
528
|
t=key[:n], freqs=key_emb[:n], config=config, cp_group=cp_group
|
|
493
529
|
)
|
|
@@ -506,23 +542,65 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
506
542
|
self.query_seq_lengths_decode_only.fill_(0)
|
|
507
543
|
self.cu_kv_seq_lengths = None
|
|
508
544
|
self.cu_kv_seq_lengths_decode_only.fill_(0)
|
|
545
|
+
self.kv_seq_lengths = None
|
|
509
546
|
self.kv_seq_lengths_decode_only.fill_(0)
|
|
510
|
-
self.
|
|
547
|
+
self.request_to_kv_chunk_ids_decode_only.fill_(0)
|
|
511
548
|
self.block_table = None
|
|
512
549
|
|
|
513
|
-
def initialize_attention_state(self) -> None:
|
|
514
|
-
"""Initialize attention state so that every layer can use it
|
|
550
|
+
def initialize_attention_state(self, *, num_warmup_requests: Optional[int] = None) -> None:
|
|
551
|
+
"""Initialize attention state so that every layer can use it.
|
|
552
|
+
|
|
553
|
+
Args:
|
|
554
|
+
num_warmup_requests (Optional[int]): Number of requests to use for
|
|
555
|
+
warming up cuda graphs. Must be less than or equal to
|
|
556
|
+
`max_requests`.
|
|
557
|
+
|
|
558
|
+
Return:
|
|
559
|
+
None.
|
|
560
|
+
"""
|
|
515
561
|
|
|
562
|
+
# Use of num_warmup_requests only for decode-only.
|
|
563
|
+
if num_warmup_requests is not None:
|
|
564
|
+
assert self.is_decode_only(), "cuda graph warmup requires decode-only mode."
|
|
565
|
+
|
|
566
|
+
# Active request count.
|
|
567
|
+
active_request_count = (
|
|
568
|
+
self.total_request_count - self.paused_request_count
|
|
569
|
+
if num_warmup_requests is None
|
|
570
|
+
else num_warmup_requests
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# Active cuda graph count (if decode-only).
|
|
574
|
+
active_cuda_graph_request_count = None
|
|
575
|
+
if self.is_decode_only():
|
|
576
|
+
if active_request_count > self.max_requests:
|
|
577
|
+
raise ActiveRequestCountOverflowError(self.max_requests, active_request_count)
|
|
578
|
+
|
|
579
|
+
if self.cuda_graph_request_counts:
|
|
580
|
+
active_cuda_graph_request_count = (
|
|
581
|
+
math.ceil(active_request_count / self.cuda_graph_step_size)
|
|
582
|
+
* self.cuda_graph_step_size
|
|
583
|
+
)
|
|
584
|
+
active_cuda_graph_request_count = min(
|
|
585
|
+
active_cuda_graph_request_count, self.max_requests
|
|
586
|
+
)
|
|
587
|
+
assert active_cuda_graph_request_count in self.cuda_graph_request_counts_set
|
|
588
|
+
else:
|
|
589
|
+
active_cuda_graph_request_count = self.max_requests
|
|
590
|
+
|
|
591
|
+
# Padded active token/request counts.
|
|
516
592
|
self.padded_active_token_count = (
|
|
517
|
-
|
|
593
|
+
active_cuda_graph_request_count
|
|
518
594
|
if self.is_decode_only()
|
|
519
595
|
else self.round_up_tokens(self.active_token_count)
|
|
520
596
|
)
|
|
521
|
-
self.
|
|
522
|
-
|
|
597
|
+
self.padded_active_request_count = (
|
|
598
|
+
active_cuda_graph_request_count
|
|
523
599
|
if self.is_decode_only()
|
|
524
600
|
else (self.total_request_count - self.paused_request_count)
|
|
525
601
|
)
|
|
602
|
+
|
|
603
|
+
# Update token position indexes.
|
|
526
604
|
self.token_to_chunk_idx[self.active_token_count : self.padded_active_token_count] = (
|
|
527
605
|
self.dummy_chunk_idx
|
|
528
606
|
)
|
|
@@ -533,6 +611,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
533
611
|
self.active_token_count : self.padded_active_token_count
|
|
534
612
|
] = 0
|
|
535
613
|
|
|
614
|
+
# Update cu_query_seq_lengths, max_seqlen_q.
|
|
536
615
|
query_lengths = self.request_query_lengths[
|
|
537
616
|
self.paused_request_count : self.total_request_count
|
|
538
617
|
]
|
|
@@ -540,9 +619,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
540
619
|
self.query_seq_lengths_decode_only[
|
|
541
620
|
0 : self.total_request_count - self.paused_request_count
|
|
542
621
|
] = query_lengths
|
|
543
|
-
|
|
544
|
-
self.cu_query_seq_lengths_decode_only[1:] = cu_query_lengths_decode_only
|
|
545
|
-
self.cu_query_seq_lengths = self.cu_query_seq_lengths_decode_only
|
|
622
|
+
self.cu_query_seq_lengths = None # ensure no accidental use
|
|
546
623
|
self.max_seqlen_q = 1
|
|
547
624
|
else:
|
|
548
625
|
cu_query_lengths = torch.cumsum(query_lengths, dim=0)
|
|
@@ -558,12 +635,18 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
558
635
|
kv_seq_lengths = self.request_kv_length_offsets + self.request_query_lengths
|
|
559
636
|
self.kv_seq_lengths = kv_seq_lengths[self.paused_request_count : self.total_request_count]
|
|
560
637
|
if self.is_decode_only():
|
|
638
|
+
# Re-assign `kv_seq_lengths` to be a view of the first
|
|
639
|
+
# `active_cuda_graph_request_count` tokens of `kv_seq_lengths_decode_only`,
|
|
640
|
+
# such that `kv_seq_lengths` has a static memory address and is therefore
|
|
641
|
+
# cuda graph compatible. This allows `kv_seq_lengths` to transition between,
|
|
642
|
+
# cuda graph sizes, which makes multi-batch-size cuda graphs possible.
|
|
561
643
|
self.kv_seq_lengths_decode_only[
|
|
562
644
|
0 : self.total_request_count - self.paused_request_count
|
|
563
645
|
] = self.kv_seq_lengths
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
646
|
+
self.kv_seq_lengths = self.kv_seq_lengths_decode_only[
|
|
647
|
+
: self.padded_active_request_count
|
|
648
|
+
]
|
|
649
|
+
self.cu_kv_seq_lengths = None # ensure no accidental use
|
|
567
650
|
self.max_seqlen_k = self.max_sequence_length
|
|
568
651
|
else:
|
|
569
652
|
self.cu_kv_seq_lengths = torch.full(
|
|
@@ -575,14 +658,17 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
575
658
|
self.cu_kv_seq_lengths[1:] = torch.cumsum(self.kv_seq_lengths, dim=0)
|
|
576
659
|
self.max_seqlen_k = self.kv_seq_lengths.max().item()
|
|
577
660
|
|
|
578
|
-
|
|
661
|
+
# Update KV chunk IDs, block table.
|
|
662
|
+
request_to_kv_chunk_ids = self.request_to_kv_chunk_ids[
|
|
579
663
|
self.paused_request_count : self.total_request_count
|
|
580
664
|
]
|
|
581
665
|
if self.is_decode_only():
|
|
582
|
-
self.
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
self.block_table = self.
|
|
666
|
+
self.request_to_kv_chunk_ids_decode_only[
|
|
667
|
+
0 : self.total_request_count - self.paused_request_count
|
|
668
|
+
] = request_to_kv_chunk_ids
|
|
669
|
+
self.block_table = self.request_to_kv_chunk_ids_decode_only[
|
|
670
|
+
: self.padded_active_request_count
|
|
671
|
+
]
|
|
586
672
|
else:
|
|
587
673
|
self.block_table = self.request_to_kv_chunk_ids[
|
|
588
674
|
self.paused_request_count : self.total_request_count
|
|
@@ -606,7 +692,7 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
606
692
|
self.active_token_count = 0
|
|
607
693
|
self.paused_request_count = 0
|
|
608
694
|
self.padded_active_token_count = 0
|
|
609
|
-
self.
|
|
695
|
+
self.padded_active_request_count = 0
|
|
610
696
|
self.paused_tokens = None
|
|
611
697
|
|
|
612
698
|
# Reset request indexes.
|
|
@@ -632,21 +718,24 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
632
718
|
self.chunk_allocator.reset()
|
|
633
719
|
self.request_to_kv_chunk_ids.fill_(-1)
|
|
634
720
|
|
|
635
|
-
def
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
(Tensor) Flattened active input IDs.
|
|
640
|
-
"""
|
|
641
|
-
return self.token_to_input_ids[: self.padded_active_token_count].unsqueeze(0)
|
|
721
|
+
def current_input_and_position_ids(
|
|
722
|
+
self, *, num_warmup_tokens: Optional[int] = None
|
|
723
|
+
) -> Tuple[Tensor, Tensor]:
|
|
724
|
+
"""Flattened input and position IDs for forward pass.
|
|
642
725
|
|
|
643
|
-
|
|
644
|
-
|
|
726
|
+
Args:
|
|
727
|
+
num_warmup_tokens (Optional[int]): Number of tokens to return for
|
|
728
|
+
warming up cuda graphs. Must be less than or equal to
|
|
729
|
+
`max_tokens`.
|
|
645
730
|
|
|
646
731
|
Return:
|
|
647
|
-
(Tensor) Flattened active position IDs.
|
|
732
|
+
(Tuple[Tensor, Tensor]) Flattened active input and position IDs.
|
|
648
733
|
"""
|
|
649
|
-
|
|
734
|
+
num_tokens = num_warmup_tokens or self.padded_active_token_count
|
|
735
|
+
return (
|
|
736
|
+
self.token_to_input_ids[:num_tokens].unsqueeze(0),
|
|
737
|
+
self.token_to_pos_ids[:num_tokens].unsqueeze(0),
|
|
738
|
+
)
|
|
650
739
|
|
|
651
740
|
def last_token_logits(self, logits: Tensor) -> Tensor:
|
|
652
741
|
"""Last tokens of logits.
|