megatron-core 0.15.0rc5__tar.gz → 0.15.0rc7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (353) hide show
  1. {megatron_core-0.15.0rc5/megatron_core.egg-info → megatron_core-0.15.0rc7}/PKG-INFO +4 -3
  2. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/async_utils.py +3 -0
  3. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +17 -9
  4. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +13 -2
  5. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/torch.py +0 -1
  6. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/distributed_data_parallel.py +8 -17
  7. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/finalize_model_grads.py +10 -12
  8. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +14 -3
  9. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/src/megatron_fsdp/package_info.py +1 -1
  10. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/param_and_grad_buffer.py +5 -2
  11. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/extensions/transformer_engine.py +96 -6
  12. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fp8_utils.py +22 -17
  13. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/async_stream.py +1 -1
  14. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/contexts/dynamic_context.py +267 -104
  15. megatron_core-0.15.0rc7/megatron/core/inference/contexts/fused_kv_append_kernel.py +174 -0
  16. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/contexts/static_context.py +3 -1
  17. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/data_parallel_inference_coordinator.py +73 -16
  18. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/engines/dynamic_engine.py +229 -80
  19. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/engines/static_engine.py +7 -6
  20. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/inference_request.py +17 -1
  21. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/sampling_params.py +3 -0
  22. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/scheduler.py +12 -12
  23. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +22 -17
  24. megatron_core-0.15.0rc7/megatron/core/inference/unified_memory.py +89 -0
  25. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/utils.py +7 -0
  26. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +37 -5
  27. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/gpt/gpt_model.py +90 -19
  28. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/optimizer/__init__.py +20 -2
  29. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/optimizer/distrib_optimizer.py +6 -3
  30. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/optimizer/optimizer_config.py +5 -0
  31. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/package_info.py +1 -1
  32. megatron_core-0.15.0rc7/megatron/core/pipeline_parallel/bridge_communicator.py +922 -0
  33. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/ssm/mamba_layer.py +32 -21
  34. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tensor_parallel/layers.py +13 -10
  35. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/attention.py +119 -37
  36. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/cuda_graphs.py +92 -49
  37. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/mlp.py +5 -2
  38. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/module.py +172 -0
  39. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/experts.py +32 -27
  40. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/moe_utils.py +17 -8
  41. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/router.py +15 -3
  42. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/multi_latent_attention.py +2 -0
  43. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/multi_token_prediction.py +13 -10
  44. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/transformer_block.py +6 -0
  45. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/transformer_config.py +7 -19
  46. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/transformer_layer.py +120 -172
  47. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/utils.py +16 -4
  48. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7/megatron_core.egg-info}/PKG-INFO +4 -3
  49. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron_core.egg-info/SOURCES.txt +3 -0
  50. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/pyproject.toml +24 -9
  51. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/LICENSE +0 -0
  52. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/MANIFEST.in +0 -0
  53. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/README.md +0 -0
  54. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/README.md +0 -0
  55. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/__init__.py +0 -0
  56. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/activations.py +0 -0
  57. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/config.py +0 -0
  58. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/config_logger.py +0 -0
  59. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/__init__.py +0 -0
  60. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/bert_dataset.py +0 -0
  61. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/blended_dataset.py +0 -0
  62. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  63. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  64. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/gpt_dataset.py +0 -0
  65. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/helpers.cpp +0 -0
  66. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/helpers.py +0 -0
  67. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/indexed_dataset.py +0 -0
  68. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/masked_dataset.py +0 -0
  69. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/megatron_dataset.py +0 -0
  70. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  71. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/multimodal_dataset.py +0 -0
  72. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/object_storage_utils.py +0 -0
  73. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/__init__.py +0 -0
  74. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/config/__init__.py +0 -0
  75. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  76. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/config/config.py +0 -0
  77. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  78. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  79. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/db/__init__.py +0 -0
  80. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/db/build.py +0 -0
  81. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/db/dataset.py +0 -0
  82. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/db/utils.py +0 -0
  83. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/external_libs.py +0 -0
  84. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/index/__init__.py +0 -0
  85. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/index/build.py +0 -0
  86. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/index/factory.py +0 -0
  87. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/index/index.py +0 -0
  88. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  89. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  90. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  91. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/index/utils.py +0 -0
  92. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/index/validate.py +0 -0
  93. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/query/__init__.py +0 -0
  94. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  95. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  96. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/query/query.py +0 -0
  97. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  98. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/query/utils.py +0 -0
  99. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/retro/utils.py +0 -0
  100. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/t5_dataset.py +0 -0
  101. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/utils.py +0 -0
  102. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/utils_object_storage.py +0 -0
  103. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/datasets/utils_s3.py +0 -0
  104. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/__init__.py +0 -0
  105. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/core.py +0 -0
  106. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  107. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  108. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/mapping.py +0 -0
  109. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  110. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/serialization.py +0 -0
  111. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  112. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  113. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  114. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  115. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/checkpointable.py +0 -0
  116. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  117. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  118. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  119. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  120. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  121. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  122. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  123. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/utils.py +0 -0
  124. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/dist_checkpointing/validation.py +0 -0
  125. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/__init__.py +0 -0
  126. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/data_parallel_base.py +0 -0
  127. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  128. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/__init__.py +0 -0
  129. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +0 -0
  130. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/src/__init__.py +0 -0
  131. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +0 -0
  132. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +0 -0
  133. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +0 -0
  134. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +0 -0
  135. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +0 -0
  136. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +0 -0
  137. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  138. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  139. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/energy_monitor.py +0 -0
  140. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/enums.py +0 -0
  141. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/__init__.py +0 -0
  142. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/data_type.py +0 -0
  143. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/export_config.py +0 -0
  144. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/model_type.py +0 -0
  145. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/__init__.py +0 -0
  146. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  147. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  148. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  149. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  150. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  151. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  152. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  153. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  154. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  155. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  156. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  157. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  158. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/extensions/__init__.py +0 -0
  159. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/extensions/kitchen.py +0 -0
  160. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  161. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fp4_utils.py +0 -0
  162. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/full_cuda_graph.py +0 -0
  163. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/__init__.py +0 -0
  164. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  165. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  166. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  167. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  168. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  169. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_indices_converter.py +0 -0
  170. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_layer_norm.py +0 -0
  171. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  172. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  173. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_softmax.py +0 -0
  174. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/fusions/fused_weighted_squared_relu.py +0 -0
  175. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/hyper_comm_grid.py +0 -0
  176. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/__init__.py +0 -0
  177. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/common_inference_params.py +0 -0
  178. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/communication_utils.py +0 -0
  179. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/contexts/__init__.py +0 -0
  180. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/contexts/base_context.py +0 -0
  181. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  182. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/engines/__init__.py +0 -0
  183. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/engines/abstract_engine.py +0 -0
  184. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/engines/mcore_engine.py +0 -0
  185. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/headers.py +0 -0
  186. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/inference_client.py +0 -0
  187. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  188. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  189. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  190. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  191. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  192. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  193. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  194. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  195. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  196. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  197. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  198. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  199. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/inference_params.py +0 -0
  200. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/jit.py +0 -0
  201. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/model_parallel_config.py +0 -0
  202. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/T5/__init__.py +0 -0
  203. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/T5/t5_model.py +0 -0
  204. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/T5/t5_spec.py +0 -0
  205. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/__init__.py +0 -0
  206. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/backends.py +0 -0
  207. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/bert/__init__.py +0 -0
  208. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  209. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/bert/bert_lm_head.py +0 -0
  210. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/bert/bert_model.py +0 -0
  211. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/bert/pooler.py +0 -0
  212. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/__init__.py +0 -0
  213. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/embeddings/__init__.py +0 -0
  214. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  215. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  216. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  217. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  218. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/language_module/__init__.py +0 -0
  219. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/language_module/language_module.py +0 -0
  220. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
  221. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/vision_module/__init__.py +0 -0
  222. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  223. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/gpt/__init__.py +0 -0
  224. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  225. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
  226. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  227. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  228. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/huggingface/__init__.py +0 -0
  229. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/huggingface/clip_model.py +0 -0
  230. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/huggingface/module.py +0 -0
  231. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/huggingface/qwen_model.py +0 -0
  232. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mamba/__init__.py +0 -0
  233. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  234. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mamba/mamba_model.py +0 -0
  235. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mimo/__init__.py +0 -0
  236. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mimo/config/__init__.py +0 -0
  237. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mimo/config/base_configs.py +0 -0
  238. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mimo/model/__init__.py +0 -0
  239. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mimo/model/base.py +0 -0
  240. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mimo/submodules/audio.py +0 -0
  241. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mimo/submodules/base.py +0 -0
  242. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/mimo/submodules/vision.py +0 -0
  243. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/multimodal/__init__.py +0 -0
  244. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/multimodal/context_parallel.py +0 -0
  245. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/multimodal/llava_model.py +0 -0
  246. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/multimodal/llava_spec.py +0 -0
  247. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/retro/__init__.py +0 -0
  248. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/retro/base_attention.py +0 -0
  249. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/retro/config.py +0 -0
  250. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/retro/decoder_attention.py +0 -0
  251. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/retro/decoder_spec.py +0 -0
  252. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/retro/encoder_attention.py +0 -0
  253. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/retro/encoder_spec.py +0 -0
  254. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/retro/model.py +0 -0
  255. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/retro/utils.py +0 -0
  256. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/vision/__init__.py +0 -0
  257. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/vision/clip_vit_model.py +0 -0
  258. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/vision/multimodal_projector.py +0 -0
  259. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/vision/radio.py +0 -0
  260. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  261. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/msc_utils.py +0 -0
  262. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/nccl_allocator.py +0 -0
  263. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/num_microbatches_calculator.py +0 -0
  264. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/optimizer/clip_grads.py +0 -0
  265. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  266. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  267. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/optimizer/grad_scaler.py +0 -0
  268. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/optimizer/optimizer.py +0 -0
  269. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/optimizer_param_scheduler.py +0 -0
  270. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/packed_seq_params.py +0 -0
  271. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/parallel_state.py +0 -0
  272. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/pipeline_parallel/__init__.py +0 -0
  273. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
  274. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  275. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/pipeline_parallel/schedules.py +0 -0
  276. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/pipeline_parallel/utils.py +0 -0
  277. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/post_training/__init__.py +0 -0
  278. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/post_training/modelopt/__init__.py +0 -0
  279. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  280. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  281. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  282. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/post_training/modelopt/layers.py +0 -0
  283. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  284. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  285. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/process_groups_config.py +0 -0
  286. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/quantization/__init__.py +0 -0
  287. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/quantization/quant_config.py +0 -0
  288. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/quantization/utils.py +0 -0
  289. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/requirements.txt +0 -0
  290. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/rerun_state_machine.py +0 -0
  291. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/safe_globals.py +0 -0
  292. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/ssm/__init__.py +0 -0
  293. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/ssm/mamba_block.py +0 -0
  294. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  295. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  296. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/ssm/mamba_mixer.py +0 -0
  297. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/ssm/mlp_layer.py +0 -0
  298. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/ssm/triton_cache_manager.py +0 -0
  299. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tensor_parallel/__init__.py +0 -0
  300. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  301. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tensor_parallel/data.py +0 -0
  302. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tensor_parallel/mappings.py +0 -0
  303. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tensor_parallel/random.py +0 -0
  304. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tensor_parallel/utils.py +0 -0
  305. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/timers.py +0 -0
  306. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/__init__.py +0 -0
  307. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/base_tokenizer.py +0 -0
  308. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/megatron_tokenizer.py +0 -0
  309. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/__init__.py +0 -0
  310. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/libraries/__init__.py +0 -0
  311. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/libraries/abstract_tokenizer.py +0 -0
  312. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/libraries/bytelevel_tokenizer.py +0 -0
  313. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/libraries/chat_template.py +0 -0
  314. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py +0 -0
  315. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py +0 -0
  316. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/libraries/null_tokenizer.py +0 -0
  317. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py +0 -0
  318. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py +0 -0
  319. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/models/__init__.py +0 -0
  320. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/models/bert_tokenizer.py +0 -0
  321. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/models/default_tokenizer.py +0 -0
  322. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/models/gpt_tokenizer.py +0 -0
  323. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/models/mamba_tokenizer.py +0 -0
  324. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/models/retro_tokenizer.py +0 -0
  325. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/models/t5_tokenizer.py +0 -0
  326. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/text_tokenizer.py +0 -0
  327. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/tokenizers/text/utils/build_tokenizer.py +0 -0
  328. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/__init__.py +0 -0
  329. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  330. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  331. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/dot_product_attention.py +0 -0
  332. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/enums.py +0 -0
  333. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/fsdp_dtensor_checkpoint.py +0 -0
  334. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  335. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  336. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/identity_op.py +0 -0
  337. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/__init__.py +0 -0
  338. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  339. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  340. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/moe_layer.py +0 -0
  341. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/shared_experts.py +0 -0
  342. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  343. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  344. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  345. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/spec_utils.py +0 -0
  346. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/torch_layer_norm.py +0 -0
  347. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/torch_norm.py +0 -0
  348. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron/core/transformer/utils.py +0 -0
  349. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron_core.egg-info/dependency_links.txt +0 -0
  350. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron_core.egg-info/requires.txt +0 -0
  351. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/megatron_core.egg-info/top_level.txt +0 -0
  352. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/setup.cfg +0 -0
  353. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc7}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.15.0rc5
3
+ Version: 0.15.0rc7
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -17,8 +17,9 @@ Classifier: License :: OSI Approved :: BSD License
17
17
  Classifier: Natural Language :: English
18
18
  Classifier: Operating System :: OS Independent
19
19
  Classifier: Programming Language :: Python :: 3
20
- Classifier: Programming Language :: Python :: 3.8
21
- Classifier: Programming Language :: Python :: 3.9
20
+ Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Programming Language :: Python :: 3.12
22
23
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
24
  Classifier: Topic :: Scientific/Engineering :: Image Recognition
24
25
  Classifier: Topic :: Scientific/Engineering :: Mathematics
@@ -564,6 +564,9 @@ class AsyncCallsQueue:
564
564
  Returns:
565
565
  List[int]: list of indices (as returned by `schedule_async_request`)
566
566
  of async calls that have been successfully finalized.
567
+ Raises:
568
+ CheckpointException: if any rank(s) raised an exception during checkpoint
569
+ writing, the exceptions are wrapped and raised on all ranks.
567
570
  """
568
571
  call_idx_finalized = []
569
572
  while self.async_calls:
@@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
19
  import torch
20
20
  from torch import multiprocessing as mp
21
21
  from torch.distributed.checkpoint import FileSystemWriter
22
+ from torch.distributed.checkpoint.api import WRAPPED_EXCEPTION, _wrap_exception
22
23
  from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
23
24
  from torch.distributed.checkpoint.metadata import Metadata
24
25
 
@@ -420,14 +421,14 @@ class FileSystemWriterAsync(FileSystemWriter):
420
421
  """Write all items from ``plan``."""
421
422
  raise NotImplementedError("write_data not implemented for FileSystemWriterAsync")
422
423
 
423
- def retrieve_write_results(self) -> List[WriteResult]:
424
+ def retrieve_write_results(self) -> Union[List[WriteResult], WRAPPED_EXCEPTION]:
424
425
  """
425
426
  Turn the latest dict including write results from `self.results_queue`
426
427
  into a single results lists. Includes error check.
427
428
 
428
- Returns (List[WriteResult]): the list of write results
429
- from all local processes performing the save.
430
-
429
+ Returns (Union(List[WriteResult], WRAPPED_EXCEPTION): the list of write results
430
+ from all local processes performing the save, or a WRAPPED_EXCEPTION if
431
+ an exception was raised during the writing process.
431
432
  """
432
433
  assert self.write_buckets is not None
433
434
 
@@ -437,15 +438,22 @@ class FileSystemWriterAsync(FileSystemWriter):
437
438
  try:
438
439
  write_results_or_exc = self.results_queue.get_nowait()
439
440
  except queue.Empty:
440
- raise RuntimeError("results_queue should not be empty")
441
+ return _wrap_exception(RuntimeError("results_queue should not be empty"))
441
442
 
442
443
  if isinstance(write_results_or_exc, Exception):
443
- raise RuntimeError(f"Worker failure: {write_results_or_exc}") from write_results_or_exc
444
+ try:
445
+ raise RuntimeError(
446
+ f"Worker failure: {write_results_or_exc}"
447
+ ) from write_results_or_exc
448
+ except Exception as e:
449
+ return _wrap_exception(e)
444
450
  write_results: dict = write_results_or_exc
445
451
  if len(write_results) != len(self.write_buckets):
446
- raise RuntimeError(
447
- f"Incomplete worker results (expected {len(self.write_buckets)},"
448
- f" got {len(write_results)}. This probably indicates a worker failure."
452
+ return _wrap_exception(
453
+ RuntimeError(
454
+ f"Incomplete worker results (expected {len(self.write_buckets)},"
455
+ f" got {len(write_results)}. This probably indicates a worker failure."
456
+ )
449
457
  )
450
458
  return list(chain.from_iterable(write_results.values()))
451
459
 
@@ -243,5 +243,16 @@ def save_state_dict_async_finalize(
243
243
  storage_writer.finish(global_metadata, all_results)
244
244
  write_end = time()
245
245
  logger.debug(f"{write_end}, metadata_write: {write_end - write_start}")
246
- else:
247
- raise CheckpointException("write", node_failures)
246
+ else:
247
+ node_failures = {}
248
+
249
+ # Broadcast failure status to all ranks to raise exceptions everywhere if needed.
250
+ # The failure details are only raised on the coordinator.
251
+ failures_occurred = torch.tensor(
252
+ [int(len(node_failures) > 0)], dtype=torch.int, device=torch.cuda.current_device()
253
+ )
254
+ torch.distributed.broadcast(
255
+ failures_occurred, src=dist_wrapper.coordinator_rank, group=dist_wrapper.group
256
+ )
257
+ if failures_occurred:
258
+ raise CheckpointException("write", node_failures)
@@ -830,7 +830,6 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
830
830
 
831
831
  def finalize_fn():
832
832
  save_state_dict_async_finalize(*save_state_dict_ret)
833
- torch.distributed.barrier()
834
833
 
835
834
  return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
836
835
 
@@ -51,8 +51,6 @@ class DistributedDataParallel(_BaseDataParallel):
51
51
  if has_config_logger_enabled(config):
52
52
  log_config_to_disk(config, locals(), prefix=type(self).__name__)
53
53
 
54
- self.module = module
55
-
56
54
  # If bucket_size is not provided as an input, use sane default.
57
55
  # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
58
56
  # ring-reduce implementations are large enough to remain bandwidth-bound rather than
@@ -121,9 +119,7 @@ class DistributedDataParallel(_BaseDataParallel):
121
119
  pp_rank = self.pp_group[0].rank()
122
120
  else:
123
121
  pp_rank = self.pp_group.rank()
124
- if pp_rank > 0:
125
- self.bucket_size = None
126
- if disable_bucketing:
122
+ if disable_bucketing or pp_rank > 0:
127
123
  self.bucket_size = None
128
124
 
129
125
  self.param_to_bucket_group = {}
@@ -519,8 +515,11 @@ class DistributedDataParallel(_BaseDataParallel):
519
515
  param_slice = bucket.param_data.view(-1)[param_start:param_end]
520
516
  param.data.copy_(param_slice.view(param.data.shape))
521
517
  # All-gathered params are not needed after being copied to param.data.
522
- # Zero out the grad buffer (shared with param buffer) for gradient accumulation.
523
- bucket.grad_data.zero_()
518
+ # Zero out the param buffer (shared with grad buffer) for gradient accumulation.
519
+ # We cannot zero out the entire grad buffer because one grad buffer may
520
+ # correspond to multiple param buffers. If we zero out the entire grad buffer,
521
+ # it would clear the data of those param buffers that have not yet completed AG.
522
+ bucket.param_data.zero_()
524
523
 
525
524
  def start_grad_sync(self, *unused):
526
525
  """
@@ -562,16 +561,8 @@ class DistributedDataParallel(_BaseDataParallel):
562
561
  # to True, and there will be a double-GA.
563
562
  for param in self.params_with_grad:
564
563
  param.grad_added_to_main_grad = False
565
- # In the case of "reuse_grad_buf_for_mxfp8_param_ag=True & overlap_param_gather=True",
566
- # the grad buffer is not reset here because the grad buffer is shared with the param buffer.
567
- # The grad buffer is zeroed by "bucket.grad_data.zero_()" in the "finish_param_sync" stage
568
- # after the param all-gather.
569
- if not (
570
- self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag
571
- and self.ddp_config.overlap_param_gather
572
- ):
573
- for buffer in self.buffers + self.expert_parallel_buffers:
574
- buffer.reset()
564
+ for buffer in self.buffers + self.expert_parallel_buffers:
565
+ buffer.reset()
575
566
  for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
576
567
  bucket_group.reset()
577
568
 
@@ -267,13 +267,18 @@ def _allreduce_position_embedding_grads(
267
267
  )
268
268
 
269
269
 
270
- def _reset_global_aux_loss_tracker(model: List[torch.nn.Module]):
270
+ def reset_model_temporary_tensors(config: TransformerConfig, model: List[torch.nn.Module]):
271
271
  """
272
- Reset the global aux loss tracker.
272
+ Reset the temporary tensors of the model.
273
273
  """
274
274
  for model_chunk in model:
275
275
  for module in get_attr_wrapped_model(model_chunk, 'modules')():
276
- if hasattr(module, 'reset_global_aux_loss_tracker'):
276
+ if config.moe_router_enable_expert_bias and hasattr(module, 'expert_bias'):
277
+ module.local_tokens_per_expert.zero_()
278
+ if (
279
+ config.moe_router_load_balancing_type == "global_aux_loss"
280
+ or "global_aux_loss" in config.moe_router_load_balancing_type
281
+ ) and hasattr(module, 'reset_global_aux_loss_tracker'):
277
282
  module.reset_global_aux_loss_tracker()
278
283
 
279
284
 
@@ -298,10 +303,7 @@ def _update_router_expert_bias(model: List[torch.nn.Module], config: Transformer
298
303
  stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate
299
304
  )
300
305
 
301
- for tokens_per_expert, expert_bias, updated_expert_bias in zip(
302
- tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias
303
- ):
304
- tokens_per_expert.zero_()
306
+ for expert_bias, updated_expert_bias in zip(expert_bias_list, stacked_updated_expert_bias):
305
307
  expert_bias.copy_(updated_expert_bias)
306
308
 
307
309
 
@@ -465,11 +467,7 @@ def finalize_model_grads(
465
467
  if config.moe_router_enable_expert_bias:
466
468
  _update_router_expert_bias(model, config)
467
469
 
468
- if (
469
- config.moe_router_load_balancing_type == "global_aux_loss"
470
- or "global_aux_loss" in config.moe_router_load_balancing_type
471
- ):
472
- _reset_global_aux_loss_tracker(model)
470
+ reset_model_temporary_tensors(config, model)
473
471
 
474
472
  # normalize gradients for per-token loss normalization.
475
473
  # if we are using by the number of tokens, then we use that as a divisor. this number
@@ -224,7 +224,7 @@ class MegatronFSDP(torch.nn.Module):
224
224
  # step of the model will reduce all gradients and gather all parameters
225
225
  # for synchronized operations such as distributed optimization and
226
226
  # distributed checkpointing particularly sharding with HSDP / DP-Outer.
227
- self.model_auto_sync = self.set_model_auto_sync(sync_model_each_microbatch)
227
+ self.set_model_auto_sync(sync_model_each_microbatch)
228
228
 
229
229
  # Check if the module contains (Megatron-Core) expert parallel parameters or DTensors.
230
230
  has_expert_parameters = self._check_module_parameter_types()
@@ -307,8 +307,11 @@ class MegatronFSDP(torch.nn.Module):
307
307
  expert_gradient_scaling_factor = None
308
308
  else:
309
309
  if self.ddp_config.average_in_collective:
310
- # FIXME(@jianbinc): Will fix this issue based on Parallel Folding's EDP patch MR.
311
- raise Exception("Not supported")
310
+ gradient_scaling_factor = 1.0
311
+ expert_gradient_scaling_factor = (
312
+ self.dist_index.get_dp_group(is_expert_parallel=True).size()
313
+ / self.dist_index.get_dp_group().size()
314
+ )
312
315
  else:
313
316
  data_parallel_world_size = self.dist_index.get_dp_group().size()
314
317
  gradient_scaling_factor = 1.0 / data_parallel_world_size
@@ -426,6 +429,14 @@ class MegatronFSDP(torch.nn.Module):
426
429
  bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
427
430
  ag_pipeline.wait_bucket_ready(bucket_id)
428
431
 
432
+ for param in params:
433
+ # This setting is needed to make FSDP store the weight object when used
434
+ # with TE's activation offloading for the first global batch.
435
+ param.grad_added_to_main_grad = False
436
+ # This setting is needed to have this attribute present after every
437
+ # un-shard of the FSDP params.
438
+ param.__fsdp_param__ = True
439
+
429
440
  def _register_fsdp_hooks(self, root_module):
430
441
  """Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
431
442
 
@@ -4,7 +4,7 @@
4
4
  MAJOR = 0
5
5
  MINOR = 1
6
6
  PATCH = 0
7
- PRE_RELEASE = 'rc3'
7
+ PRE_RELEASE = 'rc5'
8
8
 
9
9
  # Use the following formatting: (major, minor, patch, pre-release)
10
10
  VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
@@ -313,8 +313,11 @@ class _ParamAndGradBucketGroup:
313
313
  param_slice = bucket.param_data.view(-1)[param_start:param_end]
314
314
  param.data.copy_(param_slice.view(param.data.shape))
315
315
  # All-gathered params are not needed after being copied to param.data.
316
- # Zero out the grad buffer (shared with param buffer) for gradient accumulation.
317
- bucket.grad_data.zero_()
316
+ # Zero out the param buffer (shared with grad buffer) for gradient accumulation.
317
+ # We cannot zero out the entire grad buffer because one grad buffer may
318
+ # correspond to multiple param buffers. If we zero out the entire grad buffer,
319
+ # it would clear the data of those param buffers that have not yet completed AG.
320
+ bucket.param_data.zero_()
318
321
 
319
322
  def start_grad_sync(self):
320
323
  """
@@ -1,6 +1,7 @@
1
1
  # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
2
 
3
3
  import dataclasses
4
+ import inspect
4
5
  import io
5
6
  import os
6
7
  import pickle
@@ -1591,21 +1592,21 @@ if HAVE_TE and is_te_min_version("1.13.0"):
1591
1592
  if self.linear_fc2.config.tp_comm_overlap and self.linear_fc2.ub_name is not None:
1592
1593
  userbuffers_options = {"comm_name": self.linear_fc2.ub_name}
1593
1594
  op = te.pytorch.ops.BasicLinear(
1594
- weight.size(1) * tp_world_size,
1595
+ weight.size(1),
1595
1596
  weight.size(0),
1596
1597
  device="meta",
1597
1598
  dtype=weight.dtype,
1598
- tensor_parallel_mode="row" if tp_world_size > 1 else None,
1599
- tensor_parallel_group=tp_group,
1600
- sequence_parallel=self.linear_fc2.sequence_parallel,
1601
1599
  rng_state_tracker_function=rng_state_tracker_function,
1602
1600
  accumulate_into_main_grad=self.linear_fc2.fuse_wgrad_accumulation,
1603
1601
  userbuffers_options=userbuffers_options,
1604
1602
  )
1605
1603
  op.weight = weight
1606
1604
  fused_impl.append(op)
1607
- if tp_world_size > 1 and self.linear_fc2.sequence_parallel:
1608
- fused_impl.append(te.pytorch.ops.ReduceScatter(tp_group))
1605
+ if tp_world_size > 1:
1606
+ if self.linear_fc2.sequence_parallel:
1607
+ fused_impl.append(te.pytorch.ops.ReduceScatter(tp_group))
1608
+ else:
1609
+ fused_impl.append(te.pytorch.ops.AllReduce(tp_group))
1609
1610
 
1610
1611
  # FC2 bias op
1611
1612
  if not self.linear_fc2.te_return_bias:
@@ -1617,6 +1618,9 @@ if HAVE_TE and is_te_min_version("1.13.0"):
1617
1618
  op.bias = bias
1618
1619
  fused_impl.append(op)
1619
1620
 
1621
+ # Emulate submodule forward hooks if needed
1622
+ self._register_hooks_on_fused_impl(fused_impl)
1623
+
1620
1624
  return fused_impl
1621
1625
 
1622
1626
  def _make_activation_op(
@@ -1655,6 +1659,92 @@ if HAVE_TE and is_te_min_version("1.13.0"):
1655
1659
  kwargs["cache_quantized_input"] = cache_quantized_input
1656
1660
  return op_type(**kwargs)
1657
1661
 
1662
+ def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None:
1663
+ """Attempt to emulate submodule callback hooks.
1664
+
1665
+ This is not always possible because Transformer Engine's
1666
+ op fuser does not expose intermediate tensors. Depending
1667
+ on what kernel fusions the op fuser chooses, the
1668
+ intermediate tensors may not even exist. Hooks that modify
1669
+ tensors will result in incorrect behavior.
1670
+
1671
+ """
1672
+
1673
+ # Get submodule hooks
1674
+ forward_pre_hooks = []
1675
+ forward_post_hooks = []
1676
+ backward_pre_hooks = []
1677
+ backward_post_hooks = []
1678
+ for submodule in self.modules():
1679
+ for hook in submodule._forward_pre_hooks.values():
1680
+ forward_pre_hooks.append((submodule, hook))
1681
+ for hook in submodule._forward_hooks.values():
1682
+ forward_post_hooks.append((submodule, hook))
1683
+ for hook in submodule._backward_pre_hooks.values():
1684
+ backward_pre_hooks.append((submodule, hook))
1685
+ for hook in submodule._backward_hooks.values():
1686
+ backward_post_hooks.append((submodule, hook))
1687
+
1688
+ # Pre-forward hooks
1689
+ # Note: DDP pre-forward hooks are safe since they do not
1690
+ # interact with input tensor.
1691
+ if forward_pre_hooks:
1692
+ from megatron.core.distributed import distributed_data_parallel
1693
+
1694
+ if any(
1695
+ inspect.getmodule(hook) != distributed_data_parallel
1696
+ for _, hook in forward_pre_hooks
1697
+ ):
1698
+ warnings.warn(
1699
+ "TEFusedMLP module has a submodule with a pre-forward hook. "
1700
+ "TEFusedMLP module does not expose intermediate tensors, "
1701
+ "so the hook may have incorrect behavior if it attempts to "
1702
+ "access the input tensor."
1703
+ )
1704
+
1705
+ def forward_pre_hook(module, *_) -> None:
1706
+ for submodule, hook in forward_pre_hooks:
1707
+ # Assume that hook does not interact with input
1708
+ ret = hook(submodule, None)
1709
+ if ret is not None:
1710
+ raise RuntimeError(
1711
+ "TEFusedMLP module does not expose intermediate tensors, but "
1712
+ "submodule has pre-forward hook that modifies input tensor."
1713
+ )
1714
+
1715
+ fused_impl.register_forward_pre_hook(forward_pre_hook)
1716
+
1717
+ # Post-forward hooks
1718
+ if forward_post_hooks:
1719
+ warnings.warn(
1720
+ "TEFusedMLP module has a submodule with a post-forward hook. "
1721
+ "TEFusedMLP module does not expose intermediate tensors, "
1722
+ "so the hook may have incorrect behavior if it attempts to "
1723
+ "access the input or output tensors."
1724
+ )
1725
+
1726
+ def forward_post_hook(module, *_) -> None:
1727
+ for submodule, hook in forward_post_hooks:
1728
+ # Assume that hook does not interact with input or output
1729
+ ret = hook(submodule, None, None)
1730
+ if ret is not None:
1731
+ raise RuntimeError(
1732
+ "TEFusedMLP module does not expose intermediate tensors, but "
1733
+ "submodule has post-forward hook that modifies output tensor."
1734
+ )
1735
+
1736
+ fused_impl.register_forward_hook(forward_post_hook)
1737
+
1738
+ # Backward hooks
1739
+ if backward_pre_hooks:
1740
+ raise RuntimeError(
1741
+ "TEFusedMLP module does not support submodules with pre-backward hooks"
1742
+ )
1743
+ if backward_post_hooks:
1744
+ raise RuntimeError(
1745
+ "TEFusedMLP module does not support submodules with post-backward hooks"
1746
+ )
1747
+
1658
1748
  def forward(self, hidden_states: torch.Tensor) -> Tuple[Tensor, Optional[Tensor]]:
1659
1749
  """Forward."""
1660
1750
 
@@ -406,6 +406,25 @@ def correct_amax_history_if_needed(model: List[torch.nn.Module]):
406
406
  _correct_amax_history_if_needed_impl(model)
407
407
 
408
408
 
409
+ def is_first_last_bf16_layer(config: TransformerConfig, layer_no: int):
410
+ """Check if the layer is in bf16."""
411
+ num_bf16_layers_at_start = (
412
+ config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0
413
+ )
414
+ num_bf16_layers_at_end = (
415
+ config.num_layers_at_end_in_bf16 if config.first_last_layers_bf16 else 0
416
+ )
417
+ # Since layer_no is a global layer index, additional checks on whether
418
+ # we are in the first or last pipeline-parallel rank are not needed.
419
+ is_first_layer = layer_no < num_bf16_layers_at_start
420
+ is_last_layer = layer_no >= config.num_layers - num_bf16_layers_at_end
421
+
422
+ if layer_no >= 0 and config.first_last_layers_bf16 and (is_first_layer or is_last_layer):
423
+ return True
424
+ else:
425
+ return False
426
+
427
+
409
428
  if HAVE_TE:
410
429
  from megatron.core import parallel_state
411
430
  from megatron.core.extensions.transformer_engine import TEDelayedScaling
@@ -437,7 +456,7 @@ if HAVE_TE:
437
456
  )
438
457
  elif config.fp8_recipe == Fp8Recipe.tensorwise and is_te_min_version("2.2.0.dev0"):
439
458
  fp8_recipe = transformer_engine.common.recipe.Float8CurrentScaling(
440
- fp8_format=fp8_format
459
+ fp8_format=fp8_format, fp8_dpa=config.fp8_dot_product_attention
441
460
  )
442
461
  elif config.fp8_recipe == Fp8Recipe.blockwise and is_te_min_version("2.3.0.dev0"):
443
462
  fp8_recipe = transformer_engine.common.recipe.Float8BlockScaling(
@@ -483,24 +502,10 @@ if HAVE_TE:
483
502
  that needs to be trained in bf16.
484
503
  """
485
504
 
486
- num_bf16_layers_at_start = (
487
- config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0
488
- )
489
- num_bf16_layers_at_end = (
490
- config.num_layers_at_end_in_bf16 if config.first_last_layers_bf16 else 0
491
- )
492
- # Since layer_no is a global layer index, additional checks on whether
493
- # we are in the first or last pipeline-parallel rank are not needed.
494
- is_first_layer = layer_no < num_bf16_layers_at_start
495
- is_last_layer = layer_no >= config.num_layers - num_bf16_layers_at_end
496
-
497
505
  need_fp8_context = config.fp8 if not is_init else config.fp8_param
498
506
 
499
- if not need_fp8_context:
500
- # bf16 training
501
- fp8_context = nullcontext()
502
- elif layer_no >= 0 and config.first_last_layers_bf16 and (is_first_layer or is_last_layer):
503
- # fp8 training but this layer_no should be bf16
507
+ if not need_fp8_context or is_first_last_bf16_layer(config, layer_no):
508
+ # bf16 training or bf16 layer in fp8 training
504
509
  fp8_context = nullcontext()
505
510
  else:
506
511
  # fp8 training and this layer_no is in fp8
@@ -20,7 +20,7 @@ class AsyncStream:
20
20
  Adopted from https://github.com/vllm-project/vllm/blob/eb881ed006ca458b052905e33f0d16dbb428063a/vllm/v1/engine/async_stream.py # pylint: disable=line-too-long
21
21
  """
22
22
 
23
- def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
23
+ def __init__(self, request_id: int, cancel: Callable[[str], None]) -> None:
24
24
  self._request_id = request_id
25
25
  self._cancel = cancel
26
26
  self._queue: asyncio.Queue = asyncio.Queue()