megatron-core 0.14.0rc1__tar.gz → 0.14.0rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (307) hide show
  1. {megatron_core-0.14.0rc1/megatron_core.egg-info → megatron_core-0.14.0rc2}/PKG-INFO +2 -2
  2. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/enums.py +10 -3
  3. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fp8_utils.py +6 -2
  4. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/dynamic_context.py +52 -6
  5. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/static_context.py +1 -1
  6. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/dynamic_engine.py +18 -3
  7. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +2 -10
  8. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +2 -6
  9. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +2 -9
  10. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +15 -2
  11. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +57 -13
  12. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +15 -2
  13. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/utils.py +16 -0
  14. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/model_parallel_config.py +0 -5
  15. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/T5/t5_model.py +2 -7
  16. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/T5/t5_spec.py +2 -0
  17. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/bert_layer_specs.py +2 -0
  18. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/language_model_embedding.py +3 -3
  19. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +2 -2
  20. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/language_module/language_module.py +57 -17
  21. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/gpt_layer_specs.py +4 -0
  22. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/gpt_model.py +19 -15
  23. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +2 -0
  24. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/moe_module_specs.py +2 -0
  25. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mamba/mamba_model.py +12 -16
  26. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/submodules/audio.py +1 -0
  27. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/llava_model.py +19 -4
  28. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/decoder_spec.py +2 -0
  29. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/encoder_spec.py +2 -0
  30. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/clip_vit_model.py +9 -0
  31. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/multimodal_projector.py +10 -1
  32. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/radio.py +7 -0
  33. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/__init__.py +38 -4
  34. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/distrib_optimizer.py +54 -6
  35. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/optimizer.py +27 -1
  36. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/package_info.py +1 -1
  37. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/parallel_state.py +42 -451
  38. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/p2p_communication.py +25 -68
  39. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/schedules.py +12 -73
  40. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/utils.py +57 -1
  41. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/rerun_state_machine.py +123 -86
  42. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/cuda_graphs.py +62 -45
  43. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/enums.py +8 -1
  44. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/heterogeneous/linear_replacements.py +4 -0
  45. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/experts.py +1 -0
  46. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/moe_layer.py +2 -0
  47. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/moe_utils.py +6 -0
  48. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/router.py +23 -2
  49. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/multi_latent_attention.py +9 -3
  50. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/multi_token_prediction.py +10 -3
  51. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/transformer_block.py +22 -11
  52. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/transformer_config.py +31 -2
  53. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/transformer_layer.py +0 -4
  54. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2/megatron_core.egg-info}/PKG-INFO +2 -2
  55. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron_core.egg-info/requires.txt +1 -1
  56. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/pyproject.toml +13 -3
  57. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/LICENSE +0 -0
  58. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/MANIFEST.in +0 -0
  59. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/README.md +0 -0
  60. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/README.md +0 -0
  61. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/__init__.py +0 -0
  62. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/config.py +0 -0
  63. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/config_logger.py +0 -0
  64. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/__init__.py +0 -0
  65. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/bert_dataset.py +0 -0
  66. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/blended_dataset.py +0 -0
  67. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  68. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  69. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/gpt_dataset.py +0 -0
  70. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/helpers.cpp +0 -0
  71. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/helpers.py +0 -0
  72. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/indexed_dataset.py +0 -0
  73. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/masked_dataset.py +0 -0
  74. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/megatron_dataset.py +0 -0
  75. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  76. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/multimodal_dataset.py +0 -0
  77. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/object_storage_utils.py +0 -0
  78. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/__init__.py +0 -0
  79. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/__init__.py +0 -0
  80. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  81. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/config.py +0 -0
  82. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  83. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  84. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/__init__.py +0 -0
  85. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/build.py +0 -0
  86. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/dataset.py +0 -0
  87. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/utils.py +0 -0
  88. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/external_libs.py +0 -0
  89. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/__init__.py +0 -0
  90. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/build.py +0 -0
  91. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/factory.py +0 -0
  92. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/index.py +0 -0
  93. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  94. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  95. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  96. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/utils.py +0 -0
  97. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/validate.py +0 -0
  98. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/__init__.py +0 -0
  99. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  100. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  101. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/query.py +0 -0
  102. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  103. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/utils.py +0 -0
  104. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/utils.py +0 -0
  105. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/t5_dataset.py +0 -0
  106. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/utils.py +0 -0
  107. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/utils_object_storage.py +0 -0
  108. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/datasets/utils_s3.py +0 -0
  109. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/__init__.py +0 -0
  110. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/core.py +0 -0
  111. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  112. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  113. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/mapping.py +0 -0
  114. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  115. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/serialization.py +0 -0
  116. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  117. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  118. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  119. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  120. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  121. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  122. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  123. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  124. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  125. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  126. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  127. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  128. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  129. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  130. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  131. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/utils.py +0 -0
  132. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/validation.py +0 -0
  133. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/__init__.py +0 -0
  134. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
  135. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -0
  136. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
  137. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/data_parallel_base.py +0 -0
  138. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  139. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  140. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/finalize_model_grads.py +0 -0
  141. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  142. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  143. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  144. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/energy_monitor.py +0 -0
  145. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/__init__.py +0 -0
  146. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/data_type.py +0 -0
  147. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/export_config.py +0 -0
  148. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/model_type.py +0 -0
  149. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/__init__.py +0 -0
  150. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  151. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  152. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  153. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  154. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  155. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  156. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  157. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  158. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  159. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  160. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  161. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  162. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/extensions/__init__.py +0 -0
  163. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/extensions/kitchen.py +0 -0
  164. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/extensions/transformer_engine.py +0 -0
  165. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  166. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/__init__.py +0 -0
  167. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  168. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  169. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  170. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  171. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  172. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_indices_converter.py +0 -0
  173. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_layer_norm.py +0 -0
  174. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  175. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  176. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_softmax.py +0 -0
  177. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/hyper_comm_grid.py +0 -0
  178. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/__init__.py +0 -0
  179. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/async_stream.py +0 -0
  180. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/common_inference_params.py +0 -0
  181. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/communication_utils.py +0 -0
  182. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/__init__.py +0 -0
  183. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/base_context.py +0 -0
  184. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  185. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/__init__.py +0 -0
  186. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/abstract_engine.py +0 -0
  187. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/mcore_engine.py +0 -0
  188. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/static_engine.py +0 -0
  189. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/inference_request.py +0 -0
  190. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  191. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  192. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  193. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  194. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  195. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/sampling_params.py +0 -0
  196. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/scheduler.py +0 -0
  197. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  198. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  199. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/inference_params.py +0 -0
  200. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/jit.py +0 -0
  201. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/T5/__init__.py +0 -0
  202. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/__init__.py +0 -0
  203. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/backends.py +0 -0
  204. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/__init__.py +0 -0
  205. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/bert_lm_head.py +0 -0
  206. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/bert_model.py +0 -0
  207. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/bert/pooler.py +0 -0
  208. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/__init__.py +0 -0
  209. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/__init__.py +0 -0
  210. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  211. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  212. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  213. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/language_module/__init__.py +0 -0
  214. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/vision_module/__init__.py +0 -0
  215. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  216. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/__init__.py +0 -0
  217. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  218. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/__init__.py +0 -0
  219. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/clip_model.py +0 -0
  220. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/module.py +0 -0
  221. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/qwen_model.py +0 -0
  222. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mamba/__init__.py +0 -0
  223. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  224. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/__init__.py +0 -0
  225. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/config/__init__.py +0 -0
  226. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/config/base_configs.py +0 -0
  227. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/model/__init__.py +0 -0
  228. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/model/base.py +0 -0
  229. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/submodules/base.py +0 -0
  230. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/submodules/vision.py +0 -0
  231. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/__init__.py +0 -0
  232. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/context_parallel.py +0 -0
  233. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/llava_spec.py +0 -0
  234. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/__init__.py +0 -0
  235. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/base_attention.py +0 -0
  236. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/config.py +0 -0
  237. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/decoder_attention.py +0 -0
  238. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/encoder_attention.py +0 -0
  239. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/model.py +0 -0
  240. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/retro/utils.py +0 -0
  241. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/__init__.py +0 -0
  242. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  243. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/msc_utils.py +0 -0
  244. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/num_microbatches_calculator.py +0 -0
  245. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/clip_grads.py +0 -0
  246. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  247. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  248. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/grad_scaler.py +0 -0
  249. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer/optimizer_config.py +0 -0
  250. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/optimizer_param_scheduler.py +0 -0
  251. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/packed_seq_params.py +0 -0
  252. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/__init__.py +0 -0
  253. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/__init__.py +0 -0
  254. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/__init__.py +0 -0
  255. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  256. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  257. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  258. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/layers.py +0 -0
  259. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  260. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  261. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/process_groups_config.py +0 -0
  262. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/quantization/__init__.py +0 -0
  263. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/quantization/quant_config.py +0 -0
  264. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/quantization/utils.py +0 -0
  265. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/requirements.txt +0 -0
  266. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/__init__.py +0 -0
  267. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_block.py +0 -0
  268. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  269. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  270. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_layer.py +0 -0
  271. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_mixer.py +0 -0
  272. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/mlp_layer.py +0 -0
  273. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/ssm/triton_cache_manager.py +0 -0
  274. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/__init__.py +0 -0
  275. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  276. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/data.py +0 -0
  277. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/layers.py +0 -0
  278. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/mappings.py +0 -0
  279. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/random.py +0 -0
  280. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/utils.py +0 -0
  281. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/timers.py +0 -0
  282. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/__init__.py +0 -0
  283. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/attention.py +0 -0
  284. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  285. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  286. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/dot_product_attention.py +0 -0
  287. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  288. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/identity_op.py +0 -0
  289. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/mlp.py +0 -0
  290. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/module.py +0 -0
  291. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/__init__.py +0 -0
  292. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  293. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  294. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/shared_experts.py +0 -0
  295. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  296. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  297. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  298. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/spec_utils.py +0 -0
  299. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/torch_layer_norm.py +0 -0
  300. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/torch_norm.py +0 -0
  301. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/transformer/utils.py +0 -0
  302. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron/core/utils.py +0 -0
  303. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron_core.egg-info/SOURCES.txt +0 -0
  304. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron_core.egg-info/dependency_links.txt +0 -0
  305. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/megatron_core.egg-info/top_level.txt +0 -0
  306. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/setup.cfg +0 -0
  307. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.14.0rc1
3
+ Version: 0.14.0rc2
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -44,7 +44,7 @@ Requires-Dist: nvtx; extra == "dev"
44
44
  Requires-Dist: transformers; extra == "dev"
45
45
  Requires-Dist: multi-storage-client; extra == "dev"
46
46
  Requires-Dist: setuptools<80.0.0; extra == "dev"
47
- Requires-Dist: nvidia-modelopt[torch]; sys_platform != "darwin" and extra == "dev"
47
+ Requires-Dist: nvidia-modelopt[torch]~=0.31.0; sys_platform != "darwin" and extra == "dev"
48
48
  Requires-Dist: megatron-energon[av_decode]<7; extra == "dev"
49
49
  Provides-Extra: lts
50
50
  Requires-Dist: tqdm; extra == "lts"
@@ -7,9 +7,16 @@ class ModelType(enum.Enum):
7
7
  """Model type."""
8
8
 
9
9
  encoder_or_decoder = 1
10
- encoder_and_decoder = 2
11
- retro_encoder = 3
12
- retro_decoder = 4
10
+ retro_encoder = 2
11
+ retro_decoder = 3
12
+
13
+ @property
14
+ def encoder_and_decoder(self):
15
+ """Deprecated property - use encoder_or_decoder instead."""
16
+ raise ValueError(
17
+ "ModelType.encoder_and_decoder is deprecated. Please use ModelType.encoder_or_decoder "
18
+ "instead."
19
+ )
13
20
 
14
21
 
15
22
  class Fp8Recipe(str, enum.Enum):
@@ -346,8 +346,12 @@ else:
346
346
  def _modify_underlying_storage_impl(*args, **kwargs):
347
347
  raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
348
348
 
349
- def _quantize_param_shard_impl(*args, **kwargs):
350
- raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
349
+ def _quantize_param_shard_impl(model_params, *args, **kwargs):
350
+ if len(model_params) == 0:
351
+ return
352
+ else:
353
+ # If TE is not installed, there shouldn't be any fp8 params.
354
+ raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
351
355
 
352
356
  def _correct_amax_history_if_needed_impl(*args, **kwargs):
353
357
  # If TE is not installed, we are definitely not using fp8 for training, so no correction
@@ -2,9 +2,11 @@
2
2
 
3
3
  import math
4
4
  import warnings
5
- from typing import Optional, Tuple
5
+ from typing import List, Optional, Tuple
6
6
 
7
7
  import torch
8
+ import torch.nn.functional as F
9
+ from packaging.version import Version as PkgVersion
8
10
  from torch import Tensor
9
11
 
10
12
  from megatron.core import parallel_state
@@ -123,8 +125,10 @@ class DynamicInferenceContext(BaseInferenceContext):
123
125
  max_requests_override: Optional[int] = None,
124
126
  max_tokens_override: Optional[int] = None,
125
127
  tensor_model_parallel_size: Optional[int] = None,
128
+ materialize_only_last_token_logits: bool = True,
126
129
  ):
127
- super().__init__(materialize_only_last_token_logits=True)
130
+
131
+ super().__init__(materialize_only_last_token_logits=materialize_only_last_token_logits)
128
132
  # Per partition num heads and hidden size.
129
133
  projection_size = kv_channels * num_attention_heads
130
134
  if tensor_model_parallel_size is None:
@@ -762,7 +766,7 @@ class DynamicInferenceContext(BaseInferenceContext):
762
766
  self.total_request_count += 1
763
767
  self.active_token_count += context_length
764
768
 
765
- def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens):
769
+ def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens):
766
770
  """
767
771
  Swaps all the relevent booking tensors with src idxs to dst idxs
768
772
  """
@@ -866,7 +870,12 @@ class DynamicInferenceContext(BaseInferenceContext):
866
870
  kv_chunks_asigned = self.request_to_kv_chunk_ids[finished_idxs]
867
871
  non_zero_values_in_kv_memory = kv_chunks_asigned[kv_chunks_asigned != -1]
868
872
  self.chunk_allocator.release_memory_chunks(non_zero_values_in_kv_memory)
869
- self.request_to_kv_chunk_ids[finished_idxs].fill_(-1)
873
+
874
+ # Reset the KV chunks for finished requests.
875
+ # Note: do not use fill_() (or add_() and similar inplace ops) here.
876
+ # The combinition of indexing with a tensor (like finished_idxs) and fill_()/add_() creates a clone
877
+ # and updates it instead of the original tensor.
878
+ self.request_to_kv_chunk_ids[finished_idxs] = -1
870
879
 
871
880
  if active_request_count > 0:
872
881
  finished_idxs_on_left = (
@@ -881,12 +890,15 @@ class DynamicInferenceContext(BaseInferenceContext):
881
890
  + self.paused_request_count
882
891
  )
883
892
 
884
- self._swap_book_keeping_tensors(
893
+ self._move_book_keeping_tensors(
885
894
  src_idxs=active_idxs_on_right,
886
895
  dst_idxs=finished_idxs_on_left,
887
896
  next_tokens=next_tokens,
888
897
  )
889
898
 
899
+ # Reset chunk ids for recently moved requests.
900
+ self.request_to_kv_chunk_ids[active_idxs_on_right] = -1
901
+
890
902
  # 5. We identify requests that require a new chunk and add them to the paused requests (i.e move them left) :-
891
903
  # a) Put requests that have filled their current chunk and require a new one in a pause state temporarily
892
904
  # b) Move the paused requests to the left, and active requets to the right
@@ -931,7 +943,7 @@ class DynamicInferenceContext(BaseInferenceContext):
931
943
  )
932
944
  dst_idxs = torch.cat((active_request_ids_on_left, paused_requests_idxs_on_right))
933
945
  src_idxs = torch.cat((paused_requests_idxs_on_right, active_request_ids_on_left))
934
- self._swap_book_keeping_tensors(
946
+ self._move_book_keeping_tensors(
935
947
  src_idxs=src_idxs, dst_idxs=dst_idxs, next_tokens=next_tokens
936
948
  )
937
949
 
@@ -974,6 +986,8 @@ class DynamicInferenceContext(BaseInferenceContext):
974
986
  if self.paused_request_count > 0:
975
987
  self.paused_tokens = next_tokens[: self.paused_request_count]
976
988
 
989
+ # add_ and fill_ calls seems to work as intended with sliced indexing (i.e. x[3:5].add(...) or x[3:5].fill_)
990
+ # but when another tensor is used for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors)
977
991
  self.request_kv_length_offsets[self.paused_request_count : self.total_request_count].add_(
978
992
  self.request_query_lengths[self.paused_request_count : self.total_request_count]
979
993
  )
@@ -1027,3 +1041,35 @@ class DynamicInferenceContext(BaseInferenceContext):
1027
1041
  self.token_to_local_position_within_kv_chunk[: self.active_token_count] = (
1028
1042
  self.request_last_kv_chunk_offset[self.paused_request_count : self.total_request_count]
1029
1043
  )
1044
+
1045
+ def calculate_log_probs(self, logits: torch.Tensor) -> List[List[float]]:
1046
+ """Calculate log probs for all active requests and return them.
1047
+
1048
+ TODO: @wdykas support top-n log probs.
1049
+
1050
+ Args:
1051
+ logits: Raw model output logits with shape [1, sequence_length, vocab_size].
1052
+
1053
+ Returns:
1054
+ List of lists where each inner list contains log probs for a request in the
1055
+ same order as the active requests (from paused_request_count to total_request_count).
1056
+ """
1057
+ # Calculate log_probs (sequence_length x vocab_size)
1058
+ log_probs = F.log_softmax(logits, dim=-1).to(torch.float32).squeeze()
1059
+
1060
+ # Extract the log probs for only the selected tokens
1061
+ # (sequence_length x vocab_size) -> (sequence_length)
1062
+ active_token_ids = self.token_to_input_ids[: self.active_token_count]
1063
+ sequence_indices = torch.arange(self.active_token_count, device=log_probs.device)
1064
+ selected_log_probs = log_probs[sequence_indices, active_token_ids]
1065
+
1066
+ # Split the log probs across request boundaries
1067
+ active_query_lengths = self.request_query_lengths[
1068
+ self.paused_request_count : self.total_request_count
1069
+ ]
1070
+ selected_log_probs_list = selected_log_probs.cpu().split(
1071
+ active_query_lengths.tolist(), dim=0
1072
+ )
1073
+
1074
+ # Convert each log prob tensor into a list
1075
+ return [lp.tolist() for lp in selected_log_probs_list]
@@ -17,7 +17,7 @@ class StaticInferenceContext(BaseInferenceContext):
17
17
  """
18
18
 
19
19
  def __init__(self, max_batch_size: int, max_sequence_length: int):
20
- super().__init__(materialize_only_last_token_logits=False)
20
+ super().__init__(materialize_only_last_token_logits=True)
21
21
  self.max_sequence_length = max_sequence_length
22
22
  self.max_batch_size = max_batch_size
23
23
  self.sequence_len_offset = 0
@@ -2,6 +2,7 @@
2
2
 
3
3
  import asyncio
4
4
  from collections import deque
5
+ from itertools import repeat
5
6
  from typing import Dict, List, Optional, Tuple, Union
6
7
 
7
8
  import torch
@@ -182,6 +183,7 @@ class DynamicInferenceEngine(AbstractEngine):
182
183
  finished_request_ids: torch.Tensor,
183
184
  step_time: float,
184
185
  sample: torch.Tensor,
186
+ log_probs: torch.Tensor,
185
187
  ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]:
186
188
  """
187
189
  Handles post-processing for requests after a step.
@@ -191,6 +193,7 @@ class DynamicInferenceEngine(AbstractEngine):
191
193
  finished_request_ids (torch.Tensor): A list of finished request ids
192
194
  step_time (float): The latency of the last step
193
195
  sample: (torch.Tensor): The newly generated tokens for each request
196
+ log_probs: (List): Log probs for each request
194
197
 
195
198
  Returns:
196
199
  A list of active requests and completed requests as `DynamicInferenceRequest` objects
@@ -200,13 +203,25 @@ class DynamicInferenceEngine(AbstractEngine):
200
203
  finished_request_ids = set(finished_request_ids.tolist())
201
204
  self.finished_request_count += len(finished_request_ids)
202
205
 
203
- for request_id, token in zip(request_ids.tolist(), sample.tolist()):
206
+ log_probs_iter = log_probs if log_probs else repeat(None)
207
+
208
+ for request_id, token, request_log_probs in zip(
209
+ request_ids.tolist(), sample.tolist(), log_probs_iter
210
+ ):
204
211
  request: DynamicInferenceRequest = self.requests[request_id]
205
212
  request.generated_tokens.append(token)
206
213
  if request.tpot is None:
207
214
  request.tpot = []
208
215
  request.tpot.append(step_time)
209
216
 
217
+ if request_log_probs is not None:
218
+ # If prompt log probs is None we are in prefill
219
+ if request.prompt_log_probs is None:
220
+ request.prompt_log_probs = request_log_probs
221
+ request.generated_log_probs = []
222
+ else:
223
+ request.generated_log_probs.extend(request_log_probs)
224
+
210
225
  if request_id in finished_request_ids:
211
226
  request.generated_length = len(request.generated_tokens)
212
227
  request.status = Status.COMPLETED
@@ -266,11 +281,11 @@ class DynamicInferenceEngine(AbstractEngine):
266
281
  step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3
267
282
 
268
283
  if result is not None:
269
- request_ids, finished_request_ids, sample = result
284
+ request_ids, finished_request_ids, sample, log_probs = result
270
285
 
271
286
  # TODO: Move this to a background thread?
272
287
  (active_requests, finished_requests) = self.post_process_requests(
273
- request_ids, finished_request_ids, step_time, sample
288
+ request_ids, finished_request_ids, step_time, sample, log_probs
274
289
  )
275
290
 
276
291
  # TODO: Move this to a background thread?
@@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, Optional, Union
7
7
 
8
8
  import torch
9
9
 
10
- from megatron.core import parallel_state, tensor_parallel
10
+ from megatron.core import parallel_state
11
11
  from megatron.core.inference.communication_utils import (
12
12
  is_pipeline_first_stage,
13
13
  is_pipeline_last_stage,
@@ -152,13 +152,12 @@ class AbstractModelInferenceWrapper(abc.ABC):
152
152
  tokens = inference_input["tokens"]
153
153
  position_ids = inference_input["position_ids"]
154
154
  attention_mask = inference_input["attention_mask"]
155
- runtime_gather_output = inference_input.get("runtime_gather_output")
156
155
  return self.model(
157
156
  tokens,
158
157
  position_ids,
159
158
  attention_mask,
160
159
  inference_context=self.inference_context,
161
- runtime_gather_output=runtime_gather_output,
160
+ runtime_gather_output=True, # Inference should always gather the logits
162
161
  )
163
162
 
164
163
  def _get_batch_size_and_seq_len(
@@ -201,7 +200,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
201
200
  """
202
201
  tokens = inference_input["tokens"]
203
202
  logits = self._forward(inference_input)
204
- logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits, self.tp_group)
205
203
  self.inference_context.increment_sequence_len_offset(tokens.size(1))
206
204
 
207
205
  return logits
@@ -243,7 +241,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
243
241
  logits = None
244
242
  if is_pipeline_last_stage(self.pp_group):
245
243
  logits = output_tensor
246
- logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits, self.tp_group)
247
244
 
248
245
  # Explicitly cast logits to expected dtype
249
246
  logits = logits.to(self.inference_wrapper_config.params_dtype)
@@ -269,7 +266,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
269
266
  tokens = inference_input["tokens"]
270
267
  position_ids = inference_input["position_ids"]
271
268
  attention_mask = inference_input["attention_mask"]
272
- runtime_gather_output = inference_input.get("runtime_gather_output")
273
269
  materialize_only_last_token_logits = (
274
270
  self.inference_context.materialize_only_last_token_logits
275
271
  )
@@ -317,7 +313,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
317
313
  "position_ids": position_ids2use,
318
314
  "attention_mask": attention_mask,
319
315
  "inference_context": self.inference_context,
320
- "runtime_gather_output": runtime_gather_output,
321
316
  }
322
317
  )
323
318
 
@@ -327,9 +322,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
327
322
  self.inference_context.batch_size_offset += current_micro_batch_size
328
323
 
329
324
  if is_pipeline_last_stage(self.pp_group):
330
- output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
331
- output_tensor, self.tp_group
332
- )
333
325
  assert logits is not None
334
326
  logits[start:end, ...] = output_tensor
335
327
 
@@ -10,6 +10,7 @@ from megatron.core.inference.model_inference_wrappers.abstract_model_inference_w
10
10
  from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
11
11
  InferenceWrapperConfig,
12
12
  )
13
+ from megatron.core.inference.utils import get_attention_mask
13
14
  from megatron.core.models.gpt import GPTModel
14
15
  from megatron.core.transformer.enums import AttnBackend
15
16
  from megatron.core.utils import get_model_config
@@ -74,12 +75,7 @@ class GPTInferenceWrapper(AbstractModelInferenceWrapper):
74
75
  attention_backend = config.attention_backend
75
76
 
76
77
  if attention_backend == AttnBackend.local:
77
- attention_mask = torch.tril(
78
- torch.ones((1, seq_length, seq_length), device=prompts_tokens.device)
79
- ).view(1, 1, seq_length, seq_length)
80
-
81
- # Convert to boolean
82
- attention_mask = attention_mask < 0.5
78
+ attention_mask = get_attention_mask(seq_length)
83
79
  elif (
84
80
  attention_backend == AttnBackend.flash
85
81
  or attention_backend == AttnBackend.fused
@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional
4
4
 
5
5
  import torch
6
6
 
7
- from megatron.core import parallel_state
8
7
  from megatron.core.inference.communication_utils import (
9
8
  is_pipeline_first_stage,
10
9
  is_pipeline_last_stage,
@@ -48,16 +47,10 @@ class VLMInferenceWrapper(GPTInferenceWrapper):
48
47
  # has part of the LM decoder. In this case, the current stage should only receive
49
48
  # vision embeddings.
50
49
  if pp_rank > 0:
51
- self._recv_only_vision_embeds = (
52
- parallel_state.is_inside_encoder(pp_rank - 1)
53
- and (not parallel_state.is_inside_decoder(pp_rank - 1))
54
- and parallel_state.is_inside_decoder()
55
- )
50
+ self._recv_only_vision_embeds = False # TODO: Implement new logic for vision embeddings
56
51
 
57
52
  # Checks if the current stage only has a vision encoder
58
- self._encoder_only = (
59
- parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder()
60
- )
53
+ self._encoder_only = False # TODO: Implement new logic for encoder-only stages
61
54
 
62
55
  def prep_inference_input(
63
56
  self,
@@ -7,6 +7,7 @@ from megatron.core.inference.inference_request import InferenceRequest
7
7
  from megatron.core.inference.text_generation_controllers.text_generation_controller import (
8
8
  TextGenerationController,
9
9
  )
10
+ from megatron.core.inference.utils import get_attention_mask
10
11
 
11
12
 
12
13
  class EncoderDecoderTextGenerationController(TextGenerationController):
@@ -18,13 +19,18 @@ class EncoderDecoderTextGenerationController(TextGenerationController):
18
19
  """
19
20
 
20
21
  def prep_inference_input(
21
- self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
22
+ self,
23
+ prompts_tokens: torch.Tensor,
24
+ active_requests: OrderedDict[str, InferenceRequest],
25
+ use_attention_mask: bool = False,
22
26
  ) -> Dict[str, Any]:
23
27
  """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
24
28
 
25
29
  Args:
26
30
  prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
27
31
  active_requests (OrderedDict[str, InferenceRequest]): The input active requests
32
+ use_attention_mask (bool): Whether to use an attention mask. Should be set to True only
33
+ when exclusively doing prefill (no decode) with variable prompt lengths.
28
34
 
29
35
  Returns:
30
36
  A dict of the inference input for the current batch.
@@ -33,6 +39,13 @@ class EncoderDecoderTextGenerationController(TextGenerationController):
33
39
  map(lambda request: request.encoder_prompt, active_requests.values())
34
40
  )
35
41
 
36
- return self.inference_wrapped_model.prep_inference_input(
42
+ inference_input = self.inference_wrapped_model.prep_inference_input(
37
43
  prompts_tokens, encoder_prompts, tokenizer=self.tokenizer
38
44
  )
45
+
46
+ if use_attention_mask and (
47
+ attention_mask := inference_input.get("attention_mask", None) is None
48
+ ):
49
+ inference_input["attention_mask"] = get_attention_mask(prompts_tokens.size(1))
50
+
51
+ return inference_input
@@ -24,10 +24,13 @@ from megatron.core.inference.model_inference_wrappers.abstract_model_inference_w
24
24
  AbstractModelInferenceWrapper,
25
25
  )
26
26
  from megatron.core.inference.sampling_params import SamplingParams
27
+ from megatron.core.inference.utils import get_attention_mask
27
28
  from megatron.core.transformer.cuda_graphs import create_cudagraphs
28
29
  from megatron.core.utils import get_model_config
29
30
 
30
31
  try:
32
+ import transformer_engine as te # pylint: disable=unused-import
33
+
31
34
  from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
32
35
 
33
36
  HAVE_TE = True
@@ -429,6 +432,11 @@ class TextGenerationController:
429
432
 
430
433
  context = self.inference_wrapped_model.inference_context
431
434
 
435
+ if sampling_params.return_log_probs:
436
+ assert (
437
+ context.materialize_only_last_token_logits is False
438
+ ), "Materialize only last token logits must be false for returning log probs"
439
+
432
440
  # No tokens?
433
441
  if context.active_token_count == 0:
434
442
  return None
@@ -478,7 +486,13 @@ class TextGenerationController:
478
486
  pp_group=self.pp_group,
479
487
  )
480
488
 
481
- last_token_logits = logits.squeeze(0)
489
+ # Last token logits.
490
+ if context.materialize_only_last_token_logits:
491
+ # When materialize_only_last_token_logits is true, last_token_logits is
492
+ # already called in the forward pass of GPT.
493
+ last_token_logits = logits.squeeze(0)
494
+ else:
495
+ last_token_logits = context.last_token_logits(logits)
482
496
 
483
497
  # Sample.
484
498
  # Use padded vocab size because tokenizer vocab size might not include padding
@@ -505,11 +519,15 @@ class TextGenerationController:
505
519
  )
506
520
  finished_request_ids = context.request_ids[finished_idxs]
507
521
 
522
+ log_probs = None
523
+ if sampling_params.return_log_probs:
524
+ log_probs = context.calculate_log_probs(logits)
525
+
508
526
  # Update requests.
509
527
  # New sample gets updated in update_requests, so we pass in a clone
510
528
  context.update_requests(active_request_mask, new_sample.clone())
511
529
 
512
- return current_request_ids, finished_request_ids, new_sample
530
+ return current_request_ids, finished_request_ids, new_sample, log_probs
513
531
 
514
532
  def _update_top_n_logprobs_dict(
515
533
  self,
@@ -581,13 +599,12 @@ class TextGenerationController:
581
599
 
582
600
  model_config = get_model_config(self.inference_wrapped_model.model)
583
601
 
584
- # Verify that if echo mode is requested we do not generate any new tokens
585
- echo = getattr(sampling_params, "echo", False)
586
- assert (
587
- not echo or sampling_params.num_tokens_to_generate == 0
588
- ), f"Cannot generate new tokens when echoing"
589
- if sampling_params.num_tokens_to_generate == 0 and not echo:
590
- sampling_params.add_attributes({"echo": True})
602
+ # We only need an attention mask if we are exclusively doing prefill over
603
+ # prompts of variable length
604
+ use_attention_mask = (
605
+ sampling_params.num_tokens_to_generate == 0
606
+ and min_prompt_length_in_batch != max_prompt_length_in_batch
607
+ )
591
608
 
592
609
  # Check whether CUDA graphs are enabled
593
610
  enable_cuda_graph = model_config.enable_cuda_graph
@@ -689,7 +706,9 @@ class TextGenerationController:
689
706
  self.inference_wrapped_model.prep_model_for_inference()
690
707
 
691
708
  inference_input: Dict[str, Any] = self.prep_inference_input(
692
- prompts_tokens=padded_batch_prompt_tokens, active_requests=active_requests
709
+ prompts_tokens=padded_batch_prompt_tokens,
710
+ active_requests=active_requests,
711
+ use_attention_mask=use_attention_mask,
693
712
  )
694
713
 
695
714
  assert (
@@ -706,7 +725,13 @@ class TextGenerationController:
706
725
  self.inference_wrapped_model.model.module.set_symmetric_ar(None)
707
726
 
708
727
  context_start_position = 0
709
- context_end_position = min_prompt_length_in_batch
728
+
729
+ # If we are exclusively doing prefill then we can process all prompt tokens
730
+ # together even if the prompt lengths are different
731
+ if sampling_params.num_tokens_to_generate == 0:
732
+ context_end_position = max_prompt_length_in_batch
733
+ else:
734
+ context_end_position = min_prompt_length_in_batch
710
735
 
711
736
  # The initial iteration of this loop runs the prefill phase up to the shortest
712
737
  # prompt length in the batch. Then every subsequent iterations runs a decode step.
@@ -734,6 +759,13 @@ class TextGenerationController:
734
759
  and "attention_mask" in inference_input_for_context_window
735
760
  ):
736
761
  inference_input_for_context_window["attention_mask"] = None
762
+ elif use_attention_mask:
763
+ assert (
764
+ attention_mask := inference_input_for_context_window.get(
765
+ "attention_mask", None
766
+ )
767
+ is not None
768
+ )
737
769
 
738
770
  # Only materialize prompt log probs if the user requests log probs
739
771
  materialize_only_last_token_logits = (
@@ -985,18 +1017,30 @@ class TextGenerationController:
985
1017
  return active_requests
986
1018
 
987
1019
  def prep_inference_input(
988
- self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
1020
+ self,
1021
+ prompts_tokens: torch.Tensor,
1022
+ active_requests: OrderedDict[str, InferenceRequest],
1023
+ use_attention_mask: bool = False,
989
1024
  ) -> Dict[str, Any]:
990
1025
  """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
991
1026
 
992
1027
  Args:
993
1028
  prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
994
1029
  active_requests (OrderedDict[str, InferenceRequest]): The input active requests
1030
+ use_attention_mask (bool): Whether to use an attention mask. Should be set to True only
1031
+ when exclusively doing prefill (no decode) with variable prompt lengths.
995
1032
 
996
1033
  Returns:
997
1034
  A dict of the inference input for the current batch.
998
1035
  """
999
- return self.inference_wrapped_model.prep_inference_input(prompts_tokens)
1036
+ inference_input = self.inference_wrapped_model.prep_inference_input(prompts_tokens)
1037
+
1038
+ if use_attention_mask and (
1039
+ attention_mask := inference_input.get("attention_mask", None) is None
1040
+ ):
1041
+ inference_input["attention_mask"] = get_attention_mask(prompts_tokens.size(1))
1042
+
1043
+ return inference_input
1000
1044
 
1001
1045
  def stream_tokens(
1002
1046
  self,
@@ -7,13 +7,17 @@ from megatron.core.inference.inference_request import InferenceRequest, VLMInfer
7
7
  from megatron.core.inference.text_generation_controllers.text_generation_controller import (
8
8
  TextGenerationController,
9
9
  )
10
+ from megatron.core.inference.utils import get_attention_mask
10
11
 
11
12
 
12
13
  class VLMTextGenerationController(TextGenerationController):
13
14
  """The text generation controller for VLMs"""
14
15
 
15
16
  def prep_inference_input(
16
- self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
17
+ self,
18
+ prompts_tokens: torch.Tensor,
19
+ active_requests: OrderedDict[str, InferenceRequest],
20
+ use_attention_mask: bool = False,
17
21
  ):
18
22
  """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
19
23
 
@@ -22,6 +26,8 @@ class VLMTextGenerationController(TextGenerationController):
22
26
  Args:
23
27
  prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
24
28
  active_requests (OrderedDict[str, InferenceRequest]): The input active requests
29
+ use_attention_mask (bool): Whether to use an attention mask. Should be set to True only
30
+ when exclusively doing prefill (no decode) with variable prompt lengths.
25
31
  """
26
32
  assert len(active_requests) == 1, f"VLM inference currently only supports batch size 1"
27
33
 
@@ -31,10 +37,17 @@ class VLMTextGenerationController(TextGenerationController):
31
37
  request, VLMInferenceRequest
32
38
  ), f"Found inference request of type {type(request)}, expected VLMInferenceRequest"
33
39
 
34
- return self.inference_wrapped_model.prep_inference_input(
40
+ inference_input = self.inference_wrapped_model.prep_inference_input(
35
41
  prompts_tokens,
36
42
  request.num_img_embeddings_per_tile,
37
43
  request.imgs,
38
44
  request.num_tiles,
39
45
  request.decoder_seq_length,
40
46
  )
47
+
48
+ if use_attention_mask and (
49
+ attention_mask := inference_input.get("attention_mask", None) is None
50
+ ):
51
+ inference_input["attention_mask"] = get_attention_mask(prompts_tokens.size(1))
52
+
53
+ return inference_input
@@ -1,4 +1,8 @@
1
1
  # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ import torch
4
+
5
+
2
6
  class Counter:
3
7
  """A simple counter class
4
8
 
@@ -16,3 +20,15 @@ class Counter:
16
20
  def reset(self) -> None:
17
21
  """Reset counter"""
18
22
  self.counter = 0
23
+
24
+
25
+ def get_attention_mask(seq_length: int) -> torch.Tensor:
26
+ """Constructs an attention mask given the input sequence length."""
27
+ attention_mask = torch.tril(
28
+ torch.ones((1, seq_length, seq_length), device=torch.cuda.current_device())
29
+ ).view(1, 1, seq_length, seq_length)
30
+
31
+ # Convert to boolean
32
+ attention_mask = attention_mask < 0.5
33
+
34
+ return attention_mask
@@ -286,11 +286,6 @@ class ModelParallelConfig:
286
286
  Defaults to 0, which means all micro-batches are deferred.
287
287
  """
288
288
 
289
- pipeline_model_parallel_split_rank: Optional[int] = None
290
- """If int, rank where encoder and decoder should be split in cases where the model has both an
291
- encoder and decoder (e.g., T5). Ignored if None.
292
- """
293
-
294
289
  overlap_p2p_comm_warmup_flush: bool = False
295
290
  """If true, overlap communication and computation in warm up and flush phase.
296
291
  Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.