megatron-core 0.14.0rc1__tar.gz → 0.14.0rc3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (308) hide show
  1. {megatron_core-0.14.0rc1/megatron_core.egg-info → megatron_core-0.14.0rc3}/PKG-INFO +11 -8
  2. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +10 -0
  3. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/enums.py +10 -3
  4. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fp8_utils.py +125 -2
  5. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/__init__.py +1 -0
  6. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/dynamic_context.py +200 -65
  7. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/static_context.py +1 -1
  8. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/dynamic_engine.py +97 -21
  9. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +6 -10
  10. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +2 -6
  11. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +6 -0
  12. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +2 -9
  13. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +15 -2
  14. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +59 -49
  15. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +15 -2
  16. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/utils.py +16 -0
  17. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/model_parallel_config.py +0 -5
  18. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/T5/t5_model.py +2 -7
  19. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/T5/t5_spec.py +2 -0
  20. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/bert/bert_layer_specs.py +2 -0
  21. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/language_model_embedding.py +3 -3
  22. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +11 -5
  23. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/language_module/language_module.py +57 -17
  24. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/gpt_layer_specs.py +4 -0
  25. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/gpt_model.py +19 -15
  26. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +2 -0
  27. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/moe_module_specs.py +2 -0
  28. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mamba/mamba_model.py +12 -16
  29. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/submodules/audio.py +1 -0
  30. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/llava_model.py +19 -4
  31. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/retro/decoder_spec.py +2 -0
  32. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/retro/encoder_spec.py +2 -0
  33. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/vision/clip_vit_model.py +9 -0
  34. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/vision/multimodal_projector.py +10 -1
  35. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/vision/radio.py +7 -0
  36. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/optimizer/__init__.py +181 -48
  37. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/optimizer/distrib_optimizer.py +54 -6
  38. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/optimizer/optimizer.py +27 -4
  39. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/package_info.py +1 -1
  40. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/packed_seq_params.py +2 -2
  41. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/parallel_state.py +42 -451
  42. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/p2p_communication.py +25 -68
  43. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/schedules.py +12 -73
  44. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/utils.py +57 -1
  45. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/rerun_state_machine.py +123 -86
  46. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/random.py +4 -1
  47. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/attention.py +2 -7
  48. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/cuda_graphs.py +239 -87
  49. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/enums.py +8 -1
  50. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/heterogeneous/linear_replacements.py +4 -0
  51. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/experts.py +1 -0
  52. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/moe_layer.py +2 -0
  53. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/moe_utils.py +6 -0
  54. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/router.py +23 -2
  55. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/multi_latent_attention.py +9 -3
  56. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/multi_token_prediction.py +10 -3
  57. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/transformer_block.py +22 -11
  58. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/transformer_config.py +31 -2
  59. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/transformer_layer.py +0 -4
  60. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3/megatron_core.egg-info}/PKG-INFO +11 -8
  61. megatron_core-0.14.0rc3/megatron_core.egg-info/requires.txt +33 -0
  62. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/pyproject.toml +25 -12
  63. megatron_core-0.14.0rc1/megatron_core.egg-info/requires.txt +0 -30
  64. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/LICENSE +0 -0
  65. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/MANIFEST.in +0 -0
  66. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/README.md +0 -0
  67. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/README.md +0 -0
  68. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/__init__.py +0 -0
  69. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/config.py +0 -0
  70. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/config_logger.py +0 -0
  71. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/__init__.py +0 -0
  72. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/bert_dataset.py +0 -0
  73. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/blended_dataset.py +0 -0
  74. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  75. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  76. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/gpt_dataset.py +0 -0
  77. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/helpers.cpp +0 -0
  78. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/helpers.py +0 -0
  79. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/indexed_dataset.py +0 -0
  80. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/masked_dataset.py +0 -0
  81. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/megatron_dataset.py +0 -0
  82. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  83. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/multimodal_dataset.py +0 -0
  84. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/object_storage_utils.py +0 -0
  85. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/__init__.py +0 -0
  86. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/__init__.py +0 -0
  87. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  88. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/config.py +0 -0
  89. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  90. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  91. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/__init__.py +0 -0
  92. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/build.py +0 -0
  93. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/dataset.py +0 -0
  94. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/db/utils.py +0 -0
  95. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/external_libs.py +0 -0
  96. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/__init__.py +0 -0
  97. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/build.py +0 -0
  98. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/factory.py +0 -0
  99. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/index.py +0 -0
  100. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  101. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  102. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  103. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/utils.py +0 -0
  104. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/index/validate.py +0 -0
  105. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/__init__.py +0 -0
  106. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  107. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  108. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/query.py +0 -0
  109. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  110. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/query/utils.py +0 -0
  111. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/retro/utils.py +0 -0
  112. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/t5_dataset.py +0 -0
  113. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/utils.py +0 -0
  114. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/utils_object_storage.py +0 -0
  115. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/datasets/utils_s3.py +0 -0
  116. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/__init__.py +0 -0
  117. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/core.py +0 -0
  118. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  119. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  120. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/mapping.py +0 -0
  121. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  122. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/serialization.py +0 -0
  123. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  124. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  125. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  126. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  127. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  128. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  129. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  130. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  131. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  132. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  133. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  134. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  135. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  136. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  137. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  138. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/utils.py +0 -0
  139. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/dist_checkpointing/validation.py +0 -0
  140. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/__init__.py +0 -0
  141. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
  142. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
  143. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/data_parallel_base.py +0 -0
  144. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  145. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  146. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/finalize_model_grads.py +0 -0
  147. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  148. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  149. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  150. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/energy_monitor.py +0 -0
  151. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/__init__.py +0 -0
  152. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/data_type.py +0 -0
  153. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/export_config.py +0 -0
  154. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/model_type.py +0 -0
  155. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/__init__.py +0 -0
  156. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  157. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  158. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  159. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  160. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  161. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  162. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  163. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  164. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  165. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  166. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  167. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  168. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/extensions/__init__.py +0 -0
  169. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/extensions/kitchen.py +0 -0
  170. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/extensions/transformer_engine.py +0 -0
  171. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  172. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/__init__.py +0 -0
  173. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  174. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  175. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  176. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  177. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  178. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_indices_converter.py +0 -0
  179. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_layer_norm.py +0 -0
  180. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  181. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  182. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/fusions/fused_softmax.py +0 -0
  183. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/hyper_comm_grid.py +0 -0
  184. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/__init__.py +0 -0
  185. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/async_stream.py +0 -0
  186. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/common_inference_params.py +0 -0
  187. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/communication_utils.py +0 -0
  188. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/base_context.py +0 -0
  189. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  190. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/__init__.py +0 -0
  191. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/abstract_engine.py +0 -0
  192. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/mcore_engine.py +0 -0
  193. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/engines/static_engine.py +0 -0
  194. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/inference_request.py +0 -0
  195. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  196. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  197. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  198. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  199. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/sampling_params.py +0 -0
  200. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/scheduler.py +0 -0
  201. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  202. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  203. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/inference_params.py +0 -0
  204. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/jit.py +0 -0
  205. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/T5/__init__.py +0 -0
  206. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/__init__.py +0 -0
  207. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/backends.py +0 -0
  208. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/bert/__init__.py +0 -0
  209. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/bert/bert_lm_head.py +0 -0
  210. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/bert/bert_model.py +0 -0
  211. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/bert/pooler.py +0 -0
  212. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/__init__.py +0 -0
  213. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/__init__.py +0 -0
  214. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  215. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  216. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  217. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/language_module/__init__.py +0 -0
  218. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/vision_module/__init__.py +0 -0
  219. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  220. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/__init__.py +0 -0
  221. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  222. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/__init__.py +0 -0
  223. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/clip_model.py +0 -0
  224. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/module.py +0 -0
  225. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/huggingface/qwen_model.py +0 -0
  226. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mamba/__init__.py +0 -0
  227. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  228. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/__init__.py +0 -0
  229. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/config/__init__.py +0 -0
  230. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/config/base_configs.py +0 -0
  231. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/model/__init__.py +0 -0
  232. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/model/base.py +0 -0
  233. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/submodules/base.py +0 -0
  234. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/mimo/submodules/vision.py +0 -0
  235. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/__init__.py +0 -0
  236. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/context_parallel.py +0 -0
  237. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/multimodal/llava_spec.py +0 -0
  238. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/retro/__init__.py +0 -0
  239. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/retro/base_attention.py +0 -0
  240. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/retro/config.py +0 -0
  241. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/retro/decoder_attention.py +0 -0
  242. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/retro/encoder_attention.py +0 -0
  243. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/retro/model.py +0 -0
  244. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/retro/utils.py +0 -0
  245. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/vision/__init__.py +0 -0
  246. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  247. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/msc_utils.py +0 -0
  248. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/num_microbatches_calculator.py +0 -0
  249. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/optimizer/clip_grads.py +0 -0
  250. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  251. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  252. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/optimizer/grad_scaler.py +0 -0
  253. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/optimizer/optimizer_config.py +0 -0
  254. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/optimizer_param_scheduler.py +0 -0
  255. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/pipeline_parallel/__init__.py +0 -0
  256. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/post_training/__init__.py +0 -0
  257. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/__init__.py +0 -0
  258. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  259. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  260. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  261. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/layers.py +0 -0
  262. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  263. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  264. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/process_groups_config.py +0 -0
  265. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/quantization/__init__.py +0 -0
  266. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/quantization/quant_config.py +0 -0
  267. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/quantization/utils.py +0 -0
  268. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/requirements.txt +0 -0
  269. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/ssm/__init__.py +0 -0
  270. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_block.py +0 -0
  271. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  272. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  273. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_layer.py +0 -0
  274. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/ssm/mamba_mixer.py +0 -0
  275. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/ssm/mlp_layer.py +0 -0
  276. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/ssm/triton_cache_manager.py +0 -0
  277. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/__init__.py +0 -0
  278. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  279. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/data.py +0 -0
  280. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/layers.py +0 -0
  281. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/mappings.py +0 -0
  282. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/tensor_parallel/utils.py +0 -0
  283. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/timers.py +0 -0
  284. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/__init__.py +0 -0
  285. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  286. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  287. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/dot_product_attention.py +0 -0
  288. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  289. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/identity_op.py +0 -0
  290. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/mlp.py +0 -0
  291. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/module.py +0 -0
  292. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/__init__.py +0 -0
  293. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  294. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  295. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/shared_experts.py +0 -0
  296. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  297. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  298. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  299. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/spec_utils.py +0 -0
  300. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/torch_layer_norm.py +0 -0
  301. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/torch_norm.py +0 -0
  302. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/transformer/utils.py +0 -0
  303. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron/core/utils.py +0 -0
  304. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron_core.egg-info/SOURCES.txt +0 -0
  305. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron_core.egg-info/dependency_links.txt +0 -0
  306. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/megatron_core.egg-info/top_level.txt +0 -0
  307. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/setup.cfg +0 -0
  308. {megatron_core-0.14.0rc1 → megatron_core-0.14.0rc3}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.14.0rc1
3
+ Version: 0.14.0rc3
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -31,6 +31,7 @@ Description-Content-Type: text/markdown
31
31
  License-File: LICENSE
32
32
  Requires-Dist: torch
33
33
  Requires-Dist: numpy<2.0.0
34
+ Requires-Dist: packaging~=25.0
34
35
  Provides-Extra: mlm
35
36
  Requires-Dist: flask-restful; extra == "mlm"
36
37
  Requires-Dist: sentencepiece; extra == "mlm"
@@ -38,14 +39,16 @@ Requires-Dist: tiktoken; extra == "mlm"
38
39
  Requires-Dist: wandb; extra == "mlm"
39
40
  Provides-Extra: dev
40
41
  Requires-Dist: tqdm; extra == "dev"
41
- Requires-Dist: einops; extra == "dev"
42
- Requires-Dist: tensorstore!=0.1.46,!=0.1.72; extra == "dev"
43
- Requires-Dist: nvtx; extra == "dev"
44
- Requires-Dist: transformers; extra == "dev"
45
- Requires-Dist: multi-storage-client; extra == "dev"
42
+ Requires-Dist: einops~=0.8; extra == "dev"
43
+ Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "dev"
44
+ Requires-Dist: nvtx~=0.2; extra == "dev"
45
+ Requires-Dist: transformers~=4.53; extra == "dev"
46
+ Requires-Dist: multi-storage-client~=0.20.3; extra == "dev"
47
+ Requires-Dist: opentelemetry-api~=1.33.1; extra == "dev"
46
48
  Requires-Dist: setuptools<80.0.0; extra == "dev"
47
- Requires-Dist: nvidia-modelopt[torch]; sys_platform != "darwin" and extra == "dev"
48
- Requires-Dist: megatron-energon[av_decode]<7; extra == "dev"
49
+ Requires-Dist: nvidia-modelopt[torch]<0.32.0,>=0.31.0a0; sys_platform != "darwin" and extra == "dev"
50
+ Requires-Dist: megatron-energon[av_decode]~=6.0; extra == "dev"
51
+ Requires-Dist: flashinfer-python; extra == "dev"
49
52
  Provides-Extra: lts
50
53
  Requires-Dist: tqdm; extra == "lts"
51
54
  Requires-Dist: einops; extra == "lts"
@@ -217,6 +217,16 @@ class FullyShardedDataParallel(_BaseDataParallel):
217
217
 
218
218
  self.module.apply(unmap_weight_tensor)
219
219
 
220
+ for param in self.module.parameters():
221
+ if not hasattr(param, 'grad_added_to_main_grad'):
222
+ # This is to ensure that the param.grad_added_to_main_grad is set to False
223
+ # when the parameter is created.
224
+ param.grad_added_to_main_grad = False
225
+ if not hasattr(param, '__fsdp_param__'):
226
+ # This is to ensure that the param.__fsdp_param__ is set to True
227
+ # when the parameter is created.
228
+ param.__fsdp_param__ = True
229
+
220
230
  def _init_fsdp_param_and_grad_buffer(self):
221
231
  if self.config.calculate_per_token_loss:
222
232
  # We don't need to scale the gradients in this case.
@@ -7,9 +7,16 @@ class ModelType(enum.Enum):
7
7
  """Model type."""
8
8
 
9
9
  encoder_or_decoder = 1
10
- encoder_and_decoder = 2
11
- retro_encoder = 3
12
- retro_decoder = 4
10
+ retro_encoder = 2
11
+ retro_decoder = 3
12
+
13
+ @property
14
+ def encoder_and_decoder(self):
15
+ """Deprecated property - use encoder_or_decoder instead."""
16
+ raise ValueError(
17
+ "ModelType.encoder_and_decoder is deprecated. Please use ModelType.encoder_or_decoder "
18
+ "instead."
19
+ )
13
20
 
14
21
 
15
22
  class Fp8Recipe(str, enum.Enum):
@@ -2,7 +2,9 @@
2
2
 
3
3
  """Utility functions related to FP8 that are used throughout Megatron core"""
4
4
 
5
+ import weakref
5
6
  from contextlib import nullcontext
7
+ from functools import wraps
6
8
  from typing import List, Optional
7
9
 
8
10
  import torch
@@ -53,6 +55,29 @@ except (ImportError, ModuleNotFoundError):
53
55
  # MXFP8Tensor not found
54
56
  HAVE_TE_MXFP8TENSOR = False
55
57
 
58
+ if HAVE_TE:
59
+ from megatron.core.extensions.transformer_engine import (
60
+ TEColumnParallelLinear,
61
+ TELayerNormColumnParallelLinear,
62
+ TELinear,
63
+ TERowParallelLinear,
64
+ )
65
+
66
+ TE_LINEAR_TYPES = (
67
+ TELinear,
68
+ TEColumnParallelLinear,
69
+ TERowParallelLinear,
70
+ TELayerNormColumnParallelLinear,
71
+ )
72
+ else:
73
+ TE_LINEAR_TYPES = ()
74
+
75
+ try:
76
+ from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
77
+ except ImportError:
78
+ Fp8Padding = None
79
+ Fp8Unpadding = None
80
+
56
81
 
57
82
  def is_float8tensor(tensor: torch.Tensor) -> bool:
58
83
  """Check if a tensor is a Transformer Engine Float8Tensor.
@@ -346,8 +371,12 @@ else:
346
371
  def _modify_underlying_storage_impl(*args, **kwargs):
347
372
  raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
348
373
 
349
- def _quantize_param_shard_impl(*args, **kwargs):
350
- raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
374
+ def _quantize_param_shard_impl(model_params, *args, **kwargs):
375
+ if len(model_params) == 0:
376
+ return
377
+ else:
378
+ # If TE is not installed, there shouldn't be any fp8 params.
379
+ raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
351
380
 
352
381
  def _correct_amax_history_if_needed_impl(*args, **kwargs):
353
382
  # If TE is not installed, we are definitely not using fp8 for training, so no correction
@@ -507,3 +536,97 @@ else:
507
536
  def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False):
508
537
  """Returns dummy fp8 context manager since TE is not available."""
509
538
  return nullcontext()
539
+
540
+
541
+ if HAVE_TE:
542
+ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
543
+
544
+ # Modules that have been wrapped for inference for fp8
545
+ _fp8_inference_wrapped_modules = weakref.WeakSet()
546
+
547
+ def _wrap_te_linear_for_padding(module: torch.nn.Module):
548
+ """Wrap a TE linear module to automatically pad sequences for FP8 inference.
549
+
550
+ Modifies the module's forward method to:
551
+ 1. Pad input sequences to FP8 alignment requirements
552
+ 2. Run the original forward pass
553
+ 3. Unpad outputs to original sequence length
554
+
555
+ Args:
556
+ module: A Transformer Engine linear layer (TELinear, TEColumnParallelLinear, etc.)
557
+ """
558
+ if module in _fp8_inference_wrapped_modules:
559
+ return
560
+ _pad_func = Fp8Padding(1)
561
+ _unpad_func = Fp8Unpadding(1)
562
+
563
+ original_forward = module.forward
564
+
565
+ @wraps(original_forward)
566
+ def padded_forward(input_tensor, *args, **kwargs):
567
+ # Only do padding for fp8 if we are in fp8 context
568
+ if not FP8GlobalStateManager.is_fp8_enabled():
569
+ return original_forward(input_tensor, *args, **kwargs)
570
+
571
+ seq_len, batch_size, hidden_size = input_tensor.shape
572
+ # Reshape to (S, B*H) to pad sequence dimension
573
+ input_2d = input_tensor.reshape(seq_len, -1)
574
+ # Pad the sequence dimension
575
+ padded_input_2d, _ = _pad_func(input_2d, [seq_len])
576
+ padded_seq_len = padded_input_2d.shape[0]
577
+
578
+ # Reshape back to (padded_S, B, H)
579
+ padded_input_3d = padded_input_2d.view(padded_seq_len, batch_size, hidden_size)
580
+ output = original_forward(padded_input_3d, *args, **kwargs)
581
+
582
+ # Handle output
583
+ if isinstance(output, tuple):
584
+ output_tensor = output[0]
585
+ other_outputs = output[1:]
586
+ else:
587
+ output_tensor = output
588
+ other_outputs = ()
589
+
590
+ # Unpad output - reshape to 2D, unpad, reshape back
591
+ _, _, output_hidden_size = output_tensor.shape
592
+ output_2d = output_tensor.reshape(padded_seq_len, -1)
593
+ unpadded_output_2d = _unpad_func(output_2d, [seq_len])
594
+ unpadded_output = unpadded_output_2d.reshape(seq_len, batch_size, output_hidden_size)
595
+
596
+ if other_outputs:
597
+ return (unpadded_output,) + other_outputs
598
+ else:
599
+ return unpadded_output
600
+
601
+ module.forward = padded_forward
602
+ _fp8_inference_wrapped_modules.add(module)
603
+
604
+ def prepare_model_for_fp8_inference(model):
605
+ """Prepare a model for FP8 inference by wrapping TE linear layers with padding support.
606
+
607
+ FP8 TE Gemms have specific shape requirements. This function wraps all Transformer
608
+ Engine linear layers in the model to automatically pad/unpad sequences during inference.
609
+
610
+ Args:
611
+ model (model (GPTModel): Model containing TE linear layers.
612
+
613
+ Returns:
614
+ GPTModel: The same model with wrapped linear layers (modified in-place).
615
+
616
+ """
617
+ assert Fp8Padding and Fp8Unpadding, "TE version does not have FP8 padding functions"
618
+ # Find and wrap all TE linear layers
619
+ for module in model.modules():
620
+ if isinstance(module, TE_LINEAR_TYPES):
621
+ _wrap_te_linear_for_padding(module)
622
+
623
+ return model
624
+
625
+ else:
626
+
627
+ def prepare_model_for_fp8_inference(model):
628
+ """If trys using prepare_model_for_fp8_inference without TE we error"""
629
+ raise RuntimeError(
630
+ "prepare_model_for_fp8_inference requires Transformer Engine to be installed. "
631
+ "Please install transformer-engine to use FP8 inference."
632
+ )
@@ -14,6 +14,7 @@ warnings.warn(
14
14
  DeprecationWarning,
15
15
  )
16
16
  from .dynamic_context import (
17
+ ActiveRequestCountOverflowError,
17
18
  ChunkOverflowError,
18
19
  ContextOverflowError,
19
20
  DynamicInferenceContext,