megatron-core 0.12.0rc2__tar.gz → 0.12.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (291) hide show
  1. {megatron_core-0.12.0rc2/megatron_core.egg-info → megatron_core-0.12.1}/PKG-INFO +1 -1
  2. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/async_utils.py +29 -11
  3. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +214 -159
  4. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +275 -186
  5. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/distributed_data_parallel_config.py +7 -3
  6. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/finalize_model_grads.py +3 -2
  7. megatron_core-0.12.1/megatron/core/export/model_type.py +8 -0
  8. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +6 -0
  9. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +10 -0
  10. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trt_model_config.py +1 -0
  11. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trt_model_type.py +1 -0
  12. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_helper.py +20 -2
  13. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_layers.py +9 -0
  14. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +10 -4
  15. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +17 -5
  16. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/extensions/transformer_engine.py +34 -8
  17. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fp8_utils.py +15 -8
  18. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_bias_swiglu.py +57 -0
  19. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/contexts/dynamic_context.py +19 -1
  20. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/dynamic_engine.py +8 -2
  21. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +6 -3
  22. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/rope_utils.py +20 -3
  23. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/gpt/gpt_layer_specs.py +27 -6
  24. megatron_core-0.12.1/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +209 -0
  25. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/__init__.py +16 -5
  26. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/optimizer.py +99 -21
  27. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/package_info.py +2 -2
  28. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/gpt/model_specs.py +12 -4
  29. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/__init__.py +2 -0
  30. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/random.py +149 -23
  31. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/attention.py +4 -1
  32. megatron_core-0.12.1/megatron/core/transformer/heterogeneous/heterogeneous_config.py +267 -0
  33. megatron_core-0.12.1/megatron/core/transformer/heterogeneous/linear_replacements.py +111 -0
  34. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/mlp.py +37 -18
  35. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/experts.py +166 -60
  36. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py +18 -11
  37. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/moe_layer.py +20 -10
  38. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/moe_utils.py +91 -19
  39. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/shared_experts.py +4 -0
  40. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/token_dispatcher.py +63 -64
  41. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/multi_latent_attention.py +121 -69
  42. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/torch_norm.py +49 -1
  43. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/transformer_block.py +22 -5
  44. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/transformer_config.py +88 -14
  45. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/transformer_layer.py +51 -5
  46. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/utils.py +25 -0
  47. {megatron_core-0.12.0rc2 → megatron_core-0.12.1/megatron_core.egg-info}/PKG-INFO +1 -1
  48. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron_core.egg-info/SOURCES.txt +3 -0
  49. megatron_core-0.12.0rc2/megatron/core/export/model_type.py +0 -7
  50. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/LICENSE +0 -0
  51. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/MANIFEST.in +0 -0
  52. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/README.md +0 -0
  53. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/README.md +0 -0
  54. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/__init__.py +0 -0
  55. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/config.py +0 -0
  56. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/config_logger.py +0 -0
  57. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/__init__.py +0 -0
  58. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/bert_dataset.py +0 -0
  59. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/blended_dataset.py +0 -0
  60. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  61. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  62. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/gpt_dataset.py +0 -0
  63. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/helpers.cpp +0 -0
  64. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/helpers.py +0 -0
  65. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/indexed_dataset.py +0 -0
  66. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/masked_dataset.py +0 -0
  67. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/megatron_dataset.py +0 -0
  68. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  69. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/multimodal_dataset.py +0 -0
  70. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/__init__.py +0 -0
  71. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/__init__.py +0 -0
  72. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  73. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/config.py +0 -0
  74. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  75. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  76. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/db/__init__.py +0 -0
  77. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/db/build.py +0 -0
  78. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/db/dataset.py +0 -0
  79. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/db/utils.py +0 -0
  80. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/external_libs.py +0 -0
  81. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/__init__.py +0 -0
  82. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/build.py +0 -0
  83. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/factory.py +0 -0
  84. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/index.py +0 -0
  85. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  86. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  87. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  88. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/utils.py +0 -0
  89. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/index/validate.py +0 -0
  90. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/__init__.py +0 -0
  91. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  92. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  93. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/query.py +0 -0
  94. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  95. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/query/utils.py +0 -0
  96. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/retro/utils.py +0 -0
  97. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/t5_dataset.py +0 -0
  98. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/utils.py +0 -0
  99. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/datasets/utils_s3.py +0 -0
  100. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/__init__.py +0 -0
  101. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/core.py +0 -0
  102. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  103. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  104. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/mapping.py +0 -0
  105. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  106. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/serialization.py +0 -0
  107. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  108. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  109. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  110. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  111. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  112. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  113. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  114. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  115. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  116. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  117. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  118. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  119. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  120. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  121. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/utils.py +0 -0
  122. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/dist_checkpointing/validation.py +0 -0
  123. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/__init__.py +0 -0
  124. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
  125. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/data_parallel_base.py +0 -0
  126. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  127. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/param_and_grad_buffer.py +0 -0
  128. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  129. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/enums.py +0 -0
  130. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/__init__.py +0 -0
  131. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/data_type.py +0 -0
  132. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/export_config.py +0 -0
  133. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/__init__.py +0 -0
  134. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  135. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  136. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  137. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/extensions/__init__.py +0 -0
  138. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/__init__.py +0 -0
  139. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  140. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  141. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  142. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  143. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_layer_norm.py +0 -0
  144. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/fusions/fused_softmax.py +0 -0
  145. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/__init__.py +0 -0
  146. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/async_stream.py +0 -0
  147. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/common_inference_params.py +0 -0
  148. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/communication_utils.py +0 -0
  149. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/contexts/__init__.py +0 -0
  150. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/contexts/base_context.py +0 -0
  151. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/contexts/static_context.py +0 -0
  152. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/__init__.py +0 -0
  153. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/abstract_engine.py +0 -0
  154. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/mcore_engine.py +0 -0
  155. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/engines/static_engine.py +0 -0
  156. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/inference_request.py +0 -0
  157. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  158. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  159. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  160. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  161. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  162. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  163. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  164. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  165. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/__init__.py +0 -0
  166. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/gpt/__init__.py +0 -0
  167. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/gpt/model_specs.py +0 -0
  168. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py +0 -0
  169. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/mamba/__init__.py +0 -0
  170. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/modelopt_support/mamba/model_specs.py +0 -0
  171. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/sampling_params.py +0 -0
  172. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/scheduler.py +0 -0
  173. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  174. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  175. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  176. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  177. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference/utils.py +0 -0
  178. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/inference_params.py +0 -0
  179. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/jit.py +0 -0
  180. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/model_parallel_config.py +0 -0
  181. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/T5/__init__.py +0 -0
  182. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/T5/t5_model.py +0 -0
  183. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/T5/t5_spec.py +0 -0
  184. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/__init__.py +0 -0
  185. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/__init__.py +0 -0
  186. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  187. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/bert_lm_head.py +0 -0
  188. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/bert_model.py +0 -0
  189. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/bert/pooler.py +0 -0
  190. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/__init__.py +0 -0
  191. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/__init__.py +0 -0
  192. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  193. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  194. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  195. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  196. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/language_module/__init__.py +0 -0
  197. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/language_module/language_module.py +0 -0
  198. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/vision_module/__init__.py +0 -0
  199. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  200. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/gpt/__init__.py +0 -0
  201. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/gpt/gpt_model.py +0 -0
  202. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  203. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/huggingface/__init__.py +0 -0
  204. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/huggingface/clip_model.py +0 -0
  205. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/huggingface/module.py +0 -0
  206. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/huggingface/qwen_model.py +0 -0
  207. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/mamba/__init__.py +0 -0
  208. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  209. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/mamba/mamba_model.py +0 -0
  210. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/multimodal/__init__.py +0 -0
  211. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/multimodal/context_parallel.py +0 -0
  212. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/multimodal/llava_model.py +0 -0
  213. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/multimodal/llava_spec.py +0 -0
  214. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/__init__.py +0 -0
  215. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/base_attention.py +0 -0
  216. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/config.py +0 -0
  217. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/decoder_attention.py +0 -0
  218. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/decoder_spec.py +0 -0
  219. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/encoder_attention.py +0 -0
  220. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/encoder_spec.py +0 -0
  221. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/model.py +0 -0
  222. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/retro/utils.py +0 -0
  223. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/__init__.py +0 -0
  224. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/clip_vit_model.py +0 -0
  225. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/multimodal_projector.py +0 -0
  226. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/radio.py +0 -0
  227. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  228. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/num_microbatches_calculator.py +0 -0
  229. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/clip_grads.py +0 -0
  230. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  231. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  232. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/distrib_optimizer.py +0 -0
  233. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/grad_scaler.py +0 -0
  234. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer/optimizer_config.py +0 -0
  235. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/optimizer_param_scheduler.py +0 -0
  236. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/packed_seq_params.py +0 -0
  237. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/parallel_state.py +0 -0
  238. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/pipeline_parallel/__init__.py +0 -0
  239. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  240. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/pipeline_parallel/schedules.py +0 -0
  241. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/__init__.py +0 -0
  242. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/__init__.py +0 -0
  243. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  244. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  245. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/layers.py +0 -0
  246. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  247. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  248. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/process_groups_config.py +0 -0
  249. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/requirements.txt +0 -0
  250. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/rerun_state_machine.py +0 -0
  251. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/__init__.py +0 -0
  252. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_block.py +0 -0
  253. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_config.py +0 -0
  254. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  255. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_layer.py +0 -0
  256. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mamba_mixer.py +0 -0
  257. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/mlp_layer.py +0 -0
  258. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/ssm/triton_cache_manager.py +0 -0
  259. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  260. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/data.py +0 -0
  261. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/layers.py +0 -0
  262. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/mappings.py +0 -0
  263. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/tensor_parallel/utils.py +0 -0
  264. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/timers.py +0 -0
  265. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/__init__.py +0 -0
  266. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/cuda_graphs.py +0 -0
  267. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  268. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  269. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/dot_product_attention.py +0 -0
  270. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/enums.py +0 -0
  271. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/identity_op.py +0 -0
  272. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/module.py +0 -0
  273. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/__init__.py +0 -0
  274. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  275. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  276. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/router.py +0 -0
  277. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  278. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/multi_token_prediction.py +0 -0
  279. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/spec_utils.py +0 -0
  280. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/torch_layer_norm.py +0 -0
  281. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron/core/transformer/utils.py +0 -0
  282. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron_core.egg-info/dependency_links.txt +0 -0
  283. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron_core.egg-info/requires.txt +0 -0
  284. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/megatron_core.egg-info/top_level.txt +0 -0
  285. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/pyproject.toml +0 -0
  286. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/requirements/pytorch_24.01/requirements.txt +0 -0
  287. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/requirements/pytorch_24.07/requirements.txt +0 -0
  288. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/requirements/pytorch_24.10/requirements.txt +0 -0
  289. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/requirements/pytorch_25.03/requirements.txt +0 -0
  290. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/setup.cfg +0 -0
  291. {megatron_core-0.12.0rc2 → megatron_core-0.12.1}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.12.0rc2
3
+ Version: 0.12.1
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Home-page: https://github.com/NVIDIA/Megatron-LM/megatron/core
6
6
  Download-URL: https://github.com/NVIDIA/Megatron-LM/releases
@@ -155,7 +155,7 @@ class AsyncCaller(ABC):
155
155
  logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")
156
156
 
157
157
  def __del__(self):
158
- self.close()
158
+ raise NotImplementedError("This should be implemented")
159
159
 
160
160
 
161
161
  class TemporalAsyncCaller(AsyncCaller):
@@ -227,12 +227,22 @@ class TemporalAsyncCaller(AsyncCaller):
227
227
  is_alive = int(self.process.is_alive()) if self.process is not None else 0
228
228
  is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
229
229
 
230
- if not is_done and blocking:
230
+ if is_done or blocking:
231
+ # Process join is called in the following cases
232
+ # 1. blocking == True -> regardless of is_done
233
+ # 2. blocking == False (non-blocking)
234
+ # -> is_done == True: async requests on all ranks are identified to be finished
235
+ # `self.close()` makes sure the async callers terminated
231
236
  self.close()
232
237
  is_done = True
233
238
  return is_done
234
239
 
235
240
  def close(self):
241
+ """For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done`
242
+
243
+ This method make sure the TemporalAsyncCaller terminated
244
+ with all its assigned async request completed
245
+ """
236
246
  if self.process:
237
247
  logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
238
248
  self.process.join()
@@ -243,6 +253,9 @@ class TemporalAsyncCaller(AsyncCaller):
243
253
  )
244
254
  self.start_time = None
245
255
 
256
+ def __del__(self):
257
+ pass
258
+
246
259
 
247
260
  class PersistentAsyncCaller(AsyncCaller):
248
261
  """Wrapper around mp.Process that ensures correct semantic of distributed finalization.
@@ -376,6 +389,10 @@ class PersistentAsyncCaller(AsyncCaller):
376
389
  return is_done
377
390
 
378
391
  def close(self):
392
+ """Wait on the left async requests and terminate the PersistentAsyncCaller
393
+
394
+ Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated
395
+ """
379
396
  logger.info(
380
397
  f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller"
381
398
  )
@@ -385,6 +402,9 @@ class PersistentAsyncCaller(AsyncCaller):
385
402
  self.process.join()
386
403
  self.process = None
387
404
 
405
+ def __del__(self):
406
+ self.close()
407
+
388
408
  @staticmethod
389
409
  @_disable_gc()
390
410
  def async_loop(
@@ -492,13 +512,11 @@ class AsyncCallsQueue:
492
512
  # Backward compatibility for local checkpointing built with the old AsyncRequest
493
513
  if len(async_request._fields) != len(AsyncRequest._fields):
494
514
  async_request = AsyncRequest(**async_request._asdict())
495
-
496
- async_request = async_request._replace(call_idx=self.call_idx)
497
- finalize_fns = async_request.finalize_fns
498
- async_request = async_request._replace(finalize_fns=None)
499
515
  async_request = async_request.freeze()
500
- async_caller.schedule_async_call(async_request)
501
- self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, finalize_fns))
516
+ async_caller.schedule_async_call(
517
+ async_request._replace(call_idx=self.call_idx, finalize_fns=[])
518
+ )
519
+ self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
502
520
  return self.call_idx
503
521
 
504
522
  def maybe_finalize_async_calls(self, blocking=False, no_dist=False) -> List[int]:
@@ -522,13 +540,13 @@ class AsyncCallsQueue:
522
540
  if not next_async_done:
523
541
  break
524
542
  with debug_time("finalize", logger):
525
- call_idx, _, finalize_fns = self.async_calls.popleft()
543
+ call_idx, _, async_request = self.async_calls.popleft()
544
+ for finalize_fn in async_request.finalize_fns:
545
+ finalize_fn()
526
546
  ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
527
547
  torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
528
548
  assert ten.item() == call_idx, 'Unmatched async calls. '
529
549
  'That probably means not all ranks are participating in async finalization'
530
- for finalize_fn in finalize_fns:
531
- finalize_fn()
532
550
  call_idx_finalized.append(call_idx)
533
551
  return call_idx_finalized
534
552
 
@@ -22,7 +22,6 @@ from megatron.core.distributed.custom_fsdp.param_and_grad_buffer import (
22
22
  from megatron.core.distributed.data_parallel_base import _BaseDataParallel
23
23
  from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
24
24
  from megatron.core.fp8_utils import is_float8tensor
25
- from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
26
25
  from megatron.core.transformer.transformer_config import TransformerConfig
27
26
  from megatron.core.transformer.transformer_layer import TransformerLayer
28
27
  from megatron.core.utils import is_submodule, log_single_rank
@@ -77,7 +76,10 @@ class FullyShardedDataParallel(_BaseDataParallel):
77
76
  module: Underlying model.
78
77
  fsdp_unit_modules: List of modules that should be treated as FSDP Unit,
79
78
  i.e., the minimum releasable model unit. If not provided, defaults to
80
- [TransformerLayer, LanguageModelEmbedding] for GPT-like models.
79
+ [TransformerLayer, LanguageModelEmbedding] for GPT-like models. In
80
+ addition to this, it affects the granularity of the communication
81
+ parameter grouping and triggers aggregate collective communication
82
+ in fp8 mixed precision training.
81
83
  disable_bucketing: If true, force assign all parameters to a single bucket. If false,
82
84
  use standard bucketing policy: assign parameters to smaller buckets and all-reduce
83
85
  per bucket.
@@ -123,9 +125,10 @@ class FullyShardedDataParallel(_BaseDataParallel):
123
125
  if fsdp_unit_modules is not None:
124
126
  self.fsdp_unit_modules = fsdp_unit_modules
125
127
  else:
126
- self.fsdp_unit_modules = [TransformerLayer]
127
- if not getattr(self.module, "share_embeddings_and_output_weights", False):
128
- self.fsdp_unit_modules.append(LanguageModelEmbedding)
128
+ if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
129
+ self.fsdp_unit_modules = [TransformerLayer]
130
+ else:
131
+ self.fsdp_unit_modules = []
129
132
  self.main_weights = True
130
133
  self.data_parallel_group = parallel_state.get_data_parallel_group(
131
134
  with_context_parallel=True
@@ -180,14 +183,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
180
183
  self.module,
181
184
  bucketing_policy=BucketingPolicy(
182
185
  suggested_bucket_size=self.bucket_size,
183
- fsdp_unit_modules=(
184
- # Only when model weights need to be sharded, we need to
185
- # identify the minimum releasable model unit, which is the
186
- # FSDP Unit Module.
187
- self.fsdp_unit_modules
188
- if self.data_parallel_sharding_strategy == "optim_grads_params"
189
- else []
190
- ),
186
+ fsdp_unit_modules=self.fsdp_unit_modules,
191
187
  data_parallel_sharding_strategy=self.data_parallel_sharding_strategy,
192
188
  ),
193
189
  data_parallel_group=self.data_parallel_group,
@@ -211,8 +207,24 @@ class FullyShardedDataParallel(_BaseDataParallel):
211
207
  # Initialize the all-gather pipeline.
212
208
  self.all_gather_pipeline = AllGatherPipeline(self.param_and_grad_buffer)
213
209
 
214
- self.suggested_RS_queue_capacity = self.ddp_config.suggested_communication_unit_size
215
- self.suggested_AG_prefetch_size = self.ddp_config.suggested_communication_unit_size
210
+ suggested_communication_unit_size = self.ddp_config.suggested_communication_unit_size
211
+ if suggested_communication_unit_size is None:
212
+ if self.data_parallel_sharding_strategy == "optim_grads_params":
213
+ total_param_elements = 0
214
+ total_fsdp_module = 0
215
+ for module in self.module.modules():
216
+ if isinstance(module, tuple(self.fsdp_unit_modules)):
217
+ total_fsdp_module += 1
218
+ total_param_elements += sum(p.numel() for p in module.parameters())
219
+ # The suggested size is twice the number of elements in the FSDP modules.
220
+ # This ensures we process the current FSDP module and attempt to prefetch
221
+ # the next FSDP module, making the flow of communication better.
222
+ suggested_communication_unit_size = total_param_elements // total_fsdp_module * 2
223
+ elif self.bucket_size is not None:
224
+ suggested_communication_unit_size = self.bucket_size * 2
225
+
226
+ self.suggested_RS_queue_capacity = suggested_communication_unit_size
227
+ self.suggested_AG_prefetch_size = suggested_communication_unit_size
216
228
 
217
229
  def _register_fsdp_hooks(self, root_module):
218
230
  """Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
@@ -222,8 +234,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
222
234
  - Pre-forward hook: Unshards parameters before forward pass
223
235
  - Post-forward hook: Reshards parameters after forward pass
224
236
  - Pre-backward hook: Unshards parameters before backward pass
225
- - Post-backward hook: Reshards parameters after backward pass
226
- - Gradient accumulation hook: Handles gradient accumulation and reduction across devices
237
+ - Post-backward hook: Reshards parameters and reduces gradients after backward pass
227
238
 
228
239
  Args:
229
240
  root_module: The PyTorch module to register FSDP hooks on
@@ -257,10 +268,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
257
268
  `optim` and `optim_grads` do not require FSDP units because they do not
258
269
  shard model parameters.
259
270
  """
260
- if self.data_parallel_sharding_strategy != "optim_grads_params":
261
- fsdp_unit_modules = []
262
- else:
263
- fsdp_unit_modules = self.fsdp_unit_modules
271
+ fsdp_unit_modules = self.fsdp_unit_modules
264
272
 
265
273
  def release_module_parameters(module, *unused):
266
274
  for param in module.parameters():
@@ -283,27 +291,74 @@ class FullyShardedDataParallel(_BaseDataParallel):
283
291
  prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
284
292
  wait_bucket_ready=True,
285
293
  ):
286
- wait_list = []
287
294
  ag_pipeline = self.all_gather_pipeline
288
- for param in module.parameters():
289
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
290
- ag_pipeline.queue_bucket_to_all_gather(
291
- bucket_id,
292
- prefetch=prefetch,
293
- prefetch_order=prefetch_order,
294
- suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
295
- )
296
- wait_list.append(bucket_id)
297
-
295
+ ag_pipeline.all_gather_params(
296
+ params=list(module.parameters()),
297
+ prefetch=prefetch,
298
+ prefetch_order=prefetch_order,
299
+ suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
300
+ )
298
301
  if wait_bucket_ready:
299
- for bucket_id in wait_list:
302
+ for param in module.parameters():
303
+ bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
300
304
  ag_pipeline.wait_bucket_ready(bucket_id)
301
305
 
306
+ def _grad_acc(param):
307
+ """
308
+ Accumulate the gradient in the main_grad buffer.
309
+ """
310
+ group_id = self.param_and_grad_buffer.param_to_param_group[param]
311
+ group = self.param_and_grad_buffer.parameter_groups[group_id]
312
+ if not group.requires_grad:
313
+ return
314
+
315
+ overwrite_main_grad = self.ddp_config.data_parallel_sharding_strategy in [
316
+ "optim_grads",
317
+ "optim_grads_params",
318
+ ]
319
+ if overwrite_main_grad:
320
+ if not param.grad_added_to_main_grad:
321
+ if param.grad is not None:
322
+ param.main_grad.copy_(param.grad)
323
+ del param.grad
324
+ else:
325
+ param.main_grad.zero_()
326
+ else:
327
+ if not param.grad_added_to_main_grad:
328
+ if param.grad is not None:
329
+ param.main_grad.add_(param.grad)
330
+ del param.grad
331
+ # Reset the grad accumulate flag.
332
+ param.grad_added_to_main_grad = False
333
+
334
+ self._params_require_handle_grad = set()
335
+
302
336
  def _post_backward(module, *unused):
303
- release_module_parameters(module)
304
- module._training_state = TrainingState.IDLE
337
+ if isinstance(module, tuple(fsdp_unit_modules)):
338
+ if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
339
+ release_module_parameters(module)
340
+ module._training_state = TrainingState.IDLE
341
+ param_list = list(module.parameters())
342
+ else:
343
+ param_list = list(module.parameters(recurse=False))
344
+
345
+ for param in param_list:
346
+ _grad_acc(param)
347
+ self._params_require_handle_grad.discard(param)
348
+
349
+ grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
350
+ "optim_grads",
351
+ "optim_grads_params",
352
+ ]
353
+ if grad_reduce_every_bprop or self.is_last_microbatch:
354
+ self.grad_reduce_pipeline.reduce_gradients(
355
+ param_list, suggested_queue_capacity=self.suggested_RS_queue_capacity
356
+ )
305
357
 
306
- def _pre_forward(module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]):
358
+ def _pre_forward_param_unshard(
359
+ module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
360
+ ):
361
+ # Unshard the parameters before the forward pass.
307
362
  input_training_state = module._training_state
308
363
  fsdp_forward_prefetch = True
309
364
  if input_training_state == TrainingState.PRE_BACKWARD:
@@ -313,72 +368,104 @@ class FullyShardedDataParallel(_BaseDataParallel):
313
368
  module._training_state = TrainingState.FORWARD
314
369
 
315
370
  if isinstance(module, tuple(fsdp_unit_modules)):
316
- wait_list = []
317
- for param in module.parameters():
371
+ param_list = list(module.parameters())
372
+ self.all_gather_pipeline.all_gather_params(
373
+ params=param_list,
374
+ prefetch=fsdp_forward_prefetch,
375
+ suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
376
+ )
377
+ for param in param_list:
318
378
  bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
319
- self.all_gather_pipeline.queue_bucket_to_all_gather(
320
- bucket_id,
321
- prefetch=fsdp_forward_prefetch,
322
- suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
323
- )
324
- wait_list.append(bucket_id)
325
- for bucket_id in wait_list:
326
379
  self.all_gather_pipeline.wait_bucket_ready(bucket_id)
327
-
328
- if not torch.is_grad_enabled():
329
- return args, kwargs
330
-
331
- # Register the backward function to release the parameters.
332
- args_list, args_spec = tree_flatten(args)
333
- kwargs_list, kwargs_spec = tree_flatten(kwargs)
334
- args_kwargs_list = list(args_list) + list(kwargs_list)
335
- inp_tensor_indices: List[int] = []
336
- inp_tensors: List[torch.Tensor] = []
337
- for i, obj in enumerate(args_kwargs_list):
338
- if torch.is_tensor(obj) and obj.requires_grad:
339
- inp_tensor_indices.append(i)
340
- inp_tensors.append(obj)
341
- if len(inp_tensors) == 0:
342
- return args, kwargs
343
- inp_tensors = RegisterFSDPBackwardFunction.apply(
344
- functools.partial(_post_backward, module), *inp_tensors
345
- )
346
- for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
347
- args_kwargs_list[inp_tensor_idx] = inp_tensor
348
- args_list = args_kwargs_list[: len(args_list)]
349
- kwargs_list = args_kwargs_list[len(args_list) :]
350
- args = tree_unflatten(args_list, args_spec)
351
- kwargs = tree_unflatten(kwargs_list, kwargs_spec)
352
-
353
- return args, kwargs
354
380
  else:
355
381
  # All-gather the parameters in every forward pass for FSDP.
356
- for param in module.parameters(recurse=False):
357
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
358
- self.all_gather_pipeline.queue_bucket_to_all_gather(
359
- bucket_id,
360
- prefetch=fsdp_forward_prefetch,
361
- suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
362
- )
363
- for param in module.parameters(recurse=False):
382
+ param_list = list(module.parameters(recurse=False))
383
+ self.all_gather_pipeline.all_gather_params(
384
+ params=param_list,
385
+ prefetch=fsdp_forward_prefetch,
386
+ suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
387
+ )
388
+ for param in param_list:
364
389
  bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
365
390
  self.all_gather_pipeline.wait_bucket_ready(bucket_id)
391
+ return args, kwargs
392
+
393
+ def _register_post_backward_hook(
394
+ post_backward_hook: callable,
395
+ module: nn.Module,
396
+ args: Tuple[Any, ...],
397
+ kwargs: Dict[str, Any],
398
+ ):
399
+ # Register the backward function to reduce gradients after the backward pass.
400
+ # And for optim_grads_params, we need to release the parameters after the backward pass.
401
+ if not torch.is_grad_enabled():
402
+ return args, kwargs
403
+
404
+ args_list, args_spec = tree_flatten(args)
405
+ kwargs_list, kwargs_spec = tree_flatten(kwargs)
406
+ args_kwargs_list = list(args_list) + list(kwargs_list)
407
+ inp_tensor_indices: List[int] = []
408
+ inp_tensors: List[torch.Tensor] = []
409
+ for i, obj in enumerate(args_kwargs_list):
410
+ if torch.is_tensor(obj) and obj.requires_grad:
411
+ inp_tensor_indices.append(i)
412
+ inp_tensors.append(obj)
413
+
414
+ if len(inp_tensors) == 0:
415
+ return args, kwargs
416
+
417
+ inp_tensors = RegisterFSDPBackwardFunction.apply(
418
+ functools.partial(post_backward_hook, module), *inp_tensors
419
+ )
420
+
421
+ for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
422
+ args_kwargs_list[inp_tensor_idx] = inp_tensor
423
+ args_list = args_kwargs_list[: len(args_list)]
424
+ kwargs_list = args_kwargs_list[len(args_list) :]
425
+ args = tree_unflatten(args_list, args_spec)
426
+ kwargs = tree_unflatten(kwargs_list, kwargs_spec)
366
427
 
367
428
  return args, kwargs
368
429
 
369
- if self.ddp_config.overlap_param_gather:
370
- fsdp_modules = []
371
- for name, module in root_module.named_modules():
372
- if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
373
- if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
374
- continue
430
+ fsdp_modules = []
431
+ for name, module in root_module.named_modules():
432
+ if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
433
+ continue
375
434
 
376
- if isinstance(module, tuple(fsdp_unit_modules)):
377
- fsdp_modules.append(module)
435
+ if isinstance(module, tuple(fsdp_unit_modules)):
436
+ fsdp_modules.append(module)
437
+
438
+ self.forward_pre_hooks[f'module {name} parameter unshard'] = (
439
+ module.register_forward_pre_hook(
440
+ _pre_forward_param_unshard, prepend=True, with_kwargs=True
441
+ )
442
+ )
443
+ self.forward_pre_hooks[f"module {name} register post-backward hook"] = (
444
+ module.register_forward_pre_hook(
445
+ functools.partial(_register_post_backward_hook, _post_backward),
446
+ with_kwargs=True,
447
+ )
448
+ )
378
449
 
379
- self.forward_pre_hooks[f'module {name} parameter all-gather'] = (
380
- module.register_forward_pre_hook(_pre_forward, prepend=True, with_kwargs=True)
450
+ def _root_post_backward(*unused):
451
+ # Make sure all the gradients are handled.
452
+ for param in self._params_require_handle_grad:
453
+ _grad_acc(param)
454
+
455
+ # Reduce the remain gradients.
456
+ grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
457
+ "optim_grads",
458
+ "optim_grads_params",
459
+ ]
460
+ if grad_reduce_every_bprop or self.is_last_microbatch:
461
+ self.grad_reduce_pipeline.reduce_gradients(
462
+ list(self._params_require_handle_grad),
463
+ suggested_queue_capacity=self.suggested_RS_queue_capacity,
381
464
  )
465
+ self.grad_reduce_pipeline.reset()
466
+
467
+ # Reset root_pre_backward_hook_issued flag.
468
+ self._root_pre_backward_hook_issued = False
382
469
 
383
470
  def _pre_backward(module: nn.Module, *unused):
384
471
  module._training_state = TrainingState.PRE_BACKWARD
@@ -387,6 +474,8 @@ class FullyShardedDataParallel(_BaseDataParallel):
387
474
  module, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER
388
475
  )
389
476
 
477
+ self._root_pre_backward_hook_issued = False
478
+
390
479
  def _root_pre_backward(module: nn.Module, *unused):
391
480
  """Marks the module's training state as 'pre_backward' before the
392
481
  backprop, this function is registered on the root module.
@@ -395,13 +484,26 @@ class FullyShardedDataParallel(_BaseDataParallel):
395
484
  perform reshard/unshard operations in activation recomputation
396
485
  scenarios.
397
486
  """
398
- for module in root_module.modules():
399
- if isinstance(module, tuple(fsdp_unit_modules)):
400
- module._training_state = TrainingState.PRE_BACKWARD
401
- for param in module.parameters():
402
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
403
- self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True)
404
- self.all_gather_pipeline.release_bucket(bucket_id)
487
+ if self._root_pre_backward_hook_issued:
488
+ return
489
+ self._root_pre_backward_hook_issued = True
490
+
491
+ if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
492
+ for module in root_module.modules():
493
+ if isinstance(module, tuple(fsdp_unit_modules)):
494
+ module._training_state = TrainingState.PRE_BACKWARD
495
+ for param in module.parameters():
496
+ bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
497
+ self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True)
498
+ self.all_gather_pipeline.release_bucket(bucket_id)
499
+ self._params_require_handle_grad = set()
500
+ for param_group in self.param_and_grad_buffer.parameter_groups:
501
+ if not param_group.requires_grad:
502
+ continue
503
+ self._params_require_handle_grad |= set(param_group.params)
504
+ for param in param_group.params:
505
+ param.grad_added_to_main_grad = False
506
+ torch.autograd.Variable._execution_engine.queue_callback(_root_post_backward)
405
507
 
406
508
  def _post_forward(module: nn.Module, input: Any, output: Any):
407
509
  # When composing with module-hook-based activation checkpointing, the
@@ -417,7 +519,7 @@ class FullyShardedDataParallel(_BaseDataParallel):
417
519
  def _release_module_fp8_transpose_cache(module: nn.Module, *unused):
418
520
  release_params_fp8_transpose_cache(module.parameters(recurse=False))
419
521
 
420
- if self.data_parallel_sharding_strategy == "optim_grads_params":
522
+ if len(fsdp_unit_modules) != 0:
421
523
  fsdp_modules = []
422
524
  for name, module in root_module.named_modules():
423
525
  if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
@@ -437,68 +539,20 @@ class FullyShardedDataParallel(_BaseDataParallel):
437
539
  _release_module_fp8_transpose_cache, prepend=False
438
540
  )
439
541
  )
440
- self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
441
- _root_pre_backward
442
- )
443
-
444
- def _make_param_hook(param: torch.nn.Parameter):
445
- """
446
- Creates the all-reduce / reduce-scatter hook for backprop.
447
- """
448
542
 
449
- wait_previous_grad_reduce = not self.is_delay_grad_reduce
450
-
451
- # FIXME: Use insert forward op to replace grad acc hook, which will
452
- # be lost after parameter data movement. For example, module.cuda()
453
- # will cause the registered grad acc hook to be lost.
454
- def param_hook(*unused):
455
- if param.requires_grad:
456
- if self.ddp_config.overlap_grad_reduce:
457
- assert (
458
- param.grad is not None
459
- ), 'param.grad being None is not safe when overlap_grad_reduce is True'
460
-
461
- if param.grad is not None and (
462
- not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
463
- ):
464
- if self.is_delay_grad_reduce:
465
- param.main_grad.add_(param.grad.data)
466
- else:
467
- param.main_grad.copy_(param.grad.data)
468
- param.grad = None
469
-
470
- if self.ddp_config.overlap_grad_reduce and (
471
- not self.is_delay_grad_reduce or self.is_last_microbatch
472
- ):
473
- gr_pipeline = self.grad_reduce_pipeline
474
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
475
- gr_pipeline.place_bucket(bucket_id)
476
- go_rs = gr_pipeline.mark_item_ready(param, async_rs=True)
477
- if go_rs and wait_previous_grad_reduce:
478
- gr_pipeline.wait_for_previous_grad_reduce(
479
- recommeded_queue_capacity=self.suggested_RS_queue_capacity
480
- )
543
+ # Registering all models with all parameters is to handle some special cases
544
+ # where the forward function of root_module is not called, but the forward
545
+ # functions of these equivalent modules are called instead.
546
+ for name, module in root_module.named_modules():
547
+ if len(list(module.parameters())) != len(list(root_module.parameters())):
548
+ continue
481
549
 
482
- return param_hook
483
-
484
- # Register backward gradient accumulation hook for each parameter.
485
- self.grad_accs = []
486
- for param in root_module.parameters():
487
- bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
488
- wbuf = self.param_and_grad_buffer.parameter_groups[bucket_id].model_weight_buffer
489
- if param.requires_grad:
490
- if wbuf and wbuf.is_data_distributed:
491
- wbuf.fetch_bucket(and_allocate_params_data=True)
492
-
493
- # Expand so we get access to grad_fn.
494
- param_tmp = param.expand_as(param)
495
- # Get the gradient accumulator function.
496
- grad_acc = param_tmp.grad_fn.next_functions[0][0]
497
- grad_acc.register_hook(_make_param_hook(param))
498
- self.grad_accs.append(grad_acc)
499
-
500
- if wbuf and wbuf.is_data_distributed:
501
- wbuf.free_bucket_storage()
550
+ self.backward_pre_hooks[f"{name} _root_pre_backward"] = (
551
+ module.register_full_backward_pre_hook(_root_pre_backward)
552
+ )
553
+ self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
554
+ _root_pre_backward
555
+ )
502
556
 
503
557
  @contextmanager
504
558
  def no_sync(self):
@@ -529,7 +583,8 @@ class FullyShardedDataParallel(_BaseDataParallel):
529
583
  """
530
584
  if not force_sync and self.ddp_config.overlap_param_gather:
531
585
  # All-gather the first bucket before the forward pass.
532
- self.all_gather_pipeline.queue_bucket_to_all_gather(bucket_id=0, prefetch=False)
586
+ first_param = list(self.module.parameters())[0]
587
+ self.all_gather_pipeline.all_gather_params(params=[first_param], prefetch=False)
533
588
  else:
534
589
  self.all_gather_pipeline.reset()
535
590
  for bucket_id in range(self.all_gather_pipeline.num_buckets):