megatron-core 0.14.0rc2__tar.gz → 0.14.0rc4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (311) hide show
  1. {megatron_core-0.14.0rc2/megatron_core.egg-info → megatron_core-0.14.0rc4}/PKG-INFO +11 -8
  2. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/torch.py +2 -1
  3. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/validation.py +21 -15
  4. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +10 -0
  5. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine.py +5 -27
  6. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/extensions/transformer_engine_spec_provider.py +5 -0
  7. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fp8_utils.py +119 -0
  8. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +81 -16
  9. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/__init__.py +1 -0
  10. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/dynamic_context.py +191 -86
  11. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/dynamic_engine.py +79 -18
  12. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +4 -0
  13. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +6 -0
  14. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +26 -39
  15. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/model_parallel_config.py +8 -3
  16. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/rope_utils.py +20 -32
  17. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +10 -4
  18. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +11 -6
  19. megatron_core-0.14.0rc4/megatron/core/models/common/model_chunk_schedule_plan.py +502 -0
  20. megatron_core-0.14.0rc4/megatron/core/models/gpt/fine_grained_callables.py +474 -0
  21. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/gpt_layer_specs.py +2 -2
  22. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/gpt_model.py +62 -1
  23. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/__init__.py +143 -44
  24. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/optimizer.py +11 -4
  25. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/optimizer_config.py +1 -0
  26. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/package_info.py +1 -1
  27. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/packed_seq_params.py +2 -2
  28. megatron_core-0.14.0rc4/megatron/core/pipeline_parallel/combined_1f1b.py +331 -0
  29. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/schedules.py +169 -101
  30. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/utils.py +91 -0
  31. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_block.py +4 -1
  32. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_layer.py +1 -1
  33. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/layers.py +23 -12
  34. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/random.py +4 -1
  35. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/attention.py +3 -7
  36. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/cuda_graphs.py +178 -43
  37. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/mlp.py +20 -2
  38. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/experts.py +22 -0
  39. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/multi_latent_attention.py +81 -9
  40. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_config.py +60 -7
  41. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_layer.py +11 -10
  42. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/utils.py +17 -11
  43. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/utils.py +27 -3
  44. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4/megatron_core.egg-info}/PKG-INFO +11 -8
  45. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron_core.egg-info/SOURCES.txt +2 -0
  46. megatron_core-0.14.0rc4/megatron_core.egg-info/requires.txt +33 -0
  47. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/pyproject.toml +13 -10
  48. megatron_core-0.14.0rc2/megatron/core/models/gpt/fine_grained_callables.py +0 -195
  49. megatron_core-0.14.0rc2/megatron_core.egg-info/requires.txt +0 -30
  50. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/LICENSE +0 -0
  51. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/MANIFEST.in +0 -0
  52. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/README.md +0 -0
  53. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/README.md +0 -0
  54. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/__init__.py +0 -0
  55. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/config.py +0 -0
  56. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/config_logger.py +0 -0
  57. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/__init__.py +0 -0
  58. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/bert_dataset.py +0 -0
  59. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_dataset.py +0 -0
  60. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  61. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  62. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/gpt_dataset.py +0 -0
  63. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/helpers.cpp +0 -0
  64. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/helpers.py +0 -0
  65. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/indexed_dataset.py +0 -0
  66. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/masked_dataset.py +0 -0
  67. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/megatron_dataset.py +0 -0
  68. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  69. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/multimodal_dataset.py +0 -0
  70. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/object_storage_utils.py +0 -0
  71. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/__init__.py +0 -0
  72. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/__init__.py +0 -0
  73. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  74. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/config.py +0 -0
  75. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  76. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  77. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/__init__.py +0 -0
  78. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/build.py +0 -0
  79. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/dataset.py +0 -0
  80. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/db/utils.py +0 -0
  81. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/external_libs.py +0 -0
  82. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/__init__.py +0 -0
  83. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/build.py +0 -0
  84. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/factory.py +0 -0
  85. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/index.py +0 -0
  86. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  87. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  88. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  89. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/utils.py +0 -0
  90. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/index/validate.py +0 -0
  91. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/__init__.py +0 -0
  92. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  93. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  94. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/query.py +0 -0
  95. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  96. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/query/utils.py +0 -0
  97. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/retro/utils.py +0 -0
  98. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/t5_dataset.py +0 -0
  99. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils.py +0 -0
  100. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils_object_storage.py +0 -0
  101. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/datasets/utils_s3.py +0 -0
  102. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/__init__.py +0 -0
  103. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/core.py +0 -0
  104. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  105. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  106. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/mapping.py +0 -0
  107. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  108. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/serialization.py +0 -0
  109. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  110. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  111. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  112. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  113. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  114. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  115. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  116. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  117. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  118. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  119. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  120. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  121. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  122. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  123. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/dist_checkpointing/utils.py +0 -0
  124. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/__init__.py +0 -0
  125. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
  126. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
  127. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/data_parallel_base.py +0 -0
  128. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  129. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  130. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/finalize_model_grads.py +0 -0
  131. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  132. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  133. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  134. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/energy_monitor.py +0 -0
  135. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/enums.py +0 -0
  136. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/__init__.py +0 -0
  137. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/data_type.py +0 -0
  138. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/export_config.py +0 -0
  139. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/model_type.py +0 -0
  140. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/__init__.py +0 -0
  141. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  142. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  143. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  144. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  145. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  146. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  147. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  148. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  149. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  150. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  151. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  152. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  153. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/extensions/__init__.py +0 -0
  154. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/extensions/kitchen.py +0 -0
  155. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/__init__.py +0 -0
  156. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  157. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  158. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  159. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  160. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  161. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_indices_converter.py +0 -0
  162. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_layer_norm.py +0 -0
  163. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  164. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/fusions/fused_softmax.py +0 -0
  165. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/hyper_comm_grid.py +0 -0
  166. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/__init__.py +0 -0
  167. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/async_stream.py +0 -0
  168. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/common_inference_params.py +0 -0
  169. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/communication_utils.py +0 -0
  170. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/base_context.py +0 -0
  171. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  172. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/contexts/static_context.py +0 -0
  173. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/__init__.py +0 -0
  174. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/abstract_engine.py +0 -0
  175. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/mcore_engine.py +0 -0
  176. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/engines/static_engine.py +0 -0
  177. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/inference_request.py +0 -0
  178. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  179. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  180. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  181. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  182. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  183. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  184. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/sampling_params.py +0 -0
  185. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/scheduler.py +0 -0
  186. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  187. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  188. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  189. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  190. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference/utils.py +0 -0
  191. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/inference_params.py +0 -0
  192. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/jit.py +0 -0
  193. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/T5/__init__.py +0 -0
  194. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/T5/t5_model.py +0 -0
  195. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/T5/t5_spec.py +0 -0
  196. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/__init__.py +0 -0
  197. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/backends.py +0 -0
  198. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/__init__.py +0 -0
  199. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  200. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_lm_head.py +0 -0
  201. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/bert_model.py +0 -0
  202. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/bert/pooler.py +0 -0
  203. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/__init__.py +0 -0
  204. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/__init__.py +0 -0
  205. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  206. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  207. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/language_module/__init__.py +0 -0
  208. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/language_module/language_module.py +0 -0
  209. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/vision_module/__init__.py +0 -0
  210. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  211. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/__init__.py +0 -0
  212. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  213. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  214. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/__init__.py +0 -0
  215. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/clip_model.py +0 -0
  216. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/module.py +0 -0
  217. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/huggingface/qwen_model.py +0 -0
  218. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/__init__.py +0 -0
  219. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  220. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mamba/mamba_model.py +0 -0
  221. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/__init__.py +0 -0
  222. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/config/__init__.py +0 -0
  223. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/config/base_configs.py +0 -0
  224. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/model/__init__.py +0 -0
  225. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/model/base.py +0 -0
  226. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/audio.py +0 -0
  227. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/base.py +0 -0
  228. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/mimo/submodules/vision.py +0 -0
  229. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/__init__.py +0 -0
  230. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/context_parallel.py +0 -0
  231. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/llava_model.py +0 -0
  232. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/multimodal/llava_spec.py +0 -0
  233. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/__init__.py +0 -0
  234. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/base_attention.py +0 -0
  235. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/config.py +0 -0
  236. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/decoder_attention.py +0 -0
  237. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/decoder_spec.py +0 -0
  238. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/encoder_attention.py +0 -0
  239. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/encoder_spec.py +0 -0
  240. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/model.py +0 -0
  241. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/retro/utils.py +0 -0
  242. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/__init__.py +0 -0
  243. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/clip_vit_model.py +0 -0
  244. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/multimodal_projector.py +0 -0
  245. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/radio.py +0 -0
  246. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  247. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/msc_utils.py +0 -0
  248. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/num_microbatches_calculator.py +0 -0
  249. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/clip_grads.py +0 -0
  250. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  251. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  252. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/distrib_optimizer.py +0 -0
  253. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer/grad_scaler.py +0 -0
  254. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/optimizer_param_scheduler.py +0 -0
  255. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/parallel_state.py +0 -0
  256. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/__init__.py +0 -0
  257. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  258. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/__init__.py +0 -0
  259. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/__init__.py +0 -0
  260. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  261. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  262. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  263. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/layers.py +0 -0
  264. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  265. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  266. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/process_groups_config.py +0 -0
  267. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/quantization/__init__.py +0 -0
  268. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/quantization/quant_config.py +0 -0
  269. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/quantization/utils.py +0 -0
  270. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/requirements.txt +0 -0
  271. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/rerun_state_machine.py +0 -0
  272. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/__init__.py +0 -0
  273. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  274. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  275. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mamba_mixer.py +0 -0
  276. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/mlp_layer.py +0 -0
  277. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/ssm/triton_cache_manager.py +0 -0
  278. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/__init__.py +0 -0
  279. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  280. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/data.py +0 -0
  281. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/mappings.py +0 -0
  282. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/tensor_parallel/utils.py +0 -0
  283. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/timers.py +0 -0
  284. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/__init__.py +0 -0
  285. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  286. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  287. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/dot_product_attention.py +0 -0
  288. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/enums.py +0 -0
  289. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  290. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  291. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/identity_op.py +0 -0
  292. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/module.py +0 -0
  293. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/__init__.py +0 -0
  294. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  295. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  296. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/moe_layer.py +0 -0
  297. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/moe_utils.py +0 -0
  298. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/router.py +0 -0
  299. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/shared_experts.py +0 -0
  300. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  301. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  302. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/multi_token_prediction.py +0 -0
  303. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  304. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/spec_utils.py +0 -0
  305. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/torch_layer_norm.py +0 -0
  306. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/torch_norm.py +0 -0
  307. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron/core/transformer/transformer_block.py +0 -0
  308. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron_core.egg-info/dependency_links.txt +0 -0
  309. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/megatron_core.egg-info/top_level.txt +0 -0
  310. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/setup.cfg +0 -0
  311. {megatron_core-0.14.0rc2 → megatron_core-0.14.0rc4}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.14.0rc2
3
+ Version: 0.14.0rc4
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -31,6 +31,7 @@ Description-Content-Type: text/markdown
31
31
  License-File: LICENSE
32
32
  Requires-Dist: torch
33
33
  Requires-Dist: numpy<2.0.0
34
+ Requires-Dist: packaging~=25.0
34
35
  Provides-Extra: mlm
35
36
  Requires-Dist: flask-restful; extra == "mlm"
36
37
  Requires-Dist: sentencepiece; extra == "mlm"
@@ -38,14 +39,16 @@ Requires-Dist: tiktoken; extra == "mlm"
38
39
  Requires-Dist: wandb; extra == "mlm"
39
40
  Provides-Extra: dev
40
41
  Requires-Dist: tqdm; extra == "dev"
41
- Requires-Dist: einops; extra == "dev"
42
- Requires-Dist: tensorstore!=0.1.46,!=0.1.72; extra == "dev"
43
- Requires-Dist: nvtx; extra == "dev"
44
- Requires-Dist: transformers; extra == "dev"
45
- Requires-Dist: multi-storage-client; extra == "dev"
42
+ Requires-Dist: einops~=0.8; extra == "dev"
43
+ Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "dev"
44
+ Requires-Dist: nvtx~=0.2; extra == "dev"
45
+ Requires-Dist: transformers~=4.53; extra == "dev"
46
+ Requires-Dist: multi-storage-client~=0.20.3; extra == "dev"
47
+ Requires-Dist: opentelemetry-api~=1.33.1; extra == "dev"
46
48
  Requires-Dist: setuptools<80.0.0; extra == "dev"
47
- Requires-Dist: nvidia-modelopt[torch]~=0.31.0; sys_platform != "darwin" and extra == "dev"
48
- Requires-Dist: megatron-energon[av_decode]<7; extra == "dev"
49
+ Requires-Dist: nvidia-modelopt[torch]<0.32.0,>=0.31.0a0; sys_platform != "darwin" and extra == "dev"
50
+ Requires-Dist: megatron-energon[av_decode]~=6.0; extra == "dev"
51
+ Requires-Dist: flashinfer-python; extra == "dev"
49
52
  Provides-Extra: lts
50
53
  Requires-Dist: tqdm; extra == "lts"
51
54
  Requires-Dist: einops; extra == "lts"
@@ -374,7 +374,8 @@ def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]
374
374
  ten = ten.view(-1)
375
375
  else:
376
376
  for _ in range(mcore_sh_ten.prepend_axis_num):
377
- ten = ten.squeeze(0)
377
+ assert ten.size(0) == 1
378
+ ten = ten[0] # NOTE: ten.squeeze(0) uses more memory for FP8 tensors
378
379
  ret_tensors.append(ten)
379
380
  return ret_tensors
380
381
 
@@ -375,28 +375,34 @@ def maybe_report_missing_and_unexpected_keys(
375
375
  def _validate_common_state_dict(common_state_dict: CommonStateDict) -> None:
376
376
  """Validate consistancy across ranks for the common state dict
377
377
 
378
- We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving.
378
+ We save the common state dict only on rank 0. We validate to make sure that the common dict is consistent across ranks before saving.
379
379
 
380
380
  Args:
381
381
  common_state_dict: The common state dict present in all ransk
382
382
  """
383
+ if not torch.distributed.is_initialized():
384
+ return
383
385
 
384
- # Gather the common state dict across ranks onto rank 0 for comparison
386
+ # Broadcast the common state dict from rank 0 to all other ranks
387
+ # Each rank will do a comparison with its local rank vs the broadcasted state dict from rank 0
385
388
  rank = torch.distributed.get_rank()
386
- other_rank_state_dicts = [None] * torch.distributed.get_world_size() if rank == 0 else None
387
- torch.distributed.gather_object(common_state_dict, other_rank_state_dicts)
388
- common_state_dict_diff = {}
389
- if rank == 0:
390
- assert other_rank_state_dicts
391
- main_rank_state_dict = common_state_dict
392
- for rank, rank_state_dict in enumerate(other_rank_state_dicts[1:], 1):
393
- only_left, only_right, mismatch = diff(main_rank_state_dict, rank_state_dict)
394
- if only_left or only_right or mismatch:
395
- common_state_dict_diff[rank] = (only_left, only_right, mismatch)
396
-
397
- if len(common_state_dict_diff) != 0:
389
+
390
+ object_list = [common_state_dict] if rank == 0 else [None]
391
+ torch.distributed.broadcast_object_list(object_list, src=0)
392
+ rank0_state_dict = object_list[0]
393
+
394
+ # Skip comparing rank 0 with itself
395
+ if rank > 0:
396
+ current_rank_state_dict = common_state_dict
397
+ only_in_rank0, only_in_current_rank, mismatch = diff(
398
+ rank0_state_dict, current_rank_state_dict
399
+ )
400
+ if only_in_rank0 or only_in_current_rank or mismatch:
398
401
  logger.warning(
399
- f"There is difference in the common state dict in different ranks. The differences are {common_state_dict_diff}"
402
+ f"Rank {rank} common state dict differs from rank 0 common state dict. "
403
+ f"Keys only on rank 0: {only_in_rank0}, "
404
+ f"Keys only on {rank}: {only_in_current_rank}, "
405
+ f"Mismatched keys: {mismatch}"
400
406
  )
401
407
 
402
408
 
@@ -217,6 +217,16 @@ class FullyShardedDataParallel(_BaseDataParallel):
217
217
 
218
218
  self.module.apply(unmap_weight_tensor)
219
219
 
220
+ for param in self.module.parameters():
221
+ if not hasattr(param, 'grad_added_to_main_grad'):
222
+ # This is to ensure that the param.grad_added_to_main_grad is set to False
223
+ # when the parameter is created.
224
+ param.grad_added_to_main_grad = False
225
+ if not hasattr(param, '__fsdp_param__'):
226
+ # This is to ensure that the param.__fsdp_param__ is set to True
227
+ # when the parameter is created.
228
+ param.__fsdp_param__ = True
229
+
220
230
  def _init_fsdp_param_and_grad_buffer(self):
221
231
  if self.config.calculate_per_token_loss:
222
232
  # We don't need to scale the gradients in this case.
@@ -889,25 +889,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
889
889
  if packed_seq_params is not None
890
890
  else {}
891
891
  )
892
- # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
893
- # after init
894
- if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False):
895
- self.qkv_format = "bshd"
896
-
897
- qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
898
-
899
- # WAR for peak memory usage.
900
- # See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
901
- if self.config.apply_rope_fusion and qkv_format == "bshd":
902
- query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
903
- # In PyTorch, the following two tensors are in fact the same:
904
- # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
905
- # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
906
- # Stride for a dimension that is 1 has no meaning, so tensors created two different ways
907
- # can have same shape but different strides.
908
- # We unify them to the first one to pass the stride check in TE
909
- if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
910
- value = value.as_strided(value.shape, key.stride())
892
+ qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
911
893
 
912
894
  attention_bias_kwargs = {}
913
895
  if attention_bias is not None:
@@ -942,10 +924,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
942
924
  query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
943
925
  )
944
926
 
945
- if self.config.apply_rope_fusion and qkv_format == "bshd":
946
- return core_attn_out.transpose(0, 1)
947
- else:
948
- return core_attn_out
927
+ return core_attn_out
949
928
 
950
929
 
951
930
  if HAVE_TE and is_te_min_version("1.9.0.dev0"):
@@ -1633,10 +1612,8 @@ try:
1633
1612
  else:
1634
1613
  if interleaved:
1635
1614
  raise ValueError("Only TE >= 2.3.0 supports interleaved fused RoPE.")
1636
- if is_te_min_version("1.4.0.dev0"):
1637
- return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True)
1638
- else:
1639
- raise ValueError("Only TE >= 1.4.0.dev0 supports fused RoPE.")
1615
+
1616
+ return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True)
1640
1617
 
1641
1618
  def fused_apply_rotary_pos_emb_thd(
1642
1619
  t: torch.Tensor,
@@ -1659,6 +1636,7 @@ try:
1659
1636
  cp_rank=cp_rank,
1660
1637
  )
1661
1638
  else:
1639
+ assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP."
1662
1640
  return apply_rotary_pos_emb(
1663
1641
  t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens
1664
1642
  )
@@ -8,6 +8,7 @@ from megatron.core.extensions.transformer_engine import (
8
8
  TEColumnParallelLinear,
9
9
  TEDotProductAttention,
10
10
  TELayerNormColumnParallelLinear,
11
+ TELinear,
11
12
  TENorm,
12
13
  TERowParallelGroupedLinear,
13
14
  TERowParallelLinear,
@@ -23,6 +24,10 @@ from megatron.core.utils import get_te_version, is_te_min_version
23
24
  class TESpecProvider(BackendSpecProvider):
24
25
  """A protocol for providing the submodules used in Spec building."""
25
26
 
27
+ def linear(self) -> type:
28
+ """Which linear module TE backend uses"""
29
+ return TELinear
30
+
26
31
  def column_parallel_linear(self) -> type:
27
32
  """Which column parallel linear module TE backend uses"""
28
33
  return TEColumnParallelLinear
@@ -2,7 +2,9 @@
2
2
 
3
3
  """Utility functions related to FP8 that are used throughout Megatron core"""
4
4
 
5
+ import weakref
5
6
  from contextlib import nullcontext
7
+ from functools import wraps
6
8
  from typing import List, Optional
7
9
 
8
10
  import torch
@@ -53,6 +55,29 @@ except (ImportError, ModuleNotFoundError):
53
55
  # MXFP8Tensor not found
54
56
  HAVE_TE_MXFP8TENSOR = False
55
57
 
58
+ if HAVE_TE:
59
+ from megatron.core.extensions.transformer_engine import (
60
+ TEColumnParallelLinear,
61
+ TELayerNormColumnParallelLinear,
62
+ TELinear,
63
+ TERowParallelLinear,
64
+ )
65
+
66
+ TE_LINEAR_TYPES = (
67
+ TELinear,
68
+ TEColumnParallelLinear,
69
+ TERowParallelLinear,
70
+ TELayerNormColumnParallelLinear,
71
+ )
72
+ else:
73
+ TE_LINEAR_TYPES = ()
74
+
75
+ try:
76
+ from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
77
+ except ImportError:
78
+ Fp8Padding = None
79
+ Fp8Unpadding = None
80
+
56
81
 
57
82
  def is_float8tensor(tensor: torch.Tensor) -> bool:
58
83
  """Check if a tensor is a Transformer Engine Float8Tensor.
@@ -511,3 +536,97 @@ else:
511
536
  def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False):
512
537
  """Returns dummy fp8 context manager since TE is not available."""
513
538
  return nullcontext()
539
+
540
+
541
+ if HAVE_TE:
542
+ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
543
+
544
+ # Modules that have been wrapped for inference for fp8
545
+ _fp8_inference_wrapped_modules = weakref.WeakSet()
546
+
547
+ def _wrap_te_linear_for_padding(module: torch.nn.Module):
548
+ """Wrap a TE linear module to automatically pad sequences for FP8 inference.
549
+
550
+ Modifies the module's forward method to:
551
+ 1. Pad input sequences to FP8 alignment requirements
552
+ 2. Run the original forward pass
553
+ 3. Unpad outputs to original sequence length
554
+
555
+ Args:
556
+ module: A Transformer Engine linear layer (TELinear, TEColumnParallelLinear, etc.)
557
+ """
558
+ if module in _fp8_inference_wrapped_modules:
559
+ return
560
+ _pad_func = Fp8Padding(1)
561
+ _unpad_func = Fp8Unpadding(1)
562
+
563
+ original_forward = module.forward
564
+
565
+ @wraps(original_forward)
566
+ def padded_forward(input_tensor, *args, **kwargs):
567
+ # Only do padding for fp8 if we are in fp8 context
568
+ if not FP8GlobalStateManager.is_fp8_enabled():
569
+ return original_forward(input_tensor, *args, **kwargs)
570
+
571
+ seq_len, batch_size, hidden_size = input_tensor.shape
572
+ # Reshape to (S, B*H) to pad sequence dimension
573
+ input_2d = input_tensor.reshape(seq_len, -1)
574
+ # Pad the sequence dimension
575
+ padded_input_2d, _ = _pad_func(input_2d, [seq_len])
576
+ padded_seq_len = padded_input_2d.shape[0]
577
+
578
+ # Reshape back to (padded_S, B, H)
579
+ padded_input_3d = padded_input_2d.view(padded_seq_len, batch_size, hidden_size)
580
+ output = original_forward(padded_input_3d, *args, **kwargs)
581
+
582
+ # Handle output
583
+ if isinstance(output, tuple):
584
+ output_tensor = output[0]
585
+ other_outputs = output[1:]
586
+ else:
587
+ output_tensor = output
588
+ other_outputs = ()
589
+
590
+ # Unpad output - reshape to 2D, unpad, reshape back
591
+ _, _, output_hidden_size = output_tensor.shape
592
+ output_2d = output_tensor.reshape(padded_seq_len, -1)
593
+ unpadded_output_2d = _unpad_func(output_2d, [seq_len])
594
+ unpadded_output = unpadded_output_2d.reshape(seq_len, batch_size, output_hidden_size)
595
+
596
+ if other_outputs:
597
+ return (unpadded_output,) + other_outputs
598
+ else:
599
+ return unpadded_output
600
+
601
+ module.forward = padded_forward
602
+ _fp8_inference_wrapped_modules.add(module)
603
+
604
+ def prepare_model_for_fp8_inference(model):
605
+ """Prepare a model for FP8 inference by wrapping TE linear layers with padding support.
606
+
607
+ FP8 TE Gemms have specific shape requirements. This function wraps all Transformer
608
+ Engine linear layers in the model to automatically pad/unpad sequences during inference.
609
+
610
+ Args:
611
+ model (model (GPTModel): Model containing TE linear layers.
612
+
613
+ Returns:
614
+ GPTModel: The same model with wrapped linear layers (modified in-place).
615
+
616
+ """
617
+ assert Fp8Padding and Fp8Unpadding, "TE version does not have FP8 padding functions"
618
+ # Find and wrap all TE linear layers
619
+ for module in model.modules():
620
+ if isinstance(module, TE_LINEAR_TYPES):
621
+ _wrap_te_linear_for_padding(module)
622
+
623
+ return model
624
+
625
+ else:
626
+
627
+ def prepare_model_for_fp8_inference(model):
628
+ """If trys using prepare_model_for_fp8_inference without TE we error"""
629
+ raise RuntimeError(
630
+ "prepare_model_for_fp8_inference requires Transformer Engine to be installed. "
631
+ "Please install transformer-engine to use FP8 inference."
632
+ )
@@ -28,16 +28,25 @@ if not HAVE_TRITON:
28
28
 
29
29
 
30
30
  @triton.jit
31
- def _get_thd_token_idx(cu_seqlens, pid_m, seq_num):
31
+ def _get_thd_token_idx(cu_seqlens, pid_m, seq_num, cp_rank, cp_size):
32
32
  token_idx = -1
33
+ this_seq_len = 0
33
34
  seq_idx = 0
34
- last_cum_seqlen = tl.load(cu_seqlens)
35
+ last_cum_seqlen = tl.load(cu_seqlens) // cp_size
35
36
  while seq_idx < seq_num:
36
- cur_cum_seqlen = tl.load(cu_seqlens + seq_idx + 1)
37
+ cur_cum_seqlen = tl.load(cu_seqlens + seq_idx + 1) // cp_size
37
38
  if token_idx == -1 and cur_cum_seqlen > pid_m:
38
39
  token_idx = pid_m - last_cum_seqlen
40
+ this_seq_len = cur_cum_seqlen - last_cum_seqlen
39
41
  last_cum_seqlen = cur_cum_seqlen
40
42
  seq_idx += 1
43
+ if cp_size > 1:
44
+ if token_idx < this_seq_len // 2:
45
+ token_idx = token_idx + cp_rank * this_seq_len // 2
46
+ else:
47
+ token_idx = (token_idx - this_seq_len // 2) + (
48
+ 2 * cp_size - cp_rank - 1
49
+ ) * this_seq_len // 2
41
50
  return token_idx
42
51
 
43
52
 
@@ -68,6 +77,8 @@ def rotary_fwd_q_kernel(
68
77
  cu_seqlens_q,
69
78
  stride_x_seq,
70
79
  stride_x_nheads,
80
+ cp_rank,
81
+ cp_size,
71
82
  BLOCK_H: tl.constexpr,
72
83
  ):
73
84
  """
@@ -89,7 +100,7 @@ def rotary_fwd_q_kernel(
89
100
  if cu_seqlens_q is None:
90
101
  token_idx = pid_m // batch_size
91
102
  else:
92
- token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num)
103
+ token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num, cp_rank, cp_size)
93
104
 
94
105
  cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
95
106
  sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
@@ -146,6 +157,8 @@ def rotary_bwd_q_kernel(
146
157
  cu_seqlens_q,
147
158
  stride_x_seq,
148
159
  stride_x_nheads,
160
+ cp_rank,
161
+ cp_size,
149
162
  BLOCK_H: tl.constexpr,
150
163
  ):
151
164
  """
@@ -165,7 +178,7 @@ def rotary_bwd_q_kernel(
165
178
  if cu_seqlens_q is None:
166
179
  token_idx = pid_m // batch_size
167
180
  else:
168
- token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num)
181
+ token_idx = _get_thd_token_idx(cu_seqlens_q, pid_m, seq_num, cp_rank, cp_size)
169
182
 
170
183
  cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
171
184
  sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
@@ -200,7 +213,18 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
200
213
  """
201
214
 
202
215
  @staticmethod
203
- def forward(ctx, q, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, rotary_interleaved=False):
216
+ def forward(
217
+ ctx,
218
+ q,
219
+ cos,
220
+ sin,
221
+ qk_head_dim,
222
+ emb_dim,
223
+ cu_seqlens_q,
224
+ cp_rank,
225
+ cp_size,
226
+ rotary_interleaved=False,
227
+ ):
204
228
  """
205
229
  Forward function for ApplyMLARotaryEmbQ.
206
230
 
@@ -243,12 +267,16 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
243
267
  cu_seqlens_q,
244
268
  q.stride(0),
245
269
  q.stride(1),
270
+ cp_rank,
271
+ cp_size,
246
272
  )
247
273
  ctx.save_for_backward(cos, sin)
248
274
  ctx.qk_head_dim = qk_head_dim
249
275
  ctx.emb_dim = emb_dim
250
276
  ctx.cu_seqlens_q = cu_seqlens_q
251
277
  ctx.rotary_interleaved = rotary_interleaved
278
+ ctx.cp_rank = cp_rank
279
+ ctx.cp_size = cp_size
252
280
  if cu_seqlens_q is None:
253
281
  q = q.view(max_seqlen, batch_size, nheads, headdim)
254
282
  return q
@@ -268,7 +296,7 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
268
296
  seq_num = None
269
297
  if ctx.cu_seqlens_q is None:
270
298
  max_seqlen, batch_size, nheads, headdim = grad.shape
271
- grad = grad.view(-1, nheads, headdim)
299
+ grad = grad.contiguous().view(-1, nheads, headdim)
272
300
  total_seqlen = grad.shape[0]
273
301
  else:
274
302
  seq_num = len(ctx.cu_seqlens_q) - 1
@@ -288,10 +316,12 @@ class ApplyMLARotaryEmbQ(torch.autograd.Function):
288
316
  ctx.cu_seqlens_q,
289
317
  grad.stride(0),
290
318
  grad.stride(1),
319
+ ctx.cp_rank,
320
+ ctx.cp_size,
291
321
  )
292
322
  if ctx.cu_seqlens_q is None:
293
323
  grad = grad.view(max_seqlen, batch_size, nheads, headdim)
294
- return grad, None, None, None, None, None, None
324
+ return grad, None, None, None, None, None, None, None, None
295
325
 
296
326
 
297
327
  @experimental_fn(introduced_with_version="0.13.0")
@@ -302,6 +332,8 @@ def fused_apply_mla_rope_for_q(
302
332
  qk_head_dim: int,
303
333
  emb_dim: int,
304
334
  cu_seqlens_q: Optional[torch.Tensor] = None,
335
+ cp_rank: int = 0,
336
+ cp_size: int = 1,
305
337
  rotary_interleaved: bool = False,
306
338
  ):
307
339
  """
@@ -327,7 +359,7 @@ def fused_apply_mla_rope_for_q(
327
359
  t: inplace modified input tensor
328
360
  """
329
361
  return ApplyMLARotaryEmbQ.apply(
330
- t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, rotary_interleaved
362
+ t, cos, sin, qk_head_dim, emb_dim, cu_seqlens_q, cp_rank, cp_size, rotary_interleaved
331
363
  )
332
364
 
333
365
 
@@ -366,6 +398,8 @@ def rotary_fwd_kv_kernel(
366
398
  stride_k_nheads,
367
399
  stride_v_seq,
368
400
  stride_v_nheads,
401
+ cp_rank,
402
+ cp_size,
369
403
  BLOCK_H: tl.constexpr,
370
404
  ):
371
405
  """
@@ -394,7 +428,7 @@ def rotary_fwd_kv_kernel(
394
428
  if cu_seqlens_kv is None:
395
429
  token_idx = pid_m // batch_size
396
430
  else:
397
- token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num)
431
+ token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
398
432
 
399
433
  cos_left = tl.load(COS + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
400
434
  sin_left = tl.load(SIN + token_idx * emb_dim + tl.arange(0, emb_dim // 2))
@@ -472,6 +506,8 @@ def rotary_bwd_kv_kernel(
472
506
  stride_dkv_seq,
473
507
  stride_dkv_nheads,
474
508
  stride_demb_seq,
509
+ cp_rank,
510
+ cp_size,
475
511
  BLOCK_H: tl.constexpr,
476
512
  ):
477
513
  """
@@ -496,7 +532,7 @@ def rotary_bwd_kv_kernel(
496
532
  if cu_seqlens_kv is None:
497
533
  token_idx = pid_m // batch_size
498
534
  else:
499
- token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num)
535
+ token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
500
536
 
501
537
  dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads
502
538
  dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads
@@ -550,7 +586,18 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
550
586
 
551
587
  @staticmethod
552
588
  def forward(
553
- ctx, kv, k_pos_emb, cos, sin, emb_dim, k_dim, v_dim, cu_seqlens_kv, rotary_interleaved=False
589
+ ctx,
590
+ kv,
591
+ k_pos_emb,
592
+ cos,
593
+ sin,
594
+ emb_dim,
595
+ k_dim,
596
+ v_dim,
597
+ cu_seqlens_kv,
598
+ cp_rank,
599
+ cp_size,
600
+ rotary_interleaved=False,
554
601
  ):
555
602
  """
556
603
  Forward function for ApplyMLARotaryEmbKV.
@@ -609,6 +656,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
609
656
  o_key.stride(1),
610
657
  o_value.stride(0),
611
658
  o_value.stride(1),
659
+ cp_rank,
660
+ cp_size,
612
661
  )
613
662
  ctx.save_for_backward(cos, sin)
614
663
  ctx.rotary_interleaved = rotary_interleaved
@@ -616,6 +665,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
616
665
  ctx.k_dim = k_dim
617
666
  ctx.v_dim = v_dim
618
667
  ctx.cu_seqlens_kv = cu_seqlens_kv
668
+ ctx.cp_rank = cp_rank
669
+ ctx.cp_size = cp_size
619
670
  if cu_seqlens_kv is None:
620
671
  o_key = o_key.view(max_seqlen, -1, nheads, emb_dim + k_dim)
621
672
  o_value = o_value.view(max_seqlen, -1, nheads, v_dim)
@@ -638,8 +689,8 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
638
689
  if ctx.cu_seqlens_kv is None:
639
690
  # sbhd
640
691
  max_seqlen, batch_size, nheads, _ = dk.shape
641
- dk = dk.view(-1, nheads, ctx.emb_dim + ctx.k_dim)
642
- dv = dv.view(-1, nheads, ctx.v_dim)
692
+ dk = dk.contiguous().view(-1, nheads, ctx.emb_dim + ctx.k_dim)
693
+ dv = dv.contiguous().view(-1, nheads, ctx.v_dim)
643
694
  total_seqlen = dk.shape[0]
644
695
  else:
645
696
  # thd
@@ -673,11 +724,13 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
673
724
  d_kv.stride(0),
674
725
  d_kv.stride(1),
675
726
  d_emb.stride(0),
727
+ ctx.cp_rank,
728
+ ctx.cp_size,
676
729
  )
677
730
  if ctx.cu_seqlens_kv is None:
678
731
  d_kv = d_kv.view(max_seqlen, batch_size, nheads, ctx.k_dim + ctx.v_dim)
679
732
  d_emb = d_emb.view(max_seqlen, batch_size, 1, ctx.emb_dim)
680
- return d_kv, d_emb, None, None, None, None, None, None, None
733
+ return d_kv, d_emb, None, None, None, None, None, None, None, None, None
681
734
 
682
735
 
683
736
  @experimental_fn(introduced_with_version="0.13.0")
@@ -690,6 +743,8 @@ def fused_apply_mla_rope_for_kv(
690
743
  k_dim: int,
691
744
  v_dim: int,
692
745
  cu_seqlens_kv: Optional[torch.Tensor] = None,
746
+ cp_rank: int = 0,
747
+ cp_size: int = 1,
693
748
  rotary_interleaved: bool = False,
694
749
  ):
695
750
  """
@@ -715,5 +770,15 @@ def fused_apply_mla_rope_for_kv(
715
770
  value: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]
716
771
  """
717
772
  return ApplyMLARotaryEmbKV.apply(
718
- kv, k_pos_emb, cos, sin, emb_dim, k_dim, v_dim, cu_seqlens_kv, rotary_interleaved
773
+ kv,
774
+ k_pos_emb,
775
+ cos,
776
+ sin,
777
+ emb_dim,
778
+ k_dim,
779
+ v_dim,
780
+ cu_seqlens_kv,
781
+ cp_rank,
782
+ cp_size,
783
+ rotary_interleaved,
719
784
  )
@@ -14,6 +14,7 @@ warnings.warn(
14
14
  DeprecationWarning,
15
15
  )
16
16
  from .dynamic_context import (
17
+ ActiveRequestCountOverflowError,
17
18
  ChunkOverflowError,
18
19
  ContextOverflowError,
19
20
  DynamicInferenceContext,