megatron-core 0.15.0rc5__tar.gz → 0.15.0rc6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (351) hide show
  1. {megatron_core-0.15.0rc5/megatron_core.egg-info → megatron_core-0.15.0rc6}/PKG-INFO +1 -1
  2. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/distributed_data_parallel.py +7 -12
  3. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/finalize_model_grads.py +10 -12
  4. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +14 -3
  5. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/src/megatron_fsdp/package_info.py +1 -1
  6. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/param_and_grad_buffer.py +5 -2
  7. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/extensions/transformer_engine.py +96 -6
  8. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fp8_utils.py +22 -17
  9. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/optimizer/__init__.py +20 -2
  10. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/optimizer/distrib_optimizer.py +6 -3
  11. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/optimizer/optimizer_config.py +5 -0
  12. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/package_info.py +1 -1
  13. megatron_core-0.15.0rc6/megatron/core/pipeline_parallel/bridge_communicator.py +399 -0
  14. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/ssm/mamba_layer.py +32 -21
  15. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tensor_parallel/layers.py +13 -10
  16. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/cuda_graphs.py +92 -49
  17. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/mlp.py +5 -2
  18. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/module.py +172 -0
  19. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/experts.py +32 -27
  20. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/moe_utils.py +17 -8
  21. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/router.py +13 -1
  22. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/transformer_config.py +4 -2
  23. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/transformer_layer.py +114 -172
  24. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6/megatron_core.egg-info}/PKG-INFO +1 -1
  25. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron_core.egg-info/SOURCES.txt +1 -0
  26. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/pyproject.toml +21 -7
  27. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/LICENSE +0 -0
  28. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/MANIFEST.in +0 -0
  29. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/README.md +0 -0
  30. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/README.md +0 -0
  31. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/__init__.py +0 -0
  32. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/activations.py +0 -0
  33. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/config.py +0 -0
  34. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/config_logger.py +0 -0
  35. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/__init__.py +0 -0
  36. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/bert_dataset.py +0 -0
  37. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/blended_dataset.py +0 -0
  38. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  39. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  40. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/gpt_dataset.py +0 -0
  41. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/helpers.cpp +0 -0
  42. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/helpers.py +0 -0
  43. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/indexed_dataset.py +0 -0
  44. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/masked_dataset.py +0 -0
  45. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/megatron_dataset.py +0 -0
  46. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  47. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/multimodal_dataset.py +0 -0
  48. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/object_storage_utils.py +0 -0
  49. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/__init__.py +0 -0
  50. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/config/__init__.py +0 -0
  51. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  52. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/config/config.py +0 -0
  53. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  54. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  55. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/db/__init__.py +0 -0
  56. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/db/build.py +0 -0
  57. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/db/dataset.py +0 -0
  58. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/db/utils.py +0 -0
  59. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/external_libs.py +0 -0
  60. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/index/__init__.py +0 -0
  61. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/index/build.py +0 -0
  62. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/index/factory.py +0 -0
  63. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/index/index.py +0 -0
  64. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  65. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  66. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  67. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/index/utils.py +0 -0
  68. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/index/validate.py +0 -0
  69. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/query/__init__.py +0 -0
  70. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  71. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  72. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/query/query.py +0 -0
  73. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  74. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/query/utils.py +0 -0
  75. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/retro/utils.py +0 -0
  76. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/t5_dataset.py +0 -0
  77. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/utils.py +0 -0
  78. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/utils_object_storage.py +0 -0
  79. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/datasets/utils_s3.py +0 -0
  80. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/__init__.py +0 -0
  81. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/core.py +0 -0
  82. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  83. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  84. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/mapping.py +0 -0
  85. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  86. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/serialization.py +0 -0
  87. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  88. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  89. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  90. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  91. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  92. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/checkpointable.py +0 -0
  93. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  94. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  95. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  96. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  97. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  98. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  99. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  100. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  101. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  102. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  103. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/utils.py +0 -0
  104. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/dist_checkpointing/validation.py +0 -0
  105. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/__init__.py +0 -0
  106. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/data_parallel_base.py +0 -0
  107. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  108. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/__init__.py +0 -0
  109. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +0 -0
  110. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/src/__init__.py +0 -0
  111. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +0 -0
  112. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +0 -0
  113. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +0 -0
  114. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +0 -0
  115. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +0 -0
  116. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +0 -0
  117. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  118. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  119. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/energy_monitor.py +0 -0
  120. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/enums.py +0 -0
  121. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/__init__.py +0 -0
  122. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/data_type.py +0 -0
  123. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/export_config.py +0 -0
  124. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/model_type.py +0 -0
  125. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/__init__.py +0 -0
  126. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  127. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  128. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  129. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  130. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  131. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  132. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  133. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  134. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  135. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  136. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  137. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  138. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/extensions/__init__.py +0 -0
  139. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/extensions/kitchen.py +0 -0
  140. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  141. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fp4_utils.py +0 -0
  142. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/full_cuda_graph.py +0 -0
  143. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/__init__.py +0 -0
  144. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  145. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  146. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  147. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  148. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  149. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_indices_converter.py +0 -0
  150. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_layer_norm.py +0 -0
  151. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  152. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  153. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_softmax.py +0 -0
  154. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/fusions/fused_weighted_squared_relu.py +0 -0
  155. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/hyper_comm_grid.py +0 -0
  156. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/__init__.py +0 -0
  157. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/async_stream.py +0 -0
  158. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/common_inference_params.py +0 -0
  159. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/communication_utils.py +0 -0
  160. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/contexts/__init__.py +0 -0
  161. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/contexts/base_context.py +0 -0
  162. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  163. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/contexts/dynamic_context.py +0 -0
  164. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/contexts/static_context.py +0 -0
  165. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/data_parallel_inference_coordinator.py +0 -0
  166. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/engines/__init__.py +0 -0
  167. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/engines/abstract_engine.py +0 -0
  168. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/engines/dynamic_engine.py +0 -0
  169. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/engines/mcore_engine.py +0 -0
  170. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/engines/static_engine.py +0 -0
  171. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/headers.py +0 -0
  172. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/inference_client.py +0 -0
  173. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/inference_request.py +0 -0
  174. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  175. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  176. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  177. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  178. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  179. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  180. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  181. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  182. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/sampling_params.py +0 -0
  183. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/scheduler.py +0 -0
  184. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  185. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  186. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  187. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +0 -0
  188. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  189. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference/utils.py +0 -0
  190. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/inference_params.py +0 -0
  191. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/jit.py +0 -0
  192. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/model_parallel_config.py +0 -0
  193. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/T5/__init__.py +0 -0
  194. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/T5/t5_model.py +0 -0
  195. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/T5/t5_spec.py +0 -0
  196. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/__init__.py +0 -0
  197. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/backends.py +0 -0
  198. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/bert/__init__.py +0 -0
  199. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  200. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/bert/bert_lm_head.py +0 -0
  201. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/bert/bert_model.py +0 -0
  202. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/bert/pooler.py +0 -0
  203. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/__init__.py +0 -0
  204. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/embeddings/__init__.py +0 -0
  205. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  206. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  207. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  208. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  209. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  210. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/language_module/__init__.py +0 -0
  211. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/language_module/language_module.py +0 -0
  212. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
  213. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/vision_module/__init__.py +0 -0
  214. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  215. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/gpt/__init__.py +0 -0
  216. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  217. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
  218. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/gpt/gpt_model.py +0 -0
  219. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  220. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  221. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/huggingface/__init__.py +0 -0
  222. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/huggingface/clip_model.py +0 -0
  223. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/huggingface/module.py +0 -0
  224. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/huggingface/qwen_model.py +0 -0
  225. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mamba/__init__.py +0 -0
  226. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  227. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mamba/mamba_model.py +0 -0
  228. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mimo/__init__.py +0 -0
  229. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mimo/config/__init__.py +0 -0
  230. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mimo/config/base_configs.py +0 -0
  231. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mimo/model/__init__.py +0 -0
  232. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mimo/model/base.py +0 -0
  233. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mimo/submodules/audio.py +0 -0
  234. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mimo/submodules/base.py +0 -0
  235. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/mimo/submodules/vision.py +0 -0
  236. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/multimodal/__init__.py +0 -0
  237. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/multimodal/context_parallel.py +0 -0
  238. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/multimodal/llava_model.py +0 -0
  239. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/multimodal/llava_spec.py +0 -0
  240. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/retro/__init__.py +0 -0
  241. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/retro/base_attention.py +0 -0
  242. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/retro/config.py +0 -0
  243. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/retro/decoder_attention.py +0 -0
  244. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/retro/decoder_spec.py +0 -0
  245. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/retro/encoder_attention.py +0 -0
  246. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/retro/encoder_spec.py +0 -0
  247. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/retro/model.py +0 -0
  248. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/retro/utils.py +0 -0
  249. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/vision/__init__.py +0 -0
  250. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/vision/clip_vit_model.py +0 -0
  251. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/vision/multimodal_projector.py +0 -0
  252. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/vision/radio.py +0 -0
  253. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  254. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/msc_utils.py +0 -0
  255. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/nccl_allocator.py +0 -0
  256. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/num_microbatches_calculator.py +0 -0
  257. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/optimizer/clip_grads.py +0 -0
  258. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  259. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  260. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/optimizer/grad_scaler.py +0 -0
  261. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/optimizer/optimizer.py +0 -0
  262. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/optimizer_param_scheduler.py +0 -0
  263. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/packed_seq_params.py +0 -0
  264. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/parallel_state.py +0 -0
  265. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/pipeline_parallel/__init__.py +0 -0
  266. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
  267. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  268. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/pipeline_parallel/schedules.py +0 -0
  269. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/pipeline_parallel/utils.py +0 -0
  270. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/post_training/__init__.py +0 -0
  271. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/post_training/modelopt/__init__.py +0 -0
  272. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  273. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  274. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  275. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/post_training/modelopt/layers.py +0 -0
  276. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  277. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  278. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/process_groups_config.py +0 -0
  279. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/quantization/__init__.py +0 -0
  280. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/quantization/quant_config.py +0 -0
  281. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/quantization/utils.py +0 -0
  282. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/requirements.txt +0 -0
  283. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/rerun_state_machine.py +0 -0
  284. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/safe_globals.py +0 -0
  285. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/ssm/__init__.py +0 -0
  286. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/ssm/mamba_block.py +0 -0
  287. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  288. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  289. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/ssm/mamba_mixer.py +0 -0
  290. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/ssm/mlp_layer.py +0 -0
  291. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/ssm/triton_cache_manager.py +0 -0
  292. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tensor_parallel/__init__.py +0 -0
  293. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  294. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tensor_parallel/data.py +0 -0
  295. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tensor_parallel/mappings.py +0 -0
  296. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tensor_parallel/random.py +0 -0
  297. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tensor_parallel/utils.py +0 -0
  298. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/timers.py +0 -0
  299. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/__init__.py +0 -0
  300. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/base_tokenizer.py +0 -0
  301. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/megatron_tokenizer.py +0 -0
  302. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/__init__.py +0 -0
  303. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/libraries/__init__.py +0 -0
  304. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/libraries/abstract_tokenizer.py +0 -0
  305. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/libraries/bytelevel_tokenizer.py +0 -0
  306. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/libraries/chat_template.py +0 -0
  307. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py +0 -0
  308. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py +0 -0
  309. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/libraries/null_tokenizer.py +0 -0
  310. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py +0 -0
  311. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py +0 -0
  312. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/models/__init__.py +0 -0
  313. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/models/bert_tokenizer.py +0 -0
  314. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/models/default_tokenizer.py +0 -0
  315. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/models/gpt_tokenizer.py +0 -0
  316. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/models/mamba_tokenizer.py +0 -0
  317. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/models/retro_tokenizer.py +0 -0
  318. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/models/t5_tokenizer.py +0 -0
  319. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/text_tokenizer.py +0 -0
  320. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/tokenizers/text/utils/build_tokenizer.py +0 -0
  321. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/__init__.py +0 -0
  322. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/attention.py +0 -0
  323. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  324. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  325. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/dot_product_attention.py +0 -0
  326. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/enums.py +0 -0
  327. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/fsdp_dtensor_checkpoint.py +0 -0
  328. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  329. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  330. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/identity_op.py +0 -0
  331. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/__init__.py +0 -0
  332. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  333. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  334. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/moe_layer.py +0 -0
  335. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/shared_experts.py +0 -0
  336. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  337. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  338. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/multi_latent_attention.py +0 -0
  339. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/multi_token_prediction.py +0 -0
  340. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  341. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/spec_utils.py +0 -0
  342. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/torch_layer_norm.py +0 -0
  343. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/torch_norm.py +0 -0
  344. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/transformer_block.py +0 -0
  345. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/transformer/utils.py +0 -0
  346. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron/core/utils.py +0 -0
  347. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron_core.egg-info/dependency_links.txt +0 -0
  348. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron_core.egg-info/requires.txt +0 -0
  349. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/megatron_core.egg-info/top_level.txt +0 -0
  350. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/setup.cfg +0 -0
  351. {megatron_core-0.15.0rc5 → megatron_core-0.15.0rc6}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.15.0rc5
3
+ Version: 0.15.0rc6
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -519,8 +519,11 @@ class DistributedDataParallel(_BaseDataParallel):
519
519
  param_slice = bucket.param_data.view(-1)[param_start:param_end]
520
520
  param.data.copy_(param_slice.view(param.data.shape))
521
521
  # All-gathered params are not needed after being copied to param.data.
522
- # Zero out the grad buffer (shared with param buffer) for gradient accumulation.
523
- bucket.grad_data.zero_()
522
+ # Zero out the param buffer (shared with grad buffer) for gradient accumulation.
523
+ # We cannot zero out the entire grad buffer because one grad buffer may
524
+ # correspond to multiple param buffers. If we zero out the entire grad buffer,
525
+ # it would clear the data of those param buffers that have not yet completed AG.
526
+ bucket.param_data.zero_()
524
527
 
525
528
  def start_grad_sync(self, *unused):
526
529
  """
@@ -562,16 +565,8 @@ class DistributedDataParallel(_BaseDataParallel):
562
565
  # to True, and there will be a double-GA.
563
566
  for param in self.params_with_grad:
564
567
  param.grad_added_to_main_grad = False
565
- # In the case of "reuse_grad_buf_for_mxfp8_param_ag=True & overlap_param_gather=True",
566
- # the grad buffer is not reset here because the grad buffer is shared with the param buffer.
567
- # The grad buffer is zeroed by "bucket.grad_data.zero_()" in the "finish_param_sync" stage
568
- # after the param all-gather.
569
- if not (
570
- self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag
571
- and self.ddp_config.overlap_param_gather
572
- ):
573
- for buffer in self.buffers + self.expert_parallel_buffers:
574
- buffer.reset()
568
+ for buffer in self.buffers + self.expert_parallel_buffers:
569
+ buffer.reset()
575
570
  for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
576
571
  bucket_group.reset()
577
572
 
@@ -267,13 +267,18 @@ def _allreduce_position_embedding_grads(
267
267
  )
268
268
 
269
269
 
270
- def _reset_global_aux_loss_tracker(model: List[torch.nn.Module]):
270
+ def reset_model_temporary_tensors(config: TransformerConfig, model: List[torch.nn.Module]):
271
271
  """
272
- Reset the global aux loss tracker.
272
+ Reset the temporary tensors of the model.
273
273
  """
274
274
  for model_chunk in model:
275
275
  for module in get_attr_wrapped_model(model_chunk, 'modules')():
276
- if hasattr(module, 'reset_global_aux_loss_tracker'):
276
+ if config.moe_router_enable_expert_bias and hasattr(module, 'expert_bias'):
277
+ module.local_tokens_per_expert.zero_()
278
+ if (
279
+ config.moe_router_load_balancing_type == "global_aux_loss"
280
+ or "global_aux_loss" in config.moe_router_load_balancing_type
281
+ ) and hasattr(module, 'reset_global_aux_loss_tracker'):
277
282
  module.reset_global_aux_loss_tracker()
278
283
 
279
284
 
@@ -298,10 +303,7 @@ def _update_router_expert_bias(model: List[torch.nn.Module], config: Transformer
298
303
  stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate
299
304
  )
300
305
 
301
- for tokens_per_expert, expert_bias, updated_expert_bias in zip(
302
- tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias
303
- ):
304
- tokens_per_expert.zero_()
306
+ for expert_bias, updated_expert_bias in zip(expert_bias_list, stacked_updated_expert_bias):
305
307
  expert_bias.copy_(updated_expert_bias)
306
308
 
307
309
 
@@ -465,11 +467,7 @@ def finalize_model_grads(
465
467
  if config.moe_router_enable_expert_bias:
466
468
  _update_router_expert_bias(model, config)
467
469
 
468
- if (
469
- config.moe_router_load_balancing_type == "global_aux_loss"
470
- or "global_aux_loss" in config.moe_router_load_balancing_type
471
- ):
472
- _reset_global_aux_loss_tracker(model)
470
+ reset_model_temporary_tensors(config, model)
473
471
 
474
472
  # normalize gradients for per-token loss normalization.
475
473
  # if we are using by the number of tokens, then we use that as a divisor. this number
@@ -224,7 +224,7 @@ class MegatronFSDP(torch.nn.Module):
224
224
  # step of the model will reduce all gradients and gather all parameters
225
225
  # for synchronized operations such as distributed optimization and
226
226
  # distributed checkpointing particularly sharding with HSDP / DP-Outer.
227
- self.model_auto_sync = self.set_model_auto_sync(sync_model_each_microbatch)
227
+ self.set_model_auto_sync(sync_model_each_microbatch)
228
228
 
229
229
  # Check if the module contains (Megatron-Core) expert parallel parameters or DTensors.
230
230
  has_expert_parameters = self._check_module_parameter_types()
@@ -307,8 +307,11 @@ class MegatronFSDP(torch.nn.Module):
307
307
  expert_gradient_scaling_factor = None
308
308
  else:
309
309
  if self.ddp_config.average_in_collective:
310
- # FIXME(@jianbinc): Will fix this issue based on Parallel Folding's EDP patch MR.
311
- raise Exception("Not supported")
310
+ gradient_scaling_factor = 1.0
311
+ expert_gradient_scaling_factor = (
312
+ self.dist_index.get_dp_group(is_expert_parallel=True).size()
313
+ / self.dist_index.get_dp_group().size()
314
+ )
312
315
  else:
313
316
  data_parallel_world_size = self.dist_index.get_dp_group().size()
314
317
  gradient_scaling_factor = 1.0 / data_parallel_world_size
@@ -426,6 +429,14 @@ class MegatronFSDP(torch.nn.Module):
426
429
  bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
427
430
  ag_pipeline.wait_bucket_ready(bucket_id)
428
431
 
432
+ for param in params:
433
+ # This setting is needed to make FSDP store the weight object when used
434
+ # with TE's activation offloading for the first global batch.
435
+ param.grad_added_to_main_grad = False
436
+ # This setting is needed to have this attribute present after every
437
+ # un-shard of the FSDP params.
438
+ param.__fsdp_param__ = True
439
+
429
440
  def _register_fsdp_hooks(self, root_module):
430
441
  """Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
431
442
 
@@ -4,7 +4,7 @@
4
4
  MAJOR = 0
5
5
  MINOR = 1
6
6
  PATCH = 0
7
- PRE_RELEASE = 'rc3'
7
+ PRE_RELEASE = 'rc4'
8
8
 
9
9
  # Use the following formatting: (major, minor, patch, pre-release)
10
10
  VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
@@ -313,8 +313,11 @@ class _ParamAndGradBucketGroup:
313
313
  param_slice = bucket.param_data.view(-1)[param_start:param_end]
314
314
  param.data.copy_(param_slice.view(param.data.shape))
315
315
  # All-gathered params are not needed after being copied to param.data.
316
- # Zero out the grad buffer (shared with param buffer) for gradient accumulation.
317
- bucket.grad_data.zero_()
316
+ # Zero out the param buffer (shared with grad buffer) for gradient accumulation.
317
+ # We cannot zero out the entire grad buffer because one grad buffer may
318
+ # correspond to multiple param buffers. If we zero out the entire grad buffer,
319
+ # it would clear the data of those param buffers that have not yet completed AG.
320
+ bucket.param_data.zero_()
318
321
 
319
322
  def start_grad_sync(self):
320
323
  """
@@ -1,6 +1,7 @@
1
1
  # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
2
 
3
3
  import dataclasses
4
+ import inspect
4
5
  import io
5
6
  import os
6
7
  import pickle
@@ -1591,21 +1592,21 @@ if HAVE_TE and is_te_min_version("1.13.0"):
1591
1592
  if self.linear_fc2.config.tp_comm_overlap and self.linear_fc2.ub_name is not None:
1592
1593
  userbuffers_options = {"comm_name": self.linear_fc2.ub_name}
1593
1594
  op = te.pytorch.ops.BasicLinear(
1594
- weight.size(1) * tp_world_size,
1595
+ weight.size(1),
1595
1596
  weight.size(0),
1596
1597
  device="meta",
1597
1598
  dtype=weight.dtype,
1598
- tensor_parallel_mode="row" if tp_world_size > 1 else None,
1599
- tensor_parallel_group=tp_group,
1600
- sequence_parallel=self.linear_fc2.sequence_parallel,
1601
1599
  rng_state_tracker_function=rng_state_tracker_function,
1602
1600
  accumulate_into_main_grad=self.linear_fc2.fuse_wgrad_accumulation,
1603
1601
  userbuffers_options=userbuffers_options,
1604
1602
  )
1605
1603
  op.weight = weight
1606
1604
  fused_impl.append(op)
1607
- if tp_world_size > 1 and self.linear_fc2.sequence_parallel:
1608
- fused_impl.append(te.pytorch.ops.ReduceScatter(tp_group))
1605
+ if tp_world_size > 1:
1606
+ if self.linear_fc2.sequence_parallel:
1607
+ fused_impl.append(te.pytorch.ops.ReduceScatter(tp_group))
1608
+ else:
1609
+ fused_impl.append(te.pytorch.ops.AllReduce(tp_group))
1609
1610
 
1610
1611
  # FC2 bias op
1611
1612
  if not self.linear_fc2.te_return_bias:
@@ -1617,6 +1618,9 @@ if HAVE_TE and is_te_min_version("1.13.0"):
1617
1618
  op.bias = bias
1618
1619
  fused_impl.append(op)
1619
1620
 
1621
+ # Emulate submodule forward hooks if needed
1622
+ self._register_hooks_on_fused_impl(fused_impl)
1623
+
1620
1624
  return fused_impl
1621
1625
 
1622
1626
  def _make_activation_op(
@@ -1655,6 +1659,92 @@ if HAVE_TE and is_te_min_version("1.13.0"):
1655
1659
  kwargs["cache_quantized_input"] = cache_quantized_input
1656
1660
  return op_type(**kwargs)
1657
1661
 
1662
+ def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None:
1663
+ """Attempt to emulate submodule callback hooks.
1664
+
1665
+ This is not always possible because Transformer Engine's
1666
+ op fuser does not expose intermediate tensors. Depending
1667
+ on what kernel fusions the op fuser chooses, the
1668
+ intermediate tensors may not even exist. Hooks that modify
1669
+ tensors will result in incorrect behavior.
1670
+
1671
+ """
1672
+
1673
+ # Get submodule hooks
1674
+ forward_pre_hooks = []
1675
+ forward_post_hooks = []
1676
+ backward_pre_hooks = []
1677
+ backward_post_hooks = []
1678
+ for submodule in self.modules():
1679
+ for hook in submodule._forward_pre_hooks.values():
1680
+ forward_pre_hooks.append((submodule, hook))
1681
+ for hook in submodule._forward_hooks.values():
1682
+ forward_post_hooks.append((submodule, hook))
1683
+ for hook in submodule._backward_pre_hooks.values():
1684
+ backward_pre_hooks.append((submodule, hook))
1685
+ for hook in submodule._backward_hooks.values():
1686
+ backward_post_hooks.append((submodule, hook))
1687
+
1688
+ # Pre-forward hooks
1689
+ # Note: DDP pre-forward hooks are safe since they do not
1690
+ # interact with input tensor.
1691
+ if forward_pre_hooks:
1692
+ from megatron.core.distributed import distributed_data_parallel
1693
+
1694
+ if any(
1695
+ inspect.getmodule(hook) != distributed_data_parallel
1696
+ for _, hook in forward_pre_hooks
1697
+ ):
1698
+ warnings.warn(
1699
+ "TEFusedMLP module has a submodule with a pre-forward hook. "
1700
+ "TEFusedMLP module does not expose intermediate tensors, "
1701
+ "so the hook may have incorrect behavior if it attempts to "
1702
+ "access the input tensor."
1703
+ )
1704
+
1705
+ def forward_pre_hook(module, *_) -> None:
1706
+ for submodule, hook in forward_pre_hooks:
1707
+ # Assume that hook does not interact with input
1708
+ ret = hook(submodule, None)
1709
+ if ret is not None:
1710
+ raise RuntimeError(
1711
+ "TEFusedMLP module does not expose intermediate tensors, but "
1712
+ "submodule has pre-forward hook that modifies input tensor."
1713
+ )
1714
+
1715
+ fused_impl.register_forward_pre_hook(forward_pre_hook)
1716
+
1717
+ # Post-forward hooks
1718
+ if forward_post_hooks:
1719
+ warnings.warn(
1720
+ "TEFusedMLP module has a submodule with a post-forward hook. "
1721
+ "TEFusedMLP module does not expose intermediate tensors, "
1722
+ "so the hook may have incorrect behavior if it attempts to "
1723
+ "access the input or output tensors."
1724
+ )
1725
+
1726
+ def forward_post_hook(module, *_) -> None:
1727
+ for submodule, hook in forward_post_hooks:
1728
+ # Assume that hook does not interact with input or output
1729
+ ret = hook(submodule, None, None)
1730
+ if ret is not None:
1731
+ raise RuntimeError(
1732
+ "TEFusedMLP module does not expose intermediate tensors, but "
1733
+ "submodule has post-forward hook that modifies output tensor."
1734
+ )
1735
+
1736
+ fused_impl.register_forward_hook(forward_post_hook)
1737
+
1738
+ # Backward hooks
1739
+ if backward_pre_hooks:
1740
+ raise RuntimeError(
1741
+ "TEFusedMLP module does not support submodules with pre-backward hooks"
1742
+ )
1743
+ if backward_post_hooks:
1744
+ raise RuntimeError(
1745
+ "TEFusedMLP module does not support submodules with post-backward hooks"
1746
+ )
1747
+
1658
1748
  def forward(self, hidden_states: torch.Tensor) -> Tuple[Tensor, Optional[Tensor]]:
1659
1749
  """Forward."""
1660
1750
 
@@ -406,6 +406,25 @@ def correct_amax_history_if_needed(model: List[torch.nn.Module]):
406
406
  _correct_amax_history_if_needed_impl(model)
407
407
 
408
408
 
409
+ def is_first_last_bf16_layer(config: TransformerConfig, layer_no: int):
410
+ """Check if the layer is in bf16."""
411
+ num_bf16_layers_at_start = (
412
+ config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0
413
+ )
414
+ num_bf16_layers_at_end = (
415
+ config.num_layers_at_end_in_bf16 if config.first_last_layers_bf16 else 0
416
+ )
417
+ # Since layer_no is a global layer index, additional checks on whether
418
+ # we are in the first or last pipeline-parallel rank are not needed.
419
+ is_first_layer = layer_no < num_bf16_layers_at_start
420
+ is_last_layer = layer_no >= config.num_layers - num_bf16_layers_at_end
421
+
422
+ if layer_no >= 0 and config.first_last_layers_bf16 and (is_first_layer or is_last_layer):
423
+ return True
424
+ else:
425
+ return False
426
+
427
+
409
428
  if HAVE_TE:
410
429
  from megatron.core import parallel_state
411
430
  from megatron.core.extensions.transformer_engine import TEDelayedScaling
@@ -437,7 +456,7 @@ if HAVE_TE:
437
456
  )
438
457
  elif config.fp8_recipe == Fp8Recipe.tensorwise and is_te_min_version("2.2.0.dev0"):
439
458
  fp8_recipe = transformer_engine.common.recipe.Float8CurrentScaling(
440
- fp8_format=fp8_format
459
+ fp8_format=fp8_format, fp8_dpa=config.fp8_dot_product_attention
441
460
  )
442
461
  elif config.fp8_recipe == Fp8Recipe.blockwise and is_te_min_version("2.3.0.dev0"):
443
462
  fp8_recipe = transformer_engine.common.recipe.Float8BlockScaling(
@@ -483,24 +502,10 @@ if HAVE_TE:
483
502
  that needs to be trained in bf16.
484
503
  """
485
504
 
486
- num_bf16_layers_at_start = (
487
- config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0
488
- )
489
- num_bf16_layers_at_end = (
490
- config.num_layers_at_end_in_bf16 if config.first_last_layers_bf16 else 0
491
- )
492
- # Since layer_no is a global layer index, additional checks on whether
493
- # we are in the first or last pipeline-parallel rank are not needed.
494
- is_first_layer = layer_no < num_bf16_layers_at_start
495
- is_last_layer = layer_no >= config.num_layers - num_bf16_layers_at_end
496
-
497
505
  need_fp8_context = config.fp8 if not is_init else config.fp8_param
498
506
 
499
- if not need_fp8_context:
500
- # bf16 training
501
- fp8_context = nullcontext()
502
- elif layer_no >= 0 and config.first_last_layers_bf16 and (is_first_layer or is_last_layer):
503
- # fp8 training but this layer_no should be bf16
507
+ if not need_fp8_context or is_first_last_bf16_layer(config, layer_no):
508
+ # bf16 training or bf16 layer in fp8 training
504
509
  fp8_context = nullcontext()
505
510
  else:
506
511
  # fp8 training and this layer_no is in fp8
@@ -10,10 +10,14 @@ from torch.optim import AdamW as CPUAdam
10
10
  try:
11
11
  from transformer_engine.pytorch.optimizers import FusedAdam as Adam
12
12
  from transformer_engine.pytorch.optimizers import FusedSGD as SGD
13
+
14
+ USING_PYTORCH_OPTIMIZER = False
13
15
  except ImportError:
14
16
  try:
15
17
  from apex.optimizers import FusedAdam as Adam
16
18
  from apex.optimizers import FusedSGD as SGD
19
+
20
+ USING_PYTORCH_OPTIMIZER = False
17
21
  except ImportError:
18
22
  warnings.warn(
19
23
  f'Transformer Engine and Apex are not installed. Falling back to Torch optimizers.'
@@ -22,7 +26,10 @@ except ImportError:
22
26
  # Apex's FusedAdam is a drop-in replacement for torch's AdamW.
23
27
  # pylint: disable-next=line-too-long.
24
28
  # See https://github.com/NVIDIA/apex/blob/7b73b12361068a10b0f44844534613f252a5ea75/apex/optimizers/fused_adam.py#L16.
25
- from torch.optim import AdamW as Adam, SGD
29
+ from torch.optim import SGD
30
+ from torch.optim import AdamW as Adam
31
+
32
+ USING_PYTORCH_OPTIMIZER = True
26
33
 
27
34
  from megatron.core import parallel_state
28
35
  from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer
@@ -305,6 +312,9 @@ def _get_megatron_optimizer_based_on_param_groups(
305
312
  "CPU offload is recommended for PyTorch >= 2.3.0, "
306
313
  "untested versions below this may have convergence issues."
307
314
  )
315
+ assert (
316
+ config.decoupled_weight_decay
317
+ ), "CPU offloading only supported with decoupled_weight_decay enabled (AdamW mode)."
308
318
  gpu_optimizer_cls = Adam if config.optimizer == 'adam' else SGD
309
319
  cpu_optimizer_cls = CPUAdam if config.optimizer == 'adam' else CPUSGD
310
320
  if config.use_torch_optimizer_for_cpu_offload:
@@ -347,6 +357,14 @@ def _get_megatron_optimizer_based_on_param_groups(
347
357
  "eps": config.adam_eps,
348
358
  }
349
359
 
360
+ # set Adam class and weight decay mode depending
361
+ # on source of optimizer (Torch or TE/Apex)
362
+ if USING_PYTORCH_OPTIMIZER:
363
+ adam_cls = torch.optim.AdamW if config.decoupled_weight_decay else torch.optim.Adam
364
+ else:
365
+ kwargs["adam_w_mode"] = config.decoupled_weight_decay
366
+ adam_cls = Adam
367
+
350
368
  if config.use_precision_aware_optimizer:
351
369
  kwargs.update(
352
370
  {
@@ -371,7 +389,7 @@ def _get_megatron_optimizer_based_on_param_groups(
371
389
  if is_te_min_version("2.1.0.dev0"):
372
390
  kwargs.update({"store_param_remainders": config.store_param_remainders})
373
391
 
374
- optimizer = Adam(**kwargs)
392
+ optimizer = adam_cls(**kwargs)
375
393
 
376
394
  def init_state_fn(opt, config=None):
377
395
  for group in opt.param_groups:
@@ -28,7 +28,7 @@ except ImportError:
28
28
 
29
29
  USING_APEX_OPTIMIZER = True
30
30
  except ImportError:
31
- from torch.optim import AdamW as Adam
31
+ from torch.optim import Adam as Adam
32
32
 
33
33
  HAVE_APEX_OR_TE = False
34
34
 
@@ -507,7 +507,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
507
507
  assert self.ddp_config == model_chunk.ddp_config
508
508
  self.distributed_optimizer_instance_id = distributed_optimizer_instance_id
509
509
 
510
- assert isinstance(optimizer, (Adam, HybridDeviceOptimizer)) or optimizer is None, (
510
+ assert (
511
+ isinstance(optimizer, (Adam, torch.optim.AdamW, HybridDeviceOptimizer))
512
+ or optimizer is None
513
+ ), (
511
514
  "Only Adam and HybridDeviceOptimizer currently supported, "
512
515
  "due to checkpointing requirements."
513
516
  )
@@ -637,7 +640,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
637
640
  elif isinstance(self.optimizer, HybridDeviceOptimizer):
638
641
  step = None
639
642
  for optimizer in self.optimizer.sub_optimizers:
640
- if isinstance(optimizer, torch.optim.AdamW):
643
+ if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
641
644
  if len(optimizer.state) == 0:
642
645
  continue
643
646
  steps = list(set([s["step"].item() for s in optimizer.state.values()]))
@@ -115,6 +115,11 @@ class OptimizerConfig:
115
115
  adam_eps: float = 1e-08
116
116
  """Term added to the denominator to improve numerical stability in Adam optimizer."""
117
117
 
118
+ decoupled_weight_decay: bool = True
119
+ """If true, decouples weight decay from the gradient update, equivalent to AdamW. If false,
120
+ original Adam update rule will be used. Defaults to True.
121
+ """
122
+
118
123
  # SGD.
119
124
  sgd_momentum: float = 0.9
120
125
  """Momentum factor for SGD optimizer."""
@@ -4,7 +4,7 @@
4
4
  MAJOR = 0
5
5
  MINOR = 15
6
6
  PATCH = 0
7
- PRE_RELEASE = 'rc5'
7
+ PRE_RELEASE = 'rc6'
8
8
 
9
9
  # Use the following formatting: (major, minor, patch, pre-release)
10
10
  VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)