megatron-core 0.14.0rc0__tar.gz → 0.14.0rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (307) hide show
  1. {megatron_core-0.14.0rc0/megatron_core.egg-info → megatron_core-0.14.0rc2}/PKG-INFO +2 -2
  2. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/indexed_dataset.py +5 -0
  3. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/distributed_data_parallel_config.py +9 -0
  4. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/enums.py +10 -3
  5. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/extensions/transformer_engine.py +10 -9
  6. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fp8_utils.py +6 -2
  7. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/dynamic_context.py +52 -6
  8. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/static_context.py +1 -1
  9. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/dynamic_engine.py +78 -34
  10. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/inference_request.py +1 -0
  11. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +2 -10
  12. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +2 -6
  13. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +2 -9
  14. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +15 -2
  15. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +112 -15
  16. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +15 -2
  17. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/utils.py +16 -0
  18. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/model_parallel_config.py +0 -5
  19. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/T5/t5_model.py +2 -7
  20. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/T5/t5_spec.py +2 -0
  21. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/bert/bert_layer_specs.py +2 -0
  22. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/language_model_embedding.py +3 -3
  23. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +2 -2
  24. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/language_module/language_module.py +57 -17
  25. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/gpt_layer_specs.py +4 -0
  26. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/gpt_model.py +19 -15
  27. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +2 -0
  28. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/moe_module_specs.py +2 -0
  29. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mamba/mamba_model.py +12 -16
  30. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/submodules/audio.py +1 -0
  31. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/llava_model.py +19 -4
  32. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/retro/decoder_spec.py +2 -0
  33. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/retro/encoder_spec.py +2 -0
  34. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/vision/clip_vit_model.py +9 -0
  35. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/vision/multimodal_projector.py +10 -1
  36. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/vision/radio.py +7 -0
  37. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/optimizer/__init__.py +38 -4
  38. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/optimizer/distrib_optimizer.py +54 -6
  39. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/optimizer/optimizer.py +27 -1
  40. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/package_info.py +1 -1
  41. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/parallel_state.py +42 -451
  42. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/p2p_communication.py +25 -68
  43. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/schedules.py +12 -73
  44. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/utils.py +57 -1
  45. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/rerun_state_machine.py +123 -86
  46. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/layers.py +9 -7
  47. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/attention.py +2 -1
  48. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/cuda_graphs.py +67 -46
  49. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/enums.py +8 -1
  50. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/heterogeneous/linear_replacements.py +4 -0
  51. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/experts.py +1 -0
  52. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/moe_layer.py +2 -0
  53. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/moe_utils.py +6 -0
  54. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/router.py +23 -2
  55. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/multi_latent_attention.py +9 -3
  56. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/multi_token_prediction.py +10 -3
  57. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/transformer_block.py +22 -11
  58. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/transformer_config.py +31 -2
  59. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/transformer_layer.py +0 -4
  60. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/utils.py +3 -0
  61. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2/megatron_core.egg-info}/PKG-INFO +2 -2
  62. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron_core.egg-info/requires.txt +1 -1
  63. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/pyproject.toml +13 -3
  64. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/LICENSE +0 -0
  65. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/MANIFEST.in +0 -0
  66. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/README.md +0 -0
  67. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/README.md +0 -0
  68. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/__init__.py +0 -0
  69. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/config.py +0 -0
  70. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/config_logger.py +0 -0
  71. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/__init__.py +0 -0
  72. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/bert_dataset.py +0 -0
  73. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/blended_dataset.py +0 -0
  74. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  75. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  76. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/gpt_dataset.py +0 -0
  77. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/helpers.cpp +0 -0
  78. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/helpers.py +0 -0
  79. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/masked_dataset.py +0 -0
  80. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/megatron_dataset.py +0 -0
  81. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  82. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/multimodal_dataset.py +0 -0
  83. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/object_storage_utils.py +0 -0
  84. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/__init__.py +0 -0
  85. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/__init__.py +0 -0
  86. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  87. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/config.py +0 -0
  88. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  89. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  90. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/__init__.py +0 -0
  91. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/build.py +0 -0
  92. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/dataset.py +0 -0
  93. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/db/utils.py +0 -0
  94. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/external_libs.py +0 -0
  95. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/__init__.py +0 -0
  96. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/build.py +0 -0
  97. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/factory.py +0 -0
  98. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/index.py +0 -0
  99. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  100. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  101. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  102. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/utils.py +0 -0
  103. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/index/validate.py +0 -0
  104. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/__init__.py +0 -0
  105. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  106. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  107. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/query.py +0 -0
  108. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  109. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/query/utils.py +0 -0
  110. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/retro/utils.py +0 -0
  111. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/t5_dataset.py +0 -0
  112. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/utils.py +0 -0
  113. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/utils_object_storage.py +0 -0
  114. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/datasets/utils_s3.py +0 -0
  115. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/__init__.py +0 -0
  116. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/core.py +0 -0
  117. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  118. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  119. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/mapping.py +0 -0
  120. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  121. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/serialization.py +0 -0
  122. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  123. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  124. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  125. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  126. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  127. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  128. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  129. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  130. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  131. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  132. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  133. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  134. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  135. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  136. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  137. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/utils.py +0 -0
  138. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/dist_checkpointing/validation.py +0 -0
  139. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/__init__.py +0 -0
  140. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
  141. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -0
  142. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
  143. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/data_parallel_base.py +0 -0
  144. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  145. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/finalize_model_grads.py +0 -0
  146. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  147. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  148. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  149. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/energy_monitor.py +0 -0
  150. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/__init__.py +0 -0
  151. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/data_type.py +0 -0
  152. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/export_config.py +0 -0
  153. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/model_type.py +0 -0
  154. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/__init__.py +0 -0
  155. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  156. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  157. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  158. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  159. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  160. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  161. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  162. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  163. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  164. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  165. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  166. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  167. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/extensions/__init__.py +0 -0
  168. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/extensions/kitchen.py +0 -0
  169. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  170. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/__init__.py +0 -0
  171. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  172. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  173. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  174. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  175. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  176. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_indices_converter.py +0 -0
  177. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_layer_norm.py +0 -0
  178. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  179. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  180. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/fusions/fused_softmax.py +0 -0
  181. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/hyper_comm_grid.py +0 -0
  182. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/__init__.py +0 -0
  183. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/async_stream.py +0 -0
  184. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/common_inference_params.py +0 -0
  185. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/communication_utils.py +0 -0
  186. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/__init__.py +0 -0
  187. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/base_context.py +0 -0
  188. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  189. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/__init__.py +0 -0
  190. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/abstract_engine.py +0 -0
  191. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/mcore_engine.py +0 -0
  192. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/engines/static_engine.py +0 -0
  193. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  194. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  195. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  196. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  197. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  198. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/sampling_params.py +0 -0
  199. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/scheduler.py +0 -0
  200. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  201. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  202. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/inference_params.py +0 -0
  203. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/jit.py +0 -0
  204. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/T5/__init__.py +0 -0
  205. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/__init__.py +0 -0
  206. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/backends.py +0 -0
  207. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/bert/__init__.py +0 -0
  208. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/bert/bert_lm_head.py +0 -0
  209. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/bert/bert_model.py +0 -0
  210. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/bert/pooler.py +0 -0
  211. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/__init__.py +0 -0
  212. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/__init__.py +0 -0
  213. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  214. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  215. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  216. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/language_module/__init__.py +0 -0
  217. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/vision_module/__init__.py +0 -0
  218. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  219. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/__init__.py +0 -0
  220. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  221. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/__init__.py +0 -0
  222. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/clip_model.py +0 -0
  223. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/module.py +0 -0
  224. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/huggingface/qwen_model.py +0 -0
  225. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mamba/__init__.py +0 -0
  226. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  227. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/__init__.py +0 -0
  228. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/config/__init__.py +0 -0
  229. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/config/base_configs.py +0 -0
  230. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/model/__init__.py +0 -0
  231. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/model/base.py +0 -0
  232. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/submodules/base.py +0 -0
  233. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/mimo/submodules/vision.py +0 -0
  234. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/__init__.py +0 -0
  235. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/context_parallel.py +0 -0
  236. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/multimodal/llava_spec.py +0 -0
  237. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/retro/__init__.py +0 -0
  238. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/retro/base_attention.py +0 -0
  239. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/retro/config.py +0 -0
  240. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/retro/decoder_attention.py +0 -0
  241. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/retro/encoder_attention.py +0 -0
  242. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/retro/model.py +0 -0
  243. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/retro/utils.py +0 -0
  244. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/vision/__init__.py +0 -0
  245. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  246. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/msc_utils.py +0 -0
  247. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/num_microbatches_calculator.py +0 -0
  248. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/optimizer/clip_grads.py +0 -0
  249. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  250. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  251. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/optimizer/grad_scaler.py +0 -0
  252. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/optimizer/optimizer_config.py +0 -0
  253. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/optimizer_param_scheduler.py +0 -0
  254. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/packed_seq_params.py +0 -0
  255. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/pipeline_parallel/__init__.py +0 -0
  256. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/post_training/__init__.py +0 -0
  257. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/__init__.py +0 -0
  258. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  259. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  260. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  261. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/layers.py +0 -0
  262. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  263. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  264. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/process_groups_config.py +0 -0
  265. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/quantization/__init__.py +0 -0
  266. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/quantization/quant_config.py +0 -0
  267. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/quantization/utils.py +0 -0
  268. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/requirements.txt +0 -0
  269. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/ssm/__init__.py +0 -0
  270. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_block.py +0 -0
  271. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  272. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  273. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_layer.py +0 -0
  274. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/ssm/mamba_mixer.py +0 -0
  275. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/ssm/mlp_layer.py +0 -0
  276. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/ssm/triton_cache_manager.py +0 -0
  277. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/__init__.py +0 -0
  278. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  279. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/data.py +0 -0
  280. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/mappings.py +0 -0
  281. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/random.py +0 -0
  282. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/tensor_parallel/utils.py +0 -0
  283. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/timers.py +0 -0
  284. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/__init__.py +0 -0
  285. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  286. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  287. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/dot_product_attention.py +0 -0
  288. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  289. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/identity_op.py +0 -0
  290. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/mlp.py +0 -0
  291. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/module.py +0 -0
  292. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/__init__.py +0 -0
  293. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  294. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  295. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/shared_experts.py +0 -0
  296. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  297. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  298. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  299. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/spec_utils.py +0 -0
  300. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/torch_layer_norm.py +0 -0
  301. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/torch_norm.py +0 -0
  302. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron/core/transformer/utils.py +0 -0
  303. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron_core.egg-info/SOURCES.txt +0 -0
  304. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron_core.egg-info/dependency_links.txt +0 -0
  305. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/megatron_core.egg-info/top_level.txt +0 -0
  306. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/setup.cfg +0 -0
  307. {megatron_core-0.14.0rc0 → megatron_core-0.14.0rc2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.14.0rc0
3
+ Version: 0.14.0rc2
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -44,7 +44,7 @@ Requires-Dist: nvtx; extra == "dev"
44
44
  Requires-Dist: transformers; extra == "dev"
45
45
  Requires-Dist: multi-storage-client; extra == "dev"
46
46
  Requires-Dist: setuptools<80.0.0; extra == "dev"
47
- Requires-Dist: nvidia-modelopt[torch]; sys_platform != "darwin" and extra == "dev"
47
+ Requires-Dist: nvidia-modelopt[torch]~=0.31.0; sys_platform != "darwin" and extra == "dev"
48
48
  Requires-Dist: megatron-energon[av_decode]<7; extra == "dev"
49
49
  Provides-Extra: lts
50
50
  Requires-Dist: tqdm; extra == "lts"
@@ -5,6 +5,7 @@
5
5
 
6
6
  # Essentially re-written in entirety
7
7
 
8
+ import gc
8
9
  import logging
9
10
  import os
10
11
  import shutil
@@ -906,6 +907,10 @@ class IndexedDatasetBuilder(object):
906
907
  assert index.sequence_modes is not None, "sequence_modes cannot not be None"
907
908
  self.sequence_modes.extend(index.sequence_modes)
908
909
 
910
+ # Free up memory to make space for new indices
911
+ del index
912
+ gc.collect()
913
+
909
914
  # Concatenate data
910
915
  with self._open(get_bin_path(path_prefix), "rb") as f:
911
916
  shutil.copyfileobj(f, self.data_file)
@@ -113,6 +113,15 @@ class DistributedDataParallelConfig:
113
113
  """
114
114
 
115
115
  def __post_init__(self):
116
+ import os
117
+
116
118
  """Check the validity of the config."""
117
119
  if self.reuse_grad_buf_for_mxfp8_param_ag:
118
120
  assert self.fp8_param_gather, "Reuse grad buffer only when keeping params in MXFP8."
121
+
122
+ if self.nccl_ub:
123
+ if 'expandable_segments:True' in os.getenv('PYTORCH_CUDA_ALLOC_CONF', '').split(','):
124
+ raise ValueError(
125
+ "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True is currently not supported "
126
+ "with nccl_ub due to compatibility issue with torch.cuda.MemPool API."
127
+ )
@@ -7,9 +7,16 @@ class ModelType(enum.Enum):
7
7
  """Model type."""
8
8
 
9
9
  encoder_or_decoder = 1
10
- encoder_and_decoder = 2
11
- retro_encoder = 3
12
- retro_decoder = 4
10
+ retro_encoder = 2
11
+ retro_decoder = 3
12
+
13
+ @property
14
+ def encoder_and_decoder(self):
15
+ """Deprecated property - use encoder_or_decoder instead."""
16
+ raise ValueError(
17
+ "ModelType.encoder_and_decoder is deprecated. Please use ModelType.encoder_or_decoder "
18
+ "instead."
19
+ )
13
20
 
14
21
 
15
22
  class Fp8Recipe(str, enum.Enum):
@@ -39,6 +39,8 @@ from megatron.core.transformer.enums import AttnMaskType
39
39
  from megatron.core.transformer.transformer_config import TransformerConfig
40
40
  from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
41
41
  from megatron.core.utils import (
42
+ get_pg_rank,
43
+ get_pg_size,
42
44
  get_te_version,
43
45
  get_tensor_model_parallel_group_if_none,
44
46
  is_te_min_version,
@@ -228,8 +230,7 @@ class TELinear(te.pytorch.Linear):
228
230
  assert tp_group is None, "duplicated linear should not have tp_group set"
229
231
  tp_size = 1
230
232
  else:
231
- assert tp_group is not None, "Parallel linear should always have tp_group set"
232
- tp_size = tp_group.size()
233
+ tp_size = get_pg_size(tp_group)
233
234
 
234
235
  self.expert_parallel = self.config.expert_model_parallel_size > 1
235
236
  if is_expert:
@@ -374,8 +375,8 @@ class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
374
375
  self.is_first_microbatch = True
375
376
  self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
376
377
  extra_kwargs = _get_extra_te_kwargs(config)
377
- self.tp_size = tp_group.size()
378
- self.tp_rank = tp_group.rank()
378
+ self.tp_size = get_pg_size(tp_group)
379
+ self.tp_rank = get_pg_rank(tp_group)
379
380
 
380
381
  if self.config.delay_wgrad_compute:
381
382
  if is_te_min_version("2.3.0"):
@@ -542,8 +543,8 @@ class TEColumnParallelLinear(TELinear):
542
543
  if gather_output:
543
544
  raise ValueError("Transformer Engine linear layers do not support gather_output = True")
544
545
  tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
545
- world_size = tp_group.size()
546
- rank = tp_group.rank()
546
+ world_size = get_pg_size(tp_group)
547
+ rank = get_pg_rank(tp_group)
547
548
 
548
549
  super().__init__(
549
550
  input_size=input_size,
@@ -657,8 +658,8 @@ class TERowParallelLinear(TELinear):
657
658
  tp_group=tp_group,
658
659
  )
659
660
  if config.use_cpu_initialization:
660
- world_size = tp_group.size()
661
- rank = tp_group.rank()
661
+ world_size = get_pg_size(tp_group)
662
+ rank = get_pg_rank(tp_group)
662
663
  input_size_per_partition = divide(input_size, world_size)
663
664
  self.master_weight = _initialize_affine_weight_cpu(
664
665
  self.weight,
@@ -1003,7 +1004,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"):
1003
1004
  # The comms between TP and EP group is explicitly handled by MoE token dispatcher.
1004
1005
  # So we disable comms by making TE agnostic of model parallel.
1005
1006
  tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
1006
- tp_size = tp_group.size()
1007
+ tp_size = get_pg_size(tp_group)
1007
1008
 
1008
1009
  self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
1009
1010
 
@@ -346,8 +346,12 @@ else:
346
346
  def _modify_underlying_storage_impl(*args, **kwargs):
347
347
  raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
348
348
 
349
- def _quantize_param_shard_impl(*args, **kwargs):
350
- raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
349
+ def _quantize_param_shard_impl(model_params, *args, **kwargs):
350
+ if len(model_params) == 0:
351
+ return
352
+ else:
353
+ # If TE is not installed, there shouldn't be any fp8 params.
354
+ raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
351
355
 
352
356
  def _correct_amax_history_if_needed_impl(*args, **kwargs):
353
357
  # If TE is not installed, we are definitely not using fp8 for training, so no correction
@@ -2,9 +2,11 @@
2
2
 
3
3
  import math
4
4
  import warnings
5
- from typing import Optional, Tuple
5
+ from typing import List, Optional, Tuple
6
6
 
7
7
  import torch
8
+ import torch.nn.functional as F
9
+ from packaging.version import Version as PkgVersion
8
10
  from torch import Tensor
9
11
 
10
12
  from megatron.core import parallel_state
@@ -123,8 +125,10 @@ class DynamicInferenceContext(BaseInferenceContext):
123
125
  max_requests_override: Optional[int] = None,
124
126
  max_tokens_override: Optional[int] = None,
125
127
  tensor_model_parallel_size: Optional[int] = None,
128
+ materialize_only_last_token_logits: bool = True,
126
129
  ):
127
- super().__init__(materialize_only_last_token_logits=True)
130
+
131
+ super().__init__(materialize_only_last_token_logits=materialize_only_last_token_logits)
128
132
  # Per partition num heads and hidden size.
129
133
  projection_size = kv_channels * num_attention_heads
130
134
  if tensor_model_parallel_size is None:
@@ -762,7 +766,7 @@ class DynamicInferenceContext(BaseInferenceContext):
762
766
  self.total_request_count += 1
763
767
  self.active_token_count += context_length
764
768
 
765
- def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens):
769
+ def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens):
766
770
  """
767
771
  Swaps all the relevent booking tensors with src idxs to dst idxs
768
772
  """
@@ -866,7 +870,12 @@ class DynamicInferenceContext(BaseInferenceContext):
866
870
  kv_chunks_asigned = self.request_to_kv_chunk_ids[finished_idxs]
867
871
  non_zero_values_in_kv_memory = kv_chunks_asigned[kv_chunks_asigned != -1]
868
872
  self.chunk_allocator.release_memory_chunks(non_zero_values_in_kv_memory)
869
- self.request_to_kv_chunk_ids[finished_idxs].fill_(-1)
873
+
874
+ # Reset the KV chunks for finished requests.
875
+ # Note: do not use fill_() (or add_() and similar inplace ops) here.
876
+ # The combinition of indexing with a tensor (like finished_idxs) and fill_()/add_() creates a clone
877
+ # and updates it instead of the original tensor.
878
+ self.request_to_kv_chunk_ids[finished_idxs] = -1
870
879
 
871
880
  if active_request_count > 0:
872
881
  finished_idxs_on_left = (
@@ -881,12 +890,15 @@ class DynamicInferenceContext(BaseInferenceContext):
881
890
  + self.paused_request_count
882
891
  )
883
892
 
884
- self._swap_book_keeping_tensors(
893
+ self._move_book_keeping_tensors(
885
894
  src_idxs=active_idxs_on_right,
886
895
  dst_idxs=finished_idxs_on_left,
887
896
  next_tokens=next_tokens,
888
897
  )
889
898
 
899
+ # Reset chunk ids for recently moved requests.
900
+ self.request_to_kv_chunk_ids[active_idxs_on_right] = -1
901
+
890
902
  # 5. We identify requests that require a new chunk and add them to the paused requests (i.e move them left) :-
891
903
  # a) Put requests that have filled their current chunk and require a new one in a pause state temporarily
892
904
  # b) Move the paused requests to the left, and active requets to the right
@@ -931,7 +943,7 @@ class DynamicInferenceContext(BaseInferenceContext):
931
943
  )
932
944
  dst_idxs = torch.cat((active_request_ids_on_left, paused_requests_idxs_on_right))
933
945
  src_idxs = torch.cat((paused_requests_idxs_on_right, active_request_ids_on_left))
934
- self._swap_book_keeping_tensors(
946
+ self._move_book_keeping_tensors(
935
947
  src_idxs=src_idxs, dst_idxs=dst_idxs, next_tokens=next_tokens
936
948
  )
937
949
 
@@ -974,6 +986,8 @@ class DynamicInferenceContext(BaseInferenceContext):
974
986
  if self.paused_request_count > 0:
975
987
  self.paused_tokens = next_tokens[: self.paused_request_count]
976
988
 
989
+ # add_ and fill_ calls seems to work as intended with sliced indexing (i.e. x[3:5].add(...) or x[3:5].fill_)
990
+ # but when another tensor is used for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors)
977
991
  self.request_kv_length_offsets[self.paused_request_count : self.total_request_count].add_(
978
992
  self.request_query_lengths[self.paused_request_count : self.total_request_count]
979
993
  )
@@ -1027,3 +1041,35 @@ class DynamicInferenceContext(BaseInferenceContext):
1027
1041
  self.token_to_local_position_within_kv_chunk[: self.active_token_count] = (
1028
1042
  self.request_last_kv_chunk_offset[self.paused_request_count : self.total_request_count]
1029
1043
  )
1044
+
1045
+ def calculate_log_probs(self, logits: torch.Tensor) -> List[List[float]]:
1046
+ """Calculate log probs for all active requests and return them.
1047
+
1048
+ TODO: @wdykas support top-n log probs.
1049
+
1050
+ Args:
1051
+ logits: Raw model output logits with shape [1, sequence_length, vocab_size].
1052
+
1053
+ Returns:
1054
+ List of lists where each inner list contains log probs for a request in the
1055
+ same order as the active requests (from paused_request_count to total_request_count).
1056
+ """
1057
+ # Calculate log_probs (sequence_length x vocab_size)
1058
+ log_probs = F.log_softmax(logits, dim=-1).to(torch.float32).squeeze()
1059
+
1060
+ # Extract the log probs for only the selected tokens
1061
+ # (sequence_length x vocab_size) -> (sequence_length)
1062
+ active_token_ids = self.token_to_input_ids[: self.active_token_count]
1063
+ sequence_indices = torch.arange(self.active_token_count, device=log_probs.device)
1064
+ selected_log_probs = log_probs[sequence_indices, active_token_ids]
1065
+
1066
+ # Split the log probs across request boundaries
1067
+ active_query_lengths = self.request_query_lengths[
1068
+ self.paused_request_count : self.total_request_count
1069
+ ]
1070
+ selected_log_probs_list = selected_log_probs.cpu().split(
1071
+ active_query_lengths.tolist(), dim=0
1072
+ )
1073
+
1074
+ # Convert each log prob tensor into a list
1075
+ return [lp.tolist() for lp in selected_log_probs_list]
@@ -17,7 +17,7 @@ class StaticInferenceContext(BaseInferenceContext):
17
17
  """
18
18
 
19
19
  def __init__(self, max_batch_size: int, max_sequence_length: int):
20
- super().__init__(materialize_only_last_token_logits=False)
20
+ super().__init__(materialize_only_last_token_logits=True)
21
21
  self.max_sequence_length = max_sequence_length
22
22
  self.max_batch_size = max_batch_size
23
23
  self.sequence_len_offset = 0
@@ -1,8 +1,8 @@
1
1
  # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
2
 
3
3
  import asyncio
4
- import time
5
4
  from collections import deque
5
+ from itertools import repeat
6
6
  from typing import Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import torch
@@ -70,6 +70,8 @@ class DynamicInferenceEngine(AbstractEngine):
70
70
  self.request_counter = Counter()
71
71
  self.requests: Dict[int, DynamicInferenceRequest] = {}
72
72
  self.request_completion_futures: Dict[int, asyncio.Future] = {}
73
+ self.step_start_event = torch.cuda.Event(enable_timing=True)
74
+ self.step_end_event = torch.cuda.Event(enable_timing=True)
73
75
 
74
76
  # Initialize the asyncio loop if it has not already been initialized.
75
77
  # TODO: Start the engine loop here.
@@ -176,26 +178,49 @@ class DynamicInferenceEngine(AbstractEngine):
176
178
  return self.request_completion_futures[request_id]
177
179
 
178
180
  def post_process_requests(
179
- self, request_ids: torch.Tensor, finished_request_ids: torch.Tensor, sample: torch.Tensor
180
- ) -> List[DynamicInferenceRequest]:
181
+ self,
182
+ request_ids: torch.Tensor,
183
+ finished_request_ids: torch.Tensor,
184
+ step_time: float,
185
+ sample: torch.Tensor,
186
+ log_probs: torch.Tensor,
187
+ ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]:
181
188
  """
182
189
  Handles post-processing for requests after a step.
183
190
 
184
191
  Args:
185
192
  request_ids (torch.Tensor): A list of request_ids
186
193
  finished_request_ids (torch.Tensor): A list of finished request ids
194
+ step_time (float): The latency of the last step
187
195
  sample: (torch.Tensor): The newly generated tokens for each request
196
+ log_probs: (List): Log probs for each request
188
197
 
189
198
  Returns:
190
- A list of completed requests as `DynamicInferenceRequest` objects
199
+ A list of active requests and completed requests as `DynamicInferenceRequest` objects
191
200
  """
201
+ active_requests: List[DynamicInferenceRequest] = []
192
202
  finished_requests: List[DynamicInferenceRequest] = []
193
203
  finished_request_ids = set(finished_request_ids.tolist())
194
204
  self.finished_request_count += len(finished_request_ids)
195
205
 
196
- for request_id, token in zip(request_ids.tolist(), sample.tolist()):
206
+ log_probs_iter = log_probs if log_probs else repeat(None)
207
+
208
+ for request_id, token, request_log_probs in zip(
209
+ request_ids.tolist(), sample.tolist(), log_probs_iter
210
+ ):
197
211
  request: DynamicInferenceRequest = self.requests[request_id]
198
212
  request.generated_tokens.append(token)
213
+ if request.tpot is None:
214
+ request.tpot = []
215
+ request.tpot.append(step_time)
216
+
217
+ if request_log_probs is not None:
218
+ # If prompt log probs is None we are in prefill
219
+ if request.prompt_log_probs is None:
220
+ request.prompt_log_probs = request_log_probs
221
+ request.generated_log_probs = []
222
+ else:
223
+ request.generated_log_probs.extend(request_log_probs)
199
224
 
200
225
  if request_id in finished_request_ids:
201
226
  request.generated_length = len(request.generated_tokens)
@@ -207,50 +232,67 @@ class DynamicInferenceEngine(AbstractEngine):
207
232
  finished_request.generated_tokens
208
233
  )
209
234
  self.request_completion_futures[request_id].set_result(finished_request)
210
-
211
- return finished_requests
235
+ else:
236
+ active_requests.append(request)
237
+
238
+ return active_requests, finished_requests
239
+
240
+ def schedule_waiting_requests(self):
241
+ """Tries to schedule any requests in the waiting pool."""
242
+ for waiting_request_id in self.waiting_request_ids.copy():
243
+ waiting_request: DynamicInferenceRequest = self.requests[waiting_request_id]
244
+ try:
245
+ self.context.add_request(
246
+ waiting_request_id,
247
+ waiting_request.prompt_tokens,
248
+ waiting_request.sampling_params.num_tokens_to_generate,
249
+ )
250
+ self.waiting_request_ids.popleft()
251
+ except Exception as e:
252
+ break
212
253
 
213
254
  async def async_step(
214
255
  self, sampling_params: SamplingParams, *, verbose: Optional[bool] = False
215
- ) -> Tuple[List[DynamicInferenceRequest], float]:
216
- """Wrapper for controller.generate_output_tokens_dynamic_batch(), to
217
- match vLLM API.
218
-
219
- Uses `asyncio` for continuous generation which allows this
256
+ ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest], float]:
257
+ """
258
+ Wrapper for controller.generate_output_tokens_dynamic_batch(), to
259
+ match vLLM API. Uses `asyncio` for continuous generation which allows this
220
260
  method to sleep and wake up when new requests are available.
261
+
262
+ Args:
263
+ sampling_params (SamplingParams): The sampling parameters.
264
+ verbose (bool): Whether to run in verbose mode.
265
+
266
+ Returns:
267
+ A tuple comprised of:
268
+ 1. Requests that ran in the last step and are still active.
269
+ 2. Requests that ran in the last step and have now finished.
270
+ 3. The step time in seconds.
221
271
  """
222
272
 
223
273
  # Generate tokens.
224
- t = time.time()
225
274
  is_decode_only = self.context.is_decode_only()
275
+ self.step_start_event.record()
226
276
  result = self.controller.generate_output_tokens_dynamic_batch(
227
277
  sampling_params, self.termination_id
228
278
  )
229
- step_time = time.time() - t
230
-
231
- finished_requests: List[DynamicInferenceRequest] = []
279
+ self.step_end_event.record()
280
+ self.step_end_event.synchronize()
281
+ step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3
232
282
 
233
283
  if result is not None:
234
- request_ids, finished_request_ids, sample = result
284
+ request_ids, finished_request_ids, sample, log_probs = result
235
285
 
236
286
  # TODO: Move this to a background thread?
237
- finished_requests.extend(
238
- self.post_process_requests(request_ids, finished_request_ids, sample)
287
+ (active_requests, finished_requests) = self.post_process_requests(
288
+ request_ids, finished_request_ids, step_time, sample, log_probs
239
289
  )
240
290
 
241
- # Schedule waiting requests
242
291
  # TODO: Move this to a background thread?
243
- for waiting_request_id in self.waiting_request_ids.copy():
244
- waiting_request: DynamicInferenceRequest = self.requests[waiting_request_id]
245
- try:
246
- self.context.add_request(
247
- waiting_request_id,
248
- waiting_request.prompt_tokens,
249
- waiting_request.sampling_params.num_tokens_to_generate,
250
- )
251
- self.waiting_request_ids.popleft()
252
- except Exception as e:
253
- break
292
+ self.schedule_waiting_requests()
293
+ else:
294
+ active_requests: List[DynamicInferenceRequest] = []
295
+ finished_requests: List[DynamicInferenceRequest] = []
254
296
 
255
297
  # Print context state.
256
298
  if verbose:
@@ -278,9 +320,11 @@ class DynamicInferenceEngine(AbstractEngine):
278
320
  )
279
321
  )
280
322
 
281
- return finished_requests, step_time
323
+ return active_requests, finished_requests, step_time
282
324
 
283
- def step(self, sampling_params: SamplingParams, *, verbose: Optional[bool] = False):
325
+ def step(
326
+ self, sampling_params: SamplingParams, *, verbose: Optional[bool] = False
327
+ ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest], float]:
284
328
  """Synchronous wrapper for `self.async_step`."""
285
329
  return self._loop.run_until_complete(
286
330
  self.async_step(sampling_params=sampling_params, verbose=verbose)
@@ -297,7 +341,7 @@ class DynamicInferenceEngine(AbstractEngine):
297
341
 
298
342
  finished_requests_list = []
299
343
  while self.has_unfinished_requests():
300
- finished_requests, step_time = self.step(sampling_params)
344
+ active_requests, finished_requests, step_time = self.step(sampling_params)
301
345
  finished_requests_list.extend(finished_requests)
302
346
 
303
347
  return finished_requests_list
@@ -46,6 +46,7 @@ class InferenceRequest:
46
46
  prompt_top_n_logprobs: Optional[List[Dict[str, float]]] = None
47
47
  generated_top_n_logprobs: Optional[List[Dict[str, float]]] = None
48
48
  generated_length: Optional[int] = None
49
+ tpot: Optional[List[int]] = None
49
50
 
50
51
  def __post_init__(self):
51
52
  if self.sampling_params is None and self.inference_parameters is not None:
@@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, Optional, Union
7
7
 
8
8
  import torch
9
9
 
10
- from megatron.core import parallel_state, tensor_parallel
10
+ from megatron.core import parallel_state
11
11
  from megatron.core.inference.communication_utils import (
12
12
  is_pipeline_first_stage,
13
13
  is_pipeline_last_stage,
@@ -152,13 +152,12 @@ class AbstractModelInferenceWrapper(abc.ABC):
152
152
  tokens = inference_input["tokens"]
153
153
  position_ids = inference_input["position_ids"]
154
154
  attention_mask = inference_input["attention_mask"]
155
- runtime_gather_output = inference_input.get("runtime_gather_output")
156
155
  return self.model(
157
156
  tokens,
158
157
  position_ids,
159
158
  attention_mask,
160
159
  inference_context=self.inference_context,
161
- runtime_gather_output=runtime_gather_output,
160
+ runtime_gather_output=True, # Inference should always gather the logits
162
161
  )
163
162
 
164
163
  def _get_batch_size_and_seq_len(
@@ -201,7 +200,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
201
200
  """
202
201
  tokens = inference_input["tokens"]
203
202
  logits = self._forward(inference_input)
204
- logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits, self.tp_group)
205
203
  self.inference_context.increment_sequence_len_offset(tokens.size(1))
206
204
 
207
205
  return logits
@@ -243,7 +241,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
243
241
  logits = None
244
242
  if is_pipeline_last_stage(self.pp_group):
245
243
  logits = output_tensor
246
- logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits, self.tp_group)
247
244
 
248
245
  # Explicitly cast logits to expected dtype
249
246
  logits = logits.to(self.inference_wrapper_config.params_dtype)
@@ -269,7 +266,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
269
266
  tokens = inference_input["tokens"]
270
267
  position_ids = inference_input["position_ids"]
271
268
  attention_mask = inference_input["attention_mask"]
272
- runtime_gather_output = inference_input.get("runtime_gather_output")
273
269
  materialize_only_last_token_logits = (
274
270
  self.inference_context.materialize_only_last_token_logits
275
271
  )
@@ -317,7 +313,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
317
313
  "position_ids": position_ids2use,
318
314
  "attention_mask": attention_mask,
319
315
  "inference_context": self.inference_context,
320
- "runtime_gather_output": runtime_gather_output,
321
316
  }
322
317
  )
323
318
 
@@ -327,9 +322,6 @@ class AbstractModelInferenceWrapper(abc.ABC):
327
322
  self.inference_context.batch_size_offset += current_micro_batch_size
328
323
 
329
324
  if is_pipeline_last_stage(self.pp_group):
330
- output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
331
- output_tensor, self.tp_group
332
- )
333
325
  assert logits is not None
334
326
  logits[start:end, ...] = output_tensor
335
327
 
@@ -10,6 +10,7 @@ from megatron.core.inference.model_inference_wrappers.abstract_model_inference_w
10
10
  from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
11
11
  InferenceWrapperConfig,
12
12
  )
13
+ from megatron.core.inference.utils import get_attention_mask
13
14
  from megatron.core.models.gpt import GPTModel
14
15
  from megatron.core.transformer.enums import AttnBackend
15
16
  from megatron.core.utils import get_model_config
@@ -74,12 +75,7 @@ class GPTInferenceWrapper(AbstractModelInferenceWrapper):
74
75
  attention_backend = config.attention_backend
75
76
 
76
77
  if attention_backend == AttnBackend.local:
77
- attention_mask = torch.tril(
78
- torch.ones((1, seq_length, seq_length), device=prompts_tokens.device)
79
- ).view(1, 1, seq_length, seq_length)
80
-
81
- # Convert to boolean
82
- attention_mask = attention_mask < 0.5
78
+ attention_mask = get_attention_mask(seq_length)
83
79
  elif (
84
80
  attention_backend == AttnBackend.flash
85
81
  or attention_backend == AttnBackend.fused
@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional
4
4
 
5
5
  import torch
6
6
 
7
- from megatron.core import parallel_state
8
7
  from megatron.core.inference.communication_utils import (
9
8
  is_pipeline_first_stage,
10
9
  is_pipeline_last_stage,
@@ -48,16 +47,10 @@ class VLMInferenceWrapper(GPTInferenceWrapper):
48
47
  # has part of the LM decoder. In this case, the current stage should only receive
49
48
  # vision embeddings.
50
49
  if pp_rank > 0:
51
- self._recv_only_vision_embeds = (
52
- parallel_state.is_inside_encoder(pp_rank - 1)
53
- and (not parallel_state.is_inside_decoder(pp_rank - 1))
54
- and parallel_state.is_inside_decoder()
55
- )
50
+ self._recv_only_vision_embeds = False # TODO: Implement new logic for vision embeddings
56
51
 
57
52
  # Checks if the current stage only has a vision encoder
58
- self._encoder_only = (
59
- parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder()
60
- )
53
+ self._encoder_only = False # TODO: Implement new logic for encoder-only stages
61
54
 
62
55
  def prep_inference_input(
63
56
  self,
@@ -7,6 +7,7 @@ from megatron.core.inference.inference_request import InferenceRequest
7
7
  from megatron.core.inference.text_generation_controllers.text_generation_controller import (
8
8
  TextGenerationController,
9
9
  )
10
+ from megatron.core.inference.utils import get_attention_mask
10
11
 
11
12
 
12
13
  class EncoderDecoderTextGenerationController(TextGenerationController):
@@ -18,13 +19,18 @@ class EncoderDecoderTextGenerationController(TextGenerationController):
18
19
  """
19
20
 
20
21
  def prep_inference_input(
21
- self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
22
+ self,
23
+ prompts_tokens: torch.Tensor,
24
+ active_requests: OrderedDict[str, InferenceRequest],
25
+ use_attention_mask: bool = False,
22
26
  ) -> Dict[str, Any]:
23
27
  """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
24
28
 
25
29
  Args:
26
30
  prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
27
31
  active_requests (OrderedDict[str, InferenceRequest]): The input active requests
32
+ use_attention_mask (bool): Whether to use an attention mask. Should be set to True only
33
+ when exclusively doing prefill (no decode) with variable prompt lengths.
28
34
 
29
35
  Returns:
30
36
  A dict of the inference input for the current batch.
@@ -33,6 +39,13 @@ class EncoderDecoderTextGenerationController(TextGenerationController):
33
39
  map(lambda request: request.encoder_prompt, active_requests.values())
34
40
  )
35
41
 
36
- return self.inference_wrapped_model.prep_inference_input(
42
+ inference_input = self.inference_wrapped_model.prep_inference_input(
37
43
  prompts_tokens, encoder_prompts, tokenizer=self.tokenizer
38
44
  )
45
+
46
+ if use_attention_mask and (
47
+ attention_mask := inference_input.get("attention_mask", None) is None
48
+ ):
49
+ inference_input["attention_mask"] = get_attention_mask(prompts_tokens.size(1))
50
+
51
+ return inference_input