megatron-core 0.16.0rc0.dev122519__tar.gz → 0.16.0rc0.dev123313__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/PKG-INFO +14 -7
  2. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/megatron_tokenizer.py +9 -0
  3. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fp8_utils.py +49 -0
  4. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/async_stream.py +8 -2
  5. megatron_core-0.16.0rc0.dev123313/megatron/core/inference/contexts/attention_context/mamba_metadata.py +106 -0
  6. megatron_core-0.16.0rc0.dev123313/megatron/core/inference/contexts/attention_context/metadata_base.py +72 -0
  7. megatron_core-0.16.0rc0.dev123313/megatron/core/inference/contexts/attention_context/mha_metadata.py +220 -0
  8. megatron_core-0.16.0rc0.dev123313/megatron/core/inference/contexts/dynamic_block_allocator.py +118 -0
  9. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/contexts/dynamic_context.py +442 -284
  10. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/contexts/fused_kv_append_kernel.py +2 -2
  11. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/data_parallel_inference_coordinator.py +7 -0
  12. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/engines/dynamic_engine.py +125 -21
  13. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/engines/static_engine.py +4 -8
  14. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/inference_client.py +3 -1
  15. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/sampling_params.py +1 -0
  16. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +7 -7
  17. megatron_core-0.16.0rc0.dev123313/megatron/core/inference/unified_memory.py +127 -0
  18. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/utils.py +28 -0
  19. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/gpt/gpt_model.py +2 -5
  20. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mamba/mamba_model.py +30 -1
  21. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/package_info.py +1 -1
  22. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/safe_globals.py +2 -0
  23. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/ssm/mamba_block.py +16 -25
  24. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +29 -2
  25. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/ssm/mamba_layer.py +5 -5
  26. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/ssm/mamba_mixer.py +301 -57
  27. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/attention.py +14 -3
  28. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/cuda_graphs.py +5 -1
  29. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/dot_product_attention.py +2 -0
  30. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/router.py +2 -0
  31. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/pipeline_parallel_layer_layout.py +5 -2
  32. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/utils.py +143 -1
  33. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron_core.egg-info/PKG-INFO +14 -7
  34. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron_core.egg-info/SOURCES.txt +3 -0
  35. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron_core.egg-info/requires.txt +13 -6
  36. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/pyproject.toml +13 -6
  37. megatron_core-0.16.0rc0.dev122519/megatron/core/inference/contexts/dynamic_block_allocator.py +0 -92
  38. megatron_core-0.16.0rc0.dev122519/megatron/core/inference/unified_memory.py +0 -89
  39. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/MANIFEST.in +0 -0
  40. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/README.md +0 -0
  41. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/README.md +0 -0
  42. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/__init__.py +0 -0
  43. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/activations.py +0 -0
  44. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/config.py +0 -0
  45. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/config_logger.py +0 -0
  46. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/__init__.py +0 -0
  47. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/bert_dataset.py +0 -0
  48. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/blended_dataset.py +0 -0
  49. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  50. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  51. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/gpt_dataset.py +0 -0
  52. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/helpers.cpp +0 -0
  53. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/helpers.py +0 -0
  54. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/indexed_dataset.py +0 -0
  55. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/masked_dataset.py +0 -0
  56. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/megatron_dataset.py +0 -0
  57. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/multimodal_dataset.py +0 -0
  58. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/object_storage_utils.py +0 -0
  59. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/__init__.py +0 -0
  60. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/config/__init__.py +0 -0
  61. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  62. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/config/config.py +0 -0
  63. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  64. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  65. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/db/__init__.py +0 -0
  66. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/db/build.py +0 -0
  67. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/db/dataset.py +0 -0
  68. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/db/utils.py +0 -0
  69. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/external_libs.py +0 -0
  70. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/index/__init__.py +0 -0
  71. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/index/build.py +0 -0
  72. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/index/factory.py +0 -0
  73. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/index/index.py +0 -0
  74. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  75. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  76. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  77. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/index/utils.py +0 -0
  78. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/index/validate.py +0 -0
  79. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/query/__init__.py +0 -0
  80. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  81. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  82. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/query/query.py +0 -0
  83. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  84. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/query/utils.py +0 -0
  85. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/retro/utils.py +0 -0
  86. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/t5_dataset.py +0 -0
  87. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/utils.py +0 -0
  88. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/datasets/utils_s3.py +0 -0
  89. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/__init__.py +0 -0
  90. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/core.py +0 -0
  91. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  92. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  93. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/mapping.py +0 -0
  94. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  95. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/serialization.py +0 -0
  96. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  97. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  98. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  99. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  100. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  101. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/checkpointable.py +0 -0
  102. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  103. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  104. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  105. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  106. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  107. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  108. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  109. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  110. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  111. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  112. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/utils.py +0 -0
  113. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/dist_checkpointing/validation.py +0 -0
  114. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/__init__.py +0 -0
  115. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/data_parallel_base.py +0 -0
  116. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  117. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  118. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/finalize_model_grads.py +0 -0
  119. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/__init__.py +0 -0
  120. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +0 -0
  121. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/src/__init__.py +0 -0
  122. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +0 -0
  123. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +0 -0
  124. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +0 -0
  125. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +0 -0
  126. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/src/megatron_fsdp/package_info.py +0 -0
  127. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +0 -0
  128. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +0 -0
  129. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +0 -0
  130. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  131. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/reduce_scatter_with_fp32_accumulation.py +0 -0
  132. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  133. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  134. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/energy_monitor.py +0 -0
  135. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/enums.py +0 -0
  136. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/__init__.py +0 -0
  137. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/data_type.py +0 -0
  138. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/export_config.py +0 -0
  139. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/model_type.py +0 -0
  140. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/__init__.py +0 -0
  141. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  142. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  143. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  144. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  145. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  146. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  147. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  148. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  149. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  150. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  151. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  152. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  153. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/extensions/__init__.py +0 -0
  154. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/extensions/kitchen.py +0 -0
  155. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/extensions/transformer_engine.py +0 -0
  156. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  157. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fp4_utils.py +0 -0
  158. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/full_cuda_graph.py +0 -0
  159. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/__init__.py +0 -0
  160. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  161. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  162. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  163. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  164. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  165. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_indices_converter.py +0 -0
  166. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_layer_norm.py +0 -0
  167. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  168. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  169. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_softmax.py +0 -0
  170. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/fusions/fused_weighted_squared_relu.py +0 -0
  171. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/hyper_comm_grid.py +0 -0
  172. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/__init__.py +0 -0
  173. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/common_inference_params.py +0 -0
  174. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/communication_utils.py +0 -0
  175. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/contexts/__init__.py +0 -0
  176. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/contexts/base_context.py +0 -0
  177. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/contexts/static_context.py +0 -0
  178. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/engines/__init__.py +0 -0
  179. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/engines/abstract_engine.py +0 -0
  180. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/engines/mcore_engine.py +0 -0
  181. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/headers.py +0 -0
  182. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/inference_request.py +0 -0
  183. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  184. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  185. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  186. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  187. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  188. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  189. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  190. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  191. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/scheduler.py +0 -0
  192. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  193. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  194. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  195. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  196. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_server/__init__.py +0 -0
  197. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_server/endpoints/common.py +0 -0
  198. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_server/endpoints/completions.py +0 -0
  199. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_server/run_mcore_engine.py +0 -0
  200. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_server/text_generation_server.py +0 -0
  201. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference/text_generation_server/tokenization.py +0 -0
  202. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/inference_params.py +0 -0
  203. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/jit.py +0 -0
  204. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/model_parallel_config.py +0 -0
  205. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/T5/__init__.py +0 -0
  206. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/T5/t5_model.py +0 -0
  207. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/T5/t5_spec.py +0 -0
  208. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/__init__.py +0 -0
  209. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/backends.py +0 -0
  210. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/bert/__init__.py +0 -0
  211. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  212. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/bert/bert_lm_head.py +0 -0
  213. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/bert/bert_model.py +0 -0
  214. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/bert/pooler.py +0 -0
  215. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/__init__.py +0 -0
  216. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/embeddings/__init__.py +0 -0
  217. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  218. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  219. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  220. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  221. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  222. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/language_module/__init__.py +0 -0
  223. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/language_module/language_module.py +0 -0
  224. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
  225. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/vision_module/__init__.py +0 -0
  226. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  227. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/gpt/__init__.py +0 -0
  228. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  229. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
  230. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  231. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  232. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/huggingface/__init__.py +0 -0
  233. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/huggingface/clip_model.py +0 -0
  234. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/huggingface/module.py +0 -0
  235. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/huggingface/qwen_model.py +0 -0
  236. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mamba/__init__.py +0 -0
  237. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  238. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mimo/__init__.py +0 -0
  239. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mimo/config/__init__.py +0 -0
  240. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mimo/config/base_configs.py +0 -0
  241. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mimo/model/__init__.py +0 -0
  242. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mimo/model/base.py +0 -0
  243. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mimo/submodules/audio.py +0 -0
  244. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mimo/submodules/base.py +0 -0
  245. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/mimo/submodules/vision.py +0 -0
  246. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/multimodal/__init__.py +0 -0
  247. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/multimodal/context_parallel.py +0 -0
  248. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/multimodal/llava_model.py +0 -0
  249. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/multimodal/llava_spec.py +0 -0
  250. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/retro/__init__.py +0 -0
  251. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/retro/base_attention.py +0 -0
  252. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/retro/config.py +0 -0
  253. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/retro/decoder_attention.py +0 -0
  254. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/retro/decoder_spec.py +0 -0
  255. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/retro/encoder_attention.py +0 -0
  256. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/retro/encoder_spec.py +0 -0
  257. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/retro/model.py +0 -0
  258. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/retro/utils.py +0 -0
  259. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/vision/__init__.py +0 -0
  260. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/vision/clip_vit_model.py +0 -0
  261. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/vision/multimodal_projector.py +0 -0
  262. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/vision/radio.py +0 -0
  263. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  264. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/msc_utils.py +0 -0
  265. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/nccl_allocator.py +0 -0
  266. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/num_microbatches_calculator.py +0 -0
  267. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/optimizer/__init__.py +0 -0
  268. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/optimizer/clip_grads.py +0 -0
  269. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  270. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  271. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/optimizer/distrib_optimizer.py +0 -0
  272. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/optimizer/grad_scaler.py +0 -0
  273. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/optimizer/optimizer.py +0 -0
  274. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/optimizer/optimizer_config.py +0 -0
  275. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/optimizer_param_scheduler.py +0 -0
  276. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/packed_seq_params.py +0 -0
  277. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/parallel_state.py +0 -0
  278. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/pipeline_parallel/__init__.py +0 -0
  279. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/pipeline_parallel/bridge_communicator.py +0 -0
  280. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
  281. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  282. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/pipeline_parallel/schedules.py +0 -0
  283. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/pipeline_parallel/utils.py +0 -0
  284. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/post_training/__init__.py +0 -0
  285. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/post_training/modelopt/__init__.py +0 -0
  286. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  287. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  288. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  289. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/post_training/modelopt/layers.py +0 -0
  290. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  291. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  292. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/process_groups_config.py +0 -0
  293. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/quantization/__init__.py +0 -0
  294. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/quantization/quant_config.py +0 -0
  295. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/quantization/utils.py +0 -0
  296. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/requirements.txt +0 -0
  297. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/rerun_state_machine.py +0 -0
  298. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/ssm/__init__.py +0 -0
  299. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  300. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/ssm/mlp_layer.py +0 -0
  301. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/ssm/triton_cache_manager.py +0 -0
  302. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tensor_parallel/__init__.py +0 -0
  303. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  304. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tensor_parallel/data.py +0 -0
  305. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tensor_parallel/layers.py +0 -0
  306. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tensor_parallel/mappings.py +0 -0
  307. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tensor_parallel/random.py +0 -0
  308. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tensor_parallel/utils.py +0 -0
  309. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/timers.py +0 -0
  310. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/__init__.py +0 -0
  311. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/base_tokenizer.py +0 -0
  312. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/megatron_tokenizer.py +0 -0
  313. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/__init__.py +0 -0
  314. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/libraries/__init__.py +0 -0
  315. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/libraries/abstract_tokenizer.py +0 -0
  316. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/libraries/bytelevel_tokenizer.py +0 -0
  317. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/libraries/chat_template.py +0 -0
  318. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py +0 -0
  319. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py +0 -0
  320. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/libraries/null_tokenizer.py +0 -0
  321. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py +0 -0
  322. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py +0 -0
  323. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/models/__init__.py +0 -0
  324. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/models/bert_tokenizer.py +0 -0
  325. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/models/default_tokenizer.py +0 -0
  326. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/models/gpt_tokenizer.py +0 -0
  327. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/models/mamba_tokenizer.py +0 -0
  328. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/models/retro_tokenizer.py +0 -0
  329. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/models/t5_tokenizer.py +0 -0
  330. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/text_tokenizer.py +0 -0
  331. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/tokenizers/text/utils/build_tokenizer.py +0 -0
  332. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/__init__.py +0 -0
  333. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  334. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  335. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/enums.py +0 -0
  336. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/fsdp_dtensor_checkpoint.py +0 -0
  337. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  338. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  339. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/identity_op.py +0 -0
  340. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/mlp.py +0 -0
  341. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/module.py +0 -0
  342. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/__init__.py +0 -0
  343. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/experts.py +0 -0
  344. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  345. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  346. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/moe_layer.py +0 -0
  347. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/moe_utils.py +0 -0
  348. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/shared_experts.py +0 -0
  349. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  350. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  351. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/multi_latent_attention.py +0 -0
  352. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/multi_token_prediction.py +0 -0
  353. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/spec_utils.py +0 -0
  354. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/torch_layer_norm.py +0 -0
  355. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/torch_norm.py +0 -0
  356. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/transformer_block.py +0 -0
  357. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/transformer_config.py +0 -0
  358. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/transformer_layer.py +0 -0
  359. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron/core/transformer/utils.py +0 -0
  360. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron_core.egg-info/dependency_links.txt +0 -0
  361. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/megatron_core.egg-info/top_level.txt +0 -0
  362. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/setup.cfg +0 -0
  363. {megatron_core-0.16.0rc0.dev122519 → megatron_core-0.16.0rc0.dev123313}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.16.0rc0.dev122519
3
+ Version: 0.16.0rc0.dev123313
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -41,7 +41,7 @@ Requires-Dist: transformers; extra == "mlm"
41
41
  Provides-Extra: dev
42
42
  Requires-Dist: nvidia-modelopt[torch]; sys_platform != "darwin" and extra == "dev"
43
43
  Requires-Dist: transformer-engine[pytorch]<2.10.0,>=2.9.0a0; extra == "dev"
44
- Requires-Dist: nvidia-resiliency-ext<0.5.0,>=0.4.0a0; extra == "dev"
44
+ Requires-Dist: nvidia-resiliency-ext; extra == "dev"
45
45
  Requires-Dist: tqdm; extra == "dev"
46
46
  Requires-Dist: einops~=0.8; extra == "dev"
47
47
  Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "dev"
@@ -59,13 +59,20 @@ Requires-Dist: wget; extra == "dev"
59
59
  Requires-Dist: onnxscript; extra == "dev"
60
60
  Provides-Extra: lts
61
61
  Requires-Dist: tqdm; extra == "lts"
62
- Requires-Dist: einops; extra == "lts"
63
- Requires-Dist: tensorstore!=0.1.46,!=0.1.72; extra == "lts"
64
- Requires-Dist: nvtx; extra == "lts"
65
- Requires-Dist: transformers; extra == "lts"
66
- Requires-Dist: zarr; extra == "lts"
62
+ Requires-Dist: einops~=0.8; extra == "lts"
63
+ Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "lts"
64
+ Requires-Dist: nvtx~=0.2; extra == "lts"
65
+ Requires-Dist: multi-storage-client~=0.27; extra == "lts"
66
+ Requires-Dist: opentelemetry-api~=1.33.1; extra == "lts"
67
67
  Requires-Dist: setuptools<80.0.0; extra == "lts"
68
+ Requires-Dist: mamba-ssm~=2.2; extra == "lts"
69
+ Requires-Dist: causal-conv1d~=1.5; extra == "lts"
70
+ Requires-Dist: nv-grouped-gemm~=1.1; extra == "lts"
71
+ Requires-Dist: megatron-energon[av_decode]~=6.0; extra == "lts"
72
+ Requires-Dist: av<16.0.0; extra == "lts"
73
+ Requires-Dist: flashinfer-python; extra == "lts"
68
74
  Requires-Dist: wget; extra == "lts"
75
+ Requires-Dist: onnxscript; extra == "lts"
69
76
 
70
77
  <div align="center">
71
78
 
@@ -1,11 +1,14 @@
1
1
  # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
2
  import json
3
+ import logging
3
4
  from abc import ABC, abstractmethod
4
5
  from collections import OrderedDict
5
6
  from typing import Any
6
7
 
7
8
  import numpy
8
9
 
10
+ logger = logging.getLogger(__name__)
11
+
9
12
 
10
13
  class MegatronLegacyTokenizer(ABC):
11
14
  """Abstract class for tokenizer
@@ -20,6 +23,12 @@ class MegatronLegacyTokenizer(ABC):
20
23
  """
21
24
 
22
25
  def __init__(self, *tokenizer_paths: str, **tokenizer_options: Any):
26
+ # Deprecation warning
27
+ logger.warning(
28
+ "You’re using the legacy tokenizer system, which is deprecated "
29
+ "and will be removed in a future release. Please migrate to the new tokenizer system "
30
+ "(`megatron.core.tokenizers.MegatronTokenizer`)."
31
+ )
23
32
  self.unique_identifiers = OrderedDict()
24
33
  self.unique_identifiers["class"] = type(self).__name__
25
34
  self.unique_identifiers["tokenizer_path"] = list(tokenizer_paths)
@@ -10,6 +10,12 @@ from typing import List, Optional
10
10
  import torch
11
11
 
12
12
  from megatron.core.enums import Fp8Recipe
13
+ from megatron.core.tensor_parallel import (
14
+ ColumnParallelLinear,
15
+ RowParallelLinear,
16
+ gather_from_sequence_parallel_region,
17
+ reduce_scatter_to_sequence_parallel_region,
18
+ )
13
19
  from megatron.core.transformer.transformer_config import TransformerConfig
14
20
  from megatron.core.utils import get_te_version, is_te_min_version
15
21
 
@@ -112,6 +118,27 @@ def get_fp8_align_size(fp8_recipe: Fp8Recipe) -> int:
112
118
  return 16
113
119
 
114
120
 
121
+ def is_column_parallel_linear(module):
122
+ """Returns whether the given module is a ColumnParallelLinear layer."""
123
+ if HAVE_TE and (
124
+ isinstance(module, TEColumnParallelLinear)
125
+ or isinstance(module, TELayerNormColumnParallelLinear)
126
+ ):
127
+ return True
128
+ elif isinstance(module, ColumnParallelLinear):
129
+ return True
130
+ return False
131
+
132
+
133
+ def is_row_parallel_linear(module):
134
+ """Returns whether the given module is a RowParallelLinear layer."""
135
+ if HAVE_TE and isinstance(module, TERowParallelLinear):
136
+ return True
137
+ elif isinstance(module, RowParallelLinear):
138
+ return True
139
+ return False
140
+
141
+
115
142
  """
116
143
  The code below abstracts the functionalities needed for implementing "--fp8-param-gather" into
117
144
  several functions. It provides different implementations for each function based on different
@@ -587,6 +614,18 @@ if HAVE_TE:
587
614
  if not FP8GlobalStateManager.is_fp8_enabled():
588
615
  return original_forward(input_tensor, *args, **kwargs)
589
616
 
617
+ # With sequence parallelism we need to all-gather before padding
618
+ # and reduce-scatter after unpadding
619
+ if is_sequence_parallel := getattr(module, "sequence_parallel", False):
620
+ if is_column_parallel_linear(module):
621
+ input_tensor = gather_from_sequence_parallel_region(
622
+ input_tensor, group=module.tp_group
623
+ )
624
+
625
+ # Disable sequence parallelism on the module because we are handling the
626
+ # all-gather and reduce-scatter externally
627
+ module.sequence_parallel = False
628
+
590
629
  seq_len, batch_size, hidden_size = input_tensor.shape
591
630
  # Reshape to (S, B*H) to pad sequence dimension
592
631
  input_2d = input_tensor.reshape(seq_len, -1)
@@ -612,6 +651,16 @@ if HAVE_TE:
612
651
  unpadded_output_2d = _unpad_func(output_2d, [seq_len])
613
652
  unpadded_output = unpadded_output_2d.reshape(seq_len, batch_size, output_hidden_size)
614
653
 
654
+ if is_sequence_parallel:
655
+ # Reduce-scatter after unpadding
656
+ if is_row_parallel_linear(module):
657
+ unpadded_output = reduce_scatter_to_sequence_parallel_region(
658
+ unpadded_output, group=module.tp_group
659
+ )
660
+
661
+ # Reset sequence parallelism flag on the module
662
+ module.sequence_parallel = True
663
+
615
664
  if other_outputs:
616
665
  return (unpadded_output,) + other_outputs
617
666
  else:
@@ -9,6 +9,7 @@ import asyncio
9
9
  from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
10
10
 
11
11
  from megatron.core.inference.inference_request import InferenceRequest
12
+ from megatron.core.utils import get_asyncio_loop
12
13
 
13
14
  STOP_ITERATION = Exception()
14
15
 
@@ -20,12 +21,17 @@ class AsyncStream:
20
21
  Adopted from https://github.com/vllm-project/vllm/blob/eb881ed006ca458b052905e33f0d16dbb428063a/vllm/v1/engine/async_stream.py # pylint: disable=line-too-long
21
22
  """
22
23
 
23
- def __init__(self, request_id: int, cancel: Callable[[str], None]) -> None:
24
+ def __init__(
25
+ self,
26
+ request_id: int,
27
+ cancel: Callable[[str], None],
28
+ loop: Optional[asyncio.AbstractEventLoop] = None,
29
+ ) -> None:
24
30
  self._request_id = request_id
25
31
  self._cancel = cancel
26
32
  self._queue: asyncio.Queue = asyncio.Queue()
27
33
  self._finished = False
28
- self._loop = asyncio.get_running_loop()
34
+ self._loop = get_asyncio_loop(loop)
29
35
 
30
36
  def put(self, item: Union[InferenceRequest, Exception]) -> None:
31
37
  """Adds a new value to the stream"""
@@ -0,0 +1,106 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+
3
+ import torch
4
+
5
+
6
+ class MambaMetadata:
7
+ """Manages the metadata tensors required for Mamba layers during inference."""
8
+
9
+ def __init__(self, max_requests: int):
10
+ """
11
+ Initializes the Mamba slot allocator.
12
+
13
+ Args:
14
+ max_requests (int): The maximum number of concurrent requests.
15
+ """
16
+ self.max_requests = max_requests
17
+
18
+ # Metadata for mapping requests to slots in the static Mamba state buffer
19
+ self.request_to_mamba_state_idx = torch.full(
20
+ (self.max_requests,), -1, dtype=torch.int32, device=torch.cuda.current_device()
21
+ )
22
+
23
+ # Separate mapping used only for CUDA graph compatibility
24
+ self.request_to_mamba_state_idx_cudagraph_only = torch.full(
25
+ (self.max_requests,), -1, dtype=torch.int32, device=torch.cuda.current_device()
26
+ )
27
+
28
+ # Allocator for Mamba state slots
29
+ self.mamba_state_free_slots = torch.arange(
30
+ self.max_requests, dtype=torch.int32, device=torch.cuda.current_device()
31
+ )
32
+ self.mamba_state_free_slot_count = self.max_requests
33
+
34
+ def reset(self) -> None:
35
+ """
36
+ Resets all Mamba states and frees all allocated slots.
37
+ """
38
+ self.request_to_mamba_state_idx.fill_(-1)
39
+ self.request_to_mamba_state_idx_cudagraph_only.fill_(-1)
40
+
41
+ # Re-initialize the free slot pool
42
+ self.mamba_state_free_slots = torch.arange(
43
+ self.max_requests, dtype=torch.int32, device=torch.cuda.current_device()
44
+ )
45
+ self.mamba_state_free_slot_count = self.max_requests
46
+
47
+ def reset_cudagraph_mapping(self) -> None:
48
+ """
49
+ Resets only the CUDA graph mapping tensor.
50
+ """
51
+ self.request_to_mamba_state_idx_cudagraph_only.fill_(-1)
52
+
53
+ def update_cudagraph_mapping(
54
+ self, active_mamba_indices: torch.Tensor, num_active_requests: int
55
+ ) -> None:
56
+ """
57
+ Updates the dedicated CUDA graph mapping tensor with the indices
58
+ of currently active requests.
59
+
60
+ Args:
61
+ active_mamba_indices (Tensor): Tensor containing the Mamba slot indices
62
+ for active requests.
63
+ num_active_requests (int): The number of active requests.
64
+ """
65
+ self.request_to_mamba_state_idx_cudagraph_only[0:num_active_requests] = active_mamba_indices
66
+
67
+ def allocate_slot(self) -> int:
68
+ """
69
+ Allocates a new slot for a request in the Mamba state buffers.
70
+
71
+ Returns:
72
+ int: The index of the allocated slot.
73
+ Returns None if no slots are available.
74
+ """
75
+ if self.mamba_state_free_slot_count == 0:
76
+ return None
77
+
78
+ # Get a free slot
79
+ self.mamba_state_free_slot_count -= 1
80
+ mamba_idx = self.mamba_state_free_slots[self.mamba_state_free_slot_count]
81
+
82
+ return mamba_idx
83
+
84
+ def free_slots(self, request_indices: torch.Tensor) -> None:
85
+ """
86
+ Frees the Mamba state slots associated with the given request indices.
87
+
88
+ Args:
89
+ request_indices (Tensor): A 1D tensor of request indices to free.
90
+ """
91
+ # Get the Mamba state indices for finished requests
92
+ mamba_indices_to_free = self.request_to_mamba_state_idx[request_indices]
93
+
94
+ # Filter out any invalid indices (e.g., -1)
95
+ mamba_indices_to_free = mamba_indices_to_free[mamba_indices_to_free != -1]
96
+ num_to_free = len(mamba_indices_to_free)
97
+
98
+ if num_to_free > 0:
99
+ # Add the freed indices back to the free slot pool
100
+ start_idx = self.mamba_state_free_slot_count
101
+ end_idx = start_idx + num_to_free
102
+ self.mamba_state_free_slots[start_idx:end_idx] = mamba_indices_to_free
103
+ self.mamba_state_free_slot_count = end_idx
104
+
105
+ # Invalidate the Mamba state index for the finished requests
106
+ self.request_to_mamba_state_idx[request_indices] = -1
@@ -0,0 +1,72 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+
3
+
4
+ class MetadataBase:
5
+ """
6
+ Base class for attention metadata.
7
+ High-performance attention kernels often require input metadata in specific
8
+ formats—such as cumulative query lengths, cumulative key/value lengths,
9
+ and similar structures. Moreover, when using CUDA Graphs, these metadata
10
+ buffers must be statically allocated. This class serves as a unified container
11
+ that manages all such metadata in one place.
12
+ """
13
+
14
+ def __init__(self):
15
+ """
16
+ Initialize the metadata.
17
+ """
18
+ self.state_data = {}
19
+
20
+ def update(self, *args, **kwargs):
21
+ """
22
+ Construct the metadata from request states.
23
+ """
24
+ pass
25
+
26
+ def reset(self):
27
+ """
28
+ Reset the metadata.
29
+ """
30
+ pass
31
+
32
+ def tensor_copy_and_pad(
33
+ self,
34
+ tensor_buf,
35
+ unpadded_tensor,
36
+ real_batch_size,
37
+ padded_batch_size,
38
+ is_cumulative_tensor=False,
39
+ pad_value=0,
40
+ ):
41
+ """
42
+ Copy the unpadded tensor to the tensor_buf,
43
+ pad the tensor_buf with zero or the last value of the tensor,
44
+ depending on whether the tensor is cumulative.
45
+ Args:
46
+ tensor_buf: The destination tensor, at least padded_batch_size long.
47
+ unpadded_tensor: The tensor to copy, at least real_batch_size long.
48
+ real_batch_size: The real batch size.
49
+ padded_batch_size: Padded boundary of the tensor.
50
+ is_cumulative_tensor: Whether the tensor is cumulative.
51
+ If True, we pad the tensor_buf with the last value of the unpadded_tensor.
52
+ pad_value: The value to pad the tensor_buf with when the tensor is not cumulative.
53
+ """
54
+ assert real_batch_size <= padded_batch_size
55
+ assert tensor_buf.shape[0] >= padded_batch_size
56
+ assert unpadded_tensor.shape[0] >= real_batch_size
57
+ if is_cumulative_tensor:
58
+ if real_batch_size == 0:
59
+ value = pad_value
60
+ else:
61
+ value = unpadded_tensor[real_batch_size - 1]
62
+ else:
63
+ value = pad_value
64
+ tensor_buf[0:real_batch_size] = unpadded_tensor[:real_batch_size]
65
+ tensor_buf[real_batch_size:padded_batch_size] = value
66
+ return tensor_buf
67
+
68
+ def __str__(self):
69
+ """
70
+ Return a string representation of the metadata.
71
+ """
72
+ return "\n".join([f"{key}: {value}" for key, value in self.state_data.items()])
@@ -0,0 +1,220 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from .metadata_base import MetadataBase
8
+
9
+
10
+ class MHAMetadata(MetadataBase):
11
+ """
12
+ Metadata for MHA layer using flash-attention.
13
+ """
14
+
15
+ def __init__(
16
+ self, block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen
17
+ ):
18
+ super().__init__()
19
+ device = torch.cuda.current_device()
20
+ self.device = device
21
+ self.max_blocks = block_count_total
22
+ self.max_kv_blocks = max_kv_block_count
23
+ self.max_bs = max_requests
24
+ self.max_seqlen = max_seqlen
25
+ self._query_lengths_buf = torch.zeros(self.max_bs, dtype=torch.int32, device=device)
26
+ self._cu_query_seq_lengths_buf = torch.zeros(
27
+ self.max_bs + 1, dtype=torch.int32, device=device
28
+ )
29
+ self._cu_kv_seq_lengths_buf = torch.zeros(self.max_bs + 1, dtype=torch.int32, device=device)
30
+ self._kv_seq_lengths_buf = torch.zeros(self.max_bs, dtype=torch.int32, device=device)
31
+ self._block_table_buf = torch.zeros(
32
+ (self.max_bs, self.max_kv_blocks), dtype=torch.int32, device=device
33
+ )
34
+ self._max_seqlen_q = 0
35
+ self._max_seqlen_k = 0
36
+ self.state_data = {}
37
+
38
+ def update(
39
+ self,
40
+ request_query_lengths: torch.Tensor,
41
+ request_kv_length_offsets: torch.Tensor,
42
+ request_to_kv_block_ids: torch.Tensor,
43
+ padded_active_token_count: int,
44
+ real_batch_size: int,
45
+ padded_active_request_count: Optional[int] = None,
46
+ decode_only: bool = False,
47
+ ):
48
+ """
49
+ Args:
50
+ request_query_lengths: (>real_batch_size,)
51
+ request_kv_length_offsets: (>real_batch_size,)
52
+ request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
53
+ padded_active_token_count: int
54
+ real_batch_size: int
55
+ padded_active_request_count: Optional[int]
56
+ decode_only: bool
57
+ """
58
+ if padded_active_request_count is None:
59
+ padded_active_request_count = real_batch_size
60
+
61
+ assert real_batch_size <= padded_active_request_count <= self.max_bs
62
+ assert request_query_lengths.shape[0] == real_batch_size
63
+ assert request_kv_length_offsets.shape[0] == real_batch_size
64
+ assert request_to_kv_block_ids.shape[0] == real_batch_size
65
+
66
+ self.tensor_copy_and_pad(
67
+ self._query_lengths_buf,
68
+ request_query_lengths,
69
+ real_batch_size,
70
+ padded_active_request_count,
71
+ )
72
+ self._cu_query_seq_lengths_buf[0] = 0
73
+ self.tensor_copy_and_pad(
74
+ self._cu_query_seq_lengths_buf[1:],
75
+ torch.cumsum(request_query_lengths, dim=0),
76
+ real_batch_size,
77
+ padded_active_request_count,
78
+ is_cumulative_tensor=True,
79
+ )
80
+ self.tensor_copy_and_pad(
81
+ self._kv_seq_lengths_buf,
82
+ request_kv_length_offsets + request_query_lengths,
83
+ real_batch_size,
84
+ padded_active_request_count,
85
+ )
86
+ self.tensor_copy_and_pad(
87
+ self._block_table_buf,
88
+ request_to_kv_block_ids,
89
+ real_batch_size,
90
+ padded_active_request_count,
91
+ pad_value=torch.tensor(self.max_kv_blocks, dtype=torch.int32, device=self.device).fill_(
92
+ -1
93
+ ),
94
+ )
95
+ self._cu_kv_seq_lengths_buf[0] = 0
96
+ self.tensor_copy_and_pad(
97
+ self._cu_kv_seq_lengths_buf[1:],
98
+ torch.cumsum(self._kv_seq_lengths_buf, dim=0),
99
+ real_batch_size,
100
+ padded_active_request_count,
101
+ is_cumulative_tensor=True,
102
+ )
103
+
104
+ if decode_only:
105
+ self._max_seqlen_q = 1
106
+ else:
107
+ self._max_seqlen_q = max(2, padded_active_token_count)
108
+ self._max_seqlen_k = self.max_seqlen
109
+
110
+ self.state_data = {
111
+ "query_lengths": self._query_lengths_buf[:padded_active_request_count],
112
+ "cu_query_seq_lengths": self._cu_query_seq_lengths_buf[
113
+ : padded_active_request_count + 1
114
+ ],
115
+ "cu_kv_seq_lengths": self._cu_kv_seq_lengths_buf[: padded_active_request_count + 1],
116
+ "kv_seq_lengths": self._kv_seq_lengths_buf[:padded_active_request_count],
117
+ "block_table": self._block_table_buf[0:padded_active_request_count, :],
118
+ "max_seqlen_q": self._max_seqlen_q,
119
+ "max_seqlen_k": self._max_seqlen_k,
120
+ }
121
+
122
+ def reset(self):
123
+ """
124
+ Reset the metadata for the next batch.
125
+ """
126
+ self._query_lengths_buf.fill_(0)
127
+ self._cu_query_seq_lengths_buf.fill_(0)
128
+ self._cu_kv_seq_lengths_buf.fill_(0)
129
+ self._kv_seq_lengths_buf.fill_(0)
130
+ self._block_table_buf.fill_(0)
131
+ self._max_seqlen_q = 0
132
+ self._max_seqlen_k = 0
133
+
134
+
135
+ class GraphedMHAMetadata(MHAMetadata):
136
+ """
137
+ Metadata for MHA layer using flash-attention with CUDA graphs.
138
+ """
139
+
140
+ def __init__(
141
+ self, block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen
142
+ ):
143
+ super().__init__(
144
+ block_count_total, max_kv_block_count, max_requests, block_size_tokens, max_seqlen
145
+ )
146
+
147
+ def update(
148
+ self,
149
+ request_query_lengths: torch.Tensor,
150
+ request_kv_length_offsets: torch.Tensor,
151
+ request_to_kv_block_ids: torch.Tensor,
152
+ padded_active_token_count: int,
153
+ real_batch_size: int,
154
+ padded_active_request_count: Optional[int] = None,
155
+ decode_only: bool = False,
156
+ ):
157
+ """
158
+ Args:
159
+ request_query_lengths: (>real_batch_size,)
160
+ request_kv_length_offsets: (>real_batch_size,)
161
+ request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
162
+ padded_active_token_count: int
163
+ real_batch_size: int
164
+ padded_active_request_count: Optional[int]
165
+ decode_only: bool
166
+ """
167
+ super().update(
168
+ request_query_lengths,
169
+ request_kv_length_offsets,
170
+ request_to_kv_block_ids,
171
+ padded_active_token_count,
172
+ real_batch_size,
173
+ padded_active_request_count,
174
+ decode_only,
175
+ )
176
+
177
+ def reset(self):
178
+ super().reset()
179
+
180
+
181
+ class NonGraphedMHAMetadata(MHAMetadata):
182
+ """
183
+ Metadata for MHA layer using flash-attention without CUDA graphs.
184
+ """
185
+
186
+ def update(
187
+ self,
188
+ request_query_lengths: torch.Tensor,
189
+ request_kv_length_offsets: torch.Tensor,
190
+ request_to_kv_block_ids: torch.Tensor,
191
+ padded_active_token_count: int,
192
+ real_batch_size: int,
193
+ padded_active_request_count: Optional[int] = None,
194
+ decode_only: bool = False,
195
+ ):
196
+ """
197
+ Args:
198
+ request_query_lengths: (>real_batch_size,)
199
+ request_kv_length_offsets: (>real_batch_size,)
200
+ request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
201
+ padded_active_token_count: int
202
+ real_batch_size: int
203
+ padded_active_request_count: Optional[int]
204
+ decode_only: bool
205
+ """
206
+ super().update(
207
+ request_query_lengths,
208
+ request_kv_length_offsets,
209
+ request_to_kv_block_ids,
210
+ padded_active_token_count,
211
+ real_batch_size,
212
+ padded_active_request_count,
213
+ decode_only,
214
+ )
215
+ if len(self.state_data["query_lengths"]) > 0:
216
+ self.state_data["max_seqlen_q"] = torch.max(self.state_data["query_lengths"]).item()
217
+ self.state_data["max_seqlen_k"] = torch.max(self.state_data["kv_seq_lengths"]).item()
218
+ else:
219
+ self.state_data["max_seqlen_q"] = 1
220
+ self.state_data["max_seqlen_k"] = 1
@@ -0,0 +1,118 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+
9
+ class BlockAllocator:
10
+ """Allocator that manages blocks of memory for the KV cache.
11
+
12
+ This allocator is responsible for:
13
+ - Initializing a pool of block IDs
14
+ - Allocating blocks from the pool
15
+ - Releasing blocks back to the pool
16
+
17
+ Args:
18
+ context (DynamicInferenceContext): Dynamic inference context.
19
+ active_count (int): Total number of active blocks available in the buffer.
20
+ The full buffer size is 2*active_count, to accommodate an equal-size
21
+ space for paused requests that live on the CPU.
22
+ """
23
+
24
+ def __init__(self, context: "DynamicInferenceContext", active_count: int):
25
+
26
+ self.context = context
27
+
28
+ active_count -= 1 # -1 for dummy_block_idx (see below)
29
+ active_count = max(1, active_count) # need at least one block
30
+ self.total_count = 2 * active_count + 1 # +1 for dummy_block_idx
31
+ self.total_avail = self.total_count - 1 # -1 for dummy_block_idx
32
+ self.active_count = active_count
33
+ self.paused_count = self.total_count - self.active_count - 1 # -1 for dummy_block_idx
34
+ self.dummy_block_idx = self.total_count - 1
35
+
36
+ # Initialize block pool as a "stack" data structure
37
+ self.block_bag = torch.arange(
38
+ self.total_count, dtype=torch.int32, device=torch.cuda.current_device()
39
+ )
40
+
41
+ def __str__(self):
42
+ return (
43
+ f"total avail {self.total_avail} / {self.total_count - 1}"
44
+ f"; active {self.active_count}"
45
+ )
46
+
47
+ def get_active_used(self):
48
+ """Compute number of active blocks used."""
49
+ return (
50
+ self.context.request_kv_block_counts[
51
+ self.context.paused_request_count : self.context.total_request_count
52
+ ]
53
+ .sum()
54
+ .item()
55
+ )
56
+
57
+ def get_paused_used(self):
58
+ """Compute number of paused blocks used."""
59
+ return (
60
+ self.context.request_kv_block_counts[: self.context.paused_request_count].sum().item()
61
+ )
62
+
63
+ def get_active_avail(self):
64
+ """Compute number of active blocks available."""
65
+ return self.active_count - self.get_active_used()
66
+
67
+ def get_paused_avail(self):
68
+ """Compute number of paused blocks available."""
69
+ return self.paused_count - self.get_paused_used()
70
+
71
+ def is_memory_available(self, num_blocks: int) -> bool:
72
+ """Check if memory blocks are available.
73
+
74
+ Args:
75
+ num_blocks (int): Number of blocks to check.
76
+
77
+ Return:
78
+ (bool) Is memory available?
79
+ """
80
+ return self.get_active_avail() >= num_blocks
81
+
82
+ def allocate_memory_blocks(self, num_blocks: int) -> Optional[Tensor]:
83
+ """Allocate memory blocks if available, else return None.
84
+
85
+ Args:
86
+ num_blocks (int): Number of blocks to allocate.
87
+
88
+ Return:
89
+ (Optional[Tensor]) Allocated block IDs.
90
+ """
91
+ if self.is_memory_available(num_blocks):
92
+ self.total_avail -= num_blocks
93
+ block_ids = self.block_bag[self.total_avail : (self.total_avail + num_blocks)]
94
+ assert num_blocks == block_ids.numel()
95
+ return block_ids
96
+ else:
97
+ return None
98
+
99
+ def release_memory_blocks(self, blocks: Tensor) -> None:
100
+ """Release memory blocks.
101
+
102
+ Args:
103
+ blocks (Tensor): Block IDs to release.
104
+
105
+ Return:
106
+ None
107
+ """
108
+ num_blocks = blocks.size(dim=0)
109
+ self.block_bag[self.total_avail : (self.total_avail + num_blocks)] = blocks
110
+ self.total_avail += num_blocks
111
+
112
+ def reset(self) -> None:
113
+ """Reset the allocator to initial state.
114
+
115
+ This resets the available block count to the entire memory pool
116
+ (except for the dummy block).
117
+ """
118
+ self.total_avail = self.total_count - 1