megatron-core 0.14.0rc3__tar.gz → 0.14.0rc4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (310) hide show
  1. {megatron_core-0.14.0rc3/megatron_core.egg-info → megatron_core-0.14.0rc4}/PKG-INFO +1 -1
  2. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/torch.py +2 -1
  3. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/validation.py +21 -15
  4. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine.py +5 -27
  5. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine_spec_provider.py +5 -0
  6. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +81 -16
  7. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/dynamic_context.py +44 -28
  8. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +23 -2
  9. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/model_parallel_config.py +8 -3
  10. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/rope_utils.py +20 -32
  11. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +11 -6
  12. megatron_core-0.14.0rc4/megatron/core/models/common/model_chunk_schedule_plan.py +502 -0
  13. megatron_core-0.14.0rc4/megatron/core/models/gpt/fine_grained_callables.py +474 -0
  14. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/gpt_layer_specs.py +2 -2
  15. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/gpt_model.py +62 -1
  16. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/optimizer.py +11 -1
  17. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/optimizer_config.py +1 -0
  18. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/package_info.py +1 -1
  19. megatron_core-0.14.0rc4/megatron/core/pipeline_parallel/combined_1f1b.py +331 -0
  20. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/schedules.py +169 -101
  21. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/utils.py +91 -0
  22. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_block.py +4 -1
  23. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_layer.py +1 -1
  24. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/layers.py +23 -12
  25. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/attention.py +1 -0
  26. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/mlp.py +20 -2
  27. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/experts.py +22 -0
  28. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/multi_latent_attention.py +81 -9
  29. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_config.py +60 -7
  30. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_layer.py +11 -10
  31. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/utils.py +17 -11
  32. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/utils.py +27 -3
  33. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4/megatron_core.egg-info}/PKG-INFO +1 -1
  34. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron_core.egg-info/SOURCES.txt +2 -0
  35. megatron_core-0.14.0rc3/megatron/core/models/gpt/fine_grained_callables.py +0 -195
  36. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/LICENSE +0 -0
  37. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/MANIFEST.in +0 -0
  38. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/README.md +0 -0
  39. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/README.md +0 -0
  40. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/__init__.py +0 -0
  41. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/config.py +0 -0
  42. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/config_logger.py +0 -0
  43. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/__init__.py +0 -0
  44. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/bert_dataset.py +0 -0
  45. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_dataset.py +0 -0
  46. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  47. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  48. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/gpt_dataset.py +0 -0
  49. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/helpers.cpp +0 -0
  50. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/helpers.py +0 -0
  51. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/indexed_dataset.py +0 -0
  52. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/masked_dataset.py +0 -0
  53. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/megatron_dataset.py +0 -0
  54. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  55. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/multimodal_dataset.py +0 -0
  56. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/object_storage_utils.py +0 -0
  57. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/__init__.py +0 -0
  58. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/__init__.py +0 -0
  59. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  60. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/config.py +0 -0
  61. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  62. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  63. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/__init__.py +0 -0
  64. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/build.py +0 -0
  65. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/dataset.py +0 -0
  66. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/utils.py +0 -0
  67. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/external_libs.py +0 -0
  68. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/__init__.py +0 -0
  69. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/build.py +0 -0
  70. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/factory.py +0 -0
  71. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/index.py +0 -0
  72. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  73. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  74. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  75. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/utils.py +0 -0
  76. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/validate.py +0 -0
  77. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/__init__.py +0 -0
  78. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  79. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  80. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/query.py +0 -0
  81. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  82. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/utils.py +0 -0
  83. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/utils.py +0 -0
  84. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/t5_dataset.py +0 -0
  85. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils.py +0 -0
  86. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils_object_storage.py +0 -0
  87. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils_s3.py +0 -0
  88. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/__init__.py +0 -0
  89. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/core.py +0 -0
  90. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  91. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  92. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/mapping.py +0 -0
  93. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  94. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/serialization.py +0 -0
  95. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  96. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  97. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  98. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  99. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  100. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  101. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  102. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  103. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  104. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  105. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  106. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  107. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  108. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  109. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/utils.py +0 -0
  110. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/__init__.py +0 -0
  111. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
  112. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -0
  113. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
  114. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/data_parallel_base.py +0 -0
  115. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  116. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  117. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/finalize_model_grads.py +0 -0
  118. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  119. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  120. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  121. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/energy_monitor.py +0 -0
  122. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/enums.py +0 -0
  123. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/__init__.py +0 -0
  124. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/data_type.py +0 -0
  125. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/export_config.py +0 -0
  126. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/model_type.py +0 -0
  127. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/__init__.py +0 -0
  128. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  129. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  130. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  131. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  132. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  133. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  134. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  135. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  136. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  137. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  138. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  139. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  140. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/extensions/__init__.py +0 -0
  141. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/extensions/kitchen.py +0 -0
  142. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fp8_utils.py +0 -0
  143. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/__init__.py +0 -0
  144. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  145. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  146. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  147. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  148. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  149. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_indices_converter.py +0 -0
  150. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_layer_norm.py +0 -0
  151. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  152. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_softmax.py +0 -0
  153. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/hyper_comm_grid.py +0 -0
  154. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/__init__.py +0 -0
  155. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/async_stream.py +0 -0
  156. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/common_inference_params.py +0 -0
  157. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/communication_utils.py +0 -0
  158. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/__init__.py +0 -0
  159. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/base_context.py +0 -0
  160. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  161. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/static_context.py +0 -0
  162. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/__init__.py +0 -0
  163. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/abstract_engine.py +0 -0
  164. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/dynamic_engine.py +0 -0
  165. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/mcore_engine.py +0 -0
  166. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/static_engine.py +0 -0
  167. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/inference_request.py +0 -0
  168. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  169. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  170. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  171. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  172. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  173. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  174. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  175. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  176. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/sampling_params.py +0 -0
  177. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/scheduler.py +0 -0
  178. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  179. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  180. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  181. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  182. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference/utils.py +0 -0
  183. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/inference_params.py +0 -0
  184. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/jit.py +0 -0
  185. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/T5/__init__.py +0 -0
  186. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/T5/t5_model.py +0 -0
  187. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/T5/t5_spec.py +0 -0
  188. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/__init__.py +0 -0
  189. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/backends.py +0 -0
  190. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/__init__.py +0 -0
  191. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  192. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_lm_head.py +0 -0
  193. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_model.py +0 -0
  194. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/bert/pooler.py +0 -0
  195. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/__init__.py +0 -0
  196. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/__init__.py +0 -0
  197. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  198. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  199. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  200. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/language_module/__init__.py +0 -0
  201. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/language_module/language_module.py +0 -0
  202. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/vision_module/__init__.py +0 -0
  203. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  204. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/__init__.py +0 -0
  205. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  206. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  207. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/__init__.py +0 -0
  208. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/clip_model.py +0 -0
  209. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/module.py +0 -0
  210. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/qwen_model.py +0 -0
  211. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/__init__.py +0 -0
  212. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  213. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/mamba_model.py +0 -0
  214. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/__init__.py +0 -0
  215. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/config/__init__.py +0 -0
  216. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/config/base_configs.py +0 -0
  217. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/model/__init__.py +0 -0
  218. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/model/base.py +0 -0
  219. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/audio.py +0 -0
  220. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/base.py +0 -0
  221. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/vision.py +0 -0
  222. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/__init__.py +0 -0
  223. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/context_parallel.py +0 -0
  224. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/llava_model.py +0 -0
  225. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/llava_spec.py +0 -0
  226. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/__init__.py +0 -0
  227. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/base_attention.py +0 -0
  228. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/config.py +0 -0
  229. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/decoder_attention.py +0 -0
  230. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/decoder_spec.py +0 -0
  231. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/encoder_attention.py +0 -0
  232. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/encoder_spec.py +0 -0
  233. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/model.py +0 -0
  234. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/retro/utils.py +0 -0
  235. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/__init__.py +0 -0
  236. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/clip_vit_model.py +0 -0
  237. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/multimodal_projector.py +0 -0
  238. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/radio.py +0 -0
  239. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  240. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/msc_utils.py +0 -0
  241. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/num_microbatches_calculator.py +0 -0
  242. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/__init__.py +0 -0
  243. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/clip_grads.py +0 -0
  244. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  245. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  246. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/distrib_optimizer.py +0 -0
  247. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer/grad_scaler.py +0 -0
  248. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/optimizer_param_scheduler.py +0 -0
  249. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/packed_seq_params.py +0 -0
  250. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/parallel_state.py +0 -0
  251. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/__init__.py +0 -0
  252. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  253. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/__init__.py +0 -0
  254. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/__init__.py +0 -0
  255. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  256. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  257. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  258. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/layers.py +0 -0
  259. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  260. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  261. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/process_groups_config.py +0 -0
  262. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/quantization/__init__.py +0 -0
  263. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/quantization/quant_config.py +0 -0
  264. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/quantization/utils.py +0 -0
  265. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/requirements.txt +0 -0
  266. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/rerun_state_machine.py +0 -0
  267. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/__init__.py +0 -0
  268. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  269. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  270. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_mixer.py +0 -0
  271. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/mlp_layer.py +0 -0
  272. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/ssm/triton_cache_manager.py +0 -0
  273. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/__init__.py +0 -0
  274. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  275. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/data.py +0 -0
  276. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/mappings.py +0 -0
  277. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/random.py +0 -0
  278. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/utils.py +0 -0
  279. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/timers.py +0 -0
  280. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/__init__.py +0 -0
  281. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/cuda_graphs.py +0 -0
  282. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  283. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  284. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/dot_product_attention.py +0 -0
  285. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/enums.py +0 -0
  286. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  287. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  288. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/identity_op.py +0 -0
  289. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/module.py +0 -0
  290. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/__init__.py +0 -0
  291. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  292. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  293. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/moe_layer.py +0 -0
  294. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/moe_utils.py +0 -0
  295. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/router.py +0 -0
  296. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/shared_experts.py +0 -0
  297. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  298. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  299. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/multi_token_prediction.py +0 -0
  300. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  301. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/spec_utils.py +0 -0
  302. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/torch_layer_norm.py +0 -0
  303. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/torch_norm.py +0 -0
  304. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_block.py +0 -0
  305. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron_core.egg-info/dependency_links.txt +0 -0
  306. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron_core.egg-info/requires.txt +0 -0
  307. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/megatron_core.egg-info/top_level.txt +0 -0
  308. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/pyproject.toml +0 -0
  309. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/setup.cfg +0 -0
  310. {megatron_core-0.14.0rc3 → megatron_core-0.14.0rc4}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.14.0rc3
3
+ Version: 0.14.0rc4
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -374,7 +374,8 @@ def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]
374
374
  ten = ten.view(-1)
375
375
  else:
376
376
  for _ in range(mcore_sh_ten.prepend_axis_num):
377
- ten = ten.squeeze(0)
377
+ assert ten.size(0) == 1
378
+ ten = ten[0] # NOTE: ten.squeeze(0) uses more memory for FP8 tensors
378
379
  ret_tensors.append(ten)
379
380
  return ret_tensors
380
381
 
@@ -375,28 +375,34 @@ def maybe_report_missing_and_unexpected_keys(
375
375
  def _validate_common_state_dict(common_state_dict: CommonStateDict) -> None:
376
376
  """Validate consistancy across ranks for the common state dict
377
377
 
378
- We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving.
378
+ We save the common state dict only on rank 0. We validate to make sure that the common dict is consistent across ranks before saving.
379
379
 
380
380
  Args:
381
381
  common_state_dict: The common state dict present in all ransk
382
382
  """
383
+ if not torch.distributed.is_initialized():
384
+ return
383
385
 
384
- # Gather the common state dict across ranks onto rank 0 for comparison
386
+ # Broadcast the common state dict from rank 0 to all other ranks
387
+ # Each rank will do a comparison with its local rank vs the broadcasted state dict from rank 0
385
388
  rank = torch.distributed.get_rank()
386
- other_rank_state_dicts = [None] * torch.distributed.get_world_size() if rank == 0 else None
387
- torch.distributed.gather_object(common_state_dict, other_rank_state_dicts)
388
- common_state_dict_diff = {}
389
- if rank == 0:
390
- assert other_rank_state_dicts
391
- main_rank_state_dict = common_state_dict
392
- for rank, rank_state_dict in enumerate(other_rank_state_dicts[1:], 1):
393
- only_left, only_right, mismatch = diff(main_rank_state_dict, rank_state_dict)
394
- if only_left or only_right or mismatch:
395
- common_state_dict_diff[rank] = (only_left, only_right, mismatch)
396
-
397
- if len(common_state_dict_diff) != 0:
389
+
390
+ object_list = [common_state_dict] if rank == 0 else [None]
391
+ torch.distributed.broadcast_object_list(object_list, src=0)
392
+ rank0_state_dict = object_list[0]
393
+
394
+ # Skip comparing rank 0 with itself
395
+ if rank > 0:
396
+ current_rank_state_dict = common_state_dict
397
+ only_in_rank0, only_in_current_rank, mismatch = diff(
398
+ rank0_state_dict, current_rank_state_dict
399
+ )
400
+ if only_in_rank0 or only_in_current_rank or mismatch:
398
401
  logger.warning(
399
- f"There is difference in the common state dict in different ranks. The differences are {common_state_dict_diff}"
402
+ f"Rank {rank} common state dict differs from rank 0 common state dict. "
403
+ f"Keys only on rank 0: {only_in_rank0}, "
404
+ f"Keys only on {rank}: {only_in_current_rank}, "
405
+ f"Mismatched keys: {mismatch}"
400
406
  )
401
407
 
402
408
 
@@ -889,25 +889,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
889
889
  if packed_seq_params is not None
890
890
  else {}
891
891
  )
892
- # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
893
- # after init
894
- if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False):
895
- self.qkv_format = "bshd"
896
-
897
- qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
898
-
899
- # WAR for peak memory usage.
900
- # See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
901
- if self.config.apply_rope_fusion and qkv_format == "bshd":
902
- query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
903
- # In PyTorch, the following two tensors are in fact the same:
904
- # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
905
- # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
906
- # Stride for a dimension that is 1 has no meaning, so tensors created two different ways
907
- # can have same shape but different strides.
908
- # We unify them to the first one to pass the stride check in TE
909
- if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
910
- value = value.as_strided(value.shape, key.stride())
892
+ qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
911
893
 
912
894
  attention_bias_kwargs = {}
913
895
  if attention_bias is not None:
@@ -942,10 +924,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
942
924
  query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
943
925
  )
944
926
 
945
- if self.config.apply_rope_fusion and qkv_format == "bshd":
946
- return core_attn_out.transpose(0, 1)
947
- else:
948
- return core_attn_out
927
+ return core_attn_out
949
928
 
950
929
 
951
930
  if HAVE_TE and is_te_min_version("1.9.0.dev0"):
@@ -1633,10 +1612,8 @@ try:
1633
1612
  else:
1634
1613
  if interleaved:
1635
1614
  raise ValueError("Only TE >= 2.3.0 supports interleaved fused RoPE.")
1636
- if is_te_min_version("1.4.0.dev0"):
1637
- return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True)
1638
- else:
1639
- raise ValueError("Only TE >= 1.4.0.dev0 supports fused RoPE.")
1615
+
1616
+ return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True)
1640
1617
 
1641
1618
  def fused_apply_rotary_pos_emb_thd(
1642
1619
  t: torch.Tensor,
@@ -1659,6 +1636,7 @@ try:
1659
1636
  cp_rank=cp_rank,
1660
1637
  )
1661
1638
  else:
1639
+ assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP."
1662
1640
  return apply_rotary_pos_emb(
1663
1641
  t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens
1664
1642
  )
@@ -8,6 +8,7 @@ from megatron.core.extensions.transformer_engine import (
8
8
  TEColumnParallelLinear,
9
9
  TEDotProductAttention,
10
10
  TELayerNormColumnParallelLinear,
11
+ TELinear,
11
12
  TENorm,
12
13
  TERowParallelGroupedLinear,
13
14
  TERowParallelLinear,
@@ -23,6 +24,10 @@ from megatron.core.utils import get_te_version, is_te_min_version
23
24
  class TESpecProvider(BackendSpecProvider):
24
25
  """A protocol for providing the submodules used in Spec building."""
25
26
 
27
+ def linear(self) -> type:
28
+ """Which linear module TE backend uses"""
29
+ return TELinear
30
+
26
31
  def column_parallel_linear(self) -> type:
27
32
  """Which column parallel linear module TE backend uses"""
28
33
  return TEColumnParallelLinear
@@ -28,16 +28,25 @@ if not HAVE_TRITON:
28
28
 
29
29
 
30
30
  @triton.jit
31
- def _get_thd_token_idx(cu_seqlens, pid_m, seq_num):
31
+ def _get_thd_token_idx(cu_seqlens, pid_m, seq_num, cp_rank, cp_size):
32
32
  token_idx = -1
33
+ this_seq_len = 0
33
34
  seq_idx = 0
34
- last_cum_seqlen = tl.load(cu_seqlens)
35
+ last_cum_seqlen = tl.load(cu_seqlens) // cp_size
35
36
  while seq_idx < seq_num:
36
- cur_cum_seqlen = tl.load(cu_seqlens + seq_idx + 1)
37
+ cur_cum_seqlen = tl.load(cu_seqlens + seq_idx + 1) // cp_size
37
38
  if token_idx == -1 and cur_cum_seqlen > pid_m:
38
39
  token_idx = pid_m - last_cum_seqlen
40
+ this_seq_len = cur_cum_seqlen - last_cum_seqlen
39
41
  last_cum_seqlen = cur_cum_seqlen
40
42
  seq_idx += 1
43
+ if cp_size > 1:
44
+ if token_idx < this_seq_len // 2:
45
+ token_idx = token_idx + cp_rank * this_seq_len // 2
46
+ else:
47
+ token_idx = (token_idx - this_seq_len // 2) + (
48
+ 2 * cp_size - cp_rank - 1
49
+ ) * this_seq_len // 2
41
50
  return token_idx
42
51
 
43
52
 
@@ -68,6 +77,8 @@ def rotary_fwd_q_kernel(
68
77
  cu_seqlens_q,
69
78
  stride_x_seq,
70
79
  stride_x_nheads,
80
+ cp_rank,
81
+ cp_size,
71
82
  BLOCK_H: tl.constexpr,
72
83
  ):
73
84
  """
@@ -89,7 +100,7 @@ def rotary_fwd_q_kernel(
89
100
  if cu_seqlens_q is None:
90
101
  token_idx = pid_m // batch_size
91
102
  else:
92
- token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num)
103
+ token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num, cp_rank, cp_size)
93
104
 
94
105
  cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
95
106
  sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
@@ -146,6 +157,8 @@ def rotary_bwd_q_kernel(
146
157
  cu_seqlens_q,
147
158
  stride_x_seq,
148
159
  stride_x_nheads,
160
+ cp_rank,
161
+ cp_size,
149
162
  BLOCK_H: tl.constexpr,
150
163
  ):
151
164
  """
@@ -165,7 +178,7 @@ def rotary_bwd_q_kernel(
165
178
  if cu_seqlens_q is None:
166
179
  token_idx = pid_m // batch_size
167
180
  else:
168
- token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num)
181
+ token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num, cp_rank, cp_size)
169
182
 
170
183
  cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
171
184
  sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
@@ -200,7 +213,18 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
200
213
  """
201
214
 
202
215
  @staticmethod
203
- def forward(ctx, q, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, rotary_interleaved=False):
216
+ def forward(
217
+ ctx,
218
+ q,
219
+ cos,
220
+ sin,
221
+ qk_head_dim,
222
+ emb_dim,
223
+ cu_seqlens_q,
224
+ cp_rank,
225
+ cp_size,
226
+ rotary_interleaved=False,
227
+ ):
204
228
  """
205
229
  Forward function for ApplyMLARotaryEmbQ.
206
230
 
@@ -243,12 +267,16 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
243
267
  cu_seqlens_q,
244
268
  q.stride(0),
245
269
  q.stride(1),
270
+ cp_rank,
271
+ cp_size,
246
272
  )
247
273
  ctx.save_for_backward(cos, sin)
248
274
  ctx.qk_head_dim = qk_head_dim
249
275
  ctx.emb_dim = emb_dim
250
276
  ctx.cu_seqlens_q = cu_seqlens_q
251
277
  ctx.rotary_interleaved = rotary_interleaved
278
+ ctx.cp_rank = cp_rank
279
+ ctx.cp_size = cp_size
252
280
  if cu_seqlens_q is None:
253
281
  q = q.view(max_seqlen, batch_size, nheads, headdim)
254
282
  return q
@@ -268,7 +296,7 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
268
296
  seq_num = None
269
297
  if ctx.cu_seqlens_q is None:
270
298
  max_seqlen, batch_size, nheads, headdim = grad.shape
271
- grad = grad.view(-1, nheads, headdim)
299
+ grad = grad.contiguous().view(-1, nheads, headdim)
272
300
  total_seqlen = grad.shape[0]
273
301
  else:
274
302
  seq_num = len(ctx.cu_seqlens_q) - 1
@@ -288,10 +316,12 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
288
316
  ctx.cu_seqlens_q,
289
317
  grad.stride(0),
290
318
  grad.stride(1),
319
+ ctx.cp_rank,
320
+ ctx.cp_size,
291
321
  )
292
322
  if ctx.cu_seqlens_q is None:
293
323
  grad = grad.view(max_seqlen, batch_size, nheads, headdim)
294
- return grad, None, None, None, None, None, None
324
+ return grad, None, None, None, None, None, None, None, None
295
325
 
296
326
 
297
327
  @experimental_fn(introduced_with_version="0.13.0")
@@ -302,6 +332,8 @@ def fused_apply_mla_rope_for_q(
302
332
  qk_head_dim: int,
303
333
  emb_dim: int,
304
334
  cu_seqlens_q: Optional[torch.Tensor] = None,
335
+ cp_rank: int = 0,
336
+ cp_size: int = 1,
305
337
  rotary_interleaved: bool = False,
306
338
  ):
307
339
  """
@@ -327,7 +359,7 @@ def fused_apply_mla_rope_for_q(
327
359
  t: inplace modified input tensor
328
360
  """
329
361
  return ApplyMLARotaryEmbQ.apply(
330
- t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, rotary_interleaved
362
+ t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, cp_rank, cp_size, rotary_interleaved
331
363
  )
332
364
 
333
365
 
@@ -366,6 +398,8 @@ def rotary_fwd_kv_kernel(
366
398
  stride_k_nheads,
367
399
  stride_v_seq,
368
400
  stride_v_nheads,
401
+ cp_rank,
402
+ cp_size,
369
403
  BLOCK_H: tl.constexpr,
370
404
  ):
371
405
  """
@@ -394,7 +428,7 @@ def rotary_fwd_kv_kernel(
394
428
  if cu_seqlens_kv is None:
395
429
  token_idx = pid_m // batch_size
396
430
  else:
397
- token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num)
431
+ token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
398
432
 
399
433
  cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
400
434
  sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
@@ -472,6 +506,8 @@ def rotary_bwd_kv_kernel(
472
506
  stride_dkv_seq,
473
507
  stride_dkv_nheads,
474
508
  stride_demb_seq,
509
+ cp_rank,
510
+ cp_size,
475
511
  BLOCK_H: tl.constexpr,
476
512
  ):
477
513
  """
@@ -496,7 +532,7 @@ def rotary_bwd_kv_kernel(
496
532
  if cu_seqlens_kv is None:
497
533
  token_idx = pid_m // batch_size
498
534
  else:
499
- token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num)
535
+ token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
500
536
 
501
537
  dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads
502
538
  dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads
@@ -550,7 +586,18 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
550
586
 
551
587
  @staticmethod
552
588
  def forward(
553
- ctx, kv, k_pos_emb, cos, sin, emb_dim, k_dim, v_dim, cu_seqlens_kv, rotary_interleaved=False
589
+ ctx,
590
+ kv,
591
+ k_pos_emb,
592
+ cos,
593
+ sin,
594
+ emb_dim,
595
+ k_dim,
596
+ v_dim,
597
+ cu_seqlens_kv,
598
+ cp_rank,
599
+ cp_size,
600
+ rotary_interleaved=False,
554
601
  ):
555
602
  """
556
603
  Forward function for ApplyMLARotaryEmbKV.
@@ -609,6 +656,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
609
656
  o_key.stride(1),
610
657
  o_value.stride(0),
611
658
  o_value.stride(1),
659
+ cp_rank,
660
+ cp_size,
612
661
  )
613
662
  ctx.save_for_backward(cos, sin)
614
663
  ctx.rotary_interleaved = rotary_interleaved
@@ -616,6 +665,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
616
665
  ctx.k_dim = k_dim
617
666
  ctx.v_dim = v_dim
618
667
  ctx.cu_seqlens_kv = cu_seqlens_kv
668
+ ctx.cp_rank = cp_rank
669
+ ctx.cp_size = cp_size
619
670
  if cu_seqlens_kv is None:
620
671
  o_key = o_key.view(max_seqlen, -1, nheads, emb_dim + k_dim)
621
672
  o_value = o_value.view(max_seqlen, -1, nheads, v_dim)
@@ -638,8 +689,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
638
689
  if ctx.cu_seqlens_kv is None:
639
690
  # sbhd
640
691
  max_seqlen, batch_size, nheads, _ = dk.shape
641
- dk = dk.view(-1, nheads, ctx.emb_dim + ctx.k_dim)
642
- dv = dv.view(-1, nheads, ctx.v_dim)
692
+ dk = dk.contiguous().view(-1, nheads, ctx.emb_dim + ctx.k_dim)
693
+ dv = dv.contiguous().view(-1, nheads, ctx.v_dim)
643
694
  total_seqlen = dk.shape[0]
644
695
  else:
645
696
  # thd
@@ -673,11 +724,13 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
673
724
  d_kv.stride(0),
674
725
  d_kv.stride(1),
675
726
  d_emb.stride(0),
727
+ ctx.cp_rank,
728
+ ctx.cp_size,
676
729
  )
677
730
  if ctx.cu_seqlens_kv is None:
678
731
  d_kv = d_kv.view(max_seqlen, batch_size, nheads, ctx.k_dim + ctx.v_dim)
679
732
  d_emb = d_emb.view(max_seqlen, batch_size, 1, ctx.emb_dim)
680
- return d_kv, d_emb, None, None, None, None, None, None, None
733
+ return d_kv, d_emb, None, None, None, None, None, None, None, None, None
681
734
 
682
735
 
683
736
  @experimental_fn(introduced_with_version="0.13.0")
@@ -690,6 +743,8 @@ def fused_apply_mla_rope_for_kv(
690
743
  k_dim: int,
691
744
  v_dim: int,
692
745
  cu_seqlens_kv: Optional[torch.Tensor] = None,
746
+ cp_rank: int = 0,
747
+ cp_size: int = 1,
693
748
  rotary_interleaved: bool = False,
694
749
  ):
695
750
  """
@@ -715,5 +770,15 @@ def fused_apply_mla_rope_for_kv(
715
770
  value: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]
716
771
  """
717
772
  return ApplyMLARotaryEmbKV.apply(
718
- kv, k_pos_emb, cos, sin, emb_dim, k_dim, v_dim, cu_seqlens_kv, rotary_interleaved
773
+ kv,
774
+ k_pos_emb,
775
+ cos,
776
+ sin,
777
+ emb_dim,
778
+ k_dim,
779
+ v_dim,
780
+ cu_seqlens_kv,
781
+ cp_rank,
782
+ cp_size,
783
+ rotary_interleaved,
719
784
  )
@@ -155,7 +155,6 @@ class DynamicInferenceContext(BaseInferenceContext):
155
155
  tp_size = tensor_model_parallel_size
156
156
  hidden_size_per_attention_head = core_divide(projection_size, num_attention_heads)
157
157
  num_attention_heads_per_partition = core_divide(num_attention_heads, tp_size)
158
-
159
158
  # Chunk size tokens, bytes.
160
159
  dtype_size_bytes = params_dtype.itemsize
161
160
  self.chunk_size_tokens = chunk_size_tokens
@@ -177,23 +176,24 @@ class DynamicInferenceContext(BaseInferenceContext):
177
176
  def bytes_to_max_requests_and_tokens(n_bytes):
178
177
  n_tokens = n_bytes / self.chunk_size_bytes * self.chunk_size_tokens
179
178
  n_requests = n_tokens / max_sequence_length
180
- return int(n_requests), int(n_tokens)
179
+ return self.round_up_requests(int(n_requests), tp_size=tp_size), self.round_up_tokens(
180
+ int(n_tokens), tp_size=tp_size
181
+ )
181
182
 
182
183
  self.max_requests, self.max_tokens = bytes_to_max_requests_and_tokens(buffer_size_bytes)
183
-
184
184
  if buffer_overflow_factor is not None:
185
185
  self.max_requests = self.round_up_requests(
186
- int(self.max_requests * buffer_overflow_factor)
186
+ int(self.max_requests * buffer_overflow_factor), tp_size=tp_size
187
187
  )
188
188
  self.max_tokens = self.round_up_tokens(
189
- int(self.max_tokens * buffer_overflow_factor / 50.0)
189
+ int(self.max_tokens * buffer_overflow_factor / 50.0), tp_size=tp_size
190
190
  )
191
191
 
192
192
  if max_requests_override is not None:
193
- self.max_requests = self.round_up_requests(max_requests_override)
193
+ self.max_requests = self.round_up_requests(max_requests_override, tp_size=tp_size)
194
194
 
195
195
  if max_tokens_override is not None:
196
- self.max_tokens = self.round_up_tokens(max_tokens_override)
196
+ self.max_tokens = self.round_up_tokens(max_tokens_override, tp_size=tp_size)
197
197
 
198
198
  self.max_requests = min(self.max_requests, self.max_tokens) # e.g., decode only.
199
199
 
@@ -277,7 +277,8 @@ class DynamicInferenceContext(BaseInferenceContext):
277
277
  self.cuda_graph_step_size = cuda_graph_rounder * int(
278
278
  math.ceil(int(self.cuda_graph_step_size) / cuda_graph_rounder)
279
279
  )
280
-
280
+ # Make sure divisble by TP size
281
+ self.cuda_graph_step_size = math.ceil(self.cuda_graph_step_size / tp_size) * tp_size
281
282
  # Cuda graph request counts.
282
283
  if num_cuda_graphs == 1:
283
284
  self.cuda_graph_request_counts = [self.max_requests]
@@ -355,26 +356,46 @@ class DynamicInferenceContext(BaseInferenceContext):
355
356
  REQUEST_ROUNDER = 4
356
357
 
357
358
  @classmethod
358
- def round_up_tokens(cls, value):
359
- """Round up to nearest multiple of `TOKEN_ROUNDER` (above)."""
359
+ def round_up_tokens(cls, value, tp_size=None):
360
+ """Round up to nearest multiple of `TOKEN_ROUNDER` (above) that is also divisible by tensor model parallel size."""
360
361
  if not HAVE_PACKAGING:
361
362
  raise ImportError(
362
363
  "`packaging` is required for this functionality, please install it with `pip install packaging`"
363
364
  )
364
365
  if PkgVersion(mcore_version) < PkgVersion("0.13"):
365
366
  return cls.round_up(value)
366
- return cls.TOKEN_ROUNDER * int(math.ceil(int(value) / cls.TOKEN_ROUNDER))
367
+
368
+ # Make sure divisible by TP size
369
+ if tp_size is None:
370
+ # Check if parallel state is initialized before trying to get TP size
371
+ if parallel_state.is_initialized():
372
+ tp_size = parallel_state.get_tensor_model_parallel_world_size()
373
+ else:
374
+ tp_size = 1
375
+ token_rounder = math.ceil(cls.TOKEN_ROUNDER / tp_size) * tp_size
376
+
377
+ return token_rounder * int(math.ceil(int(value) / token_rounder))
367
378
 
368
379
  @classmethod
369
- def round_up_requests(cls, value):
370
- """Round up to nearest multiple of `REQUEST_ROUNDER` (above)."""
380
+ def round_up_requests(cls, value, tp_size=None):
381
+ """Round up to nearest multiple of `REQUEST_ROUNDER` (above) that is also divisible by tensor model parallel size."""
371
382
  if not HAVE_PACKAGING:
372
383
  raise ImportError(
373
384
  "`packaging` is required for this functionality, please install it with `pip install packaging`"
374
385
  )
375
386
  if PkgVersion(mcore_version) < PkgVersion("0.13"):
376
387
  return cls.round_up(value)
377
- return cls.REQUEST_ROUNDER * int(math.ceil(int(value) / cls.REQUEST_ROUNDER))
388
+
389
+ # Make sure divisible by TP size
390
+ if tp_size is None:
391
+ # Check if parallel state is initialized before trying to get TP size
392
+ if parallel_state.is_initialized():
393
+ tp_size = parallel_state.get_tensor_model_parallel_world_size()
394
+ else:
395
+ tp_size = 1
396
+ request_rounder = math.ceil(cls.REQUEST_ROUNDER / tp_size) * tp_size
397
+
398
+ return request_rounder * int(math.ceil(int(value) / request_rounder))
378
399
 
379
400
  @classmethod
380
401
  def round_up(cls, value):
@@ -1043,21 +1064,16 @@ class DynamicInferenceContext(BaseInferenceContext):
1043
1064
  # We determine how many requests we can resume and resume them
1044
1065
  # Assign released chunks to paused requests.
1045
1066
  # todo: @shanmugamr, un-pause requests using FIFO, rather than LIFO.
1046
- if (
1047
- self.chunk_allocator.chunk_count_avail
1048
- <= self.paused_request_count + self.gtd_chunk_count
1049
- ):
1050
- if active_request_count < self.gtd_request_count:
1051
- resume_request_count = min(
1052
- self.paused_request_count, self.gtd_request_count - active_request_count
1053
- )
1054
- else:
1055
- # If there are more active requests than gtd requests and not enough
1056
- # chunks available, no requests can be resumed
1057
- resume_request_count = 0
1067
+ num_non_gtd_chunks = max(0, self.chunk_allocator.chunk_count_avail - self.gtd_chunk_count)
1068
+ if num_non_gtd_chunks:
1069
+ # if we have non-gtd chunks, use them. Do not dip into the gtd-chunk pool
1070
+ resume_request_count = min(num_non_gtd_chunks, self.paused_request_count)
1058
1071
  else:
1059
- # If there are more available chunks than (paused + gtd requests), resume all paused requests
1060
- resume_request_count = self.paused_request_count
1072
+ # only dip into the gtd-chunk pool if we have run out of non-gtd-chunks and the active
1073
+ # request count has fallen below a certain threshold.
1074
+ resume_request_count = min(
1075
+ max(self.gtd_request_count - active_request_count, 0), self.paused_request_count
1076
+ )
1061
1077
 
1062
1078
  self.paused_request_count -= resume_request_count
1063
1079
  active_request_count += resume_request_count
@@ -26,6 +26,8 @@ from megatron.core.inference.model_inference_wrappers.abstract_model_inference_w
26
26
  from megatron.core.inference.sampling_params import SamplingParams
27
27
  from megatron.core.inference.utils import get_attention_mask
28
28
  from megatron.core.transformer.cuda_graphs import create_cudagraphs
29
+ from megatron.core.transformer.moe.moe_layer import BaseMoELayer
30
+ from megatron.core.transformer.utils import set_model_to_sequence_parallel
29
31
  from megatron.core.utils import get_model_config
30
32
 
31
33
  try:
@@ -429,9 +431,11 @@ class TextGenerationController:
429
431
  # Get flat tokens, position ids.
430
432
  input_ids, position_ids = context.current_input_and_position_ids()
431
433
 
434
+ model_config = get_model_config(self.inference_wrapped_model.model)
435
+
432
436
  # If using symmetric kernels and we are using using nccl
433
437
  # for prefill turn off symmetric kernels
434
- symmetric_ar_type = get_model_config(self.inference_wrapped_model.model).symmetric_ar_type
438
+ symmetric_ar_type = model_config.symmetric_ar_type
435
439
  nccl_all_reduce_for_prefill = (
436
440
  self.inference_wrapped_model.inference_wrapper_config.nccl_all_reduce_for_prefill
437
441
  )
@@ -588,7 +592,9 @@ class TextGenerationController:
588
592
  )
589
593
 
590
594
  # Check whether CUDA graphs are enabled
591
- enable_cuda_graph = model_config.enable_cuda_graph
595
+ enable_cuda_graph = (
596
+ model_config.enable_cuda_graph and model_config.cuda_graph_scope != "full_iteration"
597
+ )
592
598
 
593
599
  # Pad batch tokens if necessary
594
600
  batch_size = len(active_requests)
@@ -681,6 +687,21 @@ class TextGenerationController:
681
687
  not self.inference_wrapped_model.inference_context.is_decode_only()
682
688
  ), f"Generation must start in prefill mode"
683
689
 
690
+ # Sequence parallelism is required for MoE layers when using expert parallelism (EP)
691
+ # becausethe expert routing mechanism relies on sequence parallelism's communication
692
+ # infrastructure to distribute tokens across expert ranks. However, sequence parallelism
693
+ # is not currently supported for non-MoE layers during inference,so we selectively
694
+ # disable it for all other layer types. This is safe because MoE layers perform an
695
+ # all-gather operation on sequences before passing data to subsequent layers, ensuring
696
+ # that each rank has the complete sequence data needed for the next non-MoE layer.
697
+ tp_size = model_config.tensor_model_parallel_size
698
+ ep_size = model_config.expert_model_parallel_size
699
+ model_is_tp_ep = tp_size > 1 and ep_size > 1
700
+ if model_is_tp_ep:
701
+ set_model_to_sequence_parallel(
702
+ self.inference_wrapped_model.model.module, False, exclude_modules=[BaseMoELayer]
703
+ )
704
+
684
705
  # If using symmetric kernels and we are using using nccl
685
706
  # for prefill turn off symmetric kernels
686
707
  symmetric_ar_type = model_config.symmetric_ar_type
@@ -237,6 +237,14 @@ class ModelParallelConfig:
237
237
  Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
238
238
  """
239
239
 
240
+ overlap_moe_expert_parallel_comm: bool = False
241
+ """Overlap EP A2A communications with independent computations of different micro-batches
242
+ in 1f1b phase of pipelining or non-pipelining schedule.
243
+ """
244
+
245
+ delay_wgrad_compute: bool = False
246
+ """Delay the weight gradient computation to improve batch-level communication overlapping"""
247
+
240
248
  ###################
241
249
  # Pipeline Parallel
242
250
  ###################
@@ -307,9 +315,6 @@ class ModelParallelConfig:
307
315
  rank 1 | 0 1 2 0 1 2 3 4 3 4
308
316
  """
309
317
 
310
- delay_wgrad_compute: bool = False
311
- """If true, delay the wgrad compute for better overlapping in combined 1F1B."""
312
-
313
318
  ###################
314
319
  # CPU Offloading
315
320
  ###################