megatron-core 0.14.0rc3__tar.gz → 0.14.0rc4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- {megatron_core-0.14.0rc3/megatron_core.egg-info → megatron_core-0.14.0rc4}/PKG-INFO +1 -1
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/torch.py +2 -1
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/validation.py +21 -15
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine.py +5 -27
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine_spec_provider.py +5 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +81 -16
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/dynamic_context.py +44 -28
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +23 -2
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/model_parallel_config.py +8 -3
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/rope_utils.py +20 -32
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +11 -6
- megatron_core-0.14.0rc4/megatron/core/models/common/model_chunk_schedule_plan.py +502 -0
- megatron_core-0.14.0rc4/megatron/core/models/gpt/fine_grained_callables.py +474 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/gpt_layer_specs.py +2 -2
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/gpt_model.py +62 -1
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/optimizer.py +11 -1
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/optimizer_config.py +1 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/package_info.py +1 -1
- megatron_core-0.14.0rc4/megatron/core/pipeline_parallel/combined_1f1b.py +331 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/schedules.py +169 -101
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/utils.py +91 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_block.py +4 -1
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_layer.py +1 -1
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/layers.py +23 -12
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/attention.py +1 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/mlp.py +20 -2
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/experts.py +22 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/multi_latent_attention.py +81 -9
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_config.py +60 -7
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_layer.py +11 -10
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/utils.py +17 -11
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/utils.py +27 -3
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4/megatron_core.egg-info}/PKG-INFO +1 -1
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron_core.egg-info/SOURCES.txt +2 -0
- megatron_core-0.14.0rc3/megatron/core/models/gpt/fine_grained_callables.py +0 -195
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/LICENSE +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/MANIFEST.in +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/README.md +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/README.md +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/config_logger.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/bert_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/helpers.cpp +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/helpers.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/indexed_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/masked_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/megatron_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/megatron_tokenizer.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/multimodal_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/object_storage_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/build.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/external_libs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/build.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/factory.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/index.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/validate.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/query.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/t5_dataset.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils_object_storage.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils_s3.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/core.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/mapping.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/optimizer.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/serialization.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/data_parallel_base.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/distributed_data_parallel.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/finalize_model_grads.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/energy_monitor.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/enums.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/data_type.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/export_config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/model_type.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trt_model_config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trt_model_type.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/extensions/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/extensions/kitchen.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fp8_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_dropout.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_geglu.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_gelu.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_cross_entropy.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_indices_converter.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_layer_norm.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_softmax.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/hyper_comm_grid.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/async_stream.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/common_inference_params.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/communication_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/base_context.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/static_context.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/abstract_engine.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/dynamic_engine.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/mcore_engine.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/static_engine.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/inference_request.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/sampling_params.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/scheduler.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference_params.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/jit.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/T5/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/T5/t5_model.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/T5/t5_spec.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/backends.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_layer_specs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_lm_head.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_model.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/pooler.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/language_module/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/language_module/language_module.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/vision_module/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/vision_module/vision_module.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/moe_module_specs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/clip_model.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/module.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/qwen_model.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/mamba_model.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/config/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/config/base_configs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/model/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/model/base.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/audio.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/base.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/vision.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/context_parallel.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/llava_model.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/llava_spec.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/base_attention.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/decoder_attention.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/decoder_spec.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/encoder_attention.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/encoder_spec.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/model.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/clip_vit_model.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/multimodal_projector.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/radio.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/vit_layer_specs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/msc_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/num_microbatches_calculator.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/clip_grads.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/distrib_optimizer.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/grad_scaler.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer_param_scheduler.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/packed_seq_params.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/parallel_state.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/layers.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/process_groups_config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/quantization/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/quantization/quant_config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/quantization/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/requirements.txt +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/rerun_state_machine.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_context_parallel.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_mixer.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mlp_layer.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/triton_cache_manager.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/data.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/mappings.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/random.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/timers.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/cuda_graphs.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/custom_layers/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/dot_product_attention.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/enums.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/identity_op.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/module.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/__init__.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/fused_a2a.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/moe_layer.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/moe_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/router.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/shared_experts.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/multi_token_prediction.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/spec_utils.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/torch_layer_norm.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/torch_norm.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_block.py +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron_core.egg-info/dependency_links.txt +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron_core.egg-info/requires.txt +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron_core.egg-info/top_level.txt +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/pyproject.toml +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/setup.cfg +0 -0
- {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: megatron-core
|
|
3
|
-
Version: 0.14.
|
|
3
|
+
Version: 0.14.0rc4
|
|
4
4
|
Summary: Megatron Core - a library for efficient and scalable training of transformer based models
|
|
5
5
|
Author-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
6
6
|
Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
|
|
@@ -374,7 +374,8 @@ def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]
|
|
|
374
374
|
ten = ten.view(-1)
|
|
375
375
|
else:
|
|
376
376
|
for _ in range(mcore_sh_ten.prepend_axis_num):
|
|
377
|
-
|
|
377
|
+
assert ten.size(0) == 1
|
|
378
|
+
ten = ten[0] # NOTE: ten.squeeze(0) uses more memory for FP8 tensors
|
|
378
379
|
ret_tensors.append(ten)
|
|
379
380
|
return ret_tensors
|
|
380
381
|
|
{megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/validation.py
RENAMED
|
@@ -375,28 +375,34 @@ def maybe_report_missing_and_unexpected_keys(
|
|
|
375
375
|
def _validate_common_state_dict(common_state_dict: CommonStateDict) -> None:
|
|
376
376
|
"""Validate consistancy across ranks for the common state dict
|
|
377
377
|
|
|
378
|
-
We save the common state dict only on rank 0. We validate to make sure that the common dict is
|
|
378
|
+
We save the common state dict only on rank 0. We validate to make sure that the common dict is consistent across ranks before saving.
|
|
379
379
|
|
|
380
380
|
Args:
|
|
381
381
|
common_state_dict: The common state dict present in all ransk
|
|
382
382
|
"""
|
|
383
|
+
if not torch.distributed.is_initialized():
|
|
384
|
+
return
|
|
383
385
|
|
|
384
|
-
#
|
|
386
|
+
# Broadcast the common state dict from rank 0 to all other ranks
|
|
387
|
+
# Each rank will do a comparison with its local rank vs the broadcasted state dict from rank 0
|
|
385
388
|
rank = torch.distributed.get_rank()
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
if
|
|
389
|
+
|
|
390
|
+
object_list = [common_state_dict] if rank == 0 else [None]
|
|
391
|
+
torch.distributed.broadcast_object_list(object_list, src=0)
|
|
392
|
+
rank0_state_dict = object_list[0]
|
|
393
|
+
|
|
394
|
+
# Skip comparing rank 0 with itself
|
|
395
|
+
if rank > 0:
|
|
396
|
+
current_rank_state_dict = common_state_dict
|
|
397
|
+
only_in_rank0, only_in_current_rank, mismatch = diff(
|
|
398
|
+
rank0_state_dict, current_rank_state_dict
|
|
399
|
+
)
|
|
400
|
+
if only_in_rank0 or only_in_current_rank or mismatch:
|
|
398
401
|
logger.warning(
|
|
399
|
-
f"
|
|
402
|
+
f"Rank {rank} common state dict differs from rank 0 common state dict. "
|
|
403
|
+
f"Keys only on rank 0: {only_in_rank0}, "
|
|
404
|
+
f"Keys only on {rank}: {only_in_current_rank}, "
|
|
405
|
+
f"Mismatched keys: {mismatch}"
|
|
400
406
|
)
|
|
401
407
|
|
|
402
408
|
|
{megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine.py
RENAMED
|
@@ -889,25 +889,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
|
|
|
889
889
|
if packed_seq_params is not None
|
|
890
890
|
else {}
|
|
891
891
|
)
|
|
892
|
-
|
|
893
|
-
# after init
|
|
894
|
-
if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False):
|
|
895
|
-
self.qkv_format = "bshd"
|
|
896
|
-
|
|
897
|
-
qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
|
|
898
|
-
|
|
899
|
-
# WAR for peak memory usage.
|
|
900
|
-
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
|
|
901
|
-
if self.config.apply_rope_fusion and qkv_format == "bshd":
|
|
902
|
-
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
|
|
903
|
-
# In PyTorch, the following two tensors are in fact the same:
|
|
904
|
-
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
|
|
905
|
-
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
|
|
906
|
-
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
|
|
907
|
-
# can have same shape but different strides.
|
|
908
|
-
# We unify them to the first one to pass the stride check in TE
|
|
909
|
-
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
|
|
910
|
-
value = value.as_strided(value.shape, key.stride())
|
|
892
|
+
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
|
|
911
893
|
|
|
912
894
|
attention_bias_kwargs = {}
|
|
913
895
|
if attention_bias is not None:
|
|
@@ -942,10 +924,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
|
|
|
942
924
|
query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
|
|
943
925
|
)
|
|
944
926
|
|
|
945
|
-
|
|
946
|
-
return core_attn_out.transpose(0, 1)
|
|
947
|
-
else:
|
|
948
|
-
return core_attn_out
|
|
927
|
+
return core_attn_out
|
|
949
928
|
|
|
950
929
|
|
|
951
930
|
if HAVE_TE and is_te_min_version("1.9.0.dev0"):
|
|
@@ -1633,10 +1612,8 @@ try:
|
|
|
1633
1612
|
else:
|
|
1634
1613
|
if interleaved:
|
|
1635
1614
|
raise ValueError("Only TE >= 2.3.0 supports interleaved fused RoPE.")
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
else:
|
|
1639
|
-
raise ValueError("Only TE >= 1.4.0.dev0 supports fused RoPE.")
|
|
1615
|
+
|
|
1616
|
+
return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True)
|
|
1640
1617
|
|
|
1641
1618
|
def fused_apply_rotary_pos_emb_thd(
|
|
1642
1619
|
t: torch.Tensor,
|
|
@@ -1659,6 +1636,7 @@ try:
|
|
|
1659
1636
|
cp_rank=cp_rank,
|
|
1660
1637
|
)
|
|
1661
1638
|
else:
|
|
1639
|
+
assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP."
|
|
1662
1640
|
return apply_rotary_pos_emb(
|
|
1663
1641
|
t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens
|
|
1664
1642
|
)
|
|
@@ -8,6 +8,7 @@ from megatron.core.extensions.transformer_engine import (
|
|
|
8
8
|
TEColumnParallelLinear,
|
|
9
9
|
TEDotProductAttention,
|
|
10
10
|
TELayerNormColumnParallelLinear,
|
|
11
|
+
TELinear,
|
|
11
12
|
TENorm,
|
|
12
13
|
TERowParallelGroupedLinear,
|
|
13
14
|
TERowParallelLinear,
|
|
@@ -23,6 +24,10 @@ from megatron.core.utils import get_te_version, is_te_min_version
|
|
|
23
24
|
class TESpecProvider(BackendSpecProvider):
|
|
24
25
|
"""A protocol for providing the submodules used in Spec building."""
|
|
25
26
|
|
|
27
|
+
def linear(self) -> type:
|
|
28
|
+
"""Which linear module TE backend uses"""
|
|
29
|
+
return TELinear
|
|
30
|
+
|
|
26
31
|
def column_parallel_linear(self) -> type:
|
|
27
32
|
"""Which column parallel linear module TE backend uses"""
|
|
28
33
|
return TEColumnParallelLinear
|
|
@@ -28,16 +28,25 @@ if not HAVE_TRITON:
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@triton.jit
|
|
31
|
-
def _get_thd_token_idx(cu_seqlens, pid_m, seq_num):
|
|
31
|
+
def _get_thd_token_idx(cu_seqlens, pid_m, seq_num, cp_rank, cp_size):
|
|
32
32
|
token_idx = -1
|
|
33
|
+
this_seq_len = 0
|
|
33
34
|
seq_idx = 0
|
|
34
|
-
last_cum_seqlen = tl.load(cu_seqlens)
|
|
35
|
+
last_cum_seqlen = tl.load(cu_seqlens) // cp_size
|
|
35
36
|
while seq_idx < seq_num:
|
|
36
|
-
cur_cum_seqlen = tl.load(cu_seqlens + seq_idx + 1)
|
|
37
|
+
cur_cum_seqlen = tl.load(cu_seqlens + seq_idx + 1) // cp_size
|
|
37
38
|
if token_idx == -1 and cur_cum_seqlen > pid_m:
|
|
38
39
|
token_idx = pid_m - last_cum_seqlen
|
|
40
|
+
this_seq_len = cur_cum_seqlen - last_cum_seqlen
|
|
39
41
|
last_cum_seqlen = cur_cum_seqlen
|
|
40
42
|
seq_idx += 1
|
|
43
|
+
if cp_size > 1:
|
|
44
|
+
if token_idx < this_seq_len // 2:
|
|
45
|
+
token_idx = token_idx + cp_rank * this_seq_len // 2
|
|
46
|
+
else:
|
|
47
|
+
token_idx = (token_idx - this_seq_len // 2) + (
|
|
48
|
+
2 * cp_size - cp_rank - 1
|
|
49
|
+
) * this_seq_len // 2
|
|
41
50
|
return token_idx
|
|
42
51
|
|
|
43
52
|
|
|
@@ -68,6 +77,8 @@ def rotary_fwd_q_kernel(
|
|
|
68
77
|
cu_seqlens_q,
|
|
69
78
|
stride_x_seq,
|
|
70
79
|
stride_x_nheads,
|
|
80
|
+
cp_rank,
|
|
81
|
+
cp_size,
|
|
71
82
|
BLOCK_H: tl.constexpr,
|
|
72
83
|
):
|
|
73
84
|
"""
|
|
@@ -89,7 +100,7 @@ def rotary_fwd_q_kernel(
|
|
|
89
100
|
if cu_seqlens_q is None:
|
|
90
101
|
token_idx = pid_m // batch_size
|
|
91
102
|
else:
|
|
92
|
-
token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num)
|
|
103
|
+
token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num, cp_rank, cp_size)
|
|
93
104
|
|
|
94
105
|
cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
95
106
|
sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
@@ -146,6 +157,8 @@ def rotary_bwd_q_kernel(
|
|
|
146
157
|
cu_seqlens_q,
|
|
147
158
|
stride_x_seq,
|
|
148
159
|
stride_x_nheads,
|
|
160
|
+
cp_rank,
|
|
161
|
+
cp_size,
|
|
149
162
|
BLOCK_H: tl.constexpr,
|
|
150
163
|
):
|
|
151
164
|
"""
|
|
@@ -165,7 +178,7 @@ def rotary_bwd_q_kernel(
|
|
|
165
178
|
if cu_seqlens_q is None:
|
|
166
179
|
token_idx = pid_m // batch_size
|
|
167
180
|
else:
|
|
168
|
-
token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num)
|
|
181
|
+
token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num, cp_rank, cp_size)
|
|
169
182
|
|
|
170
183
|
cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
171
184
|
sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
@@ -200,7 +213,18 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
|
|
|
200
213
|
"""
|
|
201
214
|
|
|
202
215
|
@staticmethod
|
|
203
|
-
def forward(
|
|
216
|
+
def forward(
|
|
217
|
+
ctx,
|
|
218
|
+
q,
|
|
219
|
+
cos,
|
|
220
|
+
sin,
|
|
221
|
+
qk_head_dim,
|
|
222
|
+
emb_dim,
|
|
223
|
+
cu_seqlens_q,
|
|
224
|
+
cp_rank,
|
|
225
|
+
cp_size,
|
|
226
|
+
rotary_interleaved=False,
|
|
227
|
+
):
|
|
204
228
|
"""
|
|
205
229
|
Forward function for ApplyMLARotaryEmbQ.
|
|
206
230
|
|
|
@@ -243,12 +267,16 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
|
|
|
243
267
|
cu_seqlens_q,
|
|
244
268
|
q.stride(0),
|
|
245
269
|
q.stride(1),
|
|
270
|
+
cp_rank,
|
|
271
|
+
cp_size,
|
|
246
272
|
)
|
|
247
273
|
ctx.save_for_backward(cos, sin)
|
|
248
274
|
ctx.qk_head_dim = qk_head_dim
|
|
249
275
|
ctx.emb_dim = emb_dim
|
|
250
276
|
ctx.cu_seqlens_q = cu_seqlens_q
|
|
251
277
|
ctx.rotary_interleaved = rotary_interleaved
|
|
278
|
+
ctx.cp_rank = cp_rank
|
|
279
|
+
ctx.cp_size = cp_size
|
|
252
280
|
if cu_seqlens_q is None:
|
|
253
281
|
q = q.view(max_seqlen, batch_size, nheads, headdim)
|
|
254
282
|
return q
|
|
@@ -268,7 +296,7 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
|
|
|
268
296
|
seq_num = None
|
|
269
297
|
if ctx.cu_seqlens_q is None:
|
|
270
298
|
max_seqlen, batch_size, nheads, headdim = grad.shape
|
|
271
|
-
grad = grad.view(-1, nheads, headdim)
|
|
299
|
+
grad = grad.contiguous().view(-1, nheads, headdim)
|
|
272
300
|
total_seqlen = grad.shape[0]
|
|
273
301
|
else:
|
|
274
302
|
seq_num = len(ctx.cu_seqlens_q) - 1
|
|
@@ -288,10 +316,12 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
|
|
|
288
316
|
ctx.cu_seqlens_q,
|
|
289
317
|
grad.stride(0),
|
|
290
318
|
grad.stride(1),
|
|
319
|
+
ctx.cp_rank,
|
|
320
|
+
ctx.cp_size,
|
|
291
321
|
)
|
|
292
322
|
if ctx.cu_seqlens_q is None:
|
|
293
323
|
grad = grad.view(max_seqlen, batch_size, nheads, headdim)
|
|
294
|
-
return grad, None, None, None, None, None, None
|
|
324
|
+
return grad, None, None, None, None, None, None, None, None
|
|
295
325
|
|
|
296
326
|
|
|
297
327
|
@experimental_fn(introduced_with_version="0.13.0")
|
|
@@ -302,6 +332,8 @@ def fused_apply_mla_rope_for_q(
|
|
|
302
332
|
qk_head_dim: int,
|
|
303
333
|
emb_dim: int,
|
|
304
334
|
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
335
|
+
cp_rank: int = 0,
|
|
336
|
+
cp_size: int = 1,
|
|
305
337
|
rotary_interleaved: bool = False,
|
|
306
338
|
):
|
|
307
339
|
"""
|
|
@@ -327,7 +359,7 @@ def fused_apply_mla_rope_for_q(
|
|
|
327
359
|
t: inplace modified input tensor
|
|
328
360
|
"""
|
|
329
361
|
return ApplyMLARotaryEmbQ.apply(
|
|
330
|
-
t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, rotary_interleaved
|
|
362
|
+
t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, cp_rank, cp_size, rotary_interleaved
|
|
331
363
|
)
|
|
332
364
|
|
|
333
365
|
|
|
@@ -366,6 +398,8 @@ def rotary_fwd_kv_kernel(
|
|
|
366
398
|
stride_k_nheads,
|
|
367
399
|
stride_v_seq,
|
|
368
400
|
stride_v_nheads,
|
|
401
|
+
cp_rank,
|
|
402
|
+
cp_size,
|
|
369
403
|
BLOCK_H: tl.constexpr,
|
|
370
404
|
):
|
|
371
405
|
"""
|
|
@@ -394,7 +428,7 @@ def rotary_fwd_kv_kernel(
|
|
|
394
428
|
if cu_seqlens_kv is None:
|
|
395
429
|
token_idx = pid_m // batch_size
|
|
396
430
|
else:
|
|
397
|
-
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num)
|
|
431
|
+
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
|
|
398
432
|
|
|
399
433
|
cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
400
434
|
sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
|
|
@@ -472,6 +506,8 @@ def rotary_bwd_kv_kernel(
|
|
|
472
506
|
stride_dkv_seq,
|
|
473
507
|
stride_dkv_nheads,
|
|
474
508
|
stride_demb_seq,
|
|
509
|
+
cp_rank,
|
|
510
|
+
cp_size,
|
|
475
511
|
BLOCK_H: tl.constexpr,
|
|
476
512
|
):
|
|
477
513
|
"""
|
|
@@ -496,7 +532,7 @@ def rotary_bwd_kv_kernel(
|
|
|
496
532
|
if cu_seqlens_kv is None:
|
|
497
533
|
token_idx = pid_m // batch_size
|
|
498
534
|
else:
|
|
499
|
-
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num)
|
|
535
|
+
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
|
|
500
536
|
|
|
501
537
|
dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads
|
|
502
538
|
dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads
|
|
@@ -550,7 +586,18 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
550
586
|
|
|
551
587
|
@staticmethod
|
|
552
588
|
def forward(
|
|
553
|
-
ctx,
|
|
589
|
+
ctx,
|
|
590
|
+
kv,
|
|
591
|
+
k_pos_emb,
|
|
592
|
+
cos,
|
|
593
|
+
sin,
|
|
594
|
+
emb_dim,
|
|
595
|
+
k_dim,
|
|
596
|
+
v_dim,
|
|
597
|
+
cu_seqlens_kv,
|
|
598
|
+
cp_rank,
|
|
599
|
+
cp_size,
|
|
600
|
+
rotary_interleaved=False,
|
|
554
601
|
):
|
|
555
602
|
"""
|
|
556
603
|
Forward function for ApplyMLARotaryEmbKV.
|
|
@@ -609,6 +656,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
609
656
|
o_key.stride(1),
|
|
610
657
|
o_value.stride(0),
|
|
611
658
|
o_value.stride(1),
|
|
659
|
+
cp_rank,
|
|
660
|
+
cp_size,
|
|
612
661
|
)
|
|
613
662
|
ctx.save_for_backward(cos, sin)
|
|
614
663
|
ctx.rotary_interleaved = rotary_interleaved
|
|
@@ -616,6 +665,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
616
665
|
ctx.k_dim = k_dim
|
|
617
666
|
ctx.v_dim = v_dim
|
|
618
667
|
ctx.cu_seqlens_kv = cu_seqlens_kv
|
|
668
|
+
ctx.cp_rank = cp_rank
|
|
669
|
+
ctx.cp_size = cp_size
|
|
619
670
|
if cu_seqlens_kv is None:
|
|
620
671
|
o_key = o_key.view(max_seqlen, -1, nheads, emb_dim + k_dim)
|
|
621
672
|
o_value = o_value.view(max_seqlen, -1, nheads, v_dim)
|
|
@@ -638,8 +689,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
638
689
|
if ctx.cu_seqlens_kv is None:
|
|
639
690
|
# sbhd
|
|
640
691
|
max_seqlen, batch_size, nheads, _ = dk.shape
|
|
641
|
-
dk = dk.view(-1, nheads, ctx.emb_dim + ctx.k_dim)
|
|
642
|
-
dv = dv.view(-1, nheads, ctx.v_dim)
|
|
692
|
+
dk = dk.contiguous().view(-1, nheads, ctx.emb_dim + ctx.k_dim)
|
|
693
|
+
dv = dv.contiguous().view(-1, nheads, ctx.v_dim)
|
|
643
694
|
total_seqlen = dk.shape[0]
|
|
644
695
|
else:
|
|
645
696
|
# thd
|
|
@@ -673,11 +724,13 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
|
|
|
673
724
|
d_kv.stride(0),
|
|
674
725
|
d_kv.stride(1),
|
|
675
726
|
d_emb.stride(0),
|
|
727
|
+
ctx.cp_rank,
|
|
728
|
+
ctx.cp_size,
|
|
676
729
|
)
|
|
677
730
|
if ctx.cu_seqlens_kv is None:
|
|
678
731
|
d_kv = d_kv.view(max_seqlen, batch_size, nheads, ctx.k_dim + ctx.v_dim)
|
|
679
732
|
d_emb = d_emb.view(max_seqlen, batch_size, 1, ctx.emb_dim)
|
|
680
|
-
return d_kv, d_emb, None, None, None, None, None, None, None
|
|
733
|
+
return d_kv, d_emb, None, None, None, None, None, None, None, None, None
|
|
681
734
|
|
|
682
735
|
|
|
683
736
|
@experimental_fn(introduced_with_version="0.13.0")
|
|
@@ -690,6 +743,8 @@ def fused_apply_mla_rope_for_kv(
|
|
|
690
743
|
k_dim: int,
|
|
691
744
|
v_dim: int,
|
|
692
745
|
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
|
746
|
+
cp_rank: int = 0,
|
|
747
|
+
cp_size: int = 1,
|
|
693
748
|
rotary_interleaved: bool = False,
|
|
694
749
|
):
|
|
695
750
|
"""
|
|
@@ -715,5 +770,15 @@ def fused_apply_mla_rope_for_kv(
|
|
|
715
770
|
value: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]
|
|
716
771
|
"""
|
|
717
772
|
return ApplyMLARotaryEmbKV.apply(
|
|
718
|
-
kv,
|
|
773
|
+
kv,
|
|
774
|
+
k_pos_emb,
|
|
775
|
+
cos,
|
|
776
|
+
sin,
|
|
777
|
+
emb_dim,
|
|
778
|
+
k_dim,
|
|
779
|
+
v_dim,
|
|
780
|
+
cu_seqlens_kv,
|
|
781
|
+
cp_rank,
|
|
782
|
+
cp_size,
|
|
783
|
+
rotary_interleaved,
|
|
719
784
|
)
|
|
@@ -155,7 +155,6 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
155
155
|
tp_size = tensor_model_parallel_size
|
|
156
156
|
hidden_size_per_attention_head = core_divide(projection_size, num_attention_heads)
|
|
157
157
|
num_attention_heads_per_partition = core_divide(num_attention_heads, tp_size)
|
|
158
|
-
|
|
159
158
|
# Chunk size tokens, bytes.
|
|
160
159
|
dtype_size_bytes = params_dtype.itemsize
|
|
161
160
|
self.chunk_size_tokens = chunk_size_tokens
|
|
@@ -177,23 +176,24 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
177
176
|
def bytes_to_max_requests_and_tokens(n_bytes):
|
|
178
177
|
n_tokens = n_bytes / self.chunk_size_bytes * self.chunk_size_tokens
|
|
179
178
|
n_requests = n_tokens / max_sequence_length
|
|
180
|
-
return int(n_requests),
|
|
179
|
+
return self.round_up_requests(int(n_requests), tp_size=tp_size), self.round_up_tokens(
|
|
180
|
+
int(n_tokens), tp_size=tp_size
|
|
181
|
+
)
|
|
181
182
|
|
|
182
183
|
self.max_requests, self.max_tokens = bytes_to_max_requests_and_tokens(buffer_size_bytes)
|
|
183
|
-
|
|
184
184
|
if buffer_overflow_factor is not None:
|
|
185
185
|
self.max_requests = self.round_up_requests(
|
|
186
|
-
int(self.max_requests * buffer_overflow_factor)
|
|
186
|
+
int(self.max_requests * buffer_overflow_factor), tp_size=tp_size
|
|
187
187
|
)
|
|
188
188
|
self.max_tokens = self.round_up_tokens(
|
|
189
|
-
int(self.max_tokens * buffer_overflow_factor / 50.0)
|
|
189
|
+
int(self.max_tokens * buffer_overflow_factor / 50.0), tp_size=tp_size
|
|
190
190
|
)
|
|
191
191
|
|
|
192
192
|
if max_requests_override is not None:
|
|
193
|
-
self.max_requests = self.round_up_requests(max_requests_override)
|
|
193
|
+
self.max_requests = self.round_up_requests(max_requests_override, tp_size=tp_size)
|
|
194
194
|
|
|
195
195
|
if max_tokens_override is not None:
|
|
196
|
-
self.max_tokens = self.round_up_tokens(max_tokens_override)
|
|
196
|
+
self.max_tokens = self.round_up_tokens(max_tokens_override, tp_size=tp_size)
|
|
197
197
|
|
|
198
198
|
self.max_requests = min(self.max_requests, self.max_tokens) # e.g., decode only.
|
|
199
199
|
|
|
@@ -277,7 +277,8 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
277
277
|
self.cuda_graph_step_size = cuda_graph_rounder * int(
|
|
278
278
|
math.ceil(int(self.cuda_graph_step_size) / cuda_graph_rounder)
|
|
279
279
|
)
|
|
280
|
-
|
|
280
|
+
# Make sure divisble by TP size
|
|
281
|
+
self.cuda_graph_step_size = math.ceil(self.cuda_graph_step_size / tp_size) * tp_size
|
|
281
282
|
# Cuda graph request counts.
|
|
282
283
|
if num_cuda_graphs == 1:
|
|
283
284
|
self.cuda_graph_request_counts = [self.max_requests]
|
|
@@ -355,26 +356,46 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
355
356
|
REQUEST_ROUNDER = 4
|
|
356
357
|
|
|
357
358
|
@classmethod
|
|
358
|
-
def round_up_tokens(cls, value):
|
|
359
|
-
"""Round up to nearest multiple of `TOKEN_ROUNDER` (above)."""
|
|
359
|
+
def round_up_tokens(cls, value, tp_size=None):
|
|
360
|
+
"""Round up to nearest multiple of `TOKEN_ROUNDER` (above) that is also divisible by tensor model parallel size."""
|
|
360
361
|
if not HAVE_PACKAGING:
|
|
361
362
|
raise ImportError(
|
|
362
363
|
"`packaging` is required for this functionality, please install it with `pip install packaging`"
|
|
363
364
|
)
|
|
364
365
|
if PkgVersion(mcore_version) < PkgVersion("0.13"):
|
|
365
366
|
return cls.round_up(value)
|
|
366
|
-
|
|
367
|
+
|
|
368
|
+
# Make sure divisible by TP size
|
|
369
|
+
if tp_size is None:
|
|
370
|
+
# Check if parallel state is initialized before trying to get TP size
|
|
371
|
+
if parallel_state.is_initialized():
|
|
372
|
+
tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
|
373
|
+
else:
|
|
374
|
+
tp_size = 1
|
|
375
|
+
token_rounder = math.ceil(cls.TOKEN_ROUNDER / tp_size) * tp_size
|
|
376
|
+
|
|
377
|
+
return token_rounder * int(math.ceil(int(value) / token_rounder))
|
|
367
378
|
|
|
368
379
|
@classmethod
|
|
369
|
-
def round_up_requests(cls, value):
|
|
370
|
-
"""Round up to nearest multiple of `REQUEST_ROUNDER` (above)."""
|
|
380
|
+
def round_up_requests(cls, value, tp_size=None):
|
|
381
|
+
"""Round up to nearest multiple of `REQUEST_ROUNDER` (above) that is also divisible by tensor model parallel size."""
|
|
371
382
|
if not HAVE_PACKAGING:
|
|
372
383
|
raise ImportError(
|
|
373
384
|
"`packaging` is required for this functionality, please install it with `pip install packaging`"
|
|
374
385
|
)
|
|
375
386
|
if PkgVersion(mcore_version) < PkgVersion("0.13"):
|
|
376
387
|
return cls.round_up(value)
|
|
377
|
-
|
|
388
|
+
|
|
389
|
+
# Make sure divisible by TP size
|
|
390
|
+
if tp_size is None:
|
|
391
|
+
# Check if parallel state is initialized before trying to get TP size
|
|
392
|
+
if parallel_state.is_initialized():
|
|
393
|
+
tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
|
394
|
+
else:
|
|
395
|
+
tp_size = 1
|
|
396
|
+
request_rounder = math.ceil(cls.REQUEST_ROUNDER / tp_size) * tp_size
|
|
397
|
+
|
|
398
|
+
return request_rounder * int(math.ceil(int(value) / request_rounder))
|
|
378
399
|
|
|
379
400
|
@classmethod
|
|
380
401
|
def round_up(cls, value):
|
|
@@ -1043,21 +1064,16 @@ class DynamicInferenceContext(BaseInferenceContext):
|
|
|
1043
1064
|
# We determine how many requests we can resume and resume them
|
|
1044
1065
|
# Assign released chunks to paused requests.
|
|
1045
1066
|
# todo: @shanmugamr, un-pause requests using FIFO, rather than LIFO.
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
if active_request_count < self.gtd_request_count:
|
|
1051
|
-
resume_request_count = min(
|
|
1052
|
-
self.paused_request_count, self.gtd_request_count - active_request_count
|
|
1053
|
-
)
|
|
1054
|
-
else:
|
|
1055
|
-
# If there are more active requests than gtd requests and not enough
|
|
1056
|
-
# chunks available, no requests can be resumed
|
|
1057
|
-
resume_request_count = 0
|
|
1067
|
+
num_non_gtd_chunks = max(0, self.chunk_allocator.chunk_count_avail - self.gtd_chunk_count)
|
|
1068
|
+
if num_non_gtd_chunks:
|
|
1069
|
+
# if we have non-gtd chunks, use them. Do not dip into the gtd-chunk pool
|
|
1070
|
+
resume_request_count = min(num_non_gtd_chunks, self.paused_request_count)
|
|
1058
1071
|
else:
|
|
1059
|
-
#
|
|
1060
|
-
|
|
1072
|
+
# only dip into the gtd-chunk pool if we have run out of non-gtd-chunks and the active
|
|
1073
|
+
# request count has fallen below a certain threshold.
|
|
1074
|
+
resume_request_count = min(
|
|
1075
|
+
max(self.gtd_request_count - active_request_count, 0), self.paused_request_count
|
|
1076
|
+
)
|
|
1061
1077
|
|
|
1062
1078
|
self.paused_request_count -= resume_request_count
|
|
1063
1079
|
active_request_count += resume_request_count
|
|
@@ -26,6 +26,8 @@ from megatron.core.inference.model_inference_wrappers.abstract_model_inference_w
|
|
|
26
26
|
from megatron.core.inference.sampling_params import SamplingParams
|
|
27
27
|
from megatron.core.inference.utils import get_attention_mask
|
|
28
28
|
from megatron.core.transformer.cuda_graphs import create_cudagraphs
|
|
29
|
+
from megatron.core.transformer.moe.moe_layer import BaseMoELayer
|
|
30
|
+
from megatron.core.transformer.utils import set_model_to_sequence_parallel
|
|
29
31
|
from megatron.core.utils import get_model_config
|
|
30
32
|
|
|
31
33
|
try:
|
|
@@ -429,9 +431,11 @@ class TextGenerationController:
|
|
|
429
431
|
# Get flat tokens, position ids.
|
|
430
432
|
input_ids, position_ids = context.current_input_and_position_ids()
|
|
431
433
|
|
|
434
|
+
model_config = get_model_config(self.inference_wrapped_model.model)
|
|
435
|
+
|
|
432
436
|
# If using symmetric kernels and we are using using nccl
|
|
433
437
|
# for prefill turn off symmetric kernels
|
|
434
|
-
symmetric_ar_type =
|
|
438
|
+
symmetric_ar_type = model_config.symmetric_ar_type
|
|
435
439
|
nccl_all_reduce_for_prefill = (
|
|
436
440
|
self.inference_wrapped_model.inference_wrapper_config.nccl_all_reduce_for_prefill
|
|
437
441
|
)
|
|
@@ -588,7 +592,9 @@ class TextGenerationController:
|
|
|
588
592
|
)
|
|
589
593
|
|
|
590
594
|
# Check whether CUDA graphs are enabled
|
|
591
|
-
enable_cuda_graph =
|
|
595
|
+
enable_cuda_graph = (
|
|
596
|
+
model_config.enable_cuda_graph and model_config.cuda_graph_scope != "full_iteration"
|
|
597
|
+
)
|
|
592
598
|
|
|
593
599
|
# Pad batch tokens if necessary
|
|
594
600
|
batch_size = len(active_requests)
|
|
@@ -681,6 +687,21 @@ class TextGenerationController:
|
|
|
681
687
|
not self.inference_wrapped_model.inference_context.is_decode_only()
|
|
682
688
|
), f"Generation must start in prefill mode"
|
|
683
689
|
|
|
690
|
+
# Sequence parallelism is required for MoE layers when using expert parallelism (EP)
|
|
691
|
+
# becausethe expert routing mechanism relies on sequence parallelism's communication
|
|
692
|
+
# infrastructure to distribute tokens across expert ranks. However, sequence parallelism
|
|
693
|
+
# is not currently supported for non-MoE layers during inference,so we selectively
|
|
694
|
+
# disable it for all other layer types. This is safe because MoE layers perform an
|
|
695
|
+
# all-gather operation on sequences before passing data to subsequent layers, ensuring
|
|
696
|
+
# that each rank has the complete sequence data needed for the next non-MoE layer.
|
|
697
|
+
tp_size = model_config.tensor_model_parallel_size
|
|
698
|
+
ep_size = model_config.expert_model_parallel_size
|
|
699
|
+
model_is_tp_ep = tp_size > 1 and ep_size > 1
|
|
700
|
+
if model_is_tp_ep:
|
|
701
|
+
set_model_to_sequence_parallel(
|
|
702
|
+
self.inference_wrapped_model.model.module, False, exclude_modules=[BaseMoELayer]
|
|
703
|
+
)
|
|
704
|
+
|
|
684
705
|
# If using symmetric kernels and we are using using nccl
|
|
685
706
|
# for prefill turn off symmetric kernels
|
|
686
707
|
symmetric_ar_type = model_config.symmetric_ar_type
|
|
@@ -237,6 +237,14 @@ class ModelParallelConfig:
|
|
|
237
237
|
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
|
|
238
238
|
"""
|
|
239
239
|
|
|
240
|
+
overlap_moe_expert_parallel_comm: bool = False
|
|
241
|
+
"""Overlap EP A2A communications with independent computations of different micro-batches
|
|
242
|
+
in 1f1b phase of pipelining or non-pipelining schedule.
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
delay_wgrad_compute: bool = False
|
|
246
|
+
"""Delay the weight gradient computation to improve batch-level communication overlapping"""
|
|
247
|
+
|
|
240
248
|
###################
|
|
241
249
|
# Pipeline Parallel
|
|
242
250
|
###################
|
|
@@ -307,9 +315,6 @@ class ModelParallelConfig:
|
|
|
307
315
|
rank 1 | 0 1 2 0 1 2 3 4 3 4
|
|
308
316
|
"""
|
|
309
317
|
|
|
310
|
-
delay_wgrad_compute: bool = False
|
|
311
|
-
"""If true, delay the wgrad compute for better overlapping in combined 1F1B."""
|
|
312
|
-
|
|
313
318
|
###################
|
|
314
319
|
# CPU Offloading
|
|
315
320
|
###################
|