megatron-core 0.14.0rc6__tar.gz → 0.14.0rc7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (324) hide show
  1. {megatron_core-0.14.0rc6/megatron_core.egg-info → megatron_core-0.14.0rc7}/PKG-INFO +1 -1
  2. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/__init__.py +6 -0
  3. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/mapping.py +0 -6
  4. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/common.py +6 -6
  5. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/__init__.py +1 -0
  6. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/distributed_data_parallel_config.py +20 -6
  7. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/finalize_model_grads.py +27 -14
  8. megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/__init__.py +3 -0
  9. megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +317 -0
  10. megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/__init__.py +13 -0
  11. megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +22 -0
  12. megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +141 -0
  13. megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +387 -0
  14. megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +1107 -0
  15. {megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp → megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp}/param_and_grad_buffer.py +1658 -522
  16. megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +458 -0
  17. megatron_core-0.14.0rc7/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +908 -0
  18. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/param_and_grad_buffer.py +6 -7
  19. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +8 -0
  20. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/extensions/transformer_engine.py +14 -2
  21. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/language_module/language_module.py +19 -2
  22. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/huggingface/clip_model.py +1 -1
  23. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/huggingface/qwen_model.py +1 -1
  24. megatron_core-0.14.0rc7/megatron/core/nccl_allocator.py +249 -0
  25. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/__init__.py +3 -22
  26. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/clip_grads.py +15 -0
  27. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/distrib_optimizer.py +155 -129
  28. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/optimizer.py +3 -1
  29. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/optimizer_config.py +6 -0
  30. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/package_info.py +1 -1
  31. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/parallel_state.py +6 -3
  32. megatron_core-0.14.0rc7/megatron/core/safe_globals.py +33 -0
  33. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/layers.py +8 -8
  34. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/cuda_graphs.py +318 -7
  35. megatron_core-0.14.0rc7/megatron/core/transformer/fsdp_dtensor_checkpoint.py +195 -0
  36. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/experts.py +1 -25
  37. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/transformer_config.py +5 -1
  38. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/transformer_layer.py +1 -1
  39. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/utils.py +0 -3
  40. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/utils.py +4 -41
  41. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7/megatron_core.egg-info}/PKG-INFO +1 -1
  42. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron_core.egg-info/SOURCES.txt +13 -3
  43. megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp/__init__.py +0 -3
  44. megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -835
  45. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/LICENSE +0 -0
  46. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/MANIFEST.in +0 -0
  47. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/README.md +0 -0
  48. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/README.md +0 -0
  49. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/activations.py +0 -0
  50. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/config.py +0 -0
  51. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/config_logger.py +0 -0
  52. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/__init__.py +0 -0
  53. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/bert_dataset.py +0 -0
  54. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/blended_dataset.py +0 -0
  55. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  56. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  57. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/gpt_dataset.py +0 -0
  58. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/helpers.cpp +0 -0
  59. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/helpers.py +0 -0
  60. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/indexed_dataset.py +0 -0
  61. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/masked_dataset.py +0 -0
  62. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/megatron_dataset.py +0 -0
  63. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  64. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/multimodal_dataset.py +0 -0
  65. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/object_storage_utils.py +0 -0
  66. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/__init__.py +0 -0
  67. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/__init__.py +0 -0
  68. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  69. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/config.py +0 -0
  70. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  71. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  72. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/db/__init__.py +0 -0
  73. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/db/build.py +0 -0
  74. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/db/dataset.py +0 -0
  75. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/db/utils.py +0 -0
  76. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/external_libs.py +0 -0
  77. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/__init__.py +0 -0
  78. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/build.py +0 -0
  79. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/factory.py +0 -0
  80. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/index.py +0 -0
  81. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  82. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  83. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  84. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/utils.py +0 -0
  85. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/index/validate.py +0 -0
  86. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/__init__.py +0 -0
  87. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  88. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  89. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/query.py +0 -0
  90. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  91. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/query/utils.py +0 -0
  92. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/retro/utils.py +0 -0
  93. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/t5_dataset.py +0 -0
  94. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/utils.py +0 -0
  95. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/utils_object_storage.py +0 -0
  96. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/datasets/utils_s3.py +0 -0
  97. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/__init__.py +0 -0
  98. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/core.py +0 -0
  99. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  100. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  101. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  102. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/serialization.py +0 -0
  103. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  104. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  105. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  106. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  107. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  108. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  109. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  110. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  111. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  112. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  113. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  114. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  115. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  116. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  117. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/utils.py +0 -0
  118. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/dist_checkpointing/validation.py +0 -0
  119. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/data_parallel_base.py +0 -0
  120. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  121. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  122. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/energy_monitor.py +0 -0
  123. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/enums.py +0 -0
  124. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/__init__.py +0 -0
  125. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/data_type.py +0 -0
  126. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/export_config.py +0 -0
  127. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/model_type.py +0 -0
  128. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/__init__.py +0 -0
  129. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  130. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  131. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  132. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  133. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  134. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  135. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  136. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  137. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  138. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  139. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  140. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  141. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/extensions/__init__.py +0 -0
  142. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/extensions/kitchen.py +0 -0
  143. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  144. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fp8_utils.py +0 -0
  145. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/full_cuda_graph.py +0 -0
  146. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/__init__.py +0 -0
  147. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  148. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  149. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  150. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  151. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  152. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_indices_converter.py +0 -0
  153. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_layer_norm.py +0 -0
  154. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  155. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  156. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_softmax.py +0 -0
  157. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/fusions/fused_weighted_squared_relu.py +0 -0
  158. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/hyper_comm_grid.py +0 -0
  159. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/__init__.py +0 -0
  160. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/async_stream.py +0 -0
  161. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/common_inference_params.py +0 -0
  162. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/communication_utils.py +0 -0
  163. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/__init__.py +0 -0
  164. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/base_context.py +0 -0
  165. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  166. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/dynamic_context.py +0 -0
  167. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/contexts/static_context.py +0 -0
  168. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/__init__.py +0 -0
  169. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/abstract_engine.py +0 -0
  170. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/dynamic_engine.py +0 -0
  171. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/mcore_engine.py +0 -0
  172. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/engines/static_engine.py +0 -0
  173. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/inference_request.py +0 -0
  174. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  175. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  176. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  177. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  178. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  179. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  180. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  181. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  182. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/sampling_params.py +0 -0
  183. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/scheduler.py +0 -0
  184. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  185. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  186. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  187. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +0 -0
  188. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  189. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference/utils.py +0 -0
  190. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/inference_params.py +0 -0
  191. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/jit.py +0 -0
  192. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/model_parallel_config.py +0 -0
  193. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/T5/__init__.py +0 -0
  194. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/T5/t5_model.py +0 -0
  195. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/T5/t5_spec.py +0 -0
  196. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/__init__.py +0 -0
  197. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/backends.py +0 -0
  198. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/__init__.py +0 -0
  199. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  200. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/bert_lm_head.py +0 -0
  201. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/bert_model.py +0 -0
  202. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/bert/pooler.py +0 -0
  203. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/__init__.py +0 -0
  204. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/__init__.py +0 -0
  205. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  206. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  207. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  208. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  209. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  210. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/language_module/__init__.py +0 -0
  211. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
  212. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/vision_module/__init__.py +0 -0
  213. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  214. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/__init__.py +0 -0
  215. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  216. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
  217. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/gpt_model.py +0 -0
  218. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  219. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  220. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/huggingface/__init__.py +0 -0
  221. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/huggingface/module.py +0 -0
  222. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mamba/__init__.py +0 -0
  223. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  224. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mamba/mamba_model.py +0 -0
  225. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/__init__.py +0 -0
  226. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/config/__init__.py +0 -0
  227. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/config/base_configs.py +0 -0
  228. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/model/__init__.py +0 -0
  229. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/model/base.py +0 -0
  230. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/submodules/audio.py +0 -0
  231. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/submodules/base.py +0 -0
  232. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/mimo/submodules/vision.py +0 -0
  233. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/multimodal/__init__.py +0 -0
  234. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/multimodal/context_parallel.py +0 -0
  235. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/multimodal/llava_model.py +0 -0
  236. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/multimodal/llava_spec.py +0 -0
  237. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/__init__.py +0 -0
  238. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/base_attention.py +0 -0
  239. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/config.py +0 -0
  240. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/decoder_attention.py +0 -0
  241. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/decoder_spec.py +0 -0
  242. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/encoder_attention.py +0 -0
  243. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/encoder_spec.py +0 -0
  244. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/model.py +0 -0
  245. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/retro/utils.py +0 -0
  246. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/__init__.py +0 -0
  247. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/clip_vit_model.py +0 -0
  248. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/multimodal_projector.py +0 -0
  249. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/radio.py +0 -0
  250. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  251. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/msc_utils.py +0 -0
  252. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/num_microbatches_calculator.py +0 -0
  253. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  254. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  255. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer/grad_scaler.py +0 -0
  256. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/optimizer_param_scheduler.py +0 -0
  257. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/packed_seq_params.py +0 -0
  258. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/__init__.py +0 -0
  259. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
  260. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  261. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/schedules.py +0 -0
  262. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/pipeline_parallel/utils.py +0 -0
  263. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/__init__.py +0 -0
  264. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/__init__.py +0 -0
  265. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  266. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  267. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  268. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/layers.py +0 -0
  269. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  270. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  271. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/process_groups_config.py +0 -0
  272. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/quantization/__init__.py +0 -0
  273. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/quantization/quant_config.py +0 -0
  274. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/quantization/utils.py +0 -0
  275. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/requirements.txt +0 -0
  276. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/rerun_state_machine.py +0 -0
  277. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/__init__.py +0 -0
  278. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_block.py +0 -0
  279. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  280. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  281. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_layer.py +0 -0
  282. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mamba_mixer.py +0 -0
  283. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/mlp_layer.py +0 -0
  284. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/ssm/triton_cache_manager.py +0 -0
  285. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/__init__.py +0 -0
  286. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  287. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/data.py +0 -0
  288. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/mappings.py +0 -0
  289. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/random.py +0 -0
  290. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/tensor_parallel/utils.py +0 -0
  291. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/timers.py +0 -0
  292. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/__init__.py +0 -0
  293. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/attention.py +0 -0
  294. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  295. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  296. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/dot_product_attention.py +0 -0
  297. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/enums.py +0 -0
  298. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  299. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  300. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/identity_op.py +0 -0
  301. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/mlp.py +0 -0
  302. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/module.py +0 -0
  303. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/__init__.py +0 -0
  304. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  305. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  306. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/moe_layer.py +0 -0
  307. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/moe_utils.py +0 -0
  308. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/router.py +0 -0
  309. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/shared_experts.py +0 -0
  310. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  311. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  312. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/multi_latent_attention.py +0 -0
  313. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/multi_token_prediction.py +0 -0
  314. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  315. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/spec_utils.py +0 -0
  316. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/torch_layer_norm.py +0 -0
  317. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/torch_norm.py +0 -0
  318. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron/core/transformer/transformer_block.py +0 -0
  319. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron_core.egg-info/dependency_links.txt +0 -0
  320. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron_core.egg-info/requires.txt +0 -0
  321. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/megatron_core.egg-info/top_level.txt +0 -0
  322. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/pyproject.toml +0 -0
  323. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/setup.cfg +0 -0
  324. {megatron_core-0.14.0rc6 → megatron_core-0.14.0rc7}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.14.0rc6
3
+ Version: 0.14.0rc7
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -20,6 +20,7 @@ from megatron.core.package_info import (
20
20
  __version__,
21
21
  )
22
22
  from megatron.core.timers import Timers
23
+ from megatron.core.utils import is_torch_min_version
23
24
 
24
25
  # Alias parallel_state as mpu, its legacy name
25
26
  mpu = parallel_state
@@ -33,3 +34,8 @@ __all__ = [
33
34
  "ModelParallelConfig",
34
35
  "Timers",
35
36
  ]
37
+
38
+ from .safe_globals import register_safe_globals
39
+
40
+ if is_torch_min_version("2.6a0"):
41
+ register_safe_globals()
@@ -136,12 +136,6 @@ class ShardedTensor(ShardedBase):
136
136
  )
137
137
 
138
138
  for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
139
- # NOTE: In custom FSDP, we have a case where a new parameter shard is created locally.
140
- # For example, consider parameters [p0, p1, p2] sharded across GPU0 and GPU1.
141
- # GPU0 receives p0 and a portion of p1, while GPU1 receives the
142
- # remaining portion of p1 and p2.
143
- # As a result, there is no parameter shard of p2 on GPU0, and
144
- # the shape of p2 on GPU0 is zero.
145
139
  if sh != 0 and off % sh != 0:
146
140
  raise CheckpointingException(
147
141
  f"Global offset ({off}) must be divisible by local shape ({sh}) for {self}."
@@ -84,9 +84,9 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
84
84
  try:
85
85
  if MultiStorageClientFeature.is_enabled():
86
86
  msc = MultiStorageClientFeature.import_package()
87
- return msc.torch.load(load_path, map_location='cpu', weights_only=False)
87
+ return msc.torch.load(load_path, map_location='cpu')
88
88
  else:
89
- return torch.load(load_path, map_location='cpu', weights_only=False)
89
+ return torch.load(load_path, map_location='cpu')
90
90
  except FileNotFoundError as e:
91
91
  err_msg = f'Common file {load_path} does not exist'
92
92
  if MultiStorageClientFeature.is_enabled():
@@ -118,9 +118,9 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
118
118
  try:
119
119
  if MultiStorageClientFeature.is_enabled():
120
120
  msc = MultiStorageClientFeature.import_package()
121
- loaded_obj = msc.torch.load(load_path, weights_only=False)
121
+ loaded_obj = msc.torch.load(load_path)
122
122
  else:
123
- loaded_obj = torch.load(load_path, weights_only=False)
123
+ loaded_obj = torch.load(load_path)
124
124
  except FileNotFoundError as e:
125
125
  # Backward compatible logic: previously the save format was incorrect
126
126
  base, _ = os.path.splitext(sh_obj.unique_key)
@@ -128,9 +128,9 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
128
128
  try:
129
129
  if MultiStorageClientFeature.is_enabled():
130
130
  msc = MultiStorageClientFeature.import_package()
131
- loaded_obj = msc.torch.load(old_load_path, weights_only=False)
131
+ loaded_obj = msc.torch.load(old_load_path)
132
132
  else:
133
- loaded_obj = torch.load(old_load_path, weights_only=False)
133
+ loaded_obj = torch.load(old_load_path)
134
134
  except FileNotFoundError:
135
135
  err_msg = f'Object shard {load_path} not found'
136
136
  obj_subdir = os.path.join(checkpoint_dir, sh_obj.key)
@@ -8,5 +8,6 @@ except ImportError:
8
8
  from .distributed_data_parallel import DistributedDataParallel
9
9
  from .distributed_data_parallel_config import DistributedDataParallelConfig
10
10
  from .finalize_model_grads import finalize_model_grads
11
+ from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel
11
12
  from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
12
13
  from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig
@@ -61,9 +61,16 @@ class DistributedDataParallelConfig:
61
61
  """If true, reuse the grad buffer for param AG when using mxfp8 recipe. Should be
62
62
  set to True only when fp8_recipe is mxfp8 and fp8_param_gather is True."""
63
63
 
64
- use_custom_fsdp: bool = False
64
+ use_megatron_fsdp: bool = False
65
65
  """If true, use the FSDP code path for DDP."""
66
66
 
67
+ use_custom_fsdp: bool = False
68
+ """
69
+ NOTE: The flag `use_custom_fsdp` is deprecated and will be removed in future versions.
70
+ Please use `use_megatron_fsdp` instead, as all functionality will be migrated there.
71
+ Future updates will drop support for `use_custom_fsdp` to avoid confusion.
72
+ """
73
+
67
74
  data_parallel_sharding_strategy: str = 'no_shard'
68
75
  """Sharding strategy for FSDP. Valid values are 'no_shard', 'optim',
69
76
  'optim_grads', 'optim_grads_params'."""
@@ -80,10 +87,10 @@ class DistributedDataParallelConfig:
80
87
  based on your system's memory and performance requirements."""
81
88
 
82
89
  preserve_fp32_weights: bool = True
83
- """If true, preserve fp32 weights in the custom FSDP ParamAndGradBuffer."""
90
+ """If true, preserve fp32 weights in the Megatron FSDP ParamAndGradBuffer."""
84
91
 
85
- keep_fp8_transpose_cache_when_using_custom_fsdp: bool = False
86
- """If true, keep the fp8 transpose cache when using custom FSDP."""
92
+ keep_fp8_transpose_cache: bool = False
93
+ """If true, keep the fp8 transpose cache when using Megatron FSDP."""
87
94
 
88
95
  nccl_ub: bool = False
89
96
  """If true, allocate and register NCCL userbuffer for param and grad buffer.
@@ -106,12 +113,19 @@ class DistributedDataParallelConfig:
106
113
 
107
114
  fsdp_double_buffer: bool = False
108
115
  """If true, use persistently allocated double buffers for the
109
- temporary memory needed in the custom FSDP communications.
116
+ temporary memory needed in the Megatron FSDP communications.
110
117
  This option will cause additional memory overhead, however, it is necessary for
111
- to register user buffer (nccl_ub=True) for the custom FSDP.
118
+ to register user buffer (nccl_ub=True) for the Megatron FSDP.
112
119
  This option will be automatically set to True when nccl_ub=True.
113
120
  """
114
121
 
122
+ outer_dp_sharding_strategy: str = 'no_shard'
123
+ """
124
+ Sharding strategy for outer data parallel group in Hybrid Sharded Data Parallel (HSDP) mode.
125
+ Valid values are 'no_shard', 'optim', 'optim_grads', 'optim_grads_params'.
126
+ This option is only effective when Hybrid FSDP is enabled.
127
+ """
128
+
115
129
  def __post_init__(self):
116
130
  import os
117
131
 
@@ -31,9 +31,7 @@ from ..utils import (
31
31
  )
32
32
 
33
33
 
34
- def _get_main_grad_attr(param: torch.nn.Parameter, use_custom_fsdp: bool = False):
35
- if use_custom_fsdp:
36
- return "fsdp_managed_main_grad"
34
+ def _get_main_grad_attr(param: torch.nn.Parameter, use_megatron_fsdp: bool = False):
37
35
  if hasattr(param, "main_grad"):
38
36
  return "main_grad"
39
37
  return "grad"
@@ -241,8 +239,10 @@ def _allreduce_embedding_grad(
241
239
  if weight is None and skip_if_none:
242
240
  return
243
241
 
244
- grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp)
242
+ grad_attr = _get_main_grad_attr(weight, ddp_config.use_megatron_fsdp)
245
243
  orig_grad = getattr(weight, grad_attr)
244
+ if ddp_config.use_megatron_fsdp:
245
+ orig_grad = orig_grad._local_tensor if orig_grad is not None else None
246
246
  grad = _unshard_if_dtensor(orig_grad)
247
247
  # When the embedding is frozen, the grad is None.
248
248
  if grad is None and skip_if_none:
@@ -320,20 +320,30 @@ def _allreduce_non_tensor_model_parallel_grads(
320
320
  if param.requires_grad:
321
321
  # Check if this param needs average reduction (average_gradients_across_tp_domain)
322
322
  if getattr(param, "average_gradients_across_tp_domain", False):
323
- params_avg.append(param)
324
- grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
323
+ grad_attr = _get_main_grad_attr(param, ddp_config.use_megatron_fsdp)
325
324
  grad = getattr(param, grad_attr)
326
- grad = _unshard_if_dtensor(grad)
327
- grads_avg.append(grad.data)
325
+ if grad is None:
326
+ continue
327
+ params_avg.append(param)
328
+ if ddp_config.use_megatron_fsdp:
329
+ grads_avg.append(grad._local_tensor.data)
330
+ else:
331
+ grad = _unshard_if_dtensor(grad)
332
+ grads_avg.append(grad.data)
328
333
  # Check if this param needs sum reduction (sequence parallel or qk_layernorm)
329
334
  elif (config.sequence_parallel and getattr(param, "sequence_parallel", False)) or (
330
335
  config.qk_layernorm and ("q_layernorm" in name or "k_layernorm" in name)
331
336
  ):
332
- params_sum.append(param)
333
- grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
337
+ grad_attr = _get_main_grad_attr(param, ddp_config.use_megatron_fsdp)
334
338
  grad = getattr(param, grad_attr)
335
- grad = _unshard_if_dtensor(grad)
336
- grads_sum.append(grad.data)
339
+ if grad is None:
340
+ continue
341
+ params_sum.append(param)
342
+ if ddp_config.use_megatron_fsdp:
343
+ grads_sum.append(grad._local_tensor.data)
344
+ else:
345
+ grad = _unshard_if_dtensor(grad)
346
+ grads_sum.append(grad.data)
337
347
 
338
348
  # Loop grads and perform correct all-reduce
339
349
  for params, grads, all_reduce_op in zip(
@@ -348,9 +358,12 @@ def _allreduce_non_tensor_model_parallel_grads(
348
358
  params, grads, _unflatten_dense_tensors(coalesced, grads)
349
359
  ):
350
360
  buf.copy_(synced)
351
- grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
361
+ grad_attr = _get_main_grad_attr(param, ddp_config.use_megatron_fsdp)
352
362
  orig_grad = getattr(param, grad_attr)
353
- setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
363
+ if ddp_config.use_megatron_fsdp:
364
+ setattr(param, grad_attr, orig_grad)
365
+ else:
366
+ setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
354
367
 
355
368
 
356
369
  """
@@ -0,0 +1,3 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ from .mcore_fsdp_adapter import FullyShardedDataParallel
@@ -0,0 +1,317 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import List, Optional
17
+
18
+ try:
19
+ import einops
20
+
21
+ HAVE_EINOPS = True
22
+ except ImportError:
23
+ HAVE_EINOPS = False
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+
28
+ try:
29
+ from torch.distributed import DeviceMesh
30
+
31
+ HAVE_DTENSOR = True
32
+ except ImportError:
33
+ HAVE_DTENSOR = False
34
+
35
+ from megatron.core import parallel_state
36
+ from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
37
+ from megatron.core.distributed.data_parallel_base import _BaseDataParallel
38
+ from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
39
+ from megatron.core.process_groups_config import GradCommProcessGroups, ModelCommProcessGroups
40
+ from megatron.core.transformer.transformer_config import TransformerConfig
41
+ from megatron.core.transformer.transformer_layer import TransformerLayer
42
+ from megatron.core.utils import log_single_rank
43
+
44
+ try:
45
+ from megatron.core.distributed.fsdp.src.megatron_fsdp import FSDPDistributedIndex, MegatronFSDP
46
+
47
+ HAVE_MEGATRON_FSDP = True
48
+ except ImportError as import_megatron_fsdp_error:
49
+ IMPORT_MEGATRON_FSDP_ERROR = import_megatron_fsdp_error
50
+ HAVE_MEGATRON_FSDP = False
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ class FullyShardedDataParallel(_BaseDataParallel):
56
+ """
57
+ Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ config: TransformerConfig,
63
+ ddp_config: DistributedDataParallelConfig,
64
+ module: torch.nn.Module,
65
+ fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
66
+ disable_bucketing: bool = False,
67
+ device: Optional[torch.device] = None,
68
+ grad_comm_pgs: Optional[GradCommProcessGroups] = None,
69
+ model_comm_pgs: Optional[ModelCommProcessGroups] = None,
70
+ ):
71
+ if not HAVE_MEGATRON_FSDP:
72
+ raise IMPORT_MEGATRON_FSDP_ERROR
73
+
74
+ if has_config_logger_enabled(config):
75
+ log_config_to_disk(config, locals(), prefix=type(self).__name__)
76
+
77
+ self.ddp_config = ddp_config
78
+ log_single_rank(
79
+ logger,
80
+ logging.INFO,
81
+ f'Setting up DistributedDataParallel with config {self.ddp_config}',
82
+ )
83
+ self.megatron_fsdp_dist_index = self._init_dist_index(grad_comm_pgs, model_comm_pgs)
84
+
85
+ self.bucket_size = self.ddp_config.bucket_size
86
+ if disable_bucketing:
87
+ self.bucket_size = None
88
+ self.device = device if device else torch.device(f'cuda:{torch.cuda.current_device()}')
89
+
90
+ if fsdp_unit_modules is not None:
91
+ self.fsdp_unit_modules = fsdp_unit_modules
92
+ else:
93
+ if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
94
+ self.fsdp_unit_modules = [TransformerLayer]
95
+ else:
96
+ self.fsdp_unit_modules = []
97
+
98
+ super().__init__(
99
+ config=config,
100
+ module=MegatronFSDP(
101
+ ddp_config=ddp_config,
102
+ module=module,
103
+ fsdp_unit_modules=self.fsdp_unit_modules,
104
+ disable_bucketing=disable_bucketing,
105
+ device=self.device,
106
+ dist_index=self.megatron_fsdp_dist_index,
107
+ calculate_per_token_loss=config.calculate_per_token_loss,
108
+ init_model_with_meta_device=config.init_model_with_meta_device,
109
+ ),
110
+ )
111
+ self.param_and_grad_buffer = self.module.param_and_grad_buffer
112
+ self.no_sync = self.module.no_sync
113
+ self.start_param_sync = self.module.start_param_sync
114
+ self.start_grad_sync = self.module.start_grad_sync
115
+ self.finish_grad_sync = self.module.finish_grad_sync
116
+ self.scale_gradients = self.module.scale_gradients
117
+ self.zero_grad_buffer = self.module.zero_grad_buffer
118
+ self.broadcast_params = self.module.broadcast_params
119
+ self.module.state_dict_for_save_checkpoint = self.module.state_dict
120
+ self.state_dict_for_save_checkpoint = self.state_dict
121
+
122
+ def load_state_dict(self, state_dict, strict=True):
123
+ """
124
+ Load the state dictionary into the module.
125
+ """
126
+ custom_state_dict = {}
127
+ for key, value in state_dict.items():
128
+ if self.config.fp8 and key.endswith('._extra_state'):
129
+ # Skip extra state keys
130
+ continue
131
+ custom_state_dict[f"module.{key}"] = value
132
+
133
+ if self.config.fp8 or self.config.gated_linear_unit:
134
+ strict = False
135
+ log_single_rank(
136
+ logger,
137
+ logging.WARNING,
138
+ "Loading state_dict with strict=False due to fp8 configuration. "
139
+ "This is expected as some keys may not match exactly.",
140
+ )
141
+
142
+ self.module.load_state_dict(custom_state_dict, strict=strict)
143
+
144
+ def _init_dist_index(self, grad_comm_pgs, model_comm_pgs):
145
+ """
146
+ Initialize the distributed index for the module.
147
+ """
148
+ if not HAVE_DTENSOR:
149
+ raise ImportError(
150
+ "This module requires PyTorch with DTensor support. "
151
+ "Please install a compatible version of PyTorch."
152
+ )
153
+
154
+ enable_hsdp = self.ddp_config.num_distributed_optimizer_instances > 1
155
+ if grad_comm_pgs is None and model_comm_pgs is None:
156
+ tp_group = parallel_state.get_tensor_model_parallel_group()
157
+ if enable_hsdp:
158
+ dp_cp_group = parallel_state.get_data_parallel_group(
159
+ with_context_parallel=True, partial_data_parallel=True
160
+ )
161
+ inter_fsdp_group = parallel_state.get_inter_distributed_optimizer_instance_group()
162
+ hybrid_fsdp_group = parallel_state.get_data_parallel_group(
163
+ with_context_parallel=True, partial_data_parallel=False
164
+ )
165
+ else:
166
+ dp_cp_group = parallel_state.get_data_parallel_group(
167
+ with_context_parallel=True, partial_data_parallel=False
168
+ )
169
+ inter_fsdp_group = None
170
+ hybrid_fsdp_group = None
171
+ elif grad_comm_pgs is not None and model_comm_pgs is not None:
172
+ tp_group = getattr(model_comm_pgs, 'tp', None)
173
+ if enable_hsdp:
174
+ dp_cp_group = grad_comm_pgs.intra_dp_cp
175
+ inter_fsdp_group = grad_comm_pgs.inter_dist_opt
176
+ hybrid_fsdp_group = grad_comm_pgs.dp_cp
177
+ else:
178
+ dp_cp_group = grad_comm_pgs.dp_cp
179
+ inter_fsdp_group = None
180
+ hybrid_fsdp_group = None
181
+ else:
182
+ raise ValueError(
183
+ "Both grad_comm_pgs and model_comm_pgs must be either None or provided together."
184
+ )
185
+
186
+ if tp_group is None:
187
+ single_rank_group = dist.new_group(ranks=[dist.get_rank()])
188
+ tp_group = single_rank_group
189
+
190
+ if enable_hsdp:
191
+ mesh = _get_hsdp_tp_mesh(inter_fsdp_group, dp_cp_group, tp_group)
192
+ dist_index = FSDPDistributedIndex(
193
+ use_hybrid_fsdp=True,
194
+ hsdp_outer_dp_shard=self.ddp_config.outer_dp_sharding_strategy != "no_shard",
195
+ device_mesh=DeviceMesh.from_group(
196
+ [inter_fsdp_group, dp_cp_group, tp_group],
197
+ device_type="cuda",
198
+ mesh=mesh.tolist(),
199
+ mesh_dim_names=["inter_fsdp_dp", "dp_cp", "tp"],
200
+ ),
201
+ dp_inter_dim="inter_fsdp_dp",
202
+ dp_shard_dim="dp_cp",
203
+ tp_dim="tp",
204
+ hybrid_fsdp_group=hybrid_fsdp_group,
205
+ )
206
+ else:
207
+ mesh = _get_dp_tp_mesh(dp_cp_group, tp_group)
208
+ dist_index = FSDPDistributedIndex(
209
+ device_mesh=DeviceMesh.from_group(
210
+ [dp_cp_group, tp_group],
211
+ device_type="cuda",
212
+ mesh=mesh.tolist(),
213
+ mesh_dim_names=["dp_cp", "tp"],
214
+ ),
215
+ dp_shard_dim="dp_cp",
216
+ tp_dim="tp",
217
+ )
218
+
219
+ return dist_index
220
+
221
+ def stop_communication(self):
222
+ """
223
+ Stop communication for the module.
224
+ """
225
+ self.module.synchronize_gradient_reduce()
226
+ self.module.synchronize_param_gather()
227
+
228
+
229
+ def _get_hsdp_tp_mesh(inter_fsdp_dp_group, dp_cp_group, tp_group):
230
+ assert HAVE_EINOPS, "einops is not installed. Please install it with `pip install einops`."
231
+ world_size = dist.get_world_size()
232
+
233
+ mesh = einops.rearrange(
234
+ torch.arange(world_size),
235
+ "(inter_fsdp_dp fsdp tp) -> inter_fsdp_dp fsdp tp",
236
+ inter_fsdp_dp=inter_fsdp_dp_group.size(),
237
+ tp=tp_group.size(),
238
+ )
239
+
240
+ mesh_fsdp_ranks = einops.rearrange(
241
+ mesh,
242
+ 'inter_fsdp_dp fsdp tp -> (inter_fsdp_dp tp) fsdp',
243
+ tp=tp_group.size(),
244
+ fsdp=dp_cp_group.size(),
245
+ )
246
+ fsdp_group_ranks = dist.get_process_group_ranks(dp_cp_group)
247
+ assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_fsdp_ranks, fsdp_group_ranks), (
248
+ f"[Megatron-FSDP] FSDP ranks in the mesh {mesh_fsdp_ranks} "
249
+ f"do not match the ranks in the FSDP group {fsdp_group_ranks}."
250
+ )
251
+
252
+ mesh_tp_ranks = einops.rearrange(
253
+ mesh,
254
+ 'inter_fsdp_dp fsdp tp -> (inter_fsdp_dp fsdp) tp',
255
+ tp=tp_group.size(),
256
+ fsdp=dp_cp_group.size(),
257
+ )
258
+ tp_group_ranks = dist.get_process_group_ranks(tp_group)
259
+ assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_tp_ranks, tp_group_ranks), (
260
+ f"[Megatron-FSDP] Tensor Parallel ranks in the mesh {mesh_tp_ranks} "
261
+ f"do not match the ranks in the TP group {tp_group_ranks}."
262
+ )
263
+
264
+ mesh_inter_fsdp_dp_ranks = einops.rearrange(
265
+ mesh,
266
+ 'inter_fsdp_dp fsdp tp -> (fsdp tp) inter_fsdp_dp',
267
+ tp=tp_group.size(),
268
+ fsdp=dp_cp_group.size(),
269
+ )
270
+ inter_fsdp_dp_group_ranks = dist.get_process_group_ranks(inter_fsdp_dp_group)
271
+ assert _check_mesh_ranks_and_group_ranks_are_consistent(
272
+ mesh_inter_fsdp_dp_ranks, inter_fsdp_dp_group_ranks
273
+ ), (
274
+ f"[Megatron-FSDP] Inter FSDP Data Parallel ranks in the mesh {mesh_inter_fsdp_dp_ranks} "
275
+ f"do not match the ranks in the Inter FSDP DP group {inter_fsdp_dp_group_ranks}."
276
+ )
277
+
278
+ return mesh
279
+
280
+
281
+ def _get_dp_tp_mesh(dp_cp_group, tp_group):
282
+ assert HAVE_EINOPS, "einops is not installed. Please install it with `pip install einops`."
283
+ world_size = dist.get_world_size()
284
+
285
+ tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
286
+ # TODO: Supports configurable (dp, cp, tp) order.
287
+ mesh = einops.rearrange(torch.arange(world_size), "(dp_cp tp) -> dp_cp tp", tp=tp_size)
288
+
289
+ mesh_dp_ranks = einops.rearrange(mesh, 'dp_cp tp -> tp dp_cp', tp=tp_size)
290
+ dp_cp_group_ranks = dist.get_process_group_ranks(dp_cp_group)
291
+ assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_dp_ranks, dp_cp_group_ranks), (
292
+ f"[Megatron-FSDP] Data Parallel ranks in the mesh {mesh_dp_ranks} "
293
+ f"do not match the ranks in the DP group {dp_cp_group_ranks}."
294
+ )
295
+
296
+ mesh_tp_ranks = einops.rearrange(mesh, 'dp_cp tp -> (dp_cp) tp', tp=tp_size)
297
+ tp_group_ranks = dist.get_process_group_ranks(tp_group)
298
+ assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_tp_ranks, tp_group_ranks), (
299
+ f"[Megatron-FSDP] Tensor Parallel ranks in the mesh {mesh_tp_ranks} "
300
+ f"do not match the ranks in the TP group {tp_group_ranks}."
301
+ )
302
+
303
+ return mesh
304
+
305
+
306
+ def _check_mesh_ranks_and_group_ranks_are_consistent(mesh_ranks, group_ranks):
307
+ current_rank = dist.get_rank()
308
+ current_ranks = list(filter(lambda ranks: current_rank in ranks, mesh_ranks.tolist()))
309
+ assert len(current_ranks) == 1, (
310
+ f"[Megatron-FSDP] Current rank {current_rank} is not unique in "
311
+ f"the mesh ranks {mesh_ranks.tolist()}."
312
+ )
313
+ assert sorted(current_ranks[0]) == sorted(group_ranks), (
314
+ f"[Megatron-FSDP] Current rank {current_rank} in the mesh ranks "
315
+ f"{mesh_ranks.tolist()} does not match the group ranks {group_ranks}."
316
+ )
317
+ return sorted(current_ranks[0]) == sorted(group_ranks)
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,22 @@
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .distributed_data_parallel_config import DistributedDataParallelConfig
16
+ from .megatron_fsdp import MegatronFSDP
17
+ from .utils import FSDPDistributedIndex
18
+
19
+ try:
20
+ from .fully_shard import fully_shard
21
+ except ImportError as e:
22
+ print(f"Failed to import fully_shard: {e}")