megatron-core 0.16.0rc0.dev131152__tar.gz → 0.16.0rc0.dev131564__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (360) hide show
  1. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/PKG-INFO +1 -1
  2. megatron_core-0.16.0rc0.dev131564/megatron/core/inference/contexts/attention_context/metadata_base.py +72 -0
  3. megatron_core-0.16.0rc0.dev131564/megatron/core/inference/contexts/attention_context/mha_metadata.py +220 -0
  4. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/contexts/dynamic_context.py +143 -120
  5. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/engines/dynamic_engine.py +72 -0
  6. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/unified_memory.py +1 -1
  7. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/gpt/gpt_model.py +1 -2
  8. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/package_info.py +1 -1
  9. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron_core.egg-info/PKG-INFO +1 -1
  10. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron_core.egg-info/SOURCES.txt +2 -0
  11. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/MANIFEST.in +0 -0
  12. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/README.md +0 -0
  13. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/README.md +0 -0
  14. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/__init__.py +0 -0
  15. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/activations.py +0 -0
  16. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/config.py +0 -0
  17. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/config_logger.py +0 -0
  18. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/__init__.py +0 -0
  19. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/bert_dataset.py +0 -0
  20. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/blended_dataset.py +0 -0
  21. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  22. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  23. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/gpt_dataset.py +0 -0
  24. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/helpers.cpp +0 -0
  25. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/helpers.py +0 -0
  26. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/indexed_dataset.py +0 -0
  27. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/masked_dataset.py +0 -0
  28. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/megatron_dataset.py +0 -0
  29. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  30. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/multimodal_dataset.py +0 -0
  31. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/object_storage_utils.py +0 -0
  32. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/__init__.py +0 -0
  33. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/config/__init__.py +0 -0
  34. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  35. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/config/config.py +0 -0
  36. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  37. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  38. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/db/__init__.py +0 -0
  39. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/db/build.py +0 -0
  40. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/db/dataset.py +0 -0
  41. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/db/utils.py +0 -0
  42. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/external_libs.py +0 -0
  43. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/index/__init__.py +0 -0
  44. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/index/build.py +0 -0
  45. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/index/factory.py +0 -0
  46. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/index/index.py +0 -0
  47. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  48. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  49. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  50. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/index/utils.py +0 -0
  51. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/index/validate.py +0 -0
  52. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/query/__init__.py +0 -0
  53. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  54. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  55. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/query/query.py +0 -0
  56. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  57. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/query/utils.py +0 -0
  58. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/retro/utils.py +0 -0
  59. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/t5_dataset.py +0 -0
  60. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/utils.py +0 -0
  61. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/datasets/utils_s3.py +0 -0
  62. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/__init__.py +0 -0
  63. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/core.py +0 -0
  64. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  65. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  66. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/mapping.py +0 -0
  67. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  68. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/serialization.py +0 -0
  69. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  70. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  71. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  72. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  73. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  74. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/checkpointable.py +0 -0
  75. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  76. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  77. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  78. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  79. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  80. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  81. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  82. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  83. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  84. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  85. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/utils.py +0 -0
  86. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/dist_checkpointing/validation.py +0 -0
  87. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/__init__.py +0 -0
  88. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/data_parallel_base.py +0 -0
  89. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  90. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  91. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/finalize_model_grads.py +0 -0
  92. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/__init__.py +0 -0
  93. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +0 -0
  94. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/src/__init__.py +0 -0
  95. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +0 -0
  96. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +0 -0
  97. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +0 -0
  98. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +0 -0
  99. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/src/megatron_fsdp/package_info.py +0 -0
  100. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +0 -0
  101. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +0 -0
  102. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +0 -0
  103. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  104. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/reduce_scatter_with_fp32_accumulation.py +0 -0
  105. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  106. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  107. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/energy_monitor.py +0 -0
  108. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/enums.py +0 -0
  109. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/__init__.py +0 -0
  110. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/data_type.py +0 -0
  111. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/export_config.py +0 -0
  112. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/model_type.py +0 -0
  113. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/__init__.py +0 -0
  114. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  115. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  116. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  117. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  118. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  119. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  120. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  121. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  122. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  123. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  124. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  125. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  126. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/extensions/__init__.py +0 -0
  127. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/extensions/kitchen.py +0 -0
  128. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/extensions/transformer_engine.py +0 -0
  129. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  130. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fp4_utils.py +0 -0
  131. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fp8_utils.py +0 -0
  132. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/full_cuda_graph.py +0 -0
  133. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/__init__.py +0 -0
  134. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  135. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  136. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  137. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  138. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  139. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_indices_converter.py +0 -0
  140. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_layer_norm.py +0 -0
  141. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  142. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  143. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_softmax.py +0 -0
  144. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/fusions/fused_weighted_squared_relu.py +0 -0
  145. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/hyper_comm_grid.py +0 -0
  146. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/__init__.py +0 -0
  147. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/async_stream.py +0 -0
  148. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/common_inference_params.py +0 -0
  149. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/communication_utils.py +0 -0
  150. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/contexts/__init__.py +0 -0
  151. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/contexts/base_context.py +0 -0
  152. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/contexts/dynamic_block_allocator.py +0 -0
  153. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/contexts/fused_kv_append_kernel.py +0 -0
  154. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/contexts/static_context.py +0 -0
  155. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/data_parallel_inference_coordinator.py +0 -0
  156. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/engines/__init__.py +0 -0
  157. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/engines/abstract_engine.py +0 -0
  158. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/engines/mcore_engine.py +0 -0
  159. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/engines/static_engine.py +0 -0
  160. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/headers.py +0 -0
  161. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/inference_client.py +0 -0
  162. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/inference_request.py +0 -0
  163. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  164. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  165. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  166. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  167. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  168. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  169. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  170. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  171. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/sampling_params.py +0 -0
  172. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/scheduler.py +0 -0
  173. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  174. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  175. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  176. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +0 -0
  177. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  178. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_server/__init__.py +0 -0
  179. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_server/endpoints/common.py +0 -0
  180. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_server/endpoints/completions.py +0 -0
  181. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_server/run_mcore_engine.py +0 -0
  182. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_server/text_generation_server.py +0 -0
  183. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/text_generation_server/tokenization.py +0 -0
  184. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference/utils.py +0 -0
  185. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/inference_params.py +0 -0
  186. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/jit.py +0 -0
  187. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/model_parallel_config.py +0 -0
  188. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/T5/__init__.py +0 -0
  189. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/T5/t5_model.py +0 -0
  190. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/T5/t5_spec.py +0 -0
  191. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/__init__.py +0 -0
  192. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/backends.py +0 -0
  193. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/bert/__init__.py +0 -0
  194. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  195. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/bert/bert_lm_head.py +0 -0
  196. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/bert/bert_model.py +0 -0
  197. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/bert/pooler.py +0 -0
  198. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/__init__.py +0 -0
  199. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/embeddings/__init__.py +0 -0
  200. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  201. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  202. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  203. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  204. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  205. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/language_module/__init__.py +0 -0
  206. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/language_module/language_module.py +0 -0
  207. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
  208. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/vision_module/__init__.py +0 -0
  209. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  210. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/gpt/__init__.py +0 -0
  211. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  212. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
  213. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  214. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  215. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/huggingface/__init__.py +0 -0
  216. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/huggingface/clip_model.py +0 -0
  217. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/huggingface/module.py +0 -0
  218. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/huggingface/qwen_model.py +0 -0
  219. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mamba/__init__.py +0 -0
  220. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  221. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mamba/mamba_model.py +0 -0
  222. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mimo/__init__.py +0 -0
  223. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mimo/config/__init__.py +0 -0
  224. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mimo/config/base_configs.py +0 -0
  225. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mimo/model/__init__.py +0 -0
  226. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mimo/model/base.py +0 -0
  227. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mimo/submodules/audio.py +0 -0
  228. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mimo/submodules/base.py +0 -0
  229. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/mimo/submodules/vision.py +0 -0
  230. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/multimodal/__init__.py +0 -0
  231. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/multimodal/context_parallel.py +0 -0
  232. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/multimodal/llava_model.py +0 -0
  233. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/multimodal/llava_spec.py +0 -0
  234. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/retro/__init__.py +0 -0
  235. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/retro/base_attention.py +0 -0
  236. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/retro/config.py +0 -0
  237. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/retro/decoder_attention.py +0 -0
  238. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/retro/decoder_spec.py +0 -0
  239. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/retro/encoder_attention.py +0 -0
  240. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/retro/encoder_spec.py +0 -0
  241. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/retro/model.py +0 -0
  242. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/retro/utils.py +0 -0
  243. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/vision/__init__.py +0 -0
  244. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/vision/clip_vit_model.py +0 -0
  245. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/vision/multimodal_projector.py +0 -0
  246. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/vision/radio.py +0 -0
  247. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  248. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/msc_utils.py +0 -0
  249. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/nccl_allocator.py +0 -0
  250. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/num_microbatches_calculator.py +0 -0
  251. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/optimizer/__init__.py +0 -0
  252. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/optimizer/clip_grads.py +0 -0
  253. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  254. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  255. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/optimizer/distrib_optimizer.py +0 -0
  256. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/optimizer/grad_scaler.py +0 -0
  257. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/optimizer/optimizer.py +0 -0
  258. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/optimizer/optimizer_config.py +0 -0
  259. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/optimizer_param_scheduler.py +0 -0
  260. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/packed_seq_params.py +0 -0
  261. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/parallel_state.py +0 -0
  262. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/pipeline_parallel/__init__.py +0 -0
  263. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/pipeline_parallel/bridge_communicator.py +0 -0
  264. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
  265. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  266. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/pipeline_parallel/schedules.py +0 -0
  267. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/pipeline_parallel/utils.py +0 -0
  268. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/post_training/__init__.py +0 -0
  269. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/post_training/modelopt/__init__.py +0 -0
  270. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  271. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  272. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  273. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/post_training/modelopt/layers.py +0 -0
  274. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  275. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  276. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/process_groups_config.py +0 -0
  277. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/quantization/__init__.py +0 -0
  278. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/quantization/quant_config.py +0 -0
  279. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/quantization/utils.py +0 -0
  280. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/requirements.txt +0 -0
  281. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/rerun_state_machine.py +0 -0
  282. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/safe_globals.py +0 -0
  283. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/ssm/__init__.py +0 -0
  284. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/ssm/mamba_block.py +0 -0
  285. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  286. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  287. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/ssm/mamba_layer.py +0 -0
  288. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/ssm/mamba_mixer.py +0 -0
  289. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/ssm/mlp_layer.py +0 -0
  290. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/ssm/triton_cache_manager.py +0 -0
  291. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tensor_parallel/__init__.py +0 -0
  292. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  293. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tensor_parallel/data.py +0 -0
  294. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tensor_parallel/layers.py +0 -0
  295. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tensor_parallel/mappings.py +0 -0
  296. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tensor_parallel/random.py +0 -0
  297. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tensor_parallel/utils.py +0 -0
  298. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/timers.py +0 -0
  299. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/__init__.py +0 -0
  300. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/base_tokenizer.py +0 -0
  301. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/megatron_tokenizer.py +0 -0
  302. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/__init__.py +0 -0
  303. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/libraries/__init__.py +0 -0
  304. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/libraries/abstract_tokenizer.py +0 -0
  305. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/libraries/bytelevel_tokenizer.py +0 -0
  306. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/libraries/chat_template.py +0 -0
  307. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py +0 -0
  308. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py +0 -0
  309. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/libraries/null_tokenizer.py +0 -0
  310. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py +0 -0
  311. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py +0 -0
  312. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/models/__init__.py +0 -0
  313. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/models/bert_tokenizer.py +0 -0
  314. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/models/default_tokenizer.py +0 -0
  315. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/models/gpt_tokenizer.py +0 -0
  316. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/models/mamba_tokenizer.py +0 -0
  317. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/models/retro_tokenizer.py +0 -0
  318. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/models/t5_tokenizer.py +0 -0
  319. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/text_tokenizer.py +0 -0
  320. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/tokenizers/text/utils/build_tokenizer.py +0 -0
  321. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/__init__.py +0 -0
  322. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/attention.py +0 -0
  323. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/cuda_graphs.py +0 -0
  324. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  325. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  326. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/dot_product_attention.py +0 -0
  327. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/enums.py +0 -0
  328. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/fsdp_dtensor_checkpoint.py +0 -0
  329. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  330. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  331. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/identity_op.py +0 -0
  332. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/mlp.py +0 -0
  333. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/module.py +0 -0
  334. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/__init__.py +0 -0
  335. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/experts.py +0 -0
  336. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  337. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  338. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/moe_layer.py +0 -0
  339. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/moe_utils.py +0 -0
  340. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/router.py +0 -0
  341. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/shared_experts.py +0 -0
  342. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  343. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  344. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/multi_latent_attention.py +0 -0
  345. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/multi_token_prediction.py +0 -0
  346. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  347. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/spec_utils.py +0 -0
  348. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/torch_layer_norm.py +0 -0
  349. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/torch_norm.py +0 -0
  350. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/transformer_block.py +0 -0
  351. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/transformer_config.py +0 -0
  352. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/transformer_layer.py +0 -0
  353. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/transformer/utils.py +0 -0
  354. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron/core/utils.py +0 -0
  355. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron_core.egg-info/dependency_links.txt +0 -0
  356. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron_core.egg-info/requires.txt +0 -0
  357. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/megatron_core.egg-info/top_level.txt +0 -0
  358. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/pyproject.toml +0 -0
  359. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/setup.cfg +0 -0
  360. {megatron_core-0.16.0rc0.dev131152 → megatron_core-0.16.0rc0.dev131564}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.16.0rc0.dev131152
3
+ Version: 0.16.0rc0.dev131564
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -0,0 +1,72 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+
3
+
4
+ class MetadataBase:
5
+ """
6
+ Base class for attention metadata.
7
+ High-performance attention kernels often require input metadata in specific
8
+ formats—such as cumulative query lengths, cumulative key/value lengths,
9
+ and similar structures. Moreover, when using CUDA Graphs, these metadata
10
+ buffers must be statically allocated. This class serves as a unified container
11
+ that manages all such metadata in one place.
12
+ """
13
+
14
+ def __init__(self):
15
+ """
16
+ Initialize the metadata.
17
+ """
18
+ self.state_data = {}
19
+
20
+ def update(self, *args, **kwargs):
21
+ """
22
+ Construct the metadata from request states.
23
+ """
24
+ pass
25
+
26
+ def reset(self):
27
+ """
28
+ Reset the metadata.
29
+ """
30
+ pass
31
+
32
+ def tensor_copy_and_pad(
33
+ self,
34
+ tensor_buf,
35
+ unpadded_tensor,
36
+ real_batch_size,
37
+ padded_batch_size,
38
+ is_cumulative_tensor=False,
39
+ pad_value=0,
40
+ ):
41
+ """
42
+ Copy the unpadded tensor to the tensor_buf,
43
+ pad the tensor_buf with zero or the last value of the tensor,
44
+ depending on whether the tensor is cumulative.
45
+ Args:
46
+ tensor_buf: The destination tensor, at least padded_batch_size long.
47
+ unpadded_tensor: The tensor to copy, at least real_batch_size long.
48
+ real_batch_size: The real batch size.
49
+ padded_batch_size: Padded boundary of the tensor.
50
+ is_cumulative_tensor: Whether the tensor is cumulative.
51
+ If True, we pad the tensor_buf with the last value of the unpadded_tensor.
52
+ pad_value: The value to pad the tensor_buf with when the tensor is not cumulative.
53
+ """
54
+ assert real_batch_size <= padded_batch_size
55
+ assert tensor_buf.shape[0] >= padded_batch_size
56
+ assert unpadded_tensor.shape[0] >= real_batch_size
57
+ if is_cumulative_tensor:
58
+ if real_batch_size == 0:
59
+ value = pad_value
60
+ else:
61
+ value = unpadded_tensor[real_batch_size - 1]
62
+ else:
63
+ value = pad_value
64
+ tensor_buf[0:real_batch_size] = unpadded_tensor[:real_batch_size]
65
+ tensor_buf[real_batch_size:padded_batch_size] = value
66
+ return tensor_buf
67
+
68
+ def __str__(self):
69
+ """
70
+ Return a string representation of the metadata.
71
+ """
72
+ return "\n".join([f"{key}: {value}" for key, value in self.state_data.items()])
@@ -0,0 +1,220 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from .metadata_base import MetadataBase
8
+
9
+
10
+ class MHAMetadata(MetadataBase):
11
+ """
12
+ Metadata for MHA layer using flash-attention.
13
+ """
14
+
15
+ def __init__(
16
+ self, block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen
17
+ ):
18
+ super().__init__()
19
+ device = torch.cuda.current_device()
20
+ self.device = device
21
+ self.max_blocks = block_count_total
22
+ self.max_kv_blocks = max_kv_block_count
23
+ self.max_bs = max_requests
24
+ self.max_seqlen = max_seqlen
25
+ self._query_lengths_buf = torch.zeros(self.max_bs, dtype=torch.int32, device=device)
26
+ self._cu_query_seq_lengths_buf = torch.zeros(
27
+ self.max_bs + 1, dtype=torch.int32, device=device
28
+ )
29
+ self._cu_kv_seq_lengths_buf = torch.zeros(self.max_bs + 1, dtype=torch.int32, device=device)
30
+ self._kv_seq_lengths_buf = torch.zeros(self.max_bs, dtype=torch.int32, device=device)
31
+ self._block_table_buf = torch.zeros(
32
+ (self.max_bs, self.max_kv_blocks), dtype=torch.int32, device=device
33
+ )
34
+ self._max_seqlen_q = 0
35
+ self._max_seqlen_k = 0
36
+ self.state_data = {}
37
+
38
+ def update(
39
+ self,
40
+ request_query_lengths: torch.Tensor,
41
+ request_kv_length_offsets: torch.Tensor,
42
+ request_to_kv_block_ids: torch.Tensor,
43
+ padded_active_token_count: int,
44
+ real_batch_size: int,
45
+ padded_active_request_count: Optional[int] = None,
46
+ decode_only: bool = False,
47
+ ):
48
+ """
49
+ Args:
50
+ request_query_lengths: (>real_batch_size,)
51
+ request_kv_length_offsets: (>real_batch_size,)
52
+ request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
53
+ padded_active_token_count: int
54
+ real_batch_size: int
55
+ padded_active_request_count: Optional[int]
56
+ decode_only: bool
57
+ """
58
+ if padded_active_request_count is None:
59
+ padded_active_request_count = real_batch_size
60
+
61
+ assert real_batch_size <= padded_active_request_count <= self.max_bs
62
+ assert request_query_lengths.shape[0] == real_batch_size
63
+ assert request_kv_length_offsets.shape[0] == real_batch_size
64
+ assert request_to_kv_block_ids.shape[0] == real_batch_size
65
+
66
+ self.tensor_copy_and_pad(
67
+ self._query_lengths_buf,
68
+ request_query_lengths,
69
+ real_batch_size,
70
+ padded_active_request_count,
71
+ )
72
+ self._cu_query_seq_lengths_buf[0] = 0
73
+ self.tensor_copy_and_pad(
74
+ self._cu_query_seq_lengths_buf[1:],
75
+ torch.cumsum(request_query_lengths, dim=0),
76
+ real_batch_size,
77
+ padded_active_request_count,
78
+ is_cumulative_tensor=True,
79
+ )
80
+ self.tensor_copy_and_pad(
81
+ self._kv_seq_lengths_buf,
82
+ request_kv_length_offsets + request_query_lengths,
83
+ real_batch_size,
84
+ padded_active_request_count,
85
+ )
86
+ self.tensor_copy_and_pad(
87
+ self._block_table_buf,
88
+ request_to_kv_block_ids,
89
+ real_batch_size,
90
+ padded_active_request_count,
91
+ pad_value=torch.tensor(self.max_kv_blocks, dtype=torch.int32, device=self.device).fill_(
92
+ -1
93
+ ),
94
+ )
95
+ self._cu_kv_seq_lengths_buf[0] = 0
96
+ self.tensor_copy_and_pad(
97
+ self._cu_kv_seq_lengths_buf[1:],
98
+ torch.cumsum(self._kv_seq_lengths_buf, dim=0),
99
+ real_batch_size,
100
+ padded_active_request_count,
101
+ is_cumulative_tensor=True,
102
+ )
103
+
104
+ if decode_only:
105
+ self._max_seqlen_q = 1
106
+ else:
107
+ self._max_seqlen_q = max(2, padded_active_token_count)
108
+ self._max_seqlen_k = self.max_seqlen
109
+
110
+ self.state_data = {
111
+ "query_lengths": self._query_lengths_buf[:padded_active_request_count],
112
+ "cu_query_seq_lengths": self._cu_query_seq_lengths_buf[
113
+ : padded_active_request_count + 1
114
+ ],
115
+ "cu_kv_seq_lengths": self._cu_kv_seq_lengths_buf[: padded_active_request_count + 1],
116
+ "kv_seq_lengths": self._kv_seq_lengths_buf[:padded_active_request_count],
117
+ "block_table": self._block_table_buf[0:padded_active_request_count, :],
118
+ "max_seqlen_q": self._max_seqlen_q,
119
+ "max_seqlen_k": self._max_seqlen_k,
120
+ }
121
+
122
+ def reset(self):
123
+ """
124
+ Reset the metadata for the next batch.
125
+ """
126
+ self._query_lengths_buf.fill_(0)
127
+ self._cu_query_seq_lengths_buf.fill_(0)
128
+ self._cu_kv_seq_lengths_buf.fill_(0)
129
+ self._kv_seq_lengths_buf.fill_(0)
130
+ self._block_table_buf.fill_(0)
131
+ self._max_seqlen_q = 0
132
+ self._max_seqlen_k = 0
133
+
134
+
135
+ class GraphedMHAMetadata(MHAMetadata):
136
+ """
137
+ Metadata for MHA layer using flash-attention with CUDA graphs.
138
+ """
139
+
140
+ def __init__(
141
+ self, block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen
142
+ ):
143
+ super().__init__(
144
+ block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen
145
+ )
146
+
147
+ def update(
148
+ self,
149
+ request_query_lengths: torch.Tensor,
150
+ request_kv_length_offsets: torch.Tensor,
151
+ request_to_kv_block_ids: torch.Tensor,
152
+ padded_active_token_count: int,
153
+ real_batch_size: int,
154
+ padded_active_request_count: Optional[int] = None,
155
+ decode_only: bool = False,
156
+ ):
157
+ """
158
+ Args:
159
+ request_query_lengths: (>real_batch_size,)
160
+ request_kv_length_offsets: (>real_batch_size,)
161
+ request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
162
+ padded_active_token_count: int
163
+ real_batch_size: int
164
+ padded_active_request_count: Optional[int]
165
+ decode_only: bool
166
+ """
167
+ super().update(
168
+ request_query_lengths,
169
+ request_kv_length_offsets,
170
+ request_to_kv_block_ids,
171
+ padded_active_token_count,
172
+ real_batch_size,
173
+ padded_active_request_count,
174
+ decode_only,
175
+ )
176
+
177
+ def reset(self):
178
+ super().reset()
179
+
180
+
181
+ class NonGraphedMHAMetadata(MHAMetadata):
182
+ """
183
+ Metadata for MHA layer using flash-attention without CUDA graphs.
184
+ """
185
+
186
+ def update(
187
+ self,
188
+ request_query_lengths: torch.Tensor,
189
+ request_kv_length_offsets: torch.Tensor,
190
+ request_to_kv_block_ids: torch.Tensor,
191
+ padded_active_token_count: int,
192
+ real_batch_size: int,
193
+ padded_active_request_count: Optional[int] = None,
194
+ decode_only: bool = False,
195
+ ):
196
+ """
197
+ Args:
198
+ request_query_lengths: (>real_batch_size,)
199
+ request_kv_length_offsets: (>real_batch_size,)
200
+ request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
201
+ padded_active_token_count: int
202
+ real_batch_size: int
203
+ padded_active_request_count: Optional[int]
204
+ decode_only: bool
205
+ """
206
+ super().update(
207
+ request_query_lengths,
208
+ request_kv_length_offsets,
209
+ request_to_kv_block_ids,
210
+ padded_active_token_count,
211
+ real_batch_size,
212
+ padded_active_request_count,
213
+ decode_only,
214
+ )
215
+ if len(self.state_data["query_lengths"]) > 0:
216
+ self.state_data["max_seqlen_q"] = torch.max(self.state_data["query_lengths"]).item()
217
+ self.state_data["max_seqlen_k"] = torch.max(self.state_data["kv_seq_lengths"]).item()
218
+ else:
219
+ self.state_data["max_seqlen_q"] = 1
220
+ self.state_data["max_seqlen_k"] = 1
@@ -4,7 +4,7 @@ import math
4
4
  import warnings
5
5
  from contextlib import nullcontext
6
6
  from enum import Enum
7
- from typing import List, Optional, Tuple
7
+ from typing import TYPE_CHECKING, List, Optional, Tuple
8
8
 
9
9
  import torch
10
10
  import torch.nn.functional as F
@@ -26,6 +26,7 @@ from megatron.core.package_info import __version__ as mcore_version
26
26
  from megatron.core.transformer import TransformerConfig
27
27
  from megatron.core.utils import divide as core_divide
28
28
 
29
+ from .attention_context.mha_metadata import GraphedMHAMetadata, NonGraphedMHAMetadata
29
30
  from .base_context import BaseInferenceContext
30
31
  from .dynamic_block_allocator import BlockAllocator
31
32
 
@@ -48,6 +49,17 @@ try:
48
49
  except ImportError:
49
50
  HAVE_FLASHINFER = False
50
51
 
52
+ try:
53
+ import wandb # pylint: disable=unused-import
54
+
55
+ HAVE_WANDB = True
56
+ except ImportError:
57
+ HAVE_WANDB = False
58
+ wandb = None
59
+
60
+ if TYPE_CHECKING:
61
+ import wandb as WandbModule
62
+
51
63
 
52
64
  class ContextOverflowError(Exception):
53
65
  """Base exception for when a new request does not fit.
@@ -225,6 +237,7 @@ class DynamicInferenceContext(BaseInferenceContext):
225
237
  levels will be included to control other tensors within the context.
226
238
  use_flashinfer_fused_rope (bool): If True, use flashinfer's fused rope implementation.
227
239
  If None, defaults to using flash-infer if available.
240
+ metrics_writer (Optional['WandbModule']): Wandb module for writing metrics.
228
241
  """
229
242
 
230
243
  def __init__(
@@ -250,6 +263,7 @@ class DynamicInferenceContext(BaseInferenceContext):
250
263
  use_cuda_graphs_for_non_decode_steps: bool = True,
251
264
  use_flashinfer_fused_rope: bool = False,
252
265
  unified_memory_level: Optional[int] = 0,
266
+ metrics_writer: Optional['WandbModule'] = None,
253
267
  ):
254
268
  super().__init__(materialize_only_last_token_logits=materialize_only_last_token_logits)
255
269
 
@@ -259,6 +273,8 @@ class DynamicInferenceContext(BaseInferenceContext):
259
273
  block_size_tokens == 64
260
274
  ), "Flash MLA requires a block size of 64. Set --inference-dynamic-batching-block-size 64 to fix this assert"
261
275
 
276
+ self.metrics_writer = metrics_writer
277
+
262
278
  # Per partition num heads and hidden size.
263
279
  projection_size = kv_channels * num_attention_heads
264
280
  if tensor_model_parallel_size is None:
@@ -454,30 +470,26 @@ class DynamicInferenceContext(BaseInferenceContext):
454
470
  num_cuda_graphs is not None
455
471
  )
456
472
 
457
- # `*_cudagraph_only` tensors are for use with cuda graphs to maintain
458
- # consistent input shapes, which is required to use cuda graphs.
459
- # During these steps, the `*_cudagraph_only`
460
- # tensors are used, otherwise their same-name but un-suffixed
461
- # corresponding tensors are used.
473
+ # Attention metadata initialization (tensors are now handled by MHAMetadata classes)
462
474
 
463
- self.query_seq_lengths_cudagraph_only = torch.full(
464
- (self.max_requests,), 0, dtype=torch.int32, device=torch.cuda.current_device()
465
- )
466
- self.cu_query_seq_lengths_cudagraph_only = torch.full(
467
- (self.max_requests + 1,), 0, dtype=torch.int32, device=torch.cuda.current_device()
468
- )
469
- self.kv_seq_lengths_cudagraph_only = torch.full(
470
- (self.max_requests,), 0, dtype=torch.int32, device=torch.cuda.current_device()
471
- )
472
- self.cu_kv_seq_lengths_cudagraph_only = torch.full(
473
- (self.max_requests + 1,), 0, dtype=torch.int32, device=torch.cuda.current_device()
475
+ self.graph_attn_metadata = {}
476
+ self.non_graph_attn_metadata = {}
477
+ self.active_attn_metadata = None
478
+
479
+ self.graph_attn_metadata["mha_metadata"] = GraphedMHAMetadata(
480
+ block_count_total=block_count_total,
481
+ max_kv_block_count=self.max_kv_block_count,
482
+ max_requests=self.max_requests,
483
+ block_size_tokens=self.block_size_tokens,
484
+ max_seqlen=self.max_sequence_length,
474
485
  )
475
486
 
476
- self.request_to_kv_block_ids_cudagraph_only = torch.full(
477
- (self.max_requests, self.max_kv_block_count),
478
- 0,
479
- dtype=torch.int,
480
- device=torch.cuda.current_device(),
487
+ self.non_graph_attn_metadata["mha_metadata"] = NonGraphedMHAMetadata(
488
+ block_count_total=block_count_total,
489
+ max_kv_block_count=self.max_kv_block_count,
490
+ max_requests=self.max_requests,
491
+ block_size_tokens=self.block_size_tokens,
492
+ max_seqlen=self.max_sequence_length,
481
493
  )
482
494
 
483
495
  # Guaranteed active requests.
@@ -627,11 +639,18 @@ class DynamicInferenceContext(BaseInferenceContext):
627
639
 
628
640
  def cu_query_lengths(self) -> Tuple[Tensor, int]:
629
641
  """Cumulative query sequence lengths."""
630
- return self.cu_query_seq_lengths, self.max_seqlen_q
642
+ return (
643
+ self.active_attn_metadata["mha_metadata"].state_data["cu_query_seq_lengths"],
644
+ self.active_attn_metadata["mha_metadata"].state_data["max_seqlen_q"],
645
+ )
631
646
 
632
- def cu_kv_lengths(self) -> Tensor:
647
+ def cu_kv_lengths(self) -> Tuple[Tensor, Tensor, int]:
633
648
  """Cumulative key/value sequence lengths."""
634
- return (self.cu_kv_seq_lengths, self.kv_seq_lengths, self.max_seqlen_k)
649
+ return (
650
+ self.active_attn_metadata["mha_metadata"].state_data["cu_kv_seq_lengths"],
651
+ self.active_attn_metadata["mha_metadata"].state_data["kv_seq_lengths"],
652
+ self.active_attn_metadata["mha_metadata"].state_data["max_seqlen_k"],
653
+ )
635
654
 
636
655
  def get_active_sequence_lengths(self) -> Tensor:
637
656
  """Total sequence length (query + key) for active requests."""
@@ -709,12 +728,16 @@ class DynamicInferenceContext(BaseInferenceContext):
709
728
  to blocks within the block-level memory buffer.
710
729
  """
711
730
  if self.cache_mla_latent:
712
- return (self.memory_buffer[layer_number - 1], None, self.block_table)
731
+ return (
732
+ self.memory_buffer[layer_number - 1],
733
+ None,
734
+ self.active_attn_metadata["mha_metadata"].state_data["block_table"],
735
+ )
713
736
  else:
714
737
  return (
715
738
  self.memory_buffer[0, layer_number - 1],
716
739
  self.memory_buffer[1, layer_number - 1],
717
- self.block_table,
740
+ self.active_attn_metadata["mha_metadata"].state_data["block_table"],
718
741
  )
719
742
 
720
743
  def apply_fused_qk_rotary_emb(
@@ -824,17 +847,12 @@ class DynamicInferenceContext(BaseInferenceContext):
824
847
 
825
848
  def reset_attention_state(self) -> None:
826
849
  """Reset state used within attention, after each step."""
827
- self.max_seqlen_q = None
828
- self.max_seqlen_k = None
829
- self.cu_query_seq_lengths = None
830
- self.cu_query_seq_lengths_cudagraph_only.fill_(0)
831
- self.query_seq_lengths_cudagraph_only.fill_(0)
832
- self.cu_kv_seq_lengths = None
833
- self.cu_kv_seq_lengths_cudagraph_only.fill_(0)
834
- self.kv_seq_lengths = None
835
- self.kv_seq_lengths_cudagraph_only.fill_(0)
836
- self.request_to_kv_block_ids_cudagraph_only.fill_(0)
837
- self.block_table = None
850
+ # Attention metadata reset is now handled by MHAMetadata.reset()
851
+ for attn_metadata in self.non_graph_attn_metadata.values():
852
+ attn_metadata.reset()
853
+ for attn_metadata in self.graph_attn_metadata.values():
854
+ attn_metadata.reset()
855
+ self.active_attn_metadata = None
838
856
 
839
857
  def using_cuda_graph_this_step(self) -> bool:
840
858
  """Returns True if cuda graphs are being used for this step."""
@@ -934,89 +952,30 @@ class DynamicInferenceContext(BaseInferenceContext):
934
952
  self.active_token_count : self.padded_active_token_count
935
953
  ] = 0
936
954
 
937
- # Update cu_query_seq_lengths, max_seqlen_q.
938
- query_lengths = self.request_query_lengths[
939
- self.paused_request_count : self.total_request_count
940
- ]
941
- if self.is_decode_only() or self.using_cuda_graph_this_step():
942
- self.query_seq_lengths_cudagraph_only[
943
- 0 : self.total_request_count - self.paused_request_count
944
- ] = query_lengths
945
- if self.is_decode_only():
946
- self.cu_query_seq_lengths = None # ensure no accidental use
947
- self.max_seqlen_q = 1
948
- else:
949
- self.cu_query_seq_lengths_cudagraph_only[
950
- 1 : self.padded_active_request_count + 1
951
- ] = torch.cumsum(
952
- self.query_seq_lengths_cudagraph_only[: self.padded_active_request_count], dim=0
953
- )
954
-
955
- # The following will be passed to the FA kernel.
956
- self.cu_query_seq_lengths = self.cu_query_seq_lengths_cudagraph_only[
957
- : (self.padded_active_request_count + 1)
958
- ]
959
- self.max_seqlen_q = self.padded_active_token_count
960
- else:
961
- cu_query_lengths = torch.cumsum(query_lengths, dim=0)
962
- self.cu_query_seq_lengths = torch.full(
963
- (self.total_request_count - self.paused_request_count + 1,),
964
- 0,
965
- dtype=torch.int32,
966
- device=torch.cuda.current_device(),
967
- )
968
- self.cu_query_seq_lengths[1:] = cu_query_lengths
969
- self.max_seqlen_q = query_lengths.max().item()
970
-
971
- kv_seq_lengths = self.request_kv_length_offsets + self.request_query_lengths
972
- self.kv_seq_lengths = kv_seq_lengths[self.paused_request_count : self.total_request_count]
973
- if self.is_decode_only() or self.using_cuda_graph_this_step():
974
- # Re-assign `kv_seq_lengths` to be a view of the first
975
- # `active_cuda_graph_request_count` tokens of `kv_seq_lengths_decode_only`,
976
- # such that `kv_seq_lengths` has a static memory address and is therefore
977
- # cuda graph compatible. This allows `kv_seq_lengths` to transition between,
978
- # cuda graph sizes, which makes multi-batch-size cuda graphs possible.
979
- self.kv_seq_lengths_cudagraph_only[
980
- 0 : self.total_request_count - self.paused_request_count
981
- ] = self.kv_seq_lengths
982
- self.kv_seq_lengths = self.kv_seq_lengths_cudagraph_only[
983
- : self.padded_active_request_count
984
- ]
985
- self.max_seqlen_k = self.max_sequence_length
986
- if self.is_decode_only():
987
- self.cu_kv_seq_lengths = None # ensure no accidental use
988
- else:
989
- cu_kv_lengths = torch.cumsum(self.kv_seq_lengths, dim=0)
990
- # The following will be passed to the FA kernel.
991
- self.cu_kv_seq_lengths_cudagraph_only[1 : cu_kv_lengths.size(0) + 1] = cu_kv_lengths
992
- self.cu_kv_seq_lengths = self.cu_kv_seq_lengths_cudagraph_only[
993
- : (self.padded_active_request_count + 1)
994
- ]
995
- else:
996
- self.cu_kv_seq_lengths = torch.full(
997
- (self.total_request_count - self.paused_request_count + 1,),
998
- 0,
999
- dtype=torch.int32,
1000
- device=torch.cuda.current_device(),
1001
- )
1002
- self.cu_kv_seq_lengths[1:] = torch.cumsum(self.kv_seq_lengths, dim=0)
1003
- self.max_seqlen_k = self.kv_seq_lengths.max().item()
955
+ real_req_batch_size = (
956
+ self.total_request_count - self.paused_request_count
957
+ ) # how many requests are indeed active
958
+ self.active_attn_metadata = (
959
+ self.graph_attn_metadata
960
+ if self.using_cuda_graph_this_step()
961
+ else self.non_graph_attn_metadata
962
+ )
1004
963
 
1005
- # Update KV block IDs, block table.
1006
- request_to_kv_block_ids = self.request_to_kv_block_ids[
1007
- self.paused_request_count : self.total_request_count
1008
- ]
1009
- if self.is_decode_only() or self.using_cuda_graph_this_step():
1010
- self.request_to_kv_block_ids_cudagraph_only[
1011
- 0 : self.total_request_count - self.paused_request_count
1012
- ] = request_to_kv_block_ids
1013
- self.block_table = self.request_to_kv_block_ids_cudagraph_only[
1014
- : self.padded_active_request_count
1015
- ]
1016
- else:
1017
- self.block_table = self.request_to_kv_block_ids[
1018
- self.paused_request_count : self.total_request_count
1019
- ]
964
+ # Update cu_query_seq_lengths, max_seqlen_q.
965
+ active_slice = slice(self.paused_request_count, self.total_request_count)
966
+ query_lengths_view = self.request_query_lengths[active_slice]
967
+ request_kv_length_offsets_view = self.request_kv_length_offsets[active_slice]
968
+ request_to_kv_block_ids_view = self.request_to_kv_block_ids[active_slice]
969
+ self.active_attn_metadata["mha_metadata"].update(
970
+ request_query_lengths=query_lengths_view,
971
+ request_kv_length_offsets=request_kv_length_offsets_view,
972
+ request_to_kv_block_ids=request_to_kv_block_ids_view,
973
+ padded_active_token_count=self.padded_active_token_count,
974
+ real_batch_size=real_req_batch_size,
975
+ padded_active_request_count=self.padded_active_request_count,
976
+ decode_only=self.is_decode_only(),
977
+ )
978
+ # All attention metadata calculations are now handled by MHAMetadata.update()
1020
979
 
1021
980
  def reset(self) -> None:
1022
981
  """Reset entire context.
@@ -1625,3 +1584,67 @@ class DynamicInferenceContext(BaseInferenceContext):
1625
1584
 
1626
1585
  # Convert each log prob tensor into a list
1627
1586
  return [lp.tolist() for lp in selected_log_probs_list]
1587
+
1588
+ def get_kvcache_utilization_stats(self) -> dict:
1589
+ """Compute KV cache buffer utilization stats for the current step.
1590
+
1591
+ Returns a dictionary with counts and percentages for both allocated block
1592
+ usage (overall buffer occupancy) and active usage (blocks referenced by
1593
+ currently active requests this step).
1594
+
1595
+ Return:
1596
+ {
1597
+ 'total_blocks': int,
1598
+ 'allocated_blocks': int,
1599
+ 'active_unique_blocks': int,
1600
+ 'allocated_utilization': float,
1601
+ 'active_utilization': float,
1602
+ 'active_request_count': int,
1603
+ 'paused_request_count': int,
1604
+ 'gtd_block_count': int,
1605
+ }
1606
+ """
1607
+ # Total usable blocks exclude the reserved dummy block.
1608
+ total_blocks = max(self.block_allocator.block_count_total - 1, 1)
1609
+ block_count_avail = int(self.block_allocator.block_count_avail)
1610
+
1611
+ # Overall allocated blocks in the buffer right now.
1612
+ allocated_blocks = (self.block_allocator.block_count_total - 1) - block_count_avail
1613
+ allocated_blocks = int(max(0, allocated_blocks))
1614
+
1615
+ # Active unique blocks referenced by current active requests only.
1616
+ active_start = self.paused_request_count
1617
+ active_end = self.total_request_count
1618
+ if active_end > active_start:
1619
+ active_rows = self.request_to_kv_block_ids[active_start:active_end]
1620
+ # Filter valid block ids (>= 0) and count unique ids.
1621
+ valid_ids = active_rows[active_rows >= 0]
1622
+ if valid_ids.numel() > 0:
1623
+ unique_ids = torch.unique(valid_ids)
1624
+ active_unique_blocks = int(unique_ids.numel())
1625
+ else:
1626
+ active_unique_blocks = 0
1627
+ else:
1628
+ active_unique_blocks = 0
1629
+
1630
+ allocated_utilization = float(allocated_blocks) / float(total_blocks)
1631
+ active_utilization = float(active_unique_blocks) / float(total_blocks)
1632
+
1633
+ # Diagnostic helpers
1634
+ num_non_gtd_blocks = max(0, block_count_avail - int(self.gtd_block_count))
1635
+ total_request_count = int(self.total_request_count)
1636
+ return {
1637
+ 'total_blocks': int(total_blocks),
1638
+ 'allocated_blocks': int(allocated_blocks),
1639
+ 'active_unique_blocks': int(active_unique_blocks),
1640
+ 'allocated_utilization': allocated_utilization,
1641
+ 'active_utilization': active_utilization,
1642
+ 'active_request_count': int(self.get_active_request_count()),
1643
+ 'paused_request_count': int(self.paused_request_count),
1644
+ 'gtd_block_count': int(self.gtd_block_count),
1645
+ 'block_count_avail': int(block_count_avail),
1646
+ 'num_non_gtd_blocks': int(num_non_gtd_blocks),
1647
+ 'active_token_count': int(self.active_token_count),
1648
+ 'total_request_count': int(total_request_count),
1649
+ 'max_requests': int(self.max_requests),
1650
+ }