megatron-core 0.14.0rc2__tar.gz → 0.14.0rc3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (308) hide show
  1. {megatron_core-0.14.0rc2/megatron_core.egg-info → megatron_core-0.14.0rc3}/PKG-INFO +11 -8
  2. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +10 -0
  3. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fp8_utils.py +119 -0
  4. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/__init__.py +1 -0
  5. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/dynamic_context.py +148 -59
  6. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/dynamic_engine.py +79 -18
  7. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +4 -0
  8. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +6 -0
  9. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +3 -37
  10. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +10 -4
  11. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/__init__.py +143 -44
  12. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/optimizer.py +0 -3
  13. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/package_info.py +1 -1
  14. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/packed_seq_params.py +2 -2
  15. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/random.py +4 -1
  16. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/attention.py +2 -7
  17. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/cuda_graphs.py +178 -43
  18. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3/megatron_core.egg-info}/PKG-INFO +11 -8
  19. megatron_core-0.14.0rc3/megatron_core.egg-info/requires.txt +33 -0
  20. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/pyproject.toml +13 -10
  21. megatron_core-0.14.0rc2/megatron_core.egg-info/requires.txt +0 -30
  22. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/LICENSE +0 -0
  23. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/MANIFEST.in +0 -0
  24. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/README.md +0 -0
  25. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/README.md +0 -0
  26. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/__init__.py +0 -0
  27. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/config.py +0 -0
  28. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/config_logger.py +0 -0
  29. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/__init__.py +0 -0
  30. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/bert_dataset.py +0 -0
  31. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/blended_dataset.py +0 -0
  32. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  33. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  34. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/gpt_dataset.py +0 -0
  35. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/helpers.cpp +0 -0
  36. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/helpers.py +0 -0
  37. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/indexed_dataset.py +0 -0
  38. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/masked_dataset.py +0 -0
  39. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/megatron_dataset.py +0 -0
  40. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  41. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/multimodal_dataset.py +0 -0
  42. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/object_storage_utils.py +0 -0
  43. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/__init__.py +0 -0
  44. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/__init__.py +0 -0
  45. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  46. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/config.py +0 -0
  47. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  48. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  49. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/__init__.py +0 -0
  50. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/build.py +0 -0
  51. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/dataset.py +0 -0
  52. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/utils.py +0 -0
  53. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/external_libs.py +0 -0
  54. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/__init__.py +0 -0
  55. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/build.py +0 -0
  56. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/factory.py +0 -0
  57. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/index.py +0 -0
  58. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  59. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  60. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  61. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/utils.py +0 -0
  62. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/validate.py +0 -0
  63. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/__init__.py +0 -0
  64. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  65. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  66. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/query.py +0 -0
  67. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  68. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/utils.py +0 -0
  69. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/utils.py +0 -0
  70. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/t5_dataset.py +0 -0
  71. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/utils.py +0 -0
  72. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/utils_object_storage.py +0 -0
  73. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/datasets/utils_s3.py +0 -0
  74. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/__init__.py +0 -0
  75. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/core.py +0 -0
  76. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  77. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  78. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/mapping.py +0 -0
  79. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  80. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/serialization.py +0 -0
  81. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  82. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  83. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  84. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  85. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  86. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  87. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  88. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  89. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  90. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  91. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  92. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  93. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  94. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  95. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  96. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/utils.py +0 -0
  97. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/validation.py +0 -0
  98. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/__init__.py +0 -0
  99. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
  100. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
  101. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/data_parallel_base.py +0 -0
  102. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  103. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  104. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/finalize_model_grads.py +0 -0
  105. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  106. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  107. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  108. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/energy_monitor.py +0 -0
  109. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/enums.py +0 -0
  110. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/__init__.py +0 -0
  111. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/data_type.py +0 -0
  112. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/export_config.py +0 -0
  113. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/model_type.py +0 -0
  114. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/__init__.py +0 -0
  115. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  116. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  117. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  118. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  119. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  120. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  121. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  122. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  123. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  124. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  125. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  126. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  127. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/extensions/__init__.py +0 -0
  128. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/extensions/kitchen.py +0 -0
  129. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/extensions/transformer_engine.py +0 -0
  130. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  131. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/__init__.py +0 -0
  132. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  133. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  134. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  135. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  136. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  137. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_indices_converter.py +0 -0
  138. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_layer_norm.py +0 -0
  139. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  140. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  141. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_softmax.py +0 -0
  142. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/hyper_comm_grid.py +0 -0
  143. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/__init__.py +0 -0
  144. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/async_stream.py +0 -0
  145. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/common_inference_params.py +0 -0
  146. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/communication_utils.py +0 -0
  147. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/base_context.py +0 -0
  148. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  149. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/static_context.py +0 -0
  150. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/__init__.py +0 -0
  151. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/abstract_engine.py +0 -0
  152. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/mcore_engine.py +0 -0
  153. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/static_engine.py +0 -0
  154. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/inference_request.py +0 -0
  155. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  156. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  157. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  158. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  159. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  160. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  161. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/sampling_params.py +0 -0
  162. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/scheduler.py +0 -0
  163. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  164. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  165. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  166. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  167. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference/utils.py +0 -0
  168. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/inference_params.py +0 -0
  169. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/jit.py +0 -0
  170. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/model_parallel_config.py +0 -0
  171. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/T5/__init__.py +0 -0
  172. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/T5/t5_model.py +0 -0
  173. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/T5/t5_spec.py +0 -0
  174. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/__init__.py +0 -0
  175. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/backends.py +0 -0
  176. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/__init__.py +0 -0
  177. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  178. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/bert_lm_head.py +0 -0
  179. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/bert_model.py +0 -0
  180. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/bert/pooler.py +0 -0
  181. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/__init__.py +0 -0
  182. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/__init__.py +0 -0
  183. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  184. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  185. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  186. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  187. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/language_module/__init__.py +0 -0
  188. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/language_module/language_module.py +0 -0
  189. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/vision_module/__init__.py +0 -0
  190. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  191. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/__init__.py +0 -0
  192. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  193. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
  194. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/gpt_model.py +0 -0
  195. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  196. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  197. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/__init__.py +0 -0
  198. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/clip_model.py +0 -0
  199. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/module.py +0 -0
  200. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/qwen_model.py +0 -0
  201. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mamba/__init__.py +0 -0
  202. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  203. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mamba/mamba_model.py +0 -0
  204. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/__init__.py +0 -0
  205. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/config/__init__.py +0 -0
  206. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/config/base_configs.py +0 -0
  207. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/model/__init__.py +0 -0
  208. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/model/base.py +0 -0
  209. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/submodules/audio.py +0 -0
  210. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/submodules/base.py +0 -0
  211. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/submodules/vision.py +0 -0
  212. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/__init__.py +0 -0
  213. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/context_parallel.py +0 -0
  214. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/llava_model.py +0 -0
  215. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/llava_spec.py +0 -0
  216. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/__init__.py +0 -0
  217. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/base_attention.py +0 -0
  218. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/config.py +0 -0
  219. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/decoder_attention.py +0 -0
  220. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/decoder_spec.py +0 -0
  221. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/encoder_attention.py +0 -0
  222. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/encoder_spec.py +0 -0
  223. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/model.py +0 -0
  224. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/retro/utils.py +0 -0
  225. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/__init__.py +0 -0
  226. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/clip_vit_model.py +0 -0
  227. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/multimodal_projector.py +0 -0
  228. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/radio.py +0 -0
  229. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  230. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/msc_utils.py +0 -0
  231. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/num_microbatches_calculator.py +0 -0
  232. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/clip_grads.py +0 -0
  233. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  234. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  235. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/distrib_optimizer.py +0 -0
  236. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/grad_scaler.py +0 -0
  237. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer/optimizer_config.py +0 -0
  238. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/optimizer_param_scheduler.py +0 -0
  239. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/parallel_state.py +0 -0
  240. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/__init__.py +0 -0
  241. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  242. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/schedules.py +0 -0
  243. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/utils.py +0 -0
  244. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/__init__.py +0 -0
  245. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/__init__.py +0 -0
  246. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  247. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  248. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  249. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/layers.py +0 -0
  250. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  251. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  252. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/process_groups_config.py +0 -0
  253. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/quantization/__init__.py +0 -0
  254. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/quantization/quant_config.py +0 -0
  255. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/quantization/utils.py +0 -0
  256. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/requirements.txt +0 -0
  257. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/rerun_state_machine.py +0 -0
  258. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/__init__.py +0 -0
  259. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_block.py +0 -0
  260. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  261. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  262. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_layer.py +0 -0
  263. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_mixer.py +0 -0
  264. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/mlp_layer.py +0 -0
  265. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/ssm/triton_cache_manager.py +0 -0
  266. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/__init__.py +0 -0
  267. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  268. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/data.py +0 -0
  269. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/layers.py +0 -0
  270. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/mappings.py +0 -0
  271. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/utils.py +0 -0
  272. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/timers.py +0 -0
  273. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/__init__.py +0 -0
  274. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  275. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  276. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/dot_product_attention.py +0 -0
  277. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/enums.py +0 -0
  278. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  279. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  280. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/identity_op.py +0 -0
  281. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/mlp.py +0 -0
  282. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/module.py +0 -0
  283. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/__init__.py +0 -0
  284. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/experts.py +0 -0
  285. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  286. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  287. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/moe_layer.py +0 -0
  288. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/moe_utils.py +0 -0
  289. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/router.py +0 -0
  290. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/shared_experts.py +0 -0
  291. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  292. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  293. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/multi_latent_attention.py +0 -0
  294. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/multi_token_prediction.py +0 -0
  295. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  296. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/spec_utils.py +0 -0
  297. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/torch_layer_norm.py +0 -0
  298. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/torch_norm.py +0 -0
  299. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/transformer_block.py +0 -0
  300. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/transformer_config.py +0 -0
  301. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/transformer_layer.py +0 -0
  302. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/transformer/utils.py +0 -0
  303. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron/core/utils.py +0 -0
  304. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron_core.egg-info/SOURCES.txt +0 -0
  305. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron_core.egg-info/dependency_links.txt +0 -0
  306. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/megatron_core.egg-info/top_level.txt +0 -0
  307. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/setup.cfg +0 -0
  308. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc3}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.14.0rc2
3
+ Version: 0.14.0rc3
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -31,6 +31,7 @@ Description-Content-Type: text/markdown
31
31
  License-File: LICENSE
32
32
  Requires-Dist: torch
33
33
  Requires-Dist: numpy<2.0.0
34
+ Requires-Dist: packaging~=25.0
34
35
  Provides-Extra: mlm
35
36
  Requires-Dist: flask-restful; extra == "mlm"
36
37
  Requires-Dist: sentencepiece; extra == "mlm"
@@ -38,14 +39,16 @@ Requires-Dist: tiktoken; extra == "mlm"
38
39
  Requires-Dist: wandb; extra == "mlm"
39
40
  Provides-Extra: dev
40
41
  Requires-Dist: tqdm; extra == "dev"
41
- Requires-Dist: einops; extra == "dev"
42
- Requires-Dist: tensorstore!=0.1.46,!=0.1.72; extra == "dev"
43
- Requires-Dist: nvtx; extra == "dev"
44
- Requires-Dist: transformers; extra == "dev"
45
- Requires-Dist: multi-storage-client; extra == "dev"
42
+ Requires-Dist: einops~=0.8; extra == "dev"
43
+ Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "dev"
44
+ Requires-Dist: nvtx~=0.2; extra == "dev"
45
+ Requires-Dist: transformers~=4.53; extra == "dev"
46
+ Requires-Dist: multi-storage-client~=0.20.3; extra == "dev"
47
+ Requires-Dist: opentelemetry-api~=1.33.1; extra == "dev"
46
48
  Requires-Dist: setuptools<80.0.0; extra == "dev"
47
- Requires-Dist: nvidia-modelopt[torch]~=0.31.0; sys_platform != "darwin" and extra == "dev"
48
- Requires-Dist: megatron-energon[av_decode]<7; extra == "dev"
49
+ Requires-Dist: nvidia-modelopt[torch]<0.32.0,>=0.31.0a0; sys_platform != "darwin" and extra == "dev"
50
+ Requires-Dist: megatron-energon[av_decode]~=6.0; extra == "dev"
51
+ Requires-Dist: flashinfer-python; extra == "dev"
49
52
  Provides-Extra: lts
50
53
  Requires-Dist: tqdm; extra == "lts"
51
54
  Requires-Dist: einops; extra == "lts"
@@ -217,6 +217,16 @@ class FullyShardedDataParallel(_BaseDataParallel):
217
217
 
218
218
  self.module.apply(unmap_weight_tensor)
219
219
 
220
+ for param in self.module.parameters():
221
+ if not hasattr(param, 'grad_added_to_main_grad'):
222
+ # This is to ensure that the param.grad_added_to_main_grad is set to False
223
+ # when the parameter is created.
224
+ param.grad_added_to_main_grad = False
225
+ if not hasattr(param, '__fsdp_param__'):
226
+ # This is to ensure that the param.__fsdp_param__ is set to True
227
+ # when the parameter is created.
228
+ param.__fsdp_param__ = True
229
+
220
230
  def _init_fsdp_param_and_grad_buffer(self):
221
231
  if self.config.calculate_per_token_loss:
222
232
  # We don't need to scale the gradients in this case.
@@ -2,7 +2,9 @@
2
2
 
3
3
  """Utility functions related to FP8 that are used throughout Megatron core"""
4
4
 
5
+ import weakref
5
6
  from contextlib import nullcontext
7
+ from functools import wraps
6
8
  from typing import List, Optional
7
9
 
8
10
  import torch
@@ -53,6 +55,29 @@ except (ImportError, ModuleNotFoundError):
53
55
  # MXFP8Tensor not found
54
56
  HAVE_TE_MXFP8TENSOR = False
55
57
 
58
+ if HAVE_TE:
59
+ from megatron.core.extensions.transformer_engine import (
60
+ TEColumnParallelLinear,
61
+ TELayerNormColumnParallelLinear,
62
+ TELinear,
63
+ TERowParallelLinear,
64
+ )
65
+
66
+ TE_LINEAR_TYPES = (
67
+ TELinear,
68
+ TEColumnParallelLinear,
69
+ TERowParallelLinear,
70
+ TELayerNormColumnParallelLinear,
71
+ )
72
+ else:
73
+ TE_LINEAR_TYPES = ()
74
+
75
+ try:
76
+ from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
77
+ except ImportError:
78
+ Fp8Padding = None
79
+ Fp8Unpadding = None
80
+
56
81
 
57
82
  def is_float8tensor(tensor: torch.Tensor) -> bool:
58
83
  """Check if a tensor is a Transformer Engine Float8Tensor.
@@ -511,3 +536,97 @@ else:
511
536
  def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False):
512
537
  """Returns dummy fp8 context manager since TE is not available."""
513
538
  return nullcontext()
539
+
540
+
541
+ if HAVE_TE:
542
+ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
543
+
544
+ # Modules that have been wrapped for inference for fp8
545
+ _fp8_inference_wrapped_modules = weakref.WeakSet()
546
+
547
+ def _wrap_te_linear_for_padding(module: torch.nn.Module):
548
+ """Wrap a TE linear module to automatically pad sequences for FP8 inference.
549
+
550
+ Modifies the module's forward method to:
551
+ 1. Pad input sequences to FP8 alignment requirements
552
+ 2. Run the original forward pass
553
+ 3. Unpad outputs to original sequence length
554
+
555
+ Args:
556
+ module: A Transformer Engine linear layer (TELinear, TEColumnParallelLinear, etc.)
557
+ """
558
+ if module in _fp8_inference_wrapped_modules:
559
+ return
560
+ _pad_func = Fp8Padding(1)
561
+ _unpad_func = Fp8Unpadding(1)
562
+
563
+ original_forward = module.forward
564
+
565
+ @wraps(original_forward)
566
+ def padded_forward(input_tensor, *args, **kwargs):
567
+ # Only do padding for fp8 if we are in fp8 context
568
+ if not FP8GlobalStateManager.is_fp8_enabled():
569
+ return original_forward(input_tensor, *args, **kwargs)
570
+
571
+ seq_len, batch_size, hidden_size = input_tensor.shape
572
+ # Reshape to (S, B*H) to pad sequence dimension
573
+ input_2d = input_tensor.reshape(seq_len, -1)
574
+ # Pad the sequence dimension
575
+ padded_input_2d, _ = _pad_func(input_2d, [seq_len])
576
+ padded_seq_len = padded_input_2d.shape[0]
577
+
578
+ # Reshape back to (padded_S, B, H)
579
+ padded_input_3d = padded_input_2d.view(padded_seq_len, batch_size, hidden_size)
580
+ output = original_forward(padded_input_3d, *args, **kwargs)
581
+
582
+ # Handle output
583
+ if isinstance(output, tuple):
584
+ output_tensor = output[0]
585
+ other_outputs = output[1:]
586
+ else:
587
+ output_tensor = output
588
+ other_outputs = ()
589
+
590
+ # Unpad output - reshape to 2D, unpad, reshape back
591
+ _, _, output_hidden_size = output_tensor.shape
592
+ output_2d = output_tensor.reshape(padded_seq_len, -1)
593
+ unpadded_output_2d = _unpad_func(output_2d, [seq_len])
594
+ unpadded_output = unpadded_output_2d.reshape(seq_len, batch_size, output_hidden_size)
595
+
596
+ if other_outputs:
597
+ return (unpadded_output,) + other_outputs
598
+ else:
599
+ return unpadded_output
600
+
601
+ module.forward = padded_forward
602
+ _fp8_inference_wrapped_modules.add(module)
603
+
604
+ def prepare_model_for_fp8_inference(model):
605
+ """Prepare a model for FP8 inference by wrapping TE linear layers with padding support.
606
+
607
+ FP8 TE Gemms have specific shape requirements. This function wraps all Transformer
608
+ Engine linear layers in the model to automatically pad/unpad sequences during inference.
609
+
610
+ Args:
611
+ model (model (GPTModel): Model containing TE linear layers.
612
+
613
+ Returns:
614
+ GPTModel: The same model with wrapped linear layers (modified in-place).
615
+
616
+ """
617
+ assert Fp8Padding and Fp8Unpadding, "TE version does not have FP8 padding functions"
618
+ # Find and wrap all TE linear layers
619
+ for module in model.modules():
620
+ if isinstance(module, TE_LINEAR_TYPES):
621
+ _wrap_te_linear_for_padding(module)
622
+
623
+ return model
624
+
625
+ else:
626
+
627
+ def prepare_model_for_fp8_inference(model):
628
+ """If trys using prepare_model_for_fp8_inference without TE we error"""
629
+ raise RuntimeError(
630
+ "prepare_model_for_fp8_inference requires Transformer Engine to be installed. "
631
+ "Please install transformer-engine to use FP8 inference."
632
+ )
@@ -14,6 +14,7 @@ warnings.warn(
14
14
  DeprecationWarning,
15
15
  )
16
16
  from .dynamic_context import (
17
+ ActiveRequestCountOverflowError,
17
18
  ChunkOverflowError,
18
19
  ContextOverflowError,
19
20
  DynamicInferenceContext,
@@ -56,6 +56,18 @@ class ChunkOverflowError(ContextOverflowError):
56
56
  pass
57
57
 
58
58
 
59
+ class ActiveRequestCountOverflowError(ContextOverflowError):
60
+ '''Used when `initialize_attention_state()` is called with
61
+ `num_warmup_requests > max_requests.'''
62
+
63
+ def __init__(self, max_request_count, active_request_count):
64
+ assert active_request_count > max_request_count
65
+ super().__init__(
66
+ "active_request_count (%d) > max_request_count (%d)."
67
+ % (active_request_count, max_request_count)
68
+ )
69
+
70
+
59
71
  # pylint: disable=line-too-long
60
72
  class DynamicInferenceContext(BaseInferenceContext):
61
73
  """Inference context that is passed to the main model in order
@@ -108,6 +120,11 @@ class DynamicInferenceContext(BaseInferenceContext):
108
120
  from `buffer_overflow_factor`.
109
121
  max_tokens_override (Optional[int]): If set, overrides value computed
110
122
  from `buffer_overflow_factor`.
123
+ tensor_model_parallel_size (Optional[int]): Tensor model parallel size.
124
+ num_cuda_graphs (Optional[int]): Maximum number of cuda graphs to capture,
125
+ where the cuda graph batch sizes range from 1 to `max_requests` (as
126
+ computed below). Due to rounding, the actual number of cuda graphs may
127
+ not equal this argument.
111
128
  """
112
129
 
113
130
  def __init__(
@@ -125,6 +142,7 @@ class DynamicInferenceContext(BaseInferenceContext):
125
142
  max_requests_override: Optional[int] = None,
126
143
  max_tokens_override: Optional[int] = None,
127
144
  tensor_model_parallel_size: Optional[int] = None,
145
+ num_cuda_graphs: Optional[int] = None,
128
146
  materialize_only_last_token_logits: bool = True,
129
147
  ):
130
148
 
@@ -188,7 +206,7 @@ class DynamicInferenceContext(BaseInferenceContext):
188
206
  self.active_token_count = 0
189
207
  self.paused_request_count = 0
190
208
  self.padded_active_token_count = None
191
- self.padded_active_sample_count = None
209
+ self.padded_active_request_count = None
192
210
  self.paused_tokens = None
193
211
 
194
212
  # Per-request state.
@@ -246,6 +264,34 @@ class DynamicInferenceContext(BaseInferenceContext):
246
264
  device=torch.cuda.current_device(),
247
265
  )
248
266
 
267
+ # Cuda graph request counts (i.e., batch sizes used for decode-only steps).
268
+ self.cuda_graph_request_counts = None
269
+ if num_cuda_graphs is not None:
270
+
271
+ # Ensure valid num_cuda_graphs.
272
+ num_cuda_graphs = min(max(num_cuda_graphs, 1), self.max_requests)
273
+
274
+ # Cuda graph step size.
275
+ cuda_graph_rounder = 8
276
+ self.cuda_graph_step_size = self.max_requests / num_cuda_graphs
277
+ self.cuda_graph_step_size = cuda_graph_rounder * int(
278
+ math.ceil(int(self.cuda_graph_step_size) / cuda_graph_rounder)
279
+ )
280
+
281
+ # Cuda graph request counts.
282
+ if num_cuda_graphs == 1:
283
+ self.cuda_graph_request_counts = [self.max_requests]
284
+ else:
285
+ self.cuda_graph_request_counts = list(
286
+ range(self.cuda_graph_step_size, self.max_requests, self.cuda_graph_step_size)
287
+ )
288
+ if self.cuda_graph_request_counts[-1] != self.max_requests:
289
+ self.cuda_graph_request_counts.append(self.max_requests)
290
+ self.cuda_graph_request_counts.reverse()
291
+
292
+ # Set used for validating active cuda graph request count.
293
+ self.cuda_graph_request_counts_set = set(self.cuda_graph_request_counts)
294
+
249
295
  # `*_decode_only` tensors are for use with cuda graphs to maintain
250
296
  # consistent input shapes, which is required to use cuda graphs. Cuda
251
297
  # graphs are used only during decode-only steps (i.e., no requests are in
@@ -269,7 +315,7 @@ class DynamicInferenceContext(BaseInferenceContext):
269
315
  (self.max_requests + 1,), 0, dtype=torch.int32, device=torch.cuda.current_device()
270
316
  )
271
317
 
272
- self.kv_memory_decode_only = torch.full(
318
+ self.request_to_kv_chunk_ids_decode_only = torch.full(
273
319
  (self.max_requests, self.max_kv_chunk_count),
274
320
  0,
275
321
  dtype=torch.int,
@@ -278,27 +324,22 @@ class DynamicInferenceContext(BaseInferenceContext):
278
324
 
279
325
  # Guaranteed active requests.
280
326
  # * See details in the class docstring above. `gtd_request_fraction` is
281
- # the fraction of the memory buffer that is reserved for guaranteeing
282
- # that some number of active requests can always proceed with their
283
- # generations. The number of bytes defined by `gtd_request_fraction *
284
- # buffer_size_gb` is converted to a number of requests that this
285
- # reserved space can handle (`gtd_request_count`), and rounded to be an
286
- # exact multiple of `max_sequence_length`. This is then converted into
287
- # the number of reserved chunks (`gtd_chunk_count`) and bytes
288
- # (`gtd_byte_count`).
289
- # Chunk ids.
290
- self.max_kv_chunk_count = math.ceil(self.max_sequence_length / self.chunk_size_tokens)
291
- gtd_byte_count = buffer_guaranteed_fraction * buffer_size_bytes
292
- gtd_request_count, _ = bytes_to_max_requests_and_tokens(gtd_byte_count)
293
- if buffer_guaranteed_fraction > 0:
294
- gtd_request_count = max(1, gtd_request_count)
295
- gtd_request_count = self.round_up_requests(min(gtd_request_count, self.max_requests))
296
- gtd_chunk_count = gtd_request_count * self.max_kv_chunk_count
297
- assert (
298
- gtd_request_count <= self.max_requests
299
- ), "gtd_request_count (%d) > max_requests (%d)." % (gtd_request_count, self.max_requests)
300
- self.gtd_request_count = gtd_request_count
301
- self.gtd_chunk_count = gtd_chunk_count
327
+ # the fraction of chunks in the memory buffer that are reserved for
328
+ # guaranteeing that some number of active requests can always proceed
329
+ # with their generations. The number of chunks defined by
330
+ # `buffer_guaranteed_fraction * chunk_count_total` is converted to a
331
+ # number of requests that this reserved space can safely handle
332
+ # (`gtd_request_count`).
333
+ # * Note: computing the size of this guaranteed space from chunks rather
334
+ # than bytes is safer due to the non-linear impacts of a large
335
+ # `chunk_size_tokens` or `max_kv_chunk_count`. When computing from
336
+ # chunks, this space will always be less than `chunk_count_total`. When
337
+ # computing from bytes, this space can unexpectedly be much larger than
338
+ # `chunk_count_total`, resulting in stalled generations.
339
+ gtd_chunk_count = int(buffer_guaranteed_fraction * chunk_count_total)
340
+ gtd_chunk_count = min(gtd_chunk_count, chunk_count_total)
341
+ self.gtd_request_count = max(1, gtd_chunk_count // self.max_kv_chunk_count)
342
+ self.gtd_chunk_count = self.gtd_request_count * self.max_kv_chunk_count
302
343
 
303
344
  # Initialize chunk allocator
304
345
  self.chunk_allocator = ChunkAllocator(
@@ -368,12 +409,7 @@ class DynamicInferenceContext(BaseInferenceContext):
368
409
 
369
410
  def cu_kv_lengths(self) -> Tensor:
370
411
  """Cumulative key/value sequence lengths."""
371
- return (
372
- self.cu_kv_seq_lengths,
373
- self.kv_seq_lengths,
374
- self.kv_seq_lengths_decode_only,
375
- self.max_seqlen_k,
376
- )
412
+ return (self.cu_kv_seq_lengths, self.kv_seq_lengths, self.max_seqlen_k)
377
413
 
378
414
  def get_active_sequence_lengths(self) -> Tensor:
379
415
  """Total sequence length (query + key) for active requests."""
@@ -487,7 +523,7 @@ class DynamicInferenceContext(BaseInferenceContext):
487
523
  key_seq_idx = self.token_to_position_in_request[:n]
488
524
  key_emb = key_emb[key_seq_idx]
489
525
  if self.is_decode_only():
490
- assert key.shape[0] == n == self.max_requests
526
+ assert key.shape[0] == n
491
527
  key = apply_rotary_pos_emb(
492
528
  t=key[:n], freqs=key_emb[:n], config=config, cp_group=cp_group
493
529
  )
@@ -506,23 +542,65 @@ class DynamicInferenceContext(BaseInferenceContext):
506
542
  self.query_seq_lengths_decode_only.fill_(0)
507
543
  self.cu_kv_seq_lengths = None
508
544
  self.cu_kv_seq_lengths_decode_only.fill_(0)
545
+ self.kv_seq_lengths = None
509
546
  self.kv_seq_lengths_decode_only.fill_(0)
510
- self.kv_memory_decode_only.fill_(0)
547
+ self.request_to_kv_chunk_ids_decode_only.fill_(0)
511
548
  self.block_table = None
512
549
 
513
- def initialize_attention_state(self) -> None:
514
- """Initialize attention state so that every layer can use it"""
550
+ def initialize_attention_state(self, *, num_warmup_requests: Optional[int] = None) -> None:
551
+ """Initialize attention state so that every layer can use it.
552
+
553
+ Args:
554
+ num_warmup_requests (Optional[int]): Number of requests to use for
555
+ warming up cuda graphs. Must be less than or equal to
556
+ `max_requests`.
557
+
558
+ Return:
559
+ None.
560
+ """
515
561
 
562
+ # Use of num_warmup_requests only for decode-only.
563
+ if num_warmup_requests is not None:
564
+ assert self.is_decode_only(), "cuda graph warmup requires decode-only mode."
565
+
566
+ # Active request count.
567
+ active_request_count = (
568
+ self.total_request_count - self.paused_request_count
569
+ if num_warmup_requests is None
570
+ else num_warmup_requests
571
+ )
572
+
573
+ # Active cuda graph count (if decode-only).
574
+ active_cuda_graph_request_count = None
575
+ if self.is_decode_only():
576
+ if active_request_count > self.max_requests:
577
+ raise ActiveRequestCountOverflowError(self.max_requests, active_request_count)
578
+
579
+ if self.cuda_graph_request_counts:
580
+ active_cuda_graph_request_count = (
581
+ math.ceil(active_request_count / self.cuda_graph_step_size)
582
+ * self.cuda_graph_step_size
583
+ )
584
+ active_cuda_graph_request_count = min(
585
+ active_cuda_graph_request_count, self.max_requests
586
+ )
587
+ assert active_cuda_graph_request_count in self.cuda_graph_request_counts_set
588
+ else:
589
+ active_cuda_graph_request_count = self.max_requests
590
+
591
+ # Padded active token/request counts.
516
592
  self.padded_active_token_count = (
517
- self.max_requests
593
+ active_cuda_graph_request_count
518
594
  if self.is_decode_only()
519
595
  else self.round_up_tokens(self.active_token_count)
520
596
  )
521
- self.padded_active_sample_count = (
522
- self.max_requests
597
+ self.padded_active_request_count = (
598
+ active_cuda_graph_request_count
523
599
  if self.is_decode_only()
524
600
  else (self.total_request_count - self.paused_request_count)
525
601
  )
602
+
603
+ # Update token position indexes.
526
604
  self.token_to_chunk_idx[self.active_token_count : self.padded_active_token_count] = (
527
605
  self.dummy_chunk_idx
528
606
  )
@@ -533,6 +611,7 @@ class DynamicInferenceContext(BaseInferenceContext):
533
611
  self.active_token_count : self.padded_active_token_count
534
612
  ] = 0
535
613
 
614
+ # Update cu_query_seq_lengths, max_seqlen_q.
536
615
  query_lengths = self.request_query_lengths[
537
616
  self.paused_request_count : self.total_request_count
538
617
  ]
@@ -540,9 +619,7 @@ class DynamicInferenceContext(BaseInferenceContext):
540
619
  self.query_seq_lengths_decode_only[
541
620
  0 : self.total_request_count - self.paused_request_count
542
621
  ] = query_lengths
543
- cu_query_lengths_decode_only = torch.cumsum(self.query_seq_lengths_decode_only, dim=0)
544
- self.cu_query_seq_lengths_decode_only[1:] = cu_query_lengths_decode_only
545
- self.cu_query_seq_lengths = self.cu_query_seq_lengths_decode_only
622
+ self.cu_query_seq_lengths = None # ensure no accidental use
546
623
  self.max_seqlen_q = 1
547
624
  else:
548
625
  cu_query_lengths = torch.cumsum(query_lengths, dim=0)
@@ -558,12 +635,18 @@ class DynamicInferenceContext(BaseInferenceContext):
558
635
  kv_seq_lengths = self.request_kv_length_offsets + self.request_query_lengths
559
636
  self.kv_seq_lengths = kv_seq_lengths[self.paused_request_count : self.total_request_count]
560
637
  if self.is_decode_only():
638
+ # Re-assign `kv_seq_lengths` to be a view of the first
639
+ # `active_cuda_graph_request_count` tokens of `kv_seq_lengths_decode_only`,
640
+ # such that `kv_seq_lengths` has a static memory address and is therefore
641
+ # cuda graph compatible. This allows `kv_seq_lengths` to transition between,
642
+ # cuda graph sizes, which makes multi-batch-size cuda graphs possible.
561
643
  self.kv_seq_lengths_decode_only[
562
644
  0 : self.total_request_count - self.paused_request_count
563
645
  ] = self.kv_seq_lengths
564
- cu_kv_lengths_decode_only = torch.cumsum(self.kv_seq_lengths_decode_only, dim=0)
565
- self.cu_kv_seq_lengths_decode_only[1:] = cu_kv_lengths_decode_only
566
- self.cu_kv_seq_lengths = self.cu_kv_seq_lengths_decode_only
646
+ self.kv_seq_lengths = self.kv_seq_lengths_decode_only[
647
+ : self.padded_active_request_count
648
+ ]
649
+ self.cu_kv_seq_lengths = None # ensure no accidental use
567
650
  self.max_seqlen_k = self.max_sequence_length
568
651
  else:
569
652
  self.cu_kv_seq_lengths = torch.full(
@@ -575,14 +658,17 @@ class DynamicInferenceContext(BaseInferenceContext):
575
658
  self.cu_kv_seq_lengths[1:] = torch.cumsum(self.kv_seq_lengths, dim=0)
576
659
  self.max_seqlen_k = self.kv_seq_lengths.max().item()
577
660
 
578
- kv_memory = self.request_to_kv_chunk_ids[
661
+ # Update KV chunk IDs, block table.
662
+ request_to_kv_chunk_ids = self.request_to_kv_chunk_ids[
579
663
  self.paused_request_count : self.total_request_count
580
664
  ]
581
665
  if self.is_decode_only():
582
- self.kv_memory_decode_only[0 : self.total_request_count - self.paused_request_count] = (
583
- kv_memory
584
- )
585
- self.block_table = self.kv_memory_decode_only
666
+ self.request_to_kv_chunk_ids_decode_only[
667
+ 0 : self.total_request_count - self.paused_request_count
668
+ ] = request_to_kv_chunk_ids
669
+ self.block_table = self.request_to_kv_chunk_ids_decode_only[
670
+ : self.padded_active_request_count
671
+ ]
586
672
  else:
587
673
  self.block_table = self.request_to_kv_chunk_ids[
588
674
  self.paused_request_count : self.total_request_count
@@ -606,7 +692,7 @@ class DynamicInferenceContext(BaseInferenceContext):
606
692
  self.active_token_count = 0
607
693
  self.paused_request_count = 0
608
694
  self.padded_active_token_count = 0
609
- self.padded_active_sample_count = 0
695
+ self.padded_active_request_count = 0
610
696
  self.paused_tokens = None
611
697
 
612
698
  # Reset request indexes.
@@ -632,21 +718,24 @@ class DynamicInferenceContext(BaseInferenceContext):
632
718
  self.chunk_allocator.reset()
633
719
  self.request_to_kv_chunk_ids.fill_(-1)
634
720
 
635
- def current_input_ids(self) -> Tensor:
636
- """Flattened input IDs for forward pass.
637
-
638
- Return:
639
- (Tensor) Flattened active input IDs.
640
- """
641
- return self.token_to_input_ids[: self.padded_active_token_count].unsqueeze(0)
721
+ def current_input_and_position_ids(
722
+ self, *, num_warmup_tokens: Optional[int] = None
723
+ ) -> Tuple[Tensor, Tensor]:
724
+ """Flattened input and position IDs for forward pass.
642
725
 
643
- def current_position_ids(self) -> Tensor:
644
- """Flattened position IDs for forward pass.
726
+ Args:
727
+ num_warmup_tokens (Optional[int]): Number of tokens to return for
728
+ warming up cuda graphs. Must be less than or equal to
729
+ `max_tokens`.
645
730
 
646
731
  Return:
647
- (Tensor) Flattened active position IDs.
732
+ (Tuple[Tensor, Tensor]) Flattened active input and position IDs.
648
733
  """
649
- return self.token_to_pos_ids[: self.padded_active_token_count].unsqueeze(0)
734
+ num_tokens = num_warmup_tokens or self.padded_active_token_count
735
+ return (
736
+ self.token_to_input_ids[:num_tokens].unsqueeze(0),
737
+ self.token_to_pos_ids[:num_tokens].unsqueeze(0),
738
+ )
650
739
 
651
740
  def last_token_logits(self, logits: Tensor) -> Tensor:
652
741
  """Last tokens of logits.