megatron-core 0.14.0rc4__tar.gz → 0.14.0rc5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (311) hide show
  1. {megatron_core-0.14.0rc4/megatron_core.egg-info → megatron_core-0.14.0rc5}/PKG-INFO +10 -4
  2. megatron_core-0.14.0rc5/megatron/core/activations.py +23 -0
  3. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/param_and_grad_buffer.py +32 -12
  4. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/extensions/transformer_engine.py +13 -8
  5. megatron_core-0.14.0rc5/megatron/core/fusions/fused_weighted_squared_relu.py +110 -0
  6. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/package_info.py +1 -1
  7. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/tensor_parallel/layers.py +12 -14
  8. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/mlp.py +27 -6
  9. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/experts.py +161 -48
  10. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/multi_token_prediction.py +4 -1
  11. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/transformer_block.py +15 -1
  12. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/transformer_config.py +3 -0
  13. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5/megatron_core.egg-info}/PKG-INFO +10 -4
  14. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron_core.egg-info/SOURCES.txt +2 -0
  15. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron_core.egg-info/requires.txt +9 -3
  16. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/pyproject.toml +25 -4
  17. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/LICENSE +0 -0
  18. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/MANIFEST.in +0 -0
  19. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/README.md +0 -0
  20. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/README.md +0 -0
  21. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/__init__.py +0 -0
  22. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/config.py +0 -0
  23. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/config_logger.py +0 -0
  24. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/__init__.py +0 -0
  25. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/bert_dataset.py +0 -0
  26. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/blended_dataset.py +0 -0
  27. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  28. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  29. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/gpt_dataset.py +0 -0
  30. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/helpers.cpp +0 -0
  31. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/helpers.py +0 -0
  32. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/indexed_dataset.py +0 -0
  33. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/masked_dataset.py +0 -0
  34. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/megatron_dataset.py +0 -0
  35. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  36. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/multimodal_dataset.py +0 -0
  37. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/object_storage_utils.py +0 -0
  38. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/__init__.py +0 -0
  39. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/config/__init__.py +0 -0
  40. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  41. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/config/config.py +0 -0
  42. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  43. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  44. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/db/__init__.py +0 -0
  45. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/db/build.py +0 -0
  46. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/db/dataset.py +0 -0
  47. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/db/utils.py +0 -0
  48. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/external_libs.py +0 -0
  49. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/index/__init__.py +0 -0
  50. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/index/build.py +0 -0
  51. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/index/factory.py +0 -0
  52. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/index/index.py +0 -0
  53. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  54. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  55. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  56. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/index/utils.py +0 -0
  57. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/index/validate.py +0 -0
  58. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/query/__init__.py +0 -0
  59. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  60. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  61. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/query/query.py +0 -0
  62. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  63. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/query/utils.py +0 -0
  64. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/retro/utils.py +0 -0
  65. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/t5_dataset.py +0 -0
  66. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/utils.py +0 -0
  67. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/utils_object_storage.py +0 -0
  68. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/datasets/utils_s3.py +0 -0
  69. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/__init__.py +0 -0
  70. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/core.py +0 -0
  71. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/dict_utils.py +0 -0
  72. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  73. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/mapping.py +0 -0
  74. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/optimizer.py +0 -0
  75. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/serialization.py +0 -0
  76. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  77. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  78. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/async_utils.py +0 -0
  79. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/base.py +0 -0
  80. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  81. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  82. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  83. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  84. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  85. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  86. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  87. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/torch.py +0 -0
  88. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  89. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/strategies/zarr.py +0 -0
  90. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  91. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/utils.py +0 -0
  92. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/dist_checkpointing/validation.py +0 -0
  93. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/__init__.py +0 -0
  94. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/custom_fsdp/__init__.py +0 -0
  95. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -0
  96. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py +0 -0
  97. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/data_parallel_base.py +0 -0
  98. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  99. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/distributed_data_parallel_config.py +0 -0
  100. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/finalize_model_grads.py +0 -0
  101. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +0 -0
  102. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  103. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/energy_monitor.py +0 -0
  104. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/enums.py +0 -0
  105. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/__init__.py +0 -0
  106. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/data_type.py +0 -0
  107. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/export_config.py +0 -0
  108. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/model_type.py +0 -0
  109. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/__init__.py +0 -0
  110. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  111. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  112. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  113. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  114. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  115. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  116. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  117. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  118. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  119. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  120. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  121. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  122. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/extensions/__init__.py +0 -0
  123. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/extensions/kitchen.py +0 -0
  124. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/extensions/transformer_engine_spec_provider.py +0 -0
  125. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fp8_utils.py +0 -0
  126. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/__init__.py +0 -0
  127. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  128. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  129. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  130. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  131. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  132. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_indices_converter.py +0 -0
  133. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_layer_norm.py +0 -0
  134. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  135. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  136. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/fusions/fused_softmax.py +0 -0
  137. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/hyper_comm_grid.py +0 -0
  138. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/__init__.py +0 -0
  139. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/async_stream.py +0 -0
  140. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/common_inference_params.py +0 -0
  141. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/communication_utils.py +0 -0
  142. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/contexts/__init__.py +0 -0
  143. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/contexts/base_context.py +0 -0
  144. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  145. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/contexts/dynamic_context.py +0 -0
  146. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/contexts/static_context.py +0 -0
  147. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/engines/__init__.py +0 -0
  148. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/engines/abstract_engine.py +0 -0
  149. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/engines/dynamic_engine.py +0 -0
  150. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/engines/mcore_engine.py +0 -0
  151. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/engines/static_engine.py +0 -0
  152. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/inference_request.py +0 -0
  153. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  154. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  155. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  156. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  157. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  158. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  159. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  160. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  161. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/sampling_params.py +0 -0
  162. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/scheduler.py +0 -0
  163. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  164. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  165. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  166. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +0 -0
  167. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  168. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference/utils.py +0 -0
  169. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/inference_params.py +0 -0
  170. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/jit.py +0 -0
  171. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/model_parallel_config.py +0 -0
  172. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/T5/__init__.py +0 -0
  173. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/T5/t5_model.py +0 -0
  174. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/T5/t5_spec.py +0 -0
  175. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/__init__.py +0 -0
  176. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/backends.py +0 -0
  177. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/bert/__init__.py +0 -0
  178. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  179. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/bert/bert_lm_head.py +0 -0
  180. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/bert/bert_model.py +0 -0
  181. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/bert/pooler.py +0 -0
  182. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/__init__.py +0 -0
  183. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/embeddings/__init__.py +0 -0
  184. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  185. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  186. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  187. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  188. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +0 -0
  189. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/language_module/__init__.py +0 -0
  190. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/language_module/language_module.py +0 -0
  191. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
  192. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/vision_module/__init__.py +0 -0
  193. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  194. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/gpt/__init__.py +0 -0
  195. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  196. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/gpt/gpt_layer_specs.py +0 -0
  197. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/gpt/gpt_model.py +0 -0
  198. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  199. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/gpt/moe_module_specs.py +0 -0
  200. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/huggingface/__init__.py +0 -0
  201. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/huggingface/clip_model.py +0 -0
  202. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/huggingface/module.py +0 -0
  203. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/huggingface/qwen_model.py +0 -0
  204. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mamba/__init__.py +0 -0
  205. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  206. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mamba/mamba_model.py +0 -0
  207. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mimo/__init__.py +0 -0
  208. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mimo/config/__init__.py +0 -0
  209. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mimo/config/base_configs.py +0 -0
  210. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mimo/model/__init__.py +0 -0
  211. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mimo/model/base.py +0 -0
  212. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mimo/submodules/audio.py +0 -0
  213. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mimo/submodules/base.py +0 -0
  214. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/mimo/submodules/vision.py +0 -0
  215. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/multimodal/__init__.py +0 -0
  216. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/multimodal/context_parallel.py +0 -0
  217. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/multimodal/llava_model.py +0 -0
  218. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/multimodal/llava_spec.py +0 -0
  219. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/retro/__init__.py +0 -0
  220. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/retro/base_attention.py +0 -0
  221. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/retro/config.py +0 -0
  222. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/retro/decoder_attention.py +0 -0
  223. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/retro/decoder_spec.py +0 -0
  224. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/retro/encoder_attention.py +0 -0
  225. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/retro/encoder_spec.py +0 -0
  226. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/retro/model.py +0 -0
  227. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/retro/utils.py +0 -0
  228. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/vision/__init__.py +0 -0
  229. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/vision/clip_vit_model.py +0 -0
  230. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/vision/multimodal_projector.py +0 -0
  231. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/vision/radio.py +0 -0
  232. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  233. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/msc_utils.py +0 -0
  234. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/num_microbatches_calculator.py +0 -0
  235. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/optimizer/__init__.py +0 -0
  236. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/optimizer/clip_grads.py +0 -0
  237. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  238. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  239. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/optimizer/distrib_optimizer.py +0 -0
  240. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/optimizer/grad_scaler.py +0 -0
  241. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/optimizer/optimizer.py +0 -0
  242. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/optimizer/optimizer_config.py +0 -0
  243. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/optimizer_param_scheduler.py +0 -0
  244. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/packed_seq_params.py +0 -0
  245. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/parallel_state.py +0 -0
  246. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/pipeline_parallel/__init__.py +0 -0
  247. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
  248. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  249. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/pipeline_parallel/schedules.py +0 -0
  250. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/pipeline_parallel/utils.py +0 -0
  251. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/post_training/__init__.py +0 -0
  252. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/post_training/modelopt/__init__.py +0 -0
  253. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  254. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  255. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +0 -0
  256. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/post_training/modelopt/layers.py +0 -0
  257. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  258. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  259. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/process_groups_config.py +0 -0
  260. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/quantization/__init__.py +0 -0
  261. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/quantization/quant_config.py +0 -0
  262. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/quantization/utils.py +0 -0
  263. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/requirements.txt +0 -0
  264. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/rerun_state_machine.py +0 -0
  265. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/ssm/__init__.py +0 -0
  266. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/ssm/mamba_block.py +0 -0
  267. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  268. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  269. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/ssm/mamba_layer.py +0 -0
  270. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/ssm/mamba_mixer.py +0 -0
  271. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/ssm/mlp_layer.py +0 -0
  272. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/ssm/triton_cache_manager.py +0 -0
  273. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/tensor_parallel/__init__.py +0 -0
  274. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  275. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/tensor_parallel/data.py +0 -0
  276. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/tensor_parallel/mappings.py +0 -0
  277. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/tensor_parallel/random.py +0 -0
  278. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/tensor_parallel/utils.py +0 -0
  279. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/timers.py +0 -0
  280. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/__init__.py +0 -0
  281. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/attention.py +0 -0
  282. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/cuda_graphs.py +0 -0
  283. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  284. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  285. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/dot_product_attention.py +0 -0
  286. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/enums.py +0 -0
  287. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  288. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  289. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/identity_op.py +0 -0
  290. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/module.py +0 -0
  291. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/__init__.py +0 -0
  292. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  293. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  294. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/moe_layer.py +0 -0
  295. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/moe_utils.py +0 -0
  296. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/router.py +0 -0
  297. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/shared_experts.py +0 -0
  298. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  299. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  300. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/multi_latent_attention.py +0 -0
  301. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  302. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/spec_utils.py +0 -0
  303. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/torch_layer_norm.py +0 -0
  304. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/torch_norm.py +0 -0
  305. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/transformer_layer.py +0 -0
  306. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/transformer/utils.py +0 -0
  307. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron/core/utils.py +0 -0
  308. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron_core.egg-info/dependency_links.txt +0 -0
  309. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/megatron_core.egg-info/top_level.txt +0 -0
  310. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/setup.cfg +0 -0
  311. {megatron_core-0.14.0rc4 → megatron_core-0.14.0rc5}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.14.0rc4
3
+ Version: 0.14.0rc5
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -31,7 +31,7 @@ Description-Content-Type: text/markdown
31
31
  License-File: LICENSE
32
32
  Requires-Dist: torch
33
33
  Requires-Dist: numpy<2.0.0
34
- Requires-Dist: packaging~=25.0
34
+ Requires-Dist: packaging
35
35
  Provides-Extra: mlm
36
36
  Requires-Dist: flask-restful; extra == "mlm"
37
37
  Requires-Dist: sentencepiece; extra == "mlm"
@@ -43,12 +43,18 @@ Requires-Dist: einops~=0.8; extra == "dev"
43
43
  Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "dev"
44
44
  Requires-Dist: nvtx~=0.2; extra == "dev"
45
45
  Requires-Dist: transformers~=4.53; extra == "dev"
46
- Requires-Dist: multi-storage-client~=0.20.3; extra == "dev"
46
+ Requires-Dist: multi-storage-client~=0.20; extra == "dev"
47
47
  Requires-Dist: opentelemetry-api~=1.33.1; extra == "dev"
48
48
  Requires-Dist: setuptools<80.0.0; extra == "dev"
49
- Requires-Dist: nvidia-modelopt[torch]<0.32.0,>=0.31.0a0; sys_platform != "darwin" and extra == "dev"
49
+ Requires-Dist: mamba-ssm~=2.2; extra == "dev"
50
+ Requires-Dist: causal-conv1d~=1.5; extra == "dev"
51
+ Requires-Dist: nv-grouped-gemm~=1.1; extra == "dev"
52
+ Requires-Dist: transformer-engine[pytorch]<2.7.0,>=2.5.0a0; extra == "dev"
53
+ Requires-Dist: nvidia-resiliency-ext<0.5.0,>=0.4.0a0; extra == "dev"
54
+ Requires-Dist: nvidia-modelopt[torch]<0.34.0,>=0.33.0a0; sys_platform != "darwin" and extra == "dev"
50
55
  Requires-Dist: megatron-energon[av_decode]~=6.0; extra == "dev"
51
56
  Requires-Dist: flashinfer-python; extra == "dev"
57
+ Requires-Dist: onnxscript; extra == "dev"
52
58
  Provides-Extra: lts
53
59
  Requires-Dist: tqdm; extra == "lts"
54
60
  Requires-Dist: einops; extra == "lts"
@@ -0,0 +1,23 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from megatron.core.jit import jit_fuser
6
+
7
+
8
+ @jit_fuser
9
+ def squared_relu(x: torch.Tensor) -> torch.Tensor:
10
+ """Squared ReLU activation"""
11
+ return torch.pow(F.relu(x), 2)
12
+
13
+
14
+ @jit_fuser
15
+ def quick_gelu(x: torch.Tensor) -> torch.Tensor:
16
+ """Quick GELU activation"""
17
+ return x * torch.sigmoid(1.702 * x)
18
+
19
+
20
+ @jit_fuser
21
+ def fast_gelu(x: torch.Tensor) -> torch.Tensor:
22
+ """Fast GELU activation"""
23
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
@@ -157,6 +157,14 @@ class _ParamAndGradBucketGroup:
157
157
  self.param_gather_dispatched = False
158
158
  self.grad_reduce_handle = None
159
159
 
160
+ # Each time a local shard is created from bucket.param_data or bucket.grad_data, it
161
+ # introduces some CPU overheads. We use these two lists to cache the created local
162
+ # shards to avoid unnecessary CPU operations. This does not increase GPU memory usage
163
+ # because it only saves a slice view, which shares the same memory with bucket.param_data
164
+ # or bucket.grad_data.
165
+ self.cached_param_buffer_shard_list = [None] * len(self.buckets)
166
+ self.cached_grad_buffer_shard_list = [None] * len(self.buckets)
167
+
160
168
  def reset(self):
161
169
  """
162
170
  Reset metadata in bucket group in preparation for the next iteration of training.
@@ -229,10 +237,14 @@ class _ParamAndGradBucketGroup:
229
237
  with _coalescing_manager(
230
238
  self.intra_distributed_optimizer_instance_group, async_ops=async_op
231
239
  ) as cm:
232
- for bucket in self.buckets:
233
- local_data_view = shard_buffer(
234
- bucket.param_data, self.intra_distributed_optimizer_instance_size
235
- )[self.intra_distributed_optimizer_instance_rank]
240
+ for idx, bucket in enumerate(self.buckets):
241
+ if self.cached_param_buffer_shard_list[idx] is None:
242
+ self.cached_param_buffer_shard_list[idx] = shard_buffer(
243
+ bucket.param_data, self.intra_distributed_optimizer_instance_size
244
+ )
245
+ local_data_view = self.cached_param_buffer_shard_list[idx][
246
+ self.intra_distributed_optimizer_instance_rank
247
+ ]
236
248
  dist_all_gather_func(
237
249
  bucket.param_data,
238
250
  local_data_view,
@@ -352,11 +364,15 @@ class _ParamAndGradBucketGroup:
352
364
 
353
365
  # Coalesce communication kernels across buckets in the bucket group.
354
366
  with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm:
355
- for bucket in self.buckets:
367
+ for idx, bucket in enumerate(self.buckets):
356
368
  if self.ddp_config.use_distributed_optimizer:
357
- local_data_view = shard_buffer(
358
- bucket.grad_data, self.intra_distributed_optimizer_instance_size
359
- )[self.intra_distributed_optimizer_instance_rank]
369
+ if self.cached_grad_buffer_shard_list[idx] is None:
370
+ self.cached_grad_buffer_shard_list[idx] = shard_buffer(
371
+ bucket.grad_data, self.intra_distributed_optimizer_instance_size
372
+ )
373
+ local_data_view = self.cached_grad_buffer_shard_list[idx][
374
+ self.intra_distributed_optimizer_instance_rank
375
+ ]
360
376
  dist_reduce_scatter_func(
361
377
  local_data_view,
362
378
  bucket.grad_data,
@@ -382,10 +398,14 @@ class _ParamAndGradBucketGroup:
382
398
  self.inter_distributed_optimizer_instance_group, async_ops=async_op
383
399
  ) as cm,
384
400
  ):
385
- for bucket in self.buckets:
386
- local_data_view = shard_buffer(
387
- bucket.grad_data, self.intra_distributed_optimizer_instance_size
388
- )[self.intra_distributed_optimizer_instance_rank]
401
+ for idx, bucket in enumerate(self.buckets):
402
+ if self.cached_grad_buffer_shard_list[idx] is None:
403
+ self.cached_grad_buffer_shard_list[idx] = shard_buffer(
404
+ bucket.grad_data, self.intra_distributed_optimizer_instance_size
405
+ )
406
+ local_data_view = self.cached_grad_buffer_shard_list[idx][
407
+ self.intra_distributed_optimizer_instance_rank
408
+ ]
389
409
 
390
410
  torch.distributed.all_reduce(
391
411
  local_data_view,
@@ -1192,6 +1192,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"):
1192
1192
  """
1193
1193
  prefix should be module_name to make keys identical to sequetial ones.
1194
1194
  """
1195
+ singleton_local_shards = (metadata or {}).get('singleton_local_shards', False)
1195
1196
  sharded_state_dict = {}
1196
1197
  full_state_dict = self.state_dict(prefix="", keep_vars=True)
1197
1198
  num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms
@@ -1199,23 +1200,27 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"):
1199
1200
  ep_axis = len(sharded_offsets)
1200
1201
  extra_states = self._split_extra_state(full_state_dict["_extra_state"])
1201
1202
  for gemm_idx in range(self.num_gemms):
1203
+ global_expert_idx = local_expert_indices_offset + gemm_idx
1202
1204
  state_dict = {
1203
1205
  f"{gemm_idx}.weight": full_state_dict[f"weight{gemm_idx}"],
1204
1206
  f"{gemm_idx}._extra_state": extra_states[gemm_idx],
1205
1207
  }
1206
1208
  if self.use_bias:
1207
1209
  state_dict[f"{gemm_idx}.bias"] = full_state_dict[f"bias{gemm_idx}"]
1208
- sub_sd = make_sharded_tensors_for_checkpoint(
1209
- state_dict,
1210
- "",
1211
- tp_axis_map,
1212
- (
1210
+ if singleton_local_shards:
1211
+ expert_prefix = f"{global_expert_idx}.{prefix}"
1212
+ new_sharded_offsets = sharded_offsets
1213
+ else:
1214
+ expert_prefix = prefix
1215
+ new_sharded_offsets = (
1213
1216
  *sharded_offsets,
1214
- (ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts),
1215
- ),
1217
+ (ep_axis, global_expert_idx, num_global_experts),
1218
+ )
1219
+ sub_sd = make_sharded_tensors_for_checkpoint(
1220
+ state_dict, '', tp_axis_map, new_sharded_offsets
1216
1221
  )
1217
1222
  # Remove expert layers indexing from sharded keys
1218
- replace_prefix_for_sharding(sub_sd, f"{gemm_idx}.", prefix)
1223
+ replace_prefix_for_sharding(sub_sd, f"{gemm_idx}.", expert_prefix)
1219
1224
  sharded_state_dict.update(
1220
1225
  {
1221
1226
  f"{prefix}weight{gemm_idx}": sub_sd[f"{gemm_idx}.weight"],
@@ -0,0 +1,110 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from megatron.core.activations import squared_relu
7
+ from megatron.core.jit import jit_fuser
8
+ from megatron.core.utils import nvtx_decorator
9
+
10
+ ###################### WEIGHTED SQUARED ReLU FUSION ######################
11
+
12
+
13
+ @jit_fuser
14
+ def weighted_squared_relu(x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
15
+ """Element-wise weight applied after Squared-ReLU.
16
+
17
+ Args:
18
+ x (torch.Tensor): Input tensor.
19
+ weights (torch.Tensor): Weight tensor that will be broadcast-multiplied with the
20
+ activation result. Typically of shape ``(B, 1)`` so it can be broadcast across
21
+ the hidden dimension.
22
+
23
+ Returns:
24
+ torch.Tensor: ``squared_relu(x) * weights`` with original ``dtype`` preserved.
25
+ """
26
+ out_dtype = x.dtype
27
+ res = torch.pow(F.relu(x), 2) * weights
28
+ return res.to(out_dtype)
29
+
30
+
31
+ @jit_fuser
32
+ def _squared_relu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
33
+ """Gradient of Squared-ReLU.
34
+
35
+ The derivative of ``(ReLU(x))^2`` w.r.t ``x`` is ``2 * ReLU(x)``.
36
+ """
37
+ return g * 2 * F.relu(x)
38
+
39
+
40
+ @jit_fuser
41
+ def weighted_squared_relu_back(g: torch.Tensor, x: torch.Tensor, weights: torch.Tensor):
42
+ """Backward for weighted Squared-ReLU.
43
+
44
+ Returns gradients w.r.t ``x`` and ``weights``.
45
+ """
46
+ input_dtype = x.dtype
47
+ w_dtype = weights.dtype
48
+
49
+ # Gradient w.r.t. the input.
50
+ input_grad = _squared_relu_back(g * weights, x)
51
+
52
+ # Gradient w.r.t. the weights.
53
+ weights_grad = squared_relu(x) * g.to(w_dtype)
54
+ # Sum across the hidden dimension so each token has a single scalar weight.
55
+ weights_grad = torch.sum(weights_grad, dim=-1, keepdim=True)
56
+
57
+ return input_grad.to(input_dtype), weights_grad.to(w_dtype)
58
+
59
+
60
+ class WeightedSquaredReLUFunction(torch.autograd.Function):
61
+ """Autograd wrapper around the weighted Squared-ReLU fused kernels."""
62
+
63
+ @staticmethod
64
+ @nvtx_decorator()
65
+ def forward(ctx, input: torch.Tensor, weights: torch.Tensor):
66
+ """forward method for `WeightedSquaredReLUFunction`
67
+
68
+ Args:
69
+ ctx : context object to store intermediate tensors.
70
+ input (torch.Tensor): input tensor.
71
+ weights (torch.Tensor): weight tensor.
72
+ fp8_input_store (bool): a bool flag to indicate if storing input in fp8.
73
+ """
74
+ ctx.save_for_backward(input, weights)
75
+ return weighted_squared_relu(input, weights)
76
+
77
+ @staticmethod
78
+ @nvtx_decorator()
79
+ def backward(ctx, grad_output: torch.Tensor):
80
+ """backward method for `WeightedSquaredReLUFunction`
81
+
82
+ Args:
83
+ ctx : context object to store intermediate tensors.
84
+ grad_output (torch.Tensor): gradient of the output of the forward function.
85
+ """
86
+ input, weights = ctx.saved_tensors
87
+ inp_grad, w_grad = weighted_squared_relu_back(grad_output, input, weights)
88
+ return inp_grad, w_grad
89
+
90
+
91
+ def weighted_squared_relu_impl(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
92
+ """Token-wise weighted Squared-ReLU fusion with optional FP8 storage.
93
+
94
+ Args:
95
+ input (torch.Tensor): Input tensor of shape ``(B, *, hidden_size)`` where ``*`` can be
96
+ the sequence dimension.
97
+ weights (torch.Tensor): Per-token weights broadcastable to the output of
98
+ ``squared_relu``.
99
+
100
+ Returns:
101
+ torch.Tensor: Output tensor with the same shape as ``input`` except that the hidden
102
+ dimension remains unchanged.
103
+ """
104
+ ori_shape = input.shape
105
+ assert len(ori_shape) in [2, 3]
106
+ input = input.view(-1, ori_shape[-1])
107
+
108
+ output = WeightedSquaredReLUFunction.apply(input, weights)
109
+
110
+ return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
@@ -4,7 +4,7 @@
4
4
  MAJOR = 0
5
5
  MINOR = 14
6
6
  PATCH = 0
7
- PRE_RELEASE = 'rc4'
7
+ PRE_RELEASE = 'rc5'
8
8
 
9
9
  # Use the following formatting: (major, minor, patch, pre-release)
10
10
  VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
@@ -923,8 +923,6 @@ class ColumnParallelLinear(torch.nn.Module):
923
923
  "`allreduce_dgrad` and `sequence_parallel` cannot be enabled at the same time."
924
924
  )
925
925
 
926
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
927
-
928
926
  # Hook adding a default empty _extra_state for state dict
929
927
  self._register_load_state_dict_pre_hook(
930
928
  lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
@@ -932,6 +930,12 @@ class ColumnParallelLinear(torch.nn.Module):
932
930
  )
933
931
  )
934
932
 
933
+ def _forward_impl(self, *args, **kwargs):
934
+ if self.weight is not None and not self.weight.requires_grad:
935
+ return linear_with_frozen_weight(*args, **kwargs)
936
+ else:
937
+ return linear_with_grad_accumulation_and_async_allreduce(*args, **kwargs)
938
+
935
939
  def forward(
936
940
  self,
937
941
  input_: torch.Tensor,
@@ -989,11 +993,6 @@ class ColumnParallelLinear(torch.nn.Module):
989
993
  self.embedding_activation_buffer.append(input_parallel)
990
994
 
991
995
  # Matrix multiply.
992
- if not weight.requires_grad:
993
- self._forward_impl = linear_with_frozen_weight
994
- else:
995
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
996
-
997
996
  allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad
998
997
 
999
998
  if self.config._cpu_offloading_context is not None:
@@ -1203,8 +1202,6 @@ class RowParallelLinear(torch.nn.Module):
1203
1202
  else:
1204
1203
  self.register_parameter("bias", None)
1205
1204
 
1206
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
1207
-
1208
1205
  # Hook adding a default empty _extra_state for state dict
1209
1206
  self._register_load_state_dict_pre_hook(
1210
1207
  lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
@@ -1212,6 +1209,12 @@ class RowParallelLinear(torch.nn.Module):
1212
1209
  )
1213
1210
  )
1214
1211
 
1212
+ def _forward_impl(self, *args, **kwargs):
1213
+ if self.weight is not None and not self.weight.requires_grad:
1214
+ return linear_with_frozen_weight(*args, **kwargs)
1215
+ else:
1216
+ return linear_with_grad_accumulation_and_async_allreduce(*args, **kwargs)
1217
+
1215
1218
  def forward(self, input_):
1216
1219
  """Forward of RowParallelLinear
1217
1220
 
@@ -1230,11 +1233,6 @@ class RowParallelLinear(torch.nn.Module):
1230
1233
  assert not self.sequence_parallel
1231
1234
  input_parallel = scatter_to_tensor_model_parallel_region(input_, group=self.tp_group)
1232
1235
  # Matrix multiply.
1233
- if not self.weight.requires_grad:
1234
- self._forward_impl = linear_with_frozen_weight
1235
- else:
1236
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
1237
-
1238
1236
  allreduce_dgrad = False
1239
1237
 
1240
1238
  if self.config._cpu_offloading_context is not None:
@@ -198,12 +198,15 @@ class MLP(MegatronModule):
198
198
  self, prefix: str = "", sharded_offsets: tuple = (), metadata: Optional[dict] = None
199
199
  ) -> ShardedStateDict:
200
200
  sharded_state_dict = {}
201
+ singleton_local_shards = (metadata or {}).get('singleton_local_shards', False)
201
202
  for name, module in self._modules.items():
202
203
  sub_sd = module.sharded_state_dict(f"{prefix}{name}.", sharded_offsets, metadata)
203
204
  if self.config.gated_linear_unit and name == "linear_fc1":
204
205
  for k, v in sub_sd.items():
205
206
  if k in (f"{prefix}{name}.weight", f"{prefix}{name}.bias"):
206
- sub_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets)
207
+ sub_sd[k] = apply_swiglu_sharded_factory(
208
+ v, sharded_offsets, singleton_local_shards
209
+ )
207
210
  sharded_state_dict.update(sub_sd)
208
211
  return sharded_state_dict
209
212
 
@@ -213,7 +216,9 @@ class MLP(MegatronModule):
213
216
 
214
217
 
215
218
  # pylint: disable=missing-function-docstring
216
- def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
219
+ def apply_swiglu_sharded_factory(
220
+ original_sh_ten, sharded_offsets, singleton_local_shards: bool = False
221
+ ):
217
222
  # We must split the tensor into 2 parts, each sharded separately.
218
223
  # This requires a ShardedTensorFactory which `chunk`s during saving
219
224
  # and `cat`s during loading
@@ -235,13 +240,25 @@ def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
235
240
  def sh_ten_build_fn(
236
241
  key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice]
237
242
  ):
238
- offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag * 2)
239
- offset_v = (swiglu_shard_axis + prepend_axis_num, rank_offset + axis_frag, axis_frag * 2)
243
+ if singleton_local_shards:
244
+ offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag)
245
+ offset_v = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag)
246
+ w_key = f'{key}_w'
247
+ v_key = f'{key}_v'
248
+ else:
249
+ offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag * 2)
250
+ offset_v = (
251
+ swiglu_shard_axis + prepend_axis_num,
252
+ rank_offset + axis_frag,
253
+ axis_frag * 2,
254
+ )
255
+ w_key = key
256
+ v_key = key
240
257
  if flattened_range is None:
241
258
  tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis)
242
259
  return [
243
260
  ShardedTensor.from_rank_offsets(
244
- key,
261
+ w_key,
245
262
  tensor_w,
246
263
  *sharded_offsets,
247
264
  offset_w,
@@ -249,7 +266,7 @@ def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
249
266
  prepend_axis_num=prepend_axis_num,
250
267
  ),
251
268
  ShardedTensor.from_rank_offsets(
252
- key,
269
+ v_key,
253
270
  tensor_v,
254
271
  *sharded_offsets,
255
272
  offset_v,
@@ -258,6 +275,10 @@ def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
258
275
  ),
259
276
  ]
260
277
  else:
278
+ if singleton_local_shards:
279
+ raise NotImplementedError(
280
+ 'singleton_local_shards not implemented for SwiGLU MLP flattened tensors'
281
+ )
261
282
  # Here we need to map a slice `t` (`flattened_range` specifies slice start and stop)
262
283
  # of the *original* flattened tensor into slices `w` and `v` of chunked
263
284
  # and flattened tensor.