megatron-core 0.16.0rc0.dev125968__tar.gz → 0.16.0rc0.dev126546__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (361) hide show
  1. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/PKG-INFO +1 -1
  2. megatron_core-0.16.0rc0.dev126546/megatron/core/inference/contexts/attention_context/metadata_base.py +72 -0
  3. megatron_core-0.16.0rc0.dev126546/megatron/core/inference/contexts/attention_context/mha_metadata.py +220 -0
  4. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/contexts/dynamic_context.py +77 -126
  5. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/engines/dynamic_engine.py +15 -1
  6. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +2 -1
  7. megatron_core-0.16.0rc0.dev126546/megatron/core/inference/unified_memory.py +127 -0
  8. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/gpt/gpt_model.py +1 -2
  9. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/package_info.py +1 -1
  10. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron_core.egg-info/PKG-INFO +1 -1
  11. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron_core.egg-info/SOURCES.txt +2 -0
  12. megatron_core-0.16.0rc0.dev125968/megatron/core/inference/unified_memory.py +0 -89
  13. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/MANIFEST.in +0 -0
  14. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/README.md +0 -0
  15. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/README.md +0 -0
  16. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/__init__.py +0 -0
  17. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/activations.py +0 -0
  18. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/config.py +0 -0
  19. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/config_logger.py +0 -0
  20. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/__init__.py +0 -0
  21. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/bert_dataset.py +0 -0
  22. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/blended_dataset.py +0 -0
  23. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  24. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  25. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/gpt_dataset.py +0 -0
  26. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/helpers.cpp +0 -0
  27. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/helpers.py +0 -0
  28. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/indexed_dataset.py +0 -0
  29. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/masked_dataset.py +0 -0
  30. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/megatron_dataset.py +0 -0
  31. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  32. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/multimodal_dataset.py +0 -0
  33. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/object_storage_utils.py +0 -0
  34. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/__init__.py +0 -0
  35. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/config/__init__.py +0 -0
  36. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  37. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/config/config.py +0 -0
  38. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  39. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  40. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/db/__init__.py +0 -0
  41. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/db/build.py +0 -0
  42. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/db/dataset.py +0 -0
  43. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/db/utils.py +0 -0
  44. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/external_libs.py +0 -0
  45. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/index/__init__.py +0 -0
  46. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/index/build.py +0 -0
  47. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/index/factory.py +0 -0
  48. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/index/index.py +0 -0
  49. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  50. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  51. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  52. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/index/utils.py +0 -0
  53. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/index/validate.py +0 -0
  54. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/query/__init__.py +0 -0
  55. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  56. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  57. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/query/query.py +0 -0
  58. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  59. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/query/utils.py +0 -0
  60. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/retro/utils.py +0 -0
  61. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/t5_dataset.py +0 -0
  62. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/utils.py +0 -0
  63. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/datasets/utils_s3.py +0 -0
  64. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/__init__.py +0 -0
  65. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/core.py +0 -0
  66. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  67. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  68. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/mapping.py +0 -0
  69. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  70. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/serialization.py +0 -0
  71. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  72. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  73. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  74. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  75. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  76. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/checkpointable.py +0 -0
  77. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  78. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  79. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  80. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  81. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  82. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  83. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  84. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  85. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  86. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  87. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/utils.py +0 -0
  88. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/dist_checkpointing/validation.py +0 -0
  89. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/__init__.py +0 -0
  90. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/data_parallel_base.py +0 -0
  91. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  92. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  93. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/finalize_model_grads.py +0 -0
  94. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/__init__.py +0 -0
  95. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +0 -0
  96. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/src/__init__.py +0 -0
  97. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +0 -0
  98. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +0 -0
  99. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +0 -0
  100. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +0 -0
  101. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/src/megatron_fsdp/package_info.py +0 -0
  102. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +0 -0
  103. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +0 -0
  104. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +0 -0
  105. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  106. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/reduce_scatter_with_fp32_accumulation.py +0 -0
  107. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  108. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  109. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/energy_monitor.py +0 -0
  110. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/enums.py +0 -0
  111. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/__init__.py +0 -0
  112. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/data_type.py +0 -0
  113. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/export_config.py +0 -0
  114. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/model_type.py +0 -0
  115. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/__init__.py +0 -0
  116. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  117. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  118. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  119. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  120. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  121. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  122. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  123. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  124. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  125. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  126. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  127. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  128. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/extensions/__init__.py +0 -0
  129. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/extensions/kitchen.py +0 -0
  130. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/extensions/transformer_engine.py +0 -0
  131. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  132. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fp4_utils.py +0 -0
  133. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fp8_utils.py +0 -0
  134. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/full_cuda_graph.py +0 -0
  135. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/__init__.py +0 -0
  136. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  137. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  138. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  139. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  140. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  141. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_indices_converter.py +0 -0
  142. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_layer_norm.py +0 -0
  143. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  144. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  145. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_softmax.py +0 -0
  146. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/fusions/fused_weighted_squared_relu.py +0 -0
  147. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/hyper_comm_grid.py +0 -0
  148. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/__init__.py +0 -0
  149. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/async_stream.py +0 -0
  150. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/common_inference_params.py +0 -0
  151. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/communication_utils.py +0 -0
  152. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/contexts/__init__.py +0 -0
  153. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/contexts/base_context.py +0 -0
  154. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/contexts/dynamic_block_allocator.py +0 -0
  155. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/contexts/fused_kv_append_kernel.py +0 -0
  156. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/contexts/static_context.py +0 -0
  157. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/data_parallel_inference_coordinator.py +0 -0
  158. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/engines/__init__.py +0 -0
  159. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/engines/abstract_engine.py +0 -0
  160. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/engines/mcore_engine.py +0 -0
  161. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/engines/static_engine.py +0 -0
  162. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/headers.py +0 -0
  163. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/inference_client.py +0 -0
  164. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/inference_request.py +0 -0
  165. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  166. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  167. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  168. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  169. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  170. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  171. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  172. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  173. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/sampling_params.py +0 -0
  174. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/scheduler.py +0 -0
  175. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  176. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  177. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  178. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  179. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_server/__init__.py +0 -0
  180. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_server/endpoints/common.py +0 -0
  181. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_server/endpoints/completions.py +0 -0
  182. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_server/run_mcore_engine.py +0 -0
  183. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_server/text_generation_server.py +0 -0
  184. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/text_generation_server/tokenization.py +0 -0
  185. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference/utils.py +0 -0
  186. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/inference_params.py +0 -0
  187. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/jit.py +0 -0
  188. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/model_parallel_config.py +0 -0
  189. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/T5/__init__.py +0 -0
  190. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/T5/t5_model.py +0 -0
  191. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/T5/t5_spec.py +0 -0
  192. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/__init__.py +0 -0
  193. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/backends.py +0 -0
  194. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/bert/__init__.py +0 -0
  195. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  196. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/bert/bert_lm_head.py +0 -0
  197. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/bert/bert_model.py +0 -0
  198. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/bert/pooler.py +0 -0
  199. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/__init__.py +0 -0
  200. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/embeddings/__init__.py +0 -0
  201. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  202. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  203. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  204. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  205. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  206. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/language_module/__init__.py +0 -0
  207. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/language_module/language_module.py +0 -0
  208. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
  209. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/vision_module/__init__.py +0 -0
  210. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  211. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/gpt/__init__.py +0 -0
  212. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  213. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
  214. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  215. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  216. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/huggingface/__init__.py +0 -0
  217. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/huggingface/clip_model.py +0 -0
  218. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/huggingface/module.py +0 -0
  219. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/huggingface/qwen_model.py +0 -0
  220. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mamba/__init__.py +0 -0
  221. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  222. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mamba/mamba_model.py +0 -0
  223. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mimo/__init__.py +0 -0
  224. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mimo/config/__init__.py +0 -0
  225. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mimo/config/base_configs.py +0 -0
  226. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mimo/model/__init__.py +0 -0
  227. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mimo/model/base.py +0 -0
  228. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mimo/submodules/audio.py +0 -0
  229. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mimo/submodules/base.py +0 -0
  230. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/mimo/submodules/vision.py +0 -0
  231. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/multimodal/__init__.py +0 -0
  232. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/multimodal/context_parallel.py +0 -0
  233. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/multimodal/llava_model.py +0 -0
  234. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/multimodal/llava_spec.py +0 -0
  235. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/retro/__init__.py +0 -0
  236. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/retro/base_attention.py +0 -0
  237. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/retro/config.py +0 -0
  238. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/retro/decoder_attention.py +0 -0
  239. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/retro/decoder_spec.py +0 -0
  240. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/retro/encoder_attention.py +0 -0
  241. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/retro/encoder_spec.py +0 -0
  242. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/retro/model.py +0 -0
  243. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/retro/utils.py +0 -0
  244. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/vision/__init__.py +0 -0
  245. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/vision/clip_vit_model.py +0 -0
  246. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/vision/multimodal_projector.py +0 -0
  247. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/vision/radio.py +0 -0
  248. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  249. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/msc_utils.py +0 -0
  250. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/nccl_allocator.py +0 -0
  251. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/num_microbatches_calculator.py +0 -0
  252. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/optimizer/__init__.py +0 -0
  253. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/optimizer/clip_grads.py +0 -0
  254. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  255. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  256. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/optimizer/distrib_optimizer.py +0 -0
  257. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/optimizer/grad_scaler.py +0 -0
  258. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/optimizer/optimizer.py +0 -0
  259. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/optimizer/optimizer_config.py +0 -0
  260. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/optimizer_param_scheduler.py +0 -0
  261. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/packed_seq_params.py +0 -0
  262. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/parallel_state.py +0 -0
  263. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/pipeline_parallel/__init__.py +0 -0
  264. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/pipeline_parallel/bridge_communicator.py +0 -0
  265. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
  266. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  267. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/pipeline_parallel/schedules.py +0 -0
  268. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/pipeline_parallel/utils.py +0 -0
  269. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/post_training/__init__.py +0 -0
  270. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/post_training/modelopt/__init__.py +0 -0
  271. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  272. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  273. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  274. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/post_training/modelopt/layers.py +0 -0
  275. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  276. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  277. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/process_groups_config.py +0 -0
  278. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/quantization/__init__.py +0 -0
  279. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/quantization/quant_config.py +0 -0
  280. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/quantization/utils.py +0 -0
  281. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/requirements.txt +0 -0
  282. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/rerun_state_machine.py +0 -0
  283. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/safe_globals.py +0 -0
  284. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/ssm/__init__.py +0 -0
  285. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/ssm/mamba_block.py +0 -0
  286. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  287. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  288. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/ssm/mamba_layer.py +0 -0
  289. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/ssm/mamba_mixer.py +0 -0
  290. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/ssm/mlp_layer.py +0 -0
  291. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/ssm/triton_cache_manager.py +0 -0
  292. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tensor_parallel/__init__.py +0 -0
  293. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  294. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tensor_parallel/data.py +0 -0
  295. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tensor_parallel/layers.py +0 -0
  296. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tensor_parallel/mappings.py +0 -0
  297. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tensor_parallel/random.py +0 -0
  298. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tensor_parallel/utils.py +0 -0
  299. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/timers.py +0 -0
  300. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/__init__.py +0 -0
  301. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/base_tokenizer.py +0 -0
  302. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/megatron_tokenizer.py +0 -0
  303. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/__init__.py +0 -0
  304. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/libraries/__init__.py +0 -0
  305. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/libraries/abstract_tokenizer.py +0 -0
  306. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/libraries/bytelevel_tokenizer.py +0 -0
  307. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/libraries/chat_template.py +0 -0
  308. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py +0 -0
  309. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py +0 -0
  310. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/libraries/null_tokenizer.py +0 -0
  311. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py +0 -0
  312. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py +0 -0
  313. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/models/__init__.py +0 -0
  314. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/models/bert_tokenizer.py +0 -0
  315. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/models/default_tokenizer.py +0 -0
  316. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/models/gpt_tokenizer.py +0 -0
  317. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/models/mamba_tokenizer.py +0 -0
  318. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/models/retro_tokenizer.py +0 -0
  319. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/models/t5_tokenizer.py +0 -0
  320. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/text_tokenizer.py +0 -0
  321. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/tokenizers/text/utils/build_tokenizer.py +0 -0
  322. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/__init__.py +0 -0
  323. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/attention.py +0 -0
  324. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/cuda_graphs.py +0 -0
  325. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  326. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  327. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/dot_product_attention.py +0 -0
  328. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/enums.py +0 -0
  329. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/fsdp_dtensor_checkpoint.py +0 -0
  330. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  331. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  332. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/identity_op.py +0 -0
  333. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/mlp.py +0 -0
  334. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/module.py +0 -0
  335. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/__init__.py +0 -0
  336. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/experts.py +0 -0
  337. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  338. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  339. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/moe_layer.py +0 -0
  340. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/moe_utils.py +0 -0
  341. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/router.py +0 -0
  342. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/shared_experts.py +0 -0
  343. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  344. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  345. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/multi_latent_attention.py +0 -0
  346. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/multi_token_prediction.py +0 -0
  347. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  348. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/spec_utils.py +0 -0
  349. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/torch_layer_norm.py +0 -0
  350. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/torch_norm.py +0 -0
  351. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/transformer_block.py +0 -0
  352. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/transformer_config.py +0 -0
  353. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/transformer_layer.py +0 -0
  354. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/transformer/utils.py +0 -0
  355. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron/core/utils.py +0 -0
  356. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron_core.egg-info/dependency_links.txt +0 -0
  357. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron_core.egg-info/requires.txt +0 -0
  358. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/megatron_core.egg-info/top_level.txt +0 -0
  359. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/pyproject.toml +0 -0
  360. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/setup.cfg +0 -0
  361. {megatron_core-0.16.0rc0.dev125968 → megatron_core-0.16.0rc0.dev126546}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.16.0rc0.dev125968
3
+ Version: 0.16.0rc0.dev126546
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -0,0 +1,72 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+
3
+
4
+ class MetadataBase:
5
+ """
6
+ Base class for attention metadata.
7
+ High-performance attention kernels often require input metadata in specific
8
+ formats—such as cumulative query lengths, cumulative key/value lengths,
9
+ and similar structures. Moreover, when using CUDA Graphs, these metadata
10
+ buffers must be statically allocated. This class serves as a unified container
11
+ that manages all such metadata in one place.
12
+ """
13
+
14
+ def __init__(self):
15
+ """
16
+ Initialize the metadata.
17
+ """
18
+ self.state_data = {}
19
+
20
+ def update(self, *args, **kwargs):
21
+ """
22
+ Construct the metadata from request states.
23
+ """
24
+ pass
25
+
26
+ def reset(self):
27
+ """
28
+ Reset the metadata.
29
+ """
30
+ pass
31
+
32
+ def tensor_copy_and_pad(
33
+ self,
34
+ tensor_buf,
35
+ unpadded_tensor,
36
+ real_batch_size,
37
+ padded_batch_size,
38
+ is_cumulative_tensor=False,
39
+ pad_value=0,
40
+ ):
41
+ """
42
+ Copy the unpadded tensor to the tensor_buf,
43
+ pad the tensor_buf with zero or the last value of the tensor,
44
+ depending on whether the tensor is cumulative.
45
+ Args:
46
+ tensor_buf: The destination tensor, at least padded_batch_size long.
47
+ unpadded_tensor: The tensor to copy, at least real_batch_size long.
48
+ real_batch_size: The real batch size.
49
+ padded_batch_size: Padded boundary of the tensor.
50
+ is_cumulative_tensor: Whether the tensor is cumulative.
51
+ If True, we pad the tensor_buf with the last value of the unpadded_tensor.
52
+ pad_value: The value to pad the tensor_buf with when the tensor is not cumulative.
53
+ """
54
+ assert real_batch_size <= padded_batch_size
55
+ assert tensor_buf.shape[0] >= padded_batch_size
56
+ assert unpadded_tensor.shape[0] >= real_batch_size
57
+ if is_cumulative_tensor:
58
+ if real_batch_size == 0:
59
+ value = pad_value
60
+ else:
61
+ value = unpadded_tensor[real_batch_size - 1]
62
+ else:
63
+ value = pad_value
64
+ tensor_buf[0:real_batch_size] = unpadded_tensor[:real_batch_size]
65
+ tensor_buf[real_batch_size:padded_batch_size] = value
66
+ return tensor_buf
67
+
68
+ def __str__(self):
69
+ """
70
+ Return a string representation of the metadata.
71
+ """
72
+ return "\n".join([f"{key}: {value}" for key, value in self.state_data.items()])
@@ -0,0 +1,220 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from .metadata_base import MetadataBase
8
+
9
+
10
+ class MHAMetadata(MetadataBase):
11
+ """
12
+ Metadata for MHA layer using flash-attention.
13
+ """
14
+
15
+ def __init__(
16
+ self, block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen
17
+ ):
18
+ super().__init__()
19
+ device = torch.cuda.current_device()
20
+ self.device = device
21
+ self.max_blocks = block_count_total
22
+ self.max_kv_blocks = max_kv_block_count
23
+ self.max_bs = max_requests
24
+ self.max_seqlen = max_seqlen
25
+ self._query_lengths_buf = torch.zeros(self.max_bs, dtype=torch.int32, device=device)
26
+ self._cu_query_seq_lengths_buf = torch.zeros(
27
+ self.max_bs + 1, dtype=torch.int32, device=device
28
+ )
29
+ self._cu_kv_seq_lengths_buf = torch.zeros(self.max_bs + 1, dtype=torch.int32, device=device)
30
+ self._kv_seq_lengths_buf = torch.zeros(self.max_bs, dtype=torch.int32, device=device)
31
+ self._block_table_buf = torch.zeros(
32
+ (self.max_bs, self.max_kv_blocks), dtype=torch.int32, device=device
33
+ )
34
+ self._max_seqlen_q = 0
35
+ self._max_seqlen_k = 0
36
+ self.state_data = {}
37
+
38
+ def update(
39
+ self,
40
+ request_query_lengths: torch.Tensor,
41
+ request_kv_length_offsets: torch.Tensor,
42
+ request_to_kv_block_ids: torch.Tensor,
43
+ padded_active_token_count: int,
44
+ real_batch_size: int,
45
+ padded_active_request_count: Optional[int] = None,
46
+ decode_only: bool = False,
47
+ ):
48
+ """
49
+ Args:
50
+ request_query_lengths: (>real_batch_size,)
51
+ request_kv_length_offsets: (>real_batch_size,)
52
+ request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
53
+ padded_active_token_count: int
54
+ real_batch_size: int
55
+ padded_active_request_count: Optional[int]
56
+ decode_only: bool
57
+ """
58
+ if padded_active_request_count is None:
59
+ padded_active_request_count = real_batch_size
60
+
61
+ assert real_batch_size <= padded_active_request_count <= self.max_bs
62
+ assert request_query_lengths.shape[0] == real_batch_size
63
+ assert request_kv_length_offsets.shape[0] == real_batch_size
64
+ assert request_to_kv_block_ids.shape[0] == real_batch_size
65
+
66
+ self.tensor_copy_and_pad(
67
+ self._query_lengths_buf,
68
+ request_query_lengths,
69
+ real_batch_size,
70
+ padded_active_request_count,
71
+ )
72
+ self._cu_query_seq_lengths_buf[0] = 0
73
+ self.tensor_copy_and_pad(
74
+ self._cu_query_seq_lengths_buf[1:],
75
+ torch.cumsum(request_query_lengths, dim=0),
76
+ real_batch_size,
77
+ padded_active_request_count,
78
+ is_cumulative_tensor=True,
79
+ )
80
+ self.tensor_copy_and_pad(
81
+ self._kv_seq_lengths_buf,
82
+ request_kv_length_offsets + request_query_lengths,
83
+ real_batch_size,
84
+ padded_active_request_count,
85
+ )
86
+ self.tensor_copy_and_pad(
87
+ self._block_table_buf,
88
+ request_to_kv_block_ids,
89
+ real_batch_size,
90
+ padded_active_request_count,
91
+ pad_value=torch.tensor(self.max_kv_blocks, dtype=torch.int32, device=self.device).fill_(
92
+ -1
93
+ ),
94
+ )
95
+ self._cu_kv_seq_lengths_buf[0] = 0
96
+ self.tensor_copy_and_pad(
97
+ self._cu_kv_seq_lengths_buf[1:],
98
+ torch.cumsum(self._kv_seq_lengths_buf, dim=0),
99
+ real_batch_size,
100
+ padded_active_request_count,
101
+ is_cumulative_tensor=True,
102
+ )
103
+
104
+ if decode_only:
105
+ self._max_seqlen_q = 1
106
+ else:
107
+ self._max_seqlen_q = max(2, padded_active_token_count)
108
+ self._max_seqlen_k = self.max_seqlen
109
+
110
+ self.state_data = {
111
+ "query_lengths": self._query_lengths_buf[:padded_active_request_count],
112
+ "cu_query_seq_lengths": self._cu_query_seq_lengths_buf[
113
+ : padded_active_request_count + 1
114
+ ],
115
+ "cu_kv_seq_lengths": self._cu_kv_seq_lengths_buf[: padded_active_request_count + 1],
116
+ "kv_seq_lengths": self._kv_seq_lengths_buf[:padded_active_request_count],
117
+ "block_table": self._block_table_buf[0:padded_active_request_count, :],
118
+ "max_seqlen_q": self._max_seqlen_q,
119
+ "max_seqlen_k": self._max_seqlen_k,
120
+ }
121
+
122
+ def reset(self):
123
+ """
124
+ Reset the metadata for the next batch.
125
+ """
126
+ self._query_lengths_buf.fill_(0)
127
+ self._cu_query_seq_lengths_buf.fill_(0)
128
+ self._cu_kv_seq_lengths_buf.fill_(0)
129
+ self._kv_seq_lengths_buf.fill_(0)
130
+ self._block_table_buf.fill_(0)
131
+ self._max_seqlen_q = 0
132
+ self._max_seqlen_k = 0
133
+
134
+
135
+ class GraphedMHAMetadata(MHAMetadata):
136
+ """
137
+ Metadata for MHA layer using flash-attention with CUDA graphs.
138
+ """
139
+
140
+ def __init__(
141
+ self, block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen
142
+ ):
143
+ super().__init__(
144
+ block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen
145
+ )
146
+
147
+ def update(
148
+ self,
149
+ request_query_lengths: torch.Tensor,
150
+ request_kv_length_offsets: torch.Tensor,
151
+ request_to_kv_block_ids: torch.Tensor,
152
+ padded_active_token_count: int,
153
+ real_batch_size: int,
154
+ padded_active_request_count: Optional[int] = None,
155
+ decode_only: bool = False,
156
+ ):
157
+ """
158
+ Args:
159
+ request_query_lengths: (>real_batch_size,)
160
+ request_kv_length_offsets: (>real_batch_size,)
161
+ request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
162
+ padded_active_token_count: int
163
+ real_batch_size: int
164
+ padded_active_request_count: Optional[int]
165
+ decode_only: bool
166
+ """
167
+ super().update(
168
+ request_query_lengths,
169
+ request_kv_length_offsets,
170
+ request_to_kv_block_ids,
171
+ padded_active_token_count,
172
+ real_batch_size,
173
+ padded_active_request_count,
174
+ decode_only,
175
+ )
176
+
177
+ def reset(self):
178
+ super().reset()
179
+
180
+
181
+ class NonGraphedMHAMetadata(MHAMetadata):
182
+ """
183
+ Metadata for MHA layer using flash-attention without CUDA graphs.
184
+ """
185
+
186
+ def update(
187
+ self,
188
+ request_query_lengths: torch.Tensor,
189
+ request_kv_length_offsets: torch.Tensor,
190
+ request_to_kv_block_ids: torch.Tensor,
191
+ padded_active_token_count: int,
192
+ real_batch_size: int,
193
+ padded_active_request_count: Optional[int] = None,
194
+ decode_only: bool = False,
195
+ ):
196
+ """
197
+ Args:
198
+ request_query_lengths: (>real_batch_size,)
199
+ request_kv_length_offsets: (>real_batch_size,)
200
+ request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
201
+ padded_active_token_count: int
202
+ real_batch_size: int
203
+ padded_active_request_count: Optional[int]
204
+ decode_only: bool
205
+ """
206
+ super().update(
207
+ request_query_lengths,
208
+ request_kv_length_offsets,
209
+ request_to_kv_block_ids,
210
+ padded_active_token_count,
211
+ real_batch_size,
212
+ padded_active_request_count,
213
+ decode_only,
214
+ )
215
+ if len(self.state_data["query_lengths"]) > 0:
216
+ self.state_data["max_seqlen_q"] = torch.max(self.state_data["query_lengths"]).item()
217
+ self.state_data["max_seqlen_k"] = torch.max(self.state_data["kv_seq_lengths"]).item()
218
+ else:
219
+ self.state_data["max_seqlen_q"] = 1
220
+ self.state_data["max_seqlen_k"] = 1
@@ -16,13 +16,17 @@ from megatron.core.inference.inference_request import DynamicInferenceRequest
16
16
  from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
17
17
  InferenceWrapperConfig,
18
18
  )
19
- from megatron.core.inference.unified_memory import create_unified_mempool, has_unified_memory
19
+ from megatron.core.inference.unified_memory import (
20
+ UnifiedMemoryUnsupportedError,
21
+ create_unified_mempool,
22
+ )
20
23
  from megatron.core.inference.utils import tensor_swap
21
24
  from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
22
25
  from megatron.core.package_info import __version__ as mcore_version
23
26
  from megatron.core.transformer import TransformerConfig
24
27
  from megatron.core.utils import divide as core_divide
25
28
 
29
+ from .attention_context.mha_metadata import GraphedMHAMetadata, NonGraphedMHAMetadata
26
30
  from .base_context import BaseInferenceContext
27
31
  from .dynamic_block_allocator import BlockAllocator
28
32
 
@@ -322,16 +326,20 @@ class DynamicInferenceContext(BaseInferenceContext):
322
326
  self.params_dtype = params_dtype
323
327
  self.num_layers = num_layers
324
328
  self.max_sequence_length = max_sequence_length
329
+
330
+ # Unified memory.
325
331
  self.unified_memory_level = unified_memory_level
326
332
  if unified_memory_level > 0:
327
- if not has_unified_memory and torch.distributed.get_rank() == 0:
328
- warnings.warn(
329
- "Unified memory requested but not available; defaulting to GPU memory."
330
- )
331
- self.unified_memory_level = 0
332
- else:
333
+ try:
333
334
  self.unified_memory_mempool = create_unified_mempool()
335
+ except UnifiedMemoryUnsupportedError:
336
+ if torch.distributed.get_rank() == 0:
337
+ warnings.warn(
338
+ "Unified memory requested but not available; defaulting to GPU memory."
339
+ )
340
+ self.unified_memory_level = 0
334
341
 
342
+ # Request and token counts.
335
343
  self.total_request_count = 0
336
344
  self.active_token_count = 0
337
345
  self.paused_request_count = 0
@@ -447,30 +455,26 @@ class DynamicInferenceContext(BaseInferenceContext):
447
455
  num_cuda_graphs is not None
448
456
  )
449
457
 
450
- # `*_cudagraph_only` tensors are for use with cuda graphs to maintain
451
- # consistent input shapes, which is required to use cuda graphs.
452
- # During these steps, the `*_cudagraph_only`
453
- # tensors are used, otherwise their same-name but un-suffixed
454
- # corresponding tensors are used.
458
+ # Attention metadata initialization (tensors are now handled by MHAMetadata classes)
455
459
 
456
- self.query_seq_lengths_cudagraph_only = torch.full(
457
- (self.max_requests,), 0, dtype=torch.int32, device=torch.cuda.current_device()
458
- )
459
- self.cu_query_seq_lengths_cudagraph_only = torch.full(
460
- (self.max_requests + 1,), 0, dtype=torch.int32, device=torch.cuda.current_device()
461
- )
462
- self.kv_seq_lengths_cudagraph_only = torch.full(
463
- (self.max_requests,), 0, dtype=torch.int32, device=torch.cuda.current_device()
464
- )
465
- self.cu_kv_seq_lengths_cudagraph_only = torch.full(
466
- (self.max_requests + 1,), 0, dtype=torch.int32, device=torch.cuda.current_device()
460
+ self.graph_attn_metadata = {}
461
+ self.non_graph_attn_metadata = {}
462
+ self.active_attn_metadata = None
463
+
464
+ self.graph_attn_metadata["mha_metadata"] = GraphedMHAMetadata(
465
+ block_count_total=block_count_total,
466
+ max_kv_block_count=self.max_kv_block_count,
467
+ max_requests=self.max_requests,
468
+ block_size_tokens=self.block_size_tokens,
469
+ max_seqlen=self.max_sequence_length,
467
470
  )
468
471
 
469
- self.request_to_kv_block_ids_cudagraph_only = torch.full(
470
- (self.max_requests, self.max_kv_block_count),
471
- 0,
472
- dtype=torch.int,
473
- device=torch.cuda.current_device(),
472
+ self.non_graph_attn_metadata["mha_metadata"] = NonGraphedMHAMetadata(
473
+ block_count_total=block_count_total,
474
+ max_kv_block_count=self.max_kv_block_count,
475
+ max_requests=self.max_requests,
476
+ block_size_tokens=self.block_size_tokens,
477
+ max_seqlen=self.max_sequence_length,
474
478
  )
475
479
 
476
480
  # Guaranteed active requests.
@@ -620,11 +624,18 @@ class DynamicInferenceContext(BaseInferenceContext):
620
624
 
621
625
  def cu_query_lengths(self) -> Tuple[Tensor, int]:
622
626
  """Cumulative query sequence lengths."""
623
- return self.cu_query_seq_lengths, self.max_seqlen_q
627
+ return (
628
+ self.active_attn_metadata["mha_metadata"].state_data["cu_query_seq_lengths"],
629
+ self.active_attn_metadata["mha_metadata"].state_data["max_seqlen_q"],
630
+ )
624
631
 
625
- def cu_kv_lengths(self) -> Tensor:
632
+ def cu_kv_lengths(self) -> Tuple[Tensor, Tensor, int]:
626
633
  """Cumulative key/value sequence lengths."""
627
- return (self.cu_kv_seq_lengths, self.kv_seq_lengths, self.max_seqlen_k)
634
+ return (
635
+ self.active_attn_metadata["mha_metadata"].state_data["cu_kv_seq_lengths"],
636
+ self.active_attn_metadata["mha_metadata"].state_data["kv_seq_lengths"],
637
+ self.active_attn_metadata["mha_metadata"].state_data["max_seqlen_k"],
638
+ )
628
639
 
629
640
  def get_active_sequence_lengths(self) -> Tensor:
630
641
  """Total sequence length (query + key) for active requests."""
@@ -702,12 +713,16 @@ class DynamicInferenceContext(BaseInferenceContext):
702
713
  to blocks within the block-level memory buffer.
703
714
  """
704
715
  if self.cache_mla_latent:
705
- return (self.memory_buffer[layer_number - 1], None, self.block_table)
716
+ return (
717
+ self.memory_buffer[layer_number - 1],
718
+ None,
719
+ self.active_attn_metadata["mha_metadata"].state_data["block_table"],
720
+ )
706
721
  else:
707
722
  return (
708
723
  self.memory_buffer[0, layer_number - 1],
709
724
  self.memory_buffer[1, layer_number - 1],
710
- self.block_table,
725
+ self.active_attn_metadata["mha_metadata"].state_data["block_table"],
711
726
  )
712
727
 
713
728
  def apply_fused_qk_rotary_emb(
@@ -817,17 +832,12 @@ class DynamicInferenceContext(BaseInferenceContext):
817
832
 
818
833
  def reset_attention_state(self) -> None:
819
834
  """Reset state used within attention, after each step."""
820
- self.max_seqlen_q = None
821
- self.max_seqlen_k = None
822
- self.cu_query_seq_lengths = None
823
- self.cu_query_seq_lengths_cudagraph_only.fill_(0)
824
- self.query_seq_lengths_cudagraph_only.fill_(0)
825
- self.cu_kv_seq_lengths = None
826
- self.cu_kv_seq_lengths_cudagraph_only.fill_(0)
827
- self.kv_seq_lengths = None
828
- self.kv_seq_lengths_cudagraph_only.fill_(0)
829
- self.request_to_kv_block_ids_cudagraph_only.fill_(0)
830
- self.block_table = None
835
+ # Attention metadata reset is now handled by MHAMetadata.reset()
836
+ for attn_metadata in self.non_graph_attn_metadata.values():
837
+ attn_metadata.reset()
838
+ for attn_metadata in self.graph_attn_metadata.values():
839
+ attn_metadata.reset()
840
+ self.active_attn_metadata = None
831
841
 
832
842
  def using_cuda_graph_this_step(self) -> bool:
833
843
  """Returns True if cuda graphs are being used for this step."""
@@ -927,89 +937,30 @@ class DynamicInferenceContext(BaseInferenceContext):
927
937
  self.active_token_count : self.padded_active_token_count
928
938
  ] = 0
929
939
 
930
- # Update cu_query_seq_lengths, max_seqlen_q.
931
- query_lengths = self.request_query_lengths[
932
- self.paused_request_count : self.total_request_count
933
- ]
934
- if self.is_decode_only() or self.using_cuda_graph_this_step():
935
- self.query_seq_lengths_cudagraph_only[
936
- 0 : self.total_request_count - self.paused_request_count
937
- ] = query_lengths
938
- if self.is_decode_only():
939
- self.cu_query_seq_lengths = None # ensure no accidental use
940
- self.max_seqlen_q = 1
941
- else:
942
- self.cu_query_seq_lengths_cudagraph_only[
943
- 1 : self.padded_active_request_count + 1
944
- ] = torch.cumsum(
945
- self.query_seq_lengths_cudagraph_only[: self.padded_active_request_count], dim=0
946
- )
947
-
948
- # The following will be passed to the FA kernel.
949
- self.cu_query_seq_lengths = self.cu_query_seq_lengths_cudagraph_only[
950
- : (self.padded_active_request_count + 1)
951
- ]
952
- self.max_seqlen_q = self.padded_active_token_count
953
- else:
954
- cu_query_lengths = torch.cumsum(query_lengths, dim=0)
955
- self.cu_query_seq_lengths = torch.full(
956
- (self.total_request_count - self.paused_request_count + 1,),
957
- 0,
958
- dtype=torch.int32,
959
- device=torch.cuda.current_device(),
960
- )
961
- self.cu_query_seq_lengths[1:] = cu_query_lengths
962
- self.max_seqlen_q = query_lengths.max().item()
963
-
964
- kv_seq_lengths = self.request_kv_length_offsets + self.request_query_lengths
965
- self.kv_seq_lengths = kv_seq_lengths[self.paused_request_count : self.total_request_count]
966
- if self.is_decode_only() or self.using_cuda_graph_this_step():
967
- # Re-assign `kv_seq_lengths` to be a view of the first
968
- # `active_cuda_graph_request_count` tokens of `kv_seq_lengths_decode_only`,
969
- # such that `kv_seq_lengths` has a static memory address and is therefore
970
- # cuda graph compatible. This allows `kv_seq_lengths` to transition between,
971
- # cuda graph sizes, which makes multi-batch-size cuda graphs possible.
972
- self.kv_seq_lengths_cudagraph_only[
973
- 0 : self.total_request_count - self.paused_request_count
974
- ] = self.kv_seq_lengths
975
- self.kv_seq_lengths = self.kv_seq_lengths_cudagraph_only[
976
- : self.padded_active_request_count
977
- ]
978
- self.max_seqlen_k = self.max_sequence_length
979
- if self.is_decode_only():
980
- self.cu_kv_seq_lengths = None # ensure no accidental use
981
- else:
982
- cu_kv_lengths = torch.cumsum(self.kv_seq_lengths, dim=0)
983
- # The following will be passed to the FA kernel.
984
- self.cu_kv_seq_lengths_cudagraph_only[1 : cu_kv_lengths.size(0) + 1] = cu_kv_lengths
985
- self.cu_kv_seq_lengths = self.cu_kv_seq_lengths_cudagraph_only[
986
- : (self.padded_active_request_count + 1)
987
- ]
988
- else:
989
- self.cu_kv_seq_lengths = torch.full(
990
- (self.total_request_count - self.paused_request_count + 1,),
991
- 0,
992
- dtype=torch.int32,
993
- device=torch.cuda.current_device(),
994
- )
995
- self.cu_kv_seq_lengths[1:] = torch.cumsum(self.kv_seq_lengths, dim=0)
996
- self.max_seqlen_k = self.kv_seq_lengths.max().item()
940
+ real_req_batch_size = (
941
+ self.total_request_count - self.paused_request_count
942
+ ) # how many requests are indeed active
943
+ self.active_attn_metadata = (
944
+ self.graph_attn_metadata
945
+ if self.using_cuda_graph_this_step()
946
+ else self.non_graph_attn_metadata
947
+ )
997
948
 
998
- # Update KV block IDs, block table.
999
- request_to_kv_block_ids = self.request_to_kv_block_ids[
1000
- self.paused_request_count : self.total_request_count
1001
- ]
1002
- if self.is_decode_only() or self.using_cuda_graph_this_step():
1003
- self.request_to_kv_block_ids_cudagraph_only[
1004
- 0 : self.total_request_count - self.paused_request_count
1005
- ] = request_to_kv_block_ids
1006
- self.block_table = self.request_to_kv_block_ids_cudagraph_only[
1007
- : self.padded_active_request_count
1008
- ]
1009
- else:
1010
- self.block_table = self.request_to_kv_block_ids[
1011
- self.paused_request_count : self.total_request_count
1012
- ]
949
+ # Update cu_query_seq_lengths, max_seqlen_q.
950
+ active_slice = slice(self.paused_request_count, self.total_request_count)
951
+ query_lengths_view = self.request_query_lengths[active_slice]
952
+ request_kv_length_offsets_view = self.request_kv_length_offsets[active_slice]
953
+ request_to_kv_block_ids_view = self.request_to_kv_block_ids[active_slice]
954
+ self.active_attn_metadata["mha_metadata"].update(
955
+ request_query_lengths=query_lengths_view,
956
+ request_kv_length_offsets=request_kv_length_offsets_view,
957
+ request_to_kv_block_ids=request_to_kv_block_ids_view,
958
+ padded_active_token_count=self.padded_active_token_count,
959
+ real_batch_size=real_req_batch_size,
960
+ padded_active_request_count=self.padded_active_request_count,
961
+ decode_only=self.is_decode_only(),
962
+ )
963
+ # All attention metadata calculations are now handled by MHAMetadata.update()
1013
964
 
1014
965
  def reset(self) -> None:
1015
966
  """Reset entire context.
@@ -165,6 +165,17 @@ class DynamicInferenceEngine(AbstractEngine):
165
165
  context = self.context
166
166
  controller = self.controller
167
167
 
168
+ config = controller.inference_wrapped_model.inference_wrapper_config
169
+ moe_pad_experts = config.moe_pad_experts_for_cuda_graph_inference
170
+
171
+ if moe_pad_experts and context.non_decode_cuda_graphs:
172
+ context.non_decode_cuda_graphs = False
173
+ if torch.distributed.get_rank() == 0:
174
+ warnings.warn(
175
+ "MoE models do not support non-decode cuda graphs. "
176
+ "Forcing non_decode_cuda_graphs to False."
177
+ )
178
+
168
179
  time_start = time.time()
169
180
  mem_stats_start = torch.cuda.memory_stats()
170
181
 
@@ -174,15 +185,18 @@ class DynamicInferenceEngine(AbstractEngine):
174
185
  context.cuda_graph_token_counts,
175
186
  )
176
187
  for warmup_engine_mode in [WarmupEngineMode.DECODE, WarmupEngineMode.NON_DECODE]:
177
- # Iterate cuda graph dims.
188
+ # Check whether to skip non-decode graphs.
178
189
  if (
179
190
  warmup_engine_mode == WarmupEngineMode.NON_DECODE
180
191
  and not context.non_decode_cuda_graphs
181
192
  ):
182
193
  continue
194
+
183
195
  tbar = enumerate(context.cuda_graph_token_counts)
184
196
  if HAVE_TQDM:
185
197
  tbar = tqdm(tbar, total=len(context.cuda_graph_token_counts))
198
+
199
+ # Iterate cuda graph dims.
186
200
  for tbar_idx, cuda_graph_token_count in tbar:
187
201
  if (
188
202
  cuda_graph_token_count == 1
@@ -508,7 +508,8 @@ class TextGenerationController:
508
508
  inference_wrapper_config.moe_pad_experts_for_cuda_graph_inference
509
509
  )
510
510
  if moe_pad_experts_for_cuda_graph_inference:
511
- if context.is_decode_only() or warmup_engine_mode is not None:
511
+ assert warmup_engine_mode is not WarmupEngineMode.NON_DECODE
512
+ if context.is_decode_only():
512
513
  capacity_factor = model_config.num_moe_experts / model_config.moe_router_topk
513
514
  set_decode_expert_padding(unwrapped_model, True, capacity_factor=capacity_factor)
514
515
  else: