megatron-core 0.12.0rc3__tar.gz → 0.12.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (291) hide show
  1. {megatron_core-0.12.0rc3/megatron_core.egg-info → megatron_core-0.12.2}/PKG-INFO +1 -1
  2. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/async_utils.py +29 -11
  3. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +214 -156
  4. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +275 -186
  5. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/distributed_data_parallel_config.py +7 -3
  6. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/finalize_model_grads.py +3 -2
  7. megatron_core-0.12.2/megatron/core/export/model_type.py +8 -0
  8. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +6 -0
  9. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +10 -0
  10. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trt_model_config.py +1 -0
  11. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trt_model_type.py +1 -0
  12. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_helper.py +20 -2
  13. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_layers.py +9 -0
  14. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +10 -4
  15. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +17 -5
  16. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/extensions/transformer_engine.py +34 -8
  17. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fp8_utils.py +15 -8
  18. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_bias_swiglu.py +57 -0
  19. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/rope_utils.py +20 -3
  20. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/gpt/gpt_layer_specs.py +27 -6
  21. megatron_core-0.12.2/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +209 -0
  22. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/__init__.py +16 -5
  23. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/optimizer.py +99 -21
  24. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/package_info.py +2 -2
  25. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/gpt/model_specs.py +12 -4
  26. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/__init__.py +2 -0
  27. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/random.py +149 -23
  28. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/attention.py +4 -1
  29. megatron_core-0.12.2/megatron/core/transformer/heterogeneous/heterogeneous_config.py +267 -0
  30. megatron_core-0.12.2/megatron/core/transformer/heterogeneous/linear_replacements.py +111 -0
  31. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/mlp.py +37 -18
  32. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/experts.py +166 -60
  33. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py +18 -11
  34. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/moe_layer.py +20 -10
  35. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/moe_utils.py +91 -19
  36. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/shared_experts.py +4 -0
  37. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/token_dispatcher.py +63 -64
  38. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/multi_latent_attention.py +121 -69
  39. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/torch_norm.py +49 -1
  40. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/transformer_block.py +22 -5
  41. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/transformer_config.py +88 -14
  42. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/transformer_layer.py +51 -5
  43. {megatron_core-0.12.0rc3 → megatron_core-0.12.2/megatron_core.egg-info}/PKG-INFO +1 -1
  44. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron_core.egg-info/SOURCES.txt +3 -0
  45. megatron_core-0.12.0rc3/megatron/core/export/model_type.py +0 -7
  46. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/LICENSE +0 -0
  47. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/MANIFEST.in +0 -0
  48. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/README.md +0 -0
  49. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/README.md +0 -0
  50. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/__init__.py +0 -0
  51. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/config.py +0 -0
  52. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/config_logger.py +0 -0
  53. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/__init__.py +0 -0
  54. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/bert_dataset.py +0 -0
  55. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/blended_dataset.py +0 -0
  56. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  57. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  58. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/gpt_dataset.py +0 -0
  59. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/helpers.cpp +0 -0
  60. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/helpers.py +0 -0
  61. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/indexed_dataset.py +0 -0
  62. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/masked_dataset.py +0 -0
  63. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/megatron_dataset.py +0 -0
  64. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  65. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/multimodal_dataset.py +0 -0
  66. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/__init__.py +0 -0
  67. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/__init__.py +0 -0
  68. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  69. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/config.py +0 -0
  70. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  71. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  72. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/db/__init__.py +0 -0
  73. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/db/build.py +0 -0
  74. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/db/dataset.py +0 -0
  75. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/db/utils.py +0 -0
  76. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/external_libs.py +0 -0
  77. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/__init__.py +0 -0
  78. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/build.py +0 -0
  79. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/factory.py +0 -0
  80. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/index.py +0 -0
  81. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  82. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  83. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  84. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/utils.py +0 -0
  85. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/index/validate.py +0 -0
  86. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/__init__.py +0 -0
  87. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  88. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  89. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/query.py +0 -0
  90. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  91. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/query/utils.py +0 -0
  92. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/retro/utils.py +0 -0
  93. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/t5_dataset.py +0 -0
  94. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/utils.py +0 -0
  95. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/datasets/utils_s3.py +0 -0
  96. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/__init__.py +0 -0
  97. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/core.py +0 -0
  98. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  99. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  100. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/mapping.py +0 -0
  101. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  102. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/serialization.py +0 -0
  103. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  104. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  105. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  106. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  107. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  108. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  109. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  110. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  111. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  112. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  113. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  114. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  115. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  116. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  117. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/utils.py +0 -0
  118. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/dist_checkpointing/validation.py +0 -0
  119. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/__init__.py +0 -0
  120. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
  121. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/data_parallel_base.py +0 -0
  122. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  123. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  124. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  125. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/enums.py +0 -0
  126. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/__init__.py +0 -0
  127. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/data_type.py +0 -0
  128. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/export_config.py +0 -0
  129. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/__init__.py +0 -0
  130. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  131. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  132. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  133. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/extensions/__init__.py +0 -0
  134. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/__init__.py +0 -0
  135. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  136. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  137. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  138. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  139. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_layer_norm.py +0 -0
  140. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/fusions/fused_softmax.py +0 -0
  141. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/__init__.py +0 -0
  142. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/async_stream.py +0 -0
  143. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/common_inference_params.py +0 -0
  144. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/communication_utils.py +0 -0
  145. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/contexts/__init__.py +0 -0
  146. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/contexts/base_context.py +0 -0
  147. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/contexts/dynamic_context.py +0 -0
  148. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/contexts/static_context.py +0 -0
  149. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/__init__.py +0 -0
  150. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/abstract_engine.py +0 -0
  151. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/dynamic_engine.py +0 -0
  152. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/mcore_engine.py +0 -0
  153. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/engines/static_engine.py +0 -0
  154. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/inference_request.py +0 -0
  155. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  156. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  157. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  158. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  159. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  160. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  161. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  162. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  163. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/__init__.py +0 -0
  164. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/gpt/__init__.py +0 -0
  165. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/gpt/model_specs.py +0 -0
  166. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py +0 -0
  167. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/mamba/__init__.py +0 -0
  168. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/modelopt_support/mamba/model_specs.py +0 -0
  169. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/sampling_params.py +0 -0
  170. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/scheduler.py +0 -0
  171. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  172. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  173. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  174. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +0 -0
  175. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  176. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference/utils.py +0 -0
  177. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/inference_params.py +0 -0
  178. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/jit.py +0 -0
  179. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/model_parallel_config.py +0 -0
  180. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/T5/__init__.py +0 -0
  181. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/T5/t5_model.py +0 -0
  182. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/T5/t5_spec.py +0 -0
  183. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/__init__.py +0 -0
  184. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/__init__.py +0 -0
  185. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  186. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/bert_lm_head.py +0 -0
  187. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/bert_model.py +0 -0
  188. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/bert/pooler.py +0 -0
  189. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/__init__.py +0 -0
  190. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/__init__.py +0 -0
  191. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  192. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  193. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  194. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  195. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/language_module/__init__.py +0 -0
  196. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/language_module/language_module.py +0 -0
  197. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/vision_module/__init__.py +0 -0
  198. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  199. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/gpt/__init__.py +0 -0
  200. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/gpt/gpt_model.py +0 -0
  201. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  202. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/huggingface/__init__.py +0 -0
  203. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/huggingface/clip_model.py +0 -0
  204. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/huggingface/module.py +0 -0
  205. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/huggingface/qwen_model.py +0 -0
  206. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/mamba/__init__.py +0 -0
  207. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  208. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/mamba/mamba_model.py +0 -0
  209. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/multimodal/__init__.py +0 -0
  210. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/multimodal/context_parallel.py +0 -0
  211. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/multimodal/llava_model.py +0 -0
  212. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/multimodal/llava_spec.py +0 -0
  213. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/__init__.py +0 -0
  214. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/base_attention.py +0 -0
  215. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/config.py +0 -0
  216. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/decoder_attention.py +0 -0
  217. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/decoder_spec.py +0 -0
  218. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/encoder_attention.py +0 -0
  219. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/encoder_spec.py +0 -0
  220. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/model.py +0 -0
  221. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/retro/utils.py +0 -0
  222. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/__init__.py +0 -0
  223. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/clip_vit_model.py +0 -0
  224. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/multimodal_projector.py +0 -0
  225. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/radio.py +0 -0
  226. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  227. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/num_microbatches_calculator.py +0 -0
  228. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/clip_grads.py +0 -0
  229. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  230. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  231. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/distrib_optimizer.py +0 -0
  232. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/grad_scaler.py +0 -0
  233. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer/optimizer_config.py +0 -0
  234. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/optimizer_param_scheduler.py +0 -0
  235. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/packed_seq_params.py +0 -0
  236. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/parallel_state.py +0 -0
  237. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/pipeline_parallel/__init__.py +0 -0
  238. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  239. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/pipeline_parallel/schedules.py +0 -0
  240. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/__init__.py +0 -0
  241. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/__init__.py +0 -0
  242. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  243. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  244. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/layers.py +0 -0
  245. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  246. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  247. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/process_groups_config.py +0 -0
  248. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/requirements.txt +0 -0
  249. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/rerun_state_machine.py +0 -0
  250. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/__init__.py +0 -0
  251. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_block.py +0 -0
  252. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_config.py +0 -0
  253. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  254. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_layer.py +0 -0
  255. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mamba_mixer.py +0 -0
  256. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/mlp_layer.py +0 -0
  257. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/ssm/triton_cache_manager.py +0 -0
  258. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  259. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/data.py +0 -0
  260. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/layers.py +0 -0
  261. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/mappings.py +0 -0
  262. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/tensor_parallel/utils.py +0 -0
  263. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/timers.py +0 -0
  264. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/__init__.py +0 -0
  265. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/cuda_graphs.py +0 -0
  266. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  267. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  268. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/dot_product_attention.py +0 -0
  269. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/enums.py +0 -0
  270. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/identity_op.py +0 -0
  271. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/module.py +0 -0
  272. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/__init__.py +0 -0
  273. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  274. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  275. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/router.py +0 -0
  276. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  277. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/multi_token_prediction.py +0 -0
  278. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/spec_utils.py +0 -0
  279. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/torch_layer_norm.py +0 -0
  280. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/transformer/utils.py +0 -0
  281. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron/core/utils.py +0 -0
  282. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron_core.egg-info/dependency_links.txt +0 -0
  283. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron_core.egg-info/requires.txt +0 -0
  284. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/megatron_core.egg-info/top_level.txt +0 -0
  285. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/pyproject.toml +0 -0
  286. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/requirements/pytorch_24.01/requirements.txt +0 -0
  287. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/requirements/pytorch_24.07/requirements.txt +0 -0
  288. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/requirements/pytorch_24.10/requirements.txt +0 -0
  289. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/requirements/pytorch_25.03/requirements.txt +0 -0
  290. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/setup.cfg +0 -0
  291. {megatron_core-0.12.0rc3 → megatron_core-0.12.2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.12.0rc3
3
+ Version: 0.12.2
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Home-page: https://github.com/NVIDIA/Megatron-LM/megatron/core
6
6
  Download-URL: https://github.com/NVIDIA/Megatron-LM/releases
@@ -155,7 +155,7 @@ class AsyncCaller(ABC):
155
155
  logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")
156
156
 
157
157
  def __del__(self):
158
- self.close()
158
+ raise NotImplementedError("This should be implemented")
159
159
 
160
160
 
161
161
  class TemporalAsyncCaller(AsyncCaller):
@@ -227,12 +227,22 @@ class TemporalAsyncCaller(AsyncCaller):
227
227
  is_alive = int(self.process.is_alive()) if self.process is not None else 0
228
228
  is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
229
229
 
230
- if not is_done and blocking:
230
+ if is_done or blocking:
231
+ # Process join is called in the following cases
232
+ # 1. blocking == True -> regardless of is_done
233
+ # 2. blocking == False (non-blocking)
234
+ # -> is_done == True: async requests on all ranks are identified to be finished
235
+ # `self.close()` makes sure the async callers terminated
231
236
  self.close()
232
237
  is_done = True
233
238
  return is_done
234
239
 
235
240
  def close(self):
241
+ """For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done`
242
+
243
+ This method make sure the TemporalAsyncCaller terminated
244
+ with all its assigned async request completed
245
+ """
236
246
  if self.process:
237
247
  logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
238
248
  self.process.join()
@@ -243,6 +253,9 @@ class TemporalAsyncCaller(AsyncCaller):
243
253
  )
244
254
  self.start_time = None
245
255
 
256
+ def __del__(self):
257
+ pass
258
+
246
259
 
247
260
  class PersistentAsyncCaller(AsyncCaller):
248
261
  """Wrapper around mp.Process that ensures correct semantic of distributed finalization.
@@ -376,6 +389,10 @@ class PersistentAsyncCaller(AsyncCaller):
376
389
  return is_done
377
390
 
378
391
  def close(self):
392
+ """Wait on the left async requests and terminate the PersistentAsyncCaller
393
+
394
+ Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated
395
+ """
379
396
  logger.info(
380
397
  f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller"
381
398
  )
@@ -385,6 +402,9 @@ class PersistentAsyncCaller(AsyncCaller):
385
402
  self.process.join()
386
403
  self.process = None
387
404
 
405
+ def __del__(self):
406
+ self.close()
407
+
388
408
  @staticmethod
389
409
  @_disable_gc()
390
410
  def async_loop(
@@ -492,13 +512,11 @@ class AsyncCallsQueue:
492
512
  # Backward compatibility for local checkpointing built with the old AsyncRequest
493
513
  if len(async_request._fields) != len(AsyncRequest._fields):
494
514
  async_request = AsyncRequest(**async_request._asdict())
495
-
496
- async_request = async_request._replace(call_idx=self.call_idx)
497
- finalize_fns = async_request.finalize_fns
498
- async_request = async_request._replace(finalize_fns=None)
499
515
  async_request = async_request.freeze()
500
- async_caller.schedule_async_call(async_request)
501
- self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, finalize_fns))
516
+ async_caller.schedule_async_call(
517
+ async_request._replace(call_idx=self.call_idx, finalize_fns=[])
518
+ )
519
+ self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
502
520
  return self.call_idx
503
521
 
504
522
  def maybe_finalize_async_calls(self, blocking=False, no_dist=False) -> List[int]:
@@ -522,13 +540,13 @@ class AsyncCallsQueue:
522
540
  if not next_async_done:
523
541
  break
524
542
  with debug_time("finalize", logger):
525
- call_idx, _, finalize_fns = self.async_calls.popleft()
543
+ call_idx, _, async_request = self.async_calls.popleft()
544
+ for finalize_fn in async_request.finalize_fns:
545
+ finalize_fn()
526
546
  ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
527
547
  torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
528
548
  assert ten.item() == call_idx, 'Unmatched async calls. '
529
549
  'That probably means not all ranks are participating in async finalization'
530
- for finalize_fn in finalize_fns:
531
- finalize_fn()
532
550
  call_idx_finalized.append(call_idx)
533
551
  return call_idx_finalized
534
552
 
@@ -76,7 +76,10 @@ class FullyShardedDataParallel(_BaseDataParallel):
76
76
  module: Underlying model.
77
77
  fsdp_unit_modules: List of modules that should be treated as FSDP Unit,
78
78
  i.e., the minimum releasable model unit. If not provided, defaults to
79
- [TransformerLayer, LanguageModelEmbedding] for GPT-like models.
79
+ [TransformerLayer, LanguageModelEmbedding] for GPT-like models. In
80
+ addition to this, it affects the granularity of the communication
81
+ parameter grouping and triggers aggregate collective communication
82
+ in fp8 mixed precision training.
80
83
  disable_bucketing: If true, force assign all parameters to a single bucket. If false,
81
84
  use standard bucketing policy: assign parameters to smaller buckets and all-reduce
82
85
  per bucket.
@@ -122,7 +125,10 @@ class FullyShardedDataParallel(_BaseDataParallel):
122
125
  if fsdp_unit_modules is not None:
123
126
  self.fsdp_unit_modules = fsdp_unit_modules
124
127
  else:
125
- self.fsdp_unit_modules = [TransformerLayer]
128
+ if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
129
+ self.fsdp_unit_modules = [TransformerLayer]
130
+ else:
131
+ self.fsdp_unit_modules = []
126
132
  self.main_weights = True
127
133
  self.data_parallel_group = parallel_state.get_data_parallel_group(
128
134
  with_context_parallel=True
@@ -177,14 +183,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
177
183
  self.module,
178
184
  bucketing_policy=BucketingPolicy(
179
185
  suggested_bucket_size=self.bucket_size,
180
- fsdp_unit_modules=(
181
- # Only when model weights need to be sharded, we need to
182
- # identify the minimum releasable model unit, which is the
183
- # FSDP Unit Module.
184
- self.fsdp_unit_modules
185
- if self.data_parallel_sharding_strategy == "optim_grads_params"
186
- else []
187
- ),
186
+ fsdp_unit_modules=self.fsdp_unit_modules,
188
187
  data_parallel_sharding_strategy=self.data_parallel_sharding_strategy,
189
188
  ),
190
189
  data_parallel_group=self.data_parallel_group,
@@ -208,8 +207,24 @@ class FullyShardedDataParallel(_BaseDataParallel):
208
207
  # Initialize the all-gather pipeline.
209
208
  self.all_gather_pipeline = AllGatherPipeline(self.param_and_grad_buffer)
210
209
 
211
- self.suggested_RS_queue_capacity = self.ddp_config.suggested_communication_unit_size
212
- self.suggested_AG_prefetch_size = self.ddp_config.suggested_communication_unit_size
210
+ suggested_communication_unit_size = self.ddp_config.suggested_communication_unit_size
211
+ if suggested_communication_unit_size is None:
212
+ if self.data_parallel_sharding_strategy == "optim_grads_params":
213
+ total_param_elements = 0
214
+ total_fsdp_module = 0
215
+ for module in self.module.modules():
216
+ if isinstance(module, tuple(self.fsdp_unit_modules)):
217
+ total_fsdp_module += 1
218
+ total_param_elements += sum(p.numel() for p in module.parameters())
219
+ # The suggested size is twice the number of elements in the FSDP modules.
220
+ # This ensures we process the current FSDP module and attempt to prefetch
221
+ # the next FSDP module, making the flow of communication better.
222
+ suggested_communication_unit_size = total_param_elements // total_fsdp_module * 2
223
+ elif self.bucket_size is not None:
224
+ suggested_communication_unit_size = self.bucket_size * 2
225
+
226
+ self.suggested_RS_queue_capacity = suggested_communication_unit_size
227
+ self.suggested_AG_prefetch_size = suggested_communication_unit_size
213
228
 
214
229
  def _register_fsdp_hooks(self, root_module):
215
230
  """Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
@@ -219,8 +234,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
219
234
  - Pre-forward hook: Unshards parameters before forward pass
220
235
  - Post-forward hook: Reshards parameters after forward pass
221
236
  - Pre-backward hook: Unshards parameters before backward pass
222
- - Post-backward hook: Reshards parameters after backward pass
223
- - Gradient accumulation hook: Handles gradient accumulation and reduction across devices
237
+ - Post-backward hook: Reshards parameters and reduces gradients after backward pass
224
238
 
225
239
  Args:
226
240
  root_module: The PyTorch module to register FSDP hooks on
@@ -254,10 +268,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
254
268
  `optim` and `optim_grads` do not require FSDP units because they do not
255
269
  shard model parameters.
256
270
  """
257
- if self.data_parallel_sharding_strategy != "optim_grads_params":
258
- fsdp_unit_modules = []
259
- else:
260
- fsdp_unit_modules = self.fsdp_unit_modules
271
+ fsdp_unit_modules = self.fsdp_unit_modules
261
272
 
262
273
  def release_module_parameters(module, *unused):
263
274
  for param in module.parameters():
@@ -280,27 +291,74 @@ class FullyShardedDataParallel(_BaseDataParallel):
280
291
  prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
281
292
  wait_bucket_ready=True,
282
293
  ):
283
- wait_list = []
284
294
  ag_pipeline = self.all_gather_pipeline
285
- for param in module.parameters():
286
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
287
- ag_pipeline.queue_bucket_to_all_gather(
288
- bucket_id,
289
- prefetch=prefetch,
290
- prefetch_order=prefetch_order,
291
- suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
292
- )
293
- wait_list.append(bucket_id)
294
-
295
+ ag_pipeline.all_gather_params(
296
+ params=list(module.parameters()),
297
+ prefetch=prefetch,
298
+ prefetch_order=prefetch_order,
299
+ suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
300
+ )
295
301
  if wait_bucket_ready:
296
- for bucket_id in wait_list:
302
+ for param in module.parameters():
303
+ bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
297
304
  ag_pipeline.wait_bucket_ready(bucket_id)
298
305
 
306
+ def _grad_acc(param):
307
+ """
308
+ Accumulate the gradient in the main_grad buffer.
309
+ """
310
+ group_id = self.param_and_grad_buffer.param_to_param_group[param]
311
+ group = self.param_and_grad_buffer.parameter_groups[group_id]
312
+ if not group.requires_grad:
313
+ return
314
+
315
+ overwrite_main_grad = self.ddp_config.data_parallel_sharding_strategy in [
316
+ "optim_grads",
317
+ "optim_grads_params",
318
+ ]
319
+ if overwrite_main_grad:
320
+ if not param.grad_added_to_main_grad:
321
+ if param.grad is not None:
322
+ param.main_grad.copy_(param.grad)
323
+ del param.grad
324
+ else:
325
+ param.main_grad.zero_()
326
+ else:
327
+ if not param.grad_added_to_main_grad:
328
+ if param.grad is not None:
329
+ param.main_grad.add_(param.grad)
330
+ del param.grad
331
+ # Reset the grad accumulate flag.
332
+ param.grad_added_to_main_grad = False
333
+
334
+ self._params_require_handle_grad = set()
335
+
299
336
  def _post_backward(module, *unused):
300
- release_module_parameters(module)
301
- module._training_state = TrainingState.IDLE
337
+ if isinstance(module, tuple(fsdp_unit_modules)):
338
+ if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
339
+ release_module_parameters(module)
340
+ module._training_state = TrainingState.IDLE
341
+ param_list = list(module.parameters())
342
+ else:
343
+ param_list = list(module.parameters(recurse=False))
344
+
345
+ for param in param_list:
346
+ _grad_acc(param)
347
+ self._params_require_handle_grad.discard(param)
348
+
349
+ grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
350
+ "optim_grads",
351
+ "optim_grads_params",
352
+ ]
353
+ if grad_reduce_every_bprop or self.is_last_microbatch:
354
+ self.grad_reduce_pipeline.reduce_gradients(
355
+ param_list, suggested_queue_capacity=self.suggested_RS_queue_capacity
356
+ )
302
357
 
303
- def _pre_forward(module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]):
358
+ def _pre_forward_param_unshard(
359
+ module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
360
+ ):
361
+ # Unshard the parameters before the forward pass.
304
362
  input_training_state = module._training_state
305
363
  fsdp_forward_prefetch = True
306
364
  if input_training_state == TrainingState.PRE_BACKWARD:
@@ -310,72 +368,104 @@ class FullyShardedDataParallel(_BaseDataParallel):
310
368
  module._training_state = TrainingState.FORWARD
311
369
 
312
370
  if isinstance(module, tuple(fsdp_unit_modules)):
313
- wait_list = []
314
- for param in module.parameters():
371
+ param_list = list(module.parameters())
372
+ self.all_gather_pipeline.all_gather_params(
373
+ params=param_list,
374
+ prefetch=fsdp_forward_prefetch,
375
+ suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
376
+ )
377
+ for param in param_list:
315
378
  bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
316
- self.all_gather_pipeline.queue_bucket_to_all_gather(
317
- bucket_id,
318
- prefetch=fsdp_forward_prefetch,
319
- suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
320
- )
321
- wait_list.append(bucket_id)
322
- for bucket_id in wait_list:
323
379
  self.all_gather_pipeline.wait_bucket_ready(bucket_id)
324
-
325
- if not torch.is_grad_enabled():
326
- return args, kwargs
327
-
328
- # Register the backward function to release the parameters.
329
- args_list, args_spec = tree_flatten(args)
330
- kwargs_list, kwargs_spec = tree_flatten(kwargs)
331
- args_kwargs_list = list(args_list) + list(kwargs_list)
332
- inp_tensor_indices: List[int] = []
333
- inp_tensors: List[torch.Tensor] = []
334
- for i, obj in enumerate(args_kwargs_list):
335
- if torch.is_tensor(obj) and obj.requires_grad:
336
- inp_tensor_indices.append(i)
337
- inp_tensors.append(obj)
338
- if len(inp_tensors) == 0:
339
- return args, kwargs
340
- inp_tensors = RegisterFSDPBackwardFunction.apply(
341
- functools.partial(_post_backward, module), *inp_tensors
342
- )
343
- for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
344
- args_kwargs_list[inp_tensor_idx] = inp_tensor
345
- args_list = args_kwargs_list[: len(args_list)]
346
- kwargs_list = args_kwargs_list[len(args_list) :]
347
- args = tree_unflatten(args_list, args_spec)
348
- kwargs = tree_unflatten(kwargs_list, kwargs_spec)
349
-
350
- return args, kwargs
351
380
  else:
352
381
  # All-gather the parameters in every forward pass for FSDP.
353
- for param in module.parameters(recurse=False):
354
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
355
- self.all_gather_pipeline.queue_bucket_to_all_gather(
356
- bucket_id,
357
- prefetch=fsdp_forward_prefetch,
358
- suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
359
- )
360
- for param in module.parameters(recurse=False):
382
+ param_list = list(module.parameters(recurse=False))
383
+ self.all_gather_pipeline.all_gather_params(
384
+ params=param_list,
385
+ prefetch=fsdp_forward_prefetch,
386
+ suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
387
+ )
388
+ for param in param_list:
361
389
  bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
362
390
  self.all_gather_pipeline.wait_bucket_ready(bucket_id)
391
+ return args, kwargs
392
+
393
+ def _register_post_backward_hook(
394
+ post_backward_hook: callable,
395
+ module: nn.Module,
396
+ args: Tuple[Any, ...],
397
+ kwargs: Dict[str, Any],
398
+ ):
399
+ # Register the backward function to reduce gradients after the backward pass.
400
+ # And for optim_grads_params, we need to release the parameters after the backward pass.
401
+ if not torch.is_grad_enabled():
402
+ return args, kwargs
403
+
404
+ args_list, args_spec = tree_flatten(args)
405
+ kwargs_list, kwargs_spec = tree_flatten(kwargs)
406
+ args_kwargs_list = list(args_list) + list(kwargs_list)
407
+ inp_tensor_indices: List[int] = []
408
+ inp_tensors: List[torch.Tensor] = []
409
+ for i, obj in enumerate(args_kwargs_list):
410
+ if torch.is_tensor(obj) and obj.requires_grad:
411
+ inp_tensor_indices.append(i)
412
+ inp_tensors.append(obj)
413
+
414
+ if len(inp_tensors) == 0:
415
+ return args, kwargs
416
+
417
+ inp_tensors = RegisterFSDPBackwardFunction.apply(
418
+ functools.partial(post_backward_hook, module), *inp_tensors
419
+ )
420
+
421
+ for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
422
+ args_kwargs_list[inp_tensor_idx] = inp_tensor
423
+ args_list = args_kwargs_list[: len(args_list)]
424
+ kwargs_list = args_kwargs_list[len(args_list) :]
425
+ args = tree_unflatten(args_list, args_spec)
426
+ kwargs = tree_unflatten(kwargs_list, kwargs_spec)
363
427
 
364
428
  return args, kwargs
365
429
 
366
- if self.ddp_config.overlap_param_gather:
367
- fsdp_modules = []
368
- for name, module in root_module.named_modules():
369
- if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
370
- if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
371
- continue
430
+ fsdp_modules = []
431
+ for name, module in root_module.named_modules():
432
+ if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
433
+ continue
372
434
 
373
- if isinstance(module, tuple(fsdp_unit_modules)):
374
- fsdp_modules.append(module)
435
+ if isinstance(module, tuple(fsdp_unit_modules)):
436
+ fsdp_modules.append(module)
437
+
438
+ self.forward_pre_hooks[f'module {name} parameter unshard'] = (
439
+ module.register_forward_pre_hook(
440
+ _pre_forward_param_unshard, prepend=True, with_kwargs=True
441
+ )
442
+ )
443
+ self.forward_pre_hooks[f"module {name} register post-backward hook"] = (
444
+ module.register_forward_pre_hook(
445
+ functools.partial(_register_post_backward_hook, _post_backward),
446
+ with_kwargs=True,
447
+ )
448
+ )
375
449
 
376
- self.forward_pre_hooks[f'module {name} parameter all-gather'] = (
377
- module.register_forward_pre_hook(_pre_forward, prepend=True, with_kwargs=True)
450
+ def _root_post_backward(*unused):
451
+ # Make sure all the gradients are handled.
452
+ for param in self._params_require_handle_grad:
453
+ _grad_acc(param)
454
+
455
+ # Reduce the remain gradients.
456
+ grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
457
+ "optim_grads",
458
+ "optim_grads_params",
459
+ ]
460
+ if grad_reduce_every_bprop or self.is_last_microbatch:
461
+ self.grad_reduce_pipeline.reduce_gradients(
462
+ list(self._params_require_handle_grad),
463
+ suggested_queue_capacity=self.suggested_RS_queue_capacity,
378
464
  )
465
+ self.grad_reduce_pipeline.reset()
466
+
467
+ # Reset root_pre_backward_hook_issued flag.
468
+ self._root_pre_backward_hook_issued = False
379
469
 
380
470
  def _pre_backward(module: nn.Module, *unused):
381
471
  module._training_state = TrainingState.PRE_BACKWARD
@@ -384,6 +474,8 @@ class FullyShardedDataParallel(_BaseDataParallel):
384
474
  module, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER
385
475
  )
386
476
 
477
+ self._root_pre_backward_hook_issued = False
478
+
387
479
  def _root_pre_backward(module: nn.Module, *unused):
388
480
  """Marks the module's training state as 'pre_backward' before the
389
481
  backprop, this function is registered on the root module.
@@ -392,13 +484,26 @@ class FullyShardedDataParallel(_BaseDataParallel):
392
484
  perform reshard/unshard operations in activation recomputation
393
485
  scenarios.
394
486
  """
395
- for module in root_module.modules():
396
- if isinstance(module, tuple(fsdp_unit_modules)):
397
- module._training_state = TrainingState.PRE_BACKWARD
398
- for param in module.parameters():
399
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
400
- self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True)
401
- self.all_gather_pipeline.release_bucket(bucket_id)
487
+ if self._root_pre_backward_hook_issued:
488
+ return
489
+ self._root_pre_backward_hook_issued = True
490
+
491
+ if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
492
+ for module in root_module.modules():
493
+ if isinstance(module, tuple(fsdp_unit_modules)):
494
+ module._training_state = TrainingState.PRE_BACKWARD
495
+ for param in module.parameters():
496
+ bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
497
+ self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True)
498
+ self.all_gather_pipeline.release_bucket(bucket_id)
499
+ self._params_require_handle_grad = set()
500
+ for param_group in self.param_and_grad_buffer.parameter_groups:
501
+ if not param_group.requires_grad:
502
+ continue
503
+ self._params_require_handle_grad |= set(param_group.params)
504
+ for param in param_group.params:
505
+ param.grad_added_to_main_grad = False
506
+ torch.autograd.Variable._execution_engine.queue_callback(_root_post_backward)
402
507
 
403
508
  def _post_forward(module: nn.Module, input: Any, output: Any):
404
509
  # When composing with module-hook-based activation checkpointing, the
@@ -414,7 +519,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
414
519
  def _release_module_fp8_transpose_cache(module: nn.Module, *unused):
415
520
  release_params_fp8_transpose_cache(module.parameters(recurse=False))
416
521
 
417
- if self.data_parallel_sharding_strategy == "optim_grads_params":
522
+ if len(fsdp_unit_modules) != 0:
418
523
  fsdp_modules = []
419
524
  for name, module in root_module.named_modules():
420
525
  if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
@@ -434,68 +539,20 @@ class FullyShardedDataParallel(_BaseDataParallel):
434
539
  _release_module_fp8_transpose_cache, prepend=False
435
540
  )
436
541
  )
437
- self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
438
- _root_pre_backward
439
- )
440
-
441
- def _make_param_hook(param: torch.nn.Parameter):
442
- """
443
- Creates the all-reduce / reduce-scatter hook for backprop.
444
- """
445
542
 
446
- wait_previous_grad_reduce = not self.is_delay_grad_reduce
447
-
448
- # FIXME: Use insert forward op to replace grad acc hook, which will
449
- # be lost after parameter data movement. For example, module.cuda()
450
- # will cause the registered grad acc hook to be lost.
451
- def param_hook(*unused):
452
- if param.requires_grad:
453
- if self.ddp_config.overlap_grad_reduce:
454
- assert (
455
- param.grad is not None
456
- ), 'param.grad being None is not safe when overlap_grad_reduce is True'
457
-
458
- if param.grad is not None and (
459
- not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
460
- ):
461
- if self.is_delay_grad_reduce:
462
- param.main_grad.add_(param.grad.data)
463
- else:
464
- param.main_grad.copy_(param.grad.data)
465
- param.grad = None
466
-
467
- if self.ddp_config.overlap_grad_reduce and (
468
- not self.is_delay_grad_reduce or self.is_last_microbatch
469
- ):
470
- gr_pipeline = self.grad_reduce_pipeline
471
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
472
- gr_pipeline.place_bucket(bucket_id)
473
- go_rs = gr_pipeline.mark_item_ready(param, async_rs=True)
474
- if go_rs and wait_previous_grad_reduce:
475
- gr_pipeline.wait_for_previous_grad_reduce(
476
- recommeded_queue_capacity=self.suggested_RS_queue_capacity
477
- )
543
+ # Registering all models with all parameters is to handle some special cases
544
+ # where the forward function of root_module is not called, but the forward
545
+ # functions of these equivalent modules are called instead.
546
+ for name, module in root_module.named_modules():
547
+ if len(list(module.parameters())) != len(list(root_module.parameters())):
548
+ continue
478
549
 
479
- return param_hook
480
-
481
- # Register backward gradient accumulation hook for each parameter.
482
- self.grad_accs = []
483
- for param in root_module.parameters():
484
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
485
- wbuf = self.param_and_grad_buffer.parameter_groups[bucket_id].model_weight_buffer
486
- if param.requires_grad:
487
- if wbuf and wbuf.is_data_distributed:
488
- wbuf.fetch_bucket(and_allocate_params_data=True)
489
-
490
- # Expand so we get access to grad_fn.
491
- param_tmp = param.expand_as(param)
492
- # Get the gradient accumulator function.
493
- grad_acc = param_tmp.grad_fn.next_functions[0][0]
494
- grad_acc.register_hook(_make_param_hook(param))
495
- self.grad_accs.append(grad_acc)
496
-
497
- if wbuf and wbuf.is_data_distributed:
498
- wbuf.free_bucket_storage()
550
+ self.backward_pre_hooks[f"{name} _root_pre_backward"] = (
551
+ module.register_full_backward_pre_hook(_root_pre_backward)
552
+ )
553
+ self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
554
+ _root_pre_backward
555
+ )
499
556
 
500
557
  @contextmanager
501
558
  def no_sync(self):
@@ -526,7 +583,8 @@ class FullyShardedDataParallel(_BaseDataParallel):
526
583
  """
527
584
  if not force_sync and self.ddp_config.overlap_param_gather:
528
585
  # All-gather the first bucket before the forward pass.
529
- self.all_gather_pipeline.queue_bucket_to_all_gather(bucket_id=0, prefetch=False)
586
+ first_param = list(self.module.parameters())[0]
587
+ self.all_gather_pipeline.all_gather_params(params=[first_param], prefetch=False)
530
588
  else:
531
589
  self.all_gather_pipeline.reset()
532
590
  for bucket_id in range(self.all_gather_pipeline.num_buckets):